From 9124c6ab2d6377a6fa71bd7530b02a53e892fac2 Mon Sep 17 00:00:00 2001 From: m-peko Date: Mon, 18 Aug 2025 18:10:41 +0200 Subject: [PATCH] Fix Pydantic types --- src/atlas/_client.py | 10 +-- src/atlas/models/__init__.py | 19 ++++-- src/atlas/models/api.py | 25 +++++-- src/atlas/models/benchmark.py | 30 ++++----- src/atlas/models/model.py | 28 ++++---- src/atlas/resources/benchmarks/benchmarks.py | 13 +++- .../resources/evaluations/evaluations.py | 13 +++- src/atlas/resources/models/models.py | 13 +++- src/atlas/resources/results/results.py | 8 +-- tests/resources/test_benchmarks.py | 28 ++++---- tests/resources/test_evaluations.py | 18 ++--- tests/resources/test_models_resource.py | 66 ++++++++++++------- tests/resources/test_results.py | 40 +++++------ tests/test_integration.py | 16 ++--- tests/test_models.py | 48 +++++++------- 15 files changed, 221 insertions(+), 154 deletions(-) diff --git a/src/atlas/_client.py b/src/atlas/_client.py index 7e70403..64c7e42 100644 --- a/src/atlas/_client.py +++ b/src/atlas/_client.py @@ -10,7 +10,7 @@ from . import _exceptions from ._utils import is_mapping -from .models import Organization +from .models import Organization, OrganizationResponse from ._constants import DEFAULT_TIMEOUT from ._exceptions import AtlasError, APIStatusError from ._base_client import BaseClient @@ -53,7 +53,7 @@ def __init__( if base_url is None: base_url = os.environ.get("LAYERLENS_ATLAS_BASE_URL") if base_url is None: - base_url = "https://8bg48mbhyi.execute-api.us-east-1.amazonaws.com/prod/api/v1/dgklmnr" + base_url = "https://8bg48mbhyi.execute-api.us-east-1.amazonaws.com/prod/api/v1" super().__init__( base_url=base_url, @@ -164,10 +164,10 @@ def _get_organization(self) -> Optional[Organization]: organization = super().get_cast( f"/organizations", timeout=30, - cast_to=Organization, + cast_to=OrganizationResponse, ) - if isinstance(organization, Organization): - return organization + if isinstance(organization, OrganizationResponse): + return organization.data return None diff --git a/src/atlas/models/__init__.py b/src/atlas/models/__init__.py index d596ffb..f522371 100644 --- a/src/atlas/models/__init__.py +++ b/src/atlas/models/__init__.py @@ -1,14 +1,23 @@ -from .api import Models, Results, Benchmarks, Pagination, Evaluations, ResultMetrics +from .api import ( + Pagination, + ResultMetrics, + ModelsResponse, + ResultsResponse, + BenchmarksResponse, + EvaluationsResponse, + OrganizationResponse, +) from .model import Model, CustomModel, PublicModel from .benchmark import Benchmark, CustomBenchmark, PublicBenchmark from .evaluation import Result, Evaluation, EvaluationStatus from .organization import Project, Organization __all__ = [ - "Benchmarks", - "Evaluations", - "Models", - "Results", + "BenchmarksResponse", + "EvaluationsResponse", + "ModelsResponse", + "OrganizationResponse", + "ResultsResponse", "Benchmark", "CustomBenchmark", "PublicBenchmark", diff --git a/src/atlas/models/api.py b/src/atlas/models/api.py index 62de046..3671b72 100644 --- a/src/atlas/models/api.py +++ b/src/atlas/models/api.py @@ -7,20 +7,31 @@ from .model import Model from .benchmark import Benchmark from .evaluation import Result, Evaluation +from .organization import Organization -class Benchmarks(BaseModel): - model_config = ConfigDict(populate_by_name=True) +class BenchmarksResponse(BaseModel): + class Data(BaseModel): + model_config = ConfigDict(populate_by_name=True) - benchmarks: List[Benchmark] = Field(..., alias="datasets") + benchmarks: List[Benchmark] = Field(..., alias="datasets") + data: Data -class Evaluations(BaseModel): + +class EvaluationsResponse(BaseModel): data: List[Evaluation] -class Models(BaseModel): - models: List[Model] +class ModelsResponse(BaseModel): + class Data(BaseModel): + models: List[Model] + + data: Data + + +class OrganizationResponse(BaseModel): + data: Organization class ResultMetrics(BaseModel): @@ -33,7 +44,7 @@ class Pagination(BaseModel): total_pages: int -class Results(BaseModel): +class ResultsResponse(BaseModel): evaluation_id: str results: List[Result] metrics: ResultMetrics diff --git a/src/atlas/models/benchmark.py b/src/atlas/models/benchmark.py index 0405937..ca30eb1 100644 --- a/src/atlas/models/benchmark.py +++ b/src/atlas/models/benchmark.py @@ -12,21 +12,21 @@ class Benchmark(BaseModel): class CustomBenchmark(Benchmark): - description: str - system_prompt: Optional[str] - prompt_count: int - version_count: int - regex_pattern: Optional[str] - llm_judge_model_id: str - custom_instructions: str - scoring_metric: Optional[str] - metrics: List[str] - files: List[str] - disabled: bool + description: Optional[str] = None + system_prompt: Optional[str] = None + prompt_count: Optional[int] = None + version_count: Optional[int] = None + regex_pattern: Optional[str] = None + llm_judge_model_id: Optional[str] = None + custom_instructions: Optional[str] = None + scoring_metric: Optional[str] = None + metrics: Optional[List[str]] = None + files: Optional[List[str]] = None + disabled: Optional[bool] = None class PublicBenchmark(Benchmark): - description: str = Field(..., alias="full_description") - language: str - prompt_count: int - deprecated: bool + description: Optional[str] = Field(None, alias="full_description") + language: Optional[str] = None + prompt_count: Optional[int] = None + deprecated: Optional[bool] = None diff --git a/src/atlas/models/model.py b/src/atlas/models/model.py index ad63814..be84e12 100644 --- a/src/atlas/models/model.py +++ b/src/atlas/models/model.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + from pydantic import BaseModel @@ -11,19 +13,19 @@ class Model(BaseModel): class CustomModel(Model): - max_tokens: int - api_url: str - disabled: bool + max_tokens: Optional[int] = None + api_url: Optional[str] = None + disabled: Optional[bool] = None class PublicModel(Model): - company: str - released_at: int - parameters: float - modality: str - context_length: int - architecture_type: str - license: str - open_weights: bool - region: str - deprecated: bool + company: Optional[str] = None + released_at: Optional[int] = None + parameters: Optional[float] = None + modality: Optional[str] = None + context_length: Optional[int] = None + architecture_type: Optional[str] = None + license: Optional[str] = None + open_weights: Optional[bool] = None + region: Optional[str] = None + deprecated: Optional[bool] = None diff --git a/src/atlas/resources/benchmarks/benchmarks.py b/src/atlas/resources/benchmarks/benchmarks.py index fd69f96..9e13225 100644 --- a/src/atlas/resources/benchmarks/benchmarks.py +++ b/src/atlas/resources/benchmarks/benchmarks.py @@ -4,7 +4,7 @@ import httpx -from ...models import Benchmark, Benchmarks as BenchmarksResponse +from ...models import Benchmark, CustomBenchmark, PublicBenchmark, BenchmarksResponse from ..._resource import SyncAPIResource from ..._constants import DEFAULT_TIMEOUT @@ -34,14 +34,21 @@ def fetch(bench_type: str) -> BenchmarksResponse | None: benchmarks: List[Benchmark] = [] + def cast_benchmark(b: Benchmark, bench_type: str) -> Benchmark: + if bench_type == "custom": + return CustomBenchmark(**b.model_dump()) + elif bench_type == "public": + return PublicBenchmark(**b.model_dump()) + return b # fallback, just base class + if type is None: for t in ["custom", "public"]: resp = fetch(t) if resp: - benchmarks.extend(resp.benchmarks) + benchmarks.extend([cast_benchmark(b, t) for b in resp.data.benchmarks]) else: # fetch only one type resp = fetch(type) if resp: - benchmarks.extend(resp.benchmarks) + benchmarks.extend([cast_benchmark(b, type) for b in resp.data.benchmarks]) return benchmarks diff --git a/src/atlas/resources/evaluations/evaluations.py b/src/atlas/resources/evaluations/evaluations.py index 2506ff7..4b779ae 100644 --- a/src/atlas/resources/evaluations/evaluations.py +++ b/src/atlas/resources/evaluations/evaluations.py @@ -2,7 +2,14 @@ import httpx -from ...models import Model, Benchmark, Evaluation, Evaluations as EvaluationsResponse +from ...models import ( + Model, + Benchmark, + Evaluation, + CustomModel, + CustomBenchmark, + EvaluationsResponse, +) from ..._resource import SyncAPIResource from ..._constants import DEFAULT_TIMEOUT @@ -21,8 +28,8 @@ def create( { "model_id": model.id, "dataset_id": benchmark.id, - "is_custom_model": False, - "is_custom_dataset": False, + "is_custom_model": isinstance(model, CustomModel), + "is_custom_dataset": isinstance(benchmark, CustomBenchmark), } ], timeout=timeout, diff --git a/src/atlas/resources/models/models.py b/src/atlas/resources/models/models.py index f76f188..3ec3bea 100644 --- a/src/atlas/resources/models/models.py +++ b/src/atlas/resources/models/models.py @@ -4,7 +4,7 @@ import httpx -from ...models import Model, Models as ModelsResponse +from ...models import Model, CustomModel, PublicModel, ModelsResponse from ..._resource import SyncAPIResource from ..._constants import DEFAULT_TIMEOUT @@ -43,14 +43,21 @@ def fetch(model_type: str) -> ModelsResponse | None: models: List[Model] = [] + def cast_model(m: Model, model_type: str) -> Model: + if model_type == "custom": + return CustomModel(**m.model_dump()) + elif model_type == "public": + return PublicModel(**m.model_dump()) + return m # fallback, just base class + if type is None: # fetch both for t in ["custom", "public"]: resp = fetch(t) if resp: - models.extend(resp.models) + models.extend([cast_model(m, t) for m in resp.data.models]) else: # fetch only one type resp = fetch(type) if resp: - models.extend(resp.models) + models.extend([cast_model(m, type) for m in resp.data.models]) return models diff --git a/src/atlas/resources/results/results.py b/src/atlas/resources/results/results.py index 66b18f0..6a16cad 100644 --- a/src/atlas/resources/results/results.py +++ b/src/atlas/resources/results/results.py @@ -7,7 +7,7 @@ from ..._resource import SyncAPIResource from ..._constants import DEFAULT_TIMEOUT -from ...models.api import Results as ResultsData +from ...models.api import ResultsResponse DEFAULT_PAGE_SIZE = 100 @@ -20,7 +20,7 @@ def get( page: Optional[int] = None, page_size: Optional[int] = None, timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, - ) -> ResultsData | None: + ) -> ResultsResponse | None: """ Get evaluation results with optional pagination. @@ -31,7 +31,7 @@ def get( timeout: Request timeout Returns: - ResultsData object containing: + ResultsResponse object containing: - evaluation_id: The evaluation ID - results: List of Result objects for the current page - metrics: Contains total_count and score ranges @@ -77,6 +77,6 @@ def get( } try: - return ResultsData.model_validate(response_with_pagination) + return ResultsResponse.model_validate(response_with_pagination) except Exception: return None diff --git a/tests/resources/test_benchmarks.py b/tests/resources/test_benchmarks.py index 292438d..5bf0aa0 100644 --- a/tests/resources/test_benchmarks.py +++ b/tests/resources/test_benchmarks.py @@ -5,9 +5,9 @@ from atlas.models import ( Benchmark, - Benchmarks as BenchmarksData, CustomBenchmark, PublicBenchmark, + BenchmarksResponse, ) from atlas._constants import DEFAULT_TIMEOUT from atlas.resources.benchmarks.benchmarks import Benchmarks @@ -68,11 +68,11 @@ def sample_custom_benchmark_data(self): @pytest.fixture def mock_benchmarks_response(self, sample_benchmark_data, sample_custom_benchmark_data): - """Mock BenchmarksData response with public benchmarks.""" + """Mock BenchmarksResponse response with public benchmarks.""" public_benchmark = Benchmark(**sample_benchmark_data) custom_benchmark = CustomBenchmark(**sample_custom_benchmark_data) - return BenchmarksData(datasets=[public_benchmark, custom_benchmark]) + return BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[public_benchmark, custom_benchmark])) def test_benchmarks_initialization(self, mock_client): """Benchmarks resource initializes correctly.""" @@ -86,7 +86,7 @@ def test_get_benchmarks_success(self, benchmarks_resource, mock_benchmarks_respo benchmarks_resource._get.side_effect = lambda *_, **kwargs: ( mock_benchmarks_response if kwargs.get("params", {}).get("type") == "public" - else BenchmarksData(benchmarks=[]) + else BenchmarksResponse(data=BenchmarksResponse.Data(benchmarks=[])) ) result = benchmarks_resource.get() @@ -113,13 +113,13 @@ def test_get_benchmarks_request_parameters(self, benchmarks_resource, mock_bench "/organizations/org-123/projects/proj-456/benchmarks", params={"type": "custom"}, timeout=DEFAULT_TIMEOUT, - cast_to=BenchmarksData, + cast_to=BenchmarksResponse, ), call( "/organizations/org-123/projects/proj-456/benchmarks", params={"type": "public"}, timeout=DEFAULT_TIMEOUT, - cast_to=BenchmarksData, + cast_to=BenchmarksResponse, ), ] @@ -150,7 +150,7 @@ def test_get_benchmarks_with_httpx_timeout(self, benchmarks_resource, mock_bench [ (None, []), ("invalid-response", []), - (BenchmarksData(datasets=[]), []), + (BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[])), []), ], ids=["none_response", "invalid_type", "empty_response"], ) @@ -177,9 +177,11 @@ def test_get_benchmarks_multiple_items( benchmark2_data["name"] = "HellaSwag" benchmark2 = Benchmark(**benchmark2_data) - response = BenchmarksData(datasets=[benchmark, benchmark2]) + response = BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[benchmark, benchmark2])) benchmarks_resource._get.side_effect = lambda *_, **kwargs: ( - response if kwargs.get("params", {}).get("type") == "public" else BenchmarksData(benchmarks=[]) + response + if kwargs.get("params", {}).get("type") == "public" + else BenchmarksResponse(data=BenchmarksResponse.Data(benchmarks=[])) ) result = benchmarks_resource.get() @@ -207,7 +209,7 @@ def test_get_benchmarks_cast_to_parameter(self, benchmarks_resource, mock_benchm benchmarks_resource.get() call_args = benchmarks_resource._get.call_args - assert call_args.kwargs["cast_to"] is BenchmarksData + assert call_args.kwargs["cast_to"] is BenchmarksResponse def test_get_benchmarks_timeout_default(self, benchmarks_resource, mock_benchmarks_response): """get method uses DEFAULT_TIMEOUT when no timeout specified.""" @@ -321,7 +323,7 @@ def test_get_benchmarks_return_type_consistency(self, benchmarks_resource): assert result == [] # Test that it returns a list when successful - benchmarks_resource._get.return_value = BenchmarksData(datasets=[]) + benchmarks_resource._get.return_value = BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[])) result = benchmarks_resource.get() assert result == [] @@ -362,9 +364,9 @@ def test_get_benchmarks_mixed_benchmark_types(self, benchmarks_resource): custom_benchmark = CustomBenchmark(**custom_data) benchmarks_resource._get.side_effect = lambda *_, **kwargs: ( - BenchmarksData(benchmarks=[public_benchmark]) + BenchmarksResponse(data=BenchmarksResponse.Data(benchmarks=[public_benchmark])) if kwargs.get("params", {}).get("type") == "public" - else BenchmarksData(benchmarks=[custom_benchmark]) + else BenchmarksResponse(data=BenchmarksResponse.Data(benchmarks=[custom_benchmark])) ) result = benchmarks_resource.get() # Type doesn't matter for this test diff --git a/tests/resources/test_evaluations.py b/tests/resources/test_evaluations.py index a5cf899..d510d02 100644 --- a/tests/resources/test_evaluations.py +++ b/tests/resources/test_evaluations.py @@ -3,7 +3,7 @@ import httpx import pytest -from atlas.models import Evaluation, Evaluations as EvaluationsData, EvaluationStatus +from atlas.models import Evaluation, EvaluationStatus, EvaluationsResponse from atlas._constants import DEFAULT_TIMEOUT from atlas.resources.evaluations.evaluations import Evaluations @@ -61,9 +61,9 @@ def sample_evaluation_data(self): @pytest.fixture def mock_evaluations_response(self, sample_evaluation_data): - """Mock EvaluationsData response.""" + """Mock EvaluationsResponse response.""" evaluation = Evaluation(**sample_evaluation_data) - return EvaluationsData(data=[evaluation]) + return EvaluationsResponse(data=[evaluation]) def test_evaluations_initialization(self, mock_client): """Evaluations resource initializes correctly.""" @@ -113,7 +113,7 @@ def test_create_evaluation_request_parameters( } ], timeout=DEFAULT_TIMEOUT, - cast_to=EvaluationsData, + cast_to=EvaluationsResponse, ) def test_create_evaluation_with_custom_timeout( @@ -158,7 +158,7 @@ def test_create_evaluation_with_httpx_timeout( def test_create_evaluation_empty_response(self, mock_model, mock_benchmark, evaluations_resource): """create method returns None when no evaluations in response.""" - empty_response = EvaluationsData(data=[]) + empty_response = EvaluationsResponse(data=[]) evaluations_resource._post.return_value = empty_response result = evaluations_resource.create(model=mock_model, benchmark=mock_benchmark) @@ -174,7 +174,7 @@ def test_create_evaluation_none_response(self, mock_model, mock_benchmark, evalu assert result is None def test_create_evaluation_invalid_response_type(self, mock_model, mock_benchmark, evaluations_resource): - """create method handles non-EvaluationsData response gracefully.""" + """create method handles non-EvaluationsResponse response gracefully.""" evaluations_resource._post.return_value = "invalid-response" result = evaluations_resource.create(model=mock_model, benchmark=mock_benchmark) @@ -190,7 +190,7 @@ def test_create_evaluation_multiple_evaluations_returns_first( eval2_data["id"] = "eval-456" eval2 = Evaluation(**eval2_data) - response = EvaluationsData(data=[eval1, eval2]) + response = EvaluationsResponse(data=[eval1, eval2]) evaluations_resource._post.return_value = response result = evaluations_resource.create(model=mock_model, benchmark=mock_benchmark) @@ -251,7 +251,7 @@ def test_create_evaluation_cast_to_parameter( evaluations_resource.create(model=mock_model, benchmark=mock_benchmark) call_args = evaluations_resource._post.call_args - assert call_args.kwargs["cast_to"] is EvaluationsData + assert call_args.kwargs["cast_to"] is EvaluationsResponse def test_create_evaluation_timeout_default( self, @@ -390,7 +390,7 @@ def test_create_evaluation_end_to_end_flow(self): } evaluation = Evaluation(**evaluation_data) - response = EvaluationsData(data=[evaluation]) + response = EvaluationsResponse(data=[evaluation]) mock_client.post_cast.return_value = response # Test the resource diff --git a/tests/resources/test_models_resource.py b/tests/resources/test_models_resource.py index 7cac60b..48a5ab6 100644 --- a/tests/resources/test_models_resource.py +++ b/tests/resources/test_models_resource.py @@ -3,7 +3,7 @@ import httpx import pytest -from atlas.models import Models as ModelsData, CustomModel, PublicModel +from atlas.models import CustomModel, PublicModel, ModelsResponse from atlas._constants import DEFAULT_TIMEOUT from atlas.resources.models.models import Models @@ -60,15 +60,15 @@ def sample_custom_model_data(self): @pytest.fixture def mock_public_models_response(self, sample_model_data): - """Mock ModelsData response with public models.""" + """Mock ModelsResponse response with public models.""" model = PublicModel(**sample_model_data) - return ModelsData(models=[model]) + return ModelsResponse(data=ModelsResponse.Data(models=[model])) @pytest.fixture def mock_custom_models_response(self, sample_custom_model_data): - """Mock ModelsData response with custom models.""" + """Mock ModelsResponse response with custom models.""" custom_model = CustomModel(**sample_custom_model_data) - return ModelsData(models=[custom_model]) + return ModelsResponse(data=ModelsResponse.Data(models=[custom_model])) def test_models_initialization(self, mock_client): """Models resource initializes correctly.""" @@ -80,7 +80,9 @@ def test_models_initialization(self, mock_client): def test_get_public_models_success(self, models_resource, mock_public_models_response): """get method returns public models successfully.""" models_resource._get.side_effect = lambda *_, **kwargs: ( - mock_public_models_response if kwargs.get("params", {}).get("type") == "public" else ModelsData(models=[]) + mock_public_models_response + if kwargs.get("params", {}).get("type") == "public" + else ModelsResponse(data=ModelsResponse.Data(models=[])) ) result = models_resource.get() @@ -95,7 +97,9 @@ def test_get_public_models_success(self, models_resource, mock_public_models_res def test_get_custom_models_success(self, models_resource, mock_custom_models_response): """get method returns custom models successfully.""" models_resource._get.side_effect = lambda *_, **kwargs: ( - mock_custom_models_response if kwargs.get("params", {}).get("type") == "custom" else ModelsData(models=[]) + mock_custom_models_response + if kwargs.get("params", {}).get("type") == "custom" + else ModelsResponse(data=ModelsResponse.Data(models=[])) ) result = models_resource.get() @@ -118,13 +122,13 @@ def test_get_models_request_parameters_public(self, models_resource, mock_public "/organizations/org-123/projects/proj-456/models", params={"type": "custom"}, timeout=DEFAULT_TIMEOUT, - cast_to=ModelsData, + cast_to=ModelsResponse, ), call( "/organizations/org-123/projects/proj-456/models", params={"type": "public"}, timeout=DEFAULT_TIMEOUT, - cast_to=ModelsData, + cast_to=ModelsResponse, ), ] @@ -141,13 +145,13 @@ def test_get_models_request_parameters_custom(self, models_resource, mock_custom "/organizations/org-123/projects/proj-456/models", params={"type": "custom"}, timeout=DEFAULT_TIMEOUT, - cast_to=ModelsData, + cast_to=ModelsResponse, ), call( "/organizations/org-123/projects/proj-456/models", params={"type": "public"}, timeout=DEFAULT_TIMEOUT, - cast_to=ModelsData, + cast_to=ModelsResponse, ), ] @@ -178,7 +182,10 @@ def test_get_models_with_httpx_timeout(self, models_resource, mock_public_models [ (None, []), # None response ("invalid-response", []), # Invalid type - (ModelsData(models=[]), []), # Empty ModelsData + ( + ModelsResponse(data=ModelsResponse.Data(models=[])), + [], + ), # Empty ModelsResponse ], ) def test_get_models_responses(self, models_resource, mock_response, expected): @@ -188,7 +195,7 @@ def test_get_models_responses(self, models_resource, mock_response, expected): result = models_resource.get() assert result == expected - if isinstance(mock_response, ModelsData): + if isinstance(mock_response, ModelsResponse): assert isinstance(result, list) def test_get_models_multiple_items(self, models_resource, sample_model_data): @@ -203,10 +210,12 @@ def test_get_models_multiple_items(self, models_resource, sample_model_data): model2_data["parameters"] = 1.75e11 model2 = PublicModel(**model2_data) - response = ModelsData(models=[model1, model2]) + response = ModelsResponse(data=ModelsResponse.Data(models=[model1, model2])) models_resource._get.side_effect = lambda *_, **kwargs: ( - response if kwargs.get("params", {}).get("type") == "public" else ModelsData(models=[]) + response + if kwargs.get("params", {}).get("type") == "public" + else ModelsResponse(data=ModelsResponse.Data(models=[])) ) result = models_resource.get() @@ -232,13 +241,15 @@ def test_get_models_url_construction(self, models_resource, mock_public_models_r def test_get_models_cast_to_parameter(self, models_resource, mock_public_models_response): """get method specifies correct cast_to parameter.""" models_resource._get.side_effect = lambda *_, **kwargs: ( - mock_public_models_response if kwargs.get("params", {}).get("type") == "public" else ModelsData(models=[]) + mock_public_models_response + if kwargs.get("params", {}).get("type") == "public" + else ModelsResponse(data=ModelsResponse.Data(models=[])) ) models_resource.get() call_args = models_resource._get.call_args - assert call_args.kwargs["cast_to"] is ModelsData + assert call_args.kwargs["cast_to"] is ModelsResponse def test_get_models_timeout_default(self, models_resource, mock_public_models_response): """get method uses DEFAULT_TIMEOUT when no timeout specified.""" @@ -260,7 +271,11 @@ def test_get_models_with_none_timeout(self, models_resource, mock_public_models_ def test_get_models_model_attributes(self, models_resource, mock_public_models_response): """get method preserves all model attributes correctly.""" - models_resource._get.return_value = mock_public_models_response + models_resource._get.side_effect = lambda *_, **kwargs: ( + mock_public_models_response + if kwargs.get("params", {}).get("type") == "public" + else ModelsResponse(data=ModelsResponse.Data(models=[])) + ) result = models_resource.get() model = result[0] @@ -374,7 +389,10 @@ def models_resource(self, mock_client): "mock_response, expected_type", [ (None, list), # None response - (ModelsData(models=[]), list), # Empty ModelsData + ( + ModelsResponse(data=ModelsResponse.Data(models=[])), + list, + ), # Empty ModelsResponse ], ) def test_get_models_return_type_consistency(self, models_resource, mock_response, expected_type): @@ -419,9 +437,9 @@ def test_get_models_mixed_model_types(self, models_resource): custom_model = CustomModel(**custom_data) models_resource._get.side_effect = lambda *_, **kwargs: ( - ModelsData(models=[public_model]) + ModelsResponse(data=ModelsResponse.Data(models=[public_model])) if kwargs.get("params", {}).get("type") == "public" - else ModelsData(models=[custom_model]) + else ModelsResponse(data=ModelsResponse.Data(models=[custom_model])) ) result = models_resource.get() # Type doesn't matter for this test @@ -454,9 +472,11 @@ def test_get_models_large_parameters_handling(self, models_resource): } large_model = PublicModel(**large_model_data) - response = ModelsData(models=[large_model]) + response = ModelsResponse(data=ModelsResponse.Data(models=[large_model])) models_resource._get.side_effect = lambda *_, **kwargs: ( - response if kwargs.get("params", {}).get("type") == "public" else ModelsData(models=[]) + response + if kwargs.get("params", {}).get("type") == "public" + else ModelsResponse(data=ModelsResponse.Data(models=[])) ) result = models_resource.get() diff --git a/tests/resources/test_results.py b/tests/resources/test_results.py index db05b4f..c3031b3 100644 --- a/tests/resources/test_results.py +++ b/tests/resources/test_results.py @@ -4,7 +4,7 @@ import httpx import pytest -from atlas.models import Result, Results as ResultsData, Pagination, ResultMetrics +from atlas.models import Result, Pagination, ResultMetrics, ResultsResponse from atlas._constants import DEFAULT_TIMEOUT from atlas.resources.results.results import Results @@ -60,12 +60,12 @@ def test_results_initialization(self, mock_client): assert results._get is mock_client.get_cast def test_get_results_success(self, results_resource, mock_results_response): - """get method returns ResultsData successfully.""" + """get method returns ResultsResponse successfully.""" results_resource._get.return_value = mock_results_response result = results_resource.get(evaluation_id="eval-123") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.evaluation_id == "eval-123" assert len(result.results) == 1 assert isinstance(result.results[0], Result) @@ -121,7 +121,7 @@ def test_get_results_none_response(self, results_resource): assert result is None def test_get_results_invalid_response_type(self, results_resource): - """get method handles non-ResultsData response gracefully.""" + """get method handles non-ResultsResponse response gracefully.""" results_resource._get.return_value = "invalid-response" result = results_resource.get(evaluation_id="eval-123") @@ -129,7 +129,7 @@ def test_get_results_invalid_response_type(self, results_resource): assert result is None def test_get_results_empty_response(self, results_resource): - """get method returns ResultsData with empty results list when no results in response.""" + """get method returns ResultsResponse with empty results list when no results in response.""" empty_response = { "evaluation_id": "eval-123", "results": [], @@ -145,7 +145,7 @@ def test_get_results_empty_response(self, results_resource): result = results_resource.get(evaluation_id="eval-123") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.evaluation_id == "eval-123" assert result.results == [] assert isinstance(result.results, list) @@ -180,7 +180,7 @@ def test_get_results_multiple_items(self, results_resource, sample_result_data): result = results_resource.get(evaluation_id="eval-123") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert len(result.results) == 2 assert result.results[0].subset == "mathematics" assert result.results[1].subset == "science" @@ -262,7 +262,7 @@ def test_get_results_with_different_evaluation_ids(self, results_resource, mock_ result = results_resource.get(evaluation_id=evaluation_id) - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) call_args = results_resource._get.call_args assert call_args.kwargs["params"]["evaluation_id"] == evaluation_id @@ -416,7 +416,7 @@ def test_get_results_handles_complex_metrics(self, results_resource): result = results_resource.get(evaluation_id="eval-complex") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert len(result.results) == 1 result_item = result.results[0] @@ -464,7 +464,7 @@ def test_get_results_handles_different_durations(self, results_resource): result = results_resource.get(evaluation_id="eval-durations") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert len(result.results) == 5 assert result.results[0].duration == timedelta(seconds=0.1) assert result.results[1].duration == timedelta(seconds=1.5) @@ -499,19 +499,19 @@ def test_get_results_handles_empty_metrics(self, results_resource): result = results_resource.get(evaluation_id="eval-minimal") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert len(result.results) == 1 assert result.results[0].metrics == {} assert isinstance(result.results[0].metrics, dict) def test_get_results_return_type_consistency(self, results_resource): """get method returns consistent types.""" - # Test that the method returns either a ResultsData object or None + # Test that the method returns either a ResultsResponse object or None results_resource._get.return_value = None result = results_resource.get(evaluation_id="eval-123") assert result is None - # Test that it returns a ResultsData object when successful + # Test that it returns a ResultsResponse object when successful empty_response = { "evaluation_id": "eval-123", "results": [], @@ -525,7 +525,7 @@ def test_get_results_return_type_consistency(self, results_resource): } results_resource._get.return_value = empty_response result = results_resource.get(evaluation_id="eval-123") - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) class TestResultsPagination: @@ -590,7 +590,7 @@ def test_get_results_with_pagination_parameters(self, results_resource, sample_r ) # Verify the response structure - assert isinstance(result_data, ResultsData) + assert isinstance(result_data, ResultsResponse) assert result_data.evaluation_id == "eval-paginated" assert result_data.pagination.total_count == 250 assert result_data.pagination.page_size == 50 @@ -663,7 +663,7 @@ def test_get_results_pagination_metadata_calculation(self, results_resource, sam result = results_resource.get(evaluation_id="eval-math", page=3, page_size=50) # Should have calculated pagination correctly - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.pagination.total_count == 487 assert result.pagination.page_size == 50 assert result.pagination.total_pages == 10 # ceil(487 / 50) = 10 @@ -760,7 +760,7 @@ def test_get_results_with_zero_total_count_in_metrics(self, results_resource): result = results_resource.get(evaluation_id="eval-123") # Should handle zero total_count gracefully - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.pagination.total_count == 0 assert result.pagination.total_pages == 0 @@ -805,7 +805,7 @@ def test_get_results_extreme_pagination_values(self, results_resource): result = results_resource.get(evaluation_id="eval-extreme", page_size=1) - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.pagination.total_count == 999999 assert result.pagination.page_size == 1 assert result.pagination.total_pages == 999999 # ceil(999999 / 1) @@ -828,7 +828,7 @@ def test_get_results_zero_page_size_edge_case(self, results_resource): # Pass 0 as page_size result = results_resource.get(evaluation_id="eval-123", page_size=0) - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) # Should use 0 as provided (though this might cause division by zero, it's handled) assert result.pagination.page_size == 0 @@ -856,7 +856,7 @@ def test_get_results_negative_page_values(self, results_resource): assert params["page"] == "-1" assert params["pageSize"] == "-50" - assert isinstance(result, ResultsData) + assert isinstance(result, ResultsResponse) assert result.pagination.page_size == -50 # total_pages calculation with negative page_size assert result.pagination.total_pages == 0 # math.ceil handles negative divisors diff --git a/tests/test_integration.py b/tests/test_integration.py index 2a26d10..5a7b65e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -8,11 +8,11 @@ from atlas.models import ( Model, Result, - Results as ResultsData, Benchmark, Evaluation, - Evaluations as EvaluationsData, + ResultsResponse, EvaluationStatus, + EvaluationsResponse, ) @@ -173,8 +173,8 @@ def test_complete_evaluation_workflow(self, atlas_client): result = Result(**result_data) # Mock responses - evaluations_response = EvaluationsData(data=[evaluation]) - results_response = ResultsData( + evaluations_response = EvaluationsResponse(data=[evaluation]) + results_response = ResultsResponse( evaluation_id="eval-789", results=[result], metrics={ @@ -257,7 +257,7 @@ def test_workflow_with_custom_timeouts(self, atlas_client): "metrics": {"accuracy": 1.0}, } - results_response = ResultsData( + results_response = ResultsResponse( evaluation_id="test-eval", results=[Result(**result_data)], metrics={ @@ -363,7 +363,7 @@ def test_evaluation_creation_with_model_and_benchmark_objects(self, atlas_client benchmark = Benchmark(**benchmark_data) evaluation = Evaluation(**evaluation_data) - evaluations_response = EvaluationsData(data=[evaluation]) + evaluations_response = EvaluationsResponse(data=[evaluation]) with patch.object(atlas_client, "post_cast") as mock_post: mock_post.return_value = evaluations_response @@ -416,7 +416,7 @@ def test_results_analysis_workflow(self, atlas_client): ] results = [Result(**data) for data in results_data] - results_response = ResultsData( + results_response = ResultsResponse( evaluation_id="test-eval", results=results, metrics={ @@ -559,7 +559,7 @@ def test_resource_operations_isolated(self, mock_org1, mock_org2): "metrics": {"accuracy": 1.0}, } - results_response = ResultsData( + results_response = ResultsResponse( evaluation_id="test-eval", results=[Result(**result_data)], metrics={ diff --git a/tests/test_models.py b/tests/test_models.py index a3eb0d6..0692b03 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,19 +4,19 @@ from pydantic import ValidationError from atlas.models import ( - Models, Result, - Results, - Benchmarks, Evaluation, Pagination, CustomModel, - Evaluations, PublicModel, ResultMetrics, + ModelsResponse, CustomBenchmark, PublicBenchmark, + ResultsResponse, EvaluationStatus, + BenchmarksResponse, + EvaluationsResponse, ) @@ -108,7 +108,7 @@ def evaluation_data(self): def test_evaluations_with_list_of_evaluations(self, evaluation_data): """Evaluations model accepts list of Evaluation objects.""" evaluations_data = {"data": [evaluation_data, evaluation_data]} - evaluations = Evaluations(**evaluations_data) + evaluations = EvaluationsResponse(**evaluations_data) assert len(evaluations.data) == 2 assert all(isinstance(eval, Evaluation) for eval in evaluations.data) @@ -116,7 +116,7 @@ def test_evaluations_with_list_of_evaluations(self, evaluation_data): def test_evaluations_empty_list(self): """Evaluations model accepts empty list.""" - evaluations = Evaluations(data=[]) + evaluations = EvaluationsResponse(data=[]) assert evaluations.data == [] assert isinstance(evaluations.data, list) @@ -124,7 +124,7 @@ def test_evaluations_empty_list(self): def test_evaluations_invalid_data_structure(self): """Evaluations model validates data structure.""" with pytest.raises(ValidationError): - Evaluations(data="not-a-list") # type: ignore[arg-type] + EvaluationsResponse(data="not-a-list") # type: ignore[arg-type] class TestResult: @@ -311,7 +311,7 @@ def valid_pagination_data(self): def test_results_with_pagination(self, valid_result_data, valid_metrics_data, valid_pagination_data): """Results model accepts all required fields including pagination.""" - results = Results( + results = ResultsResponse( evaluation_id="eval-123", results=[valid_result_data, valid_result_data], metrics=valid_metrics_data, @@ -329,7 +329,7 @@ def test_results_with_pagination(self, valid_result_data, valid_metrics_data, va def test_results_field_types(self, valid_result_data, valid_metrics_data, valid_pagination_data): """Results model enforces correct field types.""" - results = Results( + results = ResultsResponse( evaluation_id="eval-456", results=[valid_result_data], metrics=valid_metrics_data, @@ -343,7 +343,7 @@ def test_results_field_types(self, valid_result_data, valid_metrics_data, valid_ def test_results_empty_results_list(self, valid_metrics_data, valid_pagination_data): """Results model handles empty results list.""" - results = Results( + results = ResultsResponse( evaluation_id="eval-empty", results=[], metrics=valid_metrics_data, @@ -360,7 +360,7 @@ def test_results_validation_errors(self, valid_result_data, valid_metrics_data, """Results model validates required fields.""" # Test missing evaluation_id with pytest.raises(ValidationError): - Results( + ResultsResponse( results=[valid_result_data], metrics=valid_metrics_data, pagination=valid_pagination_data, @@ -368,7 +368,7 @@ def test_results_validation_errors(self, valid_result_data, valid_metrics_data, # Test missing metrics with pytest.raises(ValidationError): - Results( + ResultsResponse( evaluation_id="eval-123", results=[valid_result_data], pagination=valid_pagination_data, @@ -376,7 +376,7 @@ def test_results_validation_errors(self, valid_result_data, valid_metrics_data, # Test missing pagination with pytest.raises(ValidationError): - Results( + ResultsResponse( evaluation_id="eval-123", results=[valid_result_data], metrics=valid_metrics_data, @@ -386,7 +386,7 @@ def test_results_nested_model_validation(self, valid_result_data, valid_paginati """Results model validates nested models.""" # Test invalid metrics with pytest.raises(ValidationError): - Results( + ResultsResponse( evaluation_id="eval-123", results=[valid_result_data], metrics="invalid-metrics", # Should be ResultMetrics object @@ -395,7 +395,7 @@ def test_results_nested_model_validation(self, valid_result_data, valid_paginati # Test invalid pagination with pytest.raises(ValidationError): - Results( + ResultsResponse( evaluation_id="eval-123", results=[valid_result_data], metrics={ @@ -534,11 +534,13 @@ def test_models_with_mixed_model_types(self): "disabled": False, } - models = Models(models=[PublicModel(**model_data), CustomModel(**custom_model_data)]) # type: ignore[arg-type] + models = ModelsResponse( + data=ModelsResponse.Data(models=[PublicModel(**model_data), CustomModel(**custom_model_data)]) + ) - assert len(models.models) == 2 - assert isinstance(models.models[0], PublicModel) - assert isinstance(models.models[1], CustomModel) + assert len(models.data.models) == 2 + assert isinstance(models.data.models[0], PublicModel) + assert isinstance(models.data.models[1], CustomModel) class TestBenchmark: @@ -640,9 +642,9 @@ def test_benchmarks_with_datasets_alias(self): } # Using the alias 'datasets' - benchmarks = Benchmarks(datasets=[PublicBenchmark(**benchmark_data)]) # type: ignore[arg-type] + benchmarks = BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[PublicBenchmark(**benchmark_data)])) - assert isinstance(benchmarks.benchmarks[0], PublicBenchmark) + assert isinstance(benchmarks.data.benchmarks[0], PublicBenchmark) def test_benchmarks_field_validation(self): """Benchmarks validates field structure correctly.""" @@ -657,9 +659,9 @@ def test_benchmarks_field_validation(self): "deprecated": False, } - benchmarks = Benchmarks(datasets=[benchmark_data]) # type: ignore[arg-type] + benchmarks = BenchmarksResponse(data=BenchmarksResponse.Data(datasets=[benchmark_data])) - assert len(benchmarks.benchmarks) == 1 + assert len(benchmarks.data.benchmarks) == 1 class TestModelSerialization: