Skip to content

Commit 7bb25d1

Browse files
committed
fix: Fix when the Location header is used.
1 parent d0819aa commit 7bb25d1

3 files changed

Lines changed: 120 additions & 29 deletions

File tree

autogen.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@
2020
OPENAPI_URL = "https://api.vantage.sh/v2/oas_v3.json"
2121
OUTPUT_DIR = Path(__file__).parent / "src" / "vantage"
2222

23+
# Maps a substring found in a response description to the internal client method
24+
# that should handle it, and the Python return type to emit.
25+
# Checked against each HTTP status code's description during endpoint parsing.
26+
RESPONSE_HANDLERS: list[tuple[str, str, str]] = [
27+
(
28+
"will be available at the location specified in the Location header",
29+
"_request_for_location",
30+
"str",
31+
),
32+
]
33+
2334

2435
@dataclass
2536
class Parameter:
@@ -48,6 +59,8 @@ class Endpoint:
4859
request_body_type: str | None = None
4960
response_type: str | None = None
5061
is_multipart: bool = False
62+
response_handler: str | None = None # internal client method to call, if not the default
63+
response_handler_return_type: str | None = None
5164

5265

5366
@dataclass
@@ -292,13 +305,27 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
292305

293306
response_type = extract_response_type(spec.get("responses", {}), schemas)
294307

308+
description = spec.get("description")
309+
310+
response_handler = None
311+
response_handler_return_type = None
312+
for resp_desc in spec.get("responses", {}).values():
313+
text = resp_desc.get("description", "")
314+
for phrase, handler, return_type in RESPONSE_HANDLERS:
315+
if phrase in text:
316+
response_handler = handler
317+
response_handler_return_type = return_type
318+
break
319+
if response_handler:
320+
break
321+
295322
endpoints.append(
296323
Endpoint(
297324
path=path,
298325
method=method.upper(),
299326
operation_id=operation_id,
300327
summary=spec.get("summary"),
301-
description=spec.get("description"),
328+
description=description,
302329
deprecated=spec.get("deprecated", False),
303330
parameters=parameters,
304331
request_body_required=request_body.get("required", False)
@@ -307,6 +334,8 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
307334
request_body_type=body_type,
308335
response_type=response_type,
309336
is_multipart=is_multipart,
337+
response_handler=response_handler,
338+
response_handler_return_type=response_handler_return_type,
310339
)
311340
)
312341

@@ -523,6 +552,18 @@ def generate_pydantic_models(schema: dict[str, Any]) -> str:
523552
return "\n".join(lines)
524553

525554

555+
def _collect_handler_routes(resources: dict[str, Resource]) -> dict[str, list[tuple[str, str]]]:
556+
"""Scan all endpoints and group (method, path) pairs by their response_handler."""
557+
handler_routes: dict[str, list[tuple[str, str]]] = {}
558+
for resource in resources.values():
559+
for endpoint in resource.endpoints:
560+
if endpoint.response_handler:
561+
handler_routes.setdefault(endpoint.response_handler, []).append(
562+
(endpoint.method, endpoint.path)
563+
)
564+
return handler_routes
565+
566+
526567
def generate_sync_client(resources: dict[str, Resource]) -> str:
527568
"""Generate synchronous client code."""
528569
lines = [
@@ -618,13 +659,30 @@ def generate_sync_client(resources: dict[str, Resource]) -> str:
618659
" body=response.text,",
619660
" )",
620661
"",
662+
]
663+
)
664+
665+
# Inject generated routing: one if-block per handler, checking (method, path)
666+
handler_routes = _collect_handler_routes(resources)
667+
for handler, routes in sorted(handler_routes.items()):
668+
route_set = "{" + ", ".join(f'("{m}", "{p}")' for m, p in sorted(routes)) + "}"
669+
lines.append(f" if (method, path) in {route_set}:")
670+
lines.append(f" return self.{handler}(response)")
671+
lines.append("")
672+
673+
lines.extend(
674+
[
621675
" try:",
622676
" data = response.json()",
623677
" except Exception:",
624678
" data = None",
625679
"",
626680
" return data",
627681
"",
682+
" def _request_for_location(self, response: Any) -> str:",
683+
' """Extract the Location header from a response."""',
684+
' return response.headers["Location"]',
685+
"",
628686
"",
629687
]
630688
)
@@ -699,7 +757,10 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:
699757

