From 06af0ca49081e0f39c4ed428951cde5780ffab5f Mon Sep 17 00:00:00 2001 From: Jaden Fix Date: Thu, 11 Jun 2026 09:39:54 -0700 Subject: [PATCH] codegen: support body/manual operation kinds and nested namespaces MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Groundwork for migrating the hand-written facades (agents.py, policies.py, users.py) to contract-generated code: - kind `body`: any single HTTP call — path/query/body parameter locations, UUID coercion, pass_unset_when_none, or_empty list/dict defaults, inject_organization_id body insertion, return_shape list with typed error label, refetch_with_retrieve (create-then-retrieve), no-content returns. Conventions mirror the hand-written facades exactly (request_raw/from_dict for GETs, request_json/.parsed for mutations). - kind `manual`: never generated; codegen asserts the method exists in the hand-written facade and fails loudly listing all missing methods. - namespaces: generated nested classes proxy _raw/_org_id to the parent, wired via ctor + property, mirroring AgentVersionsAPI. - split mode: APIs with manual ops emit base classes to src/roe/api/_{api}_generated.py (public class stays hand-written); all-manual APIs (users) are parity-checked only; whole-file APIs (discovery, tables) are byte-identical to before. scripts/fixtures/candidate_wrappers.yml is the migration parity fixture: every signature derived from the hand-written facades; AST comparison of generated vs hand-written shows 22/22 body methods signature-identical. Production no-op: regenerating from the committed openapi/wrappers.yml produces zero diff. Co-Authored-By: Claude Fable 5 --- scripts/fixtures/candidate_wrappers.yml | 566 ++++++++++++++++++++++++ scripts/generate_wrappers.py | 498 ++++++++++++++++++++- tests/unit/test_generate_wrappers.py | 69 +++ 3 files changed, 1114 insertions(+), 19 deletions(-) create mode 100644 scripts/fixtures/candidate_wrappers.yml create mode 100644 tests/unit/test_generate_wrappers.py diff --git a/scripts/fixtures/candidate_wrappers.yml b/scripts/fixtures/candidate_wrappers.yml new file mode 100644 index 0000000..bd7fc9e --- /dev/null +++ b/scripts/fixtures/candidate_wrappers.yml @@ -0,0 +1,566 @@ +# MIGRATION PARITY FIXTURE — not consumed by scripts/generate-sdk in production. +# +# Candidate openapi/wrappers.yml for migrating the hand-written facades +# (src/roe/api/agents.py, policies.py, users.py) to contract-generated code. +# Every signature, annotation, default, and behavior below was derived from +# the hand-written facades as of this commit; swapping this file into +# openapi/wrappers.yml and running scripts/generate_wrappers.py must +# reproduce the existing public surface exactly. +# +# - APIs with at least one `manual` operation and at least one generated +# operation (agents, policies) emit base classes to +# src/roe/api/_{api}_generated.py; the hand-written facade keeps ownership +# of the public class and the manual methods (codegen asserts they exist). +# - All-manual APIs (users) are parity-checked only; nothing is emitted. +# - discovery/tables are verbatim copies of the committed openapi/wrappers.yml +# so the swapped-in contract regenerates the whole tree. + +apis: + discovery: + class_name: DiscoveryAPI + docstring: API for discovering valid agent engine types and model IDs. + operations: + - kind: simple + method_name: list_agent_engine_types + docstring: Return production engine_class_id values accepted by agent creation. + endpoint_module: roe._generated.api.discovery.discovery_agent_engine_types_list + return_type: AgentEngineTypeList + return_import: roe._generated.models.agent_engine_type_list.AgentEngineTypeList + empty_response_message: agent engine discovery returned an empty response + - kind: simple + method_name: list_supported_models + docstring: Return non-deprecated model IDs accepted in engine_config.model. + endpoint_module: roe._generated.api.discovery.discovery_supported_models_list + return_type: SupportedLLMModelList + return_import: roe._generated.models.supported_llm_model_list.SupportedLLMModelList + empty_response_message: model discovery returned an empty response + parameters: + - name: capability + annotation: str | None + default: null + pass_unset_when_none: true + tables: + class_name: TablesAPI + docstring: API for uploading CSV files into Roe tables. + operations: + - kind: table_upload + method_name: upload + docstring: Upload a CSV file and create a Roe table. + endpoint_module: roe._generated.api.tables.upload_table + return_type: TableUploadResponse + return_import: roe._generated.models.table_upload_response.TableUploadResponse + body_type: TableUploadRequest + body_import: roe._generated.models.table_upload_request.TableUploadRequest + empty_response_message: table upload returned an empty response + + agents: + class_name: AgentsAPI + docstring: API for managing and running agents. + operations: + - kind: body + method_name: list + docstring: '' + method: GET + path: /v1/agents/ + endpoint_module: roe._generated.api.agents.agents_list + return_type: PaginatedBaseAgentList + return_import: roe._generated.models.paginated_base_agent_list.PaginatedBaseAgentList + parameters: + - name: page + location: query + annotation: int | None + default: null + pass_unset_when_none: true + - name: page_size + location: query + annotation: int | None + default: null + pass_unset_when_none: true + - kind: body + method_name: retrieve + docstring: '' + method: GET + path: /v1/agents/{agent_id}/ + endpoint_module: roe._generated.api.agents.agents_retrieve + return_type: BaseAgent + return_import: roe._generated.models.base_agent.BaseAgent + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: create + docstring: '' + method: POST + path: /v1/agents/ + endpoint_module: roe._generated.api.agents.agents_create + return_type: BaseAgent + return_import: roe._generated.models.base_agent.BaseAgent + body_type: BaseAgentCreateRequest + body_import: roe._generated.models.base_agent_create_request.BaseAgentCreateRequest + inject_organization_id: true + parameters: + - name: name + location: body + annotation: str + - name: engine_class_id + location: body + annotation: str + - name: input_definitions + location: body + annotation: list[dict[str, Any]] | None + default: null + or_empty: list + - name: engine_config + location: body + annotation: dict[str, Any] | None + default: null + or_empty: dict + - name: version_name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - name: description + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - kind: body + method_name: update + docstring: Update an agent via PATCH (partial update). + method: PATCH + path: /v1/agents/{agent_id}/ + endpoint_module: roe._generated.api.agents.agents_partial_update + return_type: BaseAgent + return_import: roe._generated.models.base_agent.BaseAgent + body_type: PatchedBaseAgentUpdateRequest + body_import: roe._generated.models.patched_base_agent_update_request.PatchedBaseAgentUpdateRequest + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - name: name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - name: disable_cache + location: body + annotation: bool | None + default: null + pass_unset_when_none: true + - name: cache_failed_jobs + location: body + annotation: bool | None + default: null + pass_unset_when_none: true + - kind: body + method_name: delete + docstring: '' + method: DELETE + path: /v1/agents/{agent_id}/ + endpoint_module: roe._generated.api.agents.agents_destroy + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: duplicate + docstring: |- + Duplicate an agent. Returns the resulting ``AgentVersion``. + + The endpoint historically returned a JSON object with a ``base_agent`` + key wrapping the new agent; the generated client now models the + response as ``AgentVersion`` directly. Callers wanting the new base + agent should read ``result.base_agent`` (already populated). + method: POST + path: /v1/agents/{agent_id}/duplicate/ + endpoint_module: roe._generated.api.agents.agents_duplicate_create + return_type: AgentVersion + return_import: roe._generated.models.agent_version.AgentVersion + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + # Bespoke run flows: dynamic **inputs packing via _build_aer/call_dynamic, + # Job/JobBatch polling wrappers, and client-side batch chunking. + - kind: manual + method_name: run + docstring: Run an agent asynchronously and return a Job handle. + method: POST + path: /v1/agents/run/{agent_id}/async/ + - kind: manual + method_name: run_many + docstring: Run an agent over a batch of inputs and return a JobBatch handle. + method: POST + path: /v1/agents/run/{agent_id}/async/many/ + - kind: manual + method_name: run_sync + docstring: Run an agent synchronously and return its output data. + method: POST + path: /v1/agents/run/{agent_id}/ + - kind: manual + method_name: run_version + docstring: Run a specific agent version asynchronously and return a Job handle. + method: POST + path: /v1/agents/run/{agent_id}/versions/{agent_version_id}/async/ + - kind: manual + method_name: run_version_sync + docstring: Run a specific agent version synchronously and return its output data. + method: POST + path: /v1/agents/run/{agent_id}/versions/{agent_version_id}/ + namespaces: + versions: + class_name: AgentVersionsAPI + attr: versions + docstring: Nested API for agent version operations. + operations: + - kind: body + method_name: list + docstring: '' + method: GET + path: /v1/agents/{agent_id}/versions/ + endpoint_module: roe._generated.api.agents.agents_versions_list + return_type: AgentVersion + return_import: roe._generated.models.agent_version.AgentVersion + return_shape: list + list_error_label: agent versions + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: retrieve + docstring: '' + method: GET + path: /v1/agents/{agent_id}/versions/{agent_version_id}/ + endpoint_module: roe._generated.api.agents.agents_versions_retrieve + return_type: AgentVersion + return_import: roe._generated.models.agent_version.AgentVersion + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - name: version_id + location: path + annotation: str + coerce: uuid + - name: get_supports_eval + location: query + annotation: bool | None + default: null + pass_unset_when_none: true + - kind: body + method_name: retrieve_current + docstring: '' + method: GET + path: /v1/agents/{agent_id}/versions/current/ + endpoint_module: roe._generated.api.agents.agents_versions_current_retrieve + return_type: AgentVersion + return_import: roe._generated.models.agent_version.AgentVersion + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: create + docstring: '' + method: POST + path: /v1/agents/{agent_id}/versions/ + endpoint_module: roe._generated.api.agents.agents_versions_create + return_type: AgentVersion + return_import: roe._generated.models.agent_version.AgentVersion + body_type: AgentVersionCreateRequest + body_import: roe._generated.models.agent_version_create_request.AgentVersionCreateRequest + refetch_with_retrieve: true + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - name: input_definitions + location: body + annotation: list[dict[str, Any]] | None + default: null + or_empty: list + - name: engine_config + location: body + annotation: dict[str, Any] | None + default: null + or_empty: dict + - name: version_name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - name: description + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - kind: body + method_name: update + docstring: Update an agent version via PATCH (partial update). + method: PATCH + path: /v1/agents/{agent_id}/versions/{agent_version_id}/ + endpoint_module: roe._generated.api.agents.agents_versions_partial_update + body_type: PatchedPatchedAgentVersionUpdateRequestRequest + body_import: roe._generated.models.patched_patched_agent_version_update_request_request.PatchedPatchedAgentVersionUpdateRequestRequest + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - name: version_id + location: path + annotation: str + coerce: uuid + - name: version_name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - name: description + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - kind: body + method_name: delete + docstring: '' + method: DELETE + path: /v1/agents/{agent_id}/versions/{agent_version_id}/ + endpoint_module: roe._generated.api.agents.agents_versions_destroy + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - name: version_id + location: path + annotation: str + coerce: uuid + jobs: + class_name: AgentJobsAPI + attr: jobs + docstring: Nested API for agent job operations. + operations: + - kind: body + method_name: retrieve_status + docstring: '' + method: GET + path: /v1/agents/jobs/{job_id}/status/ + endpoint_module: roe._generated.api.agents.agents_jobs_status_retrieve + return_type: AgentJobStatus + return_import: roe._generated.models.agent_job_status.AgentJobStatus + parameters: + - name: job_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: retrieve_result + docstring: '' + method: GET + path: /v1/agents/jobs/{agent_job_id}/result/ + endpoint_module: roe._generated.api.agents.agents_jobs_result_retrieve + return_type: AgentJobResultResponse + return_import: roe._generated.models.agent_job_result_response.AgentJobResultResponse + parameters: + - name: job_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: cancel + docstring: '' + method: POST + path: /v1/agents/jobs/{job_id}/cancel/ + endpoint_module: roe._generated.api.agents.agents_jobs_cancel_create + parameters: + - name: job_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: cancel_all + docstring: '' + method: POST + path: /v1/agents/{agent_id}/jobs/cancel-all/ + endpoint_module: roe._generated.api.agents.agents_jobs_cancel_all_create + parameters: + - name: agent_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: delete_data + docstring: '' + method: POST + path: /v1/agents/jobs/{job_id}/delete-data/ + endpoint_module: roe._generated.api.agents.agents_jobs_delete_data_create + return_type: AgentJobDeleteDataResponse + return_import: roe._generated.models.agent_job_delete_data_response.AgentJobDeleteDataResponse + parameters: + - name: job_id + location: path + annotation: str + coerce: uuid + # Bespoke: 1000-id chunking with inter-chunk delay; results-shape + # normalization; raw httpx download with query injection. + - kind: manual + method_name: retrieve_status_many + docstring: Retrieve statuses for many jobs (chunked batch calls). + method: POST + path: /v1/agents/jobs/statuses/ + - kind: manual + method_name: retrieve_result_many + docstring: Retrieve results for many jobs (chunked batch calls). + method: POST + path: /v1/agents/jobs/results/ + - kind: manual + method_name: download_reference + docstring: Download a binary reference produced by an agent job. + method: GET + path: /v1/agents/jobs/{agent_job_id}/references/{resource_id}/ + + policies: + class_name: PoliciesAPI + docstring: API for managing policies used by agentic workflows. + operations: + - kind: body + method_name: list + docstring: List policies in the organization. + method: GET + path: /v1/policies/ + endpoint_module: roe._generated.api.policies.policies_list + return_type: PaginatedPolicyList + return_import: roe._generated.models.paginated_policy_list.PaginatedPolicyList + parameters: + - name: page + location: query + annotation: int | None + default: null + pass_unset_when_none: true + - name: page_size + location: query + annotation: int | None + default: null + pass_unset_when_none: true + - kind: body + method_name: retrieve + docstring: Retrieve a specific policy by ID. + method: GET + path: /v1/policies/{id}/ + endpoint_module: roe._generated.api.policies.policies_retrieve + return_type: Policy + return_import: roe._generated.models.policy.Policy + parameters: + - name: policy_id + location: path + annotation: str + coerce: uuid + - kind: body + method_name: create + docstring: Create a new policy with an initial version. + method: POST + path: /v1/policies/ + endpoint_module: roe._generated.api.policies.policies_create + return_type: CreatePolicy + return_import: roe._generated.models.create_policy.CreatePolicy + body_type: CreatePolicyRequest + body_import: roe._generated.models.create_policy_request.CreatePolicyRequest + parameters: + - name: name + location: body + annotation: str + - name: content + location: body + annotation: dict[str, Any] + - name: description + location: body + annotation: str + default: '' + - name: version_name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - kind: body + method_name: update + docstring: Update a policy's metadata via PATCH (partial update). + method: PATCH + path: /v1/policies/{id}/ + endpoint_module: roe._generated.api.policies.policies_partial_update + return_type: Any + body_type: PatchedUpdatePolicyRequest + body_import: roe._generated.models.patched_update_policy_request.PatchedUpdatePolicyRequest + parameters: + - name: policy_id + location: path + annotation: str + coerce: uuid + - name: name + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - name: description + location: body + annotation: str | None + default: null + pass_unset_when_none: true + - kind: body + method_name: delete + docstring: Delete a policy and all its versions. + method: DELETE + path: /v1/policies/{id}/ + endpoint_module: roe._generated.api.policies.policies_destroy + parameters: + - name: policy_id + location: path + annotation: str + coerce: uuid + namespaces: + versions: + class_name: PolicyVersionsAPI + attr: versions + docstring: Nested API for policy version operations. + operations: + # Bespoke: every method routes through _parse_policy_version + # (base_version_id zero-UUID normalization); list custom-constructs + # PaginatedPolicyVersionList; create refetches via resp.parsed. + - kind: manual + method_name: list + docstring: List versions of a policy. + method: GET + path: /v1/policies/{policy_id}/versions/ + - kind: manual + method_name: retrieve + docstring: Retrieve a specific version of a policy. + method: GET + path: /v1/policies/{policy_id}/versions/{version_id}/ + - kind: manual + method_name: create + docstring: Create a new policy version (auto-set as current). + method: POST + path: /v1/policies/{policy_id}/versions/ + + users: + class_name: UsersAPI + docstring: API for retrieving information about the authenticated user. + operations: + # Bespoke: tolerant manual JSON parsing of an undocumented response body. + - kind: manual + method_name: me + docstring: Return the currently-authenticated user. + method: GET + path: /v1/users/current_user/ diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 76a6999..ec7e680 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -24,6 +24,12 @@ README_BLOCK_START = "" README_BLOCK_END = "" +# Suffix appended to contract class names when emitting base-class modules +# (split mode) for APIs that also contain hand-written ``manual`` operations. +GENERATED_CLASS_SUFFIX = "Generated" + +_ANY_RE = re.compile(r"\bAny\b") + HEADER = ( '"""Auto-generated friendly API facades for the Roe SDK."""\n' "\n" @@ -34,6 +40,20 @@ "\n" ) +BASE_HEADER = ( + '"""Auto-generated base classes for hand-written Roe SDK API facades."""\n' + "\n" + "# Generated by scripts/generate-sdk from openapi/wrappers.yml.\n" + "# Do not edit by hand.\n" + "\n" + "from __future__ import annotations\n" + "\n" +) + + +class ManualWrapperParityError(ValueError): + """Raised when a ``manual`` operation has no matching hand-written method.""" + def _load_contract() -> dict[str, Any]: yaml = YAML(typ="safe") @@ -146,6 +166,96 @@ def _default_expr(value: Any) -> str: return repr(value) +def _docstring_block(docstring: str, indent: str) -> str: + """Render a (possibly multi-line) docstring at the given indentation.""" + lines = docstring.splitlines() or [""] + if len(lines) == 1: + return f'{indent}"""{lines[0]}"""\n' + rendered = [f'{indent}"""{lines[0]}\n'] + for line in lines[1:]: + rendered.append(f"{indent}{line}\n" if line else "\n") + rendered.append(f'{indent}"""\n') + return "".join(rendered) + + +def _operations(spec: dict[str, Any]) -> list[dict[str, Any]]: + return spec.get("operations") or [] + + +def _namespaces(spec: dict[str, Any]) -> dict[str, dict[str, Any]]: + return spec.get("namespaces") or {} + + +def _body_kind_ops(spec: dict[str, Any]) -> list[dict[str, Any]]: + return [op for op in _operations(spec) if op.get("kind") == "body"] + + +def _spec_has_manual(spec: dict[str, Any]) -> bool: + if any(op.get("kind") == "manual" for op in _operations(spec)): + return True + return any( + op.get("kind") == "manual" + for ns_spec in _namespaces(spec).values() + for op in _operations(ns_spec) + ) + + +def _has_generated_ops(spec: dict[str, Any]) -> bool: + if any(op.get("kind") != "manual" for op in _operations(spec)): + return True + return any( + op.get("kind") != "manual" + for ns_spec in _namespaces(spec).values() + for op in _operations(ns_spec) + ) + + +def _collect_manual_method_names(spec: dict[str, Any]) -> list[str]: + names = [ + op["method_name"] for op in _operations(spec) if op.get("kind") == "manual" + ] + for ns_spec in _namespaces(spec).values(): + names.extend( + op["method_name"] + for op in _operations(ns_spec) + if op.get("kind") == "manual" + ) + return names + + +def _check_manual_parity(apis: dict[str, dict[str, Any]], api_dir: Path) -> None: + """Assert every ``manual`` operation exists in the hand-written facade. + + Manual operations are never generated; the hand-written module + ``src/roe/api/{api}.py`` must define each one. Collects every missing + method across all APIs and fails loudly with the full list. + """ + problems: list[str] = [] + for api_name, spec in sorted(apis.items()): + manual_names = _collect_manual_method_names(spec) + if not manual_names: + continue + facade = api_dir / f"{api_name}.py" + if not facade.exists(): + problems.append( + f"{api_name}: hand-written facade {facade} is missing " + f"(manual operations: {', '.join(manual_names)})" + ) + continue + text = facade.read_text(encoding="utf-8") + for name in manual_names: + if f"def {name}(" not in text: + problems.append( + f"{api_name}: manual operation {name!r} not found " + f"(expected `def {name}(` in {facade})" + ) + if problems: + raise ManualWrapperParityError( + "Manual wrapper parity check failed:\n" + + "\n".join(f" - {problem}" for problem in problems) + ) + + def _simple_method(operation: dict[str, Any]) -> str: method_name = operation["method_name"] endpoint_name = _module_import_parts(operation["endpoint_module"])[1] @@ -294,27 +404,276 @@ def _as_generated_file( ''' -def _render_api_module(api_name: str, spec: dict[str, Any]) -> str: - class_name = spec["class_name"] - operations = spec.get("operations") or [] +def _body_kind_method(operation: dict[str, Any]) -> str: + """Render a ``kind: body`` operation as a facade method. + + Conventions mirror the hand-written facades (agents.py / policies.py): + + - no ``body_type``: ``request_raw(self._raw, endpoint, *path_args, + **query, organization_id=self._org_id)``; ``ReturnType.from_dict`` + for object returns, a guarded list comprehension for + ``return_shape: list``, implicit ``None`` when no ``return_type``. + - with ``body_type``: build the request model from body-location + parameters and dispatch through ``request_json`` (forced-JSON body), + returning ``resp.parsed`` (or nothing when no ``return_type``). + - ``refetch_with_retrieve``: POST via ``request_raw`` with the body, + extract the created ``id`` and return ``self.retrieve(...)``. + """ + method_name = operation["method_name"] + endpoint_name = _module_import_parts(operation["endpoint_module"])[1] + params = operation.get("parameters") or [] + docstring = operation.get("docstring") or "" + return_type = operation.get("return_type") + return_shape = operation.get("return_shape", "object") + body_type = operation.get("body_type") + refetch = bool(operation.get("refetch_with_retrieve")) + empty_message = operation.get("empty_response_message") + + path_params = [p for p in params if p.get("location") == "path"] + query_params = [p for p in params if p.get("location") == "query"] + body_params = [p for p in params if p.get("location") == "body"] + + if return_shape not in ("object", "list"): + raise ValueError( + f"Unsupported return_shape {return_shape!r} in operation {method_name!r}" + ) + if return_shape == "list": + if return_type is None or body_type is not None: + raise ValueError( + f"return_shape: list requires a return_type and no body_type " + f"(operation {method_name!r})" + ) + return_annotation = f"list[{return_type}]" + elif return_type is not None: + return_annotation = return_type + else: + return_annotation = "None" + if refetch and (body_type is None or return_type is None): + raise ValueError( + f"refetch_with_retrieve requires body_type and return_type " + f"(operation {method_name!r})" + ) + + signature_parts = ["self"] + for param in params: + part = f"{param['name']}: {param['annotation']}" + if "default" in param: + part = f"{part} = {_default_expr(param['default'])}" + signature_parts.append(part) + if len(signature_parts) == 1: + lines = [f" def {method_name}(self) -> {return_annotation}:\n"] + else: + formatted_signature = ",\n".join(f" {part}" for part in signature_parts) + lines = [ + f" def {method_name}(\n{formatted_signature},\n" + f" ) -> {return_annotation}:\n" + ] + if docstring: + lines.append(_docstring_block(docstring, " ")) + + def _path_expr(param: dict[str, Any]) -> str: + if param.get("coerce") == "uuid": + return f"UUID(str({param['name']}))" + return param["name"] + + call_args = ["self._raw", endpoint_name] + call_args.extend(_path_expr(param) for param in path_params) + if body_type: + call_args.append("body=body") + for param in query_params: + wire = param.get("wire_name", param["name"]) + name = param["name"] + if param.get("pass_unset_when_none"): + call_args.append(f"{wire}={name} if {name} is not None else UNSET") + else: + call_args.append(f"{wire}={name}") + call_args.append("organization_id=self._org_id") + + def _request_call(prefix: str, func: str) -> str: + joined = ",\n".join(f" {arg}" for arg in call_args) + return f" {prefix}{func}(\n{joined},\n )\n" + + if body_type: + body_args: list[str] = [] + for param in body_params: + wire = param.get("wire_name", param["name"]) + name = param["name"] + or_empty = param.get("or_empty") + if or_empty == "list": + body_args.append(f"{wire}={name} or []") + elif or_empty == "dict": + body_args.append(f"{wire}={name} or {{}}") + elif or_empty is not None: + raise ValueError( + f"Unsupported or_empty {or_empty!r} in operation {method_name!r}" + ) + elif param.get("pass_unset_when_none"): + body_args.append(f"{wire}={name} if {name} is not None else UNSET") + else: + body_args.append(f"{wire}={name}") + if operation.get("inject_organization_id"): + required_count = sum(1 for p in body_params if "default" not in p) + body_args.insert(required_count, "organization_id=self._org_id") + joined_body = ",\n".join(f" {arg}" for arg in body_args) + lines.append(f" body = {body_type}(\n{joined_body},\n )\n") + + if refetch: + refetch_args = ", ".join(param["name"] for param in path_params) + lines.append(_request_call("response = ", "request_raw")) + lines.append( + " data = response.json()\n" + ' version_id = data.get("id") if isinstance(data, dict) else None\n' + " if version_id is None:\n" + " raise RoeAPIException(\n" + ' f"Unexpected response from server: status={response.status_code}"\n' + " )\n" + " # POST returns a partial create payload; re-fetch to get the full version.\n" + f" return self.retrieve({refetch_args}, str(version_id))\n" + ) + elif body_type: + if return_type is None: + lines.append(_request_call("", "request_json")) + else: + lines.append(_request_call("resp = ", "request_json")) + if empty_message: + lines.append( + " if resp.parsed is None:\n" + f' raise RoeAPIException("{empty_message}")\n' + " return resp.parsed\n" + ) + elif return_type == "Any": + lines.append(" return resp.parsed\n") + else: + lines.append( + " return resp.parsed # type: ignore[return-value]\n" + ) + elif return_shape == "list": + label = operation["list_error_label"] + lines.append(_request_call("response = ", "request_raw")) + lines.append( + " data = response.json()\n" + " if not isinstance(data, list):\n" + " raise RoeAPIException(\n" + f' f"{label} returned unexpected response shape: {{data!r}}"\n' + " )\n" + f" return [{return_type}.from_dict(item) for item in data]\n" + ) + elif return_type is not None: + lines.append(_request_call("response = ", "request_raw")) + lines.append(f" return {return_type}.from_dict(response.json())\n") + else: + lines.append(_request_call("", "request_raw")) + + return "".join(lines) + + +def _render_namespace_class( + api_name: str, parent_class: str, ns_spec: dict[str, Any], suffix: str +) -> str: + """Render a nested namespace class proxying ``_raw``/``_org_id`` to its parent.""" + class_name = ns_spec["class_name"] + suffix + parent_arg = f"{api_name}_api" + methods = [ + _body_kind_method(op) for op in _operations(ns_spec) if op.get("kind") == "body" + ] + + lines = [f"class {class_name}:\n"] + if ns_spec.get("docstring"): + lines.append(_docstring_block(ns_spec["docstring"], " ")) + lines.append("\n") + lines.append(f' def __init__(self, {parent_arg}: "{parent_class}"):\n') + lines.append(f" self._{parent_arg} = {parent_arg}\n") + lines.append("\n") + lines.append(" @property\n") + lines.append(" def _raw(self) -> AuthenticatedClient:\n") + lines.append(f" return self._{parent_arg}._raw\n") + lines.append("\n") + lines.append(" @property\n") + lines.append(" def _org_id(self) -> UUID:\n") + lines.append( + f" return UUID(str(self._{parent_arg}.config.organization_id))\n" + ) + lines.append("\n") + lines.append("\n".join(methods)) + return "".join(lines) + + +def _render_module(api_name: str, spec: dict[str, Any], *, split: bool) -> str: + """Render one API module. + + ``split=False`` emits the whole-file facade (today's discovery/tables + behavior). ``split=True`` emits base classes (``*Generated`` suffix) for + APIs that also have hand-written ``manual`` operations; the hand-written + facade later subclasses them. + """ + suffix = GENERATED_CLASS_SUFFIX if split else "" + class_name = spec["class_name"] + suffix + operations = _operations(spec) + namespaces = _namespaces(spec) + # Namespaces whose body operations get generated code. A namespace with + # zero body operations is fully hand-written: no class, no construction. + generated_namespaces = { + ns_name: ns_spec + for ns_name, ns_spec in namespaces.items() + if _body_kind_ops(ns_spec) + } endpoint_imports: dict[str, list[str]] = defaultdict(list) model_imports: dict[str, list[str]] = defaultdict(list) needs_unset = False needs_roe_api_exception = False + needs_translate_response = False needs_table_upload_helpers = False + needs_any = False + needs_uuid = False + request_helpers: set[str] = set() - methods: list[str] = [] - for operation in operations: + def _scan_body_operation(operation: dict[str, Any]) -> None: + nonlocal needs_unset, needs_roe_api_exception, needs_any, needs_uuid package, endpoint_name = _module_import_parts(operation["endpoint_module"]) endpoint_imports[package].append(endpoint_name) + if operation.get("return_import"): + return_module, return_class = _class_import_parts( + operation["return_import"] + ) + model_imports[return_module].append(return_class) + if operation.get("body_import"): + body_module, body_class = _class_import_parts(operation["body_import"]) + model_imports[body_module].append(body_class) - return_module, return_class = _class_import_parts(operation["return_import"]) - model_imports[return_module].append(return_class) + params = operation.get("parameters") or [] + annotations = [param["annotation"] for param in params] + if operation.get("return_type"): + annotations.append(operation["return_type"]) + if any(_ANY_RE.search(annotation) for annotation in annotations): + needs_any = True + if any(param.get("pass_unset_when_none") for param in params): + needs_unset = True + if any(param.get("coerce") == "uuid" for param in params): + needs_uuid = True + if ( + operation.get("return_shape") == "list" + or operation.get("refetch_with_retrieve") + or (operation.get("body_type") and operation.get("empty_response_message")) + ): + needs_roe_api_exception = True + if operation.get("body_type") and not operation.get("refetch_with_retrieve"): + request_helpers.add("request_json") + else: + request_helpers.add("request_raw") + methods: list[str] = [] + for operation in operations: kind = operation.get("kind", "simple") if kind == "simple": + package, endpoint_name = _module_import_parts(operation["endpoint_module"]) + endpoint_imports[package].append(endpoint_name) + return_module, return_class = _class_import_parts( + operation["return_import"] + ) + model_imports[return_module].append(return_class) needs_roe_api_exception = True + needs_translate_response = True if any( param.get("pass_unset_when_none") for param in operation.get("parameters") or [] @@ -322,16 +681,50 @@ def _render_api_module(api_name: str, spec: dict[str, Any]) -> str: needs_unset = True methods.append(_simple_method(operation)) elif kind == "table_upload": + package, endpoint_name = _module_import_parts(operation["endpoint_module"]) + endpoint_imports[package].append(endpoint_name) + return_module, return_class = _class_import_parts( + operation["return_import"] + ) + model_imports[return_module].append(return_class) needs_roe_api_exception = True + needs_translate_response = True needs_table_upload_helpers = True needs_unset = True body_module, body_class = _class_import_parts(operation["body_import"]) model_imports[body_module].append(body_class) methods.append(_table_upload_method(operation)) + elif kind == "body": + _scan_body_operation(operation) + methods.append(_body_kind_method(operation)) + elif kind == "manual": + continue else: raise ValueError(f"Unsupported wrapper kind {kind!r} in {api_name}") - lines = [HEADER] + api_body_ops = _body_kind_ops(spec) + if api_body_ops: + # Every generated body method reads ``self._org_id``. + needs_uuid = True + namespace_blocks: list[str] = [] + for ns_name, ns_spec in namespaces.items(): + for operation in _operations(ns_spec): + ns_kind = operation.get("kind") + if ns_kind == "body": + _scan_body_operation(operation) + elif ns_kind != "manual": + raise ValueError( + f"Unsupported wrapper kind {ns_kind!r} in namespace " + f"{api_name}.{ns_name} (only body/manual are allowed)" + ) + if ns_name in generated_namespaces: + # Namespace classes proxy ``_org_id`` via the parent config. + needs_uuid = True + namespace_blocks.append( + _render_namespace_class(api_name, class_name, ns_spec, suffix) + ) + + lines = [BASE_HEADER if split else HEADER] if needs_table_upload_helpers: lines.append("from io import BytesIO\n") lines.append("import mimetypes\n") @@ -339,6 +732,15 @@ def _render_api_module(api_name: str, spec: dict[str, Any]) -> str: lines.append("from typing import BinaryIO\n") lines.append("from uuid import UUID\n") lines.append("\n") + else: + stdlib_lines: list[str] = [] + if needs_any: + stdlib_lines.append("from typing import Any\n") + if needs_uuid: + stdlib_lines.append("from uuid import UUID\n") + if stdlib_lines: + lines.extend(stdlib_lines) + lines.append("\n") for package, names in sorted(endpoint_imports.items()): unique_names = sorted(set(names)) @@ -358,23 +760,46 @@ def _render_api_module(api_name: str, spec: dict[str, Any]) -> str: elif needs_unset: lines.append("from roe._generated.types import UNSET\n") lines.append("from roe.config import RoeConfig\n") - if needs_roe_api_exception: + if needs_translate_response and needs_roe_api_exception: lines.append("from roe.exceptions import RoeAPIException, translate_response\n") - else: + elif needs_translate_response: lines.append("from roe.exceptions import translate_response\n") + elif needs_roe_api_exception: + lines.append("from roe.exceptions import RoeAPIException\n") if needs_table_upload_helpers: lines.append("from roe.models import FileUpload\n") + if request_helpers: + joined_helpers = ", ".join(sorted(request_helpers)) + lines.append(f"from roe.utils.generated_request import {joined_helpers}\n") lines.append("\n\n") + for block in namespace_blocks: + lines.append(block.rstrip() + "\n") + lines.append("\n\n") + lines.append(f"class {class_name}:\n") - lines.append(f' """{spec.get("docstring", "")}"""\n') + lines.append(_docstring_block(spec.get("docstring", ""), " ")) lines.append("\n") lines.append( " def __init__(self, config: RoeConfig, raw_client: AuthenticatedClient):\n" ) lines.append(" self.config = config\n") lines.append(" self._raw = raw_client\n") + for ns_name, ns_spec in generated_namespaces.items(): + attr = ns_spec["attr"] + lines.append(f" self._{attr} = {ns_spec['class_name']}{suffix}(self)\n") lines.append("\n") + if api_body_ops: + lines.append(" @property\n") + lines.append(" def _org_id(self) -> UUID:\n") + lines.append(" return UUID(str(self.config.organization_id))\n") + lines.append("\n") + for ns_name, ns_spec in generated_namespaces.items(): + attr = ns_spec["attr"] + lines.append(" @property\n") + lines.append(f" def {attr}(self) -> {ns_spec['class_name']}{suffix}:\n") + lines.append(f" return self._{attr}\n") + lines.append("\n") lines.append("\n".join(methods)) if needs_table_upload_helpers: @@ -400,6 +825,9 @@ def _render_registry(apis: dict[str, dict[str, Any]]) -> str: "from __future__ import annotations\n", "\n", ] + if not apis: + lines.append("\nGENERATED_API_CLASSES: dict[str, type] = {}\n") + return "".join(lines) for api_name, spec in sorted(apis.items()): lines.append(f"from roe.api.{api_name} import {spec['class_name']}\n") lines.append("\n\n") @@ -410,22 +838,54 @@ def _render_registry(apis: dict[str, dict[str, Any]]) -> str: return "".join(lines) +def _generate_modules( + apis: dict[str, dict[str, Any]], api_dir: Path +) -> tuple[dict[str, dict[str, Any]], list[str]]: + """Write API modules under ``api_dir``. + + Returns ``(whole_file_apis, split_api_names)``. APIs with at least one + ``manual`` operation (at the API level or inside a namespace) are emitted + as base-class modules ``_{api}_generated.py``; the rest keep today's + whole-file behavior and are returned for registry rendering. + """ + _check_manual_parity(apis, api_dir) + api_dir.mkdir(parents=True, exist_ok=True) + + whole_file_apis: dict[str, dict[str, Any]] = {} + split_api_names: list[str] = [] + for api_name, spec in sorted(apis.items()): + if _spec_has_manual(spec): + # All-manual APIs (e.g. users) are parity-checked only; emitting an + # empty base module would add noise without value. + if _has_generated_ops(spec): + target = api_dir / f"_{api_name}_generated.py" + target.write_text(_render_module(api_name, spec, split=True)) + split_api_names.append(api_name) + else: + target = api_dir / f"{api_name}.py" + target.write_text(_render_module(api_name, spec, split=False)) + whole_file_apis[api_name] = spec + return whole_file_apis, split_api_names + + def main() -> None: contract = _load_contract() apis = contract["apis"] - API_DIR.mkdir(parents=True, exist_ok=True) - for api_name, spec in sorted(apis.items()): - target = API_DIR / f"{api_name}.py" - target.write_text(_render_api_module(api_name, spec)) - - REGISTRY_PATH.write_text(_render_registry(apis)) + whole_file_apis, split_api_names = _generate_modules(apis, API_DIR) + REGISTRY_PATH.write_text(_render_registry(whole_file_apis)) _sync_readme_release_banner() _sync_readme_block() - print( - f"Generated {len(apis)} friendly API wrapper modules from " + message = ( + f"Generated {len(whole_file_apis)} friendly API wrapper modules from " f"{CONTRACT_PATH.relative_to(ROOT_DIR)}" ) + if split_api_names: + message += ( + f" (+{len(split_api_names)} base-class modules: " + f"{', '.join(split_api_names)})" + ) + print(message) if __name__ == "__main__": diff --git a/tests/unit/test_generate_wrappers.py b/tests/unit/test_generate_wrappers.py new file mode 100644 index 0000000..0858512 --- /dev/null +++ b/tests/unit/test_generate_wrappers.py @@ -0,0 +1,69 @@ +"""Tests for the body/manual/namespace paths in scripts/generate_wrappers.py.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +pytest.importorskip("ruamel.yaml", reason="requires the codegen dependency group") +from ruamel.yaml import YAML # noqa: E402 + +ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(ROOT / "scripts")) + +import generate_wrappers as gw # noqa: E402 + +FIXTURE = ROOT / "scripts" / "fixtures" / "candidate_wrappers.yml" +API_DIR = ROOT / "src" / "roe" / "api" + + +def _fixture_apis() -> dict: + return YAML(typ="safe").load(FIXTURE.read_text())["apis"] + + +def test_candidate_fixture_passes_manual_parity() -> None: + gw._check_manual_parity(_fixture_apis(), API_DIR) + + +def test_manual_parity_fails_loudly_on_missing_method() -> None: + apis = _fixture_apis() + apis["agents"]["operations"].append( + {"kind": "manual", "method_name": "definitely_missing", "docstring": ""} + ) + with pytest.raises(gw.ManualWrapperParityError, match="definitely_missing"): + gw._check_manual_parity(apis, API_DIR) + + +def test_split_module_renders_body_methods_and_namespaces() -> None: + apis = _fixture_apis() + module = gw._render_module("agents", apis["agents"], split=True) + assert "class AgentsAPIGenerated:" in module + assert "class AgentVersionsAPIGenerated:" in module + assert "class AgentJobsAPIGenerated:" in module + # body conventions + assert "return PaginatedBaseAgentList.from_dict(response.json())" in module + assert "organization_id=self._org_id" in module + assert "UUID(str(agent_id))" in module + assert "input_definitions=input_definitions or []" in module + # refetch-with-retrieve (versions.create) + assert "return self.retrieve(agent_id, str(version_id))" in module + # manual methods never generated + assert "def run(" not in module + assert "def retrieve_status_many(" not in module + + +def test_all_manual_api_emits_nothing(tmp_path: Path) -> None: + apis = { + "users": { + "class_name": "UsersAPI", + "docstring": "API for users.", + "operations": [{"kind": "manual", "method_name": "me", "docstring": ""}], + } + } + (tmp_path / "users.py").write_text("def me(self):\n pass\n") + whole, split = gw._generate_modules(apis, tmp_path) + assert whole == {} + assert split == [] + assert not (tmp_path / "_users_generated.py").exists()