diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 35f3579..ac9e2b8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: jobs: - test: + lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -24,8 +24,27 @@ jobs: - name: Install dependencies run: uv sync --locked --group dev - - name: Lint - run: uv run ruff check . + - name: Run Ruff Linter + run: uv run ruff check --output-format=github . + + - name: Run Ruff Formatter Check + run: uv run ruff format --check --diff . + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up uv and Python + uses: astral-sh/setup-uv@v6 + with: + python-version: "3.11.9" + enable-cache: true + cache-dependency-glob: uv.lock + activate-environment: true + + - name: Install dependencies + run: uv sync --locked --group dev - name: Test run: uv run pytest diff --git a/README.md b/README.md index 9372872..7f11788 100644 --- a/README.md +++ b/README.md @@ -720,6 +720,16 @@ client.agents.jobs.delete_data(job_id) | AML Investigation | `AMLInvestigationEngine` | | Fraud Investigation | `FraudInvestigationEngine` | +## Development + +Before opening a PR, format and lint the codebase by running: + +```bash +./roe-cli format +``` + +CI runs the same checks (`ruff check` and `ruff format --check`) on every pull request and on merges to `main`, and they must pass before a PR can be merged. + ## Links - [Roe](https://www.roe-ai.com/) diff --git a/examples/create_aml_agent.py b/examples/create_aml_agent.py index 7b81025..e6d852a 100644 --- a/examples/create_aml_agent.py +++ b/examples/create_aml_agent.py @@ -36,7 +36,9 @@ def main(): "description": "Funds transferred through multiple accounts to obscure origin", "flag": "RED_FLAG", "sub_rules": [ - {"title": "Cross-border wire transfers with no business purpose"}, + { + "title": "Cross-border wire transfers with no business purpose" + }, {"title": "Shell company intermediaries"}, ], }, @@ -47,9 +49,18 @@ def main(): "instructions": "Investigate the alert by analyzing transaction patterns, account relationships, and customer profile. Use available data sources to corroborate or refute each finding.", "dispositions": { "classifications": [ - {"name": "Suspicious", "description": "Activity warrants SAR filing"}, - {"name": "Not Suspicious", "description": "Activity has legitimate business explanation"}, - {"name": "Needs Escalation", "description": "Requires senior BSA analyst review"}, + { + "name": "Suspicious", + "description": "Activity warrants SAR filing", + }, + { + "name": "Not Suspicious", + "description": "Activity has legitimate business explanation", + }, + { + "name": "Needs Escalation", + "description": "Requires senior BSA analyst review", + }, ] }, "summary_template": { @@ -69,7 +80,11 @@ def main(): name="AML Investigation Agent", engine_class_id="AMLInvestigationEngine", input_definitions=[ - {"key": "alert_data", "data_type": "text/plain", "description": "Alert data and context for AML investigation"}, + { + "key": "alert_data", + "data_type": "text/plain", + "description": "Alert data and context for AML investigation", + }, ], engine_config={ "policy_version_id": str(policy.current_version_id), @@ -128,9 +143,18 @@ def main(): "instructions": "Investigate the alert by analyzing transaction patterns, account relationships, and customer profile. Pay special attention to potential smurfing networks.", "dispositions": { "classifications": [ - {"name": "Suspicious", "description": "Activity warrants SAR filing"}, - {"name": "Not Suspicious", "description": "Activity has legitimate business explanation"}, - {"name": "Needs Escalation", "description": "Requires senior BSA analyst review"}, + { + "name": "Suspicious", + "description": "Activity warrants SAR filing", + }, + { + "name": "Not Suspicious", + "description": "Activity has legitimate business explanation", + }, + { + "name": "Needs Escalation", + "description": "Requires senior BSA analyst review", + }, ] }, }, @@ -142,7 +166,11 @@ def main(): client.agents.versions.create( agent_id=str(agent.id), input_definitions=[ - {"key": "alert_data", "data_type": "text/plain", "description": "Alert data and context for AML investigation"}, + { + "key": "alert_data", + "data_type": "text/plain", + "description": "Alert data and context for AML investigation", + }, ], engine_config={ "policy_version_id": str(new_version.id), diff --git a/examples/create_product_compliance_agent.py b/examples/create_product_compliance_agent.py index 9e1dc62..f1eb3e1 100644 --- a/examples/create_product_compliance_agent.py +++ b/examples/create_product_compliance_agent.py @@ -63,9 +63,18 @@ def main(): "instructions": "Analyze the product listing against each category. Flag any rule violations with evidence from the listing text.", "dispositions": { "classifications": [ - {"name": "Compliant", "description": "Product meets all policy requirements"}, - {"name": "Non-Compliant", "description": "Product violates one or more policy rules"}, - {"name": "Needs Review", "description": "Insufficient information to determine compliance"}, + { + "name": "Compliant", + "description": "Product meets all policy requirements", + }, + { + "name": "Non-Compliant", + "description": "Product violates one or more policy rules", + }, + { + "name": "Needs Review", + "description": "Insufficient information to determine compliance", + }, ] }, }, @@ -77,7 +86,11 @@ def main(): name="Product Compliance Checker", engine_class_id="ProductPolicyEngine", input_definitions=[ - {"key": "product_listings", "data_type": "text/plain", "description": "Product listing to analyze"}, + { + "key": "product_listings", + "data_type": "text/plain", + "description": "Product listing to analyze", + }, ], engine_config={ "policy_version_id": str(policy.current_version_id), diff --git a/examples/manage_policies.py b/examples/manage_policies.py index 532b0bd..97706fb 100644 --- a/examples/manage_policies.py +++ b/examples/manage_policies.py @@ -49,8 +49,14 @@ def main(): "instructions": "Investigate the alert against each category. Gather evidence from available data sources. Cite specific transactions and patterns.", "dispositions": { "classifications": [ - {"name": "Fraudulent", "description": "Confirmed fraud indicators found"}, - {"name": "Legitimate", "description": "Activity has legitimate explanation"}, + { + "name": "Fraudulent", + "description": "Confirmed fraud indicators found", + }, + { + "name": "Legitimate", + "description": "Activity has legitimate explanation", + }, {"name": "Escalate", "description": "Needs senior analyst review"}, ] }, @@ -109,8 +115,14 @@ def main(): "instructions": "Investigate the alert against each category. Gather evidence from available data sources. Cite specific transactions and patterns. Pay special attention to geographic indicators.", "dispositions": { "classifications": [ - {"name": "Fraudulent", "description": "Confirmed fraud indicators found"}, - {"name": "Legitimate", "description": "Activity has legitimate explanation"}, + { + "name": "Fraudulent", + "description": "Confirmed fraud indicators found", + }, + { + "name": "Legitimate", + "description": "Activity has legitimate explanation", + }, {"name": "Escalate", "description": "Needs senior analyst review"}, ] }, diff --git a/pyproject.toml b/pyproject.toml index a71a86a..4f1dee9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dev = [ codegen = [ "openapi-python-client>=0.28,<0.29", "ruamel-yaml>=0.18.15", + "ruff>=0.12.10", ] [tool.pytest.ini_options] diff --git a/roe-cli b/roe-cli new file mode 100755 index 0000000..cc88d04 --- /dev/null +++ b/roe-cli @@ -0,0 +1,62 @@ +#!/bin/bash + +# roe-cli - Development CLI for the Roe Python SDK +# Usage: ./roe-cli [options] + +set -e + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) + +show_help() { + cat << EOF +roe-cli - Development CLI for the Roe Python SDK + +Usage: ./roe-cli + +Commands: + format Run code formatter and linter (Ruff) + help Show this help message + +Examples: + ./roe-cli format # Format and lint the codebase +EOF +} + +cmd_format() { + cd "$SCRIPT_DIR" + + echo "Running Ruff linter..." + uv run --group dev ruff check --fix . + + echo "Running Ruff formatter..." + uv run --group dev ruff format . + + echo "✨ Code is properly formatted and linted!" +} + +main() { + if [[ $# -eq 0 ]]; then + show_help + exit 1 + fi + + local command="$1" + shift + + case "$command" in + format) + cmd_format "$@" + ;; + help|--help|-h) + show_help + ;; + *) + echo "Unknown command: $command" >&2 + echo "" + show_help + exit 1 + ;; + esac +} + +main "$@" diff --git a/scripts/generate-sdk b/scripts/generate-sdk index 234013e..ff0734a 100755 --- a/scripts/generate-sdk +++ b/scripts/generate-sdk @@ -18,3 +18,7 @@ rm -rf src/roe/_generated mv "$TMP_DIR/generated" src/roe/_generated uv run --group codegen python scripts/generate_wrappers.py + +# Keep generated wrappers consistent with `./roe-cli format` so formatting the +# repo never introduces codegen drift. +uv run --group codegen ruff format src/roe/api diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 76a6999..13e9604 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -73,9 +73,7 @@ def _replace_marked_block( ) updated, count = pattern.subn(block, readme, count=1) if count != 1: - raise ValueError( - f"{README_PATH} must contain {start_marker} and {end_marker}" - ) + raise ValueError(f"{README_PATH} must contain {start_marker} and {end_marker}") return updated @@ -88,7 +86,9 @@ def _load_project_version() -> str: def _load_release_marker() -> str: - marker = (ROOT_DIR / ".roe-main-release-version").read_text(encoding="utf-8").strip() + marker = ( + (ROOT_DIR / ".roe-main-release-version").read_text(encoding="utf-8").strip() + ) if not marker: raise ValueError(".roe-main-release-version must not be empty") return marker diff --git a/src/roe/api/agents.py b/src/roe/api/agents.py index 77b1878..f2449c0 100644 --- a/src/roe/api/agents.py +++ b/src/roe/api/agents.py @@ -509,9 +509,7 @@ def run_many( if not is_first_chunk: time.sleep(self.config.batch_chunk_delay) is_first_chunk = False - body = AgentRunAsyncManyRequest( - inputs=[_build_aer(item) for item in chunk] - ) + body = AgentRunAsyncManyRequest(inputs=[_build_aer(item) for item in chunk]) if metadata is not None: body.additional_properties["metadata"] = metadata response = request_raw( diff --git a/src/roe/config.py b/src/roe/config.py index 205ea04..e857963 100644 --- a/src/roe/config.py +++ b/src/roe/config.py @@ -52,8 +52,10 @@ def from_env( base_url = base_url or os.getenv("ROE_BASE_URL", "https://api.roe-ai.com") timeout = timeout or float(os.getenv("ROE_TIMEOUT", "60.0")) max_retries = max_retries or int(os.getenv("ROE_MAX_RETRIES", "3")) - batch_chunk_delay = batch_chunk_delay if batch_chunk_delay is not None else float( - os.getenv("ROE_BATCH_CHUNK_DELAY", "10.0") + batch_chunk_delay = ( + batch_chunk_delay + if batch_chunk_delay is not None + else float(os.getenv("ROE_BATCH_CHUNK_DELAY", "10.0")) ) if not api_key: diff --git a/src/roe/exceptions.py b/src/roe/exceptions.py index 035755c..ef91446 100644 --- a/src/roe/exceptions.py +++ b/src/roe/exceptions.py @@ -136,4 +136,9 @@ def translate_response(response: Any) -> None: raw_headers = getattr(response, "headers", None) cls = get_exception_for_status_code(status_code) - raise cls(message=message, status_code=status_code, response=error_data, headers=raw_headers) + raise cls( + message=message, + status_code=status_code, + response=error_data, + headers=raw_headers, + ) diff --git a/src/roe/models/job.py b/src/roe/models/job.py index ccbf228..3d3fd80 100644 --- a/src/roe/models/job.py +++ b/src/roe/models/job.py @@ -117,7 +117,9 @@ def wait( while True: status = self.retrieve_status() error_message = ( - None if isinstance(status.error_message, Unset) else status.error_message + None + if isinstance(status.error_message, Unset) + else status.error_message ) if status.status in _TERMINAL_STATUSES: @@ -203,9 +205,7 @@ def wait( while len(self._completed_jobs) < len(self._job_ids): pending_job_ids = [ - job_id - for job_id in self._job_ids - if job_id not in self._completed_jobs + job_id for job_id in self._job_ids if job_id not in self._completed_jobs ] if not pending_job_ids: diff --git a/src/roe/utils/transport.py b/src/roe/utils/transport.py index 5982095..d18fce6 100644 --- a/src/roe/utils/transport.py +++ b/src/roe/utils/transport.py @@ -62,7 +62,10 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: time.sleep(wait_time) continue - if not _should_retry_status(response.status_code) or attempt >= self.max_retries: + if ( + not _should_retry_status(response.status_code) + or attempt >= self.max_retries + ): return response wait_time = min(2**attempt, 10) diff --git a/uv.lock b/uv.lock index c39f88e..c14b7c4 100644 --- a/uv.lock +++ b/uv.lock @@ -467,6 +467,7 @@ dependencies = [ codegen = [ { name = "openapi-python-client" }, { name = "ruamel-yaml" }, + { name = "ruff" }, ] dev = [ { name = "pytest" }, @@ -485,6 +486,7 @@ requires-dist = [ codegen = [ { name = "openapi-python-client", specifier = ">=0.28,<0.29" }, { name = "ruamel-yaml", specifier = ">=0.18.15" }, + { name = "ruff", specifier = ">=0.12.10" }, ] dev = [ { name = "pytest", specifier = ">=8.3.0" },