700758
# Method signature
701759
param_str = ", ".join(["self"] + params) if params else "self"
702-
return_type = endpoint.response_type or "None"
760+
if endpoint.response_handler:
761+
return_type = endpoint.response_handler_return_type or "Any"
762+
else:
763+
return_type = endpoint.response_type or "None"
703764
lines.append(f" def {method_name}({param_str}) -> {return_type}:")
704765

705766
# Docstring
@@ -740,7 +801,11 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:
740801
lines.append(" body_data = None")
741802

742803
# Make request and coerce response payload into typed models where possible
743-
if endpoint.response_type is None:
804+
if endpoint.response_handler:
805+
lines.append(
806+
f' return self._client.request("{endpoint.method}", path, params=params, body=body_data)'
807+
)
808+
elif endpoint.response_type is None:
744809
lines.append(
745810
f' self._client.request("{endpoint.method}", path, params=params, body=body_data)'
746811
)
@@ -849,13 +914,30 @@ def generate_async_client(resources: dict[str, Resource]) -> str:
849914
" body=response.text,",
850915
" )",
851916
"",
917+
]
918+
)
919+
920+
# Inject generated routing: one if-block per handler, checking (method, path)
921+
handler_routes = _collect_handler_routes(resources)
922+
for handler, routes in sorted(handler_routes.items()):
923+
route_set = "{" + ", ".join(f'("{m}", "{p}")' for m, p in sorted(routes)) + "}"
924+
lines.append(f" if (method, path) in {route_set}:")
925+
lines.append(f" return self.{handler}(response)")
926+
lines.append("")
927+
928+
lines.extend(
929+
[
852930
" try:",
853931
" data = response.json()",
854932
" except Exception:",
855933
" data = None",
856934
"",
857935
" return data",
858936
"",
937+
" def _request_for_location(self, response: Any) -> str:",
938+
' """Extract the Location header from a response."""',
939+
' return response.headers["Location"]',
940+
"",
859941
"",
860942
]
861943
)
@@ -930,7 +1012,10 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:
9301012

9311013
# Method signature
9321014
param_str = ", ".join(["self"] + params) if params else "self"
933-
return_type = endpoint.response_type or "None"
1015+
if endpoint.response_handler:
1016+
return_type = endpoint.response_handler_return_type or "Any"
1017+
else:
1018+
return_type = endpoint.response_type or "None"
9341019
lines.append(f" async def {method_name}({param_str}) -> {return_type}:")
9351020

9361021
# Docstring
@@ -971,7 +1056,11 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:
9711056
lines.append(" body_data = None")
9721057

9731058
# Make request and coerce response payload into typed models where possible
974-
if endpoint.response_type is None:
1059+
if endpoint.response_handler:
1060+
lines.append(
1061+
f' return await self._client.request("{endpoint.method}", path, params=params, body=body_data)'
1062+
)
1063+
elif endpoint.response_type is None:
9751064
lines.append(
9761065
f' await self._client.request("{endpoint.method}", path, params=params, body=body_data)'
9771066
)

src/vantage/_async/client.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,20 @@ async def request(
123123
body=response.text,
124124
)
125125

126+
if (method, path) in {("POST", "/costs/data_exports"), ("POST", "/kubernetes_efficiency_reports/data_exports"), ("POST", "/unit_costs/data_exports")}:
127+
return self._request_for_location(response)
128+
126129
try:
127130
data = response.json()
128131
except Exception:
129132
data = None
130133

131134
return data
132135

136+
def _request_for_location(self, response: Any) -> str:
137+
"""Extract the Location header from a response."""
138+
return response.headers["Location"]
139+
133140

134141
class AccessGrantsAsyncApi:
135142
"""Async API methods for access_grants resource."""
@@ -1111,7 +1118,7 @@ class CostsAsyncApi:
11111118
def __init__(self, client: AsyncClient) -> None:
11121119
self._client = client
11131120

1114-
async def create_export(self, body: CreateCostExport, *, groupings: Optional[List[str]] = None) -> None:
1121+
async def create_export(self, body: CreateCostExport, *, groupings: Optional[List[str]] = None) -> str:
11151122
"""
11161123
Generate cost data export
11171124
@@ -1122,7 +1129,7 @@ async def create_export(self, body: CreateCostExport, *, groupings: Optional[Lis
11221129
"groupings": groupings,
11231130
}
11241131
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
1125-
await self._client.request("POST", path, params=params, body=body_data)
1132+
return await self._client.request("POST", path, params=params, body=body_data)
11261133

11271134
async def list(self, *, cost_report_token: Optional[str] = None, filter: Optional[str] = None, workspace_token: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, groupings: Optional[List[str]] = None, order: Optional[str] = None, limit: Optional[int] = None, page: Optional[int] = None, date_bin: Optional[str] = None, settings_include_credits: Optional[bool] = None, settings_include_refunds: Optional[bool] = None, settings_include_discounts: Optional[bool] = None, settings_include_tax: Optional[bool] = None, settings_amortize: Optional[bool] = None, settings_unallocated: Optional[bool] = None, settings_aggregate_by: Optional[str] = None, settings_show_previous_period: Optional[bool] = None) -> Costs:
11281135
"""
@@ -1772,7 +1779,7 @@ async def create(self, body: CreateKubernetesEfficiencyReport) -> KubernetesEffi
17721779
return KubernetesEfficiencyReport.model_validate(data)
17731780
return data
17741781

1775-
async def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, groupings: Optional[List[str]] = None) -> DataExport:
1782+
async def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, groupings: Optional[List[str]] = None) -> str:
17761783
"""
17771784
Generate Kubernetes efficiency data export
17781785
@@ -1783,10 +1790,7 @@ async def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, g
17831790
"groupings": groupings,
17841791
}
17851792
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
1786-
data = await self._client.request("POST", path, params=params, body=body_data)
1787-
if isinstance(data, dict):
1788-
return DataExport.model_validate(data)
1789-
return data
1793+
return await self._client.request("POST", path, params=params, body=body_data)
17901794

17911795
async def get(self, kubernetes_efficiency_report_token: str) -> KubernetesEfficiencyReport:
17921796
"""
@@ -2858,7 +2862,7 @@ class UnitCostsAsyncApi:
28582862
def __init__(self, client: AsyncClient) -> None:
28592863
self._client = client
28602864

2861-
async def create_export(self, body: CreateUnitCostsExport) -> DataExport:
2865+
async def create_export(self, body: CreateUnitCostsExport) -> str:
28622866
"""
28632867
Generate data export of unit costs
28642868
@@ -2867,10 +2871,7 @@ async def create_export(self, body: CreateUnitCostsExport) -> DataExport:
28672871
path = "/unit_costs/data_exports"
28682872
params = None
28692873
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
2870-
data = await self._client.request("POST", path, params=params, body=body_data)
2871-
if isinstance(data, dict):
2872-
return DataExport.model_validate(data)
2873-
return data
2874+
return await self._client.request("POST", path, params=params, body=body_data)
28742875

28752876
async def list(self, *, cost_report_token: str, start_date: Optional[str] = None, end_date: Optional[str] = None, date_bin: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, page: Optional[int] = None) -> UnitCosts:
28762877
"""

src/vantage/_sync/client.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,20 @@ def request(
123123
body=response.text,
124124
)
125125

126+
if (method, path) in {("POST", "/costs/data_exports"), ("POST", "/kubernetes_efficiency_reports/data_exports"), ("POST", "/unit_costs/data_exports")}:
127+
return self._request_for_location(response)
128+
126129
try:
127130
data = response.json()
128131
except Exception:
129132
data = None
130133

