From 7f3f8ded87937b6d8517d3ef97d4000f1d579b62 Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Fri, 3 Oct 2025 18:42:58 +0000 Subject: [PATCH 1/8] enable tracing for webarena verified --- src/agentlab/experiments/loop.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index de4b976a..4511790e 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -425,6 +425,9 @@ def run(self): exp_dir=self.exp_dir, use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False), ) + # webarena_verified hack, enable playwright tracing + if self.env_args.task_name.startswith("webarena_verified"): + env.unwrapped.context.tracing.start(snapshots=True) logger.debug("Environment created.") step_info = StepInfo(step=0) @@ -504,6 +507,9 @@ def run(self): logger.exception(f"Error while saving experiment info: {e}") try: if env is not None: + # webarena_verified hack, close playwright tracing + if self.env_args.task_name.startswith("webarena_verified"): + env.unwrapped.context.tracing.stop(path=self.exp_dir / "pw_traces" / f"{self.exp_name}.zip") env.close() except Exception as e: logger.exception(f"Error while closing the environment: {e}") @@ -915,6 +921,8 @@ def _get_env_name(task_name: str): import browsergym.workarena elif task_name.startswith("webarena"): import browsergym.webarena + import browsergym.webarena_verified + import browsergym.webarenalite elif task_name.startswith("visualwebarena"): import browsergym.visualwebarena elif task_name.startswith("assistantbench"): From a0d7f37fc3fb4e5f1589113e2b5c64492e32ee5b Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Thu, 16 Oct 2025 19:32:40 +0000 Subject: [PATCH 2/8] move webarena verified specific code to bgym/webarena_verified --- src/agentlab/experiments/loop.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index 4511790e..758d2d9d 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -425,9 +425,6 @@ def run(self): exp_dir=self.exp_dir, use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False), ) - # webarena_verified hack, enable playwright tracing - if self.env_args.task_name.startswith("webarena_verified"): - env.unwrapped.context.tracing.start(snapshots=True) logger.debug("Environment created.") step_info = StepInfo(step=0) @@ -507,9 +504,6 @@ def run(self): logger.exception(f"Error while saving experiment info: {e}") try: if env is not None: - # webarena_verified hack, close playwright tracing - if self.env_args.task_name.startswith("webarena_verified"): - env.unwrapped.context.tracing.stop(path=self.exp_dir / "pw_traces" / f"{self.exp_name}.zip") env.close() except Exception as e: logger.exception(f"Error while closing the environment: {e}") From 23c4977d0c35674f5c3dde441101c283f46b421e Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Sat, 25 Oct 2025 02:35:03 +0000 Subject: [PATCH 3/8] update gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 6234b3cd..92a1d38c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +.vscode/ +.local/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 723b38cbf498b928c3fd623f675bb882d284acf2 Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Tue, 4 Nov 2025 14:53:40 +0000 Subject: [PATCH 4/8] add webarena-verified to readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8397c605..30d261ab 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ AgentLab Features: | Benchmark | Setup
Link | # Task
Template| Seed
Diversity | Max
Step | Multi-tab | Hosted Method | BrowserGym
Leaderboard | |-----------|------------|---------|----------------|-----------|-----------|---------------|----------------------| | [WebArena](https://webarena.dev/) | [setup](https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/webarena/README.md) | 812 | None | 30 | yes | self hosted (docker) | soon | +| [WebArena-Verified](https://github.com/ServiceNow/platform-labs-webarena-verified) | [setup](https://github.com/ServiceNow/BrowserGym/blob/wa_verified/browsergym/webarena_verified/README.md) | 812 | None | 30 | yes | self hosted | soon | | [WorkArena](https://github.com/ServiceNow/WorkArena) L1 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 33 | High | 30 | no | demo instance | soon | | [WorkArena](https://github.com/ServiceNow/WorkArena) L2 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 341 | High | 50 | no | demo instance | soon | | [WorkArena](https://github.com/ServiceNow/WorkArena) L3 | [setup](https://github.com/ServiceNow/WorkArena?tab=readme-ov-file#getting-started) | 341 | High | 50 | no | demo instance | soon | From b3179258e7c8de45bded5d6853159c3e59a48fc9 Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Thu, 20 Nov 2025 16:10:31 +0000 Subject: [PATCH 5/8] update pricing model to work with azure models --- src/agentlab/llm/chat_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 188747ac..5487c71f 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -433,7 +433,7 @@ def __init__( min_retry_wait_time=min_retry_wait_time, client_class=OpenAI, client_args=client_args, - pricing_func=tracking.get_pricing_openai, + pricing_func=partial(tracking.get_pricing_litellm, model_name=model_name), log_probs=log_probs, ) From 58de958efe5e73c0d5b0a375c505815032e6e295 Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Thu, 4 Dec 2025 15:48:29 +0000 Subject: [PATCH 6/8] optional import --- src/agentlab/experiments/loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index b6b12709..75dd9f40 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -914,8 +914,12 @@ def _get_env_name(task_name: str): import browsergym.workarena elif task_name.startswith("webarena"): import browsergym.webarena - import browsergym.webarena_verified import browsergym.webarenalite + + try: + import browsergym.webarena_verified + except ImportError: + logger.warning("browsergym.webarena_verified not found. Skipping import.") elif task_name.startswith("visualwebarena"): import browsergym.visualwebarena elif task_name.startswith("assistantbench"): From 5dfec6dce846570e3c94fae723fc864037b4ebdb Mon Sep 17 00:00:00 2001 From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:47:19 -0500 Subject: [PATCH 7/8] CUA like agent with tool use and hint support. (#318) * overlay_utils can return array if needed. * exact goal loading in the tool-use-agent * add tool use cua_like_agent * remove unused imports * make extra-dependency user facing and update dev dependency group * update CI/CD env installation (dev is default group) * update makefile to use uv * black and remove unneeded items pip list from code formatting CI/CD. --------- Co-authored-by: Patrice Bechard --- .github/workflows/code_format.yml | 7 +- .github/workflows/darglint.yml | 5 +- .github/workflows/unit_tests.yml | 2 +- Makefile | 18 +- pyproject.toml | 26 +- src/agentlab/agents/agent_utils.py | 8 +- .../agents/tool_use_agent/cua_like_agent.py | 767 ++++++++++++++++++ .../agents/tool_use_agent/tool_use_agent.py | 13 +- src/agentlab/utils/hinting.py | 1 - uv.lock | 44 +- 10 files changed, 829 insertions(+), 62 deletions(-) create mode 100644 src/agentlab/agents/tool_use_agent/cua_like_agent.py diff --git a/.github/workflows/code_format.yml b/.github/workflows/code_format.yml index f5551244..7c052c4e 100644 --- a/.github/workflows/code_format.yml +++ b/.github/workflows/code_format.yml @@ -24,13 +24,10 @@ jobs: enable-cache: true - name: Set up Python - run: uv python install 3.11 + run: uv python install 3.12 - name: Install dependencies - run: uv sync --frozen --extra dev - - - name: List packages - run: uv pip list + run: uv sync --frozen - name: Code Formatting run: uv run black src/ --check --diff diff --git a/.github/workflows/darglint.yml b/.github/workflows/darglint.yml index 7fca9321..76c947ae 100644 --- a/.github/workflows/darglint.yml +++ b/.github/workflows/darglint.yml @@ -27,10 +27,7 @@ jobs: run: uv python install 3.12 # this fails in 3.11 - name: Install dependencies - run: uv sync --frozen --extra dev - - - name: List packages - run: uv pip list + run: uv sync --frozen - name: Darglint checks run: uv run darglint -v 2 -z short src/ diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 7ce65722..d3194d93 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -32,7 +32,7 @@ jobs: run: uv python install 3.11 - name: Install AgentLab - run: uv sync --frozen --extra dev + run: uv sync --frozen - name: List packages run: uv pip list diff --git a/Makefile b/Makefile index 23799f32..1c33ce09 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,14 @@ .PHONY: test setup miniwob lint stop-miniwob osworld setup: - @pip install -e . - @playwright install chromium --with-deps - @python -c 'import nltk; nltk.download("punkt_tab")' + @uv sync --python 3.12 + @uv run playwright install chromium --with-deps + @uv run python -c 'import nltk; nltk.download("punkt_tab")' miniwob: stop-miniwob @git clone https://github.com/Farama-Foundation/miniwob-plusplus.git || true @cd miniwob-plusplus && git checkout 7fd85d71a4b60325c6585396ec4f48377d049838 - @python -m http.server 8080 --directory miniwob-plusplus/miniwob/html & echo $$! > .miniwob-server.pid + @uv run python -m http.server 8080 --directory miniwob-plusplus/miniwob/html & echo $$! > .miniwob-server.pid @sleep 3 @echo "MiniWob server started on http://localhost:8080" @@ -22,14 +22,14 @@ stop-miniwob: @echo "MiniWob server stopped" run-tests: - @MINIWOB_URL="http://localhost:8080/miniwob/" pytest -n 5 --durations=10 -m 'not pricy' tests/ + @MINIWOB_URL="http://localhost:8080/miniwob/" uv run pytest -n 5 --durations=10 -m 'not pricy' tests/ @echo "Tests completed" test: setup miniwob check-miniwob run-tests stop-miniwob lint: setup - @black src/ --check --diff - @darglint -v 2 -z short src/ + @uv run black src/ --check --diff + @uv run darglint -v 2 -z short src/ osworld: @echo "Setting up OSWorld..." @@ -42,9 +42,9 @@ osworld: sed -i.bak 's/tqdm~=.*/tqdm/' requirements.txt && \ sed -i.bak 's/pandas~=.*/pandas/' requirements.txt @echo "Installing OSWorld requirements..." - @cd OSWorld && pip install -r requirements.txt + @cd OSWorld && uv pip install -r requirements.txt @echo "Installing OSWorld in development mode..." - @cd OSWorld && pip install -e . + @cd OSWorld && uv pip install -e . @echo "OSWorld setup completed!" @echo "Next steps:" @echo "1. Configure your VM (VMware/VirtualBox) according to OSWorld documentation" diff --git a/pyproject.toml b/pyproject.toml index 64285668..98fde2d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,24 +49,20 @@ dependencies = [ "pillow", "gymnasium>=0.27", "torch>=2.2.2", - "safetensors>=0.4.0", - "transformers>=4.38.2", "anthropic>=0.62.0", "litellm>=1.75.3", "python-dotenv>=1.1.1", ] [project.optional-dependencies] -dev = [ - "black[jupyter]>=24.2.0", - "blacken-docs", - "pre-commit", - "pytest==7.3.2", - "flaky", - "pytest-xdist", - "pytest-playwright", +hint = [ + "sentence-transformers>=5.0.0", +] +transformers = [ + "transformers>=4.38.2", ] + [project.urls] Homepage = "https://github.com/ServiceNow/AgentLab" @@ -107,9 +103,13 @@ dev = [ "darglint>=1.8.1", "ipykernel>=6.30.1", "pip>=25.2", -] -hint = [ - "sentence-transformers>=5.0.0", + "black[jupyter]>=24.2.0", + "blacken-docs", + "pre-commit", + "pytest==7.3.2", + "flaky", + "pytest-xdist", + "pytest-playwright", ] diff --git a/src/agentlab/agents/agent_utils.py b/src/agentlab/agents/agent_utils.py index 179a94d2..954977b2 100644 --- a/src/agentlab/agents/agent_utils.py +++ b/src/agentlab/agents/agent_utils.py @@ -5,6 +5,7 @@ from agentlab.analyze import overlay_utils from agentlab.llm.llm_utils import img_to_base_64 +import numpy as np def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image: @@ -135,7 +136,7 @@ def zoom_webpage(page: Page, zoom_factor: float = 1.5): return page -def overlay_action(obs, action): +def overlay_action(obs, action, return_array=False): """Overlays actions on screenshot in-place""" act_img = copy.deepcopy(obs["screenshot"]) act_img = Image.fromarray(act_img) @@ -153,4 +154,7 @@ def overlay_action(obs, action): pass overlay_utils.annotate_action(act_img, action, properties=new_obs_properties) - return img_to_base_64(act_img) + if return_array: + return np.array(act_img) + else: + return img_to_base_64(act_img) diff --git a/src/agentlab/agents/tool_use_agent/cua_like_agent.py b/src/agentlab/agents/tool_use_agent/cua_like_agent.py new file mode 100644 index 00000000..8a6259a2 --- /dev/null +++ b/src/agentlab/agents/tool_use_agent/cua_like_agent.py @@ -0,0 +1,767 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from copy import copy, deepcopy +from dataclasses import asdict, dataclass, field +from typing import Any, Literal + +import bgym +from bgym import Benchmark as BgymBenchmark +from browsergym.core.observation import extract_screenshot +from browsergym.utils.obs import ( + flatten_axtree_to_str, + flatten_dom_to_str, + overlay_som, + prune_html, +) + +from agentlab.agents.agent_args import AgentArgs +from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark +from agentlab.llm.base_api import BaseModelArgs +from agentlab.llm.litellm_api import LiteLLMModelArgs +from agentlab.llm.llm_utils import image_to_png_base64_url +from agentlab.llm.response_api import ( + APIPayload, + LLMOutput, + MessageBuilder, +) +from agentlab.llm.tracking import cost_tracker_decorator +from agentlab.utils.hinting import HintsSource + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +ADDITIONAL_ACTION_INSTRUCTIONS = """ +**Important Rules:** +- Coordinates (x, y) must be NUMBERS, not strings +- Do NOT use named parameters for coordinates unless necessary for clarity +- Button parameter is optional, defaults to 'left' +- String values must be in quotes +- Call send_msg_to_user only with a single number in the answer when sending the final answer for evaluation. + +**Correct Examples:** +- mouse_click(347, 192) +- mouse_click(56, 712.56, 'right') +- keyboard_type('hello@example.com') +- keyboard_type('System Diagnostics') +- keyboard_press('ControlOrMeta+v') +- keyboard_press('Escape') +- mouse_drag_and_drop(100, 200, 300, 400) + +**WRONG Examples (DO NOT DO THIS):** +- mouse_click(x='347, 192', y=192) ❌ x is a string with both coords +- mouse_click('347', '192') ❌ coordinates as strings +- "mouse_click(100, 200)" ❌ wrapped in quotes +- keyboard_press(Escape) ❌ string argument missing quotes +- keyboard_type(System Diagnostics) ❌ text argument missing quotes +""" + +simple_bgym_action_tool = { + "name": "perform_action", + "type": "function", + "description": f"""Return a string representation of a Python function call for browsergym actions. + You must return ONLY the function call string, exactly as it would appear in Python code.""", + "parameters": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "The agent's internal chain of thought for performing the action.", + }, + "action": { + "type": "string", + "description": "The Python function call string (e.g., 'mouse_click(100, 200)' or 'keyboard_type(\"hello\")')", + }, + }, + "required": ["thought", "action"], + }, +} + + +def action_from_generalized_bgym_action_tool( + response: LLMOutput, tool_name: str = "perform_action" +) -> tuple[str | None, str | None]: + """Extract the action string from the tool call in the LLM response.""" + action, think = None, None + if response.tool_calls is not None: + for tc in response.tool_calls.tool_calls: + if tc.name == tool_name: + action = tc.arguments.get("action") + think = tc.arguments.get("thought") + break + return action, think + + +@dataclass +class Block(ABC): + def _init(self): + """Initialize the block.""" + pass + + def make(self) -> "Block": + """Returns a copy so the init can start adding some stuff to `self` without changing the + original dataclass that should only contain a config. + The aim is avoid having 2 class definition for each block, e.g. Block and BlockArgs. + + Returns: + Block: A copy of the current block instance with initialization applied. + """ + block = self.__class__(**asdict(self)) + block._init() + return block + + @abstractmethod + def apply(self, llm, messages: list[MessageBuilder], **kwargs): + pass + + +@dataclass +class MsgGroup: + name: str = None + messages: list[MessageBuilder] = field(default_factory=list) + summary: MessageBuilder = None + + @property + def tool_summary(self) -> None: + return [msg for msg in self.messages if msg.role == "tool"] + + @property + def messages_without_images(self) -> list[MessageBuilder]: + _messages = deepcopy(self.messages) + for msg in _messages: + for content in msg.content: + if "image" in content: + content.pop("image") + content["text"] = "[Screenshot Placeholder]" + + return _messages + + +class StructuredDiscussion: + """ + A structured discussion that groups messages into named groups with a potential summary for each group. + + When the discussion is flattened, only the last `keep_last_n_obs` groups are kept in the final list, + the other groups are replaced by their summaries if they have one. + """ + + def __init__(self, keep_last_n_obs=None): + self.groups: list[MsgGroup] = [] + self.keep_last_n_obs: int | None = keep_last_n_obs + + def append(self, message: MessageBuilder): + """Append a message to the last group.""" + self.groups[-1].messages.append(message) + + def new_group(self, name: str = None): + """Start a new group of messages.""" + if name is None: + name = f"group_{len(self.groups)}" + self.groups.append(MsgGroup(name)) + + def flatten(self) -> list[MessageBuilder]: + """Flatten the groups into a single list of messages.""" + + keep_last_n_obs = self.keep_last_n_obs or len(self.groups) + messages = [] + for i, group in enumerate(self.groups): + is_tail = i >= len(self.groups) - keep_last_n_obs + + if not is_tail: + if group.summary is not None: + messages.append(group.summary) + else: + messages.extend(group.messages_without_images) + + else: + messages.extend(group.messages) + # Mark all summarized messages for caching + if i == len(self.groups) - keep_last_n_obs: + for msg in messages: # unset previous cache breakpoints + msg._cache_breakpoint = False + # set new cache breakpoint + messages[i].mark_all_previous_msg_for_caching() + return messages + + def set_last_summary(self, summary: MessageBuilder): + # append None to summaries until we reach the current group index + self.groups[-1].summary = summary + + def get_last_summary(self) -> MessageBuilder | None: + """Get the last summary message.""" + if len(self.groups) == 0: + return None + return self.groups[-1].summary + + def is_goal_set(self) -> bool: + """Check if the goal is set in the first group.""" + return len(self.groups) > 0 + + +SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal. +You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure. +""" + + +@dataclass +class Goal(Block): + """Block to add the goal to the messages.""" + + goal_as_system_msg: bool = True + + def apply( + self, llm, discussion: StructuredDiscussion, obs: dict, sys_msg: str = SYS_MSG + ) -> dict: + system_message = llm.msg.system().add_text(sys_msg) + discussion.append(system_message) + + if self.goal_as_system_msg: + goal_message = llm.msg.system() + else: + goal_message = llm.msg.user() + + goal_message.add_text("# Goal:\n") + for content in obs["goal_object"]: + if content["type"] == "text": + goal_message.add_text(content["text"]) + elif content["type"] == "image_url": + goal_message.add_image(content["image_url"]) + discussion.append(goal_message) + + +AXTREE_NOTE = """ +AXTree extracts most of the interactive elements of the DOM in a tree structure. It may also contain information that is not visible in the screenshot. +A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools, e.g, click(bid="a253"). Make sure to include letters and numbers in the bid. +""" + + +@dataclass +class Obs(Block): + """Block to add the observation to the messages.""" + + use_last_error: bool = True + use_screenshot: bool = True + use_axtree: bool = False + use_dom: bool = False + use_som: bool = False + use_tabs: bool = False + overlay_mouse_action: bool = False + use_zoomed_webpage: bool = False + skip_preprocessing: bool = False + + def _init(self): + self._last_observation = None + + def apply( + self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput + ) -> dict: + obs_msg = llm.msg.user() + tool_calls = last_llm_output.tool_calls + # add the tool call response first in the observation + # to maintain continuity with last response. + if tool_calls: + for call in tool_calls: + call.response_text("See Observation") + tool_response = llm.msg.add_responded_tool_calls(tool_calls) + discussion.append(tool_response) + + if self.use_last_error: + if obs["last_action_error"] != "": + obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}") + + if self.use_screenshot: + if self.use_som: + screenshot = obs["screenshot_som"] + else: + screenshot = obs["screenshot"] + + if self.overlay_mouse_action and self._last_observation is not None: + self.overlay_last_screenshot_with_action( + discussion, obs["last_action"], self._last_observation + ) + + obs_msg.add_image(image_to_png_base64_url(screenshot)) + if self.use_axtree: + obs_msg.add_text(f"AXTree:\n{AXTREE_NOTE}\n{obs['axtree_txt']}") + if self.use_dom: + obs_msg.add_text(f"DOM:\n{obs['pruned_html']}") + if self.use_tabs: + obs_msg.add_text(_format_tabs(obs)) + + discussion.append(obs_msg) + self._last_observation = deepcopy(obs) + return obs_msg + + @staticmethod + def overlay_last_screenshot_with_action(discussion: StructuredDiscussion, action, obs): + """Update the last image with new_image_base64 overlayed with the action.""" + import base64 + from agentlab.analyze import overlay_utils + from PIL import Image + from io import BytesIO + + for msg_groups in reversed(discussion.groups): + for msg in reversed(msg_groups.messages): + for content in reversed(msg.content): + if "image" in content: + data_url = content["image"] + header, encoded = data_url.split(",", 1) + new_obs_properties = deepcopy(obs["extra_element_properties"]) + sc = Image.open(BytesIO(base64.b64decode(encoded))) + overlay_utils.annotate_action(sc, action, properties=new_obs_properties) + new_base64_image = image_to_png_base64_url(sc) + content["image"] = new_base64_image + return + + +def _format_tabs(obs): + """Format the open tabs in a llm-readable way.""" + prompt_pieces = ["Currently open tabs:"] + for page_index, (page_url, page_title) in enumerate( + zip(obs["open_pages_urls"], obs["open_pages_titles"]) + ): + active_or_not = " (active tab)" if page_index == obs["active_page_index"] else "" + prompt_piece = f"""\ +Tab {page_index}{active_or_not}: + Title: {page_title} + URL: {page_url} +""" + prompt_pieces.append(prompt_piece) + return "\n".join(prompt_pieces) + + +@dataclass +class GeneralHints(Block): + use_hints: bool = True + + def apply(self, llm, discussion: StructuredDiscussion) -> dict: + if not self.use_hints: + return + + hints = [] + + hints.append( + """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" + ) + # simulated a hint. + # hints.append( + # """Remember to submit the form once all the fields are filled out.\n""" + # ) + + discussion.append(llm.msg.user().add_text("\n".join(hints))) + + +@dataclass +class Summarizer(Block): + """Block to summarize the last action and the current state of the environment.""" + + do_summary: bool = False + high_details: bool = True + + def apply(self, llm, discussion: StructuredDiscussion) -> dict: + if not self.do_summary: + + return + + msg = llm.msg.user().add_text("""Summarize\n""") + + discussion.append(msg) + + summary_response = llm(APIPayload(messages=discussion.flatten())) + + summary_msg = llm.msg.assistant().add_text(summary_response.think) + discussion.append(summary_msg) + discussion.set_last_summary(summary_msg) + return summary_msg + + def apply_init(self, llm, discussion: StructuredDiscussion) -> dict: + """Initialize the summarizer block.""" + if not self.do_summary: + return + + system_msg = llm.msg.system() + if self.high_details: + # Add a system message to the LLM to indicate that it should summarize + system_msg.add_text( + """# Summarizer instructions:\nWhen asked to summarize, do the following: +1) Summarize the effect of the last action, with attention to details. +2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it. +3) Reason about the overall task at a high level. +4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none. +5) Reason about the next action to take, based on the current state and the goal. +""" + ) + else: + system_msg.add_text( + """When asked to summarize, give a semantic description of the current state of the environment.""" + ) + discussion.append(system_msg) + + +@dataclass +class TaskHint(Block): + use_task_hint: bool = True + hint_db_rel_path: str = "hint_db.csv" + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" + top_n: int = 4 # Number of top hints to return when using embedding retrieval + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints + embedder_server: str = "http://localhost:5000" + skip_hints_for_current_task: bool = False + skip_hints_for_current_goal: bool = False + + def _init(self): + """Initialize the block.""" + if self.use_task_hint: + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + top_n=self.top_n, + embedder_model=self.embedder_model, + embedder_server=self.embedder_server, + skip_hints_for_current_task=self.skip_hints_for_current_task, + skip_hints_for_current_goal=self.skip_hints_for_current_goal, + ) + + def apply(self, llm, discussion: StructuredDiscussion, obs: dict, task_name: str) -> dict: + if not self.use_task_hint: + return {} + + try: + goal_text = obs["goal_object"][0]["text"] + except (KeyError, IndexError): + Warning("Goal text not found in observation") + goal_text = "" + task_hints = self.hints_source.choose_hints(llm, task_name, goal_text) + + hints = [] + for hint in task_hints: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + if len(hints) > 0: + hints_str = ( + "\n# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(hints) + ) + msg = llm.msg.user().add_text(hints_str) + + discussion.append(msg) + + +@dataclass +class PromptConfig: + tag_screenshot: bool = True # Whether to tag the screenshot with the last action. + goal: Goal = None + obs: Obs = None + summarizer: Summarizer = None + general_hints: GeneralHints = None + task_hint: TaskHint = None + keep_last_n_obs: int = 1 + multiaction: bool = False + action_subsets: tuple[str] = None + use_noop_as_default_action: bool = False + use_generalized_bgym_action_tool: bool = True + + +@dataclass +class ToolUseAgentArgs(AgentArgs): + model_args: BaseModelArgs = None + config: PromptConfig = None + use_raw_page_output: bool = False # This attribute is used in loop.py to setup the env. + action_set: bgym.AbstractActionSet | None = None + + def __post_init__(self): + try: + self.agent_name = f"CUAv2-{self.model_args.model_name}".replace("/", "_") + if self.config.task_hint.use_task_hint: + if self.config.task_hint.hint_retrieval_mode == "direct": + self.agent_name += f"-direct-hint" + if self.config.task_hint.hint_retrieval_mode == "emb": + self.agent_name += f"-emb-hint" + if self.config.task_hint.hint_retrieval_mode == "llm": + self.agent_name += f"-llm-hint" + + except AttributeError: + pass + + def make_agent(self) -> bgym.Agent: + if self.config is None: + self.config = DEFAULT_PROMPT_CONFIG + return ToolUseAgent( + model_args=self.model_args, # type: ignore + config=self.config, + action_set=self.action_set, + ) + + def prepare(self): + return self.model_args.prepare_server() + + def close(self): + return self.model_args.close_server() + + def set_benchmark(self, benchmark: AgentLabBenchmark | BgymBenchmark, demo_mode: bool): + """Set benchmark specific flags.""" + benchmark_name = benchmark.name + if benchmark_name == "osworld": + self.config.obs.skip_preprocessing = True + + self.config.obs.use_tabs = benchmark.is_multi_tab + benchmark_action_set = ( + deepcopy(benchmark.high_level_action_set_args).make_action_set().action_set + ) + # these actions are added based on the benchmark action set + if "send_msg_to_user" in benchmark_action_set: + self.config.action_subsets += ("chat",) + if "report_infeasible" in benchmark_action_set: + self.config.action_subsets += ("infeas",) + + +class ToolUseAgent(bgym.Agent): + def __init__( + self, + model_args: LiteLLMModelArgs, + config: PromptConfig = None, + action_set: bgym.AbstractActionSet | None = None, + ): + self.model_args = model_args + self.config = config + self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet( + self.config.action_subsets, + multiaction=self.config.multiaction, # type: ignore + ) + if self.config.use_generalized_bgym_action_tool: + self.tools = [simple_bgym_action_tool] + else: + self.tools = self.action_set.to_tool_description(api=model_args.api) + + self.call_ids = [] + + self.llm = model_args.make_model() + self.msg_builder = model_args.get_message_builder() + self.llm.msg = self.msg_builder + + self.task_hint = self.config.task_hint.make() + self.obs_block = self.config.obs.make() + + self.discussion = StructuredDiscussion(self.config.keep_last_n_obs) + self.last_response: LLMOutput = LLMOutput() + self._responses: list[LLMOutput] = [] + + def obs_preprocessor(self, obs): + obs = copy(obs) + if self.config.obs.skip_preprocessing: + return obs + page = obs.pop("page", None) + if page is not None: + obs["screenshot"] = extract_screenshot(page) + else: + if self.config.obs.use_dom: + obs["dom_txt"] = flatten_dom_to_str( + obs["dom_object"], + extra_properties=obs["extra_element_properties"], + ) + obs["pruned_html"] = prune_html(obs["dom_txt"]) + + if self.config.obs.use_axtree: + obs["axtree_txt"] = flatten_axtree_to_str( + obs["axtree_object"], + extra_properties=obs["extra_element_properties"], + ) + + if self.config.obs.use_som: + obs["screenshot_som"] = overlay_som( + obs["screenshot"], extra_properties=obs["extra_element_properties"] + ) + if self.config.obs.use_zoomed_webpage: + pass + + return obs + + def set_task_name(self, task_name: str): + """Cheater function that is supposed to be called by loop.py before callling get_action""" + self.task_name = task_name + + @cost_tracker_decorator + def get_action(self, obs: Any) -> float: + self.llm.reset_stats() + if not self.discussion.is_goal_set(): + self.discussion.new_group("goal") + + if self.config.multiaction: + sys_msg = SYS_MSG + "\nYou can take multiple actions in a single step, if needed." + else: + sys_msg = SYS_MSG + "\nYou can only take one action at a time." + + sys_msg += ( + "\nAvailable browsergym actions that can be returned with get_action:\n" + + self.action_set.describe() + ) + sys_msg += ADDITIONAL_ACTION_INSTRUCTIONS + self.config.goal.apply(self.llm, self.discussion, obs, sys_msg) + + self.config.summarizer.apply_init(self.llm, self.discussion) + self.config.general_hints.apply(self.llm, self.discussion) + self.task_hint.apply(self.llm, self.discussion, obs, self.task_name) + + self.discussion.new_group() + + self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response) + + self.config.summarizer.apply(self.llm, self.discussion) + + messages = self.discussion.flatten() + response: LLMOutput = self.llm( + APIPayload( + messages=messages, + tools=self.tools, + tool_choice="any", + cache_tool_definition=True, + cache_complete_prompt=False, + use_cache_breakpoints=True, + ) + ) + + if self.config.use_generalized_bgym_action_tool: + action, think = action_from_generalized_bgym_action_tool(response) + else: + action = response.action + think = response.think + + if action is None and self.config.use_noop_as_default_action: + action = "noop()" + + last_summary = self.discussion.get_last_summary() + if last_summary is not None: + think = last_summary.content[0]["text"] + "\n" + think + else: + # Add the think to the history when use_summarizer is False + if think is not None: + self.discussion.append(self.llm.msg.assistant().add_text(think)) + + self.discussion.new_group() + + self.last_response = response + self._responses.append(response) # may be useful for debugging + + tools_str = json.dumps(self.tools, indent=2) + tools_msg = MessageBuilder("tool_description").add_text(tools_str) + + # Adding these extra messages to visualize in gradio + messages.insert(0, tools_msg) # insert at the beginning of the message + # This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls) + msg = self.llm.msg("tool") + msg.responded_tool_calls = response.tool_calls + messages.append(msg) + + agent_info = bgym.AgentInfo( + think=think, + chat_messages=messages, + stats=self.llm.stats.stats_dict, + ) + return action, agent_info + + +CUA_PROMPT_CONFIG = PromptConfig( + tag_screenshot=True, + goal=Goal(goal_as_system_msg=True), + obs=Obs( + use_last_error=True, + use_screenshot=True, + use_axtree=False, + use_dom=False, + use_som=False, + use_tabs=False, + overlay_mouse_action=True, + ), + summarizer=Summarizer(do_summary=False), + general_hints=GeneralHints(use_hints=False), + task_hint=TaskHint(use_task_hint=False), + action_subsets=("coord",), + keep_last_n_obs=5, # no more than 20 screenshots for claude + multiaction=True, + use_noop_as_default_action=False, + use_generalized_bgym_action_tool=True, +) + + +def get_cua_like_agent_config_with_hint( + model_name: str, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", +) -> ToolUseAgentArgs: + config = deepcopy(CUA_PROMPT_CONFIG) + config.task_hint.use_task_hint = True + config.task_hint.hint_db_rel_path = hint_db_path + config.task_hint.hint_retrieval_mode = hint_retrieval_mode + return ToolUseAgentArgs( + model_args=LiteLLMModelArgs( + model_name=model_name, + max_new_tokens=2000, + temperature=None, # NONE for claude-4-5 to enable reasoning effort. + ), + config=config, + ) + + +def get_cua_like_agent_config_with_hint_skip_for_current_goal( + model_name: str, + hint_db_path: str, + hint_retrieval_mode: Literal["llm", "emb"] = "llm", +) -> ToolUseAgentArgs: + config = deepcopy(CUA_PROMPT_CONFIG) + config.task_hint.use_task_hint = True + config.task_hint.skip_hints_for_current_goal = True + config.task_hint.hint_db_rel_path = hint_db_path + config.task_hint.hint_retrieval_mode = hint_retrieval_mode + return ToolUseAgentArgs( + model_args=LiteLLMModelArgs( + model_name=model_name, + max_new_tokens=2000, + temperature=None, # NONE for claude-4-5 to enable reasoning effort. + ), + config=config, + ) + + +def get_cua_like_agent_config(model_name: str) -> ToolUseAgentArgs: + + return ToolUseAgentArgs( + model_args=LiteLLMModelArgs( + model_name=model_name, + max_new_tokens=2000, + temperature=None, + ), + config=CUA_PROMPT_CONFIG, + ) + + +CUA_LIKE_CLAUDE_4_SONNET = get_cua_like_agent_config("anthropic/claude-sonnet-4-20250514") + + +if __name__ == "__main__": + + from agentlab.agents.tool_use_agent.cua_like_agent import CUA_LIKE_CLAUDE_4_SONNET + from agentlab.experiments.study import Study + import bgym + import logging + + logging.getLogger().setLevel(logging.INFO) + os.environ["LITELLM_LOG"] = "WARNING" + + benchmark = "workarena_l1" + benchmark = bgym.DEFAULT_BENCHMARKS[benchmark](n_repeats=2) + benchmark = benchmark.subset_from_glob("task_name", "*create*") + for env_args in benchmark.env_args_list: + env_args.max_steps = 20 # increase the number of steps for coord agent testing + + agent_args = [CUA_LIKE_CLAUDE_4_SONNET] + study = Study(agent_args, benchmark, logging_level_stdout=logging.WARNING) + study.run( + n_jobs=5, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=1, + ) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 062d8ef3..a8c4a9e7 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -325,12 +325,17 @@ def _init(self): embedder_server=self.embedder_server, ) - def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: + def apply(self, llm, discussion: StructuredDiscussion, obs: dict, task_name: str) -> dict: if not self.use_task_hint: return {} - goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) - task_hints = self.hints_source.choose_hints(llm, task_name, goal) + # goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) + try: + goal_text = obs["goal_object"][0]["text"] + except (KeyError, IndexError): + Warning("Goal text not found in observation") + goal_text = "" + task_hints = self.hints_source.choose_hints(llm, task_name, goal_text) hints = [] for hint in task_hints: @@ -472,7 +477,7 @@ def get_action(self, obs: Any) -> float: self.config.summarizer.apply_init(self.llm, self.discussion) self.config.general_hints.apply(self.llm, self.discussion) - self.task_hint.apply(self.llm, self.discussion, self.task_name) + self.task_hint.apply(self.llm, self.discussion, obs=obs, task_name=self.task_name) self.discussion.new_group() diff --git a/src/agentlab/utils/hinting.py b/src/agentlab/utils/hinting.py index 506513d5..83e55efc 100644 --- a/src/agentlab/utils/hinting.py +++ b/src/agentlab/utils/hinting.py @@ -52,7 +52,6 @@ def __init__( self.hint_db_path, header=0, index_col=None, - dtype=str, converters={ "trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [], "source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [], diff --git a/uv.lock b/uv.lock index eaf9483c..fbfd4be8 100644 --- a/uv.lock +++ b/uv.lock @@ -32,43 +32,39 @@ dependencies = [ { name = "pyyaml" }, { name = "ray", extra = ["default"] }, { name = "requests" }, - { name = "safetensors" }, { name = "tiktoken" }, { name = "torch" }, - { name = "transformers" }, ] [package.optional-dependencies] +hint = [ + { name = "sentence-transformers" }, +] +transformers = [ + { name = "transformers" }, +] + +[package.dev-dependencies] dev = [ { name = "black", extra = ["jupyter"] }, { name = "blacken-docs" }, + { name = "darglint" }, { name = "flaky" }, + { name = "ipykernel" }, + { name = "pip" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-playwright" }, { name = "pytest-xdist" }, ] -[package.dev-dependencies] -dev = [ - { name = "darglint" }, - { name = "ipykernel" }, - { name = "pip" }, -] -hint = [ - { name = "sentence-transformers" }, -] - [package.metadata] requires-dist = [ { name = "anthropic", specifier = ">=0.62.0" }, - { name = "black", extras = ["jupyter"], marker = "extra == 'dev'", specifier = ">=24.2.0" }, - { name = "blacken-docs", marker = "extra == 'dev'" }, { name = "browsergym", specifier = ">=0.7.1" }, { name = "contexttimer" }, { name = "dask" }, { name = "distributed" }, - { name = "flaky", marker = "extra == 'dev'" }, { name = "gitpython" }, { name = "gradio", specifier = ">=5.5" }, { name = "gymnasium", specifier = ">=0.27" }, @@ -80,30 +76,32 @@ requires-dist = [ { name = "openai", specifier = ">=1.7,<2" }, { name = "pandas" }, { name = "pillow" }, - { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pydantic", specifier = "~=2.9" }, - { name = "pytest", marker = "extra == 'dev'", specifier = "==7.3.2" }, - { name = "pytest-playwright", marker = "extra == 'dev'" }, - { name = "pytest-xdist", marker = "extra == 'dev'" }, { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "python-slugify" }, { name = "pyyaml", specifier = ">=6" }, { name = "ray", extras = ["default"] }, { name = "requests" }, - { name = "safetensors", specifier = ">=0.4.0" }, + { name = "sentence-transformers", marker = "extra == 'hint'", specifier = ">=5.0.0" }, { name = "tiktoken" }, { name = "torch", specifier = ">=2.2.2" }, - { name = "transformers", specifier = ">=4.38.2" }, + { name = "transformers", marker = "extra == 'transformers'", specifier = ">=4.38.2" }, ] -provides-extras = ["dev"] +provides-extras = ["hint", "transformers"] [package.metadata.requires-dev] dev = [ + { name = "black", extras = ["jupyter"], specifier = ">=24.2.0" }, + { name = "blacken-docs" }, { name = "darglint", specifier = ">=1.8.1" }, + { name = "flaky" }, { name = "ipykernel", specifier = ">=6.30.1" }, { name = "pip", specifier = ">=25.2" }, + { name = "pre-commit" }, + { name = "pytest", specifier = "==7.3.2" }, + { name = "pytest-playwright" }, + { name = "pytest-xdist" }, ] -hint = [{ name = "sentence-transformers", specifier = ">=5.0.0" }] [[package]] name = "aiofiles" From 3b4acd0278dccfda84e060c98bea5ec987d6fa57 Mon Sep 17 00:00:00 2001 From: Nicolas Gontier Date: Fri, 12 Dec 2025 21:03:01 +0000 Subject: [PATCH 8/8] add pricing info for AnthropicChatModel --- src/agentlab/llm/chat_api.py | 93 +++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 9c3b021b..9478d320 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -331,33 +331,32 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float res_think["log_probs"] = completion.choices[0].logprobs return res_think, res_action else: - return [ - self._build_think_action_pair(choice) - for choice in completion.choices - ] + return [self._build_think_action_pair(choice) for choice in completion.choices] - def _extract_thinking_content_from_response(self, response, wrap_tag="think") -> tuple[str, str]: + def _extract_thinking_content_from_response( + self, response, wrap_tag="think" + ) -> tuple[str, str]: """Extract reasoning and action content from an API response. - + Logic: - 1. If reasoning_content exists: use it as think, use content as action + 1. If reasoning_content exists: use it as think, use content as action (remove BEGIN/END FINAL RESPONSE tokens if present, add action tags) 2. If reasoning_content is empty: search content for last BEGIN/END FINAL RESPONSE block, use everything before as think, use content inside tags as action - + Args: response: The API response object. wrap_tag: Tag name to wrap reasoning content (default: "think"). - + Returns: tuple: (think_wrapped, action_wrapped) """ message = response.choices[0].message - msg_dict = message.to_dict() if hasattr(message, 'to_dict') else dict(message) - + msg_dict = message.to_dict() if hasattr(message, "to_dict") else dict(message) + reasoning = msg_dict.get("reasoning_content") or msg_dict.get("reasoning") or "" content = msg_dict.get("content", "") or msg_dict.get("text", "") or "" - + # Case 1: Explicit reasoning field from API if reasoning: think_wrapped = f"<{wrap_tag}>{reasoning}" @@ -365,14 +364,14 @@ def _extract_thinking_content_from_response(self, response, wrap_tag="think") -> action_text = self._remove_final_response_tokens(content) action_wrapped = f"{action_text}" return think_wrapped, action_wrapped - + # Case 2: No reasoning field - parse content for BEGIN/END FINAL RESPONSE if "[BEGIN FINAL RESPONSE]" in content and "[END FINAL RESPONSE]" in content: think_text, action_text = self._parse_apriel_format(content) think_wrapped = f"<{wrap_tag}>{think_text}" if think_text else "" action_wrapped = f"{action_text}" if action_text else "" return think_wrapped, action_wrapped - + # Case 3: No special format - return content as action return "", f"{content}" if content else "" @@ -383,7 +382,7 @@ def _remove_final_response_tokens(self, content: str) -> str: def _extract_last_action_from_tags(self, content: str) -> str: """Extract content from the LAST [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] block.""" - pattern = r'\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]' + pattern = r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]" matches = re.findall(pattern, content, re.DOTALL) return matches[-1].strip() if matches else "" @@ -392,20 +391,18 @@ def _parse_apriel_format(self, content: str) -> tuple[str, str]: last_begin = content.rfind("[BEGIN FINAL RESPONSE]") if last_begin == -1: return "", content - + reasoning = content[:last_begin].strip() if reasoning.startswith("Here are my reasoning steps:"): - reasoning = reasoning[len("Here are my reasoning steps:"):].strip() - + reasoning = reasoning[len("Here are my reasoning steps:") :].strip() + action = self._extract_last_action_from_tags(content) return reasoning, action def _build_think_action_pair(self, choice) -> tuple[AIMessage, AIMessage]: """Build (think, action) pair from a single choice.""" # Create minimal response-like object for the extraction method - mock_response = type('MockResponse', (), { - 'choices': [choice] - })() + mock_response = type("MockResponse", (), {"choices": [choice]})() think, action = self._extract_thinking_content_from_response(mock_response) return AIMessage(think or ""), AIMessage(action or "") @@ -575,12 +572,9 @@ def __init__( max_retry=4, min_retry_wait_time=60, ): - base_url = base_url or os.getenv( - "APRIEL_API_URL", - "" - ) + base_url = base_url or os.getenv("APRIEL_API_URL", "") api_key = api_key or os.getenv("APRIEL_API_KEY") - + super().__init__( model_name=model_name, api_key=api_key, @@ -597,7 +591,7 @@ def __init__( @dataclass class AprielModelArgs(BaseModelArgs): """Serializable args for Apriel models.""" - + base_url: str = None api_key: str = None @@ -619,6 +613,7 @@ def __init__( temperature=0.5, max_tokens=100, max_retry=4, + pricing_func=None, ): self.model_name = model_name self.temperature = temperature @@ -628,6 +623,22 @@ def __init__( api_key = api_key or os.getenv("ANTHROPIC_API_KEY") self.client = anthropic.Anthropic(api_key=api_key) + # Get pricing information + if pricing_func: + pricings = pricing_func() + try: + self.input_cost = float(pricings[model_name]["prompt"]) + self.output_cost = float(pricings[model_name]["completion"]) + except KeyError: + logging.warning( + f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community." + ) + self.input_cost = 0.0 + self.output_cost = 0.0 + else: + self.input_cost = 0.0 + self.output_cost = 0.0 + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: # Convert OpenAI format to Anthropic format system_message = None @@ -655,13 +666,28 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float response = self.client.messages.create(**kwargs) + usage = getattr(response, "usage", {}) + new_input_tokens = getattr(usage, "input_tokens", 0) + output_tokens = getattr(usage, "output_tokens", 0) + cache_read_tokens = getattr(usage, "cache_input_tokens", 0) + cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0) + cache_read_cost = ( + self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] + ) + cache_write_cost = ( + self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] + ) + cost = ( + new_input_tokens * self.input_cost + + output_tokens * self.output_cost + + cache_read_tokens * cache_read_cost + + cache_write_tokens * cache_write_cost + ) # Track usage if available - if hasattr(tracking.TRACKER, "instance"): - tracking.TRACKER.instance( - response.usage.input_tokens, - response.usage.output_tokens, - 0, # cost calculation would need pricing info - ) + if hasattr(tracking.TRACKER, "instance") and isinstance( + tracking.TRACKER.instance, tracking.LLMTracker + ): + tracking.TRACKER.instance(new_input_tokens, output_tokens, cost) return AIMessage(response.content[0].text) @@ -679,6 +705,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + pricing_func=partial(tracking.get_pricing_litellm, model_name=self.model_name), )