diff --git a/packages/nvidia_nat_adk/src/nat/plugins/adk/callback_handler.py b/packages/nvidia_nat_adk/src/nat/plugins/adk/callback_handler.py index b0cd7aa9dd..aee646243a 100644 --- a/packages/nvidia_nat_adk/src/nat/plugins/adk/callback_handler.py +++ b/packages/nvidia_nat_adk/src/nat/plugins/adk/callback_handler.py @@ -59,6 +59,7 @@ def __init__(self): # Original references to Google ADK Tool and LLM methods (for uninstrumenting if needed) self._original_tool_call = None self._original_llm_call = None + self._original_adk_llm_call = None self._instrumented = False def instrument(self) -> None: @@ -77,6 +78,7 @@ def instrument(self) -> None: logger.exception("litellm import failed; skipping instrumentation") return try: + import google.adk.models.lite_llm as adk_lite_llm from google.adk.tools.function_tool import FunctionTool except Exception as _e: logger.exception("ADK import failed; skipping instrumentation") @@ -85,9 +87,12 @@ def instrument(self) -> None: # Save the originals self._original_tool_call = FunctionTool.run_async self._original_llm_call = litellm.acompletion + self._original_adk_llm_call = adk_lite_llm.acompletion + wrapped_llm = self._llm_call_monkey_patch() FunctionTool.run_async = self._tool_use_monkey_patch() - litellm.acompletion = self._llm_call_monkey_patch() + litellm.acompletion = wrapped_llm + adk_lite_llm.acompletion = wrapped_llm logger.debug("ADKProfilerHandler instrumentation applied successfully.") self._instrumented = True @@ -97,6 +102,7 @@ def uninstrument(self) -> None: Add an explicit unpatch to avoid side-effects across tests/process lifetime. """ try: + import google.adk.models.lite_llm as adk_lite_llm import litellm from google.adk.tools.function_tool import FunctionTool if self._original_tool_call is not None: @@ -107,6 +113,10 @@ def uninstrument(self) -> None: litellm.acompletion = self._original_llm_call self._original_llm_call = None + if self._original_adk_llm_call is not None: + adk_lite_llm.acompletion = self._original_adk_llm_call + self._original_adk_llm_call = None + self._instrumented = False self.last_call_ts = 0.0 logger.debug("ADKProfilerHandler uninstrumented successfully.")