diff --git a/.gitignore b/.gitignore index 9507f985..3ecb59e6 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/SETUP_GUIDE.md b/SETUP_GUIDE.md new file mode 100644 index 00000000..753ef6f5 --- /dev/null +++ b/SETUP_GUIDE.md @@ -0,0 +1,333 @@ +# AgentBench Setup Guide (Azure OpenAI + WSL2) + +This guide walks you through setting up and running AgentBench with Azure OpenAI on Windows using WSL2. + +## Prerequisites +- Use Ubuntu +- Windows 10/11 with WSL2 enabled +- Docker Desktop for Windows with WSL2 integration enabled +- An Azure OpenAI resource with a deployed model (e.g., `gpt-4o-mini`) + +--- + +## Step 1: Enable WSL2 Integration in Docker Desktop + +1. Open **Docker Desktop** +2. Go to **Settings** → **Resources** → **WSL Integration** +3. Enable integration with your WSL2 distro (e.g., Ubuntu) +4. Click **Apply & Restart** + +--- + +## Step 2: Clone the Repository + +```bash +# In WSL2 terminal +cd ~ +git clone https://github.com/Jay-Dev01/AgentBench.git +cd AgentBench +git checkout ubuntu-azure-setup +``` + +--- + +## Step 3: Set Up Python Environment + +```bash +# Install Python 3.11 if not available +sudo apt update +sudo apt install -y python3.11 python3.11-venv python3.11-dev + +# Create virtual environment +python3.11 -m venv venv +source venv/bin/activate + +# Install dependencies +pip install --upgrade pip +pip install -r requirements.txt +``` + +--- + +## Step 4: Configure Azure OpenAI API Key + +Set your Azure OpenAI API key as an environment variable: + +```bash +export AZURE_OPENAI_API_KEY="your-azure-api-key-here" +``` + +To make it persistent, add it to your `~/.bashrc`: + +```bash +echo 'export AZURE_OPENAI_API_KEY="your-azure-api-key-here"' >> ~/.bashrc +source ~/.bashrc +``` + +### Finding Your Azure OpenAI Credentials + +1. Go to [Azure Portal](https://portal.azure.com) +2. Navigate to your **Azure OpenAI resource** +3. Click **Keys and Endpoint** +4. Copy **Key 1** or **Key 2** + +The configuration file (`configs/agents/openai-chat.yaml`) is already set up to use: +- **Endpoint**: `https://algoverse-ab.openai.azure.com/` +- **Deployment**: `gpt-4o-mini` +- **API Version**: `2024-08-01-preview` + +If your Azure resource is different, update the URL in `configs/agents/openai-chat.yaml`. + +--- + +## Step 5: Start Docker Services + +```bash +cd ~/AgentBench/extra + +# Start the controller, redis, and alfworld worker +docker compose up -d controller redis alfworld-std + +# Wait for services to initialize (~30-60 seconds) + +# Verify services are running +docker compose ps + +# Check that the worker registered +curl http://localhost:5020/api/list_workers +``` + +You should see output showing `alfworld-std` with workers registered. + +### Verify Direct Worker Access + +```bash +curl http://localhost:5021/api/get_sessions +``` + +This should return `[]` or a list of sessions. + +--- + +## Step 6: Run the Benchmark + +```bash +cd ~/AgentBench +source venv/bin/activate + +# Make sure API key is set +echo $AZURE_OPENAI_API_KEY + +# Run the assigner +python -m src.assigner +``` + +### Expected Output + +``` +TaskClient created: alfworld-std (http://localhost:5020/api) + -> Using direct worker address: http://localhost:5021/api +Message: 109 samples remaining. +Agent "gpt-4o-mini" needs to run 1 tasks with total 109 samples: + Task "alfworld-std": 109 +Running Count: 0 +Assigned gpt-4o-mini/alfworld-std#108 +... +``` + +The benchmark will run through 109 ALFWorld tasks. Results are saved to the `outputs/` directory. + +--- + +## Troubleshooting + +### Rate Limit Errors + +If you see `RateLimitReached` errors, the concurrency is set to 1 in `configs/assignments/default.yaml` to minimize this. You can: + +1. Wait and retry (the error message tells you how long) +2. Increase your Azure quota at [aka.ms/oai/quotaincrease](https://aka.ms/oai/quotaincrease) + +### Connection Refused + +If you get connection errors: + +```bash +# Check Docker services are running +docker compose ps + +# Check controller logs +docker logs agentrl-controller --tail 50 + +# Check worker logs +docker logs agentbench-fc-alfworld-std-1 --tail 50 + +# Restart services +docker compose down +docker compose up -d controller redis alfworld-std +``` + +### Worker Not Registering + +If `curl http://localhost:5020/api/list_workers` shows empty workers: + +```bash +# Check worker logs for errors +docker logs agentbench-fc-alfworld-std-1 --tail 100 + +# Rebuild and restart +docker compose down +docker compose build alfworld-std +docker compose up -d controller redis alfworld-std +``` + +--- + +## Configuration Files + +| File | Purpose | +|------|---------| +| `configs/agents/openai-chat.yaml` | Azure OpenAI endpoint and API key | +| `configs/agents/api_agents.yaml` | Agent definitions (gpt-4o-mini) | +| `configs/assignments/default.yaml` | Task assignments and concurrency | +| `configs/assignments/definition.yaml` | Controller address (port 5020) | +| `extra/docker-compose.yml` | Docker service definitions | + +--- + +## Architecture Overview + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ Python │ │ Controller │ │ ALFWorld │ +│ Assigner │────▶│ (port 5020) │────▶│ Worker │ +│ │ │ │ │ (port 5021) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ (direct communication - bypasses controller) │ + └─────────────────────────────────────────────────┘ + + │ + ▼ +┌─────────────────┐ +│ Azure OpenAI │ +│ (gpt-4o-mini) │ +└─────────────────┘ +``` + +**Note:** The Python client talks directly to the worker (port 5021) because the controller has a bug that prevents proper `/interact` forwarding. + +--- + +## Stopping Services + +```bash +cd ~/AgentBench/extra +docker compose down +``` + +--- + +## Running Different Tasks + +All tasks are pre-configured. To switch between tasks: + +### Option 1: Use the run script + +```bash +chmod +x run_task.sh +./run_task.sh alfworld-std # or dbbench-std, os-std, kg-std, webshop-std +``` + +### Option 2: Manual setup + +#### 1. Edit `configs/assignments/default.yaml` + +Uncomment the task you want to run: + +```yaml + task: + # - alfworld-std # House-holding tasks + - dbbench-std # Database tasks (uncomment this one) + # - os-std # OS interaction tasks + # - kg-std # Knowledge graph tasks + # - webshop-std # Web shopping tasks +``` + +#### 2. Start the Docker service + +```bash +cd ~/AgentBench/extra + +# For alfworld (house-holding) +docker compose up -d controller redis alfworld-std + +# For dbbench (database) +docker compose up -d controller redis dbbench-std + +# For os-std (OS interaction) - requires building images first +docker compose up -d controller redis os_interaction-std + +# For kg-std (knowledge graph) - requires freebase data +docker compose up -d controller redis knowledgegraph-std freebase + +# For webshop (web shopping) - requires ~16GB RAM +docker compose up -d controller redis webshop-std +``` + +#### 3. Run the assigner + +```bash +cd ~/AgentBench +source venv/bin/activate +python -m src.assigner +``` + +--- + +## Task-Specific Requirements + +### OS Interaction (os-std) + +Build the required Docker images first: + +```bash +cd ~/AgentBench +docker build -t local-os/default -f data/os_interaction/res/dockerfiles/default data/os_interaction/res/dockerfiles +docker build -t local-os/packages -f data/os_interaction/res/dockerfiles/packages data/os_interaction/res/dockerfiles +docker build -t local-os/ubuntu -f data/os_interaction/res/dockerfiles/ubuntu data/os_interaction/res/dockerfiles +``` + +### Knowledge Graph (kg-std) + +Requires Freebase data: + +1. Download data from [Freebase-Setup](https://github.com/dki-lab/Freebase-Setup) +2. Extract and place at `./extra/virtuoso_db/virtuoso.db` +3. Start with: `docker compose up -d controller redis knowledgegraph-std freebase` + +### WebShop (webshop-std) + +- Requires ~16GB RAM +- Takes ~3 minutes to start +- Start with: `docker compose up -d controller redis webshop-std` + +--- + +## Port Mapping Reference + +| Task | Host Port | Worker Port | +|------|-----------|-------------| +| Controller | 5020 | 5020 | +| alfworld-std | 5021 | 5021 | +| dbbench-std | 5022 | 5021 | +| os-std | 5023 | 5021 | +| kg-std | 5024 | 5021 | +| webshop-std | 5025 | 5021 | + +--- + +## License + +Apache-2.0 - See [LICENSE](LICENSE) for details. + diff --git a/configs/agents/api_agents.yaml b/configs/agents/api_agents.yaml index 30a46f4d..4e5c44ac 100644 --- a/configs/agents/api_agents.yaml +++ b/configs/agents/api_agents.yaml @@ -1,9 +1,8 @@ -gpt-3.5-turbo-0613: +gpt-4o-mini: import: "./openai-chat.yaml" parameters: - name: "gpt-3.5-turbo-0613" + name: "gpt-4o-mini" body: - model: "gpt-3.5-turbo-0613" max_tokens: 512 text-davinci-003: diff --git a/configs/agents/openai-chat.yaml b/configs/agents/openai-chat.yaml index 53eff77d..86c1eca5 100644 --- a/configs/agents/openai-chat.yaml +++ b/configs/agents/openai-chat.yaml @@ -1,13 +1,11 @@ module: src.client.agents.HTTPAgent parameters: - url: https://api.openai.com/v1/chat/completions + url: https://algoverse-ab.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview headers: Content-Type: application/json - Authorization: Bearer <% PUT-YOUR-OPENAI-KEY-HERE %> + api-key: ${AZURE_OPENAI_API_KEY} body: temperature: 0 prompter: - name: role_content_dict - args: - agent_role: assistant - return_format: "{response[choices][0][message][content]}" + name: openai_passthrough + return_format: openai_chat diff --git a/configs/assignments/default.yaml b/configs/assignments/default.yaml index aec7229a..ced9e912 100644 --- a/configs/assignments/default.yaml +++ b/configs/assignments/default.yaml @@ -2,16 +2,23 @@ import: definition.yaml concurrency: task: - dbbench-std: 5 - os-std: 5 + alfworld-std: 1 + dbbench-std: 1 + os-std: 1 + kg-std: 1 + webshop-std: 1 agent: - gpt-3.5-turbo-0613: 5 + gpt-4o-mini: 1 assignments: # List[Assignment] | Assignment - - agent: # "task": List[str] | str , "agent": List[str] | str - - gpt-3.5-turbo-0613 + - agent: + - gpt-4o-mini task: - - dbbench-std - - os-std + # ===== UNCOMMENT THE TASK(S) YOU WANT TO RUN ===== + - alfworld-std # House-holding tasks (ALFWorld) + # - dbbench-std # Database tasks + # - os-std # OS interaction tasks + # - kg-std # Knowledge graph tasks (requires freebase) + # - webshop-std # Web shopping tasks (requires ~16GB RAM) output: "outputs/{TIMESTAMP}" diff --git a/configs/assignments/definition.yaml b/configs/assignments/definition.yaml index 6b00cac0..58a5a66c 100644 --- a/configs/assignments/definition.yaml +++ b/configs/assignments/definition.yaml @@ -3,7 +3,7 @@ definition: overwrite: module: src.client.TaskClient parameters: - controller_address: "http://localhost:5000/api" + controller_address: "http://localhost:5020/api" import: ../tasks/task_assembly.yaml agent: import: diff --git a/extra/docker-compose.yml b/extra/docker-compose.yml index e9399e81..9947b43e 100644 --- a/extra/docker-compose.yml +++ b/extra/docker-compose.yml @@ -5,7 +5,8 @@ services: controller: image: jingbh/agentrl-controller:latest container_name: agentrl-controller - network_mode: host + ports: + - "5020:5020" command: - controller @@ -13,7 +14,9 @@ services: build: context: .. dockerfile: src/server/tasks/alfworld/Dockerfile - command: --controller http://172.17.0.1:5020/api alfworld-std + command: --controller http://controller:5020/api alfworld-std + ports: + - "5021:5021" deploy: mode: replicated replicas: 1 @@ -24,7 +27,9 @@ services: build: context: .. dockerfile: src/server/tasks/dbbench/Dockerfile - command: --controller http://172.17.0.1:5020/api dbbench-std + command: --controller http://controller:5020/api dbbench-std + ports: + - "5022:5021" volumes: - /var/run/docker.sock:/var/run/docker.sock environment: @@ -39,7 +44,9 @@ services: build: context: .. dockerfile: src/server/tasks/knowledgegraph/Dockerfile - command: --controller http://172.17.0.1:5020/api kg-std + command: --controller http://controller:5020/api kg-std + ports: + - "5024:5021" environment: - KG_STD_PARAMETERS_ENV_OPTIONS_URLS_KG=http://freebase:3001/sparql deploy: @@ -53,7 +60,9 @@ services: build: context: .. dockerfile: src/server/tasks/os_interaction/Dockerfile - command: --controller http://172.17.0.1:5020/api os-std + command: --controller http://controller:5020/api os-std + ports: + - "5023:5021" volumes: - /var/run/docker.sock:/var/run/docker.sock environment: @@ -68,7 +77,9 @@ services: build: context: .. dockerfile: src/server/tasks/webshop/Dockerfile - command: --controller http://172.17.0.1:5020/api webshop-std + command: --controller http://controller:5020/api webshop-std + ports: + - "5025:5021" deploy: mode: replicated replicas: 1 @@ -86,4 +97,5 @@ services: redis: image: redis:7 container_name: redis - network_mode: host + ports: + - "6379:6379" diff --git a/run_task.sh b/run_task.sh new file mode 100644 index 00000000..55e7737c --- /dev/null +++ b/run_task.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# AgentBench Task Runner +# Usage: ./run_task.sh +# Example: ./run_task.sh alfworld-std + +TASK=${1:-alfworld-std} + +echo "==========================================" +echo "AgentBench Task Runner" +echo "==========================================" +echo "Task: $TASK" +echo "" + +# Map task names to docker service names +case $TASK in + alfworld-std|alfworld) + SERVICE="alfworld-std" + TASK_NAME="alfworld-std" + ;; + dbbench-std|dbbench|db) + SERVICE="dbbench-std" + TASK_NAME="dbbench-std" + ;; + os-std|os|os_interaction) + SERVICE="os_interaction-std" + TASK_NAME="os-std" + ;; + kg-std|kg|knowledgegraph) + SERVICE="knowledgegraph-std" + TASK_NAME="kg-std" + echo "WARNING: kg-std requires freebase data. See README for setup." + ;; + webshop-std|webshop) + SERVICE="webshop-std" + TASK_NAME="webshop-std" + echo "WARNING: webshop requires ~16GB RAM" + ;; + *) + echo "Unknown task: $TASK" + echo "" + echo "Available tasks:" + echo " alfworld-std - House-holding tasks (ALFWorld)" + echo " dbbench-std - Database tasks" + echo " os-std - OS interaction tasks" + echo " kg-std - Knowledge graph tasks (requires freebase)" + echo " webshop-std - Web shopping tasks (requires ~16GB RAM)" + exit 1 + ;; +esac + +# Get script directory +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +echo "Step 1: Updating config to use $TASK_NAME..." +# Update the default.yaml to use the selected task +sed -i "s/^ - alfworld-std/ # - alfworld-std/" configs/assignments/default.yaml +sed -i "s/^ - dbbench-std/ # - dbbench-std/" configs/assignments/default.yaml +sed -i "s/^ - os-std/ # - os-std/" configs/assignments/default.yaml +sed -i "s/^ - kg-std/ # - kg-std/" configs/assignments/default.yaml +sed -i "s/^ - webshop-std/ # - webshop-std/" configs/assignments/default.yaml +sed -i "s/^ # - $TASK_NAME/ - $TASK_NAME/" configs/assignments/default.yaml + +echo "Step 2: Starting Docker services..." +cd extra +docker compose up -d controller redis $SERVICE + +echo "" +echo "Step 3: Waiting for services to start (15 seconds)..." +sleep 15 + +echo "" +echo "Step 4: Checking worker registration..." +curl -s http://localhost:5020/api/list_workers | python3 -m json.tool 2>/dev/null || curl -s http://localhost:5020/api/list_workers + +echo "" +echo "==========================================" +echo "Ready to run! Execute:" +echo " cd $SCRIPT_DIR" +echo " source venv/bin/activate" +echo " python -m src.assigner" +echo "==========================================" + diff --git a/src/client/agents/http_agent.py b/src/client/agents/http_agent.py index 677f668c..37cb0bb8 100644 --- a/src/client/agents/http_agent.py +++ b/src/client/agents/http_agent.py @@ -95,6 +95,24 @@ def prompter(messages: List[Dict[str, str]]): return prompter + @staticmethod + def openai_passthrough(): + """Pass through OpenAI-format messages and tools directly.""" + def prompter(data): + # If it's already in OpenAI format (dict with messages/tools), pass through + if isinstance(data, dict) and "messages" in data: + result = {"messages": data["messages"]} + if "tools" in data: + result["tools"] = data["tools"] + return result + # Otherwise, convert from old format + prompt = [] + for item in data: + role = "assistant" if item.get("role") == "agent" else item.get("role", "user") + prompt.append({"role": role, "content": item.get("content", "")}) + return {"messages": prompt} + return prompter + @staticmethod def prompt_string( prefix: str = "", @@ -185,7 +203,7 @@ def __init__( def _handle_history(self, history: List[dict]) -> Dict[str, Any]: return self.prompter(history) - def inference(self, history: List[dict]) -> str: + def inference(self, history) -> str: for _ in range(3): try: body = self.body.copy() @@ -210,6 +228,23 @@ def inference(self, history: List[dict]) -> str: pass else: resp = resp.json() + # Handle OpenAI chat completion format + if self.return_format == "openai_chat": + if "choices" in resp and resp["choices"]: + message = resp["choices"][0].get("message", {}) + # Check for tool calls first + if "tool_calls" in message and message["tool_calls"]: + tool_call = message["tool_calls"][0] + # Return the action from the tool call arguments + import json + try: + args = json.loads(tool_call["function"]["arguments"]) + return args.get("action", "") + except: + return tool_call["function"]["arguments"] + # Otherwise return content + return message.get("content", "") + return "" return self.return_format.format(response=resp) time.sleep(_ + 2) raise Exception("Failed.") diff --git a/src/client/task.py b/src/client/task.py index ff6969ba..a305eb61 100644 --- a/src/client/task.py +++ b/src/client/task.py @@ -1,4 +1,6 @@ import enum +import json +import random import requests @@ -15,13 +17,28 @@ class TaskError(enum.Enum): NOT_AVAILABLE = "NOT_AVAILABLE" +# Direct worker address mapping (bypassing buggy controller) +# Add port mappings in docker-compose.yml for each worker you want to use +WORKER_ADDRESSES = { + "alfworld-std": "http://localhost:5021/api", + "dbbench-std": "http://localhost:5022/api", + "os-std": "http://localhost:5023/api", + "kg-std": "http://localhost:5024/api", + "webshop-std": "http://localhost:5025/api", +} + + class TaskClient: def __init__( self, name: str, controller_address: str = "http://localhost:5000/api", *_, **__, ) -> None: self.name = name self.controller_address = controller_address + # Use direct worker address if available + self.worker_address = WORKER_ADDRESSES.get(name) print("TaskClient created: {} ({})".format(name, controller_address)) + if self.worker_address: + print(f" -> Using direct worker address: {self.worker_address}") def get_indices(self) -> List[SampleIndex]: result = requests.get( @@ -47,82 +64,234 @@ def get_concurrency(self) -> int: return 0 concurrency = 0 for worker in result[self.name]["workers"].values(): - if worker["status"] == WorkerStatus.ALIVE: + # API returns status as string "ALIVE", not the IntEnum value + if worker["status"] == "ALIVE": concurrency += worker["capacity"] - worker["current"] return concurrency def run_sample(self, index: SampleIndex, agent: AgentClient) -> TaskClientOutput: + # Use direct worker communication if available (bypasses buggy controller) + if self.worker_address: + return self._run_sample_direct(index, agent) + else: + return self._run_sample_via_controller(index, agent) + + def _run_sample_direct(self, index: SampleIndex, agent: AgentClient) -> TaskClientOutput: + """Run sample by talking directly to the worker (new OpenAI-style API).""" + # Generate a unique session ID + sid = random.randint(10000, 99999) + + # Track conversation history for output + history = [] + conversation_messages = [] # Full conversation for agent + try: - result = requests.post( - self.controller_address + "/start_sample", - json=StartSampleRequest(name=self.name, index=index).dict(), + # Start sample on worker directly + response = requests.post( + self.worker_address + "/start_sample", + json={"index": index, "session_id": sid}, + timeout=60 ) except Exception as e: return TaskClientOutput(error=TaskError.NETWORK_ERROR.value, info=str(e)) - if result.status_code == 406: + + if response.status_code != 200: return TaskClientOutput( - error=TaskError.NOT_AVAILABLE.value, info=result.text + error=TaskError.START_FAILED.value, info=response.text ) - if result.status_code != 200: - return TaskClientOutput( - error=TaskError.START_FAILED.value, info=result.text - ) - result = result.json() - sid = result["session_id"] - latest_result = result - while SampleStatus(result["output"]["status"]) == SampleStatus.RUNNING: + + result = response.json() + + # Extract messages and tools from response + messages = result.get("messages", []) + tools = result.get("tools", []) + conversation_messages = messages.copy() + + # Record initial messages in history + for msg in messages: + content = msg.get("content", "") + if msg["role"] in ("system", "user"): + history.append(ChatHistoryItem(role="user", content=content or "")) + + max_turns = 50 + turn = 0 + + while result.get("status") == "running" and not result.get("finish", False) and turn < max_turns: + turn += 1 + try: - content = agent.inference(result["output"]["history"]) - response = AgentOutput(content=content) + # Prepare input for agent - full conversation with tools + agent_input = { + "messages": conversation_messages, + "tools": tools + } + content = agent.inference(agent_input) except AgentContextLimitException: - response = AgentOutput(status=AgentOutputStatus.AGENT_CONTEXT_LIMIT) + # Cancel and return + try: + requests.post(self.worker_address + "/cancel", json={"session_id": sid}, timeout=10) + except: + pass + return TaskClientOutput( + output=TaskOutput(status=SampleStatus.AGENT_CONTEXT_LIMIT, history=history) + ) except Exception as e: - if hasattr(agent, "model_name"): - model_name = agent.model_name - elif hasattr(agent, "name"): - model_name = agent.name - else: - model_name = agent.__class__.__name__ + model_name = getattr(agent, "model_name", None) or getattr(agent, "name", None) or agent.__class__.__name__ print(f"ERROR: {model_name}/{self.name} agent error", e) - requests.post( - self.controller_address + "/cancel", - json=CancelRequest(session_id=sid).dict(), - ) + try: + requests.post(self.worker_address + "/cancel", json={"session_id": sid}, timeout=10) + except: + pass return TaskClientOutput( error=TaskError.AGENT_FAILED.value, info=str(e), - output=latest_result, + output=TaskOutput(status=SampleStatus.TASK_ERROR, history=history), ) - + + # Record agent response in history + history.append(ChatHistoryItem(role="agent", content=content or "")) + + # Build assistant message with tool call for the worker + # The worker expects OpenAI-style tool calls + tool_call_id = f"call_{turn}" + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": "take_action", + "arguments": json.dumps({"action": content}) + } + }] + } + + # Add to conversation + conversation_messages.append(assistant_message) + try: - result = requests.post( - self.controller_address + "/interact", - json=InteractRequest( - session_id=sid, - agent_response=response, - ).dict(), + # Send interact request to worker + response = requests.post( + self.worker_address + "/interact", + json={ + "session_id": sid, + "messages": [assistant_message] + }, + timeout=60 ) except Exception as e: return TaskClientOutput( error=TaskError.NETWORK_ERROR.value, info=str(e), - output=latest_result, - ) - if result.status_code != 200: - requests.post( - self.controller_address + "/cancel", - json=CancelRequest(session_id=sid).dict(), + output=TaskOutput(status=SampleStatus.RUNNING, history=history), ) + + if response.status_code != 200: return TaskClientOutput( error=TaskError.INTERACT_FAILED.value, - info=result.text, - output=latest_result, + info=response.text, + output=TaskOutput(status=SampleStatus.RUNNING, history=history), ) - - result = result.json() - latest_result = result - # TODO: check this type and check where history is - return TaskClientOutput(output=result["output"]) + + result = response.json() + + # Extract environment output + env_out = result.get("env_out", result) + new_messages = env_out.get("messages", []) + + # Add tool response to conversation + for msg in new_messages: + conversation_messages.append(msg) + content = msg.get("content", "") + if content: + history.append(ChatHistoryItem(role="user", content=content)) + + # Update status from env_out + result["status"] = env_out.get("status", "running") + result["finish"] = env_out.get("finish", False) + + # Determine final status + if result.get("finish", False) or result.get("status") == "completed": + final_status = SampleStatus.COMPLETED + elif turn >= max_turns: + final_status = SampleStatus.TASK_LIMIT_REACHED + else: + final_status = SampleStatus.COMPLETED + + return TaskClientOutput(output=TaskOutput( + status=final_status, + history=history, + result=result.get("metric", {}).get("score", 0) + )) + + def _run_sample_via_controller(self, index: SampleIndex, agent: AgentClient) -> TaskClientOutput: + """Original method - run sample via controller (has bugs with current controller).""" + try: + response = requests.post( + self.controller_address + "/start_sample", + json=StartSampleRequest(name=self.name, index=index).dict(), + ) + except Exception as e: + return TaskClientOutput(error=TaskError.NETWORK_ERROR.value, info=str(e)) + if response.status_code == 406: + return TaskClientOutput( + error=TaskError.NOT_AVAILABLE.value, info=response.text + ) + if response.status_code != 200: + return TaskClientOutput( + error=TaskError.START_FAILED.value, info=response.text + ) + + result = response.json() + sid = response.headers.get("Session_id") or response.headers.get("session_id") or result.get("session_id") + if sid is None: + return TaskClientOutput( + error=TaskError.START_FAILED.value, + info=f"No session_id in response" + ) + sid = int(sid) + + history = [] + max_turns = 50 + turn = 0 + + while turn < max_turns: + turn += 1 + if "output" in result and "status" in result["output"]: + if SampleStatus(result["output"]["status"]) != SampleStatus.RUNNING: + break + elif "status" in result: + if result["status"] not in ("running", "RUNNING"): + break + + try: + agent_input = result.get("messages", result.get("output", {}).get("history", [])) + content = agent.inference(agent_input) + except AgentContextLimitException: + return TaskClientOutput(output=TaskOutput(status=SampleStatus.AGENT_CONTEXT_LIMIT, history=history)) + except Exception as e: + return TaskClientOutput(error=TaskError.AGENT_FAILED.value, info=str(e), output=TaskOutput(status=SampleStatus.TASK_ERROR, history=history)) + + history.append(ChatHistoryItem(role="agent", content=content or "")) + + try: + response = requests.post( + self.controller_address + "/interact", + json={"session_id": sid, "agent_response": {"content": content, "status": "normal"}}, + headers={"Session_id": str(sid)} + ) + except Exception as e: + return TaskClientOutput(error=TaskError.NETWORK_ERROR.value, info=str(e), output=TaskOutput(status=SampleStatus.RUNNING, history=history)) + + if response.status_code != 200: + return TaskClientOutput(error=TaskError.INTERACT_FAILED.value, info=response.text, output=TaskOutput(status=SampleStatus.RUNNING, history=history)) + + result = response.json() + + if "output" in result: + return TaskClientOutput(output=result["output"]) + return TaskClientOutput(output=TaskOutput(status=SampleStatus.COMPLETED, history=history, result=result.get("result"))) def calculate_overall(self, results: List[TaskOutput]) -> JSONSerializable: statistics = {s: 0 for s in SampleStatus} diff --git a/src/configs.py b/src/configs.py index 1125887d..c8ec6557 100644 --- a/src/configs.py +++ b/src/configs.py @@ -1,11 +1,31 @@ import json import os +import re from copy import deepcopy from typing import Any, Dict, Set import yaml +def substitute_env_vars(value): + """Substitute ${VAR_NAME} patterns with environment variables.""" + if isinstance(value, str): + pattern = r'\$\{([^}]+)\}' + def replacer(match): + var_name = match.group(1) + env_value = os.environ.get(var_name) + if env_value is None: + print(f"Warning: Environment variable {var_name} not set") + return match.group(0) + return env_value + return re.sub(pattern, replacer, value) + elif isinstance(value, dict): + return {k: substitute_env_vars(v) for k, v in value.items()} + elif isinstance(value, list): + return [substitute_env_vars(v) for v in value] + return value + + def deep_merge(base_item, new_item): if isinstance(base_item, dict) and isinstance(new_item, dict): ret = deepcopy(base_item) @@ -51,7 +71,8 @@ def load_from(self, path) -> Dict: raise e self.loading.remove(path) self.loaded[path] = config - return self.parse_default_and_overwrite(deepcopy(config)) + result = self.parse_default_and_overwrite(deepcopy(config)) + return substitute_env_vars(result) def parse_imports(self, path, raw_config): raw_config = deepcopy(raw_config)