Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 57 additions & 58 deletions stapi-fastapi/src/stapi_fastapi/backends/root_backend.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,78 @@
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar
from typing import Any, Generic, Protocol

from fastapi import Request
from returns.maybe import Maybe
from returns.result import ResultE
from stapi_pydantic import (
OpportunitySearchRecord,
OpportunitySearchStatus,
Order,
OrderStatus,
)

GetOrders = Callable[
[str | None, int, Request],
Coroutine[Any, Any, ResultE[tuple[list[Order[OrderStatus]], Maybe[str], Maybe[int]]]],
]
"""
Type alias for an async function that returns a list of existing Orders.
from stapi_pydantic import OpportunitySearchRecord, OpportunitySearchStatus, Order, OrderStatusBound

Args:
next (str | None): A pagination token.
limit (int): The maximum number of orders to return in a page.
request (Request): FastAPI's Request object.

Returns:
A tuple containing a list of orders and a pagination token.
class GetOrders(Protocol, Generic[OrderStatusBound]):
"""Callable class wrapping an async method that returns a list of Orders.

- Should return returns.result.Success[tuple[list[Order], returns.maybe.Some[str]]]
if including a pagination token
- Should return returns.result.Success[tuple[list[Order], returns.maybe.Nothing]]
if not including a pagination token
- Returning returns.result.Failure[Exception] will result in a 500.
"""
Args:
next (str | None): A pagination token.
limit (int): The maximum number of orders to return in a page.
request (Request): FastAPI's Request object.

GetOrder = Callable[[str, Request], Coroutine[Any, Any, ResultE[Maybe[Order[OrderStatus]]]]]
"""
Type alias for an async function that gets details for the order with `order_id`.
Returns:
A tuple containing a list of orders and a pagination token.

Args:
order_id (str): The order ID.
request (Request): FastAPI's Request object.
- Should return returns.result.Success[tuple[list[Order], returns.maybe.Some[str]]]
if including a pagination token
- Should return returns.result.Success[tuple[list[Order], returns.maybe.Nothing]]
if not including a pagination token
- Returning returns.result.Failure[Exception] will result in a 500.
"""

Returns:
- Should return returns.result.Success[returns.maybe.Some[Order]] if order is found.
- Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied.
- Returning returns.result.Failure[Exception] will result in a 500.
"""
async def __call__(
self,
next: str | None,
limit: int,
request: Request,
) -> ResultE[tuple[list[Order[OrderStatusBound]], Maybe[str], Maybe[int]]]: ...


T = TypeVar("T", bound=OrderStatus)
class GetOrder(Protocol, Generic[OrderStatusBound]):
"""Callable class wrapping an async method that gets details for the order with `order_id`.

Args:
order_id (str): The order ID.
request (Request): FastAPI's Request object.

GetOrderStatuses = Callable[
[str, str | None, int, Request],
Coroutine[Any, Any, ResultE[Maybe[tuple[list[T], Maybe[str]]]]],
]
"""
Type alias for an async function that gets statuses for the order with `order_id`.
Returns:
- Should return returns.result.Success[returns.maybe.Some[Order]] if order is found.
- Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied.
- Returning returns.result.Failure[Exception] will result in a 500.
"""

Args:
order_id (str): The order ID.
next (str | None): A pagination token.
limit (int): The maximum number of statuses to return in a page.
request (Request): FastAPI's Request object.
async def __call__(self, order_id: str, request: Request) -> ResultE[Maybe[Order[OrderStatusBound]]]: ...

Returns:
A tuple containing a list of order statuses and a pagination token.

- Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Some[str]]]
if order is found and including a pagination token.
- Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Nothing]]]
if order is found and not including a pagination token.
- Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied.
- Returning returns.result.Failure[Exception] will result in a 500.
"""
class GetOrderStatuses(Protocol, Generic[OrderStatusBound]):
"""Callable class wrapping an async method that gets statuses for the order with `order_id`.

Args:
order_id (str): The order ID.
next (str | None): A pagination token.
limit (int): The maximum number of statuses to return in a page.
request (Request): FastAPI's Request object.

Returns:
A tuple containing a list of order statuses and a pagination token.

- Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Some[str]]]
if order is found and including a pagination token.
- Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Nothing]]]
if order is found and not including a pagination token.
- Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied.
- Returning returns.result.Failure[Exception] will result in a 500.
"""

