diff --git a/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py b/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py index f13e5cc..4ee9644 100644 --- a/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py +++ b/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py @@ -1,79 +1,78 @@ from collections.abc import Callable, Coroutine -from typing import Any, TypeVar +from typing import Any, Generic, Protocol from fastapi import Request from returns.maybe import Maybe from returns.result import ResultE -from stapi_pydantic import ( - OpportunitySearchRecord, - OpportunitySearchStatus, - Order, - OrderStatus, -) - -GetOrders = Callable[ - [str | None, int, Request], - Coroutine[Any, Any, ResultE[tuple[list[Order[OrderStatus]], Maybe[str], Maybe[int]]]], -] -""" -Type alias for an async function that returns a list of existing Orders. +from stapi_pydantic import OpportunitySearchRecord, OpportunitySearchStatus, Order, OrderStatusBound -Args: - next (str | None): A pagination token. - limit (int): The maximum number of orders to return in a page. - request (Request): FastAPI's Request object. -Returns: - A tuple containing a list of orders and a pagination token. +class GetOrders(Protocol, Generic[OrderStatusBound]): + """Callable class wrapping an async method that returns a list of Orders. - - Should return returns.result.Success[tuple[list[Order], returns.maybe.Some[str]]] - if including a pagination token - - Should return returns.result.Success[tuple[list[Order], returns.maybe.Nothing]] - if not including a pagination token - - Returning returns.result.Failure[Exception] will result in a 500. -""" + Args: + next (str | None): A pagination token. + limit (int): The maximum number of orders to return in a page. + request (Request): FastAPI's Request object. -GetOrder = Callable[[str, Request], Coroutine[Any, Any, ResultE[Maybe[Order[OrderStatus]]]]] -""" -Type alias for an async function that gets details for the order with `order_id`. + Returns: + A tuple containing a list of orders and a pagination token. -Args: - order_id (str): The order ID. - request (Request): FastAPI's Request object. + - Should return returns.result.Success[tuple[list[Order], returns.maybe.Some[str]]] + if including a pagination token + - Should return returns.result.Success[tuple[list[Order], returns.maybe.Nothing]] + if not including a pagination token + - Returning returns.result.Failure[Exception] will result in a 500. + """ -Returns: - - Should return returns.result.Success[returns.maybe.Some[Order]] if order is found. - - Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied. - - Returning returns.result.Failure[Exception] will result in a 500. -""" + async def __call__( + self, + next: str | None, + limit: int, + request: Request, + ) -> ResultE[tuple[list[Order[OrderStatusBound]], Maybe[str], Maybe[int]]]: ... -T = TypeVar("T", bound=OrderStatus) +class GetOrder(Protocol, Generic[OrderStatusBound]): + """Callable class wrapping an async method that gets details for the order with `order_id`. + Args: + order_id (str): The order ID. + request (Request): FastAPI's Request object. -GetOrderStatuses = Callable[ - [str, str | None, int, Request], - Coroutine[Any, Any, ResultE[Maybe[tuple[list[T], Maybe[str]]]]], -] -""" -Type alias for an async function that gets statuses for the order with `order_id`. + Returns: + - Should return returns.result.Success[returns.maybe.Some[Order]] if order is found. + - Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied. + - Returning returns.result.Failure[Exception] will result in a 500. + """ -Args: - order_id (str): The order ID. - next (str | None): A pagination token. - limit (int): The maximum number of statuses to return in a page. - request (Request): FastAPI's Request object. + async def __call__(self, order_id: str, request: Request) -> ResultE[Maybe[Order[OrderStatusBound]]]: ... -Returns: - A tuple containing a list of order statuses and a pagination token. - - Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Some[str]]] - if order is found and including a pagination token. - - Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Nothing]]] - if order is found and not including a pagination token. - - Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied. - - Returning returns.result.Failure[Exception] will result in a 500. -""" +class GetOrderStatuses(Protocol, Generic[OrderStatusBound]): + """Callable class wrapping an async method that gets statuses for the order with `order_id`. + + Args: + order_id (str): The order ID. + next (str | None): A pagination token. + limit (int): The maximum number of statuses to return in a page. + request (Request): FastAPI's Request object. + + Returns: + A tuple containing a list of order statuses and a pagination token. + + - Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Some[str]]] + if order is found and including a pagination token. + - Should return returns.result.Success[returns.maybe.Some[tuple[list[OrderStatus], returns.maybe.Nothing]]] + if order is found and not including a pagination token. + - Should return returns.result.Success[returns.maybe.Nothing] if the order is not found or if access is denied. + - Returning returns.result.Failure[Exception] will result in a 500. + """ + + async def __call__( + self, order_id: str, _next: str | None, limit: int, request: Request + ) -> ResultE[Maybe[tuple[list[OrderStatusBound], Maybe[str]]]]: ... + GetOpportunitySearchRecords = Callable[ [str | None, int, Request], diff --git a/stapi-fastapi/src/stapi_fastapi/routers/product_router.py b/stapi-fastapi/src/stapi_fastapi/routers/product_router.py index 430ae00..a8c14d6 100644 --- a/stapi-fastapi/src/stapi_fastapi/routers/product_router.py +++ b/stapi-fastapi/src/stapi_fastapi/routers/product_router.py @@ -50,7 +50,8 @@ from stapi_fastapi.routers.utils import json_link if TYPE_CHECKING: - from stapi_fastapi.routers import RootRouter + from stapi_fastapi.routers.root_router import ConformancesSupport, RootProvider + logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ def get_prefer(prefer: str | None = Header(None)) -> str | None: return Prefer(prefer) -def build_conformances(product: Product, root_router: RootRouter) -> list[str]: +def build_conformances(product: Product, conformances_support: ConformancesSupport) -> list[str]: # FIXME we can make this check more robust if not any(conformance.startswith("https://geojson.org/schema/") for conformance in product.conformsTo): raise ValueError("product conformance does not contain at least one geojson conformance") @@ -78,7 +79,7 @@ def build_conformances(product: Product, root_router: RootRouter) -> list[str]: if product.supports_opportunity_search: conformances.add(PRODUCT_CONFORMACES.opportunities) - if product.supports_async_opportunity_search and root_router.supports_async_opportunity_search: + if product.supports_async_opportunity_search and conformances_support.supports_async_opportunity_search: conformances.add(PRODUCT_CONFORMACES.opportunities) conformances.add(PRODUCT_CONFORMACES.opportunities_async) @@ -90,20 +91,21 @@ class ProductRouter(StapiFastapiBaseRouter): def __init__( # noqa self, product: Product, - root_router: RootRouter, + root_provider: RootProvider, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.product = product - self.root_router = root_router - self.conformances = build_conformances(product, root_router) + self.root_provider = root_provider + self.conformances_support: ConformancesSupport = root_provider + self.conformances = build_conformances(product, root_provider) self.add_api_route( path="", endpoint=self.get_product, - name=f"{self.root_router.name}:{self.product.id}:{GET_PRODUCT}", + name=f"{self.root_provider.name}:{self.product.id}:{GET_PRODUCT}", methods=["GET"], summary="Retrieve this product", tags=["Products"], @@ -112,7 +114,7 @@ def __init__( # noqa self.add_api_route( path="/conformance", endpoint=self.get_product_conformance, - name=f"{self.root_router.name}:{self.product.id}:{CONFORMANCE}", + name=f"{self.root_provider.name}:{self.product.id}:{CONFORMANCE}", methods=["GET"], summary="Get conformance urls for the product", tags=["Products"], @@ -121,7 +123,7 @@ def __init__( # noqa self.add_api_route( path="/queryables", endpoint=self.get_product_queryables, - name=f"{self.root_router.name}:{self.product.id}:{GET_QUERYABLES}", + name=f"{self.root_provider.name}:{self.product.id}:{GET_QUERYABLES}", methods=["GET"], summary="Get queryables for the product", tags=["Products"], @@ -130,7 +132,7 @@ def __init__( # noqa self.add_api_route( path="/order-parameters", endpoint=self.get_product_order_parameters, - name=f"{self.root_router.name}:{self.product.id}:{GET_ORDER_PARAMETERS}", + name=f"{self.root_provider.name}:{self.product.id}:{GET_ORDER_PARAMETERS}", methods=["GET"], summary="Get order parameters for the product", tags=["Products"], @@ -157,7 +159,7 @@ async def _create_order( self.add_api_route( path="/orders", endpoint=_create_order, - name=f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}", + name=f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}", methods=["POST"], response_class=GeoJSONResponse, status_code=status.HTTP_201_CREATED, @@ -166,12 +168,13 @@ async def _create_order( ) if product.supports_opportunity_search or ( - self.product.supports_async_opportunity_search and self.root_router.supports_async_opportunity_search + self.product.supports_async_opportunity_search + and self.conformances_support.supports_async_opportunity_search ): self.add_api_route( path="/opportunities", endpoint=self.search_opportunities, - name=f"{self.root_router.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}", + name=f"{self.root_provider.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}", methods=["POST"], response_class=GeoJSONResponse, # unknown why mypy can't see the queryables property on Product, ignoring @@ -189,11 +192,11 @@ async def _create_order( tags=["Products"], ) - if product.supports_async_opportunity_search and root_router.supports_async_opportunity_search: + if product.supports_async_opportunity_search and self.conformances_support.supports_async_opportunity_search: self.add_api_route( path="/opportunities/{opportunity_collection_id}", endpoint=self.get_opportunity_collection, - name=f"{self.root_router.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}", + name=f"{self.root_provider.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}", methods=["GET"], response_class=GeoJSONResponse, summary="Get an Opportunity Collection by ID", @@ -202,17 +205,20 @@ async def _create_order( def get_product(self, request: Request) -> ProductPydantic: links = [ - json_link("self", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_PRODUCT}")), - json_link("conformance", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CONFORMANCE}")), + json_link("self", self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_PRODUCT}")), + json_link( + "conformance", self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CONFORMANCE}") + ), json_link( - "queryables", self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_QUERYABLES}") + "queryables", + self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_QUERYABLES}"), ), json_link( "order-parameters", - self.url_for(request, f"{self.root_router.name}:{self.product.id}:{GET_ORDER_PARAMETERS}"), + self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{GET_ORDER_PARAMETERS}"), ), Link( - href=self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}"), + href=self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}"), rel="create-order", type=TYPE_JSON, method="POST", @@ -220,12 +226,13 @@ def get_product(self, request: Request) -> ProductPydantic: ] if self.product.supports_opportunity_search or ( - self.product.supports_async_opportunity_search and self.root_router.supports_async_opportunity_search + self.product.supports_async_opportunity_search + and self.conformances_support.supports_async_opportunity_search ): links.append( json_link( "opportunities", - self.url_for(request, f"{self.root_router.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}"), + self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{SEARCH_OPPORTUNITIES}"), ), ) @@ -243,7 +250,8 @@ async def search_opportunities( """ # sync if not ( - self.root_router.supports_async_opportunity_search and self.product.supports_async_opportunity_search + self.product.supports_async_opportunity_search + and self.conformances_support.supports_async_opportunity_search ) or (prefer is Prefer.wait and self.product.supports_opportunity_search): return await self.search_opportunities_sync( search, @@ -298,7 +306,7 @@ async def search_opportunities_sync( case x: raise AssertionError(f"Expected code to be unreachable {x}") - if prefer is Prefer.wait and self.root_router.supports_async_opportunity_search: + if prefer is Prefer.wait and self.conformances_support.supports_async_opportunity_search: response.headers["Preference-Applied"] = "wait" return OpportunityCollection(features=features, links=links) @@ -311,10 +319,12 @@ async def search_opportunities_async( ) -> JSONResponse: match await self.product.search_opportunities_async(self, search, request): case Success(search_record): - search_record.links.append(self.root_router.opportunity_search_record_self_link(search_record, request)) + search_record.links.append( + self.root_provider.opportunity_search_record_self_link(search_record, request) + ) headers = {} headers["Location"] = str( - self.root_router.generate_opportunity_search_record_href(request, search_record.id) + self.root_provider.generate_opportunity_search_record_href(request, search_record.id) ) if prefer is not None: headers["Preference-Applied"] = "respond-async" @@ -365,8 +375,8 @@ async def create_order(self, payload: OrderPayload, request: Request, response: request, ): case Success(order): - order.links.extend(self.root_router.order_links(order, request)) - location = str(self.root_router.generate_order_href(request, order.id)) + order.links.extend(self.root_provider.order_links(order, request)) + location = str(self.root_provider.generate_order_href(request, order.id)) response.headers["Location"] = location return order # type: ignore case Failure(e) if isinstance(e, QueryablesError): @@ -385,7 +395,7 @@ async def create_order(self, payload: OrderPayload, request: Request, response: def order_link(self, request: Request, opp_req: OpportunityPayload) -> Link: return Link( - href=self.url_for(request, f"{self.root_router.name}:{self.product.id}:{CREATE_ORDER}"), + href=self.url_for(request, f"{self.root_provider.name}:{self.product.id}:{CREATE_ORDER}"), rel="create-order", type=TYPE_JSON, method="POST", @@ -420,7 +430,7 @@ async def get_opportunity_collection( "self", self.url_for( request, - f"{self.root_router.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}", + f"{self.root_provider.name}:{self.product.id}:{GET_OPPORTUNITY_COLLECTION}", opportunity_collection_id=opportunity_collection_id, ), ), diff --git a/stapi-fastapi/src/stapi_fastapi/routers/root_router.py b/stapi-fastapi/src/stapi_fastapi/routers/root_router.py index c33abc1..12a1639 100644 --- a/stapi-fastapi/src/stapi_fastapi/routers/root_router.py +++ b/stapi-fastapi/src/stapi_fastapi/routers/root_router.py @@ -1,6 +1,7 @@ import logging import traceback -from typing import Any +from abc import abstractmethod +from typing import Any, Generic, Protocol from fastapi import HTTPException, Request, status from fastapi.datastructures import URL @@ -15,6 +16,7 @@ Order, OrderCollection, OrderStatus, + OrderStatusBound, OrderStatuses, ProductsCollection, RootResponse, @@ -51,12 +53,38 @@ logger = logging.getLogger(__name__) -class RootRouter(StapiFastapiBaseRouter): +class ConformancesSupport(Protocol): + @property + @abstractmethod + def supports_async_opportunity_search(self) -> bool: ... + + +class RootProvider(ConformancesSupport): + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def opportunity_search_record_self_link( + self, opportunity_search_record: OpportunitySearchRecord, request: Request + ) -> Link: ... + + @abstractmethod + def generate_opportunity_search_record_href(self, request: Request, search_record_id: str) -> URL: ... + + @abstractmethod + def order_links(self, order: Order[Any], request: Request) -> list[Link]: ... + + @abstractmethod + def generate_order_href(self, request: Request, order_id: str) -> URL: ... + + +class RootRouter(StapiFastapiBaseRouter, RootProvider, Generic[OrderStatusBound]): def __init__( self, - get_orders: GetOrders, - get_order: GetOrder, - get_order_statuses: GetOrderStatuses | None = None, # type: ignore + get_orders: GetOrders[OrderStatusBound], + get_order: GetOrder[OrderStatusBound], + get_order_statuses: GetOrderStatuses[OrderStatusBound] | None = None, get_opportunity_search_records: GetOpportunitySearchRecords | None = None, get_opportunity_search_record: GetOpportunitySearchRecord | None = None, get_opportunity_search_record_statuses: GetOpportunitySearchRecordStatuses | None = None, @@ -77,7 +105,7 @@ def __init__( self.__get_opportunity_search_records = get_opportunity_search_records self.__get_opportunity_search_record = get_opportunity_search_record self.__get_opportunity_search_record_statuses = get_opportunity_search_record_statuses - self.name = name + self._name = name self.openapi_endpoint_name = openapi_endpoint_name self.docs_endpoint_name = docs_endpoint_name self.product_ids: list[str] = [] @@ -173,6 +201,10 @@ def __init__( self.conformances = list(_conformances) + @property + def name(self) -> str: + return self._name + def get_root(self, request: Request) -> RootResponse: links = [ json_link( @@ -240,7 +272,7 @@ def get_products(self, request: Request, next: str | None = None, limit: int = 1 async def get_orders( # noqa: C901 self, request: Request, next: str | None = None, limit: int = 10 - ) -> OrderCollection[OrderStatus]: + ) -> OrderCollection[OrderStatusBound]: links: list[Link] = [] orders_count: int | None = None match await self._get_orders(next, limit, request): @@ -271,7 +303,7 @@ async def get_orders( # noqa: C901 case _: raise AssertionError("Expected code to be unreachable") - return OrderCollection( + return OrderCollection[OrderStatusBound]( features=orders, links=links, number_matched=orders_count, @@ -306,7 +338,7 @@ async def get_order_statuses( request: Request, next: str | None = None, limit: int = 10, - ) -> OrderStatuses: # type: ignore + ) -> OrderStatuses[OrderStatusBound]: links: list[Link] = [] match await self._get_order_statuses(order_id, next, limit, request): case Success(Some((statuses, maybe_pagination_token))): @@ -350,7 +382,7 @@ def generate_order_href(self, request: Request, order_id: str) -> URL: def generate_order_statuses_href(self, request: Request, order_id: str) -> URL: return self.url_for(request, f"{self.name}:{LIST_ORDER_STATUSES}", order_id=order_id) - def order_links(self, order: Order[OrderStatus], request: Request) -> list[Link]: + def order_links(self, order: Order[Any], request: Request) -> list[Link]: return [ Link( href=self.generate_order_href(request, order.id), @@ -464,7 +496,7 @@ def opportunity_search_record_self_link( return json_link("self", self.generate_opportunity_search_record_href(request, opportunity_search_record.id)) @property - def _get_order_statuses(self) -> GetOrderStatuses: # type: ignore + def _get_order_statuses(self) -> GetOrderStatuses[OrderStatusBound]: if not self.__get_order_statuses: raise AttributeError("Root router does not support order status history") return self.__get_order_statuses diff --git a/stapi-pydantic/src/stapi_pydantic/__init__.py b/stapi-pydantic/src/stapi_pydantic/__init__.py index 44ecbd0..4fd8e17 100644 --- a/stapi-pydantic/src/stapi_pydantic/__init__.py +++ b/stapi-pydantic/src/stapi_pydantic/__init__.py @@ -21,6 +21,7 @@ OrderProperties, OrderSearchParameters, OrderStatus, + OrderStatusBound, OrderStatusCode, OrderStatuses, ) @@ -52,6 +53,7 @@ "OrderStatus", "OrderStatusCode", "OrderStatuses", + "OrderStatusBound", "Prefer", "Product", "ProductsCollection", diff --git a/stapi-pydantic/src/stapi_pydantic/order.py b/stapi-pydantic/src/stapi_pydantic/order.py index 159b341..94f7eae 100644 --- a/stapi-pydantic/src/stapi_pydantic/order.py +++ b/stapi-pydantic/src/stapi_pydantic/order.py @@ -72,11 +72,11 @@ def new( ) -T = TypeVar("T", bound=OrderStatus) +OrderStatusBound = TypeVar("OrderStatusBound", bound=OrderStatus) -class OrderStatuses(BaseModel, Generic[T]): - statuses: list[T] +class OrderStatuses(BaseModel, Generic[OrderStatusBound]): + statuses: list[OrderStatusBound] links: list[Link] = Field(default_factory=list) @@ -87,10 +87,10 @@ class OrderSearchParameters(BaseModel): filter: CQL2Filter | None = None # type: ignore [type-arg] -class OrderProperties(BaseModel, Generic[T]): +class OrderProperties(BaseModel, Generic[OrderStatusBound]): product_id: str created: AwareDatetime - status: T + status: OrderStatusBound search_parameters: OrderSearchParameters opportunity_properties: dict[str, Any] @@ -100,7 +100,7 @@ class OrderProperties(BaseModel, Generic[T]): # derived from geojson_pydantic.Feature -class Order(_GeoJsonBase, Generic[T]): +class Order(_GeoJsonBase, Generic[OrderStatusBound]): # We need to enforce that orders have an id defined, as that is required to # retrieve them via the API id: StrictStr @@ -109,7 +109,7 @@ class Order(_GeoJsonBase, Generic[T]): stapi_version: str = STAPI_VERSION geometry: Geometry = Field(...) - properties: OrderProperties[T] = Field(...) + properties: OrderProperties[OrderStatusBound] = Field(...) links: list[Link] = Field(default_factory=list) @@ -125,15 +125,15 @@ def set_geometry(cls, geometry: Any) -> Any: # derived from geojson_pydantic.FeatureCollection -class OrderCollection(_GeoJsonBase, Generic[T]): +class OrderCollection(_GeoJsonBase, Generic[OrderStatusBound]): type: Literal["FeatureCollection"] = "FeatureCollection" - features: list[Order[T]] + features: list[Order[OrderStatusBound]] links: list[Link] = Field(default_factory=list) number_matched: int | None = Field( serialization_alias="numberMatched", default=None, exclude_if=lambda x: x is None ) - def __iter__(self) -> Iterator[Order[T]]: # type: ignore [override] + def __iter__(self) -> Iterator[Order[OrderStatusBound]]: # type: ignore [override] """iterate over features""" return iter(self.features) @@ -141,7 +141,7 @@ def __len__(self) -> int: """return features length""" return len(self.features) - def __getitem__(self, index: int) -> Order[T]: + def __getitem__(self, index: int) -> Order[OrderStatusBound]: """get feature at a given index""" return self.features[index] diff --git a/stapi-pydantic/tests/test_json_schema.py b/stapi-pydantic/tests/test_json_schema.py index e9bc38e..a787345 100644 --- a/stapi-pydantic/tests/test_json_schema.py +++ b/stapi-pydantic/tests/test_json_schema.py @@ -1,5 +1,5 @@ from pydantic import TypeAdapter -from stapi_pydantic.datetime_interval import DatetimeInterval +from stapi_pydantic import DatetimeInterval def test_datetime_interval() -> None: