Skip to content

Commit 2f2b55b

Browse files
authored
feat: Add precheck endpoint (#15)
1 parent fa0a418 commit 2f2b55b

File tree

3 files changed

+68
-5
lines changed

3 files changed

+68
-5
lines changed

unstructured_platform_plugins/etl_uvicorn/api_generator.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,25 @@ async def invoke_func(func: Callable, kwargs: Optional[dict[str, Any]] = None) -
3434
return func(**kwargs)
3535

3636

37+
def check_precheck_func(precheck_func: Callable):
38+
sig = inspect.signature(precheck_func)
39+
inputs = sig.parameters.values()
40+
outputs = sig.return_annotation
41+
if len(inputs) == 1:
42+
i = inputs[0]
43+
if i.name != "usage" or i.annotation is list:
44+
raise ValueError("the only input available for precheck is usage which must be a list")
45+
if outputs not in [None, sig.empty]:
46+
raise ValueError(f"no output should exist for precheck function, found: {outputs}")
47+
48+
3749
def generate_fast_api(
3850
app: str,
3951
method_name: Optional[str] = None,
4052
id_str: Optional[str] = None,
4153
id_method: Optional[str] = None,
54+
precheck_str: Optional[str] = None,
55+
precheck_method: Optional[str] = None,
4256
) -> FastAPI:
4357
instance = import_from_string(app)
4458
func = get_func(instance, method_name)
@@ -49,6 +63,16 @@ def generate_fast_api(
4963
plugin_id = hashlib.sha256(
5064
json.dumps(get_schema_dict(func), sort_keys=True).encode()
5165
).hexdigest()[:32]
66+
67+
precheck_func = None
68+
if precheck_str:
69+
precheck_instance = import_from_string(precheck_str)
70+
precheck_func = get_func(precheck_instance, precheck_method)
71+
elif precheck_method:
72+
precheck_func = get_func(instance, precheck_method)
73+
if precheck_func is not None:
74+
check_precheck_func(precheck_func=precheck_func)
75+
5276
logger.debug(f"set static id response to: {plugin_id}")
5377

5478
fastapi_app = FastAPI()
@@ -66,9 +90,8 @@ class InvokeResponse(BaseModel):
6690

6791
logging.getLogger("etl_uvicorn.fastapi")
6892

69-
usage: list[UsageData] = []
70-
7193
async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> InvokeResponse:
94+
usage: list[UsageData] = []
7295
request_dict = kwargs if kwargs else {}
7396
if "usage" in inspect.signature(func).parameters:
7497
request_dict["usage"] = usage
@@ -114,12 +137,29 @@ class SchemaOutputResponse(BaseModel):
114137
async def docs_redirect():
115138
return RedirectResponse("/docs")
116139

140+
class InvokePrecheckResponse(BaseModel):
141+
usage: list[UsageData]
142+
status_code: int
143+
status_code_text: Optional[str] = None
144+
117145
@fastapi_app.get("/schema")
118146
async def get_schema() -> SchemaOutputResponse:
119147
schema = get_schema_dict(func)
120148
resp = SchemaOutputResponse(inputs=schema["inputs"], outputs=schema["outputs"])
121149
return resp
122150

151+
@fastapi_app.get("/precheck")
152+
async def run_precheck() -> InvokePrecheckResponse:
153+
if precheck_func:
154+
fn_response = await wrap_fn(func=precheck_func)
155+
return InvokePrecheckResponse(
156+
status_code=fn_response.status_code,
157+
status_code_text=fn_response.status_code_text,
158+
usage=fn_response.usage,
159+
)
160+
else:
161+
return InvokePrecheckResponse(status_code=status.HTTP_200_OK, usage=[])
162+
123163
@fastapi_app.get("/id")
124164
async def get_id() -> str:
125165
return plugin_id

unstructured_platform_plugins/etl_uvicorn/main.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def api_wrapper(
3333
method_name: Optional[str] = None,
3434
plugin_id: Optional[str] = None,
3535
plugin_id_method: Optional[str] = None,
36+
precheck_app: Optional[str] = None,
37+
precheck_app_method: Optional[str] = None,
3638
**kwargs,
3739
):
3840
# Make sure logging is configured before the call to run() so any setup has the same format
@@ -44,7 +46,12 @@ def api_wrapper(
4446
)
4547
config.configure_logging()
4648
fastapi_app = generate_fast_api(
47-
app=app, method_name=method_name, id_str=plugin_id, id_method=plugin_id_method
49+
app=app,
50+
method_name=method_name,
51+
id_str=plugin_id,
52+
id_method=plugin_id_method,
53+
precheck_str=precheck_app,
54+
precheck_method=precheck_app_method,
4855
)
4956
# Explicitly map values that are manipulated in the original
5057
# call to run(), preventing **kwargs reference
@@ -86,6 +93,22 @@ def api_wrapper(
8693
help="If plugin id reference is a class, what method to wrap. "
8794
"Will fall back to __call__ if none is provided.",
8895
),
96+
click.Option(
97+
["--precheck-app"],
98+
required=False,
99+
type=str,
100+
default=None,
101+
help="If provided, must point to code to run for precheck",
102+
),
103+
click.Option(
104+
["--precheck-app-method"],
105+
required=False,
106+
type=str,
107+
default=None,
108+
help="If provided, points to a method to call on a class. "
109+
"If precheck-app not provided, assumes method "
110+
"lives on main class passes in.",
111+
),
89112
]
90113
)
91114
return cmd

unstructured_platform_plugins/etl_uvicorn/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_input_schema(func: Callable, omit: Optional[list[str]] = None) -> dict:
6464
def get_output_sig(func: Callable) -> Optional[Any]:
6565
inspect.signature(func)
6666
type_hints = get_type_hints(func)
67-
return_typing = type_hints["return"]
67+
return_typing = type_hints.get("return")
6868
outputs = return_typing if return_typing is not NoneType else None
6969
return outputs
7070

@@ -85,7 +85,7 @@ def map_inputs(func: Callable, raw_inputs: dict[str, Any]) -> dict[str, Any]:
8585
# types expected by the function when being invoked
8686
raw_inputs = raw_inputs.copy()
8787
type_info = get_type_hints(func)
88-
type_info.pop("return")
88+
type_info.pop("return", None)
8989
for field_name, type_data in type_info.items():
9090
if field_name not in raw_inputs:
9191
continue

0 commit comments

Comments
 (0)