Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ setup-env:
debug:
FLASK_APP=policyengine_api.api FLASK_DEBUG=1 flask run --without-threads

debug-asgi:
FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --reload --port 8000

test-env-vars:
pytest tests/env_variables

Expand Down
1 change: 1 addition & 0 deletions changelog.d/fastapi-shell.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a FastAPI ASGI compatibility shell that can serve the existing Flask API through WSGI fallback.
19 changes: 19 additions & 0 deletions docs/engineering/skills/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,24 @@ python scripts/export_migration_contracts.py
python -m pytest tests/contract tests/unit/test_migration_flags.py tests/unit/test_migration_contract_artifacts.py tests/unit/test_capture_migration_baseline.py tests/unit/routes/test_migration_context_logging.py -q
```

For PR 2 FastAPI shell or Flask fallback changes, verify the ASGI entrypoint and
the v1 route contracts together:

```bash
FLASK_DEBUG=1 python -m pytest tests/unit/test_asgi_factory.py tests/contract/test_v1_route_contracts.py tests/unit/routes/test_migration_context_logging.py -q
```

If the change touches service compatibility behavior used by migrated or
candidate endpoints, add the relevant focused service tests. For budget-window
simulation compatibility, run:

```bash
FLASK_DEBUG=1 python -m pytest tests/unit/services/test_economy_service.py::TestEconomyService::TestGetBudgetWindowEconomicImpact -q
```

Regenerate and review `docs/engineering/generated/migration_contracts.md` when
route inventory, migration registry flags, or v1 contract expectations change.
FastAPI shell-only fallback changes should not change the route catalog.

Run `ruff format --check` and `ruff check` on changed Python files before
handoff.
41 changes: 41 additions & 0 deletions docs/migration-pr2-fastapi-shell-runbook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# PR 2 FastAPI Shell Runbook

PR 2 adds an ASGI FastAPI shell around the existing Flask API. It is a
compatibility step only.

## Included

- Native FastAPI `GET /health`.
- Flask fallback for all existing API v1 routes through WSGI middleware.
- ASGI parity tests for current app-v2 contract routes.
- Local Uvicorn run command.

## Not Included

- No production traffic shift.
- No Cloud Run deployment.
- No native FastAPI route migration beyond `GET /health`.
- No Supabase, Alembic, SQLAlchemy, or Modal compute changes.

## Local Smoke

Run:

```bash
FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --port 8000
```

Smoke-check:

```bash
curl -i http://localhost:8000/health
curl -i http://localhost:8000/readiness-check
curl -i http://localhost:8000/liveness-check
curl -i http://localhost:8000/zz/metadata
```

Expected behavior:

- `/health` returns FastAPI JSON: `{"status":"healthy"}`.
- `/readiness-check` and `/liveness-check` return existing Flask text `OK`.
- Existing v1 routes continue to use Flask fallback behavior.
9 changes: 9 additions & 0 deletions policyengine_api/asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""ASGI entrypoint for the Stage 2 FastAPI compatibility shell."""

from __future__ import annotations

from policyengine_api.api import app as flask_app
from policyengine_api.asgi_factory import create_asgi_app


app = application = create_asgi_app(flask_app)
52 changes: 52 additions & 0 deletions policyengine_api/asgi_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""FastAPI shell for serving the existing Flask API through ASGI."""

from __future__ import annotations

from typing import Literal

from a2wsgi import WSGIMiddleware
from fastapi import FastAPI
from pydantic import BaseModel

from policyengine_api.constants import VERSION


class HealthResponse(BaseModel):
status: Literal["healthy"]


def _add_vary_origin(response) -> None:
vary = response.headers.get("Vary")
if vary is None:
response.headers["Vary"] = "Origin"
return
if "origin" not in {value.strip().lower() for value in vary.split(",")}:
response.headers["Vary"] = f"{vary}, Origin"


def create_asgi_app(wsgi_app) -> FastAPI:
"""Create the Stage 2 FastAPI shell around the existing Flask app."""

app = FastAPI(
title="PolicyEngine API",
version=VERSION,
docs_url=None,
redoc_url=None,
openapi_url=None,
)

@app.middleware("http")
async def add_cors_for_native_routes(request, call_next):
response = await call_next(request)
origin = request.headers.get("origin")
if origin and "access-control-allow-origin" not in response.headers:
response.headers["Access-Control-Allow-Origin"] = origin
_add_vary_origin(response)
return response

@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
return HealthResponse(status="healthy")

app.mount("/", WSGIMiddleware(wsgi_app))
return app
34 changes: 34 additions & 0 deletions policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from policyengine_api.utils import budget_window as budget_window_utils
from policyengine.simulation import SimulationOptions
from policyengine.utils.data.datasets import get_default_dataset
import httpx
import json
import datetime
import hashlib
Expand Down Expand Up @@ -77,6 +78,7 @@ class ImpactStatus(Enum):
BUDGET_WINDOW_MAX_ACTIVE_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_ACTIVE_YEARS
BUDGET_WINDOW_MAX_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_YEARS
BUDGET_WINDOW_MAX_END_YEAR = budget_window_utils.BUDGET_WINDOW_MAX_END_YEAR
BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES = {400, 422}


class EconomicImpactSetupOptions(BaseModel):
Expand Down Expand Up @@ -348,6 +350,18 @@ def get_budget_window_economic_impact(
budget_window_cache.store_batch_job_id(
cache_key, batch_execution.batch_job_id
)
except httpx.HTTPStatusError as error:
budget_window_cache.clear_starting_claim(cache_key, claim_token)
if (
error.response.status_code
in BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES
):
return BudgetWindowEconomicImpactResult.failed(
self._build_budget_window_submission_error_message(error),
queued_years=years,
cache_status=cache_status,
)
raise
except Exception:
budget_window_cache.clear_starting_claim(cache_key, claim_token)
raise
Expand Down Expand Up @@ -443,6 +457,26 @@ def _start_budget_window_batch(

return simulation_api.run_budget_window_batch(sim_params)

def _build_budget_window_submission_error_message(
self, error: httpx.HTTPStatusError
) -> str:
try:
response_json = error.response.json()
except ValueError:
response_json = None

if isinstance(response_json, dict):
for key in ("detail", "message", "error"):
value = response_json.get(key)
if value:
return str(value)

response_text = error.response.text.strip()
if response_text:
return response_text

return str(error)

def _get_budget_window_result_from_batch_job_id(
self,
*,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ classifiers = [
"License :: OSI Approved :: GNU Affero General Public License v3",
]
dependencies = [
"a2wsgi>=1.10,<2",
"anthropic",
"assertpy",
"click>=8,<9",
"cloud-sql-python-connector",
"faiss-cpu",
"fastapi>=0.115,<1",
"flask>=3,<4",
"flask-cors>=5,<6",
"Flask-Caching>=2,<3",
Expand All @@ -50,6 +52,7 @@ dependencies = [
"rq",
"sqlalchemy>=2,<3",
"streamlit",
"uvicorn[standard]>=0.32,<1",
"werkzeug",
]

Expand Down
84 changes: 84 additions & 0 deletions tests/contract/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Mapping, Protocol

from fastapi.testclient import TestClient
from flask import Flask

from policyengine_api.asgi_factory import create_asgi_app


@dataclass(frozen=True)
class ContractResponse:
status_code: int
body: bytes
headers: Mapping[str, str]
content_type: str | None

@property
def data(self) -> bytes:
return self.body


class ContractClient(Protocol):
def open(
self,
path: str,
*,
method: str,
json: dict | None = None,
headers: dict | None = None,
) -> ContractResponse: ...


class FlaskContractClient:
def __init__(self, app: Flask):
self._client = app.test_client()

def open(
self,
path: str,
*,
method: str,
json: dict | None = None,
headers: dict | None = None,
) -> ContractResponse:
response = self._client.open(
path,
method=method,
json=json,
headers=headers,
)
return ContractResponse(
status_code=response.status_code,
body=response.data,
headers=dict(response.headers),
content_type=response.content_type,
)


class ASGIContractClient:
def __init__(self, app: Flask):
self._client = TestClient(create_asgi_app(app))

def open(
self,
path: str,
*,
method: str,
json: dict | None = None,
headers: dict | None = None,
) -> ContractResponse:
response = self._client.request(
method,
path,
json=json,
headers=headers,
)
return ContractResponse(
status_code=response.status_code,
body=response.content,
headers=dict(response.headers),
content_type=response.headers.get("content-type"),
)
37 changes: 27 additions & 10 deletions tests/contract/test_v1_route_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from policyengine_api.routes.policy_routes import policy_bp
from policyengine_api.routes.report_output_routes import report_output_bp
from policyengine_api.routes.simulation_routes import simulation_bp
from tests.contract.clients import (
ASGIContractClient,
ContractClient,
FlaskContractClient,
)
from tests.contract.helpers import (
assert_field_path_exists,
assert_subset,
Expand Down Expand Up @@ -121,7 +126,7 @@ def _load_contract_economy_blueprint():
)


def _client():
def create_contract_flask_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.register_blueprint(_load_contract_metadata_blueprint())
Expand All @@ -141,7 +146,17 @@ def liveness_check():
def readiness_check():
return Response("OK", status=200, mimetype="text/plain")

return app.test_client()
return app


@pytest.fixture(params=("flask_direct", "fastapi_fallback"))
def contract_client(request) -> ContractClient:
app = create_contract_flask_app()
if request.param == "flask_direct":
return FlaskContractClient(app)
if request.param == "fastapi_fallback":
return ASGIContractClient(app)
raise AssertionError(f"Unknown contract client: {request.param}")


def _resolved_path(path: str) -> str:
Expand Down Expand Up @@ -375,9 +390,12 @@ def _expected_subset(contract: ContractRequest) -> dict:
APP_V2_ROUTE_CONTRACTS,
ids=lambda contract: f"{contract.method} {contract.path}",
)
def test_app_v2_api_v1_route_contract(contract):
def test_app_v2_api_v1_route_contract(
contract: ContractRequest,
contract_client: ContractClient,
):
with _patched_route_dependencies():
response = _client().open(
response = contract_client.open(
_resolved_path(contract.path),
method=contract.method,
json=_json_payload(contract),
Expand All @@ -390,10 +408,9 @@ def test_app_v2_api_v1_route_contract(contract):
assert_field_path_exists(payload, field_path)


def test_health_routes_contract():
client = _client()
liveness = client.get("/liveness-check")
readiness = client.get("/readiness-check")
def test_health_routes_contract(contract_client: ContractClient):
liveness = contract_client.open("/liveness-check", method="GET")
readiness = contract_client.open("/readiness-check", method="GET")

assert liveness.status_code == 200
assert liveness.data == b"OK"
Expand All @@ -403,8 +420,8 @@ def test_health_routes_contract():
assert "text/plain" in readiness.content_type


def test_invalid_country_contract():
response = _client().get("/zz/metadata")
def test_invalid_country_contract(contract_client: ContractClient):
response = contract_client.open("/zz/metadata", method="GET")

assert response.status_code == 400
assert_subset(
Expand Down
Loading
Loading