diff --git a/backend/routes/playground.py b/backend/routes/playground.py index 8ed7f40..bb11e96 100644 --- a/backend/routes/playground.py +++ b/backend/routes/playground.py @@ -7,13 +7,16 @@ - Global circuit breaker: 10k searches/hour (cost protection) """ import os +import re +import httpx from typing import Optional from fastapi import APIRouter, HTTPException, Request, Response -from pydantic import BaseModel +from pydantic import BaseModel, field_validator import time from dependencies import indexer, cache, repo_manager, redis_client from services.input_validator import InputValidator +from services.repo_validator import RepoValidator from services.observability import logger from services.playground_limiter import PlaygroundLimiter, get_playground_limiter @@ -27,6 +30,15 @@ SESSION_COOKIE_MAX_AGE = 86400 # 24 hours IS_PRODUCTION = os.getenv("ENVIRONMENT", "development").lower() == "production" +# GitHub validation config +GITHUB_URL_PATTERN = re.compile( + r"^https?://github\.com/(?P[a-zA-Z0-9_.-]+)/(?P[a-zA-Z0-9_.-]+)/?$" +) +ANONYMOUS_FILE_LIMIT = 200 # Max files for anonymous indexing +GITHUB_API_BASE = "https://api.github.com" +GITHUB_API_TIMEOUT = 10.0 # seconds +VALIDATION_CACHE_TTL = 300 # 5 minutes + class PlaygroundSearchRequest(BaseModel): query: str @@ -34,6 +46,24 @@ class PlaygroundSearchRequest(BaseModel): max_results: int = 10 +class ValidateRepoRequest(BaseModel): + """Request body for GitHub repo validation.""" + github_url: str + + @field_validator("github_url") + @classmethod + def validate_github_url_format(cls, v: str) -> str: + """Basic URL format validation.""" + v = v.strip() + if not v: + raise ValueError("GitHub URL is required") + if not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") + if "github.com" not in v.lower(): + raise ValueError("URL must be a GitHub repository URL") + return v + + async def load_demo_repos(): """Load pre-indexed demo repos. Called from main.py on startup.""" # Note: We mutate DEMO_REPO_IDS dict, no need for 'global' statement @@ -305,3 +335,259 @@ async def get_playground_stats(): limiter = _get_limiter() stats = limiter.get_usage_stats() return stats + + +def _parse_github_url(url: str) -> tuple[Optional[str], Optional[str], Optional[str]]: + """ + Parse GitHub URL to extract owner and repo. + + Returns: + (owner, repo, error) - error is None if successful + """ + match = GITHUB_URL_PATTERN.match(url.strip().rstrip("/")) + if not match: + return None, None, "Invalid GitHub URL format. Expected: https://github.com/owner/repo" + return match.group("owner"), match.group("repo"), None + + +async def _fetch_repo_metadata(owner: str, repo: str) -> dict: + """ + Fetch repository metadata from GitHub API. + + Returns dict with repo info or error details. + """ + url = f"{GITHUB_API_BASE}/repos/{owner}/{repo}" + headers = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "OpenCodeIntel/1.0", + } + + # Add GitHub token if available (for higher rate limits) + github_token = os.getenv("GITHUB_TOKEN") + if github_token: + headers["Authorization"] = f"token {github_token}" + + async with httpx.AsyncClient(timeout=GITHUB_API_TIMEOUT) as client: + try: + response = await client.get(url, headers=headers) + + if response.status_code == 404: + return {"error": "not_found", "message": "Repository not found"} + if response.status_code == 403: + return { + "error": "rate_limited", + "message": "GitHub API rate limit exceeded" + } + if response.status_code != 200: + return { + "error": "api_error", + "message": f"GitHub API error: {response.status_code}" + } + + return response.json() + except httpx.TimeoutException: + return {"error": "timeout", "message": "GitHub API request timed out"} + except Exception as e: + logger.error("GitHub API request failed", error=str(e)) + return {"error": "request_failed", "message": str(e)} + + +async def _count_code_files( + owner: str, repo: str, default_branch: str +) -> tuple[int, Optional[str]]: + """ + Count code files in repository using GitHub tree API. + + Returns: + (file_count, error) - error is None if successful + """ + url = f"{GITHUB_API_BASE}/repos/{owner}/{repo}/git/trees/{default_branch}?recursive=1" + headers = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "OpenCodeIntel/1.0", + } + + github_token = os.getenv("GITHUB_TOKEN") + if github_token: + headers["Authorization"] = f"token {github_token}" + + async with httpx.AsyncClient(timeout=GITHUB_API_TIMEOUT) as client: + try: + response = await client.get(url, headers=headers) + + if response.status_code == 404: + return 0, "Could not fetch repository tree" + if response.status_code == 403: + return 0, "GitHub API rate limit exceeded" + if response.status_code != 200: + return 0, f"GitHub API error: {response.status_code}" + + data = response.json() + + # Check if tree was truncated (very large repos) + if data.get("truncated", False): + # For truncated trees, estimate from repo size + # GitHub's size is in KB, rough estimate: 1 code file per 5KB + return -1, "truncated" + + # Count files with code extensions + code_extensions = RepoValidator.CODE_EXTENSIONS + skip_dirs = RepoValidator.SKIP_DIRS + + count = 0 + for item in data.get("tree", []): + if item.get("type") != "blob": + continue + + path = item.get("path", "") + + # Skip if in excluded directory + path_parts = path.split("/") + if any(part in skip_dirs for part in path_parts): + continue + + # Check extension + ext = "." + path.rsplit(".", 1)[-1] if "." in path else "" + if ext.lower() in code_extensions: + count += 1 + + return count, None + except httpx.TimeoutException: + return 0, "GitHub API request timed out" + except Exception as e: + logger.error("GitHub tree API failed", error=str(e)) + return 0, str(e) + + +@router.post("/validate-repo") +async def validate_github_repo(request: ValidateRepoRequest, req: Request): + """ + Validate a GitHub repository URL for anonymous indexing. + + Checks: + - URL format is valid + - Repository exists and is public + - File count is within anonymous limit (200 files) + + Response varies based on validation result (see issue #124). + """ + start_time = time.time() + + # Check cache first + cache_key = f"validate:{request.github_url}" + cached = cache.get(cache_key) if cache else None + if cached: + logger.info("Returning cached validation", url=request.github_url[:50]) + return cached + + # Parse URL + owner, repo_name, parse_error = _parse_github_url(request.github_url) + if parse_error: + return { + "valid": False, + "reason": "invalid_url", + "message": parse_error, + } + + # Fetch repo metadata from GitHub + metadata = await _fetch_repo_metadata(owner, repo_name) + + if "error" in metadata: + error_type = metadata["error"] + if error_type == "not_found": + return { + "valid": False, + "reason": "not_found", + "message": "Repository not found. Check the URL or ensure it's public.", + } + elif error_type == "rate_limited": + raise HTTPException( + status_code=429, + detail={"message": "GitHub API rate limit exceeded. Try again later."} + ) + else: + raise HTTPException( + status_code=502, + detail={"message": metadata.get("message", "Failed to fetch repository info")} + ) + + # Check if private + is_private = metadata.get("private", False) + if is_private: + return { + "valid": True, + "repo_name": repo_name, + "owner": owner, + "is_public": False, + "can_index": False, + "reason": "private", + "message": "This repository is private. " + "Anonymous indexing only supports public repositories.", + } + + # Get file count + default_branch = metadata.get("default_branch", "main") + file_count, count_error = await _count_code_files(owner, repo_name, default_branch) + + # Handle truncated tree (very large repo) + if count_error == "truncated": + # Estimate from repo size (GitHub size is in KB) + repo_size_kb = metadata.get("size", 0) + # Rough estimate: 1 code file per 3KB for code repos + file_count = max(repo_size_kb // 3, ANONYMOUS_FILE_LIMIT + 1) + logger.info("Using estimated file count for large repo", + owner=owner, repo=repo_name, estimated=file_count) + + elif count_error: + logger.warning("Could not count files", owner=owner, repo=repo_name, error=count_error) + # Fall back to size-based estimate + repo_size_kb = metadata.get("size", 0) + file_count = max(repo_size_kb // 3, 1) + + # Build response + response_time_ms = int((time.time() - start_time) * 1000) + + if file_count > ANONYMOUS_FILE_LIMIT: + result = { + "valid": True, + "repo_name": repo_name, + "owner": owner, + "is_public": True, + "default_branch": default_branch, + "file_count": file_count, + "size_kb": metadata.get("size", 0), + "language": metadata.get("language"), + "stars": metadata.get("stargazers_count", 0), + "can_index": False, + "reason": "too_large", + "message": f"Repository has {file_count:,} code files. " + f"Anonymous limit is {ANONYMOUS_FILE_LIMIT}.", + "limit": ANONYMOUS_FILE_LIMIT, + "response_time_ms": response_time_ms, + } + else: + result = { + "valid": True, + "repo_name": repo_name, + "owner": owner, + "is_public": True, + "default_branch": default_branch, + "file_count": file_count, + "size_kb": metadata.get("size", 0), + "language": metadata.get("language"), + "stars": metadata.get("stargazers_count", 0), + "can_index": True, + "message": "Ready to index", + "response_time_ms": response_time_ms, + } + + # Cache successful validations + if cache: + cache.set(cache_key, result, ttl=VALIDATION_CACHE_TTL) + + logger.info("Validated GitHub repo", + owner=owner, repo=repo_name, + file_count=file_count, can_index=result["can_index"], + response_time_ms=response_time_ms) + + return result diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 52815ad..dddc112 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -5,7 +5,7 @@ import pytest import sys from pathlib import Path -from unittest.mock import Mock, MagicMock, patch +from unittest.mock import MagicMock, patch import os # Set test environment BEFORE imports @@ -20,6 +20,35 @@ os.environ["SUPABASE_ANON_KEY"] = "test-anon-key" os.environ["SUPABASE_JWT_SECRET"] = "test-jwt-secret" +# ============================================================================= +# EARLY PATCHING - runs during collection, before any imports +# ============================================================================= +# These patches prevent external service initialization during test collection + +_pinecone_patcher = patch('pinecone.Pinecone') +_mock_pinecone = _pinecone_patcher.start() +_pc_instance = MagicMock() +_pc_instance.list_indexes.return_value.names.return_value = [] +_pc_instance.Index.return_value = MagicMock() +_mock_pinecone.return_value = _pc_instance + +_openai_patcher = patch('openai.AsyncOpenAI') +_mock_openai = _openai_patcher.start() +_openai_client = MagicMock() +_mock_openai.return_value = _openai_client + +_supabase_patcher = patch('supabase.create_client') +_mock_supabase = _supabase_patcher.start() +_supabase_client = MagicMock() +_supabase_client.table.return_value.select.return_value.execute.return_value.data = [] +# Auth should reject by default - set user to None +_auth_response = MagicMock() +_auth_response.user = None +_supabase_client.auth.get_user.return_value = _auth_response +_mock_supabase.return_value = _supabase_client + +# ============================================================================= + # Add backend to path backend_dir = Path(__file__).parent.parent sys.path.insert(0, str(backend_dir)) @@ -50,29 +79,29 @@ def mock_pinecone(): yield mock -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session", autouse=True) def mock_redis(): """ Mock Redis client globally. - + Includes hash operations for session management (#127). """ with patch('redis.Redis') as mock: redis_instance = MagicMock() - + # Connection redis_instance.ping.return_value = True - + # String operations (legacy + IP/global limits) redis_instance.get.return_value = None redis_instance.set.return_value = True redis_instance.incr.return_value = 1 redis_instance.delete.return_value = 1 - + # TTL operations redis_instance.expire.return_value = True redis_instance.ttl.return_value = 86400 - + # Hash operations (session management #127) redis_instance.type.return_value = b'hash' redis_instance.hset.return_value = 1 @@ -84,7 +113,7 @@ def mock_redis(): redis_instance.hincrby.return_value = 1 redis_instance.hexists.return_value = False redis_instance.hdel.return_value = 1 - + mock.return_value = redis_instance yield mock @@ -95,30 +124,30 @@ def mock_supabase(): with patch('supabase.create_client') as mock: client = MagicMock() table = MagicMock() - + # Mock the fluent interface for tables execute_result = MagicMock() execute_result.data = [] execute_result.count = 0 - + table.select.return_value = table table.insert.return_value = table table.update.return_value = table - table.delete.return_value = table + table.delete.return_value = table table.eq.return_value = table table.order.return_value = table table.limit.return_value = table table.upsert.return_value = table table.execute.return_value = execute_result - + client.table.return_value = table - + # Mock auth.get_user to reject invalid tokens # By default, return response with user=None (invalid token) auth_response = MagicMock() auth_response.user = None client.auth.get_user.return_value = auth_response - + mock.return_value = client yield mock @@ -141,7 +170,7 @@ def client(): from fastapi.testclient import TestClient from main import app from middleware.auth import AuthContext - + # Override the require_auth dependency to always return a valid context async def mock_require_auth(): return AuthContext( @@ -149,12 +178,12 @@ async def mock_require_auth(): email="test@example.com", tier="enterprise" ) - + from middleware.auth import require_auth app.dependency_overrides[require_auth] = mock_require_auth - + yield TestClient(app) - + # Cleanup app.dependency_overrides.clear() @@ -169,7 +198,7 @@ def client_no_auth(): @pytest.fixture def valid_headers(): - """Valid authentication headers (not actually used with mocked auth, but kept for compatibility)""" + """Valid auth headers (kept for compatibility with mocked auth).""" return {"Authorization": "Bearer test-secret-key"} @@ -182,6 +211,7 @@ def sample_repo_payload(): "branch": "main" } + @pytest.fixture def malicious_payloads(): """Collection of malicious inputs for security testing""" diff --git a/backend/tests/test_multi_tenancy.py b/backend/tests/test_multi_tenancy.py index 46b7ac1..17f2693 100644 --- a/backend/tests/test_multi_tenancy.py +++ b/backend/tests/test_multi_tenancy.py @@ -9,8 +9,7 @@ 4. 404 is returned (not 403) to prevent info leakage """ import pytest -from unittest.mock import MagicMock, patch, AsyncMock -from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch import sys import os from pathlib import Path @@ -34,10 +33,22 @@ # ============== TEST DATA ============== REPOS_DB = [ - {"id": "repo-user1-a", "name": "User1 Repo A", "user_id": "user-1", "status": "indexed", "local_path": "/repos/1a", "file_count": 10}, - {"id": "repo-user1-b", "name": "User1 Repo B", "user_id": "user-1", "status": "indexed", "local_path": "/repos/1b", "file_count": 20}, - {"id": "repo-user2-a", "name": "User2 Repo A", "user_id": "user-2", "status": "indexed", "local_path": "/repos/2a", "file_count": 15}, - {"id": "repo-user2-b", "name": "User2 Repo B", "user_id": "user-2", "status": "indexed", "local_path": "/repos/2b", "file_count": 25}, + { + "id": "repo-user1-a", "name": "User1 Repo A", "user_id": "user-1", + "status": "indexed", "local_path": "/repos/1a", "file_count": 10 + }, + { + "id": "repo-user1-b", "name": "User1 Repo B", "user_id": "user-1", + "status": "indexed", "local_path": "/repos/1b", "file_count": 20 + }, + { + "id": "repo-user2-a", "name": "User2 Repo A", "user_id": "user-2", + "status": "indexed", "local_path": "/repos/2a", "file_count": 15 + }, + { + "id": "repo-user2-b", "name": "User2 Repo B", "user_id": "user-2", + "status": "indexed", "local_path": "/repos/2b", "file_count": 25 + }, ] @@ -45,19 +56,19 @@ class TestSupabaseServiceOwnership: """Unit tests for ownership verification methods in SupabaseService""" - + def test_list_repositories_for_user_method_exists(self): """list_repositories_for_user method should exist with correct signature""" from services.supabase_service import SupabaseService import inspect - + # Verify method exists assert hasattr(SupabaseService, 'list_repositories_for_user') - + sig = inspect.signature(SupabaseService.list_repositories_for_user) params = list(sig.parameters.keys()) assert 'user_id' in params, "Method should accept user_id parameter" - + def test_get_repository_with_owner_returns_none_for_wrong_user(self): """ get_repository_with_owner should query with both repo_id AND user_id filters. @@ -65,47 +76,47 @@ def test_get_repository_with_owner_returns_none_for_wrong_user(self): The actual SQL filtering is tested via integration tests. """ from services.supabase_service import SupabaseService - + # Verify method exists and has correct signature import inspect sig = inspect.signature(SupabaseService.get_repository_with_owner) params = list(sig.parameters.keys()) - + assert 'repo_id' in params, "Method should accept repo_id parameter" assert 'user_id' in params, "Method should accept user_id parameter" - + def test_verify_repo_ownership_returns_false_for_wrong_user(self): """ verify_repo_ownership should query with both repo_id AND user_id filters. This test verifies the method signature and return type. """ from services.supabase_service import SupabaseService - + # Verify method exists and has correct signature import inspect sig = inspect.signature(SupabaseService.verify_repo_ownership) params = list(sig.parameters.keys()) - + assert 'repo_id' in params, "Method should accept repo_id parameter" assert 'user_id' in params, "Method should accept user_id parameter" - + # Verify return type annotation is bool return_annotation = sig.return_annotation assert return_annotation == bool, "Method should return bool" - + def test_verify_repo_ownership_returns_true_for_owner(self): """verify_repo_ownership method should exist with correct signature""" from services.supabase_service import SupabaseService import inspect - + # Verify method exists assert hasattr(SupabaseService, 'verify_repo_ownership') - + sig = inspect.signature(SupabaseService.verify_repo_ownership) params = list(sig.parameters.keys()) assert 'repo_id' in params assert 'user_id' in params - + # Return type should be bool assert sig.return_annotation == bool @@ -114,40 +125,40 @@ def test_verify_repo_ownership_returns_true_for_owner(self): class TestRepoManagerOwnership: """Unit tests for ownership methods in RepoManager""" - + def test_list_repos_for_user_delegates_to_supabase(self): """list_repos_for_user should call supabase list_repositories_for_user""" with patch('services.repo_manager.get_supabase_service') as mock_get_db: mock_db = MagicMock() mock_db.list_repositories_for_user.return_value = [REPOS_DB[0], REPOS_DB[1]] mock_get_db.return_value = mock_db - + from services.repo_manager import RepositoryManager - + with patch.object(RepositoryManager, '_sync_existing_repos'): manager = RepositoryManager() manager.db = mock_db - + result = manager.list_repos_for_user("user-1") - + mock_db.list_repositories_for_user.assert_called_once_with("user-1") assert len(result) == 2 - + def test_verify_ownership_delegates_to_supabase(self): """verify_ownership should call supabase verify_repo_ownership""" with patch('services.repo_manager.get_supabase_service') as mock_get_db: mock_db = MagicMock() mock_db.verify_repo_ownership.return_value = False mock_get_db.return_value = mock_db - + from services.repo_manager import RepositoryManager - + with patch.object(RepositoryManager, '_sync_existing_repos'): manager = RepositoryManager() manager.db = mock_db - + result = manager.verify_ownership("repo-user2-a", "user-1") - + mock_db.verify_repo_ownership.assert_called_once_with("repo-user2-a", "user-1") assert result is False @@ -156,44 +167,42 @@ def test_verify_ownership_delegates_to_supabase(self): class TestSecurityHelpers: """Test the get_repo_or_404 and verify_repo_access helpers""" - + def test_get_repo_or_404_raises_404_for_wrong_user(self): """get_repo_or_404 should raise 404 if user doesn't own repo""" - with patch('dependencies.repo_manager') as mock_manager: - mock_manager.get_repo_for_user.return_value = None - - from dependencies import get_repo_or_404 - from fastapi import HTTPException - + from dependencies import get_repo_or_404, repo_manager + from fastapi import HTTPException + + with patch.object(repo_manager, 'get_repo_for_user', return_value=None): with pytest.raises(HTTPException) as exc_info: get_repo_or_404("repo-user2-a", "user-1") - + assert exc_info.value.status_code == 404 assert "not found" in exc_info.value.detail.lower() - + def test_get_repo_or_404_returns_repo_for_owner(self): """get_repo_or_404 should return repo if user owns it""" - with patch('dependencies.repo_manager') as mock_manager: - expected_repo = REPOS_DB[0] - mock_manager.get_repo_for_user.return_value = expected_repo - - from dependencies import get_repo_or_404 - + from dependencies import get_repo_or_404, repo_manager + + expected_repo = REPOS_DB[0] + with patch.object(repo_manager, 'get_repo_for_user', return_value=expected_repo): result = get_repo_or_404("repo-user1-a", "user-1") - assert result == expected_repo - + def test_verify_repo_access_raises_404_for_wrong_user(self): """verify_repo_access should raise 404 if user doesn't own repo""" - with patch('dependencies.repo_manager') as mock_manager: - mock_manager.verify_ownership.return_value = False - - from dependencies import verify_repo_access - from fastapi import HTTPException - + from dependencies import verify_repo_access, repo_manager + from fastapi import HTTPException + + with patch.object(repo_manager, 'verify_ownership', return_value=False): with pytest.raises(HTTPException) as exc_info: verify_repo_access("repo-user2-a", "user-1") - + + assert exc_info.value.status_code == 404 + + with pytest.raises(HTTPException) as exc_info: + verify_repo_access("repo-user2-a", "user-1") + assert exc_info.value.status_code == 404 @@ -201,99 +210,109 @@ def test_verify_repo_access_raises_404_for_wrong_user(self): class TestDevApiKeySecurity: """Test that dev API key is properly secured (Issue #8)""" - + + def _reload_auth_module(self): + """Helper to reload auth module with current env vars.""" + import importlib + import middleware.auth as auth_module + importlib.reload(auth_module) + return auth_module + def test_dev_key_without_debug_mode_fails(self): """Dev key should not work without DEBUG=true""" original_debug = os.environ.get("DEBUG") + original_dev_key = os.environ.get("DEV_API_KEY") os.environ["DEBUG"] = "false" - + try: - # Need to reload module to pick up env change - import importlib - import middleware.auth as auth_module - importlib.reload(auth_module) - + auth_module = self._reload_auth_module() result = auth_module._validate_api_key("test-dev-key") assert result is None, "Dev key should not work without DEBUG mode" finally: + # Restore original state os.environ["DEBUG"] = original_debug or "true" - + if original_dev_key: + os.environ["DEV_API_KEY"] = original_dev_key + self._reload_auth_module() # Reload to known good state + def test_dev_key_without_explicit_env_var_fails(self): """Dev key should require explicit DEV_API_KEY env var""" original_debug = os.environ.get("DEBUG") original_dev_key = os.environ.get("DEV_API_KEY") - + os.environ["DEBUG"] = "true" if "DEV_API_KEY" in os.environ: del os.environ["DEV_API_KEY"] - + try: - import importlib - import middleware.auth as auth_module - importlib.reload(auth_module) - + auth_module = self._reload_auth_module() result = auth_module._validate_api_key("some-random-key") assert result is None, "Dev key should not work without explicit DEV_API_KEY" finally: + # Restore original state os.environ["DEBUG"] = original_debug or "true" if original_dev_key: os.environ["DEV_API_KEY"] = original_dev_key - + self._reload_auth_module() # Reload to known good state + def test_dev_key_works_with_debug_and_env_var(self): """Dev key should work when DEBUG=true AND DEV_API_KEY is set""" + original_dev_key = os.environ.get("DEV_API_KEY") os.environ["DEBUG"] = "true" os.environ["DEV_API_KEY"] = "my-secret-dev-key" - + try: - import importlib - import middleware.auth as auth_module - importlib.reload(auth_module) - + auth_module = self._reload_auth_module() result = auth_module._validate_api_key("my-secret-dev-key") assert result is not None, "Dev key should work with DEBUG and DEV_API_KEY" assert result.api_key_name == "development" assert result.tier == "enterprise" finally: - os.environ["DEV_API_KEY"] = "test-dev-key" # Restore - + # Restore original state + if original_dev_key: + os.environ["DEV_API_KEY"] = original_dev_key + else: + os.environ["DEV_API_KEY"] = "test-dev-key" + self._reload_auth_module() # Reload to known good state + def test_wrong_dev_key_fails_even_in_debug(self): """Wrong dev key should fail even in DEBUG mode""" + original_dev_key = os.environ.get("DEV_API_KEY") os.environ["DEBUG"] = "true" os.environ["DEV_API_KEY"] = "correct-key" - + try: - import importlib - import middleware.auth as auth_module - importlib.reload(auth_module) - + auth_module = self._reload_auth_module() result = auth_module._validate_api_key("wrong-key") assert result is None, "Wrong dev key should not work" finally: - os.environ["DEV_API_KEY"] = "test-dev-key" + # Restore original state + if original_dev_key: + os.environ["DEV_API_KEY"] = original_dev_key + else: + os.environ["DEV_API_KEY"] = "test-dev-key" + self._reload_auth_module() # Reload to known good state # ============== INFO LEAKAGE TESTS ============== class TestInfoLeakagePrevention: """Test that 404 is returned instead of 403 to prevent info leakage""" - + def test_nonexistent_and_unauthorized_get_same_error(self): """Both non-existent repo and unauthorized access should return identical 404""" - with patch('dependencies.repo_manager') as mock_manager: - # Both cases return None from get_repo_for_user - mock_manager.get_repo_for_user.return_value = None - - from dependencies import get_repo_or_404 - from fastapi import HTTPException - + from dependencies import get_repo_or_404, repo_manager + from fastapi import HTTPException + + with patch.object(repo_manager, 'get_repo_for_user', return_value=None): # Non-existent repo with pytest.raises(HTTPException) as exc1: get_repo_or_404("does-not-exist", "user-1") - + # Other user's repo (also returns None because no ownership) with pytest.raises(HTTPException) as exc2: get_repo_or_404("repo-user2-a", "user-1") - + # Both should have identical error assert exc1.value.status_code == exc2.value.status_code == 404 assert exc1.value.detail == exc2.value.detail @@ -306,42 +325,42 @@ class TestEndpointOwnershipIntegration: These tests verify that endpoints actually call ownership verification. They mock at the right level to ensure the security helpers are used. """ - + def test_list_repos_calls_user_filtered_method(self): """GET /api/repos should call list_repos_for_user, not list_repos""" # This is a code inspection test - we verify the correct method is called import ast - + with open(backend_dir / "routes" / "repos.py") as f: source = f.read() - + # Check that list_repos_for_user is used in list_repositories function assert "list_repos_for_user" in source, "Should use list_repos_for_user" - + # And that the old unfiltered method is NOT used in that endpoint # (This is a simple check - in production you'd use proper AST analysis) tree = ast.parse(source) - + for node in ast.walk(tree): if isinstance(node, ast.FunctionDef) and node.name == "list_repositories": func_source = ast.unparse(node) assert "list_repos_for_user" in func_source # Make sure we're not calling the unfiltered version assert "repo_manager.list_repos()" not in func_source - + def test_repo_endpoints_use_ownership_verification(self): """All repo-specific endpoints should use get_repo_or_404 or verify_repo_access""" # Check repos.py for index_repository with open(backend_dir / "routes" / "repos.py") as f: repos_source = f.read() - + # Check analysis.py for analysis endpoints with open(backend_dir / "routes" / "analysis.py") as f: analysis_source = f.read() - + # Endpoints in repos.py assert "def index_repository" in repos_source, "Endpoint index_repository not found" - + # Endpoints in analysis.py analysis_endpoints = [ "get_dependency_graph", @@ -349,25 +368,25 @@ def test_repo_endpoints_use_ownership_verification(self): "get_repository_insights", "get_style_analysis", ] - + for endpoint in analysis_endpoints: assert f"def {endpoint}" in analysis_source, f"Endpoint {endpoint} not found" - + # Verify ownership checks exist in each file assert "get_repo_or_404" in repos_source or "verify_repo_access" in repos_source assert "get_repo_or_404" in analysis_source or "verify_repo_access" in analysis_source - + def test_search_endpoint_verifies_repo_ownership(self): """POST /api/search should verify repo ownership""" with open(backend_dir / "routes" / "search.py") as f: source = f.read() - + assert "verify_repo_access" in source, "search_code should verify repo ownership" - + def test_explain_endpoint_verifies_repo_ownership(self): """POST /api/explain should verify repo ownership""" with open(backend_dir / "routes" / "search.py") as f: source = f.read() - + # explain_code is in the same file, check for ownership verification assert "get_repo_or_404" in source, "explain_code should verify repo ownership" diff --git a/backend/tests/test_validate_repo.py b/backend/tests/test_validate_repo.py new file mode 100644 index 0000000..03bd027 --- /dev/null +++ b/backend/tests/test_validate_repo.py @@ -0,0 +1,371 @@ +""" +Tests for the validate-repo endpoint (Issue #124). +Tests GitHub URL validation for anonymous indexing. + +Note: These tests rely on conftest.py for Pinecone/OpenAI mocking. +""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +# Import directly - conftest.py handles external service mocking +from routes.playground import ( + _parse_github_url, + GITHUB_URL_PATTERN, + ANONYMOUS_FILE_LIMIT, + ValidateRepoRequest, +) + + +# ============================================================================= +# URL PARSING TESTS +# ============================================================================= + +class TestParseGitHubUrl: + """Tests for URL parsing.""" + + def test_valid_https_url(self): + owner, repo, error = _parse_github_url("https://github.com/facebook/react") + assert owner == "facebook" + assert repo == "react" + assert error is None + + def test_valid_http_url(self): + owner, repo, error = _parse_github_url("http://github.com/user/repo") + assert owner == "user" + assert repo == "repo" + assert error is None + + def test_url_with_trailing_slash(self): + owner, repo, error = _parse_github_url("https://github.com/owner/repo/") + assert owner == "owner" + assert repo == "repo" + assert error is None + + def test_url_with_dots_and_dashes(self): + owner, repo, error = _parse_github_url( + "https://github.com/my-org/my.repo-name" + ) + assert owner == "my-org" + assert repo == "my.repo-name" + assert error is None + + def test_invalid_url_wrong_domain(self): + owner, repo, error = _parse_github_url("https://gitlab.com/user/repo") + assert owner is None + assert repo is None + assert "Invalid GitHub URL format" in error + + def test_invalid_url_no_repo(self): + owner, repo, error = _parse_github_url("https://github.com/justowner") + assert owner is None + assert error is not None + + def test_invalid_url_with_path(self): + owner, repo, error = _parse_github_url( + "https://github.com/owner/repo/tree/main" + ) + assert owner is None + assert error is not None + + def test_invalid_url_blob_path(self): + owner, repo, error = _parse_github_url( + "https://github.com/owner/repo/blob/main/file.py" + ) + assert owner is None + assert error is not None + + +class TestGitHubUrlPattern: + """Tests for the regex pattern.""" + + def test_pattern_matches_standard(self): + match = GITHUB_URL_PATTERN.match("https://github.com/user/repo") + assert match is not None + assert match.group("owner") == "user" + assert match.group("repo") == "repo" + + def test_pattern_rejects_subpath(self): + match = GITHUB_URL_PATTERN.match("https://github.com/user/repo/issues") + assert match is None + + +# ============================================================================= +# REQUEST MODEL TESTS +# ============================================================================= + +class TestValidateRepoRequest: + """Tests for the request model validation.""" + + def test_invalid_url_format(self): + """Test with malformed URL.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ValidateRepoRequest(github_url="not-a-url") + + def test_non_github_url(self): + """Test with non-GitHub URL.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ValidateRepoRequest(github_url="https://gitlab.com/user/repo") + + def test_empty_url(self): + """Test with empty URL.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ValidateRepoRequest(github_url="") + + def test_valid_request_model(self): + """Test valid request passes validation.""" + req = ValidateRepoRequest(github_url="https://github.com/user/repo") + assert req.github_url == "https://github.com/user/repo" + + def test_url_with_whitespace_trimmed(self): + """Test that whitespace is trimmed.""" + req = ValidateRepoRequest(github_url=" https://github.com/user/repo ") + assert req.github_url == "https://github.com/user/repo" + + +# ============================================================================= +# GITHUB API TESTS +# ============================================================================= + +class TestFetchRepoMetadata: + """Tests for GitHub API interaction.""" + + @pytest.mark.asyncio + async def test_repo_not_found(self): + """Test handling of 404 response.""" + from routes.playground import _fetch_repo_metadata + + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + result = await _fetch_repo_metadata("nonexistent", "repo") + assert result["error"] == "not_found" + + @pytest.mark.asyncio + async def test_rate_limited(self): + """Test handling of 403 rate limit response.""" + from routes.playground import _fetch_repo_metadata + + mock_response = MagicMock() + mock_response.status_code = 403 + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + result = await _fetch_repo_metadata("user", "repo") + assert result["error"] == "rate_limited" + + @pytest.mark.asyncio + async def test_successful_fetch(self): + """Test successful metadata fetch.""" + from routes.playground import _fetch_repo_metadata + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "name": "repo", + "owner": {"login": "user"}, + "private": False, + "default_branch": "main", + "stargazers_count": 100, + "language": "Python", + "size": 1024, + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + result = await _fetch_repo_metadata("user", "repo") + assert result["name"] == "repo" + assert result["private"] is False + assert result["stargazers_count"] == 100 + + @pytest.mark.asyncio + async def test_timeout_handling(self): + """Test timeout is handled gracefully.""" + from routes.playground import _fetch_repo_metadata + import httpx + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.side_effect = httpx.TimeoutException("timeout") + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + result = await _fetch_repo_metadata("user", "repo") + assert result["error"] == "timeout" + + +# ============================================================================= +# FILE COUNTING TESTS +# ============================================================================= + +class TestCountCodeFiles: + """Tests for file counting logic.""" + + @pytest.mark.asyncio + async def test_count_python_files(self): + """Test counting Python files.""" + from routes.playground import _count_code_files + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "truncated": False, + "tree": [ + {"type": "blob", "path": "app.py"}, + {"type": "blob", "path": "utils.py"}, + {"type": "blob", "path": "README.md"}, # Not code + {"type": "tree", "path": "src"}, # Directory + ] + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + count, error = await _count_code_files("user", "repo", "main") + assert count == 2 # Only .py files + assert error is None + + @pytest.mark.asyncio + async def test_skip_node_modules(self): + """Test that node_modules is skipped.""" + from routes.playground import _count_code_files + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "truncated": False, + "tree": [ + {"type": "blob", "path": "index.js"}, + {"type": "blob", "path": "node_modules/lodash/index.js"}, + {"type": "blob", "path": "src/app.js"}, + ] + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + count, error = await _count_code_files("user", "repo", "main") + assert count == 2 # index.js and src/app.js, not node_modules + assert error is None + + @pytest.mark.asyncio + async def test_truncated_tree(self): + """Test handling of truncated tree response.""" + from routes.playground import _count_code_files + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "truncated": True, + "tree": [] + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + count, error = await _count_code_files("user", "repo", "main") + assert count == -1 + assert error == "truncated" + + @pytest.mark.asyncio + async def test_multiple_extensions(self): + """Test counting multiple file types.""" + from routes.playground import _count_code_files + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "truncated": False, + "tree": [ + {"type": "blob", "path": "app.py"}, + {"type": "blob", "path": "utils.js"}, + {"type": "blob", "path": "main.go"}, + {"type": "blob", "path": "lib.rs"}, + {"type": "blob", "path": "config.json"}, # Not code + {"type": "blob", "path": "style.css"}, # Not code + ] + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + count, error = await _count_code_files("user", "repo", "main") + assert count == 4 # py, js, go, rs + assert error is None + + @pytest.mark.asyncio + async def test_skip_git_directory(self): + """Test that .git directory is skipped.""" + from routes.playground import _count_code_files + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "truncated": False, + "tree": [ + {"type": "blob", "path": "app.py"}, + {"type": "blob", "path": ".git/hooks/pre-commit.py"}, + ] + } + + with patch("routes.playground.httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_instance.__aenter__.return_value = mock_instance + mock_instance.__aexit__.return_value = None + mock_client.return_value = mock_instance + + count, error = await _count_code_files("user", "repo", "main") + assert count == 1 # Only app.py + assert error is None + + +# ============================================================================= +# CONSTANTS TESTS +# ============================================================================= + +class TestAnonymousFileLimit: + """Tests for file limit constant.""" + + def test_limit_value(self): + """Verify anonymous file limit is 200.""" + assert ANONYMOUS_FILE_LIMIT == 200