-
Notifications
You must be signed in to change notification settings - Fork 87
Aaryan/trace custom scorers #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: staging
Are you sure you want to change the base?
Changes from 3 commits
e32ef92
de3b36d
93fd3db
ffeb50a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from typing import List | ||
| from pydantic import BaseModel | ||
| from judgeval.data.judgment_types import OtelTraceSpan | ||
|
|
||
|
|
||
| class TraceSpanData(OtelTraceSpan): | ||
| pass | ||
|
|
||
|
|
||
| class TraceData(BaseModel): | ||
| trace_spans: List[TraceSpanData] |
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,193 @@ | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Infrastructure for executing evaluations of `Trace`s using one or more `TraceScorer`s. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||
| from tqdm.asyncio import tqdm_asyncio | ||||||||||||||||||||||||||
| from typing import List, Optional, Callable, Union | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from judgeval.data import ( | ||||||||||||||||||||||||||
| ScoringResult, | ||||||||||||||||||||||||||
| generate_scoring_result, | ||||||||||||||||||||||||||
| create_scorer_data, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from judgeval.data.otel_trace import TraceData | ||||||||||||||||||||||||||
| from judgeval.scorers.trace_scorer import TraceScorer | ||||||||||||||||||||||||||
| from judgeval.scorers.utils import clone_scorers | ||||||||||||||||||||||||||
| from judgeval.logger import judgeval_logger | ||||||||||||||||||||||||||
| from judgeval.judges import JudgevalJudge | ||||||||||||||||||||||||||
| from judgeval.env import JUDGMENT_DEFAULT_GPT_MODEL | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def safe_a_score_trace(scorer: TraceScorer, trace: TraceData): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Scoring task function when not using a progress indicator! | ||||||||||||||||||||||||||
| "Safely" scores an `Trace` using a `TraceScorer` by gracefully handling any exceptions that may occur. | ||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| scorer (TraceScorer): The `TraceScorer` to use for scoring the trace. | ||||||||||||||||||||||||||
| trace (Trace): The `Trace` to be scored. | ||||||||||||||||||||||||||
|
Comment on lines
+30
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for the
Suggested change
|
||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| score = await scorer.a_score_trace(trace) | ||||||||||||||||||||||||||
| if score is None: | ||||||||||||||||||||||||||
| raise Exception("a_score_trace need to return a score") | ||||||||||||||||||||||||||
| elif score < 0: | ||||||||||||||||||||||||||
| judgeval_logger.warning("score cannot be less than 0 , setting to 0") | ||||||||||||||||||||||||||
| score = 0 | ||||||||||||||||||||||||||
| elif score > 1: | ||||||||||||||||||||||||||
| judgeval_logger.warning("score cannot be greater than 1 , setting to 1") | ||||||||||||||||||||||||||
| score = 1 | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| scorer.score = score | ||||||||||||||||||||||||||
| scorer.success = scorer.success_check() | ||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||
| judgeval_logger.error(f"Error during scoring: {str(e)}") | ||||||||||||||||||||||||||
| scorer.error = str(e) | ||||||||||||||||||||||||||
| scorer.success = False | ||||||||||||||||||||||||||
| scorer.score = 0 | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def a_execute_trace_scoring( | ||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] [CodeDuplication] The logic in this file for scoring traces is almost a complete duplicate of the logic for scoring examples in To avoid code duplication and improve maintainability, consider creating a generic scoring execution utility that can be reused for both Context for Agents |
||||||||||||||||||||||||||
| traces: List[TraceData], | ||||||||||||||||||||||||||
| scorers: List[TraceScorer], | ||||||||||||||||||||||||||
| model: Optional[Union[str, List[str], JudgevalJudge]] = JUDGMENT_DEFAULT_GPT_MODEL, | ||||||||||||||||||||||||||
| ignore_errors: bool = False, | ||||||||||||||||||||||||||
| throttle_value: int = 0, | ||||||||||||||||||||||||||
| max_concurrent: int = 100, | ||||||||||||||||||||||||||
| show_progress: bool = True, | ||||||||||||||||||||||||||
| ) -> List[ScoringResult]: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Executes evaluations of `Trace`s asynchronously using one or more `TraceScorer`s. | ||||||||||||||||||||||||||
| Each `Trace` will be evaluated by all of the `TraceScorer`s in the `scorers` list. | ||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| traces (List[List[TraceSpan]]): A list of `TraceSpan` objects to be evaluated. | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for the
Suggested change
|
||||||||||||||||||||||||||
| scorers (List[TraceScorer]): A list of `TraceScorer` objects to evaluate the traces. | ||||||||||||||||||||||||||
| ignore_errors (bool): Whether to ignore errors during evaluation. | ||||||||||||||||||||||||||
| throttle_value (int): The amount of time to wait between starting each task. | ||||||||||||||||||||||||||
| max_concurrent (int): The maximum number of concurrent tasks. | ||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| List[ScoringResult]: A list of `ScoringResult` objects containing the evaluation results. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| semaphore = asyncio.Semaphore(max_concurrent) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def execute_with_semaphore(func: Callable, *args, **kwargs): | ||||||||||||||||||||||||||
| async with semaphore: | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| return await func(*args, **kwargs) | ||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||
| judgeval_logger.error(f"Error executing function: {e}") | ||||||||||||||||||||||||||
| if kwargs.get("ignore_errors", False): | ||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| for scorer in scorers: | ||||||||||||||||||||||||||
| if not scorer.model and isinstance(model, str): | ||||||||||||||||||||||||||
| scorer._add_model(model) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| scoring_results: List[Optional[ScoringResult]] = [None for _ in traces] | ||||||||||||||||||||||||||
| tasks = [] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if show_progress: | ||||||||||||||||||||||||||
| with tqdm_asyncio( | ||||||||||||||||||||||||||
| desc=f"Evaluating {len(traces)} trace(s) in parallel", | ||||||||||||||||||||||||||
| unit="TraceData", | ||||||||||||||||||||||||||
| total=len(traces), | ||||||||||||||||||||||||||
| bar_format="{desc}: |{bar}|{percentage:3.0f}% ({n_fmt}/{total_fmt}) [Time Taken: {elapsed}, {rate_fmt}{postfix}]", | ||||||||||||||||||||||||||
| ) as pbar: | ||||||||||||||||||||||||||
| for i, trace in enumerate(traces): | ||||||||||||||||||||||||||
| if isinstance(trace, TraceData): | ||||||||||||||||||||||||||
| if len(scorers) == 0: | ||||||||||||||||||||||||||
| pbar.update(1) | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| cloned_scorers = clone_scorers(scorers) # type: ignore | ||||||||||||||||||||||||||
| task = execute_with_semaphore( | ||||||||||||||||||||||||||
| func=a_eval_traces_helper, | ||||||||||||||||||||||||||
| scorers=cloned_scorers, | ||||||||||||||||||||||||||
| trace=trace, | ||||||||||||||||||||||||||
| scoring_results=scoring_results, | ||||||||||||||||||||||||||
| score_index=i, | ||||||||||||||||||||||||||
| ignore_errors=ignore_errors, | ||||||||||||||||||||||||||
| pbar=pbar, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| tasks.append(asyncio.create_task(task)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| await asyncio.sleep(throttle_value) | ||||||||||||||||||||||||||
| await asyncio.gather(*tasks) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| for i, trace in enumerate(traces): | ||||||||||||||||||||||||||
| if isinstance(trace, TraceData): | ||||||||||||||||||||||||||
| if len(scorers) == 0: | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| cloned_scorers = clone_scorers(scorers) # type: ignore | ||||||||||||||||||||||||||
| task = execute_with_semaphore( | ||||||||||||||||||||||||||
| func=a_eval_traces_helper, | ||||||||||||||||||||||||||
| scorers=cloned_scorers, | ||||||||||||||||||||||||||
| trace=trace, | ||||||||||||||||||||||||||
| scoring_results=scoring_results, | ||||||||||||||||||||||||||
| score_index=i, | ||||||||||||||||||||||||||
| ignore_errors=ignore_errors, | ||||||||||||||||||||||||||
| pbar=None, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| tasks.append(asyncio.create_task(task)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| await asyncio.sleep(throttle_value) | ||||||||||||||||||||||||||
| await asyncio.gather(*tasks) | ||||||||||||||||||||||||||
|
Comment on lines
+97
to
+143
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication between the To improve maintainability, consider refactoring this to a single loop. You could use a context manager for the progress bar that does nothing when import contextlib
# ...
progress_context = tqdm_asyncio(...) if show_progress else contextlib.nullcontext()
with progress_context as pbar:
for i, trace in enumerate(traces):
# ... common logic for creating tasks ...
# pass pbar to helper, which can handle if it's None |
||||||||||||||||||||||||||
| return [result for result in scoring_results if result is not None] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def a_eval_traces_helper( | ||||||||||||||||||||||||||
| scorers: List[TraceScorer], | ||||||||||||||||||||||||||
| trace: TraceData, | ||||||||||||||||||||||||||
| scoring_results: List[ScoringResult], | ||||||||||||||||||||||||||
| score_index: int, | ||||||||||||||||||||||||||
| ignore_errors: bool, | ||||||||||||||||||||||||||
| pbar: Optional[tqdm_asyncio] = None, | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Evaluate a single trace asynchronously using a list of scorers. | ||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| scorers (List[TraceScorer]): List of TraceScorer objects to evaluate the trace. | ||||||||||||||||||||||||||
| trace (Trace): The trace to be evaluated. | ||||||||||||||||||||||||||
| scoring_results (List[TestResult]): List to store the scoring results. | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for the
Suggested change
|
||||||||||||||||||||||||||
| score_index (int): Index at which the result should be stored in scoring_results. | ||||||||||||||||||||||||||
| ignore_errors (bool): Flag to indicate whether to ignore errors during scoring. | ||||||||||||||||||||||||||
| pbar (Optional[tqdm_asyncio]): Optional progress bar for tracking progress. | ||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| None | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| # scoring the Trace | ||||||||||||||||||||||||||
| scoring_start_time = time.perf_counter() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| tasks = [safe_a_score_trace(scorer, trace) for scorer in scorers] | ||||||||||||||||||||||||||
| await asyncio.gather(*tasks) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| success = True | ||||||||||||||||||||||||||
| scorer_data_list = [] | ||||||||||||||||||||||||||
| for scorer in scorers: | ||||||||||||||||||||||||||
| if getattr(scorer, "skipped", False): | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
| scorer_data = create_scorer_data(scorer) | ||||||||||||||||||||||||||
| for s in scorer_data: | ||||||||||||||||||||||||||
| success = success and s.success | ||||||||||||||||||||||||||
| scorer_data_list.extend(scorer_data) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| scoring_end_time = time.perf_counter() | ||||||||||||||||||||||||||
| run_duration = scoring_end_time - scoring_start_time | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| scoring_result = generate_scoring_result( | ||||||||||||||||||||||||||
| trace.trace_spans[0], scorer_data_list, run_duration, success | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| scoring_results[score_index] = scoring_result | ||||||||||||||||||||||||||
|
Comment on lines
+187
to
+190
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if pbar is not None: | ||||||||||||||||||||||||||
| pbar.update(1) | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| from judgeval.scorers.base_scorer import BaseScorer | ||
| from judgeval.data import TraceData | ||
|
|
||
|
|
||
| class TraceScorer(BaseScorer): | ||
| score_type: str = "Custom Trace" | ||
|
|
||
| async def a_score_trace(self, trace: TraceData, *args, **kwargs) -> float: | ||
| """ | ||
| Asynchronously measures the score on a single trace | ||
| """ | ||
| raise NotImplementedError( | ||
| "You must implement the `a_score_trace` method in your custom scorer" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[BestPractice]
The logic to find the scorer class by walking the AST has been made more complex by the introduction of the
is_traceflag. This section is becoming difficult to read and maintain. Consider extracting the AST parsing and validation logic into a dedicated helper function to improve clarity and separation of concerns.Context for Agents