Skip to content

Commit 374e71f

Browse files
feat: Add run_id parameter to Tesseract.apply and friends (#352)
<!-- Please use a PR title that conforms to *conventional commits*: "<commit_type>: Describe your change"; for example: "fix: prevent race condition". Some other commit types are: fix, feat, ci, doc, refactor... For a full list of commit types visit https://www.conventionalcommits.org/en/v1.0.0/ --> #### Relevant issue or PR <!-- If the changes resolve an issue or follow some other PR, link to them here. Only link something if it is directly relevant. --> #### Description of changes Adds an optional run_id parameter to `Tesseract.apply`, jacobian, vjp and jvp methods. If passed, this can be used to correlate the call with the outputs under `output_path/run_{run_id}`. #### Testing done <!-- Describe how the changes were tested; e.g., "CI passes", "Tested manually in stagingrepo#123", screenshots of a terminal session that verify the changes, or any other evidence of testing the changes. -->
1 parent fbac668 commit 374e71f

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

tesseract_core/sdk/tesseract.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,19 @@ def available_endpoints(self) -> list[str]:
310310
return [endpoint.lstrip("/") for endpoint in self.openapi_schema["paths"]]
311311

312312
@requires_client
313-
def apply(self, inputs: dict) -> dict:
313+
def apply(self, inputs: dict, run_id: str | None = None) -> dict:
314314
"""Run apply endpoint.
315315
316316
Args:
317317
inputs: a dictionary with the inputs.
318+
run_id: a string to identify the run. Run outputs will be located
319+
in a directory suffixed with this id.
318320
319321
Returns:
320322
dictionary with the results.
321323
"""
322324
payload = {"inputs": inputs}
323-
return self._client.run_tesseract("apply", payload)
325+
return self._client.run_tesseract("apply", payload, run_id)
324326

325327
@requires_client
326328
def abstract_eval(self, abstract_inputs: dict) -> dict:
@@ -346,14 +348,20 @@ def health(self) -> dict:
346348

347349
@requires_client
348350
def jacobian(
349-
self, inputs: dict, jac_inputs: list[str], jac_outputs: list[str]
351+
self,
352+
inputs: dict,
353+
jac_inputs: list[str],
354+
jac_outputs: list[str],
355+
run_id: str | None = None,
350356
) -> dict:
351357
"""Calculate the Jacobian of (some of the) outputs w.r.t. (some of the) inputs.
352358
353359
Args:
354360
inputs: a dictionary with the inputs.
355361
jac_inputs: Inputs with respect to which derivatives will be calculated.
356362
jac_outputs: Outputs which will be differentiated.
363+
run_id: a string to identify the run. Run outputs will be located
364+
in a directory suffixed with this id.
357365
358366
Returns:
359367
dictionary with the results.
@@ -366,7 +374,7 @@ def jacobian(
366374
"jac_inputs": jac_inputs,
367375
"jac_outputs": jac_outputs,
368376
}
369-
return self._client.run_tesseract("jacobian", payload)
377+
return self._client.run_tesseract("jacobian", payload, run_id)
370378

371379
@requires_client
372380
def jacobian_vector_product(
@@ -375,6 +383,7 @@ def jacobian_vector_product(
375383
jvp_inputs: list[str],
376384
jvp_outputs: list[str],
377385
tangent_vector: dict,
386+
run_id: str | None = None,
378387
) -> dict:
379388
"""Calculate the Jacobian Vector Product (JVP) of (some of the) outputs w.r.t. (some of the) inputs.
380389
@@ -383,6 +392,8 @@ def jacobian_vector_product(
383392
jvp_inputs: Inputs with respect to which derivatives will be calculated.
384393
jvp_outputs: Outputs which will be differentiated.
385394
tangent_vector: Element of the tangent space to multiply with the Jacobian.
395+
run_id: a string to identify the run. Run outputs will be located
396+
in a directory suffixed with this id.
386397
387398
Returns:
388399
dictionary with the results.
@@ -398,7 +409,7 @@ def jacobian_vector_product(
398409
"jvp_outputs": jvp_outputs,
399410
"tangent_vector": tangent_vector,
400411
}
401-
return self._client.run_tesseract("jacobian_vector_product", payload)
412+
return self._client.run_tesseract("jacobian_vector_product", payload, run_id)
402413

403414
@requires_client
404415
def vector_jacobian_product(
@@ -407,6 +418,7 @@ def vector_jacobian_product(
407418
vjp_inputs: list[str],
408419
vjp_outputs: list[str],
409420
cotangent_vector: dict,
421+
run_id: str | None = None,
410422
) -> dict:
411423
"""Calculate the Vector Jacobian Product (VJP) of (some of the) outputs w.r.t. (some of the) inputs.
412424
@@ -415,6 +427,8 @@ def vector_jacobian_product(
415427
vjp_inputs: Inputs with respect to which derivatives will be calculated.
416428
vjp_outputs: Outputs which will be differentiated.
417429
cotangent_vector: Element of the cotangent space to multiply with the Jacobian.
430+
run_id: a string to identify the run. Run outputs will be located
431+
in a directory suffixed with this id.
418432
419433
420434
Returns:
@@ -431,7 +445,7 @@ def vector_jacobian_product(
431445
"vjp_outputs": vjp_outputs,
432446
"cotangent_vector": cotangent_vector,
433447
}
434-
return self._client.run_tesseract("vector_jacobian_product", payload)
448+
return self._client.run_tesseract("vector_jacobian_product", payload, run_id)
435449

436450

437451
def _tree_map(func: Callable, tree: Any, is_leaf: Callable | None = None) -> Any:
@@ -506,7 +520,11 @@ def url(self) -> str:
506520
return self._url
507521

508522
def _request(
509-
self, endpoint: str, method: str = "GET", payload: dict | None = None
523+
self,
524+
endpoint: str,
525+
method: str = "GET",
526+
payload: dict | None = None,
527+
run_id: str | None = None,
510528
) -> dict:
511529
url = f"{self.url}/{endpoint.lstrip('/')}"
512530

@@ -517,7 +535,10 @@ def _request(
517535
else:
518536
encoded_payload = None
519537

520-
response = requests.request(method=method, url=url, json=encoded_payload)
538+
params = {"run_id": run_id} if run_id is not None else {}
539+
response = requests.request(
540+
method=method, url=url, json=encoded_payload, params=params
541+
)
521542

522543
if response.status_code == requests.codes.unprocessable_entity:
523544
# Try and raise a more helpful error if the response is a Pydantic error
@@ -570,12 +591,16 @@ def _request(
570591

571592
return data
572593

573-
def run_tesseract(self, endpoint: str, payload: dict | None = None) -> dict:
594+
def run_tesseract(
595+
self, endpoint: str, payload: dict | None = None, run_id: str | None = None
596+
) -> dict:
574597
"""Run a Tesseract endpoint.
575598
576599
Args:
577600
endpoint: The endpoint to run.
578601
payload: The payload to send to the endpoint.
602+
run_id: a string to identify the run. Run outputs will be located
603+
in a directory suffixed with this id.
579604
580605
Returns:
581606
The loaded JSON response from the endpoint, with decoded arrays.
@@ -591,7 +616,7 @@ def run_tesseract(self, endpoint: str, payload: dict | None = None) -> dict:
591616
if endpoint == "openapi_schema":
592617
endpoint = "openapi.json"
593618

594-
return self._request(endpoint, method, payload)
619+
return self._request(endpoint, method, payload, run_id)
595620

596621

597622
class LocalClient:
@@ -606,12 +631,15 @@ def __init__(self, tesseract_api: ModuleType) -> None:
606631
}
607632
self._openapi_schema = create_rest_api(tesseract_api).openapi()
608633

609-
def run_tesseract(self, endpoint: str, payload: dict | None = None) -> dict:
634+
def run_tesseract(
635+
self, endpoint: str, payload: dict | None = None, run_id: str | None = None
636+
) -> dict:
610637
"""Run a Tesseract endpoint.
611638
612639
Args:
613640
endpoint: The endpoint to run.
614641
payload: The payload to send to the endpoint.
642+
run_id: a string to identify the run.
615643
616644
Returns:
617645
The loaded JSON response from the endpoint, with decoded arrays.

tests/sdk_tests/test_tesseract.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ def test_serve_lifecycle(mock_serving, mock_clients):
130130
pass
131131

132132

133-
def test_HTTPClient_run_tesseract(mocker):
133+
@pytest.mark.parametrize(
134+
"run_id",
135+
[None, "fizzbuzz"],
136+
)
137+
def test_HTTPClient_run_tesseract(mocker, run_id):
134138
mock_response = mocker.Mock()
135139
mock_response.json.return_value = {"result": [4, 4, 4]}
136140
mock_response.raise_for_status = mocker.Mock()
@@ -142,13 +146,15 @@ def test_HTTPClient_run_tesseract(mocker):
142146

143147
client = HTTPClient("somehost")
144148

145-
out = client.run_tesseract("apply", {"inputs": {"a": 1}})
149+
out = client.run_tesseract("apply", {"inputs": {"a": 1}}, run_id=run_id)
146150

147151
assert out == {"result": [4, 4, 4]}
152+
expected_params = {} if run_id is None else {"run_id": run_id}
148153
mocked_request.assert_called_with(
149154
method="POST",
150155
url="http://somehost/apply",
151156
json={"inputs": {"a": 1}},
157+
params=expected_params,
152158
)
153159

154160

0 commit comments

Comments
 (0)