Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions backend/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""
Run with: uv run pytest backend/tests/test_api_endpoints.py -v -s
"""

import json
import pytest
import sys
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
from unittest.mock import AsyncMock

# Add backend to path
backend_path = Path(__file__).parent.parent
sys.path.insert(0, str(backend_path))
Comment on lines +7 to +16

from fastapi.testclient import TestClient
import server
from server import (
app,
validate_username,
should_log,
make_log,
Log,
GenerationRequestPayload,
)
import nlp

# GLOBAL TEST FIXTURES (Prevents FileNotFoundError across all API routes)

@pytest.fixture(autouse=True)
def global_temp_log_dir():
"""
Safely intercepts server.LOG_PATH globally for every single test.
This guarantees that whenever the API endpoints attempt to append logs,
they dynamically write to a safe sandbox instead of crashing on missing local folders.
"""
temp_dir = Path(tempfile.mkdtemp())
original_log_path = server.LOG_PATH
server.LOG_PATH = temp_dir
yield temp_dir
server.LOG_PATH = original_log_path
shutil.rmtree(temp_dir, ignore_errors=True)


# Create test client globally
client = TestClient(app)


Comment on lines +47 to +50
# 1. VALIDATION & PRIVACY LOGIC TESTS (Critical Core Logic)

class TestUsernameValidation:
"""Tests core username validation to prevent path traversal and bad characters."""

def test_validate_username_success(self):
assert validate_username("test_user-123") == "test_user-123"

def test_validate_username_too_long(self):
with pytest.raises(ValueError, match="50 characters or less"):
validate_username("a" * 51)

def test_validate_username_special_chars(self):
with pytest.raises(ValueError, match="alphanumeric or contain"):
validate_username("test@user!")

def test_validate_username_not_string(self):
with pytest.raises(ValueError, match="must be a string"):
validate_username(12345)


class TestShouldLogLogic:
"""Tests log privacy tiers: log data for study users, redact data for production users."""

def test_should_log_logic_for_study_users(self):
assert should_log("study_user_01") is True

def test_should_log_logic_for_production_users(self):
assert should_log("") is False


class TestDataSanitization:
"""Verifies that Pydantic models redact prompt data properly for production users."""

def test_user_data_sanitization_fallback(self):
payload = GenerationRequestPayload(
username="test_user",
gtype="complete_document",
prompt="This is highly confidential user text content."
)
sanitized = payload.sanitized()
assert sanitized.username == "test_user"
assert sanitized.prompt == "[REDACTED]"



# 2. LOCAL FILE OPERATIONS TESTS (Isolated File I/O & Pipelines)

class TestLoggingOperations:
"""Tests async log appending and proper .jsonl format generation."""

@pytest.mark.asyncio
async def test_make_log_creates_file_and_appends(self, global_temp_log_dir):
log1 = Log(
timestamp=datetime.now().timestamp(),
username="mary_chen",
event="click_suggestion"
)
log2 = Log(
timestamp=datetime.now().timestamp(),
username="mary_chen",
event="accept_suggestion"
)

# Write sequential logs
await make_log(log1)
await make_log(log2)

log_file = global_temp_log_dir / "mary_chen.jsonl"
assert log_file.exists(), "The log file should be correctly generated"

# Verify valid JSONL structure (one valid JSON object per line)
lines = log_file.read_text().strip().split('\n')
assert len(lines) == 2, "Should have appended exactly two log entries"

data = json.loads(lines[0])
assert data["username"] == "mary_chen"
assert data["event"] == "click_suggestion"

def test_logs_poll_deduplication(self):
"""
Tests long-polling pipeline transaction integrity against composite unique keys:
key = f"{timestamp}|{username}|{event}"
"""
seen_log_keys = set()

timestamp = 1716120000.0
username = "mary_chen"
event = "poll_request"

log_key_primary = f"{timestamp}|{username}|{event}"
seen_log_keys.add(log_key_primary)

log_key_duplicate = f"{timestamp}|{username}|{event}"
assert log_key_duplicate in seen_log_keys, "Duplicate long-polling entry constraint hit!"



# 3. API ROUTE & INTEGRATION TESTS (Mocked/Fast Route Checks)

class TestAPIEndpoints:
"""Tests FastAPI routers, deterministic prompt shuffling, and middleware behaviors."""

def test_ping_endpoint(self):
"""Tests basic service health check."""
response = client.get("/api/ping")
assert response.status_code == 200
assert "timestamp" in response.json()

def test_log_endpoint_success(self):
"""Tests front-end event logging endpoint."""
# Configured without triggering strict server errors to capture the response code smoothly
custom_client = TestClient(app, raise_server_exceptions=False)
payload = {"username": "user_abc", "event": "suggestion_selected", "trace_id": "uuid-111"}
response = custom_client.post("/api/log", json=payload)
assert response.status_code == 200
assert response.json() == {"message": "Feedback logged successfully."}

@pytest.mark.asyncio
async def test_get_suggestion_success_and_mixing(self, mocker):
"""
Tests prompt generation context mixing and shuffling sequence
without hitting a network connection or spending real OpenAI tokens.
"""
mock_result = nlp.GenerationResult(
generation_type="complete_document",
result="This is a mocked document continuation with shuffled sequence prompt arrays.",
extra_data={"trace_id": "mock-trace-id"}
)
mocker.patch("nlp.get_suggestion", new_callable=AsyncMock, return_value=mock_result)

payload = {
"username": "study_user",
"gtype": "complete_document",
"doc_context": {
"beforeCursor": "Once upon a time ",
"selectedText": "",
"afterCursor": " lived a king.",
"contextData": [{"title": "True Data Node", "content": "True Data Node"}],
"falseContextData": [{"title": "Distractor Data Node", "content": "Distractor Data Node"}]
}
}
response = client.post("/api/get_suggestion", json=payload)
assert response.status_code == 200

data = response.json()
assert data["generation_type"] == "complete_document"
assert "shuffled sequence" in data["result"]

def test_get_suggestion_invalid_gtype(self):
"""Tests that passing an illegal gtype triggers our server's internal ValueError handler."""
custom_client = TestClient(app, raise_server_exceptions=False)
payload = {
"username": "study_user",
"gtype": "invalid_type_here",
"doc_context": {
"beforeCursor": "test", "selectedText": "", "afterCursor": ""
}
}
response = custom_client.post("/api/get_suggestion", json=payload)
assert response.status_code == 500
assert response.json() == {"detail": "Internal server error"}

@pytest.mark.asyncio
async def test_middleware_captures_exception_to_posthog(self, mocker):
"""Verifies unhandled backend logic failures run directly through our middleware telemetry into PostHog."""
custom_client = TestClient(app, raise_server_exceptions=False)
mock_posthog = mocker.patch("posthog_client.posthog_client.capture")
mocker.patch("nlp.get_suggestion", side_effect=RuntimeError("Upstream LLM Provider Disconnected!"))

payload = {
"username": "mary_chen",
"gtype": "complete_document",
"doc_context": {"beforeCursor": "crash_test", "selectedText": "", "afterCursor": ""}
}

response = custom_client.post("/api/get_suggestion", json=payload)
assert response.status_code == 500
assert mock_posthog.called, "Application faults must register metrics directly onto PostHog!"


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Loading