diff --git a/src/judgeval/trainer/fireworks_trainer.py b/src/judgeval/trainer/fireworks_trainer.py index 65d57aae..e76853f4 100644 --- a/src/judgeval/trainer/fireworks_trainer.py +++ b/src/judgeval/trainer/fireworks_trainer.py @@ -1,7 +1,14 @@ import asyncio import json from typing import Optional, Callable, Any, List, Union, Dict -from fireworks import Dataset # type: ignore[import-not-found,import-untyped] + +try: + from fireworks import Dataset # type: ignore[import-not-found] +except ImportError as e: + raise ImportError( + "Fireworks is not installed. Please install it with 'pip install fireworks'" + ) from e + from .config import TrainerConfig, ModelConfig from .base_trainer import BaseTrainer from .trainable_model import TrainableModel diff --git a/src/judgeval/trainer/trainable_model.py b/src/judgeval/trainer/trainable_model.py index 5c64ebfe..a1ab7b78 100644 --- a/src/judgeval/trainer/trainable_model.py +++ b/src/judgeval/trainer/trainable_model.py @@ -1,7 +1,14 @@ import time -from fireworks import LLM # type: ignore[import-not-found,import-untyped] -from .config import TrainerConfig, ModelConfig from typing import Optional, Dict, Any, Callable + +try: + from fireworks import LLM # type: ignore[import-not-found] +except ImportError as e: + raise ImportError( + "Fireworks is not installed. Please install it with 'pip install fireworks'" + ) from e + +from .config import TrainerConfig, ModelConfig from .console import _model_spinner_progress, _print_model_progress from judgeval.exceptions import JudgmentRuntimeError diff --git a/src/judgeval/v1/tracer/base_tracer.py b/src/judgeval/v1/tracer/base_tracer.py index ef669aea..5148e4ab 100644 --- a/src/judgeval/v1/tracer/base_tracer.py +++ b/src/judgeval/v1/tracer/base_tracer.py @@ -1,11 +1,30 @@ from __future__ import annotations +import asyncio +import contextvars import datetime import functools import inspect import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, overload +from collections.abc import ( + Generator as ABCGenerator, + AsyncGenerator as ABCAsyncGenerator, +) +from types import TracebackType +from typing import ( + Any, + AsyncGenerator, + Callable, + Coroutine, + Dict, + Generator, + Optional, + Tuple, + TypeVar, + overload, + cast, +) from opentelemetry import trace from opentelemetry.sdk.trace.export import SpanExporter @@ -38,6 +57,7 @@ ) C = TypeVar("C", bound=Callable[..., Any]) +T = TypeVar("T") class BaseTracer(ABC): @@ -342,6 +362,7 @@ def observe( func: C, span_type: Optional[str] = "span", span_name: Optional[str] = None, + disable_generator_yield_span: bool = False, ) -> C: ... @overload @@ -350,6 +371,7 @@ def observe( func: None = None, span_type: Optional[str] = "span", span_name: Optional[str] = None, + disable_generator_yield_span: bool = False, ) -> Callable[[C], C]: ... def observe( @@ -357,9 +379,12 @@ def observe( func: Optional[C] = None, span_type: Optional[str] = "span", span_name: Optional[str] = None, + disable_generator_yield_span: bool = False, ) -> C | Callable[[C], C]: if func is None: - return lambda f: self.observe(f, span_type, span_name) # type: ignore[return-value] + return lambda f: self.observe( + f, span_type, span_name, disable_generator_yield_span + ) tracer = self.get_tracer() name = span_name or func.__name__ @@ -371,19 +396,18 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: with tracer.start_as_current_span(name) as span: if span_type: span.set_attribute(AttributeKeys.JUDGMENT_SPAN_KIND, span_type) - try: - input_data = _format_inputs(func, args, kwargs) span.set_attribute( - AttributeKeys.JUDGMENT_INPUT, self.serializer(input_data) + AttributeKeys.JUDGMENT_INPUT, + _serialize( + self.serializer, _format_inputs(func, args, kwargs) + ), ) - self.get_span_processor().emit_partial() - result = await func(*args, **kwargs) - span.set_attribute( - AttributeKeys.JUDGMENT_OUTPUT, self.serializer(result) + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self.serializer, result), ) return result except Exception as e: @@ -391,35 +415,59 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: span.set_status(Status(StatusCode.ERROR, str(e))) raise - return async_wrapper # type: ignore[return-value] + return cast(C, async_wrapper) else: @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - with tracer.start_as_current_span(name) as span: + with tracer.start_as_current_span(name, end_on_exit=False) as span: if span_type: span.set_attribute(AttributeKeys.JUDGMENT_SPAN_KIND, span_type) - try: - input_data = _format_inputs(func, args, kwargs) span.set_attribute( - AttributeKeys.JUDGMENT_INPUT, self.serializer(input_data) + AttributeKeys.JUDGMENT_INPUT, + _serialize( + self.serializer, _format_inputs(func, args, kwargs) + ), ) - self.get_span_processor().emit_partial() - result = func(*args, **kwargs) - - span.set_attribute( - AttributeKeys.JUDGMENT_OUTPUT, self.serializer(result) - ) - return result except Exception as e: span.record_exception(e) span.set_status(Status(StatusCode.ERROR, str(e))) + span.end() raise - return sync_wrapper # type: ignore[return-value] + if inspect.isgenerator(result): + span.set_attribute(AttributeKeys.JUDGMENT_OUTPUT, "") + return _ObservedSyncGenerator( + result, + span, + self.serializer, + tracer, + contextvars.copy_context(), + disable_generator_yield_span, + ) + if inspect.isasyncgen(result): + span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, "" + ) + return _ObservedAsyncGenerator( + result, + span, + self.serializer, + tracer, + contextvars.copy_context(), + disable_generator_yield_span, + ) + span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self.serializer, result), + ) + span.end() + return result + + return cast(C, sync_wrapper) @overload def agent(self, func: C, /, *, identifier: Optional[str] = None) -> C: ... @@ -433,64 +481,51 @@ def agent( self, func: Optional[C] = None, /, *, identifier: Optional[str] = None ) -> C | Callable[[C], C]: if func is None: - return lambda f: self.agent(f, identifier=identifier) # type: ignore[return-value] + return lambda f: self.agent(f, identifier=identifier) + + class_name = ( + func.__qualname__.rsplit(".", 1)[0] + if hasattr(func, "__qualname__") and "." in func.__qualname__ + else None + ) - class_name = None - if hasattr(func, "__qualname__") and "." in func.__qualname__: - parts = func.__qualname__.split(".") - if len(parts) >= 2: - class_name = parts[-2] + def build_context(args: Tuple[Any, ...]) -> Any: + ctx = set_value(AGENT_ID_KEY, str(uuid4())) + parent_id = get_value(AGENT_ID_KEY) + if parent_id: + ctx = set_value(PARENT_AGENT_ID_KEY, parent_id, context=ctx) + if class_name: + ctx = set_value(AGENT_CLASS_NAME_KEY, class_name, context=ctx) + if identifier and args and hasattr(args[0], identifier): + ctx = set_value( + AGENT_INSTANCE_NAME_KEY, + str(getattr(args[0], identifier)), + context=ctx, + ) + return ctx if inspect.iscoroutinefunction(func): @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - agent_id = str(uuid4()) - parent_agent_id = get_value(AGENT_ID_KEY) - ctx = set_value(AGENT_ID_KEY, agent_id) - if parent_agent_id: - ctx = set_value(PARENT_AGENT_ID_KEY, parent_agent_id, context=ctx) - if class_name: - ctx = set_value(AGENT_CLASS_NAME_KEY, class_name, context=ctx) - if identifier and args: - instance = args[0] - if hasattr(instance, identifier): - instance_name = str(getattr(instance, identifier)) - ctx = set_value( - AGENT_INSTANCE_NAME_KEY, instance_name, context=ctx - ) - token = attach(ctx) + token = attach(build_context(args)) try: return await func(*args, **kwargs) finally: detach(token) - return async_wrapper # type: ignore[return-value] + return cast(C, async_wrapper) else: @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - agent_id = str(uuid4()) - parent_agent_id = get_value(AGENT_ID_KEY) - ctx = set_value(AGENT_ID_KEY, agent_id) - if parent_agent_id: - ctx = set_value(PARENT_AGENT_ID_KEY, parent_agent_id, context=ctx) - if class_name: - ctx = set_value(AGENT_CLASS_NAME_KEY, class_name, context=ctx) - if identifier and args: - instance = args[0] - if hasattr(instance, identifier): - instance_name = str(getattr(instance, identifier)) - ctx = set_value( - AGENT_INSTANCE_NAME_KEY, instance_name, context=ctx - ) - token = attach(ctx) + token = attach(build_context(args)) try: return func(*args, **kwargs) finally: detach(token) - return sync_wrapper # type: ignore[return-value] + return cast(C, sync_wrapper) def wrap(self, client: ApiClient) -> ApiClient: return wrap_provider(self, client) @@ -518,3 +553,263 @@ def _format_inputs( return inputs except Exception: return {} + + +def _serialize(serializer: Callable[[Any], str], value: Any) -> Any: + return value if isinstance(value, (str, int, float, bool)) else serializer(value) + + +class _ObservedSyncGenerator(ABCGenerator[Any, Any, Any]): + def __init__( + self, + generator: Generator[Any, Any, Any], + span: Span, + serializer: Callable[[Any], str], + tracer: trace.Tracer, + context: contextvars.Context, + disable_generator_yield_span: bool = False, + ) -> None: + self._generator = generator + self._span = span + self._serializer = serializer + self._tracer = tracer + self._context = context + self._closed = False + self._disable_generator_yield_span = disable_generator_yield_span + + def __iter__(self) -> "_ObservedSyncGenerator": + return self + + def __next__(self) -> Any: + return self.send(None) + + def send(self, value: Any) -> Any: + if self._closed: + raise StopIteration + try: + item = self._context.run(self._generator.send, value) + + if not self._disable_generator_yield_span: + with trace.use_span(self._span): + span_name = str(getattr(self._span, "name", "generator_item")) + with self._tracer.start_as_current_span( + span_name, + attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"}, + ) as child_span: + child_span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self._serializer, item), + ) + + return item + except StopIteration: + self._finish() + raise + except Exception as e: + self._record_error(e) + raise + + @overload + def throw( + self, + __typ: type[BaseException], + __val: object = ..., + __tb: Optional[TracebackType] = ..., + ) -> Any: ... + + @overload + def throw( + self, + __typ: BaseException, + __val: None = ..., + __tb: Optional[TracebackType] = ..., + ) -> Any: ... + + def throw( + self, + __typ: type[BaseException] | BaseException, + __val: object = None, + __tb: Optional[TracebackType] = None, + ) -> Any: + if self._closed: + raise StopIteration + try: + if isinstance(__typ, type): + item = self._context.run(self._generator.throw, __typ, __val, __tb) + else: + item = self._context.run(self._generator.throw, __typ, None, __tb) + + if not self._disable_generator_yield_span: + with trace.use_span(self._span): + span_name = str(getattr(self._span, "name", "generator_item")) + with self._tracer.start_as_current_span( + span_name, + attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"}, + ) as child_span: + child_span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self._serializer, item), + ) + + return item + except StopIteration: + self._finish() + raise + except Exception as e: + self._record_error(e) + raise + + def close(self) -> None: + try: + self._generator.close() + finally: + self._finish() + + def _record_error(self, exc: BaseException) -> None: + self._span.record_exception(exc) + self._span.set_status(Status(StatusCode.ERROR, str(exc))) + self._finish() + + def _finish(self) -> None: + if self._closed: + return + self._closed = True + self._span.set_attribute(AttributeKeys.JUDGMENT_SPAN_KIND, "generator") + self._span.end() + + def __del__(self) -> None: + self._finish() + + +class _ObservedAsyncGenerator(ABCAsyncGenerator[Any, Any]): + def __init__( + self, + generator: AsyncGenerator[Any, Any], + span: Span, + serializer: Callable[[Any], str], + tracer: trace.Tracer, + context: contextvars.Context, + disable_generator_yield_span: bool = False, + ) -> None: + self._generator = generator + self._span = span + self._serializer = serializer + self._tracer = tracer + self._context = context + self._closed = False + self._disable_generator_yield_span = disable_generator_yield_span + + def _create_task(self, coro: Coroutine[Any, Any, T]) -> "asyncio.Task[T]": + # Python 3.11 added the context kwarg to asyncio.create_task + # @ref https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + try: + return asyncio.create_task(coro, context=self._context) + except TypeError: + return self._context.run(lambda: asyncio.create_task(coro)) + + def __aiter__(self) -> "_ObservedAsyncGenerator": + return self + + async def __anext__(self) -> Any: + return await self.asend(None) + + async def asend(self, value: Any) -> Any: + if self._closed: + raise StopAsyncIteration + try: + item = await self._create_task(self._generator.asend(value)) + + if not self._disable_generator_yield_span: + with trace.use_span(self._span): + span_name = str(getattr(self._span, "name", "generator_item")) + with self._tracer.start_as_current_span( + span_name, + attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"}, + ) as child_span: + child_span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self._serializer, item), + ) + + return item + except StopAsyncIteration: + self._finish() + raise + except Exception as e: + self._record_error(e) + raise + + @overload + async def athrow( + self, + __typ: type[BaseException], + __val: object = ..., + __tb: Optional[TracebackType] = ..., + ) -> Any: ... + + @overload + async def athrow( + self, + __typ: BaseException, + __val: None = ..., + __tb: Optional[TracebackType] = ..., + ) -> Any: ... + + async def athrow( + self, + __typ: type[BaseException] | BaseException, + __val: object = None, + __tb: Optional[TracebackType] = None, + ) -> Any: + if self._closed: + raise StopAsyncIteration + try: + if isinstance(__typ, type): + item = await self._create_task( + self._generator.athrow(__typ, __val, __tb) + ) + else: + item = await self._create_task( + self._generator.athrow(__typ, None, __tb) + ) + + if not self._disable_generator_yield_span: + with trace.use_span(self._span): + span_name = str(getattr(self._span, "name", "generator_item")) + with self._tracer.start_as_current_span( + span_name, + attributes={AttributeKeys.JUDGMENT_SPAN_KIND: "generator_item"}, + ) as child_span: + child_span.set_attribute( + AttributeKeys.JUDGMENT_OUTPUT, + _serialize(self._serializer, item), + ) + + return item + except StopAsyncIteration: + self._finish() + raise + except Exception as e: + self._record_error(e) + raise + + async def aclose(self) -> None: + try: + await self._generator.aclose() + finally: + self._finish() + + def _record_error(self, exc: BaseException) -> None: + self._span.record_exception(exc) + self._span.set_status(Status(StatusCode.ERROR, str(exc))) + self._finish() + + def _finish(self) -> None: + if self._closed: + return + self._closed = True + self._span.set_attribute(AttributeKeys.JUDGMENT_SPAN_KIND, "generator") + self._span.end() + + def __del__(self) -> None: + self._finish() diff --git a/src/judgeval/v1/tracer/exporters/span_store.py b/src/judgeval/v1/tracer/exporters/span_store.py index e3a8e6ee..8286d267 100644 --- a/src/judgeval/v1/tracer/exporters/span_store.py +++ b/src/judgeval/v1/tracer/exporters/span_store.py @@ -48,3 +48,6 @@ def get_by_trace_id(self, trace_id: str) -> List[ReadableSpan]: def clear_trace(self, trace_id: str) -> None: if trace_id in self._spans_by_trace: del self._spans_by_trace[trace_id] + + def clear(self) -> None: + self._spans_by_trace.clear() diff --git a/src/judgeval/v1/tracer/tracer_factory.py b/src/judgeval/v1/tracer/tracer_factory.py index c2907004..2ca62f10 100644 --- a/src/judgeval/v1/tracer/tracer_factory.py +++ b/src/judgeval/v1/tracer/tracer_factory.py @@ -21,13 +21,10 @@ def create( self, project_name: str, enable_evaluation: bool = True, - serializer: Optional[Callable[[Any], str]] = None, + serializer: Callable[[Any], str] = safe_serialize, filter_tracer: Optional[FilterTracerCallback] = None, initialize: bool = True, ) -> Tracer: - if serializer is None: - serializer = safe_serialize - return Tracer( project_name=project_name, enable_evaluation=enable_evaluation, diff --git a/src/judgeval/v1/trainers/fireworks_trainer.py b/src/judgeval/v1/trainers/fireworks_trainer.py index 40ce4037..756ae837 100644 --- a/src/judgeval/v1/trainers/fireworks_trainer.py +++ b/src/judgeval/v1/trainers/fireworks_trainer.py @@ -4,7 +4,12 @@ import json from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING -from fireworks import Dataset # type: ignore[import-not-found,import-untyped] +try: + from fireworks import Dataset # type: ignore[import-not-found] +except ImportError as e: + raise ImportError( + "Fireworks is not installed. Please install it with 'pip install fireworks'" + ) from e if TYPE_CHECKING: from judgeval.v1.trainers.config import TrainerConfig, ModelConfig @@ -33,11 +38,11 @@ class FireworksTrainer(BaseTrainer): def __init__( self, - config: "TrainerConfig", - trainable_model: "TrainableModel", - tracer: "Tracer", + config: TrainerConfig, + trainable_model: TrainableModel, + tracer: Tracer, project_name: Optional[str] = None, - client: Optional["JudgmentSyncClient"] = None, + client: Optional[JudgmentSyncClient] = None, ): super().__init__(config, trainable_model, tracer, project_name) if client is None: @@ -334,6 +339,10 @@ async def run_reinforcement_learning( "Starting reinforcement training", step_num, self.config.num_steps ) job = self.trainable_model.perform_reinforcement_step(dataset, step) + if job is None: + raise JudgmentRuntimeError( + "Failed to perform reinforcement training step. Job is None." + ) last_state = None with _spinner_progress( diff --git a/src/judgeval/v1/trainers/trainable_model.py b/src/judgeval/v1/trainers/trainable_model.py index 5c64ebfe..5911d66b 100644 --- a/src/judgeval/v1/trainers/trainable_model.py +++ b/src/judgeval/v1/trainers/trainable_model.py @@ -1,7 +1,16 @@ +from __future__ import annotations + import time -from fireworks import LLM # type: ignore[import-not-found,import-untyped] -from .config import TrainerConfig, ModelConfig from typing import Optional, Dict, Any, Callable + +try: + from fireworks import LLM # type: ignore[import-not-found] +except ImportError as e: + raise ImportError( + "Fireworks is not installed. Please install it with 'pip install fireworks'" + ) from e + +from .config import TrainerConfig, ModelConfig from .console import _model_spinner_progress, _print_model_progress from judgeval.exceptions import JudgmentRuntimeError diff --git a/src/tests/v1/tracer/test_generator.py b/src/tests/v1/tracer/test_generator.py new file mode 100644 index 00000000..2762cad7 --- /dev/null +++ b/src/tests/v1/tracer/test_generator.py @@ -0,0 +1,407 @@ +import pytest +import asyncio +import contextvars +from typing import Tuple, Generator +from unittest.mock import patch, MagicMock +from judgeval.v1.tracer.tracer import Tracer +from judgeval.v1.tracer.exporters.in_memory_span_exporter import InMemorySpanExporter +from judgeval.v1.tracer.exporters.span_store import SpanStore +from judgeval.judgment_attribute_keys import AttributeKeys + + +@pytest.fixture +def tracer() -> Generator[Tuple[Tracer, SpanStore], None, None]: + from opentelemetry.trace import _TRACER_PROVIDER_SET_ONCE, _TRACER_PROVIDER + + try: + _TRACER_PROVIDER_SET_ONCE._done = False # type: ignore[attr-defined] + _TRACER_PROVIDER._default = None # type: ignore[attr-defined] + except Exception: + pass + + mock_client = MagicMock() + mock_client.organization_id = "test_org" + mock_client.base_url = "http://test.com/" + + def serializer(x: object) -> str: + return str(x) + + span_store = SpanStore() + exporter = InMemorySpanExporter(span_store) + + with patch("judgeval.v1.utils.resolve_project_id") as mock_resolve: + mock_resolve.return_value = "test_project_id" + + with patch.object(Tracer, "get_span_exporter", return_value=exporter): + tracer_instance = Tracer( + project_name="generator-test", + enable_evaluation=False, + api_client=mock_client, + serializer=serializer, + initialize=True, + ) + + yield tracer_instance, span_store + + tracer_instance.force_flush() + tracer_instance.shutdown() + + +def test_sync_generator_basic(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="sync_gen") + def sync_generator(): + yield 1 + yield 2 + yield 3 + + result = list(sync_generator()) + assert result == [1, 2, 3] + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 4 + + generator_span = [ + s + for s in spans + if s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator" + ][0] + assert generator_span.name == "sync_gen" + + +def test_async_generator_basic(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="async_gen") + async def async_generator(): + yield 1 + yield 2 + yield 3 + + async def run_test(): + result = [] + async for item in async_generator(): + result.append(item) + return result + + result = asyncio.run(run_test()) + assert result == [1, 2, 3] + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 4 + + generator_span = [ + s + for s in spans + if s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator" + ][0] + assert generator_span.name == "async_gen" + + +def test_generator_context_preservation(tracer: Tuple[Tracer, SpanStore]) -> None: + test_var: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "test_var", default=None + ) + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="parent_with_context") + def parent_function(): + test_var.set("TEST_VALUE") + + @tracer_instance.observe(span_name="gen_with_context") + def generator_with_context(): + for i in range(3): + assert test_var.get() == "TEST_VALUE", f"Context lost at iteration {i}" + yield i + + return list(generator_with_context()) + + result = parent_function() + assert result == [0, 1, 2] + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 5 + + +def test_async_generator_context_preservation(tracer: Tuple[Tracer, SpanStore]) -> None: + test_var: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "test_var", default=None + ) + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="async_parent_with_context") + async def async_parent_function(): + test_var.set("ASYNC_TEST_VALUE") + + @tracer_instance.observe(span_name="async_gen_with_context") + async def async_generator_with_context(): + for i in range(3): + assert test_var.get() == "ASYNC_TEST_VALUE", ( + f"Context lost at iteration {i}" + ) + yield i + + result = [] + async for item in async_generator_with_context(): + result.append(item) + return result + + result = asyncio.run(async_parent_function()) + assert result == [0, 1, 2] + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 5 + + +def test_generator_with_customer_id(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="parent_with_customer") + def parent_with_customer(): + tracer_instance.set_customer_id("gen-customer") + + @tracer_instance.observe(span_name="child_generator") + def child_generator(): + yield 1 + yield 2 + yield 3 + + return list(child_generator()) + + result = parent_with_customer() + assert result == [1, 2, 3] + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 5 + + child_spans = [ + s + for s in spans + if s.name == "child_generator" + or ( + s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator_item" + ) + ] + for span in child_spans: + if span.attributes: + assert ( + span.attributes.get(AttributeKeys.JUDGMENT_CUSTOMER_ID) + == "gen-customer" + ) + + +def test_generator_exception_handling(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="failing_generator") + def failing_generator(): + yield 1 + yield 2 + raise ValueError("Generator error") + + with pytest.raises(ValueError, match="Generator error"): + list(failing_generator()) + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 3 + + generator_span = [s for s in spans if s.name == "failing_generator"][0] + assert generator_span.name == "failing_generator" + + +def test_async_generator_exception_handling(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="failing_async_generator") + async def failing_async_generator(): + yield 1 + yield 2 + raise ValueError("Async generator error") + + async def run_test(): + result = [] + async for item in failing_async_generator(): + result.append(item) + + with pytest.raises(ValueError, match="Async generator error"): + asyncio.run(run_test()) + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 3 + + generator_span = [s for s in spans if s.name == "failing_async_generator"][0] + assert generator_span.name == "failing_async_generator" + + +def test_generator_partial_consumption(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="partial_generator") + def partial_generator(): + yield 1 + yield 2 + yield 3 + yield 4 + yield 5 + + gen = partial_generator() + assert next(gen) == 1 + assert next(gen) == 2 + gen.close() + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 3 + + generator_span = [s for s in spans if s.name == "partial_generator"][0] + assert generator_span.name == "partial_generator" + assert generator_span.end_time is not None + + +def test_async_generator_partial_consumption(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="async_partial_generator") + async def async_partial_generator(): + yield 1 + yield 2 + yield 3 + yield 4 + yield 5 + + async def run_test(): + gen = async_partial_generator() + assert await gen.__anext__() == 1 + assert await gen.__anext__() == 2 + await gen.aclose() + + asyncio.run(run_test()) + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 3 + + generator_span = [s for s in spans if s.name == "async_partial_generator"][0] + assert generator_span.name == "async_partial_generator" + assert generator_span.end_time is not None + + +def test_generator_parent_child_relationship(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="parent_function") + def parent_function(): + @tracer_instance.observe(span_name="child_generator") + def child_generator(): + yield 1 + yield 2 + + return list(child_generator()) + + parent_function() + + tracer_instance.force_flush() + spans = span_store.get_all() + assert len(spans) == 4 + + parent_span = [s for s in spans if s.name == "parent_function"][0] + child_gen_span = [ + s + for s in spans + if s.name == "child_generator" + and s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator" + ][0] + + assert parent_span.context is not None + assert child_gen_span.context is not None + assert child_gen_span.context.trace_id == parent_span.context.trace_id + assert child_gen_span.parent is not None + assert child_gen_span.parent.span_id == parent_span.context.span_id + + +def test_sync_generator_with_child_spans(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="parent_gen") + def parent_generator(): + yield 1 + yield 2 + yield 3 + + result = list(parent_generator()) + assert result == [1, 2, 3] + + tracer_instance.force_flush() + spans = span_store.get_all() + + assert len(spans) == 4 + + child_spans = [ + s + for s in spans + if s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator_item" + ] + + assert len(child_spans) == 3 + + for child_span in child_spans: + output = ( + child_span.attributes.get(AttributeKeys.JUDGMENT_OUTPUT) + if child_span.attributes + else None + ) + assert output is not None + + +def test_async_generator_with_child_spans(tracer: Tuple[Tracer, SpanStore]) -> None: + tracer_instance, span_store = tracer + + @tracer_instance.observe(span_name="async_parent_gen") + async def async_parent_generator(): + yield "a" + yield "b" + yield "c" + + async def run_test(): + result = [] + async for item in async_parent_generator(): + result.append(item) + return result + + result = asyncio.run(run_test()) + assert result == ["a", "b", "c"] + + tracer_instance.force_flush() + spans = span_store.get_all() + + assert len(spans) == 4 + + child_spans = [ + s + for s in spans + if s.attributes + and s.attributes.get(AttributeKeys.JUDGMENT_SPAN_KIND) == "generator_item" + ] + + assert len(child_spans) == 3 + + for child_span in child_spans: + output = ( + child_span.attributes.get(AttributeKeys.JUDGMENT_OUTPUT) + if child_span.attributes + else None + ) + assert output is not None diff --git a/src/tests/v1/tracer/test_tracer.py b/src/tests/v1/tracer/test_tracer.py index b1542a2d..3b41fe7b 100644 --- a/src/tests/v1/tracer/test_tracer.py +++ b/src/tests/v1/tracer/test_tracer.py @@ -1,20 +1,27 @@ import pytest +from typing import Any from unittest.mock import MagicMock, patch from judgeval.v1.tracer.tracer import Tracer from opentelemetry.sdk.trace import TracerProvider @pytest.fixture -def mock_client(): - return MagicMock() +def mock_client() -> MagicMock: + client = MagicMock() + client.organization_id = "test_org" + client.base_url = "http://test.com/" + return client @pytest.fixture -def serializer(): - return lambda x: str(x) +def serializer() -> Any: + def serialize(x: object) -> str: + return str(x) + return serialize -def test_tracer_initialization(mock_client, serializer): + +def test_tracer_initialization(mock_client: MagicMock, serializer: Any) -> None: tracer = Tracer( project_name="test_project", enable_evaluation=True, @@ -27,21 +34,28 @@ def test_tracer_initialization(mock_client, serializer): assert tracer._tracer_provider is None -def test_tracer_initialization_with_initialize(mock_client, serializer): +def test_tracer_initialization_with_initialize( + mock_client: MagicMock, serializer: Any +) -> None: with patch("judgeval.v1.tracer.tracer.trace.set_tracer_provider"): - tracer = Tracer( - project_name="test_project", - enable_evaluation=True, - api_client=mock_client, - serializer=serializer, - initialize=True, - ) - - assert tracer._tracer_provider is not None - assert isinstance(tracer._tracer_provider, TracerProvider) - - -def test_tracer_force_flush_without_initialization(mock_client, serializer): + with patch( + "judgeval.v1.utils.resolve_project_id", return_value="test_project_id" + ): + tracer = Tracer( + project_name="test_project", + enable_evaluation=True, + api_client=mock_client, + serializer=serializer, + initialize=True, + ) + + assert tracer._tracer_provider is not None + assert isinstance(tracer._tracer_provider, TracerProvider) + + +def test_tracer_force_flush_without_initialization( + mock_client: MagicMock, serializer: Any +) -> None: tracer = Tracer( project_name="test_project", enable_evaluation=True, @@ -54,24 +68,29 @@ def test_tracer_force_flush_without_initialization(mock_client, serializer): assert result is False -def test_tracer_force_flush_with_initialization(mock_client, serializer): +def test_tracer_force_flush_with_initialization( + mock_client: MagicMock, serializer: Any +) -> None: with patch("judgeval.v1.tracer.tracer.trace.set_tracer_provider"): - tracer = Tracer( - project_name="test_project", - enable_evaluation=True, - api_client=mock_client, - serializer=serializer, - initialize=True, - ) - - tracer._tracer_provider.force_flush = MagicMock(return_value=True) - result = tracer.force_flush(timeout_millis=5000) - - assert result is True - tracer._tracer_provider.force_flush.assert_called_once_with(5000) - - -def test_tracer_shutdown_without_initialization(mock_client, serializer): + with patch( + "judgeval.v1.utils.resolve_project_id", return_value="test_project_id" + ): + tracer = Tracer( + project_name="test_project", + enable_evaluation=True, + api_client=mock_client, + serializer=serializer, + initialize=True, + ) + + assert tracer._tracer_provider is not None + result = tracer.force_flush(timeout_millis=5000) + assert isinstance(result, bool) + + +def test_tracer_shutdown_without_initialization( + mock_client: MagicMock, serializer: Any +) -> None: tracer = Tracer( project_name="test_project", enable_evaluation=True, @@ -83,17 +102,20 @@ def test_tracer_shutdown_without_initialization(mock_client, serializer): tracer.shutdown() -def test_tracer_shutdown_with_initialization(mock_client, serializer): +def test_tracer_shutdown_with_initialization( + mock_client: MagicMock, serializer: Any +) -> None: with patch("judgeval.v1.tracer.tracer.trace.set_tracer_provider"): - tracer = Tracer( - project_name="test_project", - enable_evaluation=True, - api_client=mock_client, - serializer=serializer, - initialize=True, - ) - - tracer._tracer_provider.shutdown = MagicMock() - tracer.shutdown(timeout_millis=10000) - - tracer._tracer_provider.shutdown.assert_called_once() + with patch( + "judgeval.v1.utils.resolve_project_id", return_value="test_project_id" + ): + tracer = Tracer( + project_name="test_project", + enable_evaluation=True, + api_client=mock_client, + serializer=serializer, + initialize=True, + ) + + assert tracer._tracer_provider is not None + tracer.shutdown(timeout_millis=10000)