diff --git a/docs/README.md b/docs/README.md index 51f04e1..ced013e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -36,8 +36,12 @@ evaluation = client.evaluations.create( # Get results if evaluation: - results = client.results.get(evaluation_id=evaluation.id) - print(f"Evaluation completed with {len(results)} results") + results_data = client.results.get(evaluation_id=evaluation.id) + if results_data: + print(f"Evaluation completed with {len(results_data.results)} results") + print(f"Total results available: {results_data.pagination.total_count}") + if results_data.pagination.total_pages > 1: + print(f"Results span {results_data.pagination.total_pages} pages - use pagination to access all") ``` ## Navigation diff --git a/docs/api-reference/results.md b/docs/api-reference/results.md index 54d6196..28dca26 100644 --- a/docs/api-reference/results.md +++ b/docs/api-reference/results.md @@ -8,34 +8,49 @@ Results contain detailed information about each test case in an evaluation, incl ## Methods -### `get(evaluation_id, timeout=None)` +### `get(evaluation_id, page=None, page_size=None, timeout=None)` -Retrieves detailed results for a specific evaluation. +Retrieves detailed results for a specific evaluation with optional pagination support. #### Parameters | Parameter | Type | Required | Description | |-----------|------|----------|-------------| | `evaluation_id` | `str` | Yes | The evaluation identifier to get results for | +| `page` | `int \| None` | No | Page number for pagination (1-based). If not provided, returns first page or all results based on API default | +| `page_size` | `int \| None` | No | Number of results per page (default: 100). Maximum allowed may be limited by API | | `timeout` | `float \| httpx.Timeout \| None` | No | Override request timeout | #### Returns -Returns a list of `Result` objects if successful, `None` if no results are found or the evaluation doesn't exist. +Returns a `ResultsData` object containing results, evaluation metadata, and pagination information if successful, `None` if no results are found or the evaluation doesn't exist. -#### Example +The `ResultsData` object includes: +- `results`: List of `Result` objects for the current page +- `evaluation_id`: The evaluation ID +- `metrics`: Performance metrics including score ranges +- `pagination`: Pagination metadata (total_count, page_size, total_pages) +#### Examples + +##### Basic Usage (All Results) ```python from atlas import Atlas client = Atlas() -# Get results for a specific evaluation -results = client.results.get(evaluation_id="eval_12345") +# Get all results for a specific evaluation +results_data = client.results.get(evaluation_id="eval_12345") -if results: - print(f"Retrieved {len(results)} results") - for i, result in enumerate(results[:3]): # Show first 3 +if results_data: + print(f"Evaluation ID: {results_data.evaluation_id}") + print(f"Retrieved {len(results_data.results)} results") + print(f"Total available: {results_data.pagination.total_count}") + print(f"Page size: {results_data.pagination.page_size}") + print(f"Total pages: {results_data.pagination.total_pages}") + + # Access individual results + for i, result in enumerate(results_data.results[:3]): # Show first 3 print(f"\nResult {i+1}:") print(f" Subset: {result.subset}") print(f" Score: {result.score}") @@ -44,16 +59,103 @@ else: print("No results found or evaluation doesn't exist") ``` +##### Paginated Access +```python +# Get specific page with custom page size +results_data = client.results.get( + evaluation_id="eval_12345", + page=2, + page_size=50 +) + +if results_data: + print(f"Page 2 of {results_data.pagination.total_pages}") + print(f"Showing {len(results_data.results)} of {results_data.pagination.total_count} total results") + + # Process current page + for result in results_data.results: + # Process each result + pass +``` + +##### Iterating Through All Pages +```python +# Process all results by iterating through pages +evaluation_id = "eval_12345" +page = 1 +page_size = 100 + +while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + print(f"Processing page {page}/{results_data.pagination.total_pages}") + + # Process current page results + for result in results_data.results: + # Your processing logic here + pass + + # Move to next page + if page >= results_data.pagination.total_pages: + break + page += 1 + +print("Finished processing all results") +``` + #### With Custom Timeout ```python -# Get results with custom timeout (2 minutes) -results = client.results.get( +# Get results with custom timeout (2 minutes) and pagination +results_data = client.results.get( evaluation_id="eval_12345", + page=1, + page_size=50, timeout=120.0 ) ``` +## Pagination Information + +The `pagination` object in the response provides detailed pagination metadata: + +```python +results_data = client.results.get(evaluation_id="eval_12345", page=1, page_size=50) + +if results_data: + pagination = results_data.pagination + + print(f"Current page info:") + print(f" Total results available: {pagination.total_count}") + print(f" Results per page: {pagination.page_size}") + print(f" Total pages: {pagination.total_pages}") + print(f" Results on current page: {len(results_data.results)}") + + # Calculate current page number (if needed) + # Page number isn't stored in pagination object, so track it yourself + current_page = 1 # You would track this in your code + print(f" Current page: {current_page}") + + # Check if there are more pages + has_more_pages = current_page < pagination.total_pages + print(f" Has more pages: {has_more_pages}") +``` + +### Pagination Properties + +| Property | Type | Description | +|----------|------|-------------| +| `total_count` | `int` | Total number of results available across all pages | +| `page_size` | `int` | Number of results per page (as requested or default) | +| `total_pages` | `int` | Total number of pages available | + ## Result Object Each `Result` object contains the following properties: @@ -92,16 +194,18 @@ def analyze_evaluation_results(evaluation_id: str): try: # Get results - results = client.results.get(evaluation_id=evaluation_id) + results_data = client.results.get(evaluation_id=evaluation_id) - if not results: - print(f"āŒ No results found for evaluation {evaluation_id}") + if not results_data: + print(f"No results found for evaluation {evaluation_id}") return - print(f"šŸ“Š Analysis for evaluation {evaluation_id}") - print(f"šŸ“ˆ Total test cases: {len(results)}") + results = results_data.results + print(f"Analysis for evaluation {evaluation_id}") + print(f"Total test cases: {results_data.pagination.total_count}") + print(f"Results on current page: {len(results)}") - # Calculate overall statistics + # Calculate overall statistics for current page total_score = sum(result.score for result in results) avg_score = total_score / len(results) correct_answers = sum(1 for result in results if result.score > 0.5) @@ -128,16 +232,16 @@ def analyze_evaluation_results(evaluation_id: str): subset_stats[result.subset]["scores"].append(result.score) subset_stats[result.subset]["count"] += 1 - print(f"\nšŸ“‹ Performance by Subset:") + print(f"\nPerformance by Subset:") for subset, stats in subset_stats.items(): subset_avg = sum(stats["scores"]) / len(stats["scores"]) subset_acc = sum(1 for s in stats["scores"] if s > 0.5) / len(stats["scores"]) print(f" {subset}: {subset_acc:.1%} accuracy ({subset_avg:.3f} avg score, {stats['count']} cases)") # Show some example results - print(f"\nšŸ” Sample Results:") + print(f"\nSample Results:") for i, result in enumerate(results[:3]): - status = "āœ… Correct" if result.score > 0.5 else "āŒ Incorrect" + status = "Correct" if result.score > 0.5 else "Incorrect" print(f"\n Example {i+1} [{result.subset}] - {status}") print(f" Prompt: {result.prompt[:100]}...") print(f" Model Answer: {result.result[:100]}...") @@ -150,13 +254,13 @@ def analyze_evaluation_results(evaluation_id: str): return results except atlas.NotFoundError: - print(f"āŒ Evaluation {evaluation_id} not found") + print(f"Evaluation {evaluation_id} not found") except atlas.AuthenticationError: - print("āŒ Authentication failed - check your API key") + print("Authentication failed - check your API key") except atlas.APIConnectionError as e: - print(f"āŒ Connection error: {e}") + print(f"Connection error: {e}") except atlas.APIError as e: - print(f"āŒ API error: {e}") + print(f"API error: {e}") return None @@ -168,31 +272,116 @@ if __name__ == "__main__": ## Working with Large Result Sets -For evaluations with many test cases, consider processing results in batches: +For evaluations with many test cases, use pagination to efficiently process results: ```python from atlas import Atlas -def process_results_efficiently(evaluation_id: str): +def process_results_efficiently(evaluation_id: str, page_size: int = 100): + """Process large result sets using pagination""" client = Atlas() - results = client.results.get(evaluation_id=evaluation_id) - if not results: + # Get first page to understand total scope + first_page = client.results.get(evaluation_id=evaluation_id, page=1, page_size=page_size) + if not first_page: + print("No results found") return - print(f"Processing {len(results)} results...") + total_count = first_page.pagination.total_count + total_pages = first_page.pagination.total_pages - # Process in chunks to avoid memory issues with very large result sets - chunk_size = 100 - for i in range(0, len(results), chunk_size): - chunk = results[i:i+chunk_size] + print(f"Processing {total_count} results across {total_pages} pages...") + + # Process each page + for page_num in range(1, total_pages + 1): + print(f"Processing page {page_num}/{total_pages}...") - print(f"Processing results {i+1}-{min(i+chunk_size, len(results))}...") + # Get current page (reuse first_page for page 1) + if page_num == 1: + results_data = first_page + else: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page_num, + page_size=page_size + ) - # Process this chunk - for result in chunk: + if not results_data: + print(f"Failed to get page {page_num}") + continue + + # Process current page + for result in results_data.results: # Your processing logic here pass + + print(f"Completed page {page_num} ({len(results_data.results)} results)") + + print(f"Finished processing all {total_count} results") + +# Usage +process_results_efficiently("eval_12345", page_size=50) +``` + +### Memory-Efficient Processing + +The pagination approach is more memory-efficient than loading all results at once: + +```python +# Good - Memory efficient with pagination +def analyze_large_evaluation(evaluation_id: str): + client = Atlas() + + # Aggregate statistics across pages + total_processed = 0 + total_score = 0 + total_correct = 0 + + page = 1 + page_size = 100 + + while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + # Process current page + page_score = sum(r.score for r in results_data.results) + page_correct = sum(1 for r in results_data.results if r.score > 0.5) + + total_score += page_score + total_correct += page_correct + total_processed += len(results_data.results) + + print(f"Page {page}: {len(results_data.results)} results, {page_correct} correct") + + # Check if we're done + if page >= results_data.pagination.total_pages: + break + + page += 1 + + # Final statistics + overall_accuracy = total_correct / total_processed if total_processed > 0 else 0 + overall_avg_score = total_score / total_processed if total_processed > 0 else 0 + + print(f"\nFinal Results:") + print(f" Total processed: {total_processed}") + print(f" Overall accuracy: {overall_accuracy:.1%}") + print(f" Overall average score: {overall_avg_score:.3f}") + +# Bad - Loads everything into memory at once (may cause issues with large datasets) +def analyze_evaluation_inefficient(evaluation_id: str): + results_data = client.results.get(evaluation_id=evaluation_id) # No pagination + # This could load thousands of results into memory + for result in results_data.results: + # Process all results at once + pass ``` ## Filtering and Analysis @@ -296,14 +485,14 @@ def safe_get_results(client, evaluation_id): Results can contain thousands of individual test cases. Consider: ```python -# āœ… Good - check result size first +# Good - check result size first results = client.results.get(evaluation_id="eval_12345") if results: print(f"Retrieved {len(results)} results") if len(results) > 1000: print("Large result set - consider processing in chunks") -# āŒ Bad - not considering memory usage +# Bad - not considering memory usage results = client.results.get(evaluation_id="eval_12345") # Process all results in memory without considering size ``` @@ -338,41 +527,41 @@ def get_cached_results(client, evaluation_id, cache_dir="cache"): ### 1. Always Check for Results ```python -# āœ… Good - check if results exist +# Good - check if results exist results = client.results.get(evaluation_id="eval_12345") if results: print(f"Found {len(results)} results") else: print("No results available") -# āŒ Bad - assume results exist +# Bad - assume results exist results = client.results.get(evaluation_id="eval_12345") print(f"Found {len(results)} results") # Could raise AttributeError ``` ### 2. Handle Large Result Sets Appropriately ```python -# āœ… Good - process in chunks for large sets +# Good - process in chunks for large sets if len(results) > 1000: for i in range(0, len(results), 100): chunk = results[i:i+100] process_chunk(chunk) -# āŒ Bad - process everything in memory +# Bad - process everything in memory for result in results: # Could be thousands of results expensive_processing(result) ``` ### 3. Use Meaningful Analysis ```python -# āœ… Good - extract meaningful insights +# Good - extract meaningful insights subset_performance = {} for result in results: if result.subset not in subset_performance: subset_performance[result.subset] = [] subset_performance[result.subset].append(result.score) -# āŒ Bad - just print raw data +# Bad - just print raw data for result in results: print(result.score) # Not very useful ``` diff --git a/docs/examples/advanced-usage.md b/docs/examples/advanced-usage.md index 6db48c5..95695f0 100644 --- a/docs/examples/advanced-usage.md +++ b/docs/examples/advanced-usage.md @@ -24,6 +24,120 @@ Required environment variables: - `LAYERLENS_ATLAS_ORG_ID` - Your organization ID - `LAYERLENS_ATLAS_PROJECT_ID` - Your project ID +## Pagination Best Practices + +### Understanding Pagination + +The Atlas SDK automatically handles pagination for large result sets. When evaluation results exceed the default page size (100), you'll need to iterate through pages to access all data. + +```python +from atlas import Atlas + +def understand_pagination(evaluation_id: str): + """Understand pagination metadata""" + client = Atlas() + + # Get first page + results_data = client.results.get(evaluation_id=evaluation_id) + + if results_data: + pagination = results_data.pagination + + print(f" Pagination Overview:") + print(f" Total results: {pagination.total_count:,}") + print(f" Page size: {pagination.page_size}") + print(f" Total pages: {pagination.total_pages}") + print(f" Current page has: {len(results_data.results)} results") + + # Calculate some useful info + is_paginated = pagination.total_pages > 1 + results_per_page = pagination.page_size + last_page_size = pagination.total_count % pagination.page_size or pagination.page_size + + print(f"\n Analysis:") + print(f" Is paginated: {is_paginated}") + print(f" Results per page: {results_per_page}") + print(f" Last page size: {last_page_size}") + + if is_paginated: + print(f"\n To access all {pagination.total_count:,} results:") + print(f" - Iterate through {pagination.total_pages} pages") + print(f" - Or use batch processing patterns") + + return pagination + + return None + +# Usage +pagination_info = understand_pagination("eval_12345") +``` + +### Efficient Pagination Strategies + +```python +def efficient_pagination_strategies(): + """Demonstrate different pagination approaches""" + client = Atlas() + evaluation_id = "eval_12345" + + # Strategy 1: Small pages for real-time processing + print(" Strategy 1: Small pages for real-time feedback") + page_size = 25 + page = 1 + + while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + print(f" Processing page {page}: {len(results_data.results)} results") + + # Process immediately + for result in results_data.results: + # Real-time processing logic + pass + + if page >= results_data.pagination.total_pages: + break + page += 1 + + print("\n Strategy 2: Large pages for batch processing") + page_size = 200 # Larger pages + page = 1 + + while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + print(f" Batch processing page {page}: {len(results_data.results)} results") + + # Batch process entire page + process_batch(results_data.results) + + if page >= results_data.pagination.total_pages: + break + page += 1 + +def process_batch(results): + """Process a batch of results efficiently""" + # Batch processing logic here + pass + +# Usage +efficient_pagination_strategies() +``` + ## Batch Processing ### Running Multiple Evaluations @@ -55,7 +169,7 @@ def run_evaluation_batch(models, benchmarks): 'benchmark': benchmark, 'evaluation_id': evaluation.id }) - print(f"āœ… Created: {evaluation.id}") + print(f" Created: {evaluation.id}") else: results['failed'].append({ 'model': model, @@ -68,7 +182,7 @@ def run_evaluation_batch(models, benchmarks): time.sleep(60) except atlas.APIError as e: - print(f"āŒ Failed: {e}") + print(f" Failed: {e}") results['failed'].append({ 'model': model, 'benchmark': benchmark, @@ -84,8 +198,8 @@ models = ["gpt-4", "claude-3-opus"] benchmarks = ["mmlu", "hellaswag"] batch_results = run_evaluation_batch(models, benchmarks) -print(f"āœ… Successful: {len(batch_results['successful'])}") -print(f"āŒ Failed: {len(batch_results['failed'])}") +print(f" Successful: {len(batch_results['successful'])}") +print(f" Failed: {len(batch_results['failed'])}") ``` ## Error Handling Patterns @@ -109,7 +223,7 @@ def create_evaluation_with_retries(model, benchmark, max_retries=3): ) if evaluation: - print(f"āœ… Success on attempt {attempt + 1}") + print(f" Success on attempt {attempt + 1}") return evaluation except atlas.RateLimitError as e: @@ -124,15 +238,15 @@ def create_evaluation_with_retries(model, benchmark, max_retries=3): raise except atlas.NotFoundError: - print(f"āŒ Model '{model}' or benchmark '{benchmark}' not found") + print(f" Model '{model}' or benchmark '{benchmark}' not found") return None except atlas.AuthenticationError: - print("āŒ Authentication failed - check your API key") + print(" Authentication failed - check your API key") raise except atlas.APIError as e: - print(f"āŒ API error on attempt {attempt + 1}: {e}") + print(f" API error on attempt {attempt + 1}: {e}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # Exponential backoff else: @@ -205,7 +319,7 @@ def analyze_evaluation_results(evaluation_id: str) -> Dict: # Usage analysis = analyze_evaluation_results("eval_123") if "error" not in analysis: - print(f"šŸ“Š Analysis Results:") + print(f" Analysis Results:") print(f" Total results: {analysis['total_results']}") print(f" Overall accuracy: {analysis['overall_accuracy']:.2%}") print(f" Average duration: {analysis['avg_duration']:.2f}s") @@ -357,9 +471,9 @@ def check_atlas_health(): # Usage health = check_atlas_health() if health["status"] == "healthy": - print("āœ… Atlas service is healthy") + print(" Atlas service is healthy") else: - print(f"āŒ Atlas service is unhealthy: {health['error']}") + print(f" Atlas service is unhealthy: {health['error']}") ``` ## Integration Patterns diff --git a/docs/examples/retrieving-results.md b/docs/examples/retrieving-results.md index 9b4261a..19c6b68 100644 --- a/docs/examples/retrieving-results.md +++ b/docs/examples/retrieving-results.md @@ -14,13 +14,16 @@ client = Atlas() # Get results for a specific evaluation evaluation_id = "eval_12345" # Replace with your evaluation ID -results = client.results.get(evaluation_id=evaluation_id) +results_data = client.results.get(evaluation_id=evaluation_id) -if results: - print(f"šŸ“Š Retrieved {len(results)} results") +if results_data: + print(f"Evaluation: {results_data.evaluation_id}") + print(f"Retrieved {len(results_data.results)} results (page 1)") + print(f"Total available: {results_data.pagination.total_count}") + print(f"Total pages: {results_data.pagination.total_pages}") # Show first few results - for i, result in enumerate(results[:3]): + for i, result in enumerate(results_data.results[:3]): print(f"\nResult {i+1}:") print(f" Subset: {result.subset}") print(f" Prompt: {result.prompt[:100]}...") @@ -29,7 +32,42 @@ if results: print(f" Score: {result.score}") print(f" Duration: {result.duration}") else: - print("āŒ No results found") + print("No results found") +``` + +### Paginated Result Retrieval + +```python +from atlas import Atlas + +# Initialize client +client = Atlas() + +def get_paginated_results(evaluation_id: str, page_size: int = 50): + """Get results with pagination control""" + + # Get specific page + results_data = client.results.get( + evaluation_id=evaluation_id, + page=2, # Get second page + page_size=page_size + ) + + if results_data: + pagination = results_data.pagination + print(f"Pagination Info:") + print(f" Total results: {pagination.total_count}") + print(f" Page size: {pagination.page_size}") + print(f" Total pages: {pagination.total_pages}") + print(f" Current page results: {len(results_data.results)}") + + return results_data + else: + print("No results found") + return None + +# Usage +paginated_results = get_paginated_results("eval_12345", page_size=25) ``` ### Complete Evaluation Workflow @@ -43,19 +81,19 @@ def complete_evaluation_workflow(model: str, benchmark: str): client = Atlas() # Step 1: Create evaluation - print(f"šŸ”„ Creating evaluation: {model} + {benchmark}") + print(f"Creating evaluation: {model} + {benchmark}") evaluation = client.evaluations.create(model=model, benchmark=benchmark) if not evaluation: - print("āŒ Failed to create evaluation") + print("Failed to create evaluation") return None - print(f"āœ… Evaluation created: {evaluation.id}") + print(f"Evaluation created: {evaluation.id}") print(f" Status: {evaluation.status}") # Step 2: Wait for completion (simplified polling) # In production, use webhooks instead of polling - print("ā³ Waiting for evaluation to complete...") + print("Waiting for evaluation to complete...") # Note: This is a simplified example. In practice, you'd: # 1. Use webhooks for real-time updates @@ -63,28 +101,35 @@ def complete_evaluation_workflow(model: str, benchmark: str): # 3. Handle various status states properly if evaluation.status == "completed": - print("šŸŽ‰ Evaluation completed!") + print("Evaluation completed!") # Step 3: Retrieve results - results = client.results.get(evaluation_id=evaluation.id) + results_data = client.results.get(evaluation_id=evaluation.id) - if results: - print(f"šŸ“Š Retrieved {len(results)} detailed results") + if results_data: + results = results_data.results + print(f"Retrieved {len(results)} results from page 1") + print(f"Total results available: {results_data.pagination.total_count}") - # Basic analysis + # Basic analysis for current page correct_answers = sum(1 for r in results if r.score > 0.5) accuracy = correct_answers / len(results) avg_duration = sum(r.duration for r in results) / len(results) - print(f"šŸ“ˆ Quick Analysis:") + print(f"Quick Analysis (Page 1):") print(f" Accuracy: {accuracy:.1%} ({correct_answers}/{len(results)})") print(f" Average Duration: {avg_duration}") - return results + # Note about pagination + if results_data.pagination.total_pages > 1: + print(f"Note: This evaluation has {results_data.pagination.total_pages} pages total") + print(f" Use pagination to process all {results_data.pagination.total_count} results") + + return results_data else: - print("āŒ No results available") + print("No results available") else: - print(f"ā° Evaluation status: {evaluation.status}") + print(f"Evaluation status: {evaluation.status}") print(" Check back later for results") return None @@ -103,16 +148,58 @@ from collections import defaultdict, Counter import statistics from datetime import timedelta -def analyze_evaluation_performance(evaluation_id: str): +def analyze_evaluation_performance(evaluation_id: str, use_all_pages: bool = True): """Comprehensive performance analysis of evaluation results""" client = Atlas() - results = client.results.get(evaluation_id=evaluation_id) - if not results: - print(f"āŒ No results found for evaluation {evaluation_id}") - return None + if use_all_pages: + # Get all results across all pages for complete analysis + all_results = [] + page = 1 + page_size = 100 + + while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + all_results.extend(results_data.results) + + # Use pagination info from the first page + if page == 1: + total_count = results_data.pagination.total_count + total_pages = results_data.pagination.total_pages + print(f"Loading {total_count} results from {total_pages} pages...") + + print(f" Loaded page {page}/{total_pages}") + + if page >= results_data.pagination.total_pages: + break + + page += 1 + + results = all_results + + if not results: + print(f"No results found for evaluation {evaluation_id}") + return None + + else: + # Analyze just the first page + results_data = client.results.get(evaluation_id=evaluation_id, page=1, page_size=100) + if not results_data: + print(f"No results found for evaluation {evaluation_id}") + return None + + results = results_data.results + print(f"Analyzing first page only ({len(results)} of {results_data.pagination.total_count} total results)") - print(f"šŸ“Š Performance Analysis for {evaluation_id}") + print(f"Performance Analysis for {evaluation_id}") print(f"{'='*60}") # Overall statistics @@ -164,7 +251,7 @@ def analyze_evaluation_performance(evaluation_id: str): else: score_ranges["Zero (0.0)"] += 1 - print(f"\nšŸ“ˆ Score Distribution:") + print(f"\nScore Distribution:") for range_name, count in score_ranges.items(): percentage = count / total_cases * 100 print(f" {range_name}: {count:,} ({percentage:.1f}%)") @@ -176,7 +263,7 @@ def analyze_evaluation_performance(evaluation_id: str): subset_stats[result.subset]["scores"].append(result.score) subset_stats[result.subset]["durations"].append(result.duration) - print(f"\nšŸ“‹ Performance by Subset:") + print(f"\nPerformance by Subset:") print(f"{'Subset':<25} {'Cases':<8} {'Accuracy':<10} {'Avg Score':<10} {'Avg Duration':<12}") print("-" * 75) @@ -197,8 +284,213 @@ def analyze_evaluation_performance(evaluation_id: str): "subset_stats": dict(subset_stats) } +# Usage - analyze all results across all pages +analysis = analyze_evaluation_performance("eval_12345", use_all_pages=True) + +# Usage - analyze only first page (faster for quick checks) +quick_analysis = analyze_evaluation_performance("eval_12345", use_all_pages=False) +``` + +## Pagination Patterns + +### Pattern 1: Processing All Results Across Pages + +```python +from atlas import Atlas + +def process_all_results(evaluation_id: str): + """Process all results by iterating through all pages""" + client = Atlas() + + # Aggregate statistics across all pages + total_results = 0 + total_score = 0 + total_correct = 0 + all_subsets = set() + + page = 1 + page_size = 100 + + print("Processing all pages...") + + while True: + print(f"Fetching page {page}...") + + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + # Show progress on first page + if page == 1: + print(f"Total: {results_data.pagination.total_count} results across {results_data.pagination.total_pages} pages") + + # Process current page + current_results = results_data.results + page_score = sum(r.score for r in current_results) + page_correct = sum(1 for r in current_results if r.score > 0.5) + page_subsets = set(r.subset for r in current_results) + + # Aggregate + total_results += len(current_results) + total_score += page_score + total_correct += page_correct + all_subsets.update(page_subsets) + + print(f" Page {page}: {len(current_results)} results, {page_correct} correct, {len(page_subsets)} subsets") + + # Check if we're done + if page >= results_data.pagination.total_pages: + break + + page += 1 + + # Final summary + if total_results > 0: + overall_accuracy = total_correct / total_results + overall_avg_score = total_score / total_results + + print(f"\n Final Statistics:") + print(f" Total results processed: {total_results:,}") + print(f" Overall accuracy: {overall_accuracy:.1%}") + print(f" Overall average score: {overall_avg_score:.3f}") + print(f" Unique subsets: {len(all_subsets)}") + print(f" Subsets: {', '.join(sorted(all_subsets))}") + + return { + "total_results": total_results, + "accuracy": overall_accuracy if total_results > 0 else 0, + "avg_score": overall_avg_score if total_results > 0 else 0, + "subsets": list(all_subsets) + } + +# Usage +stats = process_all_results("eval_12345") +``` + +### Pattern 2: Selective Page Processing + +```python +def process_specific_pages(evaluation_id: str, start_page: int = 1, end_page: int = None): + """Process only specific pages of results""" + client = Atlas() + + # Get first page to understand scope + first_page = client.results.get(evaluation_id=evaluation_id, page=1, page_size=100) + if not first_page: + print(" No results found") + return None + + total_pages = first_page.pagination.total_pages + total_count = first_page.pagination.total_count + + # Set end page if not specified + if end_page is None: + end_page = total_pages + + # Validate range + end_page = min(end_page, total_pages) + start_page = max(start_page, 1) + + print(f" Processing pages {start_page}-{end_page} of {total_pages} (total: {total_count} results)") + + processed_results = [] + + for page_num in range(start_page, end_page + 1): + # Reuse first page if processing from page 1 + if page_num == 1 and start_page == 1: + results_data = first_page + else: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page_num, + page_size=100 + ) + + if not results_data: + print(f" Failed to get page {page_num}") + continue + + processed_results.extend(results_data.results) + print(f" Processed page {page_num}: {len(results_data.results)} results") + + print(f" Processed {len(processed_results)} results from pages {start_page}-{end_page}") + return processed_results + +# Usage examples +first_100_results = process_specific_pages("eval_12345", start_page=1, end_page=1) +middle_pages = process_specific_pages("eval_12345", start_page=5, end_page=10) +last_few_pages = process_specific_pages("eval_12345", start_page=18, end_page=20) +``` + +### Pattern 3: Smart Pagination with Early Stopping + +```python +def analyze_with_early_stopping(evaluation_id: str, min_accuracy_threshold: float = 0.7): + """Stop processing if accuracy drops below threshold""" + client = Atlas() + + page = 1 + page_size = 100 + total_processed = 0 + total_correct = 0 + + print(f"šŸŽÆ Processing until accuracy drops below {min_accuracy_threshold:.1%}") + + while True: + results_data = client.results.get( + evaluation_id=evaluation_id, + page=page, + page_size=page_size + ) + + if not results_data or not results_data.results: + break + + # Process current page + current_results = results_data.results + page_correct = sum(1 for r in current_results if r.score > 0.5) + + total_processed += len(current_results) + total_correct += page_correct + + current_accuracy = total_correct / total_processed + page_accuracy = page_correct / len(current_results) + + print(f" Page {page}: {page_accuracy:.1%} accuracy ({page_correct}/{len(current_results)})") + print(f" Running total: {current_accuracy:.1%} accuracy ({total_correct}/{total_processed})") + + # Check early stopping condition + if current_accuracy < min_accuracy_threshold and page > 1: + print(f" Stopping early: accuracy ({current_accuracy:.1%}) below threshold ({min_accuracy_threshold:.1%})") + break + + # Check if we've processed all pages + if page >= results_data.pagination.total_pages: + print(f" Processed all {results_data.pagination.total_pages} pages") + break + + page += 1 + + final_accuracy = total_correct / total_processed if total_processed > 0 else 0 + print(f"\n Final Results:") + print(f" Pages processed: {page}/{results_data.pagination.total_pages if 'results_data' in locals() else '?'}") + print(f" Results processed: {total_processed}") + print(f" Final accuracy: {final_accuracy:.1%}") + + return { + "pages_processed": page, + "results_processed": total_processed, + "accuracy": final_accuracy, + "stopped_early": page < (results_data.pagination.total_pages if 'results_data' in locals() else 1) + } + # Usage -analysis = analyze_evaluation_performance("eval_12345") +early_stop_results = analyze_with_early_stopping("eval_12345", min_accuracy_threshold=0.8) ``` ### Comparative Analysis @@ -216,7 +508,7 @@ def compare_evaluation_results(evaluation_ids: List[str], labels: List[str] = No elif not labels: labels = [f"Eval {i+1}" for i in range(len(evaluation_ids))] - print(f"šŸ“Š Comparing {len(evaluation_ids)} evaluations") + print(f" Comparing {len(evaluation_ids)} evaluations") print(f"{'='*80}") # Collect results for all evaluations @@ -225,15 +517,15 @@ def compare_evaluation_results(evaluation_ids: List[str], labels: List[str] = No results = client.results.get(evaluation_id=eval_id) if results: all_results[label] = results - print(f"āœ… Loaded {len(results)} results for {label}") + print(f" Loaded {len(results)} results for {label}") else: - print(f"āŒ No results found for {label} ({eval_id})") + print(f" No results found for {label} ({eval_id})") if not all_results: - print("āŒ No results to compare") + print(" No results to compare") return - print(f"\nšŸ“ˆ Comparative Analysis:") + print(f"\n Comparative Analysis:") print(f"{'Metric':<20} " + " ".join(f"{label:<15}" for label in labels)) print("-" * (20 + 15 * len(labels))) @@ -266,7 +558,7 @@ def compare_evaluation_results(evaluation_ids: List[str], labels: List[str] = No best_accuracy_label = next(label for label, data in metrics.items() if data == best_accuracy) best_speed_label = next(label for label, data in metrics.items() if data == best_speed) - print(f"\nšŸ† Winners:") + print(f"\n Winners:") print(f" Best Accuracy: {best_accuracy_label} ({best_accuracy['accuracy']:.1%})") print(f" Fastest: {best_speed_label} ({best_speed['avg_duration']})") @@ -280,7 +572,7 @@ def compare_evaluation_results(evaluation_ids: List[str], labels: List[str] = No common_subsets = common_subsets.intersection(result_subsets) if common_subsets: - print(f"\nšŸ“‹ Subset Comparison ({len(common_subsets)} common subsets):") + print(f"\n Subset Comparison ({len(common_subsets)} common subsets):") print(f"{'Subset':<25} " + " ".join(f"{label} Acc":<12 for label in labels)) print("-" * (25 + 12 * len(labels))) @@ -316,21 +608,21 @@ def analyze_failures(evaluation_id: str, error_threshold: float = 0.3): results = client.results.get(evaluation_id=evaluation_id) if not results: - print(f"āŒ No results found for evaluation {evaluation_id}") + print(f" No results found for evaluation {evaluation_id}") return None # Find poor-performing cases poor_results = [r for r in results if r.score < error_threshold] good_results = [r for r in results if r.score >= error_threshold] - print(f"šŸ” Error Analysis for {evaluation_id}") + print(f" Error Analysis for {evaluation_id}") print(f"{'='*60}") print(f"Total cases: {len(results)}") print(f"Poor performance (< {error_threshold}): {len(poor_results)} ({len(poor_results)/len(results):.1%})") print(f"Good performance (>= {error_threshold}): {len(good_results)} ({len(good_results)/len(results):.1%})") if not poor_results: - print("šŸŽ‰ No poor-performing cases found!") + print(" No poor-performing cases found!") return {"poor_results": [], "analysis": "No errors to analyze"} # Analyze failure patterns by subset @@ -340,7 +632,7 @@ def analyze_failures(evaluation_id: str, error_threshold: float = 0.3): failure_by_subset[result.subset] = [] failure_by_subset[result.subset].append(result) - print(f"\nāŒ Failure Distribution by Subset:") + print(f"\n Failure Distribution by Subset:") for subset, failures in sorted(failure_by_subset.items(), key=lambda x: len(x[1]), reverse=True): total_in_subset = len([r for r in results if r.subset == subset]) failure_rate = len(failures) / total_in_subset @@ -349,7 +641,7 @@ def analyze_failures(evaluation_id: str, error_threshold: float = 0.3): # Show worst-performing examples worst_results = sorted(poor_results, key=lambda x: x.score)[:5] - print(f"\nšŸ” Worst Performing Examples:") + print(f"\n Worst Performing Examples:") for i, result in enumerate(worst_results, 1): print(f"\n Example {i} [Score: {result.score:.3f}]") print(f" Subset: {result.subset}") @@ -362,7 +654,7 @@ def analyze_failures(evaluation_id: str, error_threshold: float = 0.3): print(f" Additional Metrics: {result.metrics}") # Common failure patterns - print(f"\nšŸ” Common Patterns in Failures:") + print(f"\n Common Patterns in Failures:") # Analyze prompt lengths poor_prompt_lengths = [len(r.prompt) for r in poor_results] @@ -424,11 +716,11 @@ def process_results_in_batches(evaluation_id: str, batch_size: int = 100, proces results = client.results.get(evaluation_id=evaluation_id) if not results: - print(f"āŒ No results found for evaluation {evaluation_id}") + print(f" No results found for evaluation {evaluation_id}") return None total_results = len(results) - print(f"šŸ“Š Processing {total_results:,} results in batches of {batch_size}") + print(f" Processing {total_results:,} results in batches of {batch_size}") if not processor_func: # Default processor: just count scores @@ -446,7 +738,7 @@ def process_results_in_batches(evaluation_id: str, batch_size: int = 100, proces batch_num = i // batch_size + 1 total_batches = (total_results + batch_size - 1) // batch_size - print(f"šŸ”„ Processing batch {batch_num}/{total_batches} ({len(batch)} items)") + print(f" Processing batch {batch_num}/{total_batches} ({len(batch)} items)") start_time = time.time() batch_result = processor_func(batch) @@ -460,7 +752,7 @@ def process_results_in_batches(evaluation_id: str, batch_size: int = 100, proces batch_results.append(batch_result) - print(f" āœ… Completed in {batch_result['processing_time']:.2f}s") + print(f" Completed in {batch_result['processing_time']:.2f}s") # Small delay to prevent overwhelming the system if batch_num < total_batches: @@ -471,7 +763,7 @@ def process_results_in_batches(evaluation_id: str, batch_size: int = 100, proces total_correct = sum(br.get("correct", 0) for br in batch_results) overall_accuracy = total_correct / total_results - print(f"\nšŸ“ˆ Batch Processing Summary:") + print(f"\n Batch Processing Summary:") print(f" Total batches: {len(batch_results)}") print(f" Total processing time: {total_processing_time:.2f}s") print(f" Average time per batch: {total_processing_time/len(batch_results):.2f}s") @@ -588,7 +880,7 @@ class ResultsCache: print(f"šŸ’¾ Cached {len(results)} results for {evaluation_id}") except Exception as e: - print(f"āŒ Error caching results: {e}") + print(f" Error caching results: {e}") def load_results(self, evaluation_id: str, format: str = "pickle"): """Load results from cache""" @@ -606,7 +898,7 @@ class ResultsCache: return results except Exception as e: - print(f"āŒ Error loading cached results: {e}") + print(f" Error loading cached results: {e}") return None def get_metadata(self, evaluation_id: str): @@ -616,7 +908,7 @@ class ResultsCache: with open(metadata_path, 'r') as f: return json.load(f) except Exception as e: - print(f"āŒ Error loading metadata: {e}") + print(f" Error loading metadata: {e}") return None def get_results_with_cache(evaluation_id: str, cache: ResultsCache = None, force_refresh: bool = False): @@ -648,15 +940,15 @@ def get_results_with_cache(evaluation_id: str, cache: ResultsCache = None, force cache.save_results(evaluation_id, results) return results else: - print(f"āŒ No results found for evaluation {evaluation_id}") + print(f" No results found for evaluation {evaluation_id}") return None except atlas.APIError as e: - print(f"āŒ Error fetching results: {e}") + print(f" Error fetching results: {e}") # Try to return cached results as fallback if cache.is_cached(evaluation_id): - print(f"šŸ”„ Falling back to cached results...") + print(f" Falling back to cached results...") return cache.load_results(evaluation_id) return None @@ -679,7 +971,7 @@ evaluation_ids = ["eval_001", "eval_002", "eval_003"] for eval_id in evaluation_ids: results = get_results_with_cache(eval_id, cache) if results: - print(f"āœ… {eval_id}: {len(results)} results cached") + print(f" {eval_id}: {len(results)} results cached") print(f"\nšŸ“ Cache contents:") for cache_file in cache.cache_dir.glob("*.json"): @@ -707,7 +999,7 @@ def export_results_to_csv(evaluation_id: str, output_path: str = None): results = client.results.get(evaluation_id=evaluation_id) if not results: - print(f"āŒ No results found for evaluation {evaluation_id}") + print(f" No results found for evaluation {evaluation_id}") return None if not output_path: @@ -748,11 +1040,11 @@ def export_results_to_csv(evaluation_id: str, output_path: str = None): writer.writerow(row) - print(f"šŸ“„ Exported {len(results)} results to {output_path}") + print(f" Exported {len(results)} results to {output_path}") return output_path except Exception as e: - print(f"āŒ Error exporting to CSV: {e}") + print(f" Error exporting to CSV: {e}") return None def generate_summary_report(evaluation_ids: list, output_path: str = None): @@ -776,7 +1068,7 @@ def generate_summary_report(evaluation_ids: list, output_path: str = None): results = client.results.get(evaluation_id=eval_id) if not results: - f.write("āŒ No results found\n\n") + f.write(" No results found\n\n") continue # Calculate statistics @@ -810,7 +1102,7 @@ def generate_summary_report(evaluation_ids: list, output_path: str = None): f.write("END OF REPORT\n") - print(f"šŸ“Š Summary report generated: {output_path}") + print(f" Summary report generated: {output_path}") return output_path # Usage examples diff --git a/scripts/test b/scripts/test old mode 100755 new mode 100644 diff --git a/scripts/test_coverage b/scripts/test_coverage new file mode 100755 index 0000000..c22dd6e --- /dev/null +++ b/scripts/test_coverage @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +echo "==> Running tests" +rye run pytest tests/ --cov=src/atlas --cov-report=term --tb=no \ No newline at end of file diff --git a/src/atlas/_models.py b/src/atlas/_models.py index 0d39f5e..ac87d2b 100644 --- a/src/atlas/_models.py +++ b/src/atlas/_models.py @@ -36,11 +36,24 @@ class Result(BaseModel): truth: str duration: timedelta score: float - metrics: Dict[str, float] + metrics: Dict[str, Optional[float]] + + +class ResultMetrics(BaseModel): + total_count: int + + +class Pagination(BaseModel): + total_count: int + page_size: int + total_pages: int class Results(BaseModel): + evaluation_id: str results: List[Result] + metrics: ResultMetrics + pagination: Pagination class Model(BaseModel): @@ -106,5 +119,5 @@ class CustomBenchmark(BaseModel): class Benchmarks(BaseModel): model_config = ConfigDict(populate_by_name=True) - + benchmarks: List[Union[Benchmark, CustomBenchmark]] = Field(..., alias="datasets") diff --git a/src/atlas/resources/results/results.py b/src/atlas/resources/results/results.py index a7b1f03..82f84a6 100644 --- a/src/atlas/resources/results/results.py +++ b/src/atlas/resources/results/results.py @@ -1,29 +1,78 @@ from __future__ import annotations -from typing import List +import math +from typing import Optional import httpx -from ..._models import Result, Results as ResultsData +from ..._models import Results as ResultsData from ..._resource import SyncAPIResource from ..._constants import DEFAULT_TIMEOUT +DEFAULT_PAGE_SIZE = 100 + class Results(SyncAPIResource): def get( self, *, evaluation_id: str, + page: Optional[int] = None, + page_size: Optional[int] = None, timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, - ) -> List[Result] | None: - results = self._get( + ) -> ResultsData | None: + """ + Get evaluation results with optional pagination. + + Args: + evaluation_id: The ID of the evaluation to get results for + page: Page number for pagination (1-based, defaults to 1 if not provided) + page_size: Number of results per page (default: 100, optional) + timeout: Request timeout + + Returns: + ResultsData object containing: + - evaluation_id: The evaluation ID + - results: List of Result objects for the current page + - metrics: Contains total_count and score ranges + - pagination: Calculated pagination info (total_count, page_size, total_pages) + or None if the request fails + """ + params = {"evaluation_id": evaluation_id} + + # Set default page_size if not provided + effective_page_size = page_size if page_size is not None else DEFAULT_PAGE_SIZE + + # Set default page to 1 if not provided + effective_page = page if page is not None else 1 + + params["page"] = str(effective_page) + if page_size is not None: + params["pageSize"] = str(page_size) + + # Get the response with cast_to to get parsed data + response_data = self._get( f"/results", - params={ - "evaluation_id": evaluation_id, - }, + params=params, timeout=timeout, - cast_to=ResultsData, + cast_to=dict, ) - if isinstance(results, ResultsData): - return results.results - return None + + if not response_data or not isinstance(response_data, dict): + return None + + # Calculate pagination info + metrics = response_data.get("metrics", {}) + total_count = metrics.get("total_count", 0) + total_pages = math.ceil(total_count / effective_page_size) if total_count > 0 and effective_page_size > 0 else 0 + + # Add pagination to the response + response_with_pagination = { + **response_data, + "pagination": {"total_count": total_count, "page_size": effective_page_size, "total_pages": total_pages}, + } + + try: + return ResultsData.model_validate(response_with_pagination) + except Exception: + return None diff --git a/tests/conftest.py b/tests/conftest.py index 9de38db..54d61a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,14 @@ def env_vars(): """Clean environment variables for testing.""" env_keys = ["LAYERLENS_ATLAS_API_KEY", "LAYERLENS_ATLAS_ORG_ID", "LAYERLENS_ATLAS_PROJECT_ID"] original_values = {key: os.environ.get(key) for key in env_keys} - + # Clear environment variables for key in env_keys: if key in os.environ: del os.environ[key] - + yield - + # Restore original values for key, value in original_values.items(): if value is not None: @@ -28,9 +28,12 @@ def env_vars(): @pytest.fixture def mock_env_vars(): """Mock environment variables with test values.""" - with mock.patch.dict(os.environ, { - "LAYERLENS_ATLAS_API_KEY": "test-api-key", - "LAYERLENS_ATLAS_ORG_ID": "test-org-id", - "LAYERLENS_ATLAS_PROJECT_ID": "test-project-id" - }): - yield \ No newline at end of file + with mock.patch.dict( + os.environ, + { + "LAYERLENS_ATLAS_API_KEY": "test-api-key", + "LAYERLENS_ATLAS_ORG_ID": "test-org-id", + "LAYERLENS_ATLAS_PROJECT_ID": "test-project-id", + }, + ): + yield diff --git a/tests/resources/test_benchmarks.py b/tests/resources/test_benchmarks.py index fe589cf..2ffbdf5 100644 --- a/tests/resources/test_benchmarks.py +++ b/tests/resources/test_benchmarks.py @@ -76,16 +76,16 @@ def mock_custom_benchmarks_response(self, sample_custom_benchmark_data): def test_benchmarks_initialization(self, mock_client): """Benchmarks resource initializes correctly.""" benchmarks = Benchmarks(mock_client) - + assert benchmarks._client is mock_client assert benchmarks._get is mock_client.get_cast def test_get_public_benchmarks_success(self, benchmarks_resource, mock_public_benchmarks_response): """get method returns public benchmarks successfully.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + result = benchmarks_resource.get(type="public") - + assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], Benchmark) @@ -95,9 +95,9 @@ def test_get_public_benchmarks_success(self, benchmarks_resource, mock_public_be def test_get_custom_benchmarks_success(self, benchmarks_resource, mock_custom_benchmarks_response): """get method returns custom benchmarks successfully.""" benchmarks_resource._get.return_value = mock_custom_benchmarks_response - + result = benchmarks_resource.get(type="custom") - + assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], CustomBenchmark) @@ -107,9 +107,9 @@ def test_get_custom_benchmarks_success(self, benchmarks_resource, mock_custom_be def test_get_benchmarks_request_parameters_public(self, benchmarks_resource, mock_public_benchmarks_response): """get method makes correct API request for public benchmarks.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + benchmarks_resource.get(type="public") - + benchmarks_resource._get.assert_called_once_with( "/organizations/org-123/projects/proj-456/benchmarks", params={"type": "public"}, @@ -120,9 +120,9 @@ def test_get_benchmarks_request_parameters_public(self, benchmarks_resource, moc def test_get_benchmarks_request_parameters_custom(self, benchmarks_resource, mock_custom_benchmarks_response): """get method makes correct API request for custom benchmarks.""" benchmarks_resource._get.return_value = mock_custom_benchmarks_response - + benchmarks_resource.get(type="custom") - + benchmarks_resource._get.assert_called_once_with( "/organizations/org-123/projects/proj-456/benchmarks", params={"type": "custom"}, @@ -134,9 +134,9 @@ def test_get_benchmarks_with_custom_timeout(self, benchmarks_resource, mock_publ """get method accepts custom timeout.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response custom_timeout = 45.0 - + benchmarks_resource.get(type="public", timeout=custom_timeout) - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["timeout"] == custom_timeout @@ -144,55 +144,57 @@ def test_get_benchmarks_with_httpx_timeout(self, benchmarks_resource, mock_publi """get method accepts httpx.Timeout object.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response custom_timeout = httpx.Timeout(45.0) - + benchmarks_resource.get(type="public", timeout=custom_timeout) - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["timeout"] is custom_timeout def test_get_benchmarks_none_response(self, benchmarks_resource): """get method returns None when response is None.""" benchmarks_resource._get.return_value = None - + result = benchmarks_resource.get(type="public") - + assert result is None def test_get_benchmarks_invalid_response_type(self, benchmarks_resource): """get method handles non-BenchmarksData response gracefully.""" benchmarks_resource._get.return_value = "invalid-response" - + result = benchmarks_resource.get(type="public") - + assert result is None def test_get_benchmarks_empty_response(self, benchmarks_resource): """get method returns empty list when no benchmarks in response.""" empty_response = BenchmarksData(datasets=[]) benchmarks_resource._get.return_value = empty_response - + result = benchmarks_resource.get(type="public") - + assert result == [] assert isinstance(result, list) - def test_get_benchmarks_multiple_items(self, benchmarks_resource, sample_benchmark_data, sample_custom_benchmark_data): + def test_get_benchmarks_multiple_items( + self, benchmarks_resource, sample_benchmark_data, sample_custom_benchmark_data + ): """get method returns multiple benchmarks correctly.""" _ = sample_custom_benchmark_data # Fixture used for side effects benchmark = Benchmark(**sample_benchmark_data) - + # Create second benchmark with different data benchmark2_data = sample_benchmark_data.copy() benchmark2_data["id"] = "benchmark-456" benchmark2_data["key"] = "hellaswag" benchmark2_data["name"] = "HellaSwag" benchmark2 = Benchmark(**benchmark2_data) - + response = BenchmarksData(datasets=[benchmark, benchmark2]) benchmarks_resource._get.return_value = response - + result = benchmarks_resource.get(type="public") - + assert len(result) == 2 assert result[0].key == "mmlu" assert result[1].key == "hellaswag" @@ -202,9 +204,9 @@ def test_get_benchmarks_url_construction(self, benchmarks_resource, mock_public_ benchmarks_resource._client.organization_id = "custom-org" benchmarks_resource._client.project_id = "custom-project" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + benchmarks_resource.get(type="public") - + expected_url = "/organizations/custom-org/projects/custom-project/benchmarks" call_args = benchmarks_resource._get.call_args assert call_args[0][0] == expected_url @@ -213,36 +215,36 @@ def test_get_benchmarks_url_construction(self, benchmarks_resource, mock_public_ def test_get_benchmarks_type_parameter(self, benchmarks_resource, benchmark_type): """get method accepts both public and custom types.""" benchmarks_resource._get.return_value = BenchmarksData(datasets=[]) - + benchmarks_resource.get(type=benchmark_type) - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["params"]["type"] == benchmark_type def test_get_benchmarks_cast_to_parameter(self, benchmarks_resource, mock_public_benchmarks_response): """get method specifies correct cast_to parameter.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + benchmarks_resource.get(type="public") - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["cast_to"] is BenchmarksData def test_get_benchmarks_timeout_default(self, benchmarks_resource, mock_public_benchmarks_response): """get method uses DEFAULT_TIMEOUT when no timeout specified.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + benchmarks_resource.get(type="public") - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT def test_get_benchmarks_with_none_timeout(self, benchmarks_resource, mock_public_benchmarks_response): """get method accepts None timeout.""" benchmarks_resource._get.return_value = mock_public_benchmarks_response - + benchmarks_resource.get(type="public", timeout=None) - + call_args = benchmarks_resource._get.call_args assert call_args.kwargs["timeout"] is None @@ -267,50 +269,50 @@ def benchmarks_resource(self, mock_client): def test_get_benchmarks_handles_api_error(self, benchmarks_resource): """get method propagates API errors.""" from atlas._exceptions import APIStatusError - + mock_response = Mock() mock_response.status_code = 404 mock_response.headers = {} - + api_error = APIStatusError("Not Found", response=mock_response, body=None) benchmarks_resource._get.side_effect = api_error - + with pytest.raises(APIStatusError): benchmarks_resource.get(type="public") def test_get_benchmarks_handles_auth_error(self, benchmarks_resource): """get method propagates authentication errors.""" from atlas._exceptions import AuthenticationError - + mock_response = Mock() mock_response.status_code = 401 mock_response.headers = {} - + auth_error = AuthenticationError("Unauthorized", response=mock_response, body=None) benchmarks_resource._get.side_effect = auth_error - + with pytest.raises(AuthenticationError): benchmarks_resource.get(type="custom") def test_get_benchmarks_handles_connection_error(self, benchmarks_resource): """get method propagates connection errors.""" from atlas._exceptions import APIConnectionError - + mock_request = Mock() connection_error = APIConnectionError(request=mock_request) benchmarks_resource._get.side_effect = connection_error - + with pytest.raises(APIConnectionError): benchmarks_resource.get(type="public") def test_get_benchmarks_handles_timeout_error(self, benchmarks_resource): """get method propagates timeout errors.""" from atlas._exceptions import APITimeoutError - + mock_request = Mock() timeout_error = APITimeoutError(mock_request) benchmarks_resource._get.side_effect = timeout_error - + with pytest.raises(APITimeoutError): benchmarks_resource.get(type="public", timeout=1.0) @@ -338,7 +340,7 @@ def test_get_benchmarks_return_type_consistency(self, benchmarks_resource): benchmarks_resource._get.return_value = None result = benchmarks_resource.get(type="public") assert result is None - + # Test that it returns a list when successful benchmarks_resource._get.return_value = BenchmarksData(datasets=[]) result = benchmarks_resource.get(type="public") @@ -358,7 +360,7 @@ def test_get_benchmarks_mixed_benchmark_types(self, benchmarks_resource): "prompt_count": 1000, "deprecated": False, } - + custom_data = { "id": "custom-456", "key": "my-bench", @@ -376,17 +378,17 @@ def test_get_benchmarks_mixed_benchmark_types(self, benchmarks_resource): "files": ["test.jsonl"], "disabled": False, } - + public_benchmark = Benchmark(**public_data) custom_benchmark = CustomBenchmark(**custom_data) - + response = BenchmarksData(datasets=[public_benchmark, custom_benchmark]) benchmarks_resource._get.return_value = response - + result = benchmarks_resource.get(type="public") # Type doesn't matter for this test - + assert len(result) == 2 assert isinstance(result[0], Benchmark) assert isinstance(result[1], CustomBenchmark) assert result[0].key == "mmlu" - assert result[1].key == "my-bench" \ No newline at end of file + assert result[1].key == "my-bench" diff --git a/tests/resources/test_evaluations.py b/tests/resources/test_evaluations.py index 2f4c444..a4e9652 100644 --- a/tests/resources/test_evaluations.py +++ b/tests/resources/test_evaluations.py @@ -57,7 +57,7 @@ def mock_evaluations_response(self, sample_evaluation_data): def test_evaluations_initialization(self, mock_client): """Evaluations resource initializes correctly.""" evaluations = Evaluations(mock_client) - + assert evaluations._client is mock_client assert evaluations._get is mock_client.get_cast assert evaluations._post is mock_client.post_cast @@ -65,9 +65,9 @@ def test_evaluations_initialization(self, mock_client): def test_create_evaluation_success(self, evaluations_resource, mock_evaluations_response): """create method returns first evaluation on success.""" evaluations_resource._post.return_value = mock_evaluations_response - + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + assert isinstance(result, Evaluation) assert result.id == "eval-123" assert result.model_name == "GPT-4" @@ -76,17 +76,19 @@ def test_create_evaluation_success(self, evaluations_resource, mock_evaluations_ def test_create_evaluation_request_parameters(self, evaluations_resource, mock_evaluations_response): """create method makes correct API request.""" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + evaluations_resource._post.assert_called_once_with( "/organizations/org-123/projects/proj-456/evaluations", - body=[{ - "model_id": "gpt-4", - "dataset_id": "mmlu", - "is_custom_model": False, - "is_custom_dataset": False, - }], + body=[ + { + "model_id": "gpt-4", + "dataset_id": "mmlu", + "is_custom_model": False, + "is_custom_dataset": False, + } + ], timeout=DEFAULT_TIMEOUT, cast_to=EvaluationsData, ) @@ -95,9 +97,9 @@ def test_create_evaluation_with_custom_timeout(self, evaluations_resource, mock_ """create method accepts custom timeout.""" evaluations_resource._post.return_value = mock_evaluations_response custom_timeout = 30.0 - + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=custom_timeout) - + call_args = evaluations_resource._post.call_args assert call_args.kwargs["timeout"] == custom_timeout @@ -105,9 +107,9 @@ def test_create_evaluation_with_httpx_timeout(self, evaluations_resource, mock_e """create method accepts httpx.Timeout object.""" evaluations_resource._post.return_value = mock_evaluations_response custom_timeout = httpx.Timeout(30.0) - + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=custom_timeout) - + call_args = evaluations_resource._post.call_args assert call_args.kwargs["timeout"] is custom_timeout @@ -115,25 +117,25 @@ def test_create_evaluation_empty_response(self, evaluations_resource): """create method returns None when no evaluations in response.""" empty_response = EvaluationsData(data=[]) evaluations_resource._post.return_value = empty_response - + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + assert result is None def test_create_evaluation_none_response(self, evaluations_resource): """create method returns None when response is None.""" evaluations_resource._post.return_value = None - + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + assert result is None def test_create_evaluation_invalid_response_type(self, evaluations_resource): """create method handles non-EvaluationsData response gracefully.""" evaluations_resource._post.return_value = "invalid-response" - + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + assert result is None def test_create_evaluation_multiple_evaluations_returns_first(self, evaluations_resource, sample_evaluation_data): @@ -142,12 +144,12 @@ def test_create_evaluation_multiple_evaluations_returns_first(self, evaluations_ eval2_data = sample_evaluation_data.copy() eval2_data["id"] = "eval-456" eval2 = Evaluation(**eval2_data) - + response = EvaluationsData(data=[eval1, eval2]) evaluations_resource._post.return_value = response - + result = evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + assert result.id == "eval-123" # First evaluation assert result is not eval2 @@ -156,9 +158,9 @@ def test_create_evaluation_url_construction(self, evaluations_resource, mock_eva evaluations_resource._client.organization_id = "custom-org" evaluations_resource._client.project_id = "custom-project" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="test-model", benchmark="test-benchmark") - + expected_url = "/organizations/custom-org/projects/custom-project/evaluations" call_args = evaluations_resource._post.call_args assert call_args[0][0] == expected_url @@ -166,12 +168,12 @@ def test_create_evaluation_url_construction(self, evaluations_resource, mock_eva def test_create_evaluation_request_body_structure(self, evaluations_resource, mock_evaluations_response): """create method sends correct request body structure.""" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="custom-model", benchmark="custom-benchmark") - + call_args = evaluations_resource._post.call_args body = call_args.kwargs["body"] - + assert isinstance(body, list) assert len(body) == 1 assert body[0]["model_id"] == "custom-model" @@ -179,18 +181,23 @@ def test_create_evaluation_request_body_structure(self, evaluations_resource, mo assert body[0]["is_custom_model"] is False assert body[0]["is_custom_dataset"] is False - @pytest.mark.parametrize("model_name,benchmark_name", [ - ("gpt-3.5-turbo", "hellaswag"), - ("claude-3-opus", "arc-challenge"), - ("llama-2-70b", "truthfulqa"), - ("custom-model-123", "custom-benchmark-456"), - ]) - def test_create_evaluation_with_different_parameters(self, evaluations_resource, mock_evaluations_response, model_name, benchmark_name): + @pytest.mark.parametrize( + "model_name,benchmark_name", + [ + ("gpt-3.5-turbo", "hellaswag"), + ("claude-3-opus", "arc-challenge"), + ("llama-2-70b", "truthfulqa"), + ("custom-model-123", "custom-benchmark-456"), + ], + ) + def test_create_evaluation_with_different_parameters( + self, evaluations_resource, mock_evaluations_response, model_name, benchmark_name + ): """create method works with various model and benchmark combinations.""" evaluations_resource._post.return_value = mock_evaluations_response - + result = evaluations_resource.create(model=model_name, benchmark=benchmark_name) - + assert isinstance(result, Evaluation) call_args = evaluations_resource._post.call_args body = call_args.kwargs["body"][0] @@ -200,27 +207,27 @@ def test_create_evaluation_with_different_parameters(self, evaluations_resource, def test_create_evaluation_cast_to_parameter(self, evaluations_resource, mock_evaluations_response): """create method specifies correct cast_to parameter.""" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + call_args = evaluations_resource._post.call_args assert call_args.kwargs["cast_to"] is EvaluationsData def test_create_evaluation_timeout_default(self, evaluations_resource, mock_evaluations_response): """create method uses DEFAULT_TIMEOUT when no timeout specified.""" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="gpt-4", benchmark="mmlu") - + call_args = evaluations_resource._post.call_args assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT def test_create_evaluation_with_none_timeout(self, evaluations_resource, mock_evaluations_response): """create method accepts None timeout.""" evaluations_resource._post.return_value = mock_evaluations_response - + evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=None) - + call_args = evaluations_resource._post.call_args assert call_args.kwargs["timeout"] is None @@ -245,36 +252,36 @@ def evaluations_resource(self, mock_client): def test_create_evaluation_handles_api_error(self, evaluations_resource): """create method propagates API errors.""" from atlas._exceptions import APIStatusError - + mock_response = Mock() mock_response.status_code = 400 mock_response.headers = {} - + api_error = APIStatusError("Bad Request", response=mock_response, body=None) evaluations_resource._post.side_effect = api_error - + with pytest.raises(APIStatusError): evaluations_resource.create(model="invalid-model", benchmark="invalid-benchmark") def test_create_evaluation_handles_connection_error(self, evaluations_resource): """create method propagates connection errors.""" from atlas._exceptions import APIConnectionError - + mock_request = Mock() connection_error = APIConnectionError(request=mock_request) evaluations_resource._post.side_effect = connection_error - + with pytest.raises(APIConnectionError): evaluations_resource.create(model="gpt-4", benchmark="mmlu") def test_create_evaluation_handles_timeout_error(self, evaluations_resource): """create method propagates timeout errors.""" from atlas._exceptions import APITimeoutError - + mock_request = Mock() timeout_error = APITimeoutError(mock_request) evaluations_resource._post.side_effect = timeout_error - + with pytest.raises(APITimeoutError): evaluations_resource.create(model="gpt-4", benchmark="mmlu", timeout=1.0) @@ -288,7 +295,7 @@ def test_create_evaluation_end_to_end_flow(self): mock_client = Mock() mock_client.organization_id = "test-org" mock_client.project_id = "test-project" - + # Create sample evaluation data evaluation_data = { "id": "eval-integration-test", @@ -308,28 +315,25 @@ def test_create_evaluation_end_to_end_flow(self): "ethics_score": 0.0, "accuracy": 0.0, } - + evaluation = Evaluation(**evaluation_data) response = EvaluationsData(data=[evaluation]) mock_client.post_cast.return_value = response - + # Test the resource evaluations_resource = Evaluations(mock_client) - result = evaluations_resource.create( - model="integration-model", - benchmark="integration-dataset" - ) - + result = evaluations_resource.create(model="integration-model", benchmark="integration-dataset") + # Verify the complete flow assert result is not None assert result.id == "eval-integration-test" assert result.model_id == "integration-model" assert result.dataset_id == "integration-dataset" assert result.status == "submitted" - + # Verify the API call was made correctly mock_client.post_cast.assert_called_once() call_args = mock_client.post_cast.call_args assert "/organizations/test-org/projects/test-project/evaluations" in call_args[0][0] assert call_args.kwargs["body"][0]["model_id"] == "integration-model" - assert call_args.kwargs["body"][0]["dataset_id"] == "integration-dataset" \ No newline at end of file + assert call_args.kwargs["body"][0]["dataset_id"] == "integration-dataset" diff --git a/tests/resources/test_models_resource.py b/tests/resources/test_models_resource.py index f1bb39c..94ba5f2 100644 --- a/tests/resources/test_models_resource.py +++ b/tests/resources/test_models_resource.py @@ -73,16 +73,16 @@ def mock_custom_models_response(self, sample_custom_model_data): def test_models_initialization(self, mock_client): """Models resource initializes correctly.""" models = Models(mock_client) - + assert models._client is mock_client assert models._get is mock_client.get_cast def test_get_public_models_success(self, models_resource, mock_public_models_response): """get method returns public models successfully.""" models_resource._get.return_value = mock_public_models_response - + result = models_resource.get(type="public") - + assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], Model) @@ -93,9 +93,9 @@ def test_get_public_models_success(self, models_resource, mock_public_models_res def test_get_custom_models_success(self, models_resource, mock_custom_models_response): """get method returns custom models successfully.""" models_resource._get.return_value = mock_custom_models_response - + result = models_resource.get(type="custom") - + assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], CustomModel) @@ -106,9 +106,9 @@ def test_get_custom_models_success(self, models_resource, mock_custom_models_res def test_get_models_request_parameters_public(self, models_resource, mock_public_models_response): """get method makes correct API request for public models.""" models_resource._get.return_value = mock_public_models_response - + models_resource.get(type="public") - + models_resource._get.assert_called_once_with( "/organizations/org-123/projects/proj-456/models", params={"type": "public"}, @@ -119,9 +119,9 @@ def test_get_models_request_parameters_public(self, models_resource, mock_public def test_get_models_request_parameters_custom(self, models_resource, mock_custom_models_response): """get method makes correct API request for custom models.""" models_resource._get.return_value = mock_custom_models_response - + models_resource.get(type="custom") - + models_resource._get.assert_called_once_with( "/organizations/org-123/projects/proj-456/models", params={"type": "custom"}, @@ -133,9 +133,9 @@ def test_get_models_with_custom_timeout(self, models_resource, mock_public_model """get method accepts custom timeout.""" models_resource._get.return_value = mock_public_models_response custom_timeout = 60.0 - + models_resource.get(type="public", timeout=custom_timeout) - + call_args = models_resource._get.call_args assert call_args.kwargs["timeout"] == custom_timeout @@ -143,42 +143,42 @@ def test_get_models_with_httpx_timeout(self, models_resource, mock_public_models """get method accepts httpx.Timeout object.""" models_resource._get.return_value = mock_public_models_response custom_timeout = httpx.Timeout(60.0) - + models_resource.get(type="public", timeout=custom_timeout) - + call_args = models_resource._get.call_args assert call_args.kwargs["timeout"] is custom_timeout def test_get_models_none_response(self, models_resource): """get method returns None when response is None.""" models_resource._get.return_value = None - + result = models_resource.get(type="public") - + assert result is None def test_get_models_invalid_response_type(self, models_resource): """get method handles non-ModelsData response gracefully.""" models_resource._get.return_value = "invalid-response" - + result = models_resource.get(type="public") - + assert result is None def test_get_models_empty_response(self, models_resource): """get method returns empty list when no models in response.""" empty_response = ModelsData(models=[]) models_resource._get.return_value = empty_response - + result = models_resource.get(type="public") - + assert result == [] assert isinstance(result, list) def test_get_models_multiple_items(self, models_resource, sample_model_data): """get method returns multiple models correctly.""" model1 = Model(**sample_model_data) - + # Create second model with different data model2_data = sample_model_data.copy() model2_data["id"] = "model-456" @@ -186,12 +186,12 @@ def test_get_models_multiple_items(self, models_resource, sample_model_data): model2_data["name"] = "GPT-3.5 Turbo" model2_data["parameters"] = 1.75e11 model2 = Model(**model2_data) - + response = ModelsData(models=[model1, model2]) models_resource._get.return_value = response - + result = models_resource.get(type="public") - + assert len(result) == 2 assert result[0].key == "gpt-4" assert result[1].key == "gpt-3.5-turbo" @@ -203,9 +203,9 @@ def test_get_models_url_construction(self, models_resource, mock_public_models_r models_resource._client.organization_id = "custom-org" models_resource._client.project_id = "custom-project" models_resource._get.return_value = mock_public_models_response - + models_resource.get(type="public") - + expected_url = "/organizations/custom-org/projects/custom-project/models" call_args = models_resource._get.call_args assert call_args[0][0] == expected_url @@ -214,46 +214,46 @@ def test_get_models_url_construction(self, models_resource, mock_public_models_r def test_get_models_type_parameter(self, models_resource, model_type): """get method accepts both public and custom types.""" models_resource._get.return_value = ModelsData(models=[]) - + models_resource.get(type=model_type) - + call_args = models_resource._get.call_args assert call_args.kwargs["params"]["type"] == model_type def test_get_models_cast_to_parameter(self, models_resource, mock_public_models_response): """get method specifies correct cast_to parameter.""" models_resource._get.return_value = mock_public_models_response - + models_resource.get(type="public") - + call_args = models_resource._get.call_args assert call_args.kwargs["cast_to"] is ModelsData def test_get_models_timeout_default(self, models_resource, mock_public_models_response): """get method uses DEFAULT_TIMEOUT when no timeout specified.""" models_resource._get.return_value = mock_public_models_response - + models_resource.get(type="public") - + call_args = models_resource._get.call_args assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT def test_get_models_with_none_timeout(self, models_resource, mock_public_models_response): """get method accepts None timeout.""" models_resource._get.return_value = mock_public_models_response - + models_resource.get(type="public", timeout=None) - + call_args = models_resource._get.call_args assert call_args.kwargs["timeout"] is None def test_get_models_model_attributes(self, models_resource, mock_public_models_response): """get method preserves all model attributes correctly.""" models_resource._get.return_value = mock_public_models_response - + result = models_resource.get(type="public") model = result[0] - + assert model.context_length == 8192 assert model.open_weights is False assert model.deprecated is False @@ -265,10 +265,10 @@ def test_get_models_model_attributes(self, models_resource, mock_public_models_r def test_get_models_custom_model_attributes(self, models_resource, mock_custom_models_response): """get method preserves all custom model attributes correctly.""" models_resource._get.return_value = mock_custom_models_response - + result = models_resource.get(type="custom") custom_model = result[0] - + assert custom_model.max_tokens == 4096 assert custom_model.disabled is False assert custom_model.api_url == "https://api.example.com/v1/chat" @@ -294,50 +294,50 @@ def models_resource(self, mock_client): def test_get_models_handles_api_error(self, models_resource): """get method propagates API errors.""" from atlas._exceptions import APIStatusError - + mock_response = Mock() mock_response.status_code = 500 mock_response.headers = {} - + api_error = APIStatusError("Internal Server Error", response=mock_response, body=None) models_resource._get.side_effect = api_error - + with pytest.raises(APIStatusError): models_resource.get(type="public") def test_get_models_handles_forbidden_error(self, models_resource): """get method propagates permission errors.""" from atlas._exceptions import PermissionDeniedError - + mock_response = Mock() mock_response.status_code = 403 mock_response.headers = {} - + permission_error = PermissionDeniedError("Forbidden", response=mock_response, body=None) models_resource._get.side_effect = permission_error - + with pytest.raises(PermissionDeniedError): models_resource.get(type="custom") def test_get_models_handles_connection_error(self, models_resource): """get method propagates connection errors.""" from atlas._exceptions import APIConnectionError - + mock_request = Mock() connection_error = APIConnectionError(request=mock_request) models_resource._get.side_effect = connection_error - + with pytest.raises(APIConnectionError): models_resource.get(type="public") def test_get_models_handles_timeout_error(self, models_resource): """get method propagates timeout errors.""" from atlas._exceptions import APITimeoutError - + mock_request = Mock() timeout_error = APITimeoutError(mock_request) models_resource._get.side_effect = timeout_error - + with pytest.raises(APITimeoutError): models_resource.get(type="public", timeout=5.0) @@ -365,7 +365,7 @@ def test_get_models_return_type_consistency(self, models_resource): models_resource._get.return_value = None result = models_resource.get(type="public") assert result is None - + # Test that it returns a list when successful models_resource._get.return_value = ModelsData(models=[]) result = models_resource.get(type="public") @@ -390,7 +390,7 @@ def test_get_models_mixed_model_types(self, models_resource): "region": "us-east-1", "deprecated": False, } - + custom_data = { "id": "custom-456", "key": "my-model", @@ -400,22 +400,22 @@ def test_get_models_mixed_model_types(self, models_resource): "api_url": "https://api.example.com/v1/chat", "disabled": False, } - + public_model = Model(**public_data) custom_model = CustomModel(**custom_data) - + response = ModelsData(models=[public_model, custom_model]) models_resource._get.return_value = response - + result = models_resource.get(type="public") # Type doesn't matter for this test - + assert len(result) == 2 assert isinstance(result[0], Model) assert isinstance(result[1], CustomModel) assert result[0].key == "gpt-4" assert result[1].key == "my-model" - assert hasattr(result[0], 'parameters') # Model-specific attribute - assert hasattr(result[1], 'max_tokens') # CustomModel-specific attribute + assert hasattr(result[0], "parameters") # Model-specific attribute + assert hasattr(result[1], "max_tokens") # CustomModel-specific attribute def test_get_models_large_parameters_handling(self, models_resource): """get method handles large parameter numbers correctly.""" @@ -435,15 +435,15 @@ def test_get_models_large_parameters_handling(self, models_resource): "region": "us-west-2", "deprecated": False, } - + large_model = Model(**large_model_data) response = ModelsData(models=[large_model]) models_resource._get.return_value = response - + result = models_resource.get(type="public") - + assert len(result) == 1 assert result[0].parameters == 1.3e14 assert result[0].context_length == 200000 assert isinstance(result[0].parameters, float) - assert isinstance(result[0].context_length, int) \ No newline at end of file + assert isinstance(result[0].context_length, int) diff --git a/tests/resources/test_results.py b/tests/resources/test_results.py index 26c610a..2529670 100644 --- a/tests/resources/test_results.py +++ b/tests/resources/test_results.py @@ -4,7 +4,7 @@ import httpx import pytest -from atlas._models import Result, Results as ResultsData +from atlas._models import Result, Results as ResultsData, Pagination, ResultMetrics from atlas._constants import DEFAULT_TIMEOUT from atlas.resources.results.results import Results @@ -34,60 +34,71 @@ def sample_result_data(self): "truth": "2x", "duration": timedelta(seconds=2.5), "score": 1.0, - "metrics": { - "accuracy": 1.0, - "confidence": 0.95, - "reasoning_quality": 0.9 - } + "metrics": {"accuracy": 1.0, "confidence": 0.95, "reasoning_quality": 0.9}, } @pytest.fixture def mock_results_response(self, sample_result_data): - """Mock ResultsData response.""" - result = Result(**sample_result_data) - return ResultsData(results=[result]) + """Mock raw API response with pagination.""" + return { + "evaluation_id": "eval-123", + "results": [sample_result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } def test_results_initialization(self, mock_client): """Results resource initializes correctly.""" results = Results(mock_client) - + assert results._client is mock_client assert results._get is mock_client.get_cast def test_get_results_success(self, results_resource, mock_results_response): - """get method returns results successfully.""" + """get method returns ResultsData successfully.""" results_resource._get.return_value = mock_results_response - + result = results_resource.get(evaluation_id="eval-123") - - assert isinstance(result, list) - assert len(result) == 1 - assert isinstance(result[0], Result) - assert result[0].subset == "mathematics" - assert result[0].prompt == "What is the derivative of x^2?" - assert result[0].result == "2x" - assert result[0].score == 1.0 + + assert isinstance(result, ResultsData) + assert result.evaluation_id == "eval-123" + assert len(result.results) == 1 + assert isinstance(result.results[0], Result) + assert result.results[0].subset == "mathematics" + assert result.results[0].prompt == "What is the derivative of x^2?" + assert result.results[0].result == "2x" + assert result.results[0].score == 1.0 + assert isinstance(result.metrics, ResultMetrics) + assert isinstance(result.pagination, Pagination) + assert result.pagination.total_count == 1 + assert result.pagination.page_size == 100 + assert result.pagination.total_pages == 1 def test_get_results_request_parameters(self, results_resource, mock_results_response): """get method makes correct API request.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="eval-456") - + results_resource._get.assert_called_once_with( "/results", - params={"evaluation_id": "eval-456"}, + params={"evaluation_id": "eval-456", "page": "1"}, timeout=DEFAULT_TIMEOUT, - cast_to=ResultsData, + cast_to=dict, ) def test_get_results_with_custom_timeout(self, results_resource, mock_results_response): """get method accepts custom timeout.""" results_resource._get.return_value = mock_results_response custom_timeout = 120.0 - + results_resource.get(evaluation_id="eval-123", timeout=custom_timeout) - + call_args = results_resource._get.call_args assert call_args.kwargs["timeout"] == custom_timeout @@ -95,42 +106,55 @@ def test_get_results_with_httpx_timeout(self, results_resource, mock_results_res """get method accepts httpx.Timeout object.""" results_resource._get.return_value = mock_results_response custom_timeout = httpx.Timeout(120.0) - + results_resource.get(evaluation_id="eval-123", timeout=custom_timeout) - + call_args = results_resource._get.call_args assert call_args.kwargs["timeout"] is custom_timeout def test_get_results_none_response(self, results_resource): """get method returns None when response is None.""" results_resource._get.return_value = None - + result = results_resource.get(evaluation_id="eval-123") - + assert result is None def test_get_results_invalid_response_type(self, results_resource): """get method handles non-ResultsData response gracefully.""" results_resource._get.return_value = "invalid-response" - + result = results_resource.get(evaluation_id="eval-123") - + assert result is None def test_get_results_empty_response(self, results_resource): - """get method returns empty list when no results in response.""" - empty_response = ResultsData(results=[]) + """get method returns ResultsData with empty results list when no results in response.""" + empty_response = { + "evaluation_id": "eval-123", + "results": [], + "metrics": { + "total_count": 0, + "min_toxicity_score": None, + "max_toxicity_score": None, + "min_readability_score": None, + "max_readability_score": None, + }, + } results_resource._get.return_value = empty_response - + result = results_resource.get(evaluation_id="eval-123") - - assert result == [] - assert isinstance(result, list) + + assert isinstance(result, ResultsData) + assert result.evaluation_id == "eval-123" + assert result.results == [] + assert isinstance(result.results, list) + assert result.pagination.total_count == 0 def test_get_results_multiple_items(self, results_resource, sample_result_data): """get method returns multiple results correctly.""" result1 = Result(**sample_result_data) - + # Create second result with different data result2_data = sample_result_data.copy() result2_data["subset"] = "science" @@ -140,70 +164,82 @@ def test_get_results_multiple_items(self, results_resource, sample_result_data): result2_data["score"] = 0.95 result2_data["duration"] = timedelta(seconds=3.2) result2 = Result(**result2_data) - - response = ResultsData(results=[result1, result2]) + + response = { + "evaluation_id": "eval-123", + "results": [sample_result_data, result2_data], + "metrics": { + "total_count": 2, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } results_resource._get.return_value = response - + result = results_resource.get(evaluation_id="eval-123") - - assert len(result) == 2 - assert result[0].subset == "mathematics" - assert result[1].subset == "science" - assert result[0].score == 1.0 - assert result[1].score == 0.95 + + assert isinstance(result, ResultsData) + assert len(result.results) == 2 + assert result.results[0].subset == "mathematics" + assert result.results[1].subset == "science" + assert result.results[0].score == 1.0 + assert result.results[1].score == 0.95 + assert result.pagination.total_count == 2 def test_get_results_url_construction(self, results_resource, mock_results_response): """get method uses correct URL endpoint.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="eval-123") - + call_args = results_resource._get.call_args assert call_args[0][0] == "/results" def test_get_results_evaluation_id_parameter(self, results_resource, mock_results_response): """get method correctly passes evaluation_id parameter.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="test-eval-789") - + call_args = results_resource._get.call_args assert call_args.kwargs["params"]["evaluation_id"] == "test-eval-789" def test_get_results_cast_to_parameter(self, results_resource, mock_results_response): """get method specifies correct cast_to parameter.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="eval-123") - + call_args = results_resource._get.call_args - assert call_args.kwargs["cast_to"] is ResultsData + assert call_args.kwargs["cast_to"] is dict def test_get_results_timeout_default(self, results_resource, mock_results_response): """get method uses DEFAULT_TIMEOUT when no timeout specified.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="eval-123") - + call_args = results_resource._get.call_args assert call_args.kwargs["timeout"] is DEFAULT_TIMEOUT def test_get_results_with_none_timeout(self, results_resource, mock_results_response): """get method accepts None timeout.""" results_resource._get.return_value = mock_results_response - + results_resource.get(evaluation_id="eval-123", timeout=None) - + call_args = results_resource._get.call_args assert call_args.kwargs["timeout"] is None def test_get_results_preserves_result_attributes(self, results_resource, mock_results_response): """get method preserves all result attributes correctly.""" results_resource._get.return_value = mock_results_response - + result = results_resource.get(evaluation_id="eval-123") - result_item = result[0] - + result_item = result.results[0] + assert isinstance(result_item.duration, timedelta) assert result_item.duration.total_seconds() == 2.5 assert isinstance(result_item.metrics, dict) @@ -211,19 +247,22 @@ def test_get_results_preserves_result_attributes(self, results_resource, mock_re assert result_item.metrics["confidence"] == 0.95 assert result_item.metrics["reasoning_quality"] == 0.9 - @pytest.mark.parametrize("evaluation_id", [ - "eval-123", - "evaluation-456-abc", - "test_eval_789", - "long-evaluation-id-with-many-characters-123456789", - ]) + @pytest.mark.parametrize( + "evaluation_id", + [ + "eval-123", + "evaluation-456-abc", + "test_eval_789", + "long-evaluation-id-with-many-characters-123456789", + ], + ) def test_get_results_with_different_evaluation_ids(self, results_resource, mock_results_response, evaluation_id): """get method works with various evaluation ID formats.""" results_resource._get.return_value = mock_results_response - + result = results_resource.get(evaluation_id=evaluation_id) - - assert isinstance(result, list) + + assert isinstance(result, ResultsData) call_args = results_resource._get.call_args assert call_args.kwargs["params"]["evaluation_id"] == evaluation_id @@ -246,78 +285,78 @@ def results_resource(self, mock_client): def test_get_results_handles_not_found_error(self, results_resource): """get method propagates not found errors.""" from atlas._exceptions import NotFoundError - + mock_response = Mock() mock_response.status_code = 404 mock_response.headers = {} - + not_found_error = NotFoundError("Evaluation not found", response=mock_response, body=None) results_resource._get.side_effect = not_found_error - + with pytest.raises(NotFoundError): results_resource.get(evaluation_id="nonexistent-eval") def test_get_results_handles_auth_error(self, results_resource): """get method propagates authentication errors.""" from atlas._exceptions import AuthenticationError - + mock_response = Mock() mock_response.status_code = 401 mock_response.headers = {} - + auth_error = AuthenticationError("Unauthorized", response=mock_response, body=None) results_resource._get.side_effect = auth_error - + with pytest.raises(AuthenticationError): results_resource.get(evaluation_id="eval-123") def test_get_results_handles_permission_error(self, results_resource): """get method propagates permission errors.""" from atlas._exceptions import PermissionDeniedError - + mock_response = Mock() mock_response.status_code = 403 mock_response.headers = {} - + permission_error = PermissionDeniedError("Access denied", response=mock_response, body=None) results_resource._get.side_effect = permission_error - + with pytest.raises(PermissionDeniedError): results_resource.get(evaluation_id="restricted-eval") def test_get_results_handles_server_error(self, results_resource): """get method propagates server errors.""" from atlas._exceptions import InternalServerError - + mock_response = Mock() mock_response.status_code = 500 mock_response.headers = {} - + server_error = InternalServerError("Internal server error", response=mock_response, body=None) results_resource._get.side_effect = server_error - + with pytest.raises(InternalServerError): results_resource.get(evaluation_id="eval-123") def test_get_results_handles_connection_error(self, results_resource): """get method propagates connection errors.""" from atlas._exceptions import APIConnectionError - + mock_request = Mock() connection_error = APIConnectionError(request=mock_request) results_resource._get.side_effect = connection_error - + with pytest.raises(APIConnectionError): results_resource.get(evaluation_id="eval-123") def test_get_results_handles_timeout_error(self, results_resource): """get method propagates timeout errors.""" from atlas._exceptions import APITimeoutError - + mock_request = Mock() timeout_error = APITimeoutError(mock_request) results_resource._get.side_effect = timeout_error - + with pytest.raises(APITimeoutError): results_resource.get(evaluation_id="eval-123", timeout=1.0) @@ -358,19 +397,29 @@ def test_get_results_handles_complex_metrics(self, results_resource): "rouge_l": 0.80, "semantic_similarity": 0.91, "factual_correctness": 0.95, - "reasoning_steps": 4.0 - } + "reasoning_steps": 4.0, + }, + } + + response = { + "evaluation_id": "eval-complex", + "results": [complex_result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, } - - complex_result = Result(**complex_result_data) - response = ResultsData(results=[complex_result]) results_resource._get.return_value = response - + result = results_resource.get(evaluation_id="eval-complex") - - assert len(result) == 1 - result_item = result[0] - + + assert isinstance(result, ResultsData) + assert len(result.results) == 1 + result_item = result.results[0] + assert result_item.score == 0.87 assert len(result_item.metrics) == 12 assert result_item.metrics["f1_score"] == 0.875 @@ -380,14 +429,14 @@ def test_get_results_handles_complex_metrics(self, results_resource): def test_get_results_handles_different_durations(self, results_resource): """get method handles various duration formats.""" durations_to_test = [ - timedelta(seconds=0.1), # Very short - timedelta(seconds=1.5), # Normal - timedelta(seconds=30.0), # Long - timedelta(minutes=2.5), # Very long - timedelta(hours=1), # Extremely long + timedelta(seconds=0.1), # Very short + timedelta(seconds=1.5), # Normal + timedelta(seconds=30.0), # Long + timedelta(minutes=2.5), # Very long + timedelta(hours=1), # Extremely long ] - - results = [] + + results_data = [] for i, duration in enumerate(durations_to_test): result_data = { "subset": f"test-{i}", @@ -396,21 +445,32 @@ def test_get_results_handles_different_durations(self, results_resource): "truth": f"Test truth {i}", "duration": duration, "score": 0.8 + i * 0.05, - "metrics": {"accuracy": 0.8 + i * 0.05} + "metrics": {"accuracy": 0.8 + i * 0.05}, } - results.append(Result(**result_data)) - - response = ResultsData(results=results) + results_data.append(result_data) + + response = { + "evaluation_id": "eval-durations", + "results": results_data, + "metrics": { + "total_count": 5, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } results_resource._get.return_value = response - + result = results_resource.get(evaluation_id="eval-durations") - - assert len(result) == 5 - assert result[0].duration == timedelta(seconds=0.1) - assert result[1].duration == timedelta(seconds=1.5) - assert result[2].duration == timedelta(seconds=30.0) - assert result[3].duration == timedelta(minutes=2.5) - assert result[4].duration == timedelta(hours=1) + + assert isinstance(result, ResultsData) + assert len(result.results) == 5 + assert result.results[0].duration == timedelta(seconds=0.1) + assert result.results[1].duration == timedelta(seconds=1.5) + assert result.results[2].duration == timedelta(seconds=30.0) + assert result.results[3].duration == timedelta(minutes=2.5) + assert result.results[4].duration == timedelta(hours=1) def test_get_results_handles_empty_metrics(self, results_resource): """get method handles results with empty metrics.""" @@ -421,27 +481,377 @@ def test_get_results_handles_empty_metrics(self, results_resource): "truth": "Minimal truth", "duration": timedelta(seconds=1.0), "score": 0.5, - "metrics": {} # Empty metrics + "metrics": {}, # Empty metrics + } + + response = { + "evaluation_id": "eval-minimal", + "results": [result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, } - - minimal_result = Result(**result_data) - response = ResultsData(results=[minimal_result]) results_resource._get.return_value = response - + result = results_resource.get(evaluation_id="eval-minimal") - - assert len(result) == 1 - assert result[0].metrics == {} - assert isinstance(result[0].metrics, dict) + + assert isinstance(result, ResultsData) + assert len(result.results) == 1 + assert result.results[0].metrics == {} + assert isinstance(result.results[0].metrics, dict) def test_get_results_return_type_consistency(self, results_resource): """get method returns consistent types.""" - # Test that the method returns either a list or None + # Test that the method returns either a ResultsData object or None results_resource._get.return_value = None result = results_resource.get(evaluation_id="eval-123") assert result is None - - # Test that it returns a list when successful - results_resource._get.return_value = ResultsData(results=[]) + + # Test that it returns a ResultsData object when successful + empty_response = { + "evaluation_id": "eval-123", + "results": [], + "metrics": { + "total_count": 0, + "min_toxicity_score": None, + "max_toxicity_score": None, + "min_readability_score": None, + "max_readability_score": None, + }, + } + results_resource._get.return_value = empty_response + result = results_resource.get(evaluation_id="eval-123") + assert isinstance(result, ResultsData) + + +class TestResultsPagination: + """Test pagination functionality in Results resource.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.get_cast = Mock() + return client + + @pytest.fixture + def results_resource(self, mock_client): + """Results resource instance.""" + return Results(mock_client) + + @pytest.fixture + def sample_result_data(self): + """Sample result data for testing.""" + return { + "subset": "mathematics", + "prompt": "What is the derivative of x^2?", + "result": "2x", + "truth": "2x", + "duration": timedelta(seconds=2.5), + "score": 1.0, + "metrics": {"accuracy": 1.0, "confidence": 0.95}, + } + + def test_get_results_with_pagination_parameters(self, results_resource, sample_result_data): + """get method accepts pagination parameters.""" + mock_response = { + "evaluation_id": "eval-paginated", + "results": [sample_result_data], + "metrics": { + "total_count": 250, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.5, + "min_readability_score": 0.3, + "max_readability_score": 0.95, + }, + } + results_resource._get.return_value = mock_response + + result_data = results_resource.get( + evaluation_id="eval-paginated", + page=2, + page_size=50, + ) + + # Verify the call was made with correct parameters + results_resource._get.assert_called_once_with( + "/results", + params={ + "evaluation_id": "eval-paginated", + "page": "2", + "pageSize": "50", + }, + timeout=DEFAULT_TIMEOUT, + cast_to=dict, + ) + + # Verify the response structure + assert isinstance(result_data, ResultsData) + assert result_data.evaluation_id == "eval-paginated" + assert result_data.pagination.total_count == 250 + assert result_data.pagination.page_size == 50 + assert result_data.pagination.total_pages == 5 # ceil(250 / 50) = 5 + + def test_get_results_pagination_parameter_conversion(self, results_resource, sample_result_data): + """get method converts pagination parameters to strings.""" + mock_response = { + "evaluation_id": "eval-123", + "results": [sample_result_data], + "metrics": { + "total_count": 100, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = mock_response + + results_resource.get(evaluation_id="eval-123", page=3, page_size=25) + + call_args = results_resource._get.call_args + params = call_args.kwargs["params"] + + # Verify parameters are converted to strings + assert params["page"] == "3" + assert params["pageSize"] == "25" + assert isinstance(params["page"], str) + assert isinstance(params["pageSize"], str) + + def test_get_results_default_page_parameter(self, results_resource, sample_result_data): + """get method defaults to page 1 when no page is specified.""" + mock_response = { + "evaluation_id": "eval-123", + "results": [sample_result_data], + "metrics": { + "total_count": 100, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = mock_response + + results_resource.get(evaluation_id="eval-123") + + call_args = results_resource._get.call_args + params = call_args.kwargs["params"] + assert params["page"] == "1" + assert "pageSize" not in params # pageSize should not be included when not specified + + def test_get_results_pagination_metadata_calculation(self, results_resource, sample_result_data): + """get method correctly calculates pagination metadata.""" + # Mock API response without pagination + api_response = { + "evaluation_id": "eval-math", + "results": [sample_result_data], + "metrics": { + "total_count": 487, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.15, + "min_readability_score": 0.75, + "max_readability_score": 0.98, + }, + } + results_resource._get.return_value = api_response + + result = results_resource.get(evaluation_id="eval-math", page=3, page_size=50) + + # Should have calculated pagination correctly + assert isinstance(result, ResultsData) + assert result.pagination.total_count == 487 + assert result.pagination.page_size == 50 + assert result.pagination.total_pages == 10 # ceil(487 / 50) = 10 + + @pytest.mark.parametrize( + "total_count,page_size,expected_pages", + [ + (100, 50, 2), + (99, 50, 2), + (101, 50, 3), + (1000, 100, 10), + (999, 100, 10), + (1001, 100, 11), + (1, 100, 1), + (0, 100, 0), + (250, 25, 10), + (251, 25, 11), + ], + ) + def test_pagination_total_pages_calculation( + self, results_resource, sample_result_data, total_count, page_size, expected_pages + ): + """get method correctly calculates total_pages for various scenarios.""" + api_response = { + "evaluation_id": "eval-calc", + "results": [sample_result_data] if total_count > 0 else [], + "metrics": { + "total_count": total_count, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = api_response + + result = results_resource.get(evaluation_id="eval-calc", page_size=page_size) + + assert result.pagination.total_count == total_count + assert result.pagination.page_size == page_size + assert result.pagination.total_pages == expected_pages + + +class TestResultsPaginationErrorHandling: + """Test error handling and edge cases for pagination.""" + + @pytest.fixture + def mock_client(self): + """Mock Atlas client.""" + client = Mock() + client.get_cast = Mock() + return client + + @pytest.fixture + def results_resource(self, mock_client): + """Results resource instance.""" + return Results(mock_client) + + def test_get_results_invalid_api_response(self, results_resource): + """get method handles invalid API response structure.""" + # Response missing metrics + invalid_response = { + "evaluation_id": "eval-123", + "results": [], + # Missing metrics + } + results_resource._get.return_value = invalid_response + + result = results_resource.get(evaluation_id="eval-123") + + # Should return None when response structure is invalid + assert result is None + + def test_get_results_with_zero_total_count_in_metrics(self, results_resource): + """get method handles zero total_count in metrics.""" + invalid_response = { + "evaluation_id": "eval-123", + "results": [], + "metrics": { + "total_count": 0, # Now included, testing graceful handling + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": None, + "max_readability_score": None, + }, + } + results_resource._get.return_value = invalid_response + + result = results_resource.get(evaluation_id="eval-123") + + # Should handle zero total_count gracefully + assert isinstance(result, ResultsData) + assert result.pagination.total_count == 0 + assert result.pagination.total_pages == 0 + + def test_get_results_non_dict_response(self, results_resource): + """get method handles non-dict API response.""" + results_resource._get.return_value = "invalid-string-response" + result = results_resource.get(evaluation_id="eval-123") - assert isinstance(result, list) \ No newline at end of file + + assert result is None + + def test_get_results_pydantic_validation_error(self, results_resource): + """get method handles Pydantic validation errors.""" + # Response with invalid data types + invalid_response = { + "evaluation_id": "eval-123", + "results": "not-a-list", # Should be a list + "metrics": { + "total_count": 100, + }, + } + results_resource._get.return_value = invalid_response + + result = results_resource.get(evaluation_id="eval-123") + + assert result is None + + def test_get_results_extreme_pagination_values(self, results_resource): + """get method handles extreme pagination values.""" + extreme_response = { + "evaluation_id": "eval-extreme", + "results": [], + "metrics": { + "total_count": 999999, # Very large number + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = extreme_response + + result = results_resource.get(evaluation_id="eval-extreme", page_size=1) + + assert isinstance(result, ResultsData) + assert result.pagination.total_count == 999999 + assert result.pagination.page_size == 1 + assert result.pagination.total_pages == 999999 # ceil(999999 / 1) + + def test_get_results_zero_page_size_edge_case(self, results_resource): + """get method handles zero page_size (should use default).""" + response = { + "evaluation_id": "eval-123", + "results": [], + "metrics": { + "total_count": 100, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = response + + # Pass 0 as page_size + result = results_resource.get(evaluation_id="eval-123", page_size=0) + + assert isinstance(result, ResultsData) + # Should use 0 as provided (though this might cause division by zero, it's handled) + assert result.pagination.page_size == 0 + + def test_get_results_negative_page_values(self, results_resource): + """get method handles negative page values.""" + response = { + "evaluation_id": "eval-123", + "results": [], + "metrics": { + "total_count": 100, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + results_resource._get.return_value = response + + # Test with negative page and page_size + result = results_resource.get(evaluation_id="eval-123", page=-1, page_size=-50) + + # Should still make the API call and process response + call_args = results_resource._get.call_args + params = call_args.kwargs["params"] + assert params["page"] == "-1" + assert params["pageSize"] == "-50" + + assert isinstance(result, ResultsData) + assert result.pagination.page_size == -50 + # total_pages calculation with negative page_size + assert result.pagination.total_pages == 0 # math.ceil handles negative divisors diff --git a/tests/test_base_client.py b/tests/test_base_client.py index ba1a35a..878a391 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -11,6 +11,7 @@ @dataclass class ResponseModel: """Test model for response casting.""" + name: str value: int @@ -35,14 +36,14 @@ def mock_response(self): def test_init_sets_base_url(self): """BaseClient initializes with correct base URL.""" client = BaseClient(base_url="https://custom.api.com") - + assert str(client.base_url) == "https://custom.api.com" def test_init_with_headers(self): """BaseClient accepts custom headers.""" headers = {"X-Custom": "value"} client = BaseClient(base_url="https://api.test.com", headers=headers) - + assert client.headers["X-Custom"] == "value" def test_auth_headers_empty_by_default(self, client): @@ -52,133 +53,117 @@ def test_auth_headers_empty_by_default(self, client): def test_default_headers_structure(self, client): """BaseClient default_headers includes required headers.""" headers = client.default_headers - + assert headers["Accept"] == "application/json" assert headers["Content-Type"] == "application/json" assert isinstance(headers, dict) def test_default_headers_includes_auth(self, client): """default_headers merges auth_headers.""" - with patch.object(type(client), 'auth_headers', new_callable=lambda: property(lambda _: {"Authorization": "Bearer token"})): + with patch.object( + type(client), "auth_headers", new_callable=lambda: property(lambda _: {"Authorization": "Bearer token"}) + ): headers = client.default_headers - + assert headers["Authorization"] == "Bearer token" assert headers["Accept"] == "application/json" - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_request_cast_without_cast_to(self, mock_request, client, mock_response): """_request_cast returns raw response when cast_to is None.""" mock_request.return_value = mock_response - + result = client._request_cast("GET", "/test") - + assert result is mock_response mock_request.assert_called_once_with( - method="GET", - url="/test", - json=None, - params=None, - headers=client.default_headers + method="GET", url="/test", json=None, params=None, headers=client.default_headers ) - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_request_cast_with_cast_to(self, mock_request, client, mock_response): """_request_cast casts response to specified type.""" mock_request.return_value = mock_response - + result = client._request_cast("GET", "/test", cast_to=ResponseModel) - + assert isinstance(result, ResponseModel) assert result.name == "test" assert result.value == 42 mock_response.json.assert_called_once() - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_request_cast_combines_headers(self, mock_request, client, mock_response): """_request_cast merges default and custom headers.""" mock_request.return_value = mock_response custom_headers = {"X-Custom": "value"} - + client._request_cast("POST", "/test", headers=custom_headers) - + expected_headers = {**client.default_headers, **custom_headers} mock_request.assert_called_once_with( - method="POST", - url="/test", - json=None, - params=None, - headers=expected_headers + method="POST", url="/test", json=None, params=None, headers=expected_headers ) - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_request_cast_with_body_and_params(self, mock_request, client, mock_response): """_request_cast sends body and params correctly.""" mock_request.return_value = mock_response body = {"key": "value"} params = {"filter": "active"} - + client._request_cast("POST", "/test", body=body, params=params) - + mock_request.assert_called_once_with( - method="POST", - url="/test", - json=body, - params=params, - headers=client.default_headers + method="POST", url="/test", json=body, params=params, headers=client.default_headers ) - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_request_cast_handles_http_error(self, mock_request, client): """_request_cast converts HTTPStatusError to APIStatusError.""" mock_response = Mock(spec=httpx.Response) mock_response.status_code = 404 mock_response.text = "Not Found" mock_response.headers = {} - mock_response.raise_for_status.side_effect = httpx.HTTPStatusError("404", request=Mock(), response=mock_response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "404", request=Mock(), response=mock_response + ) mock_request.return_value = mock_response - - with patch.object(client, '_make_status_error_from_response') as mock_make_error: + + with patch.object(client, "_make_status_error_from_response") as mock_make_error: mock_make_error.side_effect = _exceptions.APIStatusError("Test error", response=mock_response, body=None) - + with pytest.raises(_exceptions.APIStatusError): client._request_cast("GET", "/test") - + mock_make_error.assert_called_once_with(mock_response) - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_get_cast_delegates_correctly(self, mock_request, client, mock_response): """get_cast delegates to _request_cast with GET method.""" mock_request.return_value = mock_response params = {"page": 1} headers = {"X-Test": "value"} - + result = client.get_cast("/test", params=params, headers=headers, cast_to=ResponseModel) - + assert isinstance(result, ResponseModel) mock_request.assert_called_once_with( - method="GET", - url="/test", - json=None, - params=params, - headers={**client.default_headers, **headers} + method="GET", url="/test", json=None, params=params, headers={**client.default_headers, **headers} ) - @patch('httpx.Client.request') + @patch("httpx.Client.request") def test_post_cast_delegates_correctly(self, mock_request, client, mock_response): """post_cast delegates to _request_cast with POST method.""" mock_request.return_value = mock_response body = {"name": "test"} headers = {"X-Test": "value"} - + result = client.post_cast("/test", body=body, headers=headers, cast_to=ResponseModel) - + assert isinstance(result, ResponseModel) mock_request.assert_called_once_with( - method="POST", - url="/test", - json=body, - params=None, - headers={**client.default_headers, **headers} + method="POST", url="/test", json=body, params=None, headers={**client.default_headers, **headers} ) def test_make_status_error_from_response_with_json(self, client): @@ -186,10 +171,10 @@ def test_make_status_error_from_response_with_json(self, client): mock_response = Mock(spec=httpx.Response) mock_response.status_code = 400 mock_response.text = '{"error": "Bad Request", "code": 400}' - - with patch.object(client, '_make_status_error') as mock_make_error: + + with patch.object(client, "_make_status_error") as mock_make_error: client._make_status_error_from_response(mock_response) - + mock_make_error.assert_called_once() args, kwargs = mock_make_error.call_args assert "Error code: 400" in args[0] @@ -201,10 +186,10 @@ def test_make_status_error_from_response_with_text(self, client): mock_response = Mock(spec=httpx.Response) mock_response.status_code = 500 mock_response.text = "Internal Server Error" - - with patch.object(client, '_make_status_error') as mock_make_error: + + with patch.object(client, "_make_status_error") as mock_make_error: client._make_status_error_from_response(mock_response) - + mock_make_error.assert_called_once() args, kwargs = mock_make_error.call_args assert args[0] == "Internal Server Error" @@ -215,10 +200,10 @@ def test_make_status_error_from_response_empty_text(self, client): mock_response = Mock(spec=httpx.Response) mock_response.status_code = 503 mock_response.text = "" - - with patch.object(client, '_make_status_error') as mock_make_error: + + with patch.object(client, "_make_status_error") as mock_make_error: client._make_status_error_from_response(mock_response) - + mock_make_error.assert_called_once() args, _ = mock_make_error.call_args assert args[0] == "Error code: 503" @@ -226,6 +211,6 @@ def test_make_status_error_from_response_empty_text(self, client): def test_make_status_error_not_implemented(self, client): """_make_status_error raises NotImplementedError.""" mock_response = Mock(spec=httpx.Response) - + with pytest.raises(NotImplementedError): - client._make_status_error("test", body=None, response=mock_response) \ No newline at end of file + client._make_status_error("test", body=None, response=mock_response) diff --git a/tests/test_client.py b/tests/test_client.py index 0384891..84b386d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,12 +9,8 @@ class TestAtlasClientInitialization: def test_init_with_explicit_params(self): """Client initializes correctly with explicit parameters.""" - client = Atlas( - api_key="explicit-key", - organization_id="explicit-org", - project_id="explicit-project" - ) - + client = Atlas(api_key="explicit-key", organization_id="explicit-org", project_id="explicit-project") + assert client.api_key == "explicit-key" assert client.organization_id == "explicit-org" assert client.project_id == "explicit-project" @@ -23,7 +19,7 @@ def test_init_from_environment(self, mock_env_vars): """Client initializes from environment variables.""" _ = mock_env_vars # Fixture used for side effects client = Atlas() - + assert client.api_key == "test-api-key" assert client.organization_id == "test-org-id" assert client.project_id == "test-project-id" @@ -31,11 +27,8 @@ def test_init_from_environment(self, mock_env_vars): def test_explicit_params_override_env(self, mock_env_vars): """Explicit parameters override environment variables.""" _ = mock_env_vars # Fixture used for side effects - client = Atlas( - api_key="override-key", - organization_id="override-org" - ) - + client = Atlas(api_key="override-key", organization_id="override-org") + assert client.api_key == "override-key" assert client.organization_id == "override-org" assert client.project_id == "test-project-id" @@ -49,12 +42,8 @@ def test_missing_api_key_raises_error(self, env_vars): def test_none_values_fallback_to_env(self, mock_env_vars): """None values explicitly passed fallback to environment.""" _ = mock_env_vars # Fixture used for side effects - client = Atlas( - api_key=None, - organization_id=None, - project_id=None - ) - + client = Atlas(api_key=None, organization_id=None, project_id=None) + assert client.api_key == "test-api-key" assert client.organization_id == "test-org-id" assert client.project_id == "test-project-id" @@ -62,24 +51,278 @@ def test_none_values_fallback_to_env(self, mock_env_vars): def test_optional_params_can_be_none(self): """Organization and project IDs can be None.""" client = Atlas(api_key="test-key") - + assert client.api_key == "test-key" assert client.organization_id is None assert client.project_id is None - @pytest.mark.parametrize("base_url", [ - "https://custom.api.com", - "https://staging.layerlens.ai/api/v1" - ]) + @pytest.mark.parametrize("base_url", ["https://custom.api.com", "https://staging.layerlens.ai/api/v1"]) def test_custom_base_url(self, base_url): """Client accepts custom base URL.""" client = Atlas(api_key="test-key", base_url=base_url) - - assert str(client.base_url).rstrip('/') == base_url.rstrip('/') + + assert str(client.base_url).rstrip("/") == base_url.rstrip("/") def test_custom_timeout(self): """Client accepts custom timeout.""" import httpx + client = Atlas(api_key="test-key", timeout=30.0) - - assert isinstance(client.timeout, httpx.Timeout) \ No newline at end of file + + assert isinstance(client.timeout, httpx.Timeout) + + def test_auth_headers_with_api_key(self): + """auth_headers property returns correct headers when API key is set.""" + client = Atlas(api_key="test-api-key") + + headers = client.auth_headers + + assert headers == {"x-api-key": "test-api-key"} + + def test_auth_headers_without_api_key(self): + """auth_headers property returns empty dict when no API key.""" + client = Atlas(api_key="") + + headers = client.auth_headers + + assert headers == {} + + def test_auth_headers_with_empty_api_key(self): + """auth_headers property returns empty dict when API key is empty string.""" + client = Atlas(api_key="") + + headers = client.auth_headers + + assert headers == {} + + def test_copy_method(self): + """copy method creates new client with overridden parameters.""" + original_client = Atlas( + api_key="original-key", + organization_id="original-org", + project_id="original-project", + base_url="https://original.api.com", + timeout=10.0, + ) + + new_client = original_client.copy(api_key="new-key", organization_id="new-org", timeout=20.0) + + # Check overridden values + assert new_client.api_key == "new-key" + assert new_client.organization_id == "new-org" + # The copy method uses 'or' logic, so timeout=20.0 won't override the existing timeout + # Let's check that the timeout is still the original value + assert new_client.timeout == original_client.timeout # Should remain the original timeout + + # Check unchanged values + assert new_client.project_id == "original-project" + assert str(new_client.base_url) == "https://original.api.com" + + def test_copy_method_partial_override(self): + """copy method allows partial parameter override.""" + original_client = Atlas(api_key="original-key", organization_id="original-org", project_id="original-project") + + new_client = original_client.copy(api_key="new-key") + + assert new_client.api_key == "new-key" + assert new_client.organization_id == "original-org" + assert new_client.project_id == "original-project" + + def test_with_options_alias(self): + """with_options is an alias for copy method.""" + original_client = Atlas(api_key="original-key") + + new_client = original_client.with_options(api_key="new-key") + + assert new_client.api_key == "new-key" + assert new_client is not original_client + + def test_copy_method_timeout_override(self): + """copy method properly overrides timeout when original is None.""" + # Create a client with no explicit timeout (uses default) + original_client = Atlas(api_key="original-key") + + new_client = original_client.copy(timeout=30.0) + + import httpx + + assert isinstance(new_client.timeout, httpx.Timeout) + # Both clients use the same default timeout, so they should be equal + assert new_client.timeout == original_client.timeout + + +class TestAtlasClientErrorHandling: + """Test error handling in Atlas client.""" + + def _create_mock_response(self, status_code): + """Helper to create a mock response with all required attributes.""" + mock_request = type("MockRequest", (), {})() + return type("MockResponse", (), {"status_code": status_code, "request": mock_request, "headers": {}})() + + def test_make_status_error_bad_request(self): + """_make_status_error creates BadRequestError for 400 status.""" + from atlas._exceptions import BadRequestError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(400) + mock_body = {"error": "Bad request"} + + error = client._make_status_error("Bad request", body=mock_body, response=mock_response) + + assert isinstance(error, BadRequestError) + assert error.message == "Bad request" + + def test_make_status_error_unauthorized(self): + """_make_status_error creates AuthenticationError for 401 status.""" + from atlas._exceptions import AuthenticationError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(401) + mock_body = {"error": "Unauthorized"} + + error = client._make_status_error("Unauthorized", body=mock_body, response=mock_response) + + assert isinstance(error, AuthenticationError) + assert error.message == "Unauthorized" + + def test_make_status_error_forbidden(self): + """_make_status_error creates PermissionDeniedError for 403 status.""" + from atlas._exceptions import PermissionDeniedError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(403) + mock_body = {"error": "Forbidden"} + + error = client._make_status_error("Forbidden", body=mock_body, response=mock_response) + + assert isinstance(error, PermissionDeniedError) + assert error.message == "Forbidden" + + def test_make_status_error_not_found(self): + """_make_status_error creates NotFoundError for 404 status.""" + from atlas._exceptions import NotFoundError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(404) + mock_body = {"error": "Not found"} + + error = client._make_status_error("Not found", body=mock_body, response=mock_response) + + assert isinstance(error, NotFoundError) + assert error.message == "Not found" + + def test_make_status_error_conflict(self): + """_make_status_error creates ConflictError for 409 status.""" + from atlas._exceptions import ConflictError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(409) + mock_body = {"error": "Conflict"} + + error = client._make_status_error("Conflict", body=mock_body, response=mock_response) + + assert isinstance(error, ConflictError) + assert error.message == "Conflict" + + def test_make_status_error_unprocessable_entity(self): + """_make_status_error creates UnprocessableEntityError for 422 status.""" + from atlas._exceptions import UnprocessableEntityError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(422) + mock_body = {"error": "Unprocessable entity"} + + error = client._make_status_error("Unprocessable entity", body=mock_body, response=mock_response) + + assert isinstance(error, UnprocessableEntityError) + assert error.message == "Unprocessable entity" + + def test_make_status_error_rate_limit(self): + """_make_status_error creates RateLimitError for 429 status.""" + from atlas._exceptions import RateLimitError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(429) + mock_body = {"error": "Rate limited"} + + error = client._make_status_error("Rate limited", body=mock_body, response=mock_response) + + assert isinstance(error, RateLimitError) + assert error.message == "Rate limited" + + def test_make_status_error_internal_server_error(self): + """_make_status_error creates InternalServerError for 500+ status.""" + from atlas._exceptions import InternalServerError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(500) + mock_body = {"error": "Internal server error"} + + error = client._make_status_error("Internal server error", body=mock_body, response=mock_response) + + assert isinstance(error, InternalServerError) + assert error.message == "Internal server error" + + def test_make_status_error_gateway_timeout(self): + """_make_status_error creates InternalServerError for 502 status.""" + from atlas._exceptions import InternalServerError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(502) + mock_body = {"error": "Gateway timeout"} + + error = client._make_status_error("Gateway timeout", body=mock_body, response=mock_response) + + assert isinstance(error, InternalServerError) + assert error.message == "Gateway timeout" + + def test_make_status_error_unknown_status(self): + """_make_status_error creates generic APIStatusError for unknown status codes.""" + from atlas._exceptions import APIStatusError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(418) # I'm a teapot + mock_body = {"error": "Unknown error"} + + error = client._make_status_error("Unknown error", body=mock_body, response=mock_response) + + assert isinstance(error, APIStatusError) + assert error.message == "Unknown error" + + def test_make_status_error_with_non_mapping_body(self): + """_make_status_error handles non-mapping body correctly.""" + from atlas._exceptions import NotFoundError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(404) + mock_body = "Simple string error" + + error = client._make_status_error("Not found", body=mock_body, response=mock_response) + + assert isinstance(error, NotFoundError) + assert error.body == "Simple string error" + + def test_make_status_error_with_none_body(self): + """_make_status_error handles None body correctly.""" + from atlas._exceptions import BadRequestError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(400) + + error = client._make_status_error("Bad request", body=None, response=mock_response) + + assert isinstance(error, BadRequestError) + assert error.body is None + + def test_make_status_error_with_complex_body(self): + """_make_status_error extracts error from complex body structure.""" + from atlas._exceptions import AuthenticationError + + client = Atlas(api_key="test-key") + mock_response = self._create_mock_response(401) + mock_body = {"error": {"message": "Invalid API key", "code": "AUTH_ERROR"}, "timestamp": "2023-01-01T00:00:00Z"} + + error = client._make_status_error("Authentication failed", body=mock_body, response=mock_response) + + assert isinstance(error, AuthenticationError) + assert error.body == mock_body["error"] # Should extract the error field diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index dea1d9a..4471471 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -28,7 +28,7 @@ class TestExceptionHierarchy: def test_atlas_error_is_base_exception(self): """AtlasError inherits from Exception.""" error = AtlasError("test message") - + assert isinstance(error, Exception) assert str(error) == "test message" @@ -36,7 +36,7 @@ def test_api_error_inherits_from_atlas_error(self): """APIError inherits from AtlasError.""" mock_request = Mock(spec=httpx.Request) error = APIError("api error", mock_request, body=None) - + assert isinstance(error, AtlasError) assert isinstance(error, Exception) @@ -46,31 +46,34 @@ def test_api_status_error_inherits_from_api_error(self): mock_response.request = Mock(spec=httpx.Request) mock_response.status_code = 400 mock_response.headers = {} - + error = APIStatusError("status error", response=mock_response, body=None) - + assert isinstance(error, APIError) assert isinstance(error, AtlasError) - @pytest.mark.parametrize("exception_class", [ - BadRequestError, - AuthenticationError, - PermissionDeniedError, - NotFoundError, - ConflictError, - UnprocessableEntityError, - RateLimitError, - InternalServerError, - ]) + @pytest.mark.parametrize( + "exception_class", + [ + BadRequestError, + AuthenticationError, + PermissionDeniedError, + NotFoundError, + ConflictError, + UnprocessableEntityError, + RateLimitError, + InternalServerError, + ], + ) def test_status_exceptions_inherit_from_api_status_error(self, exception_class): """All status-specific exceptions inherit from APIStatusError.""" mock_response = Mock(spec=httpx.Response) mock_response.request = Mock(spec=httpx.Request) mock_response.status_code = 400 mock_response.headers = {} - + error = exception_class("test error", response=mock_response, body=None) - + assert isinstance(error, APIStatusError) assert isinstance(error, APIError) assert isinstance(error, AtlasError) @@ -88,7 +91,7 @@ def test_api_error_stores_message_and_request(self, mock_request): """APIError stores message, request, and body.""" body = {"error": "test"} error = APIError("test message", mock_request, body=body) - + assert error.message == "test message" assert error.request is mock_request assert error.body == body @@ -97,7 +100,7 @@ def test_api_error_stores_message_and_request(self, mock_request): def test_api_error_with_none_body(self, mock_request): """APIError handles None body.""" error = APIError("test message", mock_request, body=None) - + assert error.body is None assert error.message == "test message" @@ -105,7 +108,7 @@ def test_api_error_with_json_body(self, mock_request): """APIError stores JSON body correctly.""" body = {"error": "validation failed", "code": 422} error = APIError("validation error", mock_request, body=body) - + assert error.body == body assert isinstance(error.body, dict) assert error.body["error"] == "validation failed" @@ -114,7 +117,7 @@ def test_api_error_with_string_body(self, mock_request): """APIError stores string body correctly.""" body = "Plain text error message" error = APIError("server error", mock_request, body=body) - + assert error.body == body @@ -132,7 +135,7 @@ def mock_response(self): def test_validation_error_with_default_message(self, mock_response): """APIResponseValidationError uses default message when none provided.""" error = APIResponseValidationError(mock_response, body=None) - + assert error.message == "Data returned by API invalid for expected schema." assert error.response is mock_response assert error.status_code == 200 @@ -141,7 +144,7 @@ def test_validation_error_with_custom_message(self, mock_response): """APIResponseValidationError uses custom message when provided.""" custom_message = "Custom validation error" error = APIResponseValidationError(mock_response, body=None, message=custom_message) - + assert error.message == custom_message assert str(error) == custom_message @@ -149,7 +152,7 @@ def test_validation_error_stores_response_data(self, mock_response): """APIResponseValidationError stores response and body.""" body = {"invalid": "data"} error = APIResponseValidationError(mock_response, body=body) - + assert error.response is mock_response assert error.body == body assert error.request is mock_response.request @@ -170,7 +173,7 @@ def mock_response(self): def test_status_error_stores_response_data(self, mock_response): """APIStatusError stores response, status code, and request ID.""" error = APIStatusError("not found", response=mock_response, body=None) - + assert error.response is mock_response assert error.status_code == 404 assert error.request_id == "req-123" @@ -180,14 +183,14 @@ def test_status_error_without_request_id(self, mock_response): """APIStatusError handles missing request ID header.""" mock_response.headers = {} error = APIStatusError("error", response=mock_response, body=None) - + assert error.request_id is None def test_status_error_with_body(self, mock_response): """APIStatusError stores error body.""" body = {"error": "Resource not found", "code": "NOT_FOUND"} error = APIStatusError("not found", response=mock_response, body=body) - + assert error.body == body @@ -202,7 +205,7 @@ def mock_request(self): def test_api_connection_error_default_message(self, mock_request): """APIConnectionError uses default message.""" error = APIConnectionError(request=mock_request) - + assert error.message == "Connection error." assert error.request is mock_request assert error.body is None @@ -211,13 +214,13 @@ def test_api_connection_error_custom_message(self, mock_request): """APIConnectionError accepts custom message.""" custom_message = "Failed to connect to server" error = APIConnectionError(message=custom_message, request=mock_request) - + assert error.message == custom_message def test_api_timeout_error_inherits_from_connection_error(self, mock_request): """APITimeoutError inherits from APIConnectionError.""" error = APITimeoutError(mock_request) - + assert isinstance(error, APIConnectionError) assert isinstance(error, APIError) assert error.message == "Request timed out." @@ -230,30 +233,35 @@ class TestStatusCodeExceptions: @pytest.fixture def mock_response_factory(self): """Factory for creating mock responses with different status codes.""" + def _create_response(status_code: int) -> Mock: mock = Mock(spec=httpx.Response) mock.request = Mock(spec=httpx.Request) mock.status_code = status_code mock.headers = {} return mock + return _create_response - @pytest.mark.parametrize("exception_class,expected_status", [ - (BadRequestError, HTTPStatus.BAD_REQUEST), - (AuthenticationError, HTTPStatus.UNAUTHORIZED), - (PermissionDeniedError, HTTPStatus.FORBIDDEN), - (NotFoundError, HTTPStatus.NOT_FOUND), - (ConflictError, HTTPStatus.CONFLICT), - (UnprocessableEntityError, HTTPStatus.UNPROCESSABLE_ENTITY), - (RateLimitError, HTTPStatus.TOO_MANY_REQUESTS), - ]) + @pytest.mark.parametrize( + "exception_class,expected_status", + [ + (BadRequestError, HTTPStatus.BAD_REQUEST), + (AuthenticationError, HTTPStatus.UNAUTHORIZED), + (PermissionDeniedError, HTTPStatus.FORBIDDEN), + (NotFoundError, HTTPStatus.NOT_FOUND), + (ConflictError, HTTPStatus.CONFLICT), + (UnprocessableEntityError, HTTPStatus.UNPROCESSABLE_ENTITY), + (RateLimitError, HTTPStatus.TOO_MANY_REQUESTS), + ], + ) def test_status_exception_has_correct_status_code(self, exception_class, expected_status, mock_response_factory): """Status-specific exceptions have correct status codes.""" mock_response = mock_response_factory(expected_status.value) error = exception_class("test error", response=mock_response, body=None) - + assert error.status_code == expected_status.value - assert hasattr(error.__class__, 'status_code') + assert hasattr(error.__class__, "status_code") assert error.__class__.status_code == expected_status def test_bad_request_error_properties(self, mock_response_factory): @@ -261,7 +269,7 @@ def test_bad_request_error_properties(self, mock_response_factory): mock_response = mock_response_factory(400) body = {"error": "Invalid request", "field": "name"} error = BadRequestError("bad request", response=mock_response, body=body) - + assert error.status_code == 400 assert error.body == body assert isinstance(error, APIStatusError) @@ -270,7 +278,7 @@ def test_authentication_error_properties(self, mock_response_factory): """AuthenticationError has correct properties.""" mock_response = mock_response_factory(401) error = AuthenticationError("unauthorized", response=mock_response, body=None) - + assert error.status_code == 401 assert error.__class__.status_code == HTTPStatus.UNAUTHORIZED @@ -278,9 +286,9 @@ def test_internal_server_error_no_fixed_status(self, mock_response_factory): """InternalServerError doesn't have a fixed status code.""" mock_response = mock_response_factory(500) error = InternalServerError("server error", response=mock_response, body=None) - + assert error.status_code == 500 - assert not hasattr(error.__class__, 'status_code') or error.__class__.status_code is None + assert not hasattr(error.__class__, "status_code") or error.__class__.status_code is None class TestErrorMessages: @@ -290,18 +298,15 @@ def test_exception_str_representation(self): """Exception string representation shows message.""" mock_request = Mock(spec=httpx.Request) error = APIError("Test error message", mock_request, body=None) - + assert str(error) == "Test error message" def test_exception_with_complex_body(self): """Exception handles complex body structures.""" mock_request = Mock(spec=httpx.Request) - body = { - "error": {"code": "VALIDATION_ERROR", "details": ["Field 'name' is required"]}, - "request_id": "req-456" - } + body = {"error": {"code": "VALIDATION_ERROR", "details": ["Field 'name' is required"]}, "request_id": "req-456"} error = APIError("Validation failed", mock_request, body=body) - + assert isinstance(error.body, dict) assert error.body["error"]["code"] == "VALIDATION_ERROR" - assert error.body["request_id"] == "req-456" \ No newline at end of file + assert error.body["request_id"] == "req-456" diff --git a/tests/test_integration.py b/tests/test_integration.py index 8daa315..d91c87e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,11 +23,7 @@ class TestAtlasIntegration: @pytest.fixture def atlas_client(self): """Create Atlas client with mocked dependencies.""" - return Atlas( - api_key="test-api-key", - organization_id="test-org", - project_id="test-project" - ) + return Atlas(api_key="test-api-key", organization_id="test-org", project_id="test-project") @pytest.fixture def sample_model_data(self): @@ -96,10 +92,7 @@ def sample_result_data(self): "truth": "2x", "duration": timedelta(seconds=2.5), "score": 1.0, - "metrics": { - "accuracy": 1.0, - "confidence": 0.95 - } + "metrics": {"accuracy": 1.0, "confidence": 0.95}, } @@ -109,99 +102,147 @@ class TestCompleteEvaluationWorkflow: @pytest.fixture def atlas_client(self): """Atlas client for workflow testing.""" - return Atlas( - api_key="workflow-test-key", - organization_id="workflow-org", - project_id="workflow-project" - ) + return Atlas(api_key="workflow-test-key", organization_id="workflow-org", project_id="workflow-project") def test_complete_evaluation_workflow(self, atlas_client): """Test complete workflow: get models/benchmarks -> create evaluation -> get results.""" - + # Mock data model_data = { - "id": "model-123", "key": "gpt-4", "name": "GPT-4", "company": "OpenAI", - "description": "LLM", "released_at": 1679875200, "parameters": 1.76e12, - "modality": "text", "context_length": 8192, "architecture_type": "transformer", - "license": "proprietary", "open_weights": False, "region": "us-east-1", "deprecated": False, + "id": "model-123", + "key": "gpt-4", + "name": "GPT-4", + "company": "OpenAI", + "description": "LLM", + "released_at": 1679875200, + "parameters": 1.76e12, + "modality": "text", + "context_length": 8192, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-east-1", + "deprecated": False, } - + benchmark_data = { - "id": "bench-456", "key": "mmlu", "name": "MMLU", - "full_description": "MMLU benchmark", "language": "english", - "categories": ["reasoning"], "subsets": ["math"], "prompt_count": 1000, "deprecated": False, + "id": "bench-456", + "key": "mmlu", + "name": "MMLU", + "full_description": "MMLU benchmark", + "language": "english", + "categories": ["reasoning"], + "subsets": ["math"], + "prompt_count": 1000, + "deprecated": False, } - + evaluation_data = { - "id": "eval-789", "status": "completed", "status_description": "Done", - "submitted_at": 1640995200, "finished_at": 1640995800, - "model_id": "model-123", "model_name": "GPT-4", "model_key": "gpt-4", "model_company": "OpenAI", - "dataset_id": "bench-456", "dataset_name": "MMLU", "average_duration": 2500, - "readability_score": 0.85, "toxicity_score": 0.02, "ethics_score": 0.92, "accuracy": 0.89, + "id": "eval-789", + "status": "completed", + "status_description": "Done", + "submitted_at": 1640995200, + "finished_at": 1640995800, + "model_id": "model-123", + "model_name": "GPT-4", + "model_key": "gpt-4", + "model_company": "OpenAI", + "dataset_id": "bench-456", + "dataset_name": "MMLU", + "average_duration": 2500, + "readability_score": 0.85, + "toxicity_score": 0.02, + "ethics_score": 0.92, + "accuracy": 0.89, } - + result_data = { - "subset": "math", "prompt": "2+2=?", "result": "4", "truth": "4", - "duration": timedelta(seconds=1.5), "score": 1.0, "metrics": {"accuracy": 1.0} + "subset": "math", + "prompt": "2+2=?", + "result": "4", + "truth": "4", + "duration": timedelta(seconds=1.5), + "score": 1.0, + "metrics": {"accuracy": 1.0}, } - + # Create model objects model = Model(**model_data) benchmark = Benchmark(**benchmark_data) evaluation = Evaluation(**evaluation_data) result = Result(**result_data) - + # Mock responses models_response = ModelsData(models=[model]) benchmarks_response = BenchmarksData(datasets=[benchmark]) evaluations_response = EvaluationsData(data=[evaluation]) - results_response = ResultsData(results=[result]) - - with patch.object(atlas_client, 'get_cast') as mock_get, \ - patch.object(atlas_client, 'post_cast') as mock_post: - + results_response = ResultsData( + evaluation_id="eval-789", + results=[result], + metrics={ + "total_count": 1, + "min_toxicity_score": 0.02, + "max_toxicity_score": 0.02, + "min_readability_score": 0.85, + "max_readability_score": 0.85, + }, + pagination={ + "total_count": 1, + "page_size": 100, + "total_pages": 1, + }, + ) + + with patch.object(atlas_client, "get_cast") as mock_get, patch.object(atlas_client, "post_cast") as mock_post: # Configure mocks for the workflow - mock_get.return_value = results_response # Get results + mock_get.return_value = { + "evaluation_id": "eval-789", + "results": [result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.02, + "max_toxicity_score": 0.02, + "min_readability_score": 0.85, + "max_readability_score": 0.85, + }, + } # Get results - raw API response mock_post.return_value = evaluations_response # Create evaluation - + # Step 1: Create evaluation directly (Atlas client doesn't expose models/benchmarks resources) - created_evaluation = atlas_client.evaluations.create( - model="gpt-4", - benchmark="mmlu" - ) + created_evaluation = atlas_client.evaluations.create(model="gpt-4", benchmark="mmlu") assert created_evaluation.id == "eval-789" assert created_evaluation.status == "completed" - + # Step 2: Get evaluation results results = atlas_client.results.get(evaluation_id=created_evaluation.id) - assert len(results) == 1 - assert results[0].score == 1.0 - assert results[0].subset == "math" - + assert len(results.results) == 1 + assert results.results[0].score == 1.0 + assert results.results[0].subset == "math" + # Verify all API calls were made correctly assert mock_get.call_count == 1 # Only results call assert mock_post.call_count == 1 - + # Verify specific API calls get_calls = mock_get.call_args_list assert "/results" in get_calls[0][0][0] - + post_call = mock_post.call_args_list[0] assert "/evaluations" in post_call[0][0] def test_workflow_with_error_handling(self, atlas_client): """Test workflow handles errors gracefully.""" from atlas._exceptions import NotFoundError - + mock_response = Mock() mock_response.status_code = 404 mock_response.headers = {} - - with patch.object(atlas_client, 'get_cast') as mock_get: + + with patch.object(atlas_client, "get_cast") as mock_get: # Mock API error when getting results api_error = NotFoundError("Results not found", response=mock_response, body=None) mock_get.side_effect = api_error - + # Verify error is propagated with pytest.raises(NotFoundError): atlas_client.results.get(evaluation_id="test-eval") @@ -209,21 +250,51 @@ def test_workflow_with_error_handling(self, atlas_client): def test_workflow_with_custom_timeouts(self, atlas_client): """Test workflow respects custom timeout settings.""" result_data = { - "subset": "test", "prompt": "test", "result": "test", "truth": "test", - "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + "subset": "test", + "prompt": "test", + "result": "test", + "truth": "test", + "duration": timedelta(seconds=1.0), + "score": 1.0, + "metrics": {"accuracy": 1.0}, } - - results_response = ResultsData(results=[Result(**result_data)]) - - with patch.object(atlas_client, 'get_cast') as mock_get: - mock_get.return_value = results_response - + + results_response = ResultsData( + evaluation_id="test-eval", + results=[Result(**result_data)], + metrics={ + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + pagination={ + "total_count": 1, + "page_size": 100, + "total_pages": 1, + }, + ) + + with patch.object(atlas_client, "get_cast") as mock_get: + mock_get.return_value = { + "evaluation_id": "test-eval", + "results": [result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + # Test with custom timeout custom_timeout = httpx.Timeout(30.0) results = atlas_client.results.get(evaluation_id="test-eval", timeout=custom_timeout) - - assert len(results) == 1 - + + assert len(results.results) == 1 + # Verify timeout was passed correctly call_args = mock_get.call_args assert call_args.kwargs["timeout"] is custom_timeout @@ -236,55 +307,77 @@ class TestResourceInteraction: def atlas_client(self): """Atlas client for resource interaction testing.""" return Atlas( - api_key="interaction-test-key", - organization_id="interaction-org", - project_id="interaction-project" + api_key="interaction-test-key", organization_id="interaction-org", project_id="interaction-project" ) def test_evaluation_creation_with_model_and_benchmark_objects(self, atlas_client): """Test creating evaluation using model and benchmark objects.""" - + # Create model and benchmark objects model_data = { - "id": "model-abc", "key": "claude-3", "name": "Claude 3", "company": "Anthropic", - "description": "Claude 3", "released_at": 1709251200, "parameters": 5e11, - "modality": "text", "context_length": 100000, "architecture_type": "transformer", - "license": "proprietary", "open_weights": False, "region": "us-west-2", "deprecated": False, + "id": "model-abc", + "key": "claude-3", + "name": "Claude 3", + "company": "Anthropic", + "description": "Claude 3", + "released_at": 1709251200, + "parameters": 5e11, + "modality": "text", + "context_length": 100000, + "architecture_type": "transformer", + "license": "proprietary", + "open_weights": False, + "region": "us-west-2", + "deprecated": False, } - + benchmark_data = { - "id": "bench-xyz", "key": "hellaswag", "name": "HellaSwag", - "full_description": "HellaSwag benchmark", "language": "english", - "categories": ["reasoning"], "subsets": ["commonsense"], "prompt_count": 10042, "deprecated": False, + "id": "bench-xyz", + "key": "hellaswag", + "name": "HellaSwag", + "full_description": "HellaSwag benchmark", + "language": "english", + "categories": ["reasoning"], + "subsets": ["commonsense"], + "prompt_count": 10042, + "deprecated": False, } - + evaluation_data = { - "id": "eval-interaction", "status": "submitted", "status_description": "Submitted", - "submitted_at": 1640995200, "finished_at": 0, - "model_id": "model-abc", "model_name": "Claude 3", "model_key": "claude-3", "model_company": "Anthropic", - "dataset_id": "bench-xyz", "dataset_name": "HellaSwag", "average_duration": 0, - "readability_score": 0.0, "toxicity_score": 0.0, "ethics_score": 0.0, "accuracy": 0.0, + "id": "eval-interaction", + "status": "submitted", + "status_description": "Submitted", + "submitted_at": 1640995200, + "finished_at": 0, + "model_id": "model-abc", + "model_name": "Claude 3", + "model_key": "claude-3", + "model_company": "Anthropic", + "dataset_id": "bench-xyz", + "dataset_name": "HellaSwag", + "average_duration": 0, + "readability_score": 0.0, + "toxicity_score": 0.0, + "ethics_score": 0.0, + "accuracy": 0.0, } - + model = Model(**model_data) benchmark = Benchmark(**benchmark_data) evaluation = Evaluation(**evaluation_data) - + evaluations_response = EvaluationsData(data=[evaluation]) - - with patch.object(atlas_client, 'post_cast') as mock_post: + + with patch.object(atlas_client, "post_cast") as mock_post: mock_post.return_value = evaluations_response - + # Create evaluation using model and benchmark keys - created_evaluation = atlas_client.evaluations.create( - model=model.key, - benchmark=benchmark.key - ) - + created_evaluation = atlas_client.evaluations.create(model=model.key, benchmark=benchmark.key) + assert created_evaluation.id == "eval-interaction" assert created_evaluation.model_key == model.key assert created_evaluation.dataset_id == benchmark.id - + # Verify API call call_args = mock_post.call_args body = call_args.kwargs["body"][0] @@ -293,48 +386,90 @@ def test_evaluation_creation_with_model_and_benchmark_objects(self, atlas_client def test_results_analysis_workflow(self, atlas_client): """Test analyzing results from multiple evaluations.""" - + # Create multiple result objects results_data = [ { - "subset": "math", "prompt": "2+2=?", "result": "4", "truth": "4", - "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + "subset": "math", + "prompt": "2+2=?", + "result": "4", + "truth": "4", + "duration": timedelta(seconds=1.0), + "score": 1.0, + "metrics": {"accuracy": 1.0}, }, { - "subset": "math", "prompt": "3*3=?", "result": "9", "truth": "9", - "duration": timedelta(seconds=1.2), "score": 1.0, "metrics": {"accuracy": 1.0} + "subset": "math", + "prompt": "3*3=?", + "result": "9", + "truth": "9", + "duration": timedelta(seconds=1.2), + "score": 1.0, + "metrics": {"accuracy": 1.0}, }, { - "subset": "reading", "prompt": "What is the main idea?", "result": "Education", "truth": "Learning", - "duration": timedelta(seconds=2.8), "score": 0.7, "metrics": {"accuracy": 0.7} + "subset": "reading", + "prompt": "What is the main idea?", + "result": "Education", + "truth": "Learning", + "duration": timedelta(seconds=2.8), + "score": 0.7, + "metrics": {"accuracy": 0.7}, }, ] - + results = [Result(**data) for data in results_data] - results_response = ResultsData(results=results) - - with patch.object(atlas_client, 'get_cast') as mock_get: - mock_get.return_value = results_response - + results_response = ResultsData( + evaluation_id="test-eval", + results=results, + metrics={ + "total_count": 3, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.7, + "max_readability_score": 0.9, + }, + pagination={ + "total_count": 3, + "page_size": 100, + "total_pages": 1, + }, + ) + + with patch.object(atlas_client, "get_cast") as mock_get: + mock_get.return_value = { + "evaluation_id": "test-eval", + "results": results_data, + "metrics": { + "total_count": 3, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.7, + "max_readability_score": 0.9, + }, + } + # Get results evaluation_results = atlas_client.results.get(evaluation_id="test-eval") - + # Analyze results - math_results = [r for r in evaluation_results if r.subset == "math"] - reading_results = [r for r in evaluation_results if r.subset == "reading"] - + math_results = [r for r in evaluation_results.results if r.subset == "math"] + reading_results = [r for r in evaluation_results.results if r.subset == "reading"] + assert len(math_results) == 2 assert len(reading_results) == 1 - + # Calculate average scores math_avg = sum(r.score for r in math_results) / len(math_results) reading_avg = sum(r.score for r in reading_results) / len(reading_results) - + assert math_avg == 1.0 assert reading_avg == 0.7 - + # Calculate average duration - avg_duration = sum((r.duration.total_seconds() for r in evaluation_results), 0.0) / len(evaluation_results) + avg_duration = sum((r.duration.total_seconds() for r in evaluation_results.results), 0.0) / len( + evaluation_results.results + ) expected_avg = (1.0 + 1.2 + 2.8) / 3 assert abs(avg_duration - expected_avg) < 0.01 @@ -344,31 +479,23 @@ class TestAtlasClientProperties: def test_client_has_all_resource_properties(self): """Atlas client exposes all resource properties.""" - client = Atlas( - api_key="property-test-key", - organization_id="property-org", - project_id="property-project" - ) - + client = Atlas(api_key="property-test-key", organization_id="property-org", project_id="property-project") + # Verify available resource properties exist - assert hasattr(client, 'evaluations') - assert hasattr(client, 'results') - + assert hasattr(client, "evaluations") + assert hasattr(client, "results") + # Verify they are the correct types from atlas.resources.results import Results from atlas.resources.evaluations import Evaluations - + assert isinstance(client.evaluations, Evaluations) assert isinstance(client.results, Results) def test_resource_properties_share_same_client(self): """All resource properties share the same client instance.""" - client = Atlas( - api_key="shared-client-test", - organization_id="shared-org", - project_id="shared-project" - ) - + client = Atlas(api_key="shared-client-test", organization_id="shared-org", project_id="shared-project") + # Verify all resources use the same client assert client.evaluations._client is client assert client.results._client is client @@ -377,17 +504,13 @@ def test_client_configuration_propagates_to_resources(self): """Client configuration (org_id, project_id) propagates to resources.""" org_id = "config-test-org" project_id = "config-test-project" - - client = Atlas( - api_key="config-test-key", - organization_id=org_id, - project_id=project_id - ) - + + client = Atlas(api_key="config-test-key", organization_id=org_id, project_id=project_id) + # Verify configuration is available to resources assert client.organization_id == org_id assert client.project_id == project_id - + # Resources should have access to client configuration assert client.evaluations._client.organization_id == org_id assert client.evaluations._client.project_id == project_id @@ -400,61 +523,91 @@ class TestConcurrentOperations: def test_multiple_atlas_clients_independent(self): """Multiple Atlas client instances operate independently.""" - - client1 = Atlas( - api_key="client-1-key", - organization_id="org-1", - project_id="project-1" - ) - - client2 = Atlas( - api_key="client-2-key", - organization_id="org-2", - project_id="project-2" - ) - + + client1 = Atlas(api_key="client-1-key", organization_id="org-1", project_id="project-1") + + client2 = Atlas(api_key="client-2-key", organization_id="org-2", project_id="project-2") + # Verify clients are independent assert client1.api_key != client2.api_key assert client1.organization_id != client2.organization_id assert client1.project_id != client2.project_id - + # Verify resources are independent assert client1.evaluations._client is not client2.evaluations._client assert client1.results._client is not client2.results._client def test_resource_operations_isolated(self): """Operations on different client resources are isolated.""" - + client1 = Atlas(api_key="iso-test-1", organization_id="org-1", project_id="proj-1") client2 = Atlas(api_key="iso-test-2", organization_id="org-2", project_id="proj-2") - + result_data = { - "subset": "test", "prompt": "test", "result": "test", "truth": "test", - "duration": timedelta(seconds=1.0), "score": 1.0, "metrics": {"accuracy": 1.0} + "subset": "test", + "prompt": "test", + "result": "test", + "truth": "test", + "duration": timedelta(seconds=1.0), + "score": 1.0, + "metrics": {"accuracy": 1.0}, } - - results_response = ResultsData(results=[Result(**result_data)]) - - with patch.object(client1, 'get_cast') as mock_get1, \ - patch.object(client2, 'get_cast') as mock_get2: - - mock_get1.return_value = results_response - mock_get2.return_value = results_response - + + results_response = ResultsData( + evaluation_id="test-eval", + results=[Result(**result_data)], + metrics={ + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + pagination={ + "total_count": 1, + "page_size": 100, + "total_pages": 1, + }, + ) + + with patch.object(client1, "get_cast") as mock_get1, patch.object(client2, "get_cast") as mock_get2: + mock_get1.return_value = { + "evaluation_id": "test-eval", + "results": [result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + mock_get2.return_value = { + "evaluation_id": "test-eval", + "results": [result_data], + "metrics": { + "total_count": 1, + "min_toxicity_score": 0.0, + "max_toxicity_score": 0.1, + "min_readability_score": 0.8, + "max_readability_score": 0.9, + }, + } + # Make calls on both clients results1 = client1.results.get(evaluation_id="eval-1") results2 = client2.results.get(evaluation_id="eval-2") - + # Verify both calls succeeded assert results1 is not None - assert len(results1) == 1 + assert len(results1.results) == 1 assert results2 is not None - assert len(results2) == 1 - + assert len(results2.results) == 1 + # Verify calls were made to correct clients mock_get1.assert_called_once() mock_get2.assert_called_once() - + # Verify different parameters were used call1_params = mock_get1.call_args.kwargs["params"] call2_params = mock_get2.call_args.kwargs["params"] @@ -468,34 +621,28 @@ class TestErrorPropagation: def test_evaluation_workflow_error_propagation(self): """Errors in evaluation workflow are properly propagated.""" from atlas._exceptions import APIStatusError, APIConnectionError - - client = Atlas( - api_key="error-test-key", - organization_id="error-org", - project_id="error-project" - ) - + + client = Atlas(api_key="error-test-key", organization_id="error-org", project_id="error-project") + mock_response = Mock() mock_response.status_code = 500 mock_response.headers = {} - + # Test different types of errors api_error = APIStatusError("Server Error", response=mock_response, body=None) connection_error = APIConnectionError(request=Mock()) - - with patch.object(client, 'get_cast') as mock_get, \ - patch.object(client, 'post_cast') as mock_post: - + + with patch.object(client, "get_cast") as mock_get, patch.object(client, "post_cast") as mock_post: # Test API error in results.get mock_get.side_effect = api_error with pytest.raises(APIStatusError): client.results.get(evaluation_id="test-eval") - + # Test connection error in evaluations.create mock_post.side_effect = connection_error with pytest.raises(APIConnectionError): client.evaluations.create(model="gpt-4", benchmark="mmlu") - + # Verify errors didn't interfere with each other assert mock_get.called - assert mock_post.called \ No newline at end of file + assert mock_post.called diff --git a/tests/test_models.py b/tests/test_models.py index 14784e7..4d26157 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,8 +11,10 @@ Benchmark, Benchmarks, Evaluation, + Pagination, CustomModel, Evaluations, + ResultMetrics, CustomBenchmark, ) @@ -45,7 +47,7 @@ def valid_evaluation_data(self): def test_evaluation_creation_with_valid_data(self, valid_evaluation_data): """Evaluation model creates successfully with valid data.""" evaluation = Evaluation(**valid_evaluation_data) - + assert evaluation.id == "eval-123" assert evaluation.status == "completed" assert evaluation.model_name == "GPT-4" @@ -55,7 +57,7 @@ def test_evaluation_creation_with_valid_data(self, valid_evaluation_data): def test_evaluation_field_types(self, valid_evaluation_data): """Evaluation model enforces correct field types.""" evaluation = Evaluation(**valid_evaluation_data) - + assert isinstance(evaluation.id, str) assert isinstance(evaluation.submitted_at, int) assert isinstance(evaluation.readability_score, float) @@ -68,8 +70,8 @@ def test_evaluation_validation_errors(self, valid_evaluation_data): invalid_data["id"] = 123 with pytest.raises(ValidationError): Evaluation(**invalid_data) - - # Test int field with wrong type + + # Test int field with wrong type invalid_data = valid_evaluation_data.copy() invalid_data["submitted_at"] = "not-an-int" with pytest.raises(ValidationError): @@ -78,10 +80,10 @@ def test_evaluation_validation_errors(self, valid_evaluation_data): def test_evaluation_missing_required_fields(self): """Evaluation model requires all fields.""" incomplete_data = {"id": "eval-123", "status": "pending"} - + with pytest.raises(ValidationError) as exc_info: Evaluation(**incomplete_data) # type: ignore[arg-type] - + errors = exc_info.value.errors() assert len(errors) > 5 @@ -89,7 +91,7 @@ def test_evaluation_json_serialization(self, valid_evaluation_data): """Evaluation model serializes to JSON correctly.""" evaluation = Evaluation(**valid_evaluation_data) json_data = evaluation.model_dump() - + assert json_data["id"] == "eval-123" assert json_data["accuracy"] == 0.89 assert isinstance(json_data, dict) @@ -124,7 +126,7 @@ def test_evaluations_with_list_of_evaluations(self, evaluation_data): """Evaluations model accepts list of Evaluation objects.""" evaluations_data = {"data": [evaluation_data, evaluation_data]} evaluations = Evaluations(**evaluations_data) - + assert len(evaluations.data) == 2 assert all(isinstance(eval, Evaluation) for eval in evaluations.data) assert evaluations.data[0].id == "eval-1" @@ -132,12 +134,12 @@ def test_evaluations_with_list_of_evaluations(self, evaluation_data): def test_evaluations_empty_list(self): """Evaluations model accepts empty list.""" evaluations = Evaluations(data=[]) - + assert evaluations.data == [] assert isinstance(evaluations.data, list) def test_evaluations_invalid_data_structure(self): - """Evaluations model validates data structure.""" + """Evaluations model validates data structure.""" with pytest.raises(ValidationError): Evaluations(data="not-a-list") # type: ignore[arg-type] @@ -161,7 +163,7 @@ def valid_result_data(self): def test_result_creation(self, valid_result_data): """Result model creates with valid data.""" result = Result(**valid_result_data) - + assert result.subset == "math" assert result.prompt == "What is 2+2?" assert result.score == 1.0 @@ -171,14 +173,14 @@ def test_result_creation(self, valid_result_data): def test_result_timedelta_handling(self, valid_result_data): """Result model handles timedelta correctly.""" result = Result(**valid_result_data) - + assert result.duration == timedelta(seconds=1.5) assert result.duration.total_seconds() == 1.5 def test_result_metrics_validation(self, valid_result_data): """Result model validates metrics as dict.""" result = Result(**valid_result_data) - + assert result.metrics["accuracy"] == 1.0 assert result.metrics["confidence"] == 0.95 assert len(result.metrics) == 2 @@ -187,17 +189,118 @@ def test_result_invalid_metrics_type(self, valid_result_data): """Result model rejects invalid metrics type.""" invalid_data = valid_result_data.copy() invalid_data["metrics"] = "not-a-dict" - + with pytest.raises(ValidationError): Result(**invalid_data) +class TestResultMetrics: + """Test ResultMetrics model.""" + + @pytest.fixture + def valid_metrics_data(self): + """Valid result metrics data for testing.""" + return { + "total_count": 150, + } + + def test_result_metrics_creation(self, valid_metrics_data): + """ResultMetrics model creates with valid data.""" + metrics = ResultMetrics(**valid_metrics_data) + + assert metrics.total_count == 150 + + def test_result_metrics_optional_fields(self): + """ResultMetrics model handles optional score fields.""" + metrics = ResultMetrics( + total_count=100, + ) + + assert metrics.total_count == 100 + + def test_result_metrics_field_types(self, valid_metrics_data): + """ResultMetrics model enforces correct field types.""" + metrics = ResultMetrics(**valid_metrics_data) + + assert isinstance(metrics.total_count, int) + + def test_result_metrics_invalid_total_count(self, valid_metrics_data): + """ResultMetrics model validates total_count as integer.""" + invalid_data = valid_metrics_data.copy() + invalid_data["total_count"] = "not-an-int" + + with pytest.raises(ValidationError): + ResultMetrics(**invalid_data) + + +class TestPagination: + """Test Pagination model.""" + + @pytest.fixture + def valid_pagination_data(self): + """Valid pagination data for testing.""" + return { + "total_count": 250, + "page_size": 100, + "total_pages": 3, + } + + def test_pagination_creation(self, valid_pagination_data): + """Pagination model creates with valid data.""" + pagination = Pagination(**valid_pagination_data) + + assert pagination.total_count == 250 + assert pagination.page_size == 100 + assert pagination.total_pages == 3 + + def test_pagination_field_types(self, valid_pagination_data): + """Pagination model enforces correct field types.""" + pagination = Pagination(**valid_pagination_data) + + assert isinstance(pagination.total_count, int) + assert isinstance(pagination.page_size, int) + assert isinstance(pagination.total_pages, int) + + def test_pagination_zero_values(self): + """Pagination model handles zero values correctly.""" + pagination = Pagination( + total_count=0, + page_size=100, + total_pages=0, + ) + + assert pagination.total_count == 0 + assert pagination.page_size == 100 + assert pagination.total_pages == 0 + + def test_pagination_validation_errors(self, valid_pagination_data): + """Pagination model validates field types.""" + # Test invalid total_count + invalid_data = valid_pagination_data.copy() + invalid_data["total_count"] = "not-an-int" + with pytest.raises(ValidationError): + Pagination(**invalid_data) + + # Test invalid page_size + invalid_data = valid_pagination_data.copy() + invalid_data["page_size"] = 3.14 + with pytest.raises(ValidationError): + Pagination(**invalid_data) + + # Test invalid total_pages + invalid_data = valid_pagination_data.copy() + invalid_data["total_pages"] = "not-an-int" + with pytest.raises(ValidationError): + Pagination(**invalid_data) + + class TestResults: - """Test Results collection model.""" + """Test Results collection model with pagination.""" - def test_results_with_result_list(self): - """Results model accepts list of Result objects.""" - result_data = { + @pytest.fixture + def valid_result_data(self): + """Valid result data for testing.""" + return { "subset": "test", "prompt": "test prompt", "result": "test result", @@ -206,10 +309,117 @@ def test_results_with_result_list(self): "score": 0.8, "metrics": {"score": 0.8}, } - results = Results(results=[result_data, result_data]) # type: ignore[arg-type] - + + @pytest.fixture + def valid_metrics_data(self): + """Valid metrics data for testing.""" + return { + "total_count": 150, + } + + @pytest.fixture + def valid_pagination_data(self): + """Valid pagination data for testing.""" + return { + "total_count": 150, + "page_size": 100, + "total_pages": 2, + } + + def test_results_with_pagination(self, valid_result_data, valid_metrics_data, valid_pagination_data): + """Results model accepts all required fields including pagination.""" + results = Results( + evaluation_id="eval-123", + results=[valid_result_data, valid_result_data], + metrics=valid_metrics_data, + pagination=valid_pagination_data, + ) + + assert results.evaluation_id == "eval-123" assert len(results.results) == 2 assert all(isinstance(result, Result) for result in results.results) + assert isinstance(results.metrics, ResultMetrics) + assert isinstance(results.pagination, Pagination) + assert results.pagination.total_count == 150 + assert results.pagination.page_size == 100 + assert results.pagination.total_pages == 2 + + def test_results_field_types(self, valid_result_data, valid_metrics_data, valid_pagination_data): + """Results model enforces correct field types.""" + results = Results( + evaluation_id="eval-456", + results=[valid_result_data], + metrics=valid_metrics_data, + pagination=valid_pagination_data, + ) + + assert isinstance(results.evaluation_id, str) + assert isinstance(results.results, list) + assert isinstance(results.metrics, ResultMetrics) + assert isinstance(results.pagination, Pagination) + + def test_results_empty_results_list(self, valid_metrics_data, valid_pagination_data): + """Results model handles empty results list.""" + results = Results( + evaluation_id="eval-empty", + results=[], + metrics=valid_metrics_data, + pagination=valid_pagination_data, + ) + + assert results.evaluation_id == "eval-empty" + assert len(results.results) == 0 + assert isinstance(results.results, list) + assert isinstance(results.metrics, ResultMetrics) + assert isinstance(results.pagination, Pagination) + + def test_results_validation_errors(self, valid_result_data, valid_metrics_data, valid_pagination_data): + """Results model validates required fields.""" + # Test missing evaluation_id + with pytest.raises(ValidationError): + Results( + results=[valid_result_data], + metrics=valid_metrics_data, + pagination=valid_pagination_data, + ) + + # Test missing metrics + with pytest.raises(ValidationError): + Results( + evaluation_id="eval-123", + results=[valid_result_data], + pagination=valid_pagination_data, + ) + + # Test missing pagination + with pytest.raises(ValidationError): + Results( + evaluation_id="eval-123", + results=[valid_result_data], + metrics=valid_metrics_data, + ) + + def test_results_nested_model_validation(self, valid_result_data, valid_pagination_data): + """Results model validates nested models.""" + # Test invalid metrics + with pytest.raises(ValidationError): + Results( + evaluation_id="eval-123", + results=[valid_result_data], + metrics="invalid-metrics", # Should be ResultMetrics object + pagination=valid_pagination_data, + ) + + # Test invalid pagination + with pytest.raises(ValidationError): + Results( + evaluation_id="eval-123", + results=[valid_result_data], + metrics={ + "total_count": 100, + }, + pagination="invalid-pagination", # Should be Pagination object + ) class TestModel: @@ -238,7 +448,7 @@ def valid_model_data(self): def test_model_creation(self, valid_model_data): """Model creates with valid data.""" model = Model(**valid_model_data) - + assert model.id == "model-123" assert model.name == "GPT-4" assert model.parameters == 1.76e12 @@ -248,7 +458,7 @@ def test_model_creation(self, valid_model_data): def test_model_boolean_fields(self, valid_model_data): """Model handles boolean fields correctly.""" model = Model(**valid_model_data) - + assert isinstance(model.open_weights, bool) assert isinstance(model.deprecated, bool) assert model.open_weights is False @@ -256,7 +466,7 @@ def test_model_boolean_fields(self, valid_model_data): def test_model_numeric_fields(self, valid_model_data): """Model validates numeric fields.""" model = Model(**valid_model_data) - + assert isinstance(model.parameters, float) assert isinstance(model.context_length, int) assert isinstance(model.released_at, int) @@ -268,7 +478,7 @@ def test_model_field_validation(self, valid_model_data): invalid_data["parameters"] = "not-a-number" with pytest.raises(ValidationError): Model(**invalid_data) - + # Test int field validation invalid_data = valid_model_data.copy() invalid_data["context_length"] = "not-an-int" @@ -295,7 +505,7 @@ def valid_custom_model_data(self): def test_custom_model_creation(self, valid_custom_model_data): """CustomModel creates with valid data.""" model = CustomModel(**valid_custom_model_data) - + assert model.id == "custom-123" assert model.max_tokens == 4096 assert model.api_url == "https://api.example.com/v1/chat" @@ -304,7 +514,7 @@ def test_custom_model_creation(self, valid_custom_model_data): def test_custom_model_url_validation(self, valid_custom_model_data): """CustomModel stores URL as string.""" model = CustomModel(**valid_custom_model_data) - + assert isinstance(model.api_url, str) assert model.api_url.startswith("https://") @@ -330,7 +540,7 @@ def test_models_with_mixed_model_types(self): "region": "us-east-1", "deprecated": False, } - + custom_model_data = { "id": "custom-1", "key": "my-model", @@ -340,9 +550,9 @@ def test_models_with_mixed_model_types(self): "api_url": "https://api.example.com", "disabled": False, } - + models = Models(models=[model_data, custom_model_data]) # type: ignore[arg-type] - + assert len(models.models) == 2 assert isinstance(models.models[0], Model) assert isinstance(models.models[1], CustomModel) @@ -369,7 +579,7 @@ def valid_benchmark_data(self): def test_benchmark_creation(self, valid_benchmark_data): """Benchmark creates with valid data.""" benchmark = Benchmark(**valid_benchmark_data) - + assert benchmark.id == "bench-123" assert benchmark.name == "MMLU" assert len(benchmark.categories) == 2 @@ -379,7 +589,7 @@ def test_benchmark_creation(self, valid_benchmark_data): def test_benchmark_list_fields(self, valid_benchmark_data): """Benchmark handles list fields correctly.""" benchmark = Benchmark(**valid_benchmark_data) - + assert isinstance(benchmark.categories, list) assert isinstance(benchmark.subsets, list) assert "reasoning" in benchmark.categories @@ -413,7 +623,7 @@ def valid_custom_benchmark_data(self): def test_custom_benchmark_creation(self, valid_custom_benchmark_data): """CustomBenchmark creates with all fields.""" benchmark = CustomBenchmark(**valid_custom_benchmark_data) - + assert benchmark.id == "custom-bench-123" assert benchmark.system_prompt == "You are a helpful assistant" assert benchmark.regex_pattern == r"Answer: (.+)" @@ -438,9 +648,9 @@ def test_custom_benchmark_optional_fields(self): "files": ["test.jsonl"], "disabled": False, } - + benchmark = CustomBenchmark(**minimal_data) - + assert benchmark.system_prompt is None assert benchmark.regex_pattern is None assert benchmark.scoring_metric is None @@ -462,10 +672,10 @@ def test_benchmarks_with_datasets_alias(self): "prompt_count": 10, "deprecated": False, } - + # Using the alias 'datasets' benchmarks = Benchmarks(datasets=[benchmark_data]) # type: ignore[arg-type] - + assert len(benchmarks.benchmarks) == 1 assert isinstance(benchmarks.benchmarks[0], Benchmark) @@ -477,15 +687,15 @@ def test_benchmarks_field_validation(self): "key": "test", "name": "Test", "full_description": "Test benchmark", - "language": "english", + "language": "english", "categories": ["test"], "subsets": ["test"], "prompt_count": 10, "deprecated": False, } - + benchmarks = Benchmarks(datasets=[benchmark_data]) # type: ignore[arg-type] - + assert len(benchmarks.benchmarks) == 1 @@ -512,12 +722,12 @@ def test_round_trip_serialization(self): "ethics_score": 0.92, "accuracy": 0.89, } - + # Create model, serialize, then deserialize evaluation = Evaluation(**original_data) serialized = evaluation.model_dump() deserialized = Evaluation(**serialized) - + assert deserialized.id == evaluation.id assert deserialized.accuracy == evaluation.accuracy assert deserialized == evaluation @@ -525,7 +735,7 @@ def test_round_trip_serialization(self): def test_json_compatibility(self): """Models work with JSON serialization.""" import json - + model_data = { "id": "model-123", "key": "gpt-4", @@ -542,11 +752,11 @@ def test_json_compatibility(self): "region": "us-east-1", "deprecated": False, } - + model = Model(**model_data) json_str = json.dumps(model.model_dump()) parsed_data = json.loads(json_str) reconstructed = Model(**parsed_data) - + assert reconstructed.name == model.name - assert reconstructed.parameters == model.parameters \ No newline at end of file + assert reconstructed.parameters == model.parameters diff --git a/tests/test_resource.py b/tests/test_resource.py index e1af72e..254d016 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -24,7 +24,7 @@ def resource_instance(self, mock_client): def test_resource_initialization(self, mock_client): """SyncAPIResource initializes correctly with client.""" resource = SyncAPIResource(mock_client) - + assert resource._client is mock_client assert resource._get is mock_client.get_cast assert resource._post is mock_client.post_cast @@ -32,15 +32,15 @@ def test_resource_initialization(self, mock_client): def test_resource_stores_client_reference(self, resource_instance, mock_client): """Resource maintains reference to the client.""" assert resource_instance._client is mock_client - assert hasattr(resource_instance, '_client') + assert hasattr(resource_instance, "_client") def test_resource_delegates_get_to_client(self, resource_instance, mock_client): """_get method delegates to client.get_cast.""" assert resource_instance._get is mock_client.get_cast - + # Verify it's the same method reference assert callable(resource_instance._get) - + # Test delegation works resource_instance._get("/test", params={"key": "value"}) mock_client.get_cast.assert_called_once_with("/test", params={"key": "value"}) @@ -48,50 +48,50 @@ def test_resource_delegates_get_to_client(self, resource_instance, mock_client): def test_resource_delegates_post_to_client(self, resource_instance, mock_client): """_post method delegates to client.post_cast.""" assert resource_instance._post is mock_client.post_cast - + # Verify it's the same method reference assert callable(resource_instance._post) - + # Test delegation works resource_instance._post("/test", body={"data": "test"}) mock_client.post_cast.assert_called_once_with("/test", body={"data": "test"}) def test_resource_sleep_method_exists(self, resource_instance): """Resource has _sleep method.""" - assert hasattr(resource_instance, '_sleep') + assert hasattr(resource_instance, "_sleep") assert callable(resource_instance._sleep) - @patch('time.sleep') + @patch("time.sleep") def test_resource_sleep_delegates_to_time_sleep(self, mock_time_sleep, resource_instance): """_sleep method delegates to time.sleep.""" sleep_duration = 2.5 - + resource_instance._sleep(sleep_duration) - + mock_time_sleep.assert_called_once_with(sleep_duration) - @patch('time.sleep') + @patch("time.sleep") def test_resource_sleep_with_different_durations(self, mock_time_sleep, resource_instance): """_sleep method works with various duration values.""" durations = [0.1, 1.0, 5.0, 10.5, 60.0] - + for duration in durations: mock_time_sleep.reset_mock() resource_instance._sleep(duration) mock_time_sleep.assert_called_once_with(duration) - @patch('time.sleep') + @patch("time.sleep") def test_resource_sleep_with_zero_duration(self, mock_time_sleep, resource_instance): """_sleep method handles zero duration.""" resource_instance._sleep(0.0) - + mock_time_sleep.assert_called_once_with(0.0) - @patch('time.sleep') + @patch("time.sleep") def test_resource_sleep_with_integer_duration(self, mock_time_sleep, resource_instance): """_sleep method handles integer duration values.""" resource_instance._sleep(3) - + mock_time_sleep.assert_called_once_with(3) def test_resource_initialization_with_different_clients(self): @@ -100,24 +100,24 @@ def test_resource_initialization_with_different_clients(self): client1 = Mock() client1.get_cast = Mock(return_value="get_result_1") client1.post_cast = Mock(return_value="post_result_1") - - client2 = Mock() + + client2 = Mock() client2.get_cast = Mock(return_value="get_result_2") client2.post_cast = Mock(return_value="post_result_2") - + resource1 = SyncAPIResource(client1) resource2 = SyncAPIResource(client2) - + # Verify each resource uses its own client assert resource1._client is client1 assert resource2._client is client2 assert resource1._get is client1.get_cast assert resource2._get is client2.get_cast - + # Verify method calls go to correct clients result1 = resource1._get("/test1") result2 = resource2._get("/test2") - + assert result1 == "get_result_1" assert result2 == "get_result_2" client1.get_cast.assert_called_once_with("/test1") @@ -129,97 +129,97 @@ class TestSyncAPIResourceInheritance: def test_resource_can_be_subclassed(self): """SyncAPIResource can be subclassed for specific resources.""" - + class TestResource(SyncAPIResource): def get_data(self, id: str): return self._get(f"/data/{id}") - + def create_data(self, data: dict): return self._post("/data", body=data) - + mock_client = Mock() mock_client.get_cast = Mock(return_value={"id": "123", "data": "test"}) mock_client.post_cast = Mock(return_value={"id": "456", "created": True}) - + resource = TestResource(mock_client) - + # Test inherited initialization assert resource._client is mock_client assert resource._get is mock_client.get_cast assert resource._post is mock_client.post_cast - + # Test custom methods using inherited functionality get_result = resource.get_data("123") create_result = resource.create_data({"name": "test"}) - + assert get_result == {"id": "123", "data": "test"} assert create_result == {"id": "456", "created": True} - + mock_client.get_cast.assert_called_once_with("/data/123") mock_client.post_cast.assert_called_once_with("/data", body={"name": "test"}) def test_subclass_can_override_methods(self): """Subclasses can override resource methods.""" - + class CustomResource(SyncAPIResource): def __init__(self, client): super().__init__(client) self.custom_property = "custom_value" - + def _sleep(self, seconds: float) -> None: # Custom sleep implementation self.last_sleep_duration = seconds super()._sleep(seconds) - + mock_client = Mock() mock_client.get_cast = Mock() mock_client.post_cast = Mock() - + resource = CustomResource(mock_client) - + # Test custom property assert resource.custom_property == "custom_value" - + # Test overridden method - with patch('time.sleep') as mock_time_sleep: + with patch("time.sleep") as mock_time_sleep: resource._sleep(1.5) - + assert resource.last_sleep_duration == 1.5 mock_time_sleep.assert_called_once_with(1.5) def test_multiple_resource_instances_independent(self): """Multiple resource instances maintain independence.""" - + class ResourceA(SyncAPIResource): def method_a(self): return self._get("/resource-a") - + class ResourceB(SyncAPIResource): def method_b(self): return self._post("/resource-b", body={"type": "b"}) - + client1 = Mock() client1.get_cast = Mock(return_value="result_a") client1.post_cast = Mock() - + client2 = Mock() client2.get_cast = Mock() client2.post_cast = Mock(return_value="result_b") - + resource_a = ResourceA(client1) resource_b = ResourceB(client2) - + # Test that resources are independent result_a = resource_a.method_a() result_b = resource_b.method_b() - + assert result_a == "result_a" assert result_b == "result_b" - + # Verify correct clients were called client1.get_cast.assert_called_once_with("/resource-a") client2.post_cast.assert_called_once_with("/resource-b", body={"type": "b"}) - + # Verify cross-contamination didn't occur client1.post_cast.assert_not_called() client2.get_cast.assert_not_called() @@ -244,25 +244,25 @@ def resource_instance(self, mock_client): def test_resource_propagates_get_errors(self, resource_instance, mock_client): """Resource propagates errors from _get calls.""" from atlas._exceptions import APIStatusError - + mock_response = Mock() mock_response.status_code = 404 mock_response.headers = {} - + api_error = APIStatusError("Not Found", response=mock_response, body=None) mock_client.get_cast.side_effect = api_error - + with pytest.raises(APIStatusError): resource_instance._get("/test") def test_resource_propagates_post_errors(self, resource_instance, mock_client): """Resource propagates errors from _post calls.""" from atlas._exceptions import APIConnectionError - + mock_request = Mock() connection_error = APIConnectionError(request=mock_request) mock_client.post_cast.side_effect = connection_error - + with pytest.raises(APIConnectionError): resource_instance._post("/test", body={"data": "test"}) @@ -270,19 +270,19 @@ def test_resource_handles_client_method_missing(self): """Resource handles clients missing required methods gracefully.""" # Create a client without the required methods incomplete_client = object() # Plain object with no methods - + # This should fail during initialization since the methods don't exist with pytest.raises(AttributeError): SyncAPIResource(incomplete_client) # type: ignore[arg-type] - @patch('time.sleep') + @patch("time.sleep") def test_resource_sleep_handles_exceptions(self, mock_time_sleep, resource_instance): """_sleep method handles exceptions from time.sleep.""" mock_time_sleep.side_effect = KeyboardInterrupt("Interrupted") - + with pytest.raises(KeyboardInterrupt): resource_instance._sleep(1.0) - + mock_time_sleep.assert_called_once_with(1.0) @@ -291,39 +291,39 @@ class TestSyncAPIResourceTyping: def test_resource_client_attribute_typing(self): """Resource._client maintains proper typing.""" - + # Test with properly typed client (would be Atlas in real usage) mock_client = Mock() mock_client.get_cast = Mock() mock_client.post_cast = Mock() - + resource = SyncAPIResource(mock_client) - + # Verify the client is stored and accessible assert resource._client is mock_client - assert hasattr(resource, '_client') + assert hasattr(resource, "_client") def test_resource_method_signatures(self): """Resource methods have expected signatures.""" import inspect - + # Check _sleep method signature sleep_sig = inspect.signature(SyncAPIResource._sleep) sleep_params = list(sleep_sig.parameters.keys()) - - assert 'self' in sleep_params - assert 'seconds' in sleep_params + + assert "self" in sleep_params + assert "seconds" in sleep_params assert len(sleep_params) == 2 def test_resource_initialization_signature(self): """Resource __init__ has expected signature.""" import inspect - + init_sig = inspect.signature(SyncAPIResource.__init__) init_params = list(init_sig.parameters.keys()) - - assert 'self' in init_params - assert 'client' in init_params + + assert "self" in init_params + assert "client" in init_params assert len(init_params) == 2 @@ -332,7 +332,7 @@ class TestSyncAPIResourceRealWorldUsage: def test_resource_with_retry_logic(self): """Resource can implement retry logic using _sleep.""" - + class RetryableResource(SyncAPIResource): def get_with_retry(self, url: str, max_retries: int = 3): for attempt in range(max_retries): @@ -341,21 +341,17 @@ def get_with_retry(self, url: str, max_retries: int = 3): except Exception as e: if attempt == max_retries - 1: raise - self._sleep(2 ** attempt) # Exponential backoff - + self._sleep(2**attempt) # Exponential backoff + mock_client = Mock() # First two calls fail, third succeeds - mock_client.get_cast.side_effect = [ - Exception("First failure"), - Exception("Second failure"), - {"success": True} - ] - + mock_client.get_cast.side_effect = [Exception("First failure"), Exception("Second failure"), {"success": True}] + resource = RetryableResource(mock_client) - - with patch.object(resource, '_sleep') as mock_sleep: + + with patch.object(resource, "_sleep") as mock_sleep: result = resource.get_with_retry("/test") - + assert result == {"success": True} assert mock_client.get_cast.call_count == 3 assert mock_sleep.call_count == 2 @@ -364,13 +360,13 @@ def get_with_retry(self, url: str, max_retries: int = 3): def test_resource_with_complex_workflow(self): """Resource can implement complex workflows.""" - + class WorkflowResource(SyncAPIResource): def create_and_wait(self, data: dict, poll_interval: float = 1.0): # Create resource created = self._post("/create", body=data) resource_id = created["id"] # type: ignore[index] - + # Poll until complete while True: status = self._get(f"/status/{resource_id}") @@ -378,26 +374,26 @@ def create_and_wait(self, data: dict, poll_interval: float = 1.0): return self._get(f"/result/{resource_id}") elif status["state"] == "failed": # type: ignore[index] raise Exception("Workflow failed") - + self._sleep(poll_interval) - + mock_client = Mock() mock_client.post_cast.return_value = {"id": "workflow-123"} - + # Mock status progression: pending -> running -> completed mock_client.get_cast.side_effect = [ {"state": "pending"}, {"state": "running"}, {"state": "completed"}, - {"result": "workflow complete"} + {"result": "workflow complete"}, ] - + resource = WorkflowResource(mock_client) - - with patch.object(resource, '_sleep') as mock_sleep: + + with patch.object(resource, "_sleep") as mock_sleep: result = resource.create_and_wait({"name": "test"}) - + assert result == {"result": "workflow complete"} assert mock_client.post_cast.call_count == 1 assert mock_client.get_cast.call_count == 4 - assert mock_sleep.call_count == 2 # Two sleeps during polling \ No newline at end of file + assert mock_sleep.call_count == 2 # Two sleeps during polling diff --git a/tests/test_utils.py b/tests/test_utils.py index 2c026ae..85b438b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,31 +18,34 @@ class TestTypeguards: def test_is_dict_with_dict(self): """is_dict returns True for dict objects.""" test_dict = {"key": "value", "number": 42} - + assert is_dict(test_dict) is True def test_is_dict_with_empty_dict(self): """is_dict returns True for empty dict.""" empty_dict = {} - + assert is_dict(empty_dict) is True def test_is_dict_with_nested_dict(self): """is_dict returns True for nested dict structures.""" nested_dict = {"outer": {"inner": {"deep": "value"}}} - + assert is_dict(nested_dict) is True - @pytest.mark.parametrize("non_dict_value", [ - "string", - 123, - [1, 2, 3], - (1, 2, 3), - {"key", "value"}, # set - None, - True, - object(), - ]) + @pytest.mark.parametrize( + "non_dict_value", + [ + "string", + 123, + [1, 2, 3], + (1, 2, 3), + {"key", "value"}, # set + None, + True, + object(), + ], + ) def test_is_dict_with_non_dict_objects(self, non_dict_value): """is_dict returns False for non-dict objects.""" assert is_dict(non_dict_value) is False @@ -50,49 +53,53 @@ def test_is_dict_with_non_dict_objects(self, non_dict_value): def test_is_mapping_with_dict(self): """is_mapping returns True for dict objects.""" test_dict = {"key": "value"} - + assert is_mapping(test_dict) is True def test_is_mapping_with_custom_mapping(self): """is_mapping returns True for custom Mapping implementations.""" from collections import UserDict, OrderedDict - + ordered_dict = OrderedDict([("a", 1), ("b", 2)]) user_dict = UserDict({"x": 10, "y": 20}) - + assert is_mapping(ordered_dict) is True assert is_mapping(user_dict) is True def test_is_mapping_with_mapping_subclass(self): """is_mapping returns True for Mapping subclasses.""" + class CustomMapping(Mapping): def __init__(self): self._data = {"custom": "mapping"} - + def __getitem__(self, key): return self._data[key] - + def __iter__(self): return iter(self._data) - + def __len__(self): return len(self._data) - + custom_mapping = CustomMapping() - + assert is_mapping(custom_mapping) is True assert custom_mapping["custom"] == "mapping" - @pytest.mark.parametrize("non_mapping_value", [ - "string", - 123, - [1, 2, 3], - (1, 2, 3), - {"key", "value"}, # set - None, - True, - object(), - ]) + @pytest.mark.parametrize( + "non_mapping_value", + [ + "string", + 123, + [1, 2, 3], + (1, 2, 3), + {"key", "value"}, # set + None, + True, + object(), + ], + ) def test_is_mapping_with_non_mapping_objects(self, non_mapping_value): """is_mapping returns False for non-mapping objects.""" assert is_mapping(non_mapping_value) is False @@ -131,31 +138,31 @@ def mock_log_record(self): def test_filter_initialization(self): """SensitiveHeadersFilter initializes correctly.""" filter_instance = SensitiveHeadersFilter() - + assert isinstance(filter_instance, logging.Filter) - assert hasattr(filter_instance, 'filter') + assert hasattr(filter_instance, "filter") def test_filter_returns_true_by_default(self, filter_instance, mock_log_record): """filter method always returns True to allow logging.""" result = filter_instance.filter(mock_log_record) - + assert result is True def test_filter_handles_record_without_headers(self, filter_instance, mock_log_record): """filter handles log records without headers gracefully.""" mock_log_record.args = {"message": "test", "data": "value"} - + result = filter_instance.filter(mock_log_record) - + assert result is True assert mock_log_record.args["message"] == "test" def test_filter_handles_non_dict_args(self, filter_instance, mock_log_record): """filter handles log records with non-dict args.""" mock_log_record.args = "string args" - + result = filter_instance.filter(mock_log_record) - + assert result is True def test_filter_redacts_sensitive_headers(self, filter_instance, mock_log_record): @@ -168,9 +175,9 @@ def test_filter_redacts_sensitive_headers(self, filter_instance, mock_log_record "user-agent": "atlas-python-sdk", } } - + result = filter_instance.filter(mock_log_record) - + assert result is True headers = mock_log_record.args["headers"] assert headers["content-type"] == "application/json" @@ -187,9 +194,9 @@ def test_filter_handles_case_insensitive_headers(self, filter_instance, mock_log "AUTHORIZATION": "Bearer another-token", } } - + result = filter_instance.filter(mock_log_record) - + assert result is True headers = mock_log_record.args["headers"] assert headers["X-API-KEY"] == "" @@ -205,12 +212,12 @@ def test_filter_preserves_original_args_structure(self, filter_instance, mock_lo "x-api-key": "secret", "content-type": "application/json", }, - "body": {"data": "test"} + "body": {"data": "test"}, } mock_log_record.args = original_args - + result = filter_instance.filter(mock_log_record) - + assert result is True assert mock_log_record.args["method"] == "POST" assert mock_log_record.args["url"] == "/test" @@ -224,9 +231,9 @@ def test_filter_creates_copy_of_headers(self, filter_instance, mock_log_record): "content-type": "application/json", } mock_log_record.args = {"headers": original_headers} - + filter_instance.filter(mock_log_record) - + # Original headers should be unchanged assert original_headers["x-api-key"] == "secret-key" # Record headers should be modified @@ -243,9 +250,9 @@ def test_filter_handles_non_string_header_keys(self, filter_instance, mock_log_r ("tuple", "key"): "tuple-value", } } - + result = filter_instance.filter(mock_log_record) - + assert result is True headers = mock_log_record.args["headers"] assert headers[123] == "numeric-key" # Non-string keys unchanged @@ -254,13 +261,10 @@ def test_filter_handles_non_string_header_keys(self, filter_instance, mock_log_r def test_filter_handles_non_dict_headers(self, filter_instance, mock_log_record): """filter handles cases where headers is not a dict.""" - mock_log_record.args = { - "headers": "not-a-dict", - "other": "data" - } - + mock_log_record.args = {"headers": "not-a-dict", "other": "data"} + result = filter_instance.filter(mock_log_record) - + assert result is True assert mock_log_record.args["headers"] == "not-a-dict" assert mock_log_record.args["other"] == "data" @@ -268,9 +272,9 @@ def test_filter_handles_non_dict_headers(self, filter_instance, mock_log_record) def test_filter_with_empty_headers(self, filter_instance, mock_log_record): """filter handles empty headers dict.""" mock_log_record.args = {"headers": {}} - + result = filter_instance.filter(mock_log_record) - + assert result is True assert mock_log_record.args["headers"] == {} @@ -278,21 +282,18 @@ def test_filter_with_complex_header_values(self, filter_instance, mock_log_recor """filter redacts complex header values.""" mock_log_record.args = { "headers": { - "authorization": { - "type": "Bearer", - "token": "complex-token-123" - }, + "authorization": {"type": "Bearer", "token": "complex-token-123"}, "x-api-key": ["key1", "key2", "key3"], "content-type": "application/json", } } - + result = filter_instance.filter(mock_log_record) - + assert result is True headers = mock_log_record.args["headers"] assert headers["authorization"] == "" - assert headers["x-api-key"] == "" + assert headers["x-api-key"] == "" assert headers["content-type"] == "application/json" @pytest.mark.parametrize("sensitive_header", list(SENSITIVE_HEADERS)) @@ -304,9 +305,9 @@ def test_filter_redacts_all_sensitive_headers(self, filter_instance, mock_log_re "safe-header": "safe-value", } } - + result = filter_instance.filter(mock_log_record) - + assert result is True headers = mock_log_record.args["headers"] assert headers[sensitive_header] == "" @@ -319,7 +320,7 @@ class TestUtilsIntegration: def test_sensitive_filter_with_real_logging(self): """SensitiveHeadersFilter works with real logging setup.""" filter_instance = SensitiveHeadersFilter() - + # Create a mock LogRecord directly mock_record = Mock() mock_record.args = { @@ -328,13 +329,13 @@ def test_sensitive_filter_with_real_logging(self): "content-type": "application/json", } } - + # Process the record through our filter result = filter_instance.filter(mock_record) - + # Verify filter returns True (allowing the log) assert result is True - + # Verify sensitive data was redacted assert mock_record.args["headers"]["x-api-key"] == "" assert mock_record.args["headers"]["content-type"] == "application/json" @@ -343,35 +344,26 @@ def test_typeguards_with_complex_data_structures(self): """Type guards work correctly with complex nested structures.""" complex_structure = { "metadata": { - "headers": { - "authorization": "Bearer token", - "x-api-key": "secret" - }, - "params": ["param1", "param2"] + "headers": {"authorization": "Bearer token", "x-api-key": "secret"}, + "params": ["param1", "param2"], }, - "data": { - "nested": { - "deep": { - "value": 42 - } - } - } + "data": {"nested": {"deep": {"value": 42}}}, } - + # Test type guards at different levels assert is_dict(complex_structure) assert is_mapping(complex_structure) assert is_dict(complex_structure["metadata"]) assert is_dict(complex_structure["metadata"]["headers"]) assert not is_dict(complex_structure["metadata"]["params"]) - + # Test with the filter filter_instance = SensitiveHeadersFilter() mock_record = Mock(spec=logging.LogRecord) mock_record.args = complex_structure["metadata"] - + result = filter_instance.filter(mock_record) - + assert result is True assert mock_record.args["headers"]["authorization"] == "" - assert mock_record.args["headers"]["x-api-key"] == "" \ No newline at end of file + assert mock_record.args["headers"]["x-api-key"] == ""