131134
return data
132135

136+
def _request_for_location(self, response: Any) -> str:
137+
"""Extract the Location header from a response."""
138+
return response.headers["Location"]
139+
133140

134141
class AccessGrantsApi:
135142
"""API methods for access_grants resource."""
@@ -1111,7 +1118,7 @@ class CostsApi:
11111118
def __init__(self, client: SyncClient) -> None:
11121119
self._client = client
11131120

1114-
def create_export(self, body: CreateCostExport, *, groupings: Optional[List[str]] = None) -> None:
1121+
def create_export(self, body: CreateCostExport, *, groupings: Optional[List[str]] = None) -> str:
11151122
"""
11161123
Generate cost data export
11171124
@@ -1122,7 +1129,7 @@ def create_export(self, body: CreateCostExport, *, groupings: Optional[List[str]
11221129
"groupings": groupings,
11231130
}
11241131
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
1125-
self._client.request("POST", path, params=params, body=body_data)
1132+
return self._client.request("POST", path, params=params, body=body_data)
11261133

11271134
def list(self, *, cost_report_token: Optional[str] = None, filter: Optional[str] = None, workspace_token: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, groupings: Optional[List[str]] = None, order: Optional[str] = None, limit: Optional[int] = None, page: Optional[int] = None, date_bin: Optional[str] = None, settings_include_credits: Optional[bool] = None, settings_include_refunds: Optional[bool] = None, settings_include_discounts: Optional[bool] = None, settings_include_tax: Optional[bool] = None, settings_amortize: Optional[bool] = None, settings_unallocated: Optional[bool] = None, settings_aggregate_by: Optional[str] = None, settings_show_previous_period: Optional[bool] = None) -> Costs:
11281135
"""
@@ -1772,7 +1779,7 @@ def create(self, body: CreateKubernetesEfficiencyReport) -> KubernetesEfficiency
17721779
return KubernetesEfficiencyReport.model_validate(data)
17731780
return data
17741781

1775-
def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, groupings: Optional[List[str]] = None) -> DataExport:
1782+
def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, groupings: Optional[List[str]] = None) -> str:
17761783
"""
17771784
Generate Kubernetes efficiency data export
17781785
@@ -1783,10 +1790,7 @@ def create_export(self, body: CreateKubernetesEfficiencyReportExport, *, groupin
17831790
"groupings": groupings,
17841791
}
17851792
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
1786-
data = self._client.request("POST", path, params=params, body=body_data)
1787-
if isinstance(data, dict):
1788-
return DataExport.model_validate(data)
1789-
return data
1793+
return self._client.request("POST", path, params=params, body=body_data)
17901794

17911795
def get(self, kubernetes_efficiency_report_token: str) -> KubernetesEfficiencyReport:
17921796
"""
@@ -2858,7 +2862,7 @@ class UnitCostsApi:
28582862
def __init__(self, client: SyncClient) -> None:
28592863
self._client = client
28602864

2861-
def create_export(self, body: CreateUnitCostsExport) -> DataExport:
2865+
def create_export(self, body: CreateUnitCostsExport) -> str:
28622866
"""
28632867
Generate data export of unit costs
28642868
@@ -2867,10 +2871,7 @@ def create_export(self, body: CreateUnitCostsExport) -> DataExport:
28672871
path = "/unit_costs/data_exports"
28682872
params = None
28692873
body_data = body.model_dump(by_alias=True, exclude_none=True) if hasattr(body, 'model_dump') else body
2870-
data = self._client.request("POST", path, params=params, body=body_data)
2871-
if isinstance(data, dict):
2872-
return DataExport.model_validate(data)
2873-
return data
2874+
return self._client.request("POST", path, params=params, body=body_data)
28742875

28752876
def list(self, *, cost_report_token: str, start_date: Optional[str] = None, end_date: Optional[str] = None, date_bin: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, page: Optional[int] = None) -> UnitCosts:
28762877
"""

0 commit comments

Comments
 (0)