Skip to content

Commit d18ee0c

Browse files
doc: Updated comments in jax recipe and docs on Differentiable flag (#65)
#### Description of changes Explained how to specify static numerical values in Jax recipe. Added clarifications to docs warning on Differentiable. #### Testing done N/A #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: [Jonathan Brodrick] <jonathan.brodrick@simulation.science> --------- Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
1 parent 3c82729 commit d18ee0c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

docs/content/creating-tesseracts/create.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,10 @@ class InputSchema(BaseModel):
212212

213213
A key feature of Tesseracts is their ability to expose endpoints for calculating various kinds of derivatives when the operation they implement is differentiable, which in turn makes it possible to combine multiple Tesseracts into automatically differentiable workflows! This is advantageous in multiple contexts: shape optimization, model calibration, and so on.
214214

215-
In order to make your Tesseract differentiable with respect to some parameter, you need
216-
to mark that parameter with the {py:class}`tesseract_core.runtime.Differentiable` type annotation.
217-
The outputs that can be differentiated should be marked as `Differentiable` as well.
218-
219-
In other words, all outputs marked as `Differentiable` will be considered differentiable with respect to
215+
Keeping with one of Tesseract's key foci being *validation*, the type annotation {py:class}`tesseract_core.runtime.Differentiable` is introduced to mark outputs that can be differentiated, and inputs that can be differentiated with respect to.
216+
All outputs marked as `Differentiable` will be considered differentiable with respect to
220217
all inputs marked as `Differentiable`.
218+
Attempting to differentiate (with respect to) an output/input (e.g. by passing `jac_inputs=["non_differentiable_arg"]` to the `jacobian` endpoint) will raise a validation error even before the endpoint is invoked.
221219

222220
For example:
223221

@@ -239,7 +237,7 @@ Here, it will be possible in principle to differentiate `a` in the Tesseract's o
239237
parameter `x` and with respect to each of the components of the matrix `r` -- but not with respect to `s`.
240238

241239
```{warning}
242-
Differentiable can only be used on {py:class}`tesseract_core.runtime.Array` types, which includes aliases for
240+
`Differentiable` can only be used on {py:class}`tesseract_core.runtime.Array` types, which includes aliases for
243241
rank 0 tensors like {py:class}`Float64 <tesseract_core.runtime.Float64>`. Do not use it on
244242
Python base types -- things like `Differentiable[float]` will trigger errors.
245243
```

tesseract_core/sdk/templates/jax/tesseract_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
#
1919

2020

21+
# Note: This template uses equinox filter_jit to automatically treat non-array
22+
# inputs/outputs as static. As Tesseract scalar objects (e.g. Float32) are
23+
# essentially just wrappers around numpy 0D arrays, they will be considered to
24+
# be dynamic and will be traced by JAX.
25+
# If you want to treat numerical values as scalar you will need to use
26+
# built-in Python types (e.g. float, int) instead of Float32.
27+
28+
2129
class InputSchema(BaseModel):
2230
example: Differentiable[Float32]
2331

0 commit comments

Comments
 (0)