|
26 | 26 | read_from_path, |
27 | 27 | write_to_path, |
28 | 28 | ) |
| 29 | +from tesseract_core.runtime.finite_differences import ( |
| 30 | + check_gradients as check_gradients_, |
| 31 | +) |
29 | 32 | from tesseract_core.runtime.serve import create_rest_api |
30 | 33 | from tesseract_core.runtime.serve import serve as serve_ |
31 | 34 |
|
@@ -152,6 +155,154 @@ def check() -> None: |
152 | 155 | typer.echo("✅ Tesseract API check successful ✅") |
153 | 156 |
|
154 | 157 |
|
| 158 | +@tesseract_runtime.command() |
| 159 | +@click.argument( |
| 160 | + "payload", |
| 161 | + type=click.STRING, |
| 162 | + required=True, |
| 163 | + metavar="JSON_PAYLOAD", |
| 164 | + callback=_parse_arg_callback, |
| 165 | +) |
| 166 | +@click.option( |
| 167 | + "--endpoints", |
| 168 | + type=click.STRING, |
| 169 | + required=False, |
| 170 | + multiple=True, |
| 171 | + help="Endpoints to check gradients for (default: check all).", |
| 172 | +) |
| 173 | +@click.option( |
| 174 | + "--input-paths", |
| 175 | + type=click.STRING, |
| 176 | + required=False, |
| 177 | + multiple=True, |
| 178 | + help="Paths to differentiable inputs to check gradients for (default: check all).", |
| 179 | +) |
| 180 | +@click.option( |
| 181 | + "--output-paths", |
| 182 | + type=click.STRING, |
| 183 | + required=False, |
| 184 | + multiple=True, |
| 185 | + help="Paths to differentiable outputs to check gradients for (default: check all).", |
| 186 | +) |
| 187 | +@click.option( |
| 188 | + "--eps", |
| 189 | + type=click.FLOAT, |
| 190 | + required=False, |
| 191 | + help="Step size for finite differences.", |
| 192 | + default=1e-4, |
| 193 | + show_default=True, |
| 194 | +) |
| 195 | +@click.option( |
| 196 | + "--rtol", |
| 197 | + type=click.FLOAT, |
| 198 | + required=False, |
| 199 | + help="Relative tolerance when comparing finite differences to gradients.", |
| 200 | + default=0.1, |
| 201 | + show_default=True, |
| 202 | +) |
| 203 | +@click.option( |
| 204 | + "--max-evals", |
| 205 | + type=click.INT, |
| 206 | + required=False, |
| 207 | + help="Maximum number of evaluations per input.", |
| 208 | + default=1000, |
| 209 | + show_default=True, |
| 210 | +) |
| 211 | +@click.option( |
| 212 | + "--max-failures", |
| 213 | + type=click.INT, |
| 214 | + required=False, |
| 215 | + help="Maximum number of failures to report per endpoint.", |
| 216 | + default=10, |
| 217 | + show_default=True, |
| 218 | +) |
| 219 | +@click.option( |
| 220 | + "--seed", |
| 221 | + type=click.INT, |
| 222 | + required=False, |
| 223 | + help="Seed for random number generator. If not set, a random seed is used.", |
| 224 | + default=None, |
| 225 | +) |
| 226 | +@click.option( |
| 227 | + "--show-progress", |
| 228 | + is_flag=True, |
| 229 | + default=True, |
| 230 | + help="Show progress bar.", |
| 231 | +) |
| 232 | +def check_gradients( |
| 233 | + payload, |
| 234 | + input_paths, |
| 235 | + output_paths, |
| 236 | + endpoints, |
| 237 | + eps, |
| 238 | + rtol, |
| 239 | + max_evals, |
| 240 | + max_failures, |
| 241 | + seed, |
| 242 | + show_progress, |
| 243 | +) -> None: |
| 244 | + """Check gradients of endpoints against a finite difference approximation. |
| 245 | +
|
| 246 | + This is an automated way to check the correctness of the gradients of the different AD endpoints |
| 247 | + (jacobian, jacobian_vector_product, vector_jacobian_product) of a ``tesseract_api.py`` module. |
| 248 | + It will sample random indices and compare the gradients computed by the AD endpoints with the |
| 249 | + finite difference approximation. |
| 250 | +
|
| 251 | + Warning: |
| 252 | + Finite differences are not exact and the comparison is done with a tolerance. This means |
| 253 | + that the check may fail even if the gradients are correct, and vice versa. |
| 254 | +
|
| 255 | + Finite difference approximations are sensitive to numerical precision. When finite differences |
| 256 | + are reported incorrectly as 0.0, it is likely that the chosen `eps` is too small, especially for |
| 257 | + inputs that do not use float64 precision. |
| 258 | + """ |
| 259 | + api_module = get_tesseract_api() |
| 260 | + inputs, base_dir = payload |
| 261 | + |
| 262 | + result_iter = check_gradients_( |
| 263 | + api_module, |
| 264 | + inputs, |
| 265 | + base_dir=base_dir, |
| 266 | + input_paths=input_paths, |
| 267 | + output_paths=output_paths, |
| 268 | + endpoints=endpoints, |
| 269 | + max_evals=max_evals, |
| 270 | + eps=eps, |
| 271 | + rtol=rtol, |
| 272 | + seed=seed, |
| 273 | + show_progress=show_progress, |
| 274 | + ) |
| 275 | + |
| 276 | + failed = False |
| 277 | + for endpoint, failures, num_evals in result_iter: |
| 278 | + if not failures: |
| 279 | + typer.echo( |
| 280 | + f"✅ Gradient check for {endpoint} passed ✅ ({len(failures)} failures / {num_evals} checks)" |
| 281 | + ) |
| 282 | + else: |
| 283 | + failed = True |
| 284 | + typer.echo() |
| 285 | + typer.echo( |
| 286 | + f"⚠️ Gradient check for {endpoint} failed ⚠️ ({len(failures)} failures / {num_evals} checks)" |
| 287 | + ) |
| 288 | + printed_failures = min(len(failures), max_failures) |
| 289 | + typer.echo(f"First {printed_failures} failures:") |
| 290 | + for failure in failures[:printed_failures]: |
| 291 | + typer.echo( |
| 292 | + f" Input path: '{failure.in_path}', Output path: '{failure.out_path}', Index: {failure.idx}" |
| 293 | + ) |
| 294 | + if failure.exception: |
| 295 | + typer.echo(f" Encountered exception: {failure.exception}") |
| 296 | + else: |
| 297 | + typer.echo(f" {endpoint} value: {failure.grad_val}") |
| 298 | + typer.echo(f" Finite difference value: {failure.ref_val}") |
| 299 | + typer.echo() |
| 300 | + |
| 301 | + if failed: |
| 302 | + typer.echo("❌ Some gradient checks failed ❌") |
| 303 | + sys.exit(1) |
| 304 | + |
| 305 | + |
155 | 306 | @tesseract_runtime.command() |
156 | 307 | @click.option("-p", "--port", default=8000, help="Port number") |
157 | 308 | @click.option("-h", "--host", default="0.0.0.0", help="Host IP address") |
|
0 commit comments