diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..62d0249 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +omit = + */tests/* + */__init__.py +show_missing = false +skip_covered = true +include = * \ No newline at end of file diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml new file mode 100644 index 0000000..c262386 --- /dev/null +++ b/.github/workflows/run-tests.yaml @@ -0,0 +1,53 @@ +name: Run Tests + +on: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Install Rye + uses: eifinger/setup-rye@v4 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + run: | + rye pin ${{ matrix.python-version }} + rye sync + + - name: Run lints + run: rye run lint + + - name: Run tests + run: rye run pytest + + test-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rye + uses: eifinger/setup-rye@v4 + with: + enable-cache: true + + - name: Build package + run: rye build + + - name: Check build artifacts + run: | + ls -la dist/ + # Verify wheel and source distribution were created + test -f dist/*.whl + test -f dist/*.tar.gz diff --git a/.gitignore b/.gitignore index 4f4167b..16a79ef 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ codegen.log Brewfile.lock.json .DS_Store +.coverage \ No newline at end of file diff --git a/examples/demo.py b/examples/demo.py index 023e9ac..4b63f6b 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -9,7 +9,7 @@ client = Atlas() # Evaluations -evaluation = client.evaluations.create(model="random", benchmark="random") +evaluation = client.evaluations.create(model="random_model_id", benchmark="random_benchmark_id") # Results if evaluation is not None: diff --git a/pyproject.toml b/pyproject.toml index 62e0718..2d47a68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,13 +3,8 @@ name = "atlas" version = "1.0.0" description = "The official Python library for the LayerLens Atlas API" license = "Apache-2.0" -authors = [ -{ name = "LayerLens", email = "support@layerlens.ai" }, -] -dependencies = [ - "httpx>=0.23.0, <1", - "pydantic>=1.9.0, <3", -] +authors = [{ name = "LayerLens", email = "support@layerlens.ai" }] +dependencies = ["httpx>=0.23.0, <1", "pydantic>=1.9.0, <3"] requires-python = ">= 3.8" classifiers = [ "Typing :: Typed", @@ -25,7 +20,7 @@ classifiers = [ "Operating System :: MacOS", "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", - "Topic :: Software Development :: Libraries :: Python Modules" + "Topic :: Software Development :: Libraries :: Python Modules", ] [project.urls] @@ -40,10 +35,11 @@ atlas = "atlas.cli:main" managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "pyright==1.1.399", - "mypy", - "pytest", - "ruff", + "pyright==1.1.399", + "mypy", + "pytest", + "ruff", + "pytest-cov>=6.2.1", ] [tool.rye.scripts] @@ -52,26 +48,27 @@ format = { chain = [ "fix:ruff", # run formatting again to fix any inconsistencies when imports are stripped "format:ruff", -]} +] } "format:ruff" = "ruff format" -"lint" = { chain = [ - "check:ruff", - "typecheck", - "check:importable", -]} +"lint" = { chain = ["check:ruff", "typecheck:src", "check:importable"] } +"lint:all" = { chain = ["check:ruff", "typecheck", "check:importable"] } "check:ruff" = "ruff check ." "fix:ruff" = "ruff check --fix ." "check:importable" = "python -c 'import atlas'" -typecheck = { chain = [ - "typecheck:pyright", - "typecheck:mypy" -]} +# Type checking for production code only (excludes tests) +"typecheck:src" = { chain = ["typecheck:pyright:src", "typecheck:mypy:src"] } + +# Type checking for all code including tests +typecheck = { chain = ["typecheck:pyright", "typecheck:mypy"] } + "typecheck:pyright" = "pyright" +"typecheck:pyright:src" = "pyright src" "typecheck:verify-types" = "pyright --verifytypes atlas --ignoreexternal" "typecheck:mypy" = "mypy ." +"typecheck:mypy:src" = "mypy src" [tool.ruff] line-length = 120 @@ -125,4 +122,14 @@ known-first-party = ["openai", "tests"] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] -"examples/**.py" = ["T201", "T203"] \ No newline at end of file +"examples/**.py" = ["T201", "T203"] + +[tool.pyright] +include = ["src", "tests"] +exclude = ["**/__pycache__"] +reportMissingTypeStubs = false + +# Less strict settings for tests +executionEnvironments = [ + { root = "tests", reportGeneralTypeIssues = false, reportOptionalSubscript = false, reportOptionalMemberAccess = false, reportUntypedFunctionDecorator = false, reportUnknownArgumentType = false, reportUnknownMemberType = false, reportUnknownVariableType = false, reportUnnecessaryIsInstance = false, reportUnnecessaryComparison = false, reportArgumentType = false, reportCallIssue = false }, +] diff --git a/requirements-dev.lock b/requirements-dev.lock index 71ca3f7..7d97580 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -17,6 +17,8 @@ anyio==4.9.0 certifi==2025.7.14 # via httpcore # via httpx +coverage==7.10.2 + # via pytest-cov exceptiongroup==1.3.0 # via anyio # via pytest @@ -42,6 +44,7 @@ pathspec==0.12.1 # via mypy pluggy==1.6.0 # via pytest + # via pytest-cov pydantic==2.11.7 # via atlas pydantic-core==2.33.2 @@ -50,10 +53,13 @@ pygments==2.19.2 # via pytest pyright==1.1.399 pytest==8.4.1 + # via pytest-cov +pytest-cov==6.2.1 ruff==0.12.7 sniffio==1.3.1 # via anyio tomli==2.2.1 + # via coverage # via mypy # via pytest typing-extensions==4.14.1 diff --git a/scripts/test b/scripts/test old mode 100644 new mode 100755 diff --git a/src/atlas/_models.py b/src/atlas/_models.py index bf7ce43..0d39f5e 100644 --- a/src/atlas/_models.py +++ b/src/atlas/_models.py @@ -3,7 +3,7 @@ from typing import Dict, List, Union, Optional from datetime import timedelta -from pydantic import Field, BaseModel +from pydantic import Field, BaseModel, ConfigDict class Evaluation(BaseModel): @@ -105,7 +105,6 @@ class CustomBenchmark(BaseModel): class Benchmarks(BaseModel): + model_config = ConfigDict(populate_by_name=True) + benchmarks: List[Union[Benchmark, CustomBenchmark]] = Field(..., alias="datasets") - - class Config: - validate_by_name = True diff --git a/src/atlas/resources/models/models.py b/src/atlas/resources/models/models.py index 5ad09ce..a9aca0c 100644 --- a/src/atlas/resources/models/models.py +++ b/src/atlas/resources/models/models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Union, Literal +from typing import List, Literal import httpx @@ -15,7 +15,7 @@ def get( *, type: Literal["public"] | Literal["custom"], timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, - ) -> List[Union[Model | CustomModel]] | None: + ) -> List[Model | CustomModel] | None: models = self._get( f"/organizations/{self._client.organization_id}/projects/{self._client.project_id}/models", params={ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9de38db --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +import os +from unittest import mock + +import pytest + + +@pytest.fixture +def env_vars(): + """Clean environment variables for testing.""" + env_keys = ["LAYERLENS_ATLAS_API_KEY", "LAYERLENS_ATLAS_ORG_ID", "LAYERLENS_ATLAS_PROJECT_ID"] + original_values = {key: os.environ.get(key) for key in env_keys} + + # Clear environment variables + for key in env_keys: + if key in os.environ: + del os.environ[key] + + yield + + # Restore original values + for key, value in original_values.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + +@pytest.fixture +def mock_env_vars(): + """Mock environment variables with test values.""" + with mock.patch.dict(os.environ, { + "LAYERLENS_ATLAS_API_KEY": "test-api-key", + "LAYERLENS_ATLAS_ORG_ID": "test-org-id", + "LAYERLENS_ATLAS_PROJECT_ID": "test-project-id" + }): + yield \ No newline at end of file diff --git a/tests/resources/test_benchmarks.py b/tests/resources/test_benchmarks.py new file mode 100644 index 0000000..fe589cf --- /dev/null +++ b/tests/resources/test_benchmarks.py @@ -0,0 +1,392 @@ +from unittest.mock import Mock + +import httpx +import pytest + +from atlas._models import Benchmark, Benchmarks as BenchmarksData, CustomBenchmark +from atlas._constants import DEFAULT_TIMEOUT +from atlas.resources.benchmarks.benchmarks import Benchmarks + + +class TestBenchmarks: + """Test Benchmarks resource API methods.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def benchmarks_resource(self, mock_client): + """Benchmarks resource instance.""" + return Benchmarks(mock_client) + + @pytest.fixture + def sample_benchmark_data(self): + """Sample benchmark data for testing.""" + return { + "id": "benchmark-123", + "key": "mmlu", + "name": "MMLU", + "full_description": "Massive Multitask Language Understanding", + "language": "english", + "categories": ["reasoning", "knowledge"], + "subsets": ["math", "science", "history"], + "prompt_count": 15908, + "deprecated": False, + } + + @pytest.fixture + def sample_custom_benchmark_data(self): + """Sample custom benchmark data for testing.""" + return { + "id": "custom-benchmark-456", + "key": "my-benchmark", + "name": "My Custom Benchmark", + "description": "Custom benchmark description", + "system_prompt": "You are a helpful assistant", + "subsets": ["subset1", "subset2"], + "prompt_count": 100, + "version_count": 1, + "regex_pattern": r"Answer: (.+)", + "llm_judge_model_id": "gpt-4", + "custom_instructions": "Rate responses on scale 1-10", + "scoring_metric": "accuracy", + "metrics": ["accuracy", "precision"], + "files": ["data.jsonl"], + "disabled": False, + } + + @pytest.fixture + def mock_public_benchmarks_response(self, sample_benchmark_data): + """Mock BenchmarksData response with public benchmarks.""" + benchmark = Benchmark(**sample_benchmark_data) + return BenchmarksData(datasets=[benchmark]) + + @pytest.fixture + def mock_custom_benchmarks_response(self, sample_custom_benchmark_data): + """Mock BenchmarksData response with custom benchmarks.""" + custom_benchmark = CustomBenchmark(**sample_custom_benchmark_data) + return BenchmarksData(datasets=[custom_benchmark]) + + def test_benchmarks_initialization(self, mock_client): + """Benchmarks resource initializes correctly.""" + benchmarks = Benchmarks(mock_client) + + assert benchmarks._client is mock_client + assert benchmarks._get is mock_client.get_cast + + def test_get_public_benchmarks_success(self, benchmarks_resource, mock_public_benchmarks_response): + """get method returns public benchmarks successfully.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + result = benchmarks_resource.get(type="public") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Benchmark) + assert result[0].name == "MMLU" + assert result[0].key == "mmlu" + + def test_get_custom_benchmarks_success(self, benchmarks_resource, mock_custom_benchmarks_response): + """get method returns custom benchmarks successfully.""" + benchmarks_resource._get.return_value = mock_custom_benchmarks_response + + result = benchmarks_resource.get(type="custom") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], CustomBenchmark) + assert result[0].name == "My Custom Benchmark" + assert result[0].key == "my-benchmark" + + def test_get_benchmarks_request_parameters_public(self, benchmarks_resource, mock_public_benchmarks_response): + """get method makes correct API request for public benchmarks.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + benchmarks_resource.get(type="public") + + benchmarks_resource._get.assert_called_once_with( + "/organizations/org-123/projects/proj-456/benchmarks", + params={"type": "public"}, + timeout=DEFAULT_TIMEOUT, + cast_to=BenchmarksData, + ) + + def test_get_benchmarks_request_parameters_custom(self, benchmarks_resource, mock_custom_benchmarks_response): + """get method makes correct API request for custom benchmarks.""" + benchmarks_resource._get.return_value = mock_custom_benchmarks_response + + benchmarks_resource.get(type="custom") + + benchmarks_resource._get.assert_called_once_with( + "/organizations/org-123/projects/proj-456/benchmarks", + params={"type": "custom"}, + timeout=DEFAULT_TIMEOUT, + cast_to=BenchmarksData, + ) + + def test_get_benchmarks_with_custom_timeout(self, benchmarks_resource, mock_public_benchmarks_response): + """get method accepts custom timeout.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + custom_timeout = 45.0 + + benchmarks_resource.get(type="public", timeout=custom_timeout) + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["timeout"] == custom_timeout + + def test_get_benchmarks_with_httpx_timeout(self, benchmarks_resource, mock_public_benchmarks_response): + """get method accepts httpx.Timeout object.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + custom_timeout = httpx.Timeout(45.0) + + benchmarks_resource.get(type="public", timeout=custom_timeout) + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["timeout"] is custom_timeout + + def test_get_benchmarks_none_response(self, benchmarks_resource): + """get method returns None when response is None.""" + benchmarks_resource._get.return_value = None + + result = benchmarks_resource.get(type="public") + + assert result is None + + def test_get_benchmarks_invalid_response_type(self, benchmarks_resource): + """get method handles non-BenchmarksData response gracefully.""" + benchmarks_resource._get.return_value = "invalid-response" + + result = benchmarks_resource.get(type="public") + + assert result is None + + def test_get_benchmarks_empty_response(self, benchmarks_resource): + """get method returns empty list when no benchmarks in response.""" + empty_response = BenchmarksData(datasets=[]) + benchmarks_resource._get.return_value = empty_response + + result = benchmarks_resource.get(type="public") + + assert result == [] + assert isinstance(result, list) + + def test_get_benchmarks_multiple_items(self, benchmarks_resource, sample_benchmark_data, sample_custom_benchmark_data): + """get method returns multiple benchmarks correctly.""" + _ = sample_custom_benchmark_data # Fixture used for side effects + benchmark = Benchmark(**sample_benchmark_data) + + # Create second benchmark with different data + benchmark2_data = sample_benchmark_data.copy() + benchmark2_data["id"] = "benchmark-456" + benchmark2_data["key"] = "hellaswag" + benchmark2_data["name"] = "HellaSwag" + benchmark2 = Benchmark(**benchmark2_data) + + response = BenchmarksData(datasets=[benchmark, benchmark2]) + benchmarks_resource._get.return_value = response + + result = benchmarks_resource.get(type="public") + + assert len(result) == 2 + assert result[0].key == "mmlu" + assert result[1].key == "hellaswag" + + def test_get_benchmarks_url_construction(self, benchmarks_resource, mock_public_benchmarks_response): + """get method constructs URL correctly with org and project IDs.""" + benchmarks_resource._client.organization_id = "custom-org" + benchmarks_resource._client.project_id = "custom-project" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + benchmarks_resource.get(type="public") + + expected_url = "/organizations/custom-org/projects/custom-project/benchmarks" + call_args = benchmarks_resource._get.call_args + assert call_args[0][0] == expected_url + + @pytest.mark.parametrize("benchmark_type", ["public", "custom"]) + def test_get_benchmarks_type_parameter(self, benchmarks_resource, benchmark_type): + """get method accepts both public and custom types.""" + benchmarks_resource._get.return_value = BenchmarksData(datasets=[]) + + benchmarks_resource.get(type=benchmark_type) + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["params"]["type"] == benchmark_type + + def test_get_benchmarks_cast_to_parameter(self, benchmarks_resource, mock_public_benchmarks_response): + """get method specifies correct cast_to parameter.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + benchmarks_resource.get(type="public") + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["cast_to"] is BenchmarksData + + def test_get_benchmarks_timeout_default(self, benchmarks_resource, mock_public_benchmarks_response): + """get method uses DEFAULT_TIMEOUT when no timeout specified.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + benchmarks_resource.get(type="public") + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT + + def test_get_benchmarks_with_none_timeout(self, benchmarks_resource, mock_public_benchmarks_response): + """get method accepts None timeout.""" + benchmarks_resource._get.return_value = mock_public_benchmarks_response + + benchmarks_resource.get(type="public", timeout=None) + + call_args = benchmarks_resource._get.call_args + assert call_args.kwargs["timeout"] is None + + +class TestBenchmarksErrorHandling: + """Test error handling in Benchmarks resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def benchmarks_resource(self, mock_client): + """Benchmarks resource instance.""" + return Benchmarks(mock_client) + + def test_get_benchmarks_handles_api_error(self, benchmarks_resource): + """get method propagates API errors.""" + from atlas._exceptions import APIStatusError + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + + api_error = APIStatusError("Not Found", response=mock_response, body=None) + benchmarks_resource._get.side_effect = api_error + + with pytest.raises(APIStatusError): + benchmarks_resource.get(type="public") + + def test_get_benchmarks_handles_auth_error(self, benchmarks_resource): + """get method propagates authentication errors.""" + from atlas._exceptions import AuthenticationError + + mock_response = Mock() + mock_response.status_code = 401 + mock_response.headers = {} + + auth_error = AuthenticationError("Unauthorized", response=mock_response, body=None) + benchmarks_resource._get.side_effect = auth_error + + with pytest.raises(AuthenticationError): + benchmarks_resource.get(type="custom") + + def test_get_benchmarks_handles_connection_error(self, benchmarks_resource): + """get method propagates connection errors.""" + from atlas._exceptions import APIConnectionError + + mock_request = Mock() + connection_error = APIConnectionError(request=mock_request) + benchmarks_resource._get.side_effect = connection_error + + with pytest.raises(APIConnectionError): + benchmarks_resource.get(type="public") + + def test_get_benchmarks_handles_timeout_error(self, benchmarks_resource): + """get method propagates timeout errors.""" + from atlas._exceptions import APITimeoutError + + mock_request = Mock() + timeout_error = APITimeoutError(mock_request) + benchmarks_resource._get.side_effect = timeout_error + + with pytest.raises(APITimeoutError): + benchmarks_resource.get(type="public", timeout=1.0) + + +class TestBenchmarksTyping: + """Test type handling in Benchmarks resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def benchmarks_resource(self, mock_client): + """Benchmarks resource instance.""" + return Benchmarks(mock_client) + + def test_get_benchmarks_return_type_consistency(self, benchmarks_resource): + """get method returns consistent types.""" + # Test that the method returns either a list or None + benchmarks_resource._get.return_value = None + result = benchmarks_resource.get(type="public") + assert result is None + + # Test that it returns a list when successful + benchmarks_resource._get.return_value = BenchmarksData(datasets=[]) + result = benchmarks_resource.get(type="public") + assert isinstance(result, list) + + def test_get_benchmarks_mixed_benchmark_types(self, benchmarks_resource): + """get method can handle mixed benchmark types in response.""" + # Create mixed response with both Benchmark and CustomBenchmark + public_data = { + "id": "public-123", + "key": "mmlu", + "name": "MMLU", + "full_description": "Public benchmark", + "language": "english", + "categories": ["reasoning"], + "subsets": ["math"], + "prompt_count": 1000, + "deprecated": False, + } + + custom_data = { + "id": "custom-456", + "key": "my-bench", + "name": "My Benchmark", + "description": "Custom benchmark", + "system_prompt": None, + "subsets": ["custom"], + "prompt_count": 50, + "version_count": 1, + "regex_pattern": None, + "llm_judge_model_id": "gpt-4", + "custom_instructions": "Custom instructions", + "scoring_metric": None, + "metrics": ["accuracy"], + "files": ["test.jsonl"], + "disabled": False, + } + + public_benchmark = Benchmark(**public_data) + custom_benchmark = CustomBenchmark(**custom_data) + + response = BenchmarksData(datasets=[public_benchmark, custom_benchmark]) + benchmarks_resource._get.return_value = response + + result = benchmarks_resource.get(type="public") # Type doesn't matter for this test + + assert len(result) == 2 + assert isinstance(result[0], Benchmark) + assert isinstance(result[1], CustomBenchmark) + assert result[0].key == "mmlu" + assert result[1].key == "my-bench" \ No newline at end of file diff --git a/tests/resources/test_evaluations.py b/tests/resources/test_evaluations.py new file mode 100644 index 0000000..2f4c444 --- /dev/null +++ b/tests/resources/test_evaluations.py @@ -0,0 +1,335 @@ +from unittest.mock import Mock + +import httpx +import pytest + +from atlas._models import Evaluation, Evaluations as EvaluationsData +from atlas._constants import DEFAULT_TIMEOUT +from atlas.resources.evaluations.evaluations import Evaluations + + +class TestEvaluations: + """Test Evaluations resource API methods.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + client.post_cast = Mock() + return client + + @pytest.fixture + def evaluations_resource(self, mock_client): + """Evaluations resource instance.""" + return Evaluations(mock_client) + + @pytest.fixture + def sample_evaluation_data(self): + """Sample evaluation data for testing.""" + return { + "id": "eval-123", + "status": "completed", + "status_description": "Evaluation completed successfully", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-456", + "model_name": "GPT-4", + "model_key": "gpt-4", + "model_company": "OpenAI", + "dataset_id": "dataset-789", + "dataset_name": "MMLU", + "average_duration": 2500, + "readability_score": 0.85, + "toxicity_score": 0.02, + "ethics_score": 0.92, + "accuracy": 0.89, + } + + @pytest.fixture + def mock_evaluations_response(self, sample_evaluation_data): + """Mock EvaluationsData response.""" + evaluation = Evaluation(**sample_evaluation_data) + return EvaluationsData(data=[evaluation]) + + def test_evaluations_initialization(self, mock_client): + """Evaluations resource initializes correctly.""" + evaluations = Evaluations(mock_client) + + assert evaluations._client is mock_client + assert evaluations._get is mock_client.get_cast + assert evaluations._post is mock_client.post_cast + + def test_create_evaluation_success(self, evaluations_resource, mock_evaluations_response): + """create method returns first evaluation on success.""" + evaluations_resource._post.return_value = mock_evaluations_response + + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + assert isinstance(result, Evaluation) + assert result.id == "eval-123" + assert result.model_name == "GPT-4" + assert result.dataset_name == "MMLU" + + def test_create_evaluation_request_parameters(self, evaluations_resource, mock_evaluations_response): + """create method makes correct API request.""" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + evaluations_resource._post.assert_called_once_with( + "/organizations/org-123/projects/proj-456/evaluations", + body=[{ + "model_id": "gpt-4", + "dataset_id": "mmlu", + "is_custom_model": False, + "is_custom_dataset": False, + }], + timeout=DEFAULT_TIMEOUT, + cast_to=EvaluationsData, + ) + + def test_create_evaluation_with_custom_timeout(self, evaluations_resource, mock_evaluations_response): + """create method accepts custom timeout.""" + evaluations_resource._post.return_value = mock_evaluations_response + custom_timeout = 30.0 + + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=custom_timeout) + + call_args = evaluations_resource._post.call_args + assert call_args.kwargs["timeout"] == custom_timeout + + def test_create_evaluation_with_httpx_timeout(self, evaluations_resource, mock_evaluations_response): + """create method accepts httpx.Timeout object.""" + evaluations_resource._post.return_value = mock_evaluations_response + custom_timeout = httpx.Timeout(30.0) + + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=custom_timeout) + + call_args = evaluations_resource._post.call_args + assert call_args.kwargs["timeout"] is custom_timeout + + def test_create_evaluation_empty_response(self, evaluations_resource): + """create method returns None when no evaluations in response.""" + empty_response = EvaluationsData(data=[]) + evaluations_resource._post.return_value = empty_response + + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + assert result is None + + def test_create_evaluation_none_response(self, evaluations_resource): + """create method returns None when response is None.""" + evaluations_resource._post.return_value = None + + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + assert result is None + + def test_create_evaluation_invalid_response_type(self, evaluations_resource): + """create method handles non-EvaluationsData response gracefully.""" + evaluations_resource._post.return_value = "invalid-response" + + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + assert result is None + + def test_create_evaluation_multiple_evaluations_returns_first(self, evaluations_resource, sample_evaluation_data): + """create method returns first evaluation when multiple exist.""" + eval1 = Evaluation(**sample_evaluation_data) + eval2_data = sample_evaluation_data.copy() + eval2_data["id"] = "eval-456" + eval2 = Evaluation(**eval2_data) + + response = EvaluationsData(data=[eval1, eval2]) + evaluations_resource._post.return_value = response + + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + assert result.id == "eval-123" # First evaluation + assert result is not eval2 + + def test_create_evaluation_url_construction(self, evaluations_resource, mock_evaluations_response): + """create method constructs URL correctly with org and project IDs.""" + evaluations_resource._client.organization_id = "custom-org" + evaluations_resource._client.project_id = "custom-project" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="test-model", benchmark="test-benchmark") + + expected_url = "/organizations/custom-org/projects/custom-project/evaluations" + call_args = evaluations_resource._post.call_args + assert call_args[0][0] == expected_url + + def test_create_evaluation_request_body_structure(self, evaluations_resource, mock_evaluations_response): + """create method sends correct request body structure.""" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="custom-model", benchmark="custom-benchmark") + + call_args = evaluations_resource._post.call_args + body = call_args.kwargs["body"] + + assert isinstance(body, list) + assert len(body) == 1 + assert body[0]["model_id"] == "custom-model" + assert body[0]["dataset_id"] == "custom-benchmark" + assert body[0]["is_custom_model"] is False + assert body[0]["is_custom_dataset"] is False + + @pytest.mark.parametrize("model_name,benchmark_name", [ + ("gpt-3.5-turbo", "hellaswag"), + ("claude-3-opus", "arc-challenge"), + ("llama-2-70b", "truthfulqa"), + ("custom-model-123", "custom-benchmark-456"), + ]) + def test_create_evaluation_with_different_parameters(self, evaluations_resource, mock_evaluations_response, model_name, benchmark_name): + """create method works with various model and benchmark combinations.""" + evaluations_resource._post.return_value = mock_evaluations_response + + result = evaluations_resource.create(model=model_name, benchmark=benchmark_name) + + assert isinstance(result, Evaluation) + call_args = evaluations_resource._post.call_args + body = call_args.kwargs["body"][0] + assert body["model_id"] == model_name + assert body["dataset_id"] == benchmark_name + + def test_create_evaluation_cast_to_parameter(self, evaluations_resource, mock_evaluations_response): + """create method specifies correct cast_to parameter.""" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + call_args = evaluations_resource._post.call_args + assert call_args.kwargs["cast_to"] is EvaluationsData + + def test_create_evaluation_timeout_default(self, evaluations_resource, mock_evaluations_response): + """create method uses DEFAULT_TIMEOUT when no timeout specified.""" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + call_args = evaluations_resource._post.call_args + assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT + + def test_create_evaluation_with_none_timeout(self, evaluations_resource, mock_evaluations_response): + """create method accepts None timeout.""" + evaluations_resource._post.return_value = mock_evaluations_response + + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=None) + + call_args = evaluations_resource._post.call_args + assert call_args.kwargs["timeout"] is None + + +class TestEvaluationsErrorHandling: + """Test error handling in Evaluations resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.post_cast = Mock() + return client + + @pytest.fixture + def evaluations_resource(self, mock_client): + """Evaluations resource instance.""" + return Evaluations(mock_client) + + def test_create_evaluation_handles_api_error(self, evaluations_resource): + """create method propagates API errors.""" + from atlas._exceptions import APIStatusError + + mock_response = Mock() + mock_response.status_code = 400 + mock_response.headers = {} + + api_error = APIStatusError("Bad Request", response=mock_response, body=None) + evaluations_resource._post.side_effect = api_error + + with pytest.raises(APIStatusError): + evaluations_resource.create(model="invalid-model", benchmark="invalid-benchmark") + + def test_create_evaluation_handles_connection_error(self, evaluations_resource): + """create method propagates connection errors.""" + from atlas._exceptions import APIConnectionError + + mock_request = Mock() + connection_error = APIConnectionError(request=mock_request) + evaluations_resource._post.side_effect = connection_error + + with pytest.raises(APIConnectionError): + evaluations_resource.create(model="gpt-4", benchmark="mmlu") + + def test_create_evaluation_handles_timeout_error(self, evaluations_resource): + """create method propagates timeout errors.""" + from atlas._exceptions import APITimeoutError + + mock_request = Mock() + timeout_error = APITimeoutError(mock_request) + evaluations_resource._post.side_effect = timeout_error + + with pytest.raises(APITimeoutError): + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=1.0) + + +class TestEvaluationsResourceIntegration: + """Integration-style tests for Evaluations resource.""" + + def test_create_evaluation_end_to_end_flow(self): + """Test complete evaluation creation flow.""" + # Mock the full chain: client -> resource -> API call -> response + mock_client = Mock() + mock_client.organization_id = "test-org" + mock_client.project_id = "test-project" + + # Create sample evaluation data + evaluation_data = { + "id": "eval-integration-test", + "status": "submitted", + "status_description": "Evaluation submitted", + "submitted_at": 1640995200, + "finished_at": 0, + "model_id": "integration-model", + "model_name": "Integration Test Model", + "model_key": "integration-model", + "model_company": "TestCorp", + "dataset_id": "integration-dataset", + "dataset_name": "Integration Test Dataset", + "average_duration": 0, + "readability_score": 0.0, + "toxicity_score": 0.0, + "ethics_score": 0.0, + "accuracy": 0.0, + } + + evaluation = Evaluation(**evaluation_data) + response = EvaluationsData(data=[evaluation]) + mock_client.post_cast.return_value = response + + # Test the resource + evaluations_resource = Evaluations(mock_client) + result = evaluations_resource.create( + model="integration-model", + benchmark="integration-dataset" + ) + + # Verify the complete flow + assert result is not None + assert result.id == "eval-integration-test" + assert result.model_id == "integration-model" + assert result.dataset_id == "integration-dataset" + assert result.status == "submitted" + + # Verify the API call was made correctly + mock_client.post_cast.assert_called_once() + call_args = mock_client.post_cast.call_args + assert "/organizations/test-org/projects/test-project/evaluations" in call_args[0][0] + assert call_args.kwargs["body"][0]["model_id"] == "integration-model" + assert call_args.kwargs["body"][0]["dataset_id"] == "integration-dataset" \ No newline at end of file diff --git a/tests/resources/test_models_resource.py b/tests/resources/test_models_resource.py new file mode 100644 index 0000000..f1bb39c --- /dev/null +++ b/tests/resources/test_models_resource.py @@ -0,0 +1,449 @@ +from unittest.mock import Mock + +import httpx +import pytest + +from atlas._models import Model, Models as ModelsData, CustomModel +from atlas._constants import DEFAULT_TIMEOUT +from atlas.resources.models.models import Models + + +class TestModels: + """Test Models resource API methods.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def models_resource(self, mock_client): + """Models resource instance.""" + return Models(mock_client) + + @pytest.fixture + def sample_model_data(self): + """Sample model data for testing.""" + return { + "id": "model-123", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "Large language model", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + @pytest.fixture + def sample_custom_model_data(self): + """Sample custom model data for testing.""" + return { + "id": "custom-model-456", + "key": "my-model", + "name": "My Custom Model", + "description": "Custom model description", + "max_tokens": 4096, + "api_url": "https://api.example.com/v1/chat", + "disabled": False, + } + + @pytest.fixture + def mock_public_models_response(self, sample_model_data): + """Mock ModelsData response with public models.""" + model = Model(**sample_model_data) + return ModelsData(models=[model]) + + @pytest.fixture + def mock_custom_models_response(self, sample_custom_model_data): + """Mock ModelsData response with custom models.""" + custom_model = CustomModel(**sample_custom_model_data) + return ModelsData(models=[custom_model]) + + def test_models_initialization(self, mock_client): + """Models resource initializes correctly.""" + models = Models(mock_client) + + assert models._client is mock_client + assert models._get is mock_client.get_cast + + def test_get_public_models_success(self, models_resource, mock_public_models_response): + """get method returns public models successfully.""" + models_resource._get.return_value = mock_public_models_response + + result = models_resource.get(type="public") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Model) + assert result[0].name == "GPT-4" + assert result[0].key == "gpt-4" + assert result[0].company == "OpenAI" + + def test_get_custom_models_success(self, models_resource, mock_custom_models_response): + """get method returns custom models successfully.""" + models_resource._get.return_value = mock_custom_models_response + + result = models_resource.get(type="custom") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], CustomModel) + assert result[0].name == "My Custom Model" + assert result[0].key == "my-model" + assert result[0].api_url == "https://api.example.com/v1/chat" + + def test_get_models_request_parameters_public(self, models_resource, mock_public_models_response): + """get method makes correct API request for public models.""" + models_resource._get.return_value = mock_public_models_response + + models_resource.get(type="public") + + models_resource._get.assert_called_once_with( + "/organizations/org-123/projects/proj-456/models", + params={"type": "public"}, + timeout=DEFAULT_TIMEOUT, + cast_to=ModelsData, + ) + + def test_get_models_request_parameters_custom(self, models_resource, mock_custom_models_response): + """get method makes correct API request for custom models.""" + models_resource._get.return_value = mock_custom_models_response + + models_resource.get(type="custom") + + models_resource._get.assert_called_once_with( + "/organizations/org-123/projects/proj-456/models", + params={"type": "custom"}, + timeout=DEFAULT_TIMEOUT, + cast_to=ModelsData, + ) + + def test_get_models_with_custom_timeout(self, models_resource, mock_public_models_response): + """get method accepts custom timeout.""" + models_resource._get.return_value = mock_public_models_response + custom_timeout = 60.0 + + models_resource.get(type="public", timeout=custom_timeout) + + call_args = models_resource._get.call_args + assert call_args.kwargs["timeout"] == custom_timeout + + def test_get_models_with_httpx_timeout(self, models_resource, mock_public_models_response): + """get method accepts httpx.Timeout object.""" + models_resource._get.return_value = mock_public_models_response + custom_timeout = httpx.Timeout(60.0) + + models_resource.get(type="public", timeout=custom_timeout) + + call_args = models_resource._get.call_args + assert call_args.kwargs["timeout"] is custom_timeout + + def test_get_models_none_response(self, models_resource): + """get method returns None when response is None.""" + models_resource._get.return_value = None + + result = models_resource.get(type="public") + + assert result is None + + def test_get_models_invalid_response_type(self, models_resource): + """get method handles non-ModelsData response gracefully.""" + models_resource._get.return_value = "invalid-response" + + result = models_resource.get(type="public") + + assert result is None + + def test_get_models_empty_response(self, models_resource): + """get method returns empty list when no models in response.""" + empty_response = ModelsData(models=[]) + models_resource._get.return_value = empty_response + + result = models_resource.get(type="public") + + assert result == [] + assert isinstance(result, list) + + def test_get_models_multiple_items(self, models_resource, sample_model_data): + """get method returns multiple models correctly.""" + model1 = Model(**sample_model_data) + + # Create second model with different data + model2_data = sample_model_data.copy() + model2_data["id"] = "model-456" + model2_data["key"] = "gpt-3.5-turbo" + model2_data["name"] = "GPT-3.5 Turbo" + model2_data["parameters"] = 1.75e11 + model2 = Model(**model2_data) + + response = ModelsData(models=[model1, model2]) + models_resource._get.return_value = response + + result = models_resource.get(type="public") + + assert len(result) == 2 + assert result[0].key == "gpt-4" + assert result[1].key == "gpt-3.5-turbo" + assert result[0].parameters == 1.76e12 + assert result[1].parameters == 1.75e11 + + def test_get_models_url_construction(self, models_resource, mock_public_models_response): + """get method constructs URL correctly with org and project IDs.""" + models_resource._client.organization_id = "custom-org" + models_resource._client.project_id = "custom-project" + models_resource._get.return_value = mock_public_models_response + + models_resource.get(type="public") + + expected_url = "/organizations/custom-org/projects/custom-project/models" + call_args = models_resource._get.call_args + assert call_args[0][0] == expected_url + + @pytest.mark.parametrize("model_type", ["public", "custom"]) + def test_get_models_type_parameter(self, models_resource, model_type): + """get method accepts both public and custom types.""" + models_resource._get.return_value = ModelsData(models=[]) + + models_resource.get(type=model_type) + + call_args = models_resource._get.call_args + assert call_args.kwargs["params"]["type"] == model_type + + def test_get_models_cast_to_parameter(self, models_resource, mock_public_models_response): + """get method specifies correct cast_to parameter.""" + models_resource._get.return_value = mock_public_models_response + + models_resource.get(type="public") + + call_args = models_resource._get.call_args + assert call_args.kwargs["cast_to"] is ModelsData + + def test_get_models_timeout_default(self, models_resource, mock_public_models_response): + """get method uses DEFAULT_TIMEOUT when no timeout specified.""" + models_resource._get.return_value = mock_public_models_response + + models_resource.get(type="public") + + call_args = models_resource._get.call_args + assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT + + def test_get_models_with_none_timeout(self, models_resource, mock_public_models_response): + """get method accepts None timeout.""" + models_resource._get.return_value = mock_public_models_response + + models_resource.get(type="public", timeout=None) + + call_args = models_resource._get.call_args + assert call_args.kwargs["timeout"] is None + + def test_get_models_model_attributes(self, models_resource, mock_public_models_response): + """get method preserves all model attributes correctly.""" + models_resource._get.return_value = mock_public_models_response + + result = models_resource.get(type="public") + model = result[0] + + assert model.context_length == 8192 + assert model.open_weights is False + assert model.deprecated is False + assert model.region == "us-east-1" + assert model.license == "proprietary" + assert model.architecture_type == "transformer" + assert model.modality == "text" + + def test_get_models_custom_model_attributes(self, models_resource, mock_custom_models_response): + """get method preserves all custom model attributes correctly.""" + models_resource._get.return_value = mock_custom_models_response + + result = models_resource.get(type="custom") + custom_model = result[0] + + assert custom_model.max_tokens == 4096 + assert custom_model.disabled is False + assert custom_model.api_url == "https://api.example.com/v1/chat" + + +class TestModelsErrorHandling: + """Test error handling in Models resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def models_resource(self, mock_client): + """Models resource instance.""" + return Models(mock_client) + + def test_get_models_handles_api_error(self, models_resource): + """get method propagates API errors.""" + from atlas._exceptions import APIStatusError + + mock_response = Mock() + mock_response.status_code = 500 + mock_response.headers = {} + + api_error = APIStatusError("Internal Server Error", response=mock_response, body=None) + models_resource._get.side_effect = api_error + + with pytest.raises(APIStatusError): + models_resource.get(type="public") + + def test_get_models_handles_forbidden_error(self, models_resource): + """get method propagates permission errors.""" + from atlas._exceptions import PermissionDeniedError + + mock_response = Mock() + mock_response.status_code = 403 + mock_response.headers = {} + + permission_error = PermissionDeniedError("Forbidden", response=mock_response, body=None) + models_resource._get.side_effect = permission_error + + with pytest.raises(PermissionDeniedError): + models_resource.get(type="custom") + + def test_get_models_handles_connection_error(self, models_resource): + """get method propagates connection errors.""" + from atlas._exceptions import APIConnectionError + + mock_request = Mock() + connection_error = APIConnectionError(request=mock_request) + models_resource._get.side_effect = connection_error + + with pytest.raises(APIConnectionError): + models_resource.get(type="public") + + def test_get_models_handles_timeout_error(self, models_resource): + """get method propagates timeout errors.""" + from atlas._exceptions import APITimeoutError + + mock_request = Mock() + timeout_error = APITimeoutError(mock_request) + models_resource._get.side_effect = timeout_error + + with pytest.raises(APITimeoutError): + models_resource.get(type="public", timeout=5.0) + + +class TestModelsTyping: + """Test type handling in Models resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + return client + + @pytest.fixture + def models_resource(self, mock_client): + """Models resource instance.""" + return Models(mock_client) + + def test_get_models_return_type_consistency(self, models_resource): + """get method returns consistent types.""" + # Test that the method returns either a list or None + models_resource._get.return_value = None + result = models_resource.get(type="public") + assert result is None + + # Test that it returns a list when successful + models_resource._get.return_value = ModelsData(models=[]) + result = models_resource.get(type="public") + assert isinstance(result, list) + + def test_get_models_mixed_model_types(self, models_resource): + """get method can handle mixed model types in response.""" + # Create mixed response with both Model and CustomModel + public_data = { + "id": "public-123", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "Public model", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + custom_data = { + "id": "custom-456", + "key": "my-model", + "name": "My Custom Model", + "description": "Custom model", + "max_tokens": 4096, + "api_url": "https://api.example.com/v1/chat", + "disabled": False, + } + + public_model = Model(**public_data) + custom_model = CustomModel(**custom_data) + + response = ModelsData(models=[public_model, custom_model]) + models_resource._get.return_value = response + + result = models_resource.get(type="public") # Type doesn't matter for this test + + assert len(result) == 2 + assert isinstance(result[0], Model) + assert isinstance(result[1], CustomModel) + assert result[0].key == "gpt-4" + assert result[1].key == "my-model" + assert hasattr(result[0], 'parameters') # Model-specific attribute + assert hasattr(result[1], 'max_tokens') # CustomModel-specific attribute + + def test_get_models_large_parameters_handling(self, models_resource): + """get method handles large parameter numbers correctly.""" + large_model_data = { + "id": "large-model", + "key": "claude-3-opus", + "name": "Claude 3 Opus", + "company": "Anthropic", + "description": "Very large language model", + "released_at": 1709251200, + "parameters": 1.3e14, # 130 trillion parameters + "modality": "text", + "context_length": 200000, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-west-2", + "deprecated": False, + } + + large_model = Model(**large_model_data) + response = ModelsData(models=[large_model]) + models_resource._get.return_value = response + + result = models_resource.get(type="public") + + assert len(result) == 1 + assert result[0].parameters == 1.3e14 + assert result[0].context_length == 200000 + assert isinstance(result[0].parameters, float) + assert isinstance(result[0].context_length, int) \ No newline at end of file diff --git a/tests/resources/test_results.py b/tests/resources/test_results.py new file mode 100644 index 0000000..26c610a --- /dev/null +++ b/tests/resources/test_results.py @@ -0,0 +1,447 @@ +from datetime import timedelta +from unittest.mock import Mock + +import httpx +import pytest + +from atlas._models import Result, Results as ResultsData +from atlas._constants import DEFAULT_TIMEOUT +from atlas.resources.results.results import Results + + +class TestResults: + """Test Results resource API methods.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.get_cast = Mock() + return client + + @pytest.fixture + def results_resource(self, mock_client): + """Results resource instance.""" + return Results(mock_client) + + @pytest.fixture + def sample_result_data(self): + """Sample result data for testing.""" + return { + "subset": "mathematics", + "prompt": "What is the derivative of x^2?", + "result": "2x", + "truth": "2x", + "duration": timedelta(seconds=2.5), + "score": 1.0, + "metrics": { + "accuracy": 1.0, + "confidence": 0.95, + "reasoning_quality": 0.9 + } + } + + @pytest.fixture + def mock_results_response(self, sample_result_data): + """Mock ResultsData response.""" + result = Result(**sample_result_data) + return ResultsData(results=[result]) + + def test_results_initialization(self, mock_client): + """Results resource initializes correctly.""" + results = Results(mock_client) + + assert results._client is mock_client + assert results._get is mock_client.get_cast + + def test_get_results_success(self, results_resource, mock_results_response): + """get method returns results successfully.""" + results_resource._get.return_value = mock_results_response + + result = results_resource.get(evaluation_id="eval-123") + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Result) + assert result[0].subset == "mathematics" + assert result[0].prompt == "What is the derivative of x^2?" + assert result[0].result == "2x" + assert result[0].score == 1.0 + + def test_get_results_request_parameters(self, results_resource, mock_results_response): + """get method makes correct API request.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="eval-456") + + results_resource._get.assert_called_once_with( + "/results", + params={"evaluation_id": "eval-456"}, + timeout=DEFAULT_TIMEOUT, + cast_to=ResultsData, + ) + + def test_get_results_with_custom_timeout(self, results_resource, mock_results_response): + """get method accepts custom timeout.""" + results_resource._get.return_value = mock_results_response + custom_timeout = 120.0 + + results_resource.get(evaluation_id="eval-123", timeout=custom_timeout) + + call_args = results_resource._get.call_args + assert call_args.kwargs["timeout"] == custom_timeout + + def test_get_results_with_httpx_timeout(self, results_resource, mock_results_response): + """get method accepts httpx.Timeout object.""" + results_resource._get.return_value = mock_results_response + custom_timeout = httpx.Timeout(120.0) + + results_resource.get(evaluation_id="eval-123", timeout=custom_timeout) + + call_args = results_resource._get.call_args + assert call_args.kwargs["timeout"] is custom_timeout + + def test_get_results_none_response(self, results_resource): + """get method returns None when response is None.""" + results_resource._get.return_value = None + + result = results_resource.get(evaluation_id="eval-123") + + assert result is None + + def test_get_results_invalid_response_type(self, results_resource): + """get method handles non-ResultsData response gracefully.""" + results_resource._get.return_value = "invalid-response" + + result = results_resource.get(evaluation_id="eval-123") + + assert result is None + + def test_get_results_empty_response(self, results_resource): + """get method returns empty list when no results in response.""" + empty_response = ResultsData(results=[]) + results_resource._get.return_value = empty_response + + result = results_resource.get(evaluation_id="eval-123") + + assert result == [] + assert isinstance(result, list) + + def test_get_results_multiple_items(self, results_resource, sample_result_data): + """get method returns multiple results correctly.""" + result1 = Result(**sample_result_data) + + # Create second result with different data + result2_data = sample_result_data.copy() + result2_data["subset"] = "science" + result2_data["prompt"] = "What is photosynthesis?" + result2_data["result"] = "Process of converting light to energy" + result2_data["truth"] = "Process of converting light to energy" + result2_data["score"] = 0.95 + result2_data["duration"] = timedelta(seconds=3.2) + result2 = Result(**result2_data) + + response = ResultsData(results=[result1, result2]) + results_resource._get.return_value = response + + result = results_resource.get(evaluation_id="eval-123") + + assert len(result) == 2 + assert result[0].subset == "mathematics" + assert result[1].subset == "science" + assert result[0].score == 1.0 + assert result[1].score == 0.95 + + def test_get_results_url_construction(self, results_resource, mock_results_response): + """get method uses correct URL endpoint.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="eval-123") + + call_args = results_resource._get.call_args + assert call_args[0][0] == "/results" + + def test_get_results_evaluation_id_parameter(self, results_resource, mock_results_response): + """get method correctly passes evaluation_id parameter.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="test-eval-789") + + call_args = results_resource._get.call_args + assert call_args.kwargs["params"]["evaluation_id"] == "test-eval-789" + + def test_get_results_cast_to_parameter(self, results_resource, mock_results_response): + """get method specifies correct cast_to parameter.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="eval-123") + + call_args = results_resource._get.call_args + assert call_args.kwargs["cast_to"] is ResultsData + + def test_get_results_timeout_default(self, results_resource, mock_results_response): + """get method uses DEFAULT_TIMEOUT when no timeout specified.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="eval-123") + + call_args = results_resource._get.call_args + assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT + + def test_get_results_with_none_timeout(self, results_resource, mock_results_response): + """get method accepts None timeout.""" + results_resource._get.return_value = mock_results_response + + results_resource.get(evaluation_id="eval-123", timeout=None) + + call_args = results_resource._get.call_args + assert call_args.kwargs["timeout"] is None + + def test_get_results_preserves_result_attributes(self, results_resource, mock_results_response): + """get method preserves all result attributes correctly.""" + results_resource._get.return_value = mock_results_response + + result = results_resource.get(evaluation_id="eval-123") + result_item = result[0] + + assert isinstance(result_item.duration, timedelta) + assert result_item.duration.total_seconds() == 2.5 + assert isinstance(result_item.metrics, dict) + assert result_item.metrics["accuracy"] == 1.0 + assert result_item.metrics["confidence"] == 0.95 + assert result_item.metrics["reasoning_quality"] == 0.9 + + @pytest.mark.parametrize("evaluation_id", [ + "eval-123", + "evaluation-456-abc", + "test_eval_789", + "long-evaluation-id-with-many-characters-123456789", + ]) + def test_get_results_with_different_evaluation_ids(self, results_resource, mock_results_response, evaluation_id): + """get method works with various evaluation ID formats.""" + results_resource._get.return_value = mock_results_response + + result = results_resource.get(evaluation_id=evaluation_id) + + assert isinstance(result, list) + call_args = results_resource._get.call_args + assert call_args.kwargs["params"]["evaluation_id"] == evaluation_id + + +class TestResultsErrorHandling: + """Test error handling in Results resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.get_cast = Mock() + return client + + @pytest.fixture + def results_resource(self, mock_client): + """Results resource instance.""" + return Results(mock_client) + + def test_get_results_handles_not_found_error(self, results_resource): + """get method propagates not found errors.""" + from atlas._exceptions import NotFoundError + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + + not_found_error = NotFoundError("Evaluation not found", response=mock_response, body=None) + results_resource._get.side_effect = not_found_error + + with pytest.raises(NotFoundError): + results_resource.get(evaluation_id="nonexistent-eval") + + def test_get_results_handles_auth_error(self, results_resource): + """get method propagates authentication errors.""" + from atlas._exceptions import AuthenticationError + + mock_response = Mock() + mock_response.status_code = 401 + mock_response.headers = {} + + auth_error = AuthenticationError("Unauthorized", response=mock_response, body=None) + results_resource._get.side_effect = auth_error + + with pytest.raises(AuthenticationError): + results_resource.get(evaluation_id="eval-123") + + def test_get_results_handles_permission_error(self, results_resource): + """get method propagates permission errors.""" + from atlas._exceptions import PermissionDeniedError + + mock_response = Mock() + mock_response.status_code = 403 + mock_response.headers = {} + + permission_error = PermissionDeniedError("Access denied", response=mock_response, body=None) + results_resource._get.side_effect = permission_error + + with pytest.raises(PermissionDeniedError): + results_resource.get(evaluation_id="restricted-eval") + + def test_get_results_handles_server_error(self, results_resource): + """get method propagates server errors.""" + from atlas._exceptions import InternalServerError + + mock_response = Mock() + mock_response.status_code = 500 + mock_response.headers = {} + + server_error = InternalServerError("Internal server error", response=mock_response, body=None) + results_resource._get.side_effect = server_error + + with pytest.raises(InternalServerError): + results_resource.get(evaluation_id="eval-123") + + def test_get_results_handles_connection_error(self, results_resource): + """get method propagates connection errors.""" + from atlas._exceptions import APIConnectionError + + mock_request = Mock() + connection_error = APIConnectionError(request=mock_request) + results_resource._get.side_effect = connection_error + + with pytest.raises(APIConnectionError): + results_resource.get(evaluation_id="eval-123") + + def test_get_results_handles_timeout_error(self, results_resource): + """get method propagates timeout errors.""" + from atlas._exceptions import APITimeoutError + + mock_request = Mock() + timeout_error = APITimeoutError(mock_request) + results_resource._get.side_effect = timeout_error + + with pytest.raises(APITimeoutError): + results_resource.get(evaluation_id="eval-123", timeout=1.0) + + +class TestResultsDataHandling: + """Test data handling specifics in Results resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.get_cast = Mock() + return client + + @pytest.fixture + def results_resource(self, mock_client): + """Results resource instance.""" + return Results(mock_client) + + def test_get_results_handles_complex_metrics(self, results_resource): + """get method handles complex metrics structures.""" + complex_result_data = { + "subset": "reasoning", + "prompt": "Complex reasoning question", + "result": "Complex answer", + "truth": "Expected answer", + "duration": timedelta(seconds=5.75), + "score": 0.87, + "metrics": { + "accuracy": 0.87, + "precision": 0.92, + "recall": 0.83, + "f1_score": 0.875, + "perplexity": 12.34, + "bleu_score": 0.78, + "rouge_1": 0.85, + "rouge_2": 0.72, + "rouge_l": 0.80, + "semantic_similarity": 0.91, + "factual_correctness": 0.95, + "reasoning_steps": 4.0 + } + } + + complex_result = Result(**complex_result_data) + response = ResultsData(results=[complex_result]) + results_resource._get.return_value = response + + result = results_resource.get(evaluation_id="eval-complex") + + assert len(result) == 1 + result_item = result[0] + + assert result_item.score == 0.87 + assert len(result_item.metrics) == 12 + assert result_item.metrics["f1_score"] == 0.875 + assert result_item.metrics["perplexity"] == 12.34 + assert result_item.metrics["reasoning_steps"] == 4.0 + + def test_get_results_handles_different_durations(self, results_resource): + """get method handles various duration formats.""" + durations_to_test = [ + timedelta(seconds=0.1), # Very short + timedelta(seconds=1.5), # Normal + timedelta(seconds=30.0), # Long + timedelta(minutes=2.5), # Very long + timedelta(hours=1), # Extremely long + ] + + results = [] + for i, duration in enumerate(durations_to_test): + result_data = { + "subset": f"test-{i}", + "prompt": f"Test prompt {i}", + "result": f"Test result {i}", + "truth": f"Test truth {i}", + "duration": duration, + "score": 0.8 + i * 0.05, + "metrics": {"accuracy": 0.8 + i * 0.05} + } + results.append(Result(**result_data)) + + response = ResultsData(results=results) + results_resource._get.return_value = response + + result = results_resource.get(evaluation_id="eval-durations") + + assert len(result) == 5 + assert result[0].duration == timedelta(seconds=0.1) + assert result[1].duration == timedelta(seconds=1.5) + assert result[2].duration == timedelta(seconds=30.0) + assert result[3].duration == timedelta(minutes=2.5) + assert result[4].duration == timedelta(hours=1) + + def test_get_results_handles_empty_metrics(self, results_resource): + """get method handles results with empty metrics.""" + result_data = { + "subset": "minimal", + "prompt": "Minimal test", + "result": "Minimal result", + "truth": "Minimal truth", + "duration": timedelta(seconds=1.0), + "score": 0.5, + "metrics": {} # Empty metrics + } + + minimal_result = Result(**result_data) + response = ResultsData(results=[minimal_result]) + results_resource._get.return_value = response + + result = results_resource.get(evaluation_id="eval-minimal") + + assert len(result) == 1 + assert result[0].metrics == {} + assert isinstance(result[0].metrics, dict) + + def test_get_results_return_type_consistency(self, results_resource): + """get method returns consistent types.""" + # Test that the method returns either a list or None + results_resource._get.return_value = None + result = results_resource.get(evaluation_id="eval-123") + assert result is None + + # Test that it returns a list when successful + results_resource._get.return_value = ResultsData(results=[]) + result = results_resource.get(evaluation_id="eval-123") + assert isinstance(result, list) \ No newline at end of file diff --git a/tests/test_base_client.py b/tests/test_base_client.py new file mode 100644 index 0000000..ba1a35a --- /dev/null +++ b/tests/test_base_client.py @@ -0,0 +1,231 @@ +from dataclasses import dataclass +from unittest.mock import Mock, patch + +import httpx +import pytest + +from atlas import _exceptions +from atlas._base_client import BaseClient + + +@dataclass +class ResponseModel: + """Test model for response casting.""" + name: str + value: int + + +class TestBaseClient: + """Test BaseClient HTTP functionality.""" + + @pytest.fixture + def client(self): + """Create a BaseClient instance for testing.""" + return BaseClient(base_url="https://api.test.com") + + @pytest.fixture + def mock_response(self): + """Mock httpx Response.""" + mock = Mock(spec=httpx.Response) + mock.status_code = 200 + mock.raise_for_status.return_value = None + mock.json.return_value = {"name": "test", "value": 42} + return mock + + def test_init_sets_base_url(self): + """BaseClient initializes with correct base URL.""" + client = BaseClient(base_url="https://custom.api.com") + + assert str(client.base_url) == "https://custom.api.com" + + def test_init_with_headers(self): + """BaseClient accepts custom headers.""" + headers = {"X-Custom": "value"} + client = BaseClient(base_url="https://api.test.com", headers=headers) + + assert client.headers["X-Custom"] == "value" + + def test_auth_headers_empty_by_default(self, client): + """BaseClient auth_headers returns empty dict by default.""" + assert client.auth_headers == {} + + def test_default_headers_structure(self, client): + """BaseClient default_headers includes required headers.""" + headers = client.default_headers + + assert headers["Accept"] == "application/json" + assert headers["Content-Type"] == "application/json" + assert isinstance(headers, dict) + + def test_default_headers_includes_auth(self, client): + """default_headers merges auth_headers.""" + with patch.object(type(client), 'auth_headers', new_callable=lambda: property(lambda _: {"Authorization": "Bearer token"})): + headers = client.default_headers + + assert headers["Authorization"] == "Bearer token" + assert headers["Accept"] == "application/json" + + @patch('httpx.Client.request') + def test_request_cast_without_cast_to(self, mock_request, client, mock_response): + """_request_cast returns raw response when cast_to is None.""" + mock_request.return_value = mock_response + + result = client._request_cast("GET", "/test") + + assert result is mock_response + mock_request.assert_called_once_with( + method="GET", + url="/test", + json=None, + params=None, + headers=client.default_headers + ) + + @patch('httpx.Client.request') + def test_request_cast_with_cast_to(self, mock_request, client, mock_response): + """_request_cast casts response to specified type.""" + mock_request.return_value = mock_response + + result = client._request_cast("GET", "/test", cast_to=ResponseModel) + + assert isinstance(result, ResponseModel) + assert result.name == "test" + assert result.value == 42 + mock_response.json.assert_called_once() + + @patch('httpx.Client.request') + def test_request_cast_combines_headers(self, mock_request, client, mock_response): + """_request_cast merges default and custom headers.""" + mock_request.return_value = mock_response + custom_headers = {"X-Custom": "value"} + + client._request_cast("POST", "/test", headers=custom_headers) + + expected_headers = {**client.default_headers, **custom_headers} + mock_request.assert_called_once_with( + method="POST", + url="/test", + json=None, + params=None, + headers=expected_headers + ) + + @patch('httpx.Client.request') + def test_request_cast_with_body_and_params(self, mock_request, client, mock_response): + """_request_cast sends body and params correctly.""" + mock_request.return_value = mock_response + body = {"key": "value"} + params = {"filter": "active"} + + client._request_cast("POST", "/test", body=body, params=params) + + mock_request.assert_called_once_with( + method="POST", + url="/test", + json=body, + params=params, + headers=client.default_headers + ) + + @patch('httpx.Client.request') + def test_request_cast_handles_http_error(self, mock_request, client): + """_request_cast converts HTTPStatusError to APIStatusError.""" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_response.headers = {} + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError("404", request=Mock(), response=mock_response) + mock_request.return_value = mock_response + + with patch.object(client, '_make_status_error_from_response') as mock_make_error: + mock_make_error.side_effect = _exceptions.APIStatusError("Test error", response=mock_response, body=None) + + with pytest.raises(_exceptions.APIStatusError): + client._request_cast("GET", "/test") + + mock_make_error.assert_called_once_with(mock_response) + + @patch('httpx.Client.request') + def test_get_cast_delegates_correctly(self, mock_request, client, mock_response): + """get_cast delegates to _request_cast with GET method.""" + mock_request.return_value = mock_response + params = {"page": 1} + headers = {"X-Test": "value"} + + result = client.get_cast("/test", params=params, headers=headers, cast_to=ResponseModel) + + assert isinstance(result, ResponseModel) + mock_request.assert_called_once_with( + method="GET", + url="/test", + json=None, + params=params, + headers={**client.default_headers, **headers} + ) + + @patch('httpx.Client.request') + def test_post_cast_delegates_correctly(self, mock_request, client, mock_response): + """post_cast delegates to _request_cast with POST method.""" + mock_request.return_value = mock_response + body = {"name": "test"} + headers = {"X-Test": "value"} + + result = client.post_cast("/test", body=body, headers=headers, cast_to=ResponseModel) + + assert isinstance(result, ResponseModel) + mock_request.assert_called_once_with( + method="POST", + url="/test", + json=body, + params=None, + headers={**client.default_headers, **headers} + ) + + def test_make_status_error_from_response_with_json(self, client): + """_make_status_error_from_response parses JSON error body.""" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 400 + mock_response.text = '{"error": "Bad Request", "code": 400}' + + with patch.object(client, '_make_status_error') as mock_make_error: + client._make_status_error_from_response(mock_response) + + mock_make_error.assert_called_once() + args, kwargs = mock_make_error.call_args + assert "Error code: 400" in args[0] + assert kwargs["body"] == {"error": "Bad Request", "code": 400} + assert kwargs["response"] is mock_response + + def test_make_status_error_from_response_with_text(self, client): + """_make_status_error_from_response handles plain text errors.""" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + with patch.object(client, '_make_status_error') as mock_make_error: + client._make_status_error_from_response(mock_response) + + mock_make_error.assert_called_once() + args, kwargs = mock_make_error.call_args + assert args[0] == "Internal Server Error" + assert kwargs["body"] == "Internal Server Error" + + def test_make_status_error_from_response_empty_text(self, client): + """_make_status_error_from_response handles empty response text.""" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 503 + mock_response.text = "" + + with patch.object(client, '_make_status_error') as mock_make_error: + client._make_status_error_from_response(mock_response) + + mock_make_error.assert_called_once() + args, _ = mock_make_error.call_args + assert args[0] == "Error code: 503" + + def test_make_status_error_not_implemented(self, client): + """_make_status_error raises NotImplementedError.""" + mock_response = Mock(spec=httpx.Response) + + with pytest.raises(NotImplementedError): + client._make_status_error("test", body=None, response=mock_response) \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..0384891 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,85 @@ +import pytest + +from atlas import Atlas +from atlas._exceptions import AtlasError + + +class TestAtlasClientInitialization: + """Test Atlas client initialization and configuration.""" + + def test_init_with_explicit_params(self): + """Client initializes correctly with explicit parameters.""" + client = Atlas( + api_key="explicit-key", + organization_id="explicit-org", + project_id="explicit-project" + ) + + assert client.api_key == "explicit-key" + assert client.organization_id == "explicit-org" + assert client.project_id == "explicit-project" + + def test_init_from_environment(self, mock_env_vars): + """Client initializes from environment variables.""" + _ = mock_env_vars # Fixture used for side effects + client = Atlas() + + assert client.api_key == "test-api-key" + assert client.organization_id == "test-org-id" + assert client.project_id == "test-project-id" + + def test_explicit_params_override_env(self, mock_env_vars): + """Explicit parameters override environment variables.""" + _ = mock_env_vars # Fixture used for side effects + client = Atlas( + api_key="override-key", + organization_id="override-org" + ) + + assert client.api_key == "override-key" + assert client.organization_id == "override-org" + assert client.project_id == "test-project-id" + + def test_missing_api_key_raises_error(self, env_vars): + """Missing API key raises AtlasError.""" + _ = env_vars # Fixture used for side effects + with pytest.raises(AtlasError, match="api_key client option must be set"): + Atlas() + + def test_none_values_fallback_to_env(self, mock_env_vars): + """None values explicitly passed fallback to environment.""" + _ = mock_env_vars # Fixture used for side effects + client = Atlas( + api_key=None, + organization_id=None, + project_id=None + ) + + assert client.api_key == "test-api-key" + assert client.organization_id == "test-org-id" + assert client.project_id == "test-project-id" + + def test_optional_params_can_be_none(self): + """Organization and project IDs can be None.""" + client = Atlas(api_key="test-key") + + assert client.api_key == "test-key" + assert client.organization_id is None + assert client.project_id is None + + @pytest.mark.parametrize("base_url", [ + "https://custom.api.com", + "https://staging.layerlens.ai/api/v1" + ]) + def test_custom_base_url(self, base_url): + """Client accepts custom base URL.""" + client = Atlas(api_key="test-key", base_url=base_url) + + assert str(client.base_url).rstrip('/') == base_url.rstrip('/') + + def test_custom_timeout(self): + """Client accepts custom timeout.""" + import httpx + client = Atlas(api_key="test-key", timeout=30.0) + + assert isinstance(client.timeout, httpx.Timeout) \ No newline at end of file diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..dea1d9a --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,307 @@ +from http import HTTPStatus +from unittest.mock import Mock + +import httpx +import pytest + +from atlas._exceptions import ( + APIError, + AtlasError, + ConflictError, + NotFoundError, + APIStatusError, + RateLimitError, + APITimeoutError, + BadRequestError, + APIConnectionError, + AuthenticationError, + InternalServerError, + PermissionDeniedError, + UnprocessableEntityError, + APIResponseValidationError, +) + + +class TestExceptionHierarchy: + """Test exception inheritance and basic functionality.""" + + def test_atlas_error_is_base_exception(self): + """AtlasError inherits from Exception.""" + error = AtlasError("test message") + + assert isinstance(error, Exception) + assert str(error) == "test message" + + def test_api_error_inherits_from_atlas_error(self): + """APIError inherits from AtlasError.""" + mock_request = Mock(spec=httpx.Request) + error = APIError("api error", mock_request, body=None) + + assert isinstance(error, AtlasError) + assert isinstance(error, Exception) + + def test_api_status_error_inherits_from_api_error(self): + """APIStatusError inherits from APIError.""" + mock_response = Mock(spec=httpx.Response) + mock_response.request = Mock(spec=httpx.Request) + mock_response.status_code = 400 + mock_response.headers = {} + + error = APIStatusError("status error", response=mock_response, body=None) + + assert isinstance(error, APIError) + assert isinstance(error, AtlasError) + + @pytest.mark.parametrize("exception_class", [ + BadRequestError, + AuthenticationError, + PermissionDeniedError, + NotFoundError, + ConflictError, + UnprocessableEntityError, + RateLimitError, + InternalServerError, + ]) + def test_status_exceptions_inherit_from_api_status_error(self, exception_class): + """All status-specific exceptions inherit from APIStatusError.""" + mock_response = Mock(spec=httpx.Response) + mock_response.request = Mock(spec=httpx.Request) + mock_response.status_code = 400 + mock_response.headers = {} + + error = exception_class("test error", response=mock_response, body=None) + + assert isinstance(error, APIStatusError) + assert isinstance(error, APIError) + assert isinstance(error, AtlasError) + + +class TestAPIError: + """Test APIError functionality.""" + + @pytest.fixture + def mock_request(self): + """Mock httpx.Request.""" + return Mock(spec=httpx.Request) + + def test_api_error_stores_message_and_request(self, mock_request): + """APIError stores message, request, and body.""" + body = {"error": "test"} + error = APIError("test message", mock_request, body=body) + + assert error.message == "test message" + assert error.request is mock_request + assert error.body == body + assert str(error) == "test message" + + def test_api_error_with_none_body(self, mock_request): + """APIError handles None body.""" + error = APIError("test message", mock_request, body=None) + + assert error.body is None + assert error.message == "test message" + + def test_api_error_with_json_body(self, mock_request): + """APIError stores JSON body correctly.""" + body = {"error": "validation failed", "code": 422} + error = APIError("validation error", mock_request, body=body) + + assert error.body == body + assert isinstance(error.body, dict) + assert error.body["error"] == "validation failed" + + def test_api_error_with_string_body(self, mock_request): + """APIError stores string body correctly.""" + body = "Plain text error message" + error = APIError("server error", mock_request, body=body) + + assert error.body == body + + +class TestAPIResponseValidationError: + """Test APIResponseValidationError functionality.""" + + @pytest.fixture + def mock_response(self): + """Mock httpx.Response.""" + mock = Mock(spec=httpx.Response) + mock.request = Mock(spec=httpx.Request) + mock.status_code = 200 + return mock + + def test_validation_error_with_default_message(self, mock_response): + """APIResponseValidationError uses default message when none provided.""" + error = APIResponseValidationError(mock_response, body=None) + + assert error.message == "Data returned by API invalid for expected schema." + assert error.response is mock_response + assert error.status_code == 200 + + def test_validation_error_with_custom_message(self, mock_response): + """APIResponseValidationError uses custom message when provided.""" + custom_message = "Custom validation error" + error = APIResponseValidationError(mock_response, body=None, message=custom_message) + + assert error.message == custom_message + assert str(error) == custom_message + + def test_validation_error_stores_response_data(self, mock_response): + """APIResponseValidationError stores response and body.""" + body = {"invalid": "data"} + error = APIResponseValidationError(mock_response, body=body) + + assert error.response is mock_response + assert error.body == body + assert error.request is mock_response.request + + +class TestAPIStatusError: + """Test APIStatusError functionality.""" + + @pytest.fixture + def mock_response(self): + """Mock httpx.Response with headers.""" + mock = Mock(spec=httpx.Response) + mock.request = Mock(spec=httpx.Request) + mock.status_code = 404 + mock.headers = {"x-request-id": "req-123"} + return mock + + def test_status_error_stores_response_data(self, mock_response): + """APIStatusError stores response, status code, and request ID.""" + error = APIStatusError("not found", response=mock_response, body=None) + + assert error.response is mock_response + assert error.status_code == 404 + assert error.request_id == "req-123" + assert error.request is mock_response.request + + def test_status_error_without_request_id(self, mock_response): + """APIStatusError handles missing request ID header.""" + mock_response.headers = {} + error = APIStatusError("error", response=mock_response, body=None) + + assert error.request_id is None + + def test_status_error_with_body(self, mock_response): + """APIStatusError stores error body.""" + body = {"error": "Resource not found", "code": "NOT_FOUND"} + error = APIStatusError("not found", response=mock_response, body=body) + + assert error.body == body + + +class TestConnectionErrors: + """Test connection-related errors.""" + + @pytest.fixture + def mock_request(self): + """Mock httpx.Request.""" + return Mock(spec=httpx.Request) + + def test_api_connection_error_default_message(self, mock_request): + """APIConnectionError uses default message.""" + error = APIConnectionError(request=mock_request) + + assert error.message == "Connection error." + assert error.request is mock_request + assert error.body is None + + def test_api_connection_error_custom_message(self, mock_request): + """APIConnectionError accepts custom message.""" + custom_message = "Failed to connect to server" + error = APIConnectionError(message=custom_message, request=mock_request) + + assert error.message == custom_message + + def test_api_timeout_error_inherits_from_connection_error(self, mock_request): + """APITimeoutError inherits from APIConnectionError.""" + error = APITimeoutError(mock_request) + + assert isinstance(error, APIConnectionError) + assert isinstance(error, APIError) + assert error.message == "Request timed out." + assert error.request is mock_request + + +class TestStatusCodeExceptions: + """Test HTTP status code specific exceptions.""" + + @pytest.fixture + def mock_response_factory(self): + """Factory for creating mock responses with different status codes.""" + def _create_response(status_code: int) -> Mock: + mock = Mock(spec=httpx.Response) + mock.request = Mock(spec=httpx.Request) + mock.status_code = status_code + mock.headers = {} + return mock + return _create_response + + @pytest.mark.parametrize("exception_class,expected_status", [ + (BadRequestError, HTTPStatus.BAD_REQUEST), + (AuthenticationError, HTTPStatus.UNAUTHORIZED), + (PermissionDeniedError, HTTPStatus.FORBIDDEN), + (NotFoundError, HTTPStatus.NOT_FOUND), + (ConflictError, HTTPStatus.CONFLICT), + (UnprocessableEntityError, HTTPStatus.UNPROCESSABLE_ENTITY), + (RateLimitError, HTTPStatus.TOO_MANY_REQUESTS), + ]) + def test_status_exception_has_correct_status_code(self, exception_class, expected_status, mock_response_factory): + """Status-specific exceptions have correct status codes.""" + mock_response = mock_response_factory(expected_status.value) + error = exception_class("test error", response=mock_response, body=None) + + assert error.status_code == expected_status.value + assert hasattr(error.__class__, 'status_code') + assert error.__class__.status_code == expected_status + + def test_bad_request_error_properties(self, mock_response_factory): + """BadRequestError has correct properties.""" + mock_response = mock_response_factory(400) + body = {"error": "Invalid request", "field": "name"} + error = BadRequestError("bad request", response=mock_response, body=body) + + assert error.status_code == 400 + assert error.body == body + assert isinstance(error, APIStatusError) + + def test_authentication_error_properties(self, mock_response_factory): + """AuthenticationError has correct properties.""" + mock_response = mock_response_factory(401) + error = AuthenticationError("unauthorized", response=mock_response, body=None) + + assert error.status_code == 401 + assert error.__class__.status_code == HTTPStatus.UNAUTHORIZED + + def test_internal_server_error_no_fixed_status(self, mock_response_factory): + """InternalServerError doesn't have a fixed status code.""" + mock_response = mock_response_factory(500) + error = InternalServerError("server error", response=mock_response, body=None) + + assert error.status_code == 500 + assert not hasattr(error.__class__, 'status_code') or error.__class__.status_code is None + + +class TestErrorMessages: + """Test error message handling and formatting.""" + + def test_exception_str_representation(self): + """Exception string representation shows message.""" + mock_request = Mock(spec=httpx.Request) + error = APIError("Test error message", mock_request, body=None) + + assert str(error) == "Test error message" + + def test_exception_with_complex_body(self): + """Exception handles complex body structures.""" + mock_request = Mock(spec=httpx.Request) + body = { + "error": {"code": "VALIDATION_ERROR", "details": ["Field 'name' is required"]}, + "request_id": "req-456" + } + error = APIError("Validation failed", mock_request, body=body) + + assert isinstance(error.body, dict) + assert error.body["error"]["code"] == "VALIDATION_ERROR" + assert error.body["request_id"] == "req-456" \ No newline at end of file diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..8daa315 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,501 @@ +from datetime import timedelta +from unittest.mock import Mock, patch + +import httpx +import pytest + +from atlas import Atlas +from atlas._models import ( + Model, + Models as ModelsData, + Result, + Results as ResultsData, + Benchmark, + Benchmarks as BenchmarksData, + Evaluation, + Evaluations as EvaluationsData, +) + + +class TestAtlasIntegration: + """Integration tests for full Atlas API workflows.""" + + @pytest.fixture + def atlas_client(self): + """Create Atlas client with mocked dependencies.""" + return Atlas( + api_key="test-api-key", + organization_id="test-org", + project_id="test-project" + ) + + @pytest.fixture + def sample_model_data(self): + """Sample model data for testing.""" + return { + "id": "model-gpt4", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "Large language model", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + @pytest.fixture + def sample_benchmark_data(self): + """Sample benchmark data for testing.""" + return { + "id": "benchmark-mmlu", + "key": "mmlu", + "name": "MMLU", + "full_description": "Massive Multitask Language Understanding", + "language": "english", + "categories": ["reasoning", "knowledge"], + "subsets": ["math", "science", "history"], + "prompt_count": 15908, + "deprecated": False, + } + + @pytest.fixture + def sample_evaluation_data(self): + """Sample evaluation data for testing.""" + return { + "id": "eval-12345", + "status": "completed", + "status_description": "Evaluation completed successfully", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-gpt4", + "model_name": "GPT-4", + "model_key": "gpt-4", + "model_company": "OpenAI", + "dataset_id": "benchmark-mmlu", + "dataset_name": "MMLU", + "average_duration": 2500, + "readability_score": 0.85, + "toxicity_score": 0.02, + "ethics_score": 0.92, + "accuracy": 0.89, + } + + @pytest.fixture + def sample_result_data(self): + """Sample result data for testing.""" + return { + "subset": "mathematics", + "prompt": "What is the derivative of x^2?", + "result": "2x", + "truth": "2x", + "duration": timedelta(seconds=2.5), + "score": 1.0, + "metrics": { + "accuracy": 1.0, + "confidence": 0.95 + } + } + + +class TestCompleteEvaluationWorkflow: + """Test complete evaluation workflow from start to finish.""" + + @pytest.fixture + def atlas_client(self): + """Atlas client for workflow testing.""" + return Atlas( + api_key="workflow-test-key", + organization_id="workflow-org", + project_id="workflow-project" + ) + + def test_complete_evaluation_workflow(self, atlas_client): + """Test complete workflow: get models/benchmarks -> create evaluation -> get results.""" + + # Mock data + model_data = { + "id": "model-123", "key": "gpt-4", "name": "GPT-4", "company": "OpenAI", + "description": "LLM", "released_at": 1679875200, "parameters": 1.76e12, + "modality": "text", "context_length": 8192, "architecture_type": "transformer", + "license": "proprietary", "open_weights": False, "region": "us-east-1", "deprecated": False, + } + + benchmark_data = { + "id": "bench-456", "key": "mmlu", "name": "MMLU", + "full_description": "MMLU benchmark", "language": "english", + "categories": ["reasoning"], "subsets": ["math"], "prompt_count": 1000, "deprecated": False, + } + + evaluation_data = { + "id": "eval-789", "status": "completed", "status_description": "Done", + "submitted_at": 1640995200, "finished_at": 1640995800, + "model_id": "model-123", "model_name": "GPT-4", "model_key": "gpt-4", "model_company": "OpenAI", + "dataset_id": "bench-456", "dataset_name": "MMLU", "average_duration": 2500, + "readability_score": 0.85, "toxicity_score": 0.02, "ethics_score": 0.92, "accuracy": 0.89, + } + + result_data = { + "subset": "math", "prompt": "2+2=?", "result": "4", "truth": "4", + "duration": timedelta(seconds=1.5), "score": 1.0, "metrics": {"accuracy": 1.0} + } + + # Create model objects + model = Model(**model_data) + benchmark = Benchmark(**benchmark_data) + evaluation = Evaluation(**evaluation_data) + result = Result(**result_data) + + # Mock responses + models_response = ModelsData(models=[model]) + benchmarks_response = BenchmarksData(datasets=[benchmark]) + evaluations_response = EvaluationsData(data=[evaluation]) + results_response = ResultsData(results=[result]) + + with patch.object(atlas_client, 'get_cast') as mock_get, \ + patch.object(atlas_client, 'post_cast') as mock_post: + + # Configure mocks for the workflow + mock_get.return_value = results_response # Get results + mock_post.return_value = evaluations_response # Create evaluation + + # Step 1: Create evaluation directly (Atlas client doesn't expose models/benchmarks resources) + created_evaluation = atlas_client.evaluations.create( + model="gpt-4", + benchmark="mmlu" + ) + assert created_evaluation.id == "eval-789" + assert created_evaluation.status == "completed" + + # Step 2: Get evaluation results + results = atlas_client.results.get(evaluation_id=created_evaluation.id) + assert len(results) == 1 + assert results[0].score == 1.0 + assert results[0].subset == "math" + + # Verify all API calls were made correctly + assert mock_get.call_count == 1 # Only results call + assert mock_post.call_count == 1 + + # Verify specific API calls + get_calls = mock_get.call_args_list + assert "/results" in get_calls[0][0][0] + + post_call = mock_post.call_args_list[0] + assert "/evaluations" in post_call[0][0] + + def test_workflow_with_error_handling(self, atlas_client): + """Test workflow handles errors gracefully.""" + from atlas._exceptions import NotFoundError + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + + with patch.object(atlas_client, 'get_cast') as mock_get: + # Mock API error when getting results + api_error = NotFoundError("Results not found", response=mock_response, body=None) + mock_get.side_effect = api_error + + # Verify error is propagated + with pytest.raises(NotFoundError): + atlas_client.results.get(evaluation_id="test-eval") + + def test_workflow_with_custom_timeouts(self, atlas_client): + """Test workflow respects custom timeout settings.""" + result_data = { + "subset": "test", "prompt": "test", "result": "test", "truth": "test", + "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + } + + results_response = ResultsData(results=[Result(**result_data)]) + + with patch.object(atlas_client, 'get_cast') as mock_get: + mock_get.return_value = results_response + + # Test with custom timeout + custom_timeout = httpx.Timeout(30.0) + results = atlas_client.results.get(evaluation_id="test-eval", timeout=custom_timeout) + + assert len(results) == 1 + + # Verify timeout was passed correctly + call_args = mock_get.call_args + assert call_args.kwargs["timeout"] is custom_timeout + + +class TestResourceInteraction: + """Test interactions between different resources.""" + + @pytest.fixture + def atlas_client(self): + """Atlas client for resource interaction testing.""" + return Atlas( + api_key="interaction-test-key", + organization_id="interaction-org", + project_id="interaction-project" + ) + + def test_evaluation_creation_with_model_and_benchmark_objects(self, atlas_client): + """Test creating evaluation using model and benchmark objects.""" + + # Create model and benchmark objects + model_data = { + "id": "model-abc", "key": "claude-3", "name": "Claude 3", "company": "Anthropic", + "description": "Claude 3", "released_at": 1709251200, "parameters": 5e11, + "modality": "text", "context_length": 100000, "architecture_type": "transformer", + "license": "proprietary", "open_weights": False, "region": "us-west-2", "deprecated": False, + } + + benchmark_data = { + "id": "bench-xyz", "key": "hellaswag", "name": "HellaSwag", + "full_description": "HellaSwag benchmark", "language": "english", + "categories": ["reasoning"], "subsets": ["commonsense"], "prompt_count": 10042, "deprecated": False, + } + + evaluation_data = { + "id": "eval-interaction", "status": "submitted", "status_description": "Submitted", + "submitted_at": 1640995200, "finished_at": 0, + "model_id": "model-abc", "model_name": "Claude 3", "model_key": "claude-3", "model_company": "Anthropic", + "dataset_id": "bench-xyz", "dataset_name": "HellaSwag", "average_duration": 0, + "readability_score": 0.0, "toxicity_score": 0.0, "ethics_score": 0.0, "accuracy": 0.0, + } + + model = Model(**model_data) + benchmark = Benchmark(**benchmark_data) + evaluation = Evaluation(**evaluation_data) + + evaluations_response = EvaluationsData(data=[evaluation]) + + with patch.object(atlas_client, 'post_cast') as mock_post: + mock_post.return_value = evaluations_response + + # Create evaluation using model and benchmark keys + created_evaluation = atlas_client.evaluations.create( + model=model.key, + benchmark=benchmark.key + ) + + assert created_evaluation.id == "eval-interaction" + assert created_evaluation.model_key == model.key + assert created_evaluation.dataset_id == benchmark.id + + # Verify API call + call_args = mock_post.call_args + body = call_args.kwargs["body"][0] + assert body["model_id"] == model.key + assert body["dataset_id"] == benchmark.key + + def test_results_analysis_workflow(self, atlas_client): + """Test analyzing results from multiple evaluations.""" + + # Create multiple result objects + results_data = [ + { + "subset": "math", "prompt": "2+2=?", "result": "4", "truth": "4", + "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + }, + { + "subset": "math", "prompt": "3*3=?", "result": "9", "truth": "9", + "duration": timedelta(seconds=1.2), "score": 1.0, "metrics": {"accuracy": 1.0} + }, + { + "subset": "reading", "prompt": "What is the main idea?", "result": "Education", "truth": "Learning", + "duration": timedelta(seconds=2.8), "score": 0.7, "metrics": {"accuracy": 0.7} + }, + ] + + results = [Result(**data) for data in results_data] + results_response = ResultsData(results=results) + + with patch.object(atlas_client, 'get_cast') as mock_get: + mock_get.return_value = results_response + + # Get results + evaluation_results = atlas_client.results.get(evaluation_id="test-eval") + + # Analyze results + math_results = [r for r in evaluation_results if r.subset == "math"] + reading_results = [r for r in evaluation_results if r.subset == "reading"] + + assert len(math_results) == 2 + assert len(reading_results) == 1 + + # Calculate average scores + math_avg = sum(r.score for r in math_results) / len(math_results) + reading_avg = sum(r.score for r in reading_results) / len(reading_results) + + assert math_avg == 1.0 + assert reading_avg == 0.7 + + # Calculate average duration + avg_duration = sum((r.duration.total_seconds() for r in evaluation_results), 0.0) / len(evaluation_results) + expected_avg = (1.0 + 1.2 + 2.8) / 3 + assert abs(avg_duration - expected_avg) < 0.01 + + +class TestAtlasClientProperties: + """Test Atlas client resource properties and access.""" + + def test_client_has_all_resource_properties(self): + """Atlas client exposes all resource properties.""" + client = Atlas( + api_key="property-test-key", + organization_id="property-org", + project_id="property-project" + ) + + # Verify available resource properties exist + assert hasattr(client, 'evaluations') + assert hasattr(client, 'results') + + # Verify they are the correct types + from atlas.resources.results import Results + from atlas.resources.evaluations import Evaluations + + assert isinstance(client.evaluations, Evaluations) + assert isinstance(client.results, Results) + + def test_resource_properties_share_same_client(self): + """All resource properties share the same client instance.""" + client = Atlas( + api_key="shared-client-test", + organization_id="shared-org", + project_id="shared-project" + ) + + # Verify all resources use the same client + assert client.evaluations._client is client + assert client.results._client is client + + def test_client_configuration_propagates_to_resources(self): + """Client configuration (org_id, project_id) propagates to resources.""" + org_id = "config-test-org" + project_id = "config-test-project" + + client = Atlas( + api_key="config-test-key", + organization_id=org_id, + project_id=project_id + ) + + # Verify configuration is available to resources + assert client.organization_id == org_id + assert client.project_id == project_id + + # Resources should have access to client configuration + assert client.evaluations._client.organization_id == org_id + assert client.evaluations._client.project_id == project_id + assert client.results._client.organization_id == org_id + assert client.results._client.project_id == project_id + + +class TestConcurrentOperations: + """Test concurrent operations and resource independence.""" + + def test_multiple_atlas_clients_independent(self): + """Multiple Atlas client instances operate independently.""" + + client1 = Atlas( + api_key="client-1-key", + organization_id="org-1", + project_id="project-1" + ) + + client2 = Atlas( + api_key="client-2-key", + organization_id="org-2", + project_id="project-2" + ) + + # Verify clients are independent + assert client1.api_key != client2.api_key + assert client1.organization_id != client2.organization_id + assert client1.project_id != client2.project_id + + # Verify resources are independent + assert client1.evaluations._client is not client2.evaluations._client + assert client1.results._client is not client2.results._client + + def test_resource_operations_isolated(self): + """Operations on different client resources are isolated.""" + + client1 = Atlas(api_key="iso-test-1", organization_id="org-1", project_id="proj-1") + client2 = Atlas(api_key="iso-test-2", organization_id="org-2", project_id="proj-2") + + result_data = { + "subset": "test", "prompt": "test", "result": "test", "truth": "test", + "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + } + + results_response = ResultsData(results=[Result(**result_data)]) + + with patch.object(client1, 'get_cast') as mock_get1, \ + patch.object(client2, 'get_cast') as mock_get2: + + mock_get1.return_value = results_response + mock_get2.return_value = results_response + + # Make calls on both clients + results1 = client1.results.get(evaluation_id="eval-1") + results2 = client2.results.get(evaluation_id="eval-2") + + # Verify both calls succeeded + assert results1 is not None + assert len(results1) == 1 + assert results2 is not None + assert len(results2) == 1 + + # Verify calls were made to correct clients + mock_get1.assert_called_once() + mock_get2.assert_called_once() + + # Verify different parameters were used + call1_params = mock_get1.call_args.kwargs["params"] + call2_params = mock_get2.call_args.kwargs["params"] + assert call1_params["evaluation_id"] == "eval-1" + assert call2_params["evaluation_id"] == "eval-2" + + +class TestErrorPropagation: + """Test error propagation through full workflows.""" + + def test_evaluation_workflow_error_propagation(self): + """Errors in evaluation workflow are properly propagated.""" + from atlas._exceptions import APIStatusError, APIConnectionError + + client = Atlas( + api_key="error-test-key", + organization_id="error-org", + project_id="error-project" + ) + + mock_response = Mock() + mock_response.status_code = 500 + mock_response.headers = {} + + # Test different types of errors + api_error = APIStatusError("Server Error", response=mock_response, body=None) + connection_error = APIConnectionError(request=Mock()) + + with patch.object(client, 'get_cast') as mock_get, \ + patch.object(client, 'post_cast') as mock_post: + + # Test API error in results.get + mock_get.side_effect = api_error + with pytest.raises(APIStatusError): + client.results.get(evaluation_id="test-eval") + + # Test connection error in evaluations.create + mock_post.side_effect = connection_error + with pytest.raises(APIConnectionError): + client.evaluations.create(model="gpt-4", benchmark="mmlu") + + # Verify errors didn't interfere with each other + assert mock_get.called + assert mock_post.called \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..14784e7 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,552 @@ +from datetime import timedelta + +import pytest +from pydantic import ValidationError + +from atlas._models import ( + Model, + Models, + Result, + Results, + Benchmark, + Benchmarks, + Evaluation, + CustomModel, + Evaluations, + CustomBenchmark, +) + + +class TestEvaluation: + """Test Evaluation model validation and serialization.""" + + @pytest.fixture + def valid_evaluation_data(self): + """Valid evaluation data for testing.""" + return { + "id": "eval-123", + "status": "completed", + "status_description": "Evaluation completed successfully", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-456", + "model_name": "GPT-4", + "model_key": "gpt-4", + "model_company": "OpenAI", + "dataset_id": "dataset-789", + "dataset_name": "MMLU", + "average_duration": 2500, + "readability_score": 0.85, + "toxicity_score": 0.02, + "ethics_score": 0.92, + "accuracy": 0.89, + } + + def test_evaluation_creation_with_valid_data(self, valid_evaluation_data): + """Evaluation model creates successfully with valid data.""" + evaluation = Evaluation(**valid_evaluation_data) + + assert evaluation.id == "eval-123" + assert evaluation.status == "completed" + assert evaluation.model_name == "GPT-4" + assert evaluation.accuracy == 0.89 + assert evaluation.readability_score == 0.85 + + def test_evaluation_field_types(self, valid_evaluation_data): + """Evaluation model enforces correct field types.""" + evaluation = Evaluation(**valid_evaluation_data) + + assert isinstance(evaluation.id, str) + assert isinstance(evaluation.submitted_at, int) + assert isinstance(evaluation.readability_score, float) + assert isinstance(evaluation.accuracy, float) + + def test_evaluation_validation_errors(self, valid_evaluation_data): + """Evaluation model validates field types and requirements.""" + # Test string field with wrong type + invalid_data = valid_evaluation_data.copy() + invalid_data["id"] = 123 + with pytest.raises(ValidationError): + Evaluation(**invalid_data) + + # Test int field with wrong type + invalid_data = valid_evaluation_data.copy() + invalid_data["submitted_at"] = "not-an-int" + with pytest.raises(ValidationError): + Evaluation(**invalid_data) + + def test_evaluation_missing_required_fields(self): + """Evaluation model requires all fields.""" + incomplete_data = {"id": "eval-123", "status": "pending"} + + with pytest.raises(ValidationError) as exc_info: + Evaluation(**incomplete_data) # type: ignore[arg-type] + + errors = exc_info.value.errors() + assert len(errors) > 5 + + def test_evaluation_json_serialization(self, valid_evaluation_data): + """Evaluation model serializes to JSON correctly.""" + evaluation = Evaluation(**valid_evaluation_data) + json_data = evaluation.model_dump() + + assert json_data["id"] == "eval-123" + assert json_data["accuracy"] == 0.89 + assert isinstance(json_data, dict) + + +class TestEvaluations: + """Test Evaluations collection model.""" + + @pytest.fixture + def evaluation_data(self): + """Sample evaluation data.""" + return { + "id": "eval-1", + "status": "completed", + "status_description": "Done", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-1", + "model_name": "Test Model", + "model_key": "test-model", + "model_company": "TestCorp", + "dataset_id": "dataset-1", + "dataset_name": "Test Dataset", + "average_duration": 1000, + "readability_score": 0.8, + "toxicity_score": 0.1, + "ethics_score": 0.9, + "accuracy": 0.85, + } + + def test_evaluations_with_list_of_evaluations(self, evaluation_data): + """Evaluations model accepts list of Evaluation objects.""" + evaluations_data = {"data": [evaluation_data, evaluation_data]} + evaluations = Evaluations(**evaluations_data) + + assert len(evaluations.data) == 2 + assert all(isinstance(eval, Evaluation) for eval in evaluations.data) + assert evaluations.data[0].id == "eval-1" + + def test_evaluations_empty_list(self): + """Evaluations model accepts empty list.""" + evaluations = Evaluations(data=[]) + + assert evaluations.data == [] + assert isinstance(evaluations.data, list) + + def test_evaluations_invalid_data_structure(self): + """Evaluations model validates data structure.""" + with pytest.raises(ValidationError): + Evaluations(data="not-a-list") # type: ignore[arg-type] + + +class TestResult: + """Test Result model validation.""" + + @pytest.fixture + def valid_result_data(self): + """Valid result data for testing.""" + return { + "subset": "math", + "prompt": "What is 2+2?", + "result": "4", + "truth": "4", + "duration": timedelta(seconds=1.5), + "score": 1.0, + "metrics": {"accuracy": 1.0, "confidence": 0.95}, + } + + def test_result_creation(self, valid_result_data): + """Result model creates with valid data.""" + result = Result(**valid_result_data) + + assert result.subset == "math" + assert result.prompt == "What is 2+2?" + assert result.score == 1.0 + assert isinstance(result.duration, timedelta) + assert isinstance(result.metrics, dict) + + def test_result_timedelta_handling(self, valid_result_data): + """Result model handles timedelta correctly.""" + result = Result(**valid_result_data) + + assert result.duration == timedelta(seconds=1.5) + assert result.duration.total_seconds() == 1.5 + + def test_result_metrics_validation(self, valid_result_data): + """Result model validates metrics as dict.""" + result = Result(**valid_result_data) + + assert result.metrics["accuracy"] == 1.0 + assert result.metrics["confidence"] == 0.95 + assert len(result.metrics) == 2 + + def test_result_invalid_metrics_type(self, valid_result_data): + """Result model rejects invalid metrics type.""" + invalid_data = valid_result_data.copy() + invalid_data["metrics"] = "not-a-dict" + + with pytest.raises(ValidationError): + Result(**invalid_data) + + +class TestResults: + """Test Results collection model.""" + + def test_results_with_result_list(self): + """Results model accepts list of Result objects.""" + result_data = { + "subset": "test", + "prompt": "test prompt", + "result": "test result", + "truth": "test truth", + "duration": timedelta(seconds=1), + "score": 0.8, + "metrics": {"score": 0.8}, + } + results = Results(results=[result_data, result_data]) # type: ignore[arg-type] + + assert len(results.results) == 2 + assert all(isinstance(result, Result) for result in results.results) + + +class TestModel: + """Test Model validation.""" + + @pytest.fixture + def valid_model_data(self): + """Valid model data for testing.""" + return { + "id": "model-123", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "Large language model", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + def test_model_creation(self, valid_model_data): + """Model creates with valid data.""" + model = Model(**valid_model_data) + + assert model.id == "model-123" + assert model.name == "GPT-4" + assert model.parameters == 1.76e12 + assert model.open_weights is False + assert model.deprecated is False + + def test_model_boolean_fields(self, valid_model_data): + """Model handles boolean fields correctly.""" + model = Model(**valid_model_data) + + assert isinstance(model.open_weights, bool) + assert isinstance(model.deprecated, bool) + assert model.open_weights is False + + def test_model_numeric_fields(self, valid_model_data): + """Model validates numeric fields.""" + model = Model(**valid_model_data) + + assert isinstance(model.parameters, float) + assert isinstance(model.context_length, int) + assert isinstance(model.released_at, int) + + def test_model_field_validation(self, valid_model_data): + """Model validates field types.""" + # Test numeric field validation + invalid_data = valid_model_data.copy() + invalid_data["parameters"] = "not-a-number" + with pytest.raises(ValidationError): + Model(**invalid_data) + + # Test int field validation + invalid_data = valid_model_data.copy() + invalid_data["context_length"] = "not-an-int" + with pytest.raises(ValidationError): + Model(**invalid_data) + + +class TestCustomModel: + """Test CustomModel validation.""" + + @pytest.fixture + def valid_custom_model_data(self): + """Valid custom model data.""" + return { + "id": "custom-123", + "key": "my-model", + "name": "My Custom Model", + "description": "Custom model description", + "max_tokens": 4096, + "api_url": "https://api.example.com/v1/chat", + "disabled": False, + } + + def test_custom_model_creation(self, valid_custom_model_data): + """CustomModel creates with valid data.""" + model = CustomModel(**valid_custom_model_data) + + assert model.id == "custom-123" + assert model.max_tokens == 4096 + assert model.api_url == "https://api.example.com/v1/chat" + assert model.disabled is False + + def test_custom_model_url_validation(self, valid_custom_model_data): + """CustomModel stores URL as string.""" + model = CustomModel(**valid_custom_model_data) + + assert isinstance(model.api_url, str) + assert model.api_url.startswith("https://") + + +class TestModels: + """Test Models collection with Union types.""" + + def test_models_with_mixed_model_types(self): + """Models collection handles Union of Model and CustomModel.""" + model_data = { + "id": "model-1", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "LLM", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + custom_model_data = { + "id": "custom-1", + "key": "my-model", + "name": "My Model", + "description": "Custom", + "max_tokens": 4096, + "api_url": "https://api.example.com", + "disabled": False, + } + + models = Models(models=[model_data, custom_model_data]) # type: ignore[arg-type] + + assert len(models.models) == 2 + assert isinstance(models.models[0], Model) + assert isinstance(models.models[1], CustomModel) + + +class TestBenchmark: + """Test Benchmark model validation.""" + + @pytest.fixture + def valid_benchmark_data(self): + """Valid benchmark data.""" + return { + "id": "bench-123", + "key": "mmlu", + "name": "MMLU", + "full_description": "Massive Multitask Language Understanding", + "language": "english", + "categories": ["reasoning", "knowledge"], + "subsets": ["math", "science", "history"], + "prompt_count": 15908, + "deprecated": False, + } + + def test_benchmark_creation(self, valid_benchmark_data): + """Benchmark creates with valid data.""" + benchmark = Benchmark(**valid_benchmark_data) + + assert benchmark.id == "bench-123" + assert benchmark.name == "MMLU" + assert len(benchmark.categories) == 2 + assert len(benchmark.subsets) == 3 + assert benchmark.prompt_count == 15908 + + def test_benchmark_list_fields(self, valid_benchmark_data): + """Benchmark handles list fields correctly.""" + benchmark = Benchmark(**valid_benchmark_data) + + assert isinstance(benchmark.categories, list) + assert isinstance(benchmark.subsets, list) + assert "reasoning" in benchmark.categories + assert "math" in benchmark.subsets + + +class TestCustomBenchmark: + """Test CustomBenchmark with optional fields.""" + + @pytest.fixture + def valid_custom_benchmark_data(self): + """Valid custom benchmark data.""" + return { + "id": "custom-bench-123", + "key": "my-benchmark", + "name": "My Benchmark", + "description": "Custom benchmark", + "system_prompt": "You are a helpful assistant", + "subsets": ["subset1", "subset2"], + "prompt_count": 100, + "version_count": 1, + "regex_pattern": r"Answer: (.+)", + "llm_judge_model_id": "gpt-4", + "custom_instructions": "Rate on scale 1-10", + "scoring_metric": "accuracy", + "metrics": ["accuracy", "precision"], + "files": ["data.jsonl"], + "disabled": False, + } + + def test_custom_benchmark_creation(self, valid_custom_benchmark_data): + """CustomBenchmark creates with all fields.""" + benchmark = CustomBenchmark(**valid_custom_benchmark_data) + + assert benchmark.id == "custom-bench-123" + assert benchmark.system_prompt == "You are a helpful assistant" + assert benchmark.regex_pattern == r"Answer: (.+)" + assert len(benchmark.metrics) == 2 + + def test_custom_benchmark_optional_fields(self): + """CustomBenchmark handles optional fields correctly.""" + minimal_data = { + "id": "custom-123", + "key": "test", + "name": "Test", + "description": "Test desc", + "system_prompt": None, + "subsets": ["test"], + "prompt_count": 10, + "version_count": 1, + "regex_pattern": None, + "llm_judge_model_id": "gpt-4", + "custom_instructions": "Test", + "scoring_metric": None, + "metrics": ["accuracy"], + "files": ["test.jsonl"], + "disabled": False, + } + + benchmark = CustomBenchmark(**minimal_data) + + assert benchmark.system_prompt is None + assert benchmark.regex_pattern is None + assert benchmark.scoring_metric is None + + +class TestBenchmarks: + """Test Benchmarks collection with alias field.""" + + def test_benchmarks_with_datasets_alias(self): + """Benchmarks accepts 'datasets' as alias for benchmarks field.""" + benchmark_data = { + "id": "bench-1", + "key": "test", + "name": "Test", + "full_description": "Test benchmark", + "language": "english", + "categories": ["test"], + "subsets": ["test"], + "prompt_count": 10, + "deprecated": False, + } + + # Using the alias 'datasets' + benchmarks = Benchmarks(datasets=[benchmark_data]) # type: ignore[arg-type] + + assert len(benchmarks.benchmarks) == 1 + assert isinstance(benchmarks.benchmarks[0], Benchmark) + + def test_benchmarks_field_validation(self): + """Benchmarks validates field structure correctly.""" + # Should work with 'benchmarks' field name too + benchmark_data = { + "id": "bench-1", + "key": "test", + "name": "Test", + "full_description": "Test benchmark", + "language": "english", + "categories": ["test"], + "subsets": ["test"], + "prompt_count": 10, + "deprecated": False, + } + + benchmarks = Benchmarks(datasets=[benchmark_data]) # type: ignore[arg-type] + + assert len(benchmarks.benchmarks) == 1 + + +class TestModelSerialization: + """Test model serialization and deserialization patterns.""" + + def test_round_trip_serialization(self): + """Models can be serialized and deserialized correctly.""" + original_data = { + "id": "eval-123", + "status": "completed", + "status_description": "Done", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-456", + "model_name": "GPT-4", + "model_key": "gpt-4", + "model_company": "OpenAI", + "dataset_id": "dataset-789", + "dataset_name": "MMLU", + "average_duration": 2500, + "readability_score": 0.85, + "toxicity_score": 0.02, + "ethics_score": 0.92, + "accuracy": 0.89, + } + + # Create model, serialize, then deserialize + evaluation = Evaluation(**original_data) + serialized = evaluation.model_dump() + deserialized = Evaluation(**serialized) + + assert deserialized.id == evaluation.id + assert deserialized.accuracy == evaluation.accuracy + assert deserialized == evaluation + + def test_json_compatibility(self): + """Models work with JSON serialization.""" + import json + + model_data = { + "id": "model-123", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "LLM", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, + } + + model = Model(**model_data) + json_str = json.dumps(model.model_dump()) + parsed_data = json.loads(json_str) + reconstructed = Model(**parsed_data) + + assert reconstructed.name == model.name + assert reconstructed.parameters == model.parameters \ No newline at end of file diff --git a/tests/test_resource.py b/tests/test_resource.py new file mode 100644 index 0000000..e1af72e --- /dev/null +++ b/tests/test_resource.py @@ -0,0 +1,403 @@ +from unittest.mock import Mock, patch + +import pytest + +from atlas._resource import SyncAPIResource + + +class TestSyncAPIResource: + """Test SyncAPIResource base class functionality.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client with required methods.""" + client = Mock() + client.get_cast = Mock() + client.post_cast = Mock() + return client + + @pytest.fixture + def resource_instance(self, mock_client): + """Create SyncAPIResource instance for testing.""" + return SyncAPIResource(mock_client) + + def test_resource_initialization(self, mock_client): + """SyncAPIResource initializes correctly with client.""" + resource = SyncAPIResource(mock_client) + + assert resource._client is mock_client + assert resource._get is mock_client.get_cast + assert resource._post is mock_client.post_cast + + def test_resource_stores_client_reference(self, resource_instance, mock_client): + """Resource maintains reference to the client.""" + assert resource_instance._client is mock_client + assert hasattr(resource_instance, '_client') + + def test_resource_delegates_get_to_client(self, resource_instance, mock_client): + """_get method delegates to client.get_cast.""" + assert resource_instance._get is mock_client.get_cast + + # Verify it's the same method reference + assert callable(resource_instance._get) + + # Test delegation works + resource_instance._get("/test", params={"key": "value"}) + mock_client.get_cast.assert_called_once_with("/test", params={"key": "value"}) + + def test_resource_delegates_post_to_client(self, resource_instance, mock_client): + """_post method delegates to client.post_cast.""" + assert resource_instance._post is mock_client.post_cast + + # Verify it's the same method reference + assert callable(resource_instance._post) + + # Test delegation works + resource_instance._post("/test", body={"data": "test"}) + mock_client.post_cast.assert_called_once_with("/test", body={"data": "test"}) + + def test_resource_sleep_method_exists(self, resource_instance): + """Resource has _sleep method.""" + assert hasattr(resource_instance, '_sleep') + assert callable(resource_instance._sleep) + + @patch('time.sleep') + def test_resource_sleep_delegates_to_time_sleep(self, mock_time_sleep, resource_instance): + """_sleep method delegates to time.sleep.""" + sleep_duration = 2.5 + + resource_instance._sleep(sleep_duration) + + mock_time_sleep.assert_called_once_with(sleep_duration) + + @patch('time.sleep') + def test_resource_sleep_with_different_durations(self, mock_time_sleep, resource_instance): + """_sleep method works with various duration values.""" + durations = [0.1, 1.0, 5.0, 10.5, 60.0] + + for duration in durations: + mock_time_sleep.reset_mock() + resource_instance._sleep(duration) + mock_time_sleep.assert_called_once_with(duration) + + @patch('time.sleep') + def test_resource_sleep_with_zero_duration(self, mock_time_sleep, resource_instance): + """_sleep method handles zero duration.""" + resource_instance._sleep(0.0) + + mock_time_sleep.assert_called_once_with(0.0) + + @patch('time.sleep') + def test_resource_sleep_with_integer_duration(self, mock_time_sleep, resource_instance): + """_sleep method handles integer duration values.""" + resource_instance._sleep(3) + + mock_time_sleep.assert_called_once_with(3) + + def test_resource_initialization_with_different_clients(self): + """SyncAPIResource works with different client objects.""" + # Test with different mock clients + client1 = Mock() + client1.get_cast = Mock(return_value="get_result_1") + client1.post_cast = Mock(return_value="post_result_1") + + client2 = Mock() + client2.get_cast = Mock(return_value="get_result_2") + client2.post_cast = Mock(return_value="post_result_2") + + resource1 = SyncAPIResource(client1) + resource2 = SyncAPIResource(client2) + + # Verify each resource uses its own client + assert resource1._client is client1 + assert resource2._client is client2 + assert resource1._get is client1.get_cast + assert resource2._get is client2.get_cast + + # Verify method calls go to correct clients + result1 = resource1._get("/test1") + result2 = resource2._get("/test2") + + assert result1 == "get_result_1" + assert result2 == "get_result_2" + client1.get_cast.assert_called_once_with("/test1") + client2.get_cast.assert_called_once_with("/test2") + + +class TestSyncAPIResourceInheritance: + """Test SyncAPIResource as a base class for inheritance.""" + + def test_resource_can_be_subclassed(self): + """SyncAPIResource can be subclassed for specific resources.""" + + class TestResource(SyncAPIResource): + def get_data(self, id: str): + return self._get(f"/data/{id}") + + def create_data(self, data: dict): + return self._post("/data", body=data) + + mock_client = Mock() + mock_client.get_cast = Mock(return_value={"id": "123", "data": "test"}) + mock_client.post_cast = Mock(return_value={"id": "456", "created": True}) + + resource = TestResource(mock_client) + + # Test inherited initialization + assert resource._client is mock_client + assert resource._get is mock_client.get_cast + assert resource._post is mock_client.post_cast + + # Test custom methods using inherited functionality + get_result = resource.get_data("123") + create_result = resource.create_data({"name": "test"}) + + assert get_result == {"id": "123", "data": "test"} + assert create_result == {"id": "456", "created": True} + + mock_client.get_cast.assert_called_once_with("/data/123") + mock_client.post_cast.assert_called_once_with("/data", body={"name": "test"}) + + def test_subclass_can_override_methods(self): + """Subclasses can override resource methods.""" + + class CustomResource(SyncAPIResource): + def __init__(self, client): + super().__init__(client) + self.custom_property = "custom_value" + + def _sleep(self, seconds: float) -> None: + # Custom sleep implementation + self.last_sleep_duration = seconds + super()._sleep(seconds) + + mock_client = Mock() + mock_client.get_cast = Mock() + mock_client.post_cast = Mock() + + resource = CustomResource(mock_client) + + # Test custom property + assert resource.custom_property == "custom_value" + + # Test overridden method + with patch('time.sleep') as mock_time_sleep: + resource._sleep(1.5) + + assert resource.last_sleep_duration == 1.5 + mock_time_sleep.assert_called_once_with(1.5) + + def test_multiple_resource_instances_independent(self): + """Multiple resource instances maintain independence.""" + + class ResourceA(SyncAPIResource): + def method_a(self): + return self._get("/resource-a") + + class ResourceB(SyncAPIResource): + def method_b(self): + return self._post("/resource-b", body={"type": "b"}) + + client1 = Mock() + client1.get_cast = Mock(return_value="result_a") + client1.post_cast = Mock() + + client2 = Mock() + client2.get_cast = Mock() + client2.post_cast = Mock(return_value="result_b") + + resource_a = ResourceA(client1) + resource_b = ResourceB(client2) + + # Test that resources are independent + result_a = resource_a.method_a() + result_b = resource_b.method_b() + + assert result_a == "result_a" + assert result_b == "result_b" + + # Verify correct clients were called + client1.get_cast.assert_called_once_with("/resource-a") + client2.post_cast.assert_called_once_with("/resource-b", body={"type": "b"}) + + # Verify cross-contamination didn't occur + client1.post_cast.assert_not_called() + client2.get_cast.assert_not_called() + + +class TestSyncAPIResourceErrorHandling: + """Test error handling in SyncAPIResource.""" + + @pytest.fixture + def mock_client(self): + """Mock client that can raise errors.""" + client = Mock() + client.get_cast = Mock() + client.post_cast = Mock() + return client + + @pytest.fixture + def resource_instance(self, mock_client): + """Create resource instance for error testing.""" + return SyncAPIResource(mock_client) + + def test_resource_propagates_get_errors(self, resource_instance, mock_client): + """Resource propagates errors from _get calls.""" + from atlas._exceptions import APIStatusError + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + + api_error = APIStatusError("Not Found", response=mock_response, body=None) + mock_client.get_cast.side_effect = api_error + + with pytest.raises(APIStatusError): + resource_instance._get("/test") + + def test_resource_propagates_post_errors(self, resource_instance, mock_client): + """Resource propagates errors from _post calls.""" + from atlas._exceptions import APIConnectionError + + mock_request = Mock() + connection_error = APIConnectionError(request=mock_request) + mock_client.post_cast.side_effect = connection_error + + with pytest.raises(APIConnectionError): + resource_instance._post("/test", body={"data": "test"}) + + def test_resource_handles_client_method_missing(self): + """Resource handles clients missing required methods gracefully.""" + # Create a client without the required methods + incomplete_client = object() # Plain object with no methods + + # This should fail during initialization since the methods don't exist + with pytest.raises(AttributeError): + SyncAPIResource(incomplete_client) # type: ignore[arg-type] + + @patch('time.sleep') + def test_resource_sleep_handles_exceptions(self, mock_time_sleep, resource_instance): + """_sleep method handles exceptions from time.sleep.""" + mock_time_sleep.side_effect = KeyboardInterrupt("Interrupted") + + with pytest.raises(KeyboardInterrupt): + resource_instance._sleep(1.0) + + mock_time_sleep.assert_called_once_with(1.0) + + +class TestSyncAPIResourceTyping: + """Test type-related aspects of SyncAPIResource.""" + + def test_resource_client_attribute_typing(self): + """Resource._client maintains proper typing.""" + + # Test with properly typed client (would be Atlas in real usage) + mock_client = Mock() + mock_client.get_cast = Mock() + mock_client.post_cast = Mock() + + resource = SyncAPIResource(mock_client) + + # Verify the client is stored and accessible + assert resource._client is mock_client + assert hasattr(resource, '_client') + + def test_resource_method_signatures(self): + """Resource methods have expected signatures.""" + import inspect + + # Check _sleep method signature + sleep_sig = inspect.signature(SyncAPIResource._sleep) + sleep_params = list(sleep_sig.parameters.keys()) + + assert 'self' in sleep_params + assert 'seconds' in sleep_params + assert len(sleep_params) == 2 + + def test_resource_initialization_signature(self): + """Resource __init__ has expected signature.""" + import inspect + + init_sig = inspect.signature(SyncAPIResource.__init__) + init_params = list(init_sig.parameters.keys()) + + assert 'self' in init_params + assert 'client' in init_params + assert len(init_params) == 2 + + +class TestSyncAPIResourceRealWorldUsage: + """Test SyncAPIResource in realistic usage scenarios.""" + + def test_resource_with_retry_logic(self): + """Resource can implement retry logic using _sleep.""" + + class RetryableResource(SyncAPIResource): + def get_with_retry(self, url: str, max_retries: int = 3): + for attempt in range(max_retries): + try: + return self._get(url) + except Exception as e: + if attempt == max_retries - 1: + raise + self._sleep(2 ** attempt) # Exponential backoff + + mock_client = Mock() + # First two calls fail, third succeeds + mock_client.get_cast.side_effect = [ + Exception("First failure"), + Exception("Second failure"), + {"success": True} + ] + + resource = RetryableResource(mock_client) + + with patch.object(resource, '_sleep') as mock_sleep: + result = resource.get_with_retry("/test") + + assert result == {"success": True} + assert mock_client.get_cast.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1) # 2^0 + mock_sleep.assert_any_call(2) # 2^1 + + def test_resource_with_complex_workflow(self): + """Resource can implement complex workflows.""" + + class WorkflowResource(SyncAPIResource): + def create_and_wait(self, data: dict, poll_interval: float = 1.0): + # Create resource + created = self._post("/create", body=data) + resource_id = created["id"] # type: ignore[index] + + # Poll until complete + while True: + status = self._get(f"/status/{resource_id}") + if status["state"] == "completed": # type: ignore[index] + return self._get(f"/result/{resource_id}") + elif status["state"] == "failed": # type: ignore[index] + raise Exception("Workflow failed") + + self._sleep(poll_interval) + + mock_client = Mock() + mock_client.post_cast.return_value = {"id": "workflow-123"} + + # Mock status progression: pending -> running -> completed + mock_client.get_cast.side_effect = [ + {"state": "pending"}, + {"state": "running"}, + {"state": "completed"}, + {"result": "workflow complete"} + ] + + resource = WorkflowResource(mock_client) + + with patch.object(resource, '_sleep') as mock_sleep: + result = resource.create_and_wait({"name": "test"}) + + assert result == {"result": "workflow complete"} + assert mock_client.post_cast.call_count == 1 + assert mock_client.get_cast.call_count == 4 + assert mock_sleep.call_count == 2 # Two sleeps during polling \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..2c026ae --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,377 @@ +import logging +from typing import Mapping +from unittest.mock import Mock + +import pytest + +from atlas._utils import ( + SENSITIVE_HEADERS, + SensitiveHeadersFilter, + is_dict, + is_mapping, +) + + +class TestTypeguards: + """Test utility type guard functions.""" + + def test_is_dict_with_dict(self): + """is_dict returns True for dict objects.""" + test_dict = {"key": "value", "number": 42} + + assert is_dict(test_dict) is True + + def test_is_dict_with_empty_dict(self): + """is_dict returns True for empty dict.""" + empty_dict = {} + + assert is_dict(empty_dict) is True + + def test_is_dict_with_nested_dict(self): + """is_dict returns True for nested dict structures.""" + nested_dict = {"outer": {"inner": {"deep": "value"}}} + + assert is_dict(nested_dict) is True + + @pytest.mark.parametrize("non_dict_value", [ + "string", + 123, + [1, 2, 3], + (1, 2, 3), + {"key", "value"}, # set + None, + True, + object(), + ]) + def test_is_dict_with_non_dict_objects(self, non_dict_value): + """is_dict returns False for non-dict objects.""" + assert is_dict(non_dict_value) is False + + def test_is_mapping_with_dict(self): + """is_mapping returns True for dict objects.""" + test_dict = {"key": "value"} + + assert is_mapping(test_dict) is True + + def test_is_mapping_with_custom_mapping(self): + """is_mapping returns True for custom Mapping implementations.""" + from collections import UserDict, OrderedDict + + ordered_dict = OrderedDict([("a", 1), ("b", 2)]) + user_dict = UserDict({"x": 10, "y": 20}) + + assert is_mapping(ordered_dict) is True + assert is_mapping(user_dict) is True + + def test_is_mapping_with_mapping_subclass(self): + """is_mapping returns True for Mapping subclasses.""" + class CustomMapping(Mapping): + def __init__(self): + self._data = {"custom": "mapping"} + + def __getitem__(self, key): + return self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + custom_mapping = CustomMapping() + + assert is_mapping(custom_mapping) is True + assert custom_mapping["custom"] == "mapping" + + @pytest.mark.parametrize("non_mapping_value", [ + "string", + 123, + [1, 2, 3], + (1, 2, 3), + {"key", "value"}, # set + None, + True, + object(), + ]) + def test_is_mapping_with_non_mapping_objects(self, non_mapping_value): + """is_mapping returns False for non-mapping objects.""" + assert is_mapping(non_mapping_value) is False + + +class TestSensitiveHeaders: + """Test sensitive headers constant and handling.""" + + def test_sensitive_headers_constant(self): + """SENSITIVE_HEADERS contains expected header names.""" + assert "x-api-key" in SENSITIVE_HEADERS + assert "authorization" in SENSITIVE_HEADERS + assert isinstance(SENSITIVE_HEADERS, set) + + def test_sensitive_headers_lowercase(self): + """SENSITIVE_HEADERS contains lowercase header names.""" + for header in SENSITIVE_HEADERS: + assert header == header.lower() + + +class TestSensitiveHeadersFilter: + """Test SensitiveHeadersFilter logging functionality.""" + + @pytest.fixture + def filter_instance(self): + """Create SensitiveHeadersFilter instance.""" + return SensitiveHeadersFilter() + + @pytest.fixture + def mock_log_record(self): + """Create mock logging record.""" + record = Mock(spec=logging.LogRecord) + record.args = {} + return record + + def test_filter_initialization(self): + """SensitiveHeadersFilter initializes correctly.""" + filter_instance = SensitiveHeadersFilter() + + assert isinstance(filter_instance, logging.Filter) + assert hasattr(filter_instance, 'filter') + + def test_filter_returns_true_by_default(self, filter_instance, mock_log_record): + """filter method always returns True to allow logging.""" + result = filter_instance.filter(mock_log_record) + + assert result is True + + def test_filter_handles_record_without_headers(self, filter_instance, mock_log_record): + """filter handles log records without headers gracefully.""" + mock_log_record.args = {"message": "test", "data": "value"} + + result = filter_instance.filter(mock_log_record) + + assert result is True + assert mock_log_record.args["message"] == "test" + + def test_filter_handles_non_dict_args(self, filter_instance, mock_log_record): + """filter handles log records with non-dict args.""" + mock_log_record.args = "string args" + + result = filter_instance.filter(mock_log_record) + + assert result is True + + def test_filter_redacts_sensitive_headers(self, filter_instance, mock_log_record): + """filter redacts sensitive header values.""" + mock_log_record.args = { + "headers": { + "content-type": "application/json", + "x-api-key": "secret-key-123", + "authorization": "Bearer token-456", + "user-agent": "atlas-python-sdk", + } + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + headers = mock_log_record.args["headers"] + assert headers["content-type"] == "application/json" + assert headers["x-api-key"] == "" + assert headers["authorization"] == "" + assert headers["user-agent"] == "atlas-python-sdk" + + def test_filter_handles_case_insensitive_headers(self, filter_instance, mock_log_record): + """filter redacts headers regardless of case.""" + mock_log_record.args = { + "headers": { + "X-API-KEY": "secret-key-123", + "Authorization": "Bearer token-456", + "AUTHORIZATION": "Bearer another-token", + } + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + headers = mock_log_record.args["headers"] + assert headers["X-API-KEY"] == "" + assert headers["Authorization"] == "" + assert headers["AUTHORIZATION"] == "" + + def test_filter_preserves_original_args_structure(self, filter_instance, mock_log_record): + """filter preserves the original args structure.""" + original_args = { + "method": "POST", + "url": "/test", + "headers": { + "x-api-key": "secret", + "content-type": "application/json", + }, + "body": {"data": "test"} + } + mock_log_record.args = original_args + + result = filter_instance.filter(mock_log_record) + + assert result is True + assert mock_log_record.args["method"] == "POST" + assert mock_log_record.args["url"] == "/test" + assert mock_log_record.args["body"] == {"data": "test"} + assert mock_log_record.args["headers"]["content-type"] == "application/json" + + def test_filter_creates_copy_of_headers(self, filter_instance, mock_log_record): + """filter creates a copy of headers dict to avoid modifying original.""" + original_headers = { + "x-api-key": "secret-key", + "content-type": "application/json", + } + mock_log_record.args = {"headers": original_headers} + + filter_instance.filter(mock_log_record) + + # Original headers should be unchanged + assert original_headers["x-api-key"] == "secret-key" + # Record headers should be modified + assert mock_log_record.args["headers"]["x-api-key"] == "" + # They should be different objects + assert mock_log_record.args["headers"] is not original_headers + + def test_filter_handles_non_string_header_keys(self, filter_instance, mock_log_record): + """filter handles non-string header keys gracefully.""" + mock_log_record.args = { + "headers": { + 123: "numeric-key", + "x-api-key": "secret-key", + ("tuple", "key"): "tuple-value", + } + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + headers = mock_log_record.args["headers"] + assert headers[123] == "numeric-key" # Non-string keys unchanged + assert headers["x-api-key"] == "" # String keys processed + assert headers[("tuple", "key")] == "tuple-value" + + def test_filter_handles_non_dict_headers(self, filter_instance, mock_log_record): + """filter handles cases where headers is not a dict.""" + mock_log_record.args = { + "headers": "not-a-dict", + "other": "data" + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + assert mock_log_record.args["headers"] == "not-a-dict" + assert mock_log_record.args["other"] == "data" + + def test_filter_with_empty_headers(self, filter_instance, mock_log_record): + """filter handles empty headers dict.""" + mock_log_record.args = {"headers": {}} + + result = filter_instance.filter(mock_log_record) + + assert result is True + assert mock_log_record.args["headers"] == {} + + def test_filter_with_complex_header_values(self, filter_instance, mock_log_record): + """filter redacts complex header values.""" + mock_log_record.args = { + "headers": { + "authorization": { + "type": "Bearer", + "token": "complex-token-123" + }, + "x-api-key": ["key1", "key2", "key3"], + "content-type": "application/json", + } + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + headers = mock_log_record.args["headers"] + assert headers["authorization"] == "" + assert headers["x-api-key"] == "" + assert headers["content-type"] == "application/json" + + @pytest.mark.parametrize("sensitive_header", list(SENSITIVE_HEADERS)) + def test_filter_redacts_all_sensitive_headers(self, filter_instance, mock_log_record, sensitive_header): + """filter redacts all headers defined in SENSITIVE_HEADERS.""" + mock_log_record.args = { + "headers": { + sensitive_header: f"secret-value-for-{sensitive_header}", + "safe-header": "safe-value", + } + } + + result = filter_instance.filter(mock_log_record) + + assert result is True + headers = mock_log_record.args["headers"] + assert headers[sensitive_header] == "" + assert headers["safe-header"] == "safe-value" + + +class TestUtilsIntegration: + """Test integration scenarios for utility functions.""" + + def test_sensitive_filter_with_real_logging(self): + """SensitiveHeadersFilter works with real logging setup.""" + filter_instance = SensitiveHeadersFilter() + + # Create a mock LogRecord directly + mock_record = Mock() + mock_record.args = { + "headers": { + "x-api-key": "secret-key-123", + "content-type": "application/json", + } + } + + # Process the record through our filter + result = filter_instance.filter(mock_record) + + # Verify filter returns True (allowing the log) + assert result is True + + # Verify sensitive data was redacted + assert mock_record.args["headers"]["x-api-key"] == "" + assert mock_record.args["headers"]["content-type"] == "application/json" + + def test_typeguards_with_complex_data_structures(self): + """Type guards work correctly with complex nested structures.""" + complex_structure = { + "metadata": { + "headers": { + "authorization": "Bearer token", + "x-api-key": "secret" + }, + "params": ["param1", "param2"] + }, + "data": { + "nested": { + "deep": { + "value": 42 + } + } + } + } + + # Test type guards at different levels + assert is_dict(complex_structure) + assert is_mapping(complex_structure) + assert is_dict(complex_structure["metadata"]) + assert is_dict(complex_structure["metadata"]["headers"]) + assert not is_dict(complex_structure["metadata"]["params"]) + + # Test with the filter + filter_instance = SensitiveHeadersFilter() + mock_record = Mock(spec=logging.LogRecord) + mock_record.args = complex_structure["metadata"] + + result = filter_instance.filter(mock_record) + + assert result is True + assert mock_record.args["headers"]["authorization"] == "" + assert mock_record.args["headers"]["x-api-key"] == "" \ No newline at end of file