async def __call__(
self, order_id: str, _next: str | None, limit: int, request: Request
) -> ResultE[Maybe[tuple[list[OrderStatusBound], Maybe[str]]]]: ...


GetOpportunitySearchRecords = Callable[
[str | None, int, Request],
Expand Down
70 changes: 40 additions & 30 deletions stapi-fastapi/src/stapi_fastapi/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
from stapi_fastapi.routers.utils import json_link

if TYPE_CHECKING:
from stapi_fastapi.routers import RootRouter
from stapi_fastapi.routers.root_router import ConformancesSupport, RootProvider


logger = logging.getLogger(__name__)

Expand All @@ -68,7 +69,7 @@ def get_prefer(prefer: str | None = Header(None)) -> str | None:
return Prefer(prefer)


def build_conformances(product: Product, root_router: RootRouter) -> list[str]:
def build_conformances(product: Product, conformances_support: ConformancesSupport) -> list[str]:
# FIXME we can make this check more robust
if not any(conformance.startswith("https://geojson.org/schema/") for conformance in product.conformsTo):
raise ValueError("product conformance does not contain at least one geojson conformance")
Expand All @@ -78,7 +79,7 @@ def build_conformances(product: Product, root_router: RootRouter) -> list[str]:
if product.supports_opportunity_search:
conformances.add(PRODUCT_CONFORMACES.opportunities)

if product.supports_async_opportunity_search and root_router.supports_async_opportunity_search:
if product.supports_async_opportunity_search and conformances_support.supports_async_opportunity_search:
conformances.add(PRODUCT_CONFORMACES.opportunities)
conformances.add(PRODUCT_CONFORMACES.opportunities_async)

Expand All @@ -90,20 +91,21 @@ class ProductRouter(StapiFastapiBaseRouter):
def __init__( # noqa
self,
product: Product,
root_router: RootRouter,
root_provider: RootProvider,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self.product = product
self.root_router = root_router
self.conformances = build_conformances(product, root_router)
self.root_provider = root_provider
self.conformances_support: ConformancesSupport = root_provider
self.conformances = build_conformances(product, root_provider)

self.add_api_route(
path="",
endpoint=self.get_product,
name=f"{self.root_router.name}:{self.product.id}:{GET_PRODUCT}",
name=f"{self.root_provider.name}:{self.product.id}:{GET_PRODUCT}",
methods=["GET"],
summary="Retrieve this product",
tags=["Products"],
Expand All @@ -112,7 +114,7 @@ def __init__( # noqa
self.add_api_route(
path="/conformance",
endpoint=self.get_product_conformance,
name=f"{self.root_router.name}:{self.product.id}:{CONFORMANCE}",
name=f"{self.root_provider.name}:{self.product.id}:{CONFORMANCE}",
methods=["GET"],
summary="Get conformance urls for the product",
tags=["Products"],
Expand All @@ -121,7 +123,7 @@ def __init__( # noqa
self.add_api_route(
path="/queryables",
endpoint=self.get_product_queryables,
name=f"{self.root_router.name}:{self.product.id}:{GET_QUERYABLES}",
name=f"{self.root_provider.name}:{self.product.id}:{GET_QUERYABLES}",
methods=["GET"],
summary="Get queryables for the product",
tags=["Products"],
Expand All @@ -130,7 +132,7 @@ def __init__( # noqa
self.add_api_route(
path="/order-parameters",
endpoint=self.get_product_order_parameters,
name=f"{self.root_router.name}:{self.product.id}:{GET_ORDER_PARAMETERS}",
name=f"{self.root_provider.name}:{self.product.id}:{GET_ORDER_PARAMETERS}",
methods=["GET"],
summary="Get order parameters for the product",
tags=["Products"],
Expand All @@ -157,7 +159,7 @@ async def _create_order(
self.add_api_route(
path="/orders",
endpoint=_create_order,
name=f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}",
name=f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}",
methods=["POST"],
response_class=GeoJSONResponse,
status_code=status.HTTP_201_CREATED,
Expand All @@ -166,12 +168,13 @@ async def _create_order(
)

if product.supports_opportunity_search or (
self.product.supports_async_opportunity_search and self.root_router.supports_async_opportunity_search
self.product.supports_async_opportunity_search
and self.conformances_support.supports_async_opportunity_search
):
self.add_api_route(
path="/opportunities",
endpoint=self.search_opportunities,
name=f"{self.root_router.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}",
name=f"{self.root_provider.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}",
methods=["POST"],
response_class=GeoJSONResponse,
# unknown why mypy can't see the queryables property on Product, ignoring
Expand All @@ -189,11 +192,11 @@ async def _create_order(
tags=["Products"],
)

if product.supports_async_opportunity_search and root_router.supports_async_opportunity_search:
if product.supports_async_opportunity_search and self.conformances_support.supports_async_opportunity_search:
self.add_api_route(
path="/opportunities/{opportunity_collection_id}",
endpoint=self.get_opportunity_collection,
name=f"{self.root_router.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}",
name=f"{self.root_provider.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}",
methods=["GET"],
response_class=GeoJSONResponse,
summary="Get an Opportunity Collection by ID",
Expand All @@ -202,30 +205,34 @@ async def _create_order(

def get_product(self, request: Request) -> ProductPydantic:
links = [
json_link("self", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_PRODUCT}")),
json_link("conformance", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CONFORMANCE}")),
json_link("self", self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_PRODUCT}")),
json_link(
"conformance", self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CONFORMANCE}")
),
json_link(
"queryables", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_QUERYABLES}")
"queryables",
self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_QUERYABLES}"),
),
json_link(
"order-parameters",
self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_ORDER_PARAMETERS}"),
self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_ORDER_PARAMETERS}"),
),
Link(
href=self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}"),
href=self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}"),
rel="create-order",
type=TYPE_JSON,
method="POST",
),
]

if self.product.supports_opportunity_search or (
self.product.supports_async_opportunity_search and self.root_router.supports_async_opportunity_search
self.product.supports_async_opportunity_search
and self.conformances_support.supports_async_opportunity_search
):
links.append(
json_link(
"opportunities",
self.url_for(request, f"{self.root_router.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}"),
self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}"),
),
)

Expand All @@ -243,7 +250,8 @@ async def search_opportunities(
"""
# sync
if not (
self.root_router.supports_async_opportunity_search and self.product.supports_async_opportunity_search
self.product.supports_async_opportunity_search
and self.conformances_support.supports_async_opportunity_search
) or (prefer is Prefer.wait and self.product.supports_opportunity_search):
return await self.search_opportunities_sync(
search,
Expand Down Expand Up @@ -298,7 +306,7 @@ async def search_opportunities_sync(
case x:
raise AssertionError(f"Expected code to be unreachable {x}")

if prefer is Prefer.wait and self.root_router.supports_async_opportunity_search:
if prefer is Prefer.wait and self.conformances_support.supports_async_opportunity_search:
response.headers["Preference-Applied"] = "wait"

return OpportunityCollection(features=features, links=links)
Expand All @@ -311,10 +319,12 @@ async def search_opportunities_async(
) -> JSONResponse:
match await self.product.search_opportunities_async(self, search, request):
case Success(search_record):
search_record.links.append(self.root_router.opportunity_search_record_self_link(search_record, request))
search_record.links.append(
self.root_provider.opportunity_search_record_self_link(search_record, request)
)
headers = {}
headers["Location"] = str(
self.root_router.generate_opportunity_search_record_href(request, search_record.id)
self.root_provider.generate_opportunity_search_record_href(request, search_record.id)
)
if prefer is not None:
headers["Preference-Applied"] = "respond-async"
Expand Down Expand Up @@ -365,8 +375,8 @@ async def create_order(self, payload: OrderPayload, request: Request, response:
request,
):
case Success(order):
order.links.extend(self.root_router.order_links(order, request))
location = str(self.root_router.generate_order_href(request, order.id))
order.links.extend(self.root_provider.order_links(order, request))
location = str(self.root_provider.generate_order_href(request, order.id))
response.headers["Location"] = location
return order # type: ignore
case Failure(e) if isinstance(e, QueryablesError):
Expand All @@ -385,7 +395,7 @@ async def create_order(self, payload: OrderPayload, request: Request, response:

def order_link(self, request: Request, opp_req: OpportunityPayload) -> Link:
return Link(
href=self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}"),
href=self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}"),
rel="create-order",
type=TYPE_JSON,
method="POST",
Expand Down Expand Up @@ -420,7 +430,7 @@ async def get_opportunity_collection(
"self",
self.url_for(
request,
f"{self.root_router.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}",
f"{self.root_provider.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}",
opportunity_collection_id=opportunity_collection_id,
),
),
Expand Down
Loading