Skip to content

Commit b587f4d

Browse files
dionhaefnerxalelax
andauthored
feat: allow creating tesseract objects from python modules (#122)
#### Relevant issue or PR n/a #### Description of changes - Introduce `Tesseract.from_tesseract_api` to create Tesseract objects directly from importable `tesseract_api.py` files. (This is mainly useful for debugging.) - Add `from_url` classmethod (same as `Tesseract.__init__`) for symmetry. - Make `Tesseract.serve` and `Tesseract.teardown` public to give users the option to avoid excessive context managers. #### Testing done CI, old and new #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Dion Häfner <dion.haefner@simulation.science> --------- Co-authored-by: Alessandro Angioi <alessandro.angioi@simulation.science>
1 parent 578f929 commit b587f4d

File tree

11 files changed

+331
-81
lines changed

11 files changed

+331
-81
lines changed

inject_runtime_pyproject.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@
4949

5050

5151
class RuntimeDepenencyHook(MetadataHookInterface):
52+
"""Injects runtime dependencies and scripts from a separate pyproject.toml file."""
53+
5254
PLUGIN_NAME = "runtime-deps"
5355

54-
def update(self, metadata):
56+
def update(self, metadata: dict) -> dict:
57+
"""Update the metadata with runtime dependencies and scripts."""
5558
runtime_metadata = toml.load(Path(self.root) / RUNTIME_PYPROJECT_PATH)
5659
metadata["optional-dependencies"] = {
5760
**BASE_OPTIONAL_DEPS,

ruff.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ignore = [
4646

4747
[lint.extend-per-file-ignores]
4848
# Ignore missing docstrings and type annotations for selected directories
49-
"./*.py" = ["D101", "D102", "D103", "D106", "ANN"]
49+
"docs/*" = ["D101", "D102", "D103", "D106", "ANN"]
5050
"tests/*" = ["D101", "D102", "D103", "D106", "ANN"]
5151
"examples/**/*" = ["D101", "D102", "D103", "D106", "ANN"]
5252
"tesseract_core/sdk/templates/*" = ["D101", "D102", "D103", "D106", "ANN"]

tesseract_core/runtime/cli.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,16 @@ def check() -> None:
230230
help="Show progress bar.",
231231
)
232232
def check_gradients(
233-
payload,
234-
input_paths,
235-
output_paths,
236-
endpoints,
237-
eps,
238-
rtol,
239-
max_evals,
240-
max_failures,
241-
seed,
242-
show_progress,
233+
payload: tuple[dict[str, Any], Optional[Path]],
234+
input_paths: list[str],
235+
output_paths: list[str],
236+
endpoints: list[str],
237+
eps: float,
238+
rtol: float,
239+
max_evals: int,
240+
max_failures: int,
241+
seed: Optional[int],
242+
show_progress: bool,
243243
) -> None:
244244
"""Check gradients of endpoints against a finite difference approximation.
245245

tesseract_core/runtime/finite_differences.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import traceback
5-
from collections.abc import Sequence
5+
from collections.abc import Iterator, Sequence
66
from functools import wraps
77
from pathlib import Path
88
from types import ModuleType
@@ -29,6 +29,17 @@
2929

3030

3131
class GradientCheckResult(NamedTuple):
32+
"""Result of a gradient check (Jacobian row).
33+
34+
Attributes:
35+
in_path: The input path of the gradient check.
36+
out_path: The output path of the gradient check.
37+
idx: The row index of the gradient check.
38+
grad_val: The value of the gradient at the given index.
39+
ref_val: The value of the reference gradient at the given index.
40+
exception: The exception raised during the gradient check, if any.
41+
"""
42+
3243
in_path: str
3344
out_path: str
3445
idx: tuple[int, ...]
@@ -150,7 +161,7 @@ def _jacobian_via_apply(
150161
apply_fn = endpoints_func["apply"]
151162
ApplySchema = get_input_schema(apply_fn)
152163

153-
def _perturbed_apply(inputs, eps):
164+
def _perturbed_apply(inputs: dict[str, Any], eps: float) -> dict[str, Any]:
154165
input_val = get_at_path(inputs, input_path).copy()
155166
if input_idx:
156167
# array
@@ -183,7 +194,7 @@ def _jacobian_via_jacobian(
183194
"""Compute a Jacobian row using the jacobian endpoint."""
184195
jac_fn = endpoints_func["jacobian"]
185196

186-
def _jacobian(inputs):
197+
def _jacobian(inputs: dict[str, Any]) -> dict[str, Any]:
187198
JacSchema = get_input_schema(jac_fn)
188199
return jac_fn(
189200
JacSchema.model_validate(
@@ -406,7 +417,7 @@ def check_gradients(
406417
rtol: float = 0.1,
407418
seed: Optional[int] = None,
408419
show_progress: bool = True,
409-
):
420+
) -> Iterator[tuple[str, list[GradientCheckResult], int]]:
410421
"""Check gradients of endpoints against a finite difference approximation.
411422
412423
Args:

tesseract_core/runtime/schema_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def _is_regex_pattern(pattern: str) -> bool:
411411

412412

413413
def input_path_validator(path: str, info: ValidationInfo) -> str:
414+
"""Validate that the given path points to a valid input key."""
414415
if "[" in path or "{" in path:
415416
try:
416417
get_at_path(info.data["inputs"], path)

tesseract_core/runtime/schema_types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
EllipsisType = type(Ellipsis)
3737

3838

39-
def _ensure_valid_shapedtype(expected_shape, expected_dtype) -> tuple:
39+
def _ensure_valid_shapedtype(expected_shape: Any, expected_dtype: Any) -> tuple:
4040
if not isinstance(expected_shape, (tuple, EllipsisType)):
4141
raise ValueError(
4242
"Shape in Array[<shape>, <dtype>] must be a tuple or '...' (ellipsis)"
@@ -79,7 +79,8 @@ def __class_getitem__(
7979
) -> AnnotatedType:
8080
expected_shape, expected_dtype = _ensure_valid_shapedtype(*key)
8181

82-
def validate(shapedtype):
82+
def validate(shapedtype: ShapeDType) -> ShapeDType:
83+
"""Validator to check if the shape and dtype match the expected values."""
8384
if isinstance(shapedtype, ShapeDType):
8485
shape = shapedtype.shape
8586
if expected_shape is Ellipsis:
@@ -95,6 +96,7 @@ def validate(shapedtype):
9596

9697
@classmethod
9798
def from_array_annotation(cls, obj: AnnotatedType) -> AnnotatedType:
99+
"""Create a ShapeDType from an array annotation."""
98100
shape = obj.__metadata__[0].expected_shape
99101
dtype = obj.__metadata__[0].expected_dtype
100102
return cls[shape, dtype]

tesseract_core/sdk/api_parse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,17 @@ def assert_relative_path(value: str) -> str:
7474

7575

7676
class PipRequirements(BaseModel):
77+
"""Configuration options for Python environments built via pip."""
78+
7779
provider: Literal["python-pip"]
7880
_filename: Literal["tesseract_requirements.txt"] = "tesseract_requirements.txt"
7981
_build_script: Literal["build_pip_venv.sh"] = "build_pip_venv.sh"
8082
model_config: ConfigDict = ConfigDict(extra="forbid")
8183

8284

8385
class CondaRequirements(BaseModel):
86+
"""Configuration options for Python environments built via conda."""
87+
8488
provider: Literal["conda"]
8589
_filename: Literal["tesseract_environment.yaml"] = "tesseract_environment.yaml"
8690
_build_script: Literal["build_conda_venv.sh"] = "build_conda_venv.sh"

tesseract_core/sdk/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def get_runtime_dir() -> Path:
243243
return Path(tesseract_core.__file__).parent / "runtime"
244244

245245

246-
def get_template_dir():
246+
def get_template_dir() -> Path:
247247
"""Get the template directory for the Tesseract runtime."""
248248
import tesseract_core
249249

@@ -345,7 +345,7 @@ def prepare_build_context(
345345
for dependency in remote_dependencies:
346346
f.write(f"{dependency}\n")
347347

348-
def _ignore_pycache(_, names: list[str]) -> list[str]:
348+
def _ignore_pycache(_: Any, names: list[str]) -> list[str]:
349349
ignore = []
350350
if "__pycache__" in names:
351351
ignore.append("__pycache__")

0 commit comments

Comments
 (0)