-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathapi.py
More file actions
680 lines (624 loc) · 31.2 KB
/
Copy pathapi.py
File metadata and controls
680 lines (624 loc) · 31.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
"""
FlashRT — Public API.
3 lines of code to run VLA inference:
import flash_rt
model = flash_rt.load_model(
checkpoint="/path/to/checkpoint",
framework="torch",
autotune=3,
)
actions = model.predict(images=[base_img, wrist_img],
prompt="pick up the red block")
# actions: np.ndarray (10, 7)
"""
import logging
import os
# Silence ``torch_xla``'s "Defaulting to PJRT_DEVICE=CPU" warning that
# fires when openpi (pulled in by the Pi0.5 torch frontend for the
# PaligemmaTokenizer) drags transformers→accelerate→torch_xla. We don't
# use XLA on the torch path, so the warning is pure noise. ``setdefault``
# preserves any value the user has already configured.
os.environ.setdefault("PJRT_DEVICE", "CUDA")
import numpy as np
logger = logging.getLogger(__name__)
class VLAModel:
"""Unified VLA inference model. Wraps ThorPipelineTorch or ThorPipelineJax."""
def __init__(self, pipe, framework: str):
self._pipe = pipe
self._framework = framework
self._current_prompt = None
self._current_prompt_state = None
# rtx Pi0.5 (RtxTorchPi05) requires an explicit
# ``calibrate_with_real_data([obs])`` call before the first
# ``infer()``; Thor / rtx GROOT lazy-calibrate inside ``infer()``.
# Track whether we still need to bootstrap calibration so that
# first predict() can call it exactly once.
self._needs_real_data_calibration = (
hasattr(pipe, "calibrate_with_real_data")
and hasattr(pipe, "calibrated")
)
@staticmethod
def _snapshot_prompt_state(state):
if state is None:
return None
try:
return np.asarray(state).copy()
except Exception:
return state
@staticmethod
def _prompt_state_equal(a, b) -> bool:
if a is None or b is None:
return a is b
try:
return np.array_equal(np.asarray(a), np.asarray(b))
except Exception:
return a is b
def predict(self, images, prompt=None, state=None):
"""Run inference.
Args:
images: list of numpy arrays (224,224,3) uint8 or float16.
Or a dict with 'image'/'wrist_image' keys.
prompt: text prompt. Only needed on first call or when changing prompt.
If None, reuses the last prompt.
state: optional robot state array. It is forwarded to
set_prompt() for frontends that encode state in prompt
tokens, and attached to the observation for frontends that
consume state during infer().
Returns:
np.ndarray: actions
"""
if prompt is None and self._current_prompt is None:
raise ValueError("prompt is required on first call")
prompt_for_call = self._current_prompt if prompt is None else prompt
prompt_changed = prompt is not None and prompt != self._current_prompt
prompt_state_changed = False
if hasattr(self._pipe, 'set_prompt'):
import inspect
sig = inspect.signature(self._pipe.set_prompt)
prompt_accepts_state = 'state' in sig.parameters
if prompt_accepts_state:
prompt_state_changed = not self._prompt_state_equal(
self._current_prompt_state, state)
else:
sig = None
prompt_accepts_state = False
if prompt_changed or prompt_state_changed:
if hasattr(self._pipe, 'set_prompt'):
if prompt_accepts_state:
self._pipe.set_prompt(prompt_for_call, state=state)
else:
self._pipe.set_prompt(prompt_for_call)
self._current_prompt = prompt_for_call
self._current_prompt_state = self._snapshot_prompt_state(state)
if isinstance(images, dict):
obs = dict(images)
elif isinstance(images, (list, tuple)):
if len(images) == 0:
raise ValueError("images list must have at least one frame")
# Use the "images" list form so backends that support
# variable num_views (rtx Pi0.5, etc.) don't choke on the
# 1-view case. Also populate the legacy image / wrist_image
# / wrist_image_right keys so Thor-style backends that only
# read those still see the right frames.
obs = {'images': list(images), 'image': images[0]}
if len(images) >= 2:
obs['wrist_image'] = images[1]
if len(images) >= 3:
obs['wrist_image_right'] = images[2]
else:
raise ValueError("images must be a list of numpy arrays or a dict")
if state is not None and "state" not in obs:
obs["state"] = state
# RTX Pi0.5 can swap in a different cached pipeline when a changing
# state prompt hits a new token length. Re-check that frontend's
# calibration flag instead of relying only on the first-call latch.
needs_real_data_calibration = self._needs_real_data_calibration
if (hasattr(self._pipe, "_prompt_pipeline_cache")
and not getattr(self._pipe, "calibrated", False)):
needs_real_data_calibration = True
if (needs_real_data_calibration
and hasattr(self._pipe, "calibrate_with_real_data")):
self._pipe.calibrate_with_real_data([obs])
self._needs_real_data_calibration = False
result = self._pipe.infer(obs)
return result['actions']
def warm_state_prompt_buckets(self, images, prompt, states):
"""Pre-build Pi0.5 state-prompt runtime buckets.
Pi0.5 encodes robot state in the text prompt. Different state
values can tokenize to different lengths; warming representative
states up front prevents the control loop from paying graph
capture/autotune the first time each length appears.
"""
if not hasattr(self._pipe, "warm_state_prompt_buckets"):
raise NotImplementedError(
"This frontend does not expose state prompt bucket warmup.")
if isinstance(images, dict):
obs = dict(images)
elif isinstance(images, (list, tuple)):
if len(images) == 0:
raise ValueError("images list must have at least one frame")
obs = {"images": list(images), "image": images[0]}
if len(images) >= 2:
obs["wrist_image"] = images[1]
if len(images) >= 3:
obs["wrist_image_right"] = images[2]
else:
raise ValueError("images must be a list of numpy arrays or a dict")
lengths = self._pipe.warm_state_prompt_buckets(prompt, states, obs)
self._needs_real_data_calibration = False
self._current_prompt = None
self._current_prompt_state = None
return lengths
def set_prompt(self, *args, **kwargs):
"""Delegate prompt setup to the selected frontend."""
if not hasattr(self._pipe, "set_prompt"):
raise NotImplementedError(
"This frontend does not expose set_prompt().")
result = self._pipe.set_prompt(*args, **kwargs)
if "prompt" in kwargs:
self._current_prompt = kwargs["prompt"]
elif args and isinstance(args[0], str):
self._current_prompt = args[0]
try:
import inspect
sig = inspect.signature(self._pipe.set_prompt)
params = list(sig.parameters)
if "state" in sig.parameters:
state_pos = params.index("state")
if "state" in kwargs:
state = kwargs["state"]
elif len(args) > state_pos:
state = args[state_pos]
else:
state = None
self._current_prompt_state = self._snapshot_prompt_state(state)
except (TypeError, ValueError):
pass
return result
def infer(self, *args, **kwargs):
"""Delegate inference to the selected frontend."""
if not hasattr(self._pipe, "infer"):
raise NotImplementedError(
"This frontend does not expose infer().")
return self._pipe.infer(*args, **kwargs)
def calibrate(
self,
observations,
*,
percentile: float = 99.9,
max_samples: int | None = None,
verbose: bool = False,
) -> None:
"""Unified calibration entry point.
Args:
observations: single dict or iterable of dicts. N=1 triggers
the single-frame calibration path (back-compatible). Frontends
that document N>=2 support run dataset calibration with
percentile-clipped amax reduction; unsupported frontends raise
a clear NotImplementedError from their calibrate() method.
percentile: percentile for multi-sample amax reduction. 99.9
by default; 100.0 == traditional max.
max_samples: optional cap.
verbose: log dispersion summary after reduction.
See ``docs/calibration.md`` for full guidance.
"""
if not hasattr(self._pipe, "calibrate"):
raise NotImplementedError(
"This frontend does not expose a public calibrate() API. "
"Upgrade to a recent version of FlashRT that includes "
"the unified calibration interface.")
self._pipe.calibrate(
observations,
percentile=percentile,
max_samples=max_samples,
verbose=verbose,
)
# Any lazy-bootstrap was just handled explicitly — prevent
# predict() from double-triggering it.
self._needs_real_data_calibration = False
@property
def precision_spec(self):
"""Return the :class:`ModelPrecisionSpec` captured at calibration
time, or None if the frontend does not surface it yet."""
return getattr(self._pipe, "precision_spec", None)
def recalibrate(self):
"""Force recalibration on next set_prompt().
Use after fine-tuning or switching deployment domains.
Clears calibration cache (and weight cache for JAX).
"""
from flash_rt.core.quant.calibrator import clear_calibration
clear_calibration(self._pipe._checkpoint_path)
if self._framework == "jax":
from flash_rt.core.weights.weight_cache import clear_weight_cache
clear_weight_cache(self._pipe._checkpoint_path)
self._pipe.calibrated = False
self._pipe._real_data_calibrated = False
self._current_prompt = None # force re-set_prompt
logger.info("Caches cleared. Next predict() will recalibrate.")
@property
def framework(self):
return self._framework
@property
def prompt(self):
return self._current_prompt
def load_model(checkpoint, framework="torch", num_views=2, autotune=3,
recalibrate=False, weight_cache=True, config="pi05", device=None,
decode_cuda_graph=False, decode_graph_steps=80,
max_decode_steps=256,
hardware="auto",
embodiment_tag=None,
action_horizon=None,
use_fp4=False,
fp4_layers=None,
use_awq=None,
awq_alpha=0.5,
use_p1_split_gu=None,
num_steps=None,
vision_pool_factor=None,
vision_num_layers=None,
cache_frames=None,
use_fp16=False,
use_fp8=True,
state_prompt_mode="exact",
state_prompt_fixed_max_len=None):
"""Load a FlashRT model.
Args:
checkpoint: path to checkpoint directory.
- torch: safetensors directory
- jax: Orbax checkpoint directory
framework: "torch" or "jax"
num_views: number of camera views (default 2)
autotune: CUDA Graph autotune intensity.
0 or False = off (fastest startup, ~2ms slower inference risk)
3 = default (Torch finds fast graph on trial 0-1)
5+ = thorough (JAX may need more trials for fast graph)
True = same as 3
recalibrate: if True, ignore cached calibration (and weight cache for JAX)
and force fresh FP8 quantization + calibration.
weight_cache: if True (default), cache FP8-quantized weights to disk
after first load. Only affects JAX.
config: model config name: "pi05", "pi0", "groot", "groot_n17",
"pi0fast", "motus", "wan22_ti2v_5b", "cosmos3_video".
"cosmos3_video" is a non-VLA text2video denoise model: drive it with
set_prompt(ref=<reference dump>) + infer(...), not predict().
device: ignored (auto-detects GPU). Reserved for future multi-GPU.
decode_cuda_graph: Pi0-FAST only. Capture action-phase decode as CUDA
Graph for max throughput (trades startup time for per-token speed).
decode_graph_steps: Pi0-FAST only. Number of action tokens to capture
in the decode graph (default 80).
hardware: GPU backend selection. ``"auto"`` (default) detects the
current CUDA device via compute capability and picks the
best-matching backend:
SM110 (Jetson Thor) → ``flash_rt.hardware.thor.*``
SM120 (RTX 5090) → ``flash_rt.hardware.rtx.*``
(falls back to Thor classes for models
without an rtx-specific implementation —
those classes have SM120 runtime forks
where needed, e.g. Pi0-FAST.)
SM89 (RTX 4090) → ``flash_rt.hardware.rtx.*``
SM87 (Jetson Orin) → ``flash_rt.hardware.rtx.*`` (experimental,
Pi0.5 torch only; BF16 default, INT8
via Orin env flags)
Pass ``"thor"`` / ``"rtx_sm120"`` / ``"rtx_sm89"`` /
``"rtx_sm87"`` explicitly to
force a specific backend (useful for cross-hardware debugging).
embodiment_tag: GROOT only. Per-embodiment MLP slot to load. Passing
``None`` uses the backend default (``"new_embodiment"`` — unfit
for the base 3B checkpoint demo; see below). The GR00T-N1.6-3B
base checkpoint is only actually trained on a subset of its 32
slots. For a working demo pick one of ``"gr1"``,
``"robocasa_panda_omron"``, or ``"behavior_r1_pro"``. Any other
tag prints a warning and emits noise-like actions.
action_horizon: GROOT only. Number of action steps to generate per
inference (default = ``ACTION_HORIZON_MAX`` = 50). Set to a
smaller value (e.g. 16 for LIBERO) to reduce DiT compute.
use_fp4: Pi0.5 torch and JAX on Thor. If True, enable NVFP4
quantization on the selected encoder FFN layers (Gate+Up + Down
GEMMs). Requires SM100+ GPU (Thor SM110) and the flash_rt_fp4
extension. Uses the FP8 route with a warning if the extension is
unavailable. Default False (production FP8 baseline).
Torch uses safetensors checkpoints; JAX uses Orbax checkpoints.
Validated on LIBERO Spatial for the torch path: 491/500 = 98.2%
(matches baseline). JAX FP4 has Thor precision / replay-latency
validation against a same-origin PyTorch reference.
fp4_layers: Tuple of encoder layer indices to FP4-quantize (only
applies when use_fp4=True). ``None`` resolves to the production
preset, full 18 encoder FFN layers with AWQ + P1 split-GU.
Explicit tuples override the preset; `(7, 8, 9)` is the
conservative middle-FFN subset.
use_fp8: Enable FP8 execution where the selected frontend supports
an FP8/BF16 switch. Defaults to True to preserve existing
performance-oriented behavior.
use_fp16: Opt-in non-quantized full-FP16 path for Pi0.5 on RTX
SM120/SM89, GROOT N1.6 on Thor/RTX SM120, and GROOT N1.7 on
Thor/RTX SM120/SM89. Only valid with ``use_fp8=False``; an A/B
reference against the quantized default. On GROOT N1.7 the
default is FP8 (FP8 backbone + bf16 DiT), so ``use_fp8=False``
without ``use_fp16=True`` raises.
num_steps: Pi0/Pi0.5 torch only when supported. Number of
flow-matching ODE steps. ``None`` uses the frontend default.
vision_pool_factor: Pi0.5 torch RTX/Orin only. Spatial pooling factor
for vision tokens; valid values are 1, 2, or 4. ``None`` keeps
the frontend default.
vision_num_layers: Pi0.5 torch RTX/Orin only. Number of SigLIP vision
layers to execute; valid range is 1-27. ``None`` keeps the
frontend default.
cache_frames: Pi0.5 torch RTX/Orin only. Temporal K/V reuse period.
1 runs the full vision+encoder+decoder path on every frame; 2
alternates full and decoder-only frames. ``None`` keeps the
frontend default.
state_prompt_mode: Pi0.5 RTX/Thor only. How the variable-length
state-in-prompt is mapped to CUDA graphs:
``"exact"`` (default) — graph shape tracks the exact prompt
length. RTX caches recurring lengths and can front-load them
with ``warm_state_prompt_buckets()``; Thor reuses same-length
updates and recaptures when the exact length changes.
``"fixed"`` — ONE graph at the max prompt length serves every
length (padded prefix masked by a device-side valid length +
decoder K/V appended at the valid offset); a changing length
never re-captures and no warmup is needed. RTX uses the
vendored bf16 FA2 path; Thor uses its cuBLAS-decomposed
attention path.
Cost: every inference runs at the padded max length, so it is
~1 ms slower than a warmed ``"exact"`` graph (split-KV decoder
joint-attention keeps the padding overhead small on RTX; Thor
uses its cuBLAS-decomposed attention path with device-side
valid-length masking). Prefer ``"fixed"`` when the state-token
length drifts and you'd rather not enumerate/warm lengths;
prefer ``"exact"`` + warmup for absolute peak latency at known
lengths.
Env override: ``FLASHRT_PI05_STATE_PROMPT_MODE``.
state_prompt_fixed_max_len: Pi0.5 Thor fixed mode only. Padded
state-prompt token capacity used when ``state_prompt_mode="fixed"``.
``None`` keeps the frontend default (200 tokens). Lower this when
the serving stack can bound the live state prompt length; for
example, a cap near the actual length (120 vs. 117 tokens) measured
about a 1 ms normal overhead on Thor versus warmed exact mode.
Env override: ``FLASHRT_PI05_STATE_PROMPT_FIXED_MAX_LEN``.
Returns:
VLAModel instance with .predict() method.
"""
if config not in ("pi05", "groot", "groot_n17", "pi0", "pi0fast",
"motus", "wan22_ti2v_5b", "cosmos3_video", "nexn2"):
raise ValueError(
f"Unknown config: {config}. "
f"Supported: pi05, groot, groot_n17, pi0, pi0fast, motus, "
f"wan22_ti2v_5b, cosmos3_video, nexn2")
if framework not in ("torch", "jax"):
raise ValueError(
f"Unknown framework: {framework}. Supported: torch, jax")
# When use_fp4=True, the default resolves to the best-known production
# FP4 config (full 18 encoder FFN layers + AWQ + P1 split-GU). Passing
# any sub-flag explicitly overrides the preset; None means "use preset".
if use_fp4:
if fp4_layers is None:
fp4_layers = tuple(range(18))
if use_awq is None:
use_awq = True
if use_p1_split_gu is None:
use_p1_split_gu = True
else:
if fp4_layers is None:
fp4_layers = (7, 8, 9)
if use_awq is None:
use_awq = False
if use_p1_split_gu is None:
use_p1_split_gu = False
# Nex-N2-mini (qwen3_5_moe) is a text LLM, not a VLA: its frontend exposes
# infer()->logits / generate() rather than the predict(images, ...) surface
# that load_model's VLAModel wraps. It is registered in _PIPELINE_MAP for
# discoverability but constructed directly (checked before GPU detection so
# the redirect fires on any machine).
if config == "nexn2":
raise NotImplementedError(
"config='nexn2' is a text LLM and is not served through "
"load_model's VLA wrapper. Construct it directly:\n"
" from flash_rt.frontends.torch.nexn2_rtx import "
"Nexn2TorchFrontendRtx\n"
"See docs/nexn2_usage.md.")
from flash_rt.hardware import detect_arch, resolve_pipeline_class
arch = detect_arch() if hardware == "auto" else hardware
if recalibrate:
from flash_rt.core.quant.calibrator import clear_calibration
try:
clear_calibration(checkpoint)
except FileNotFoundError:
pass
if framework == "jax":
from flash_rt.core.weights.weight_cache import clear_weight_cache
try:
clear_weight_cache(checkpoint)
except FileNotFoundError:
pass
logger.info("Caches cleared for %s", checkpoint)
if framework == "jax":
os.environ.setdefault(
"XLA_FLAGS",
"--xla_gpu_enable_triton_gemm=false --xla_gpu_autotune_level=0")
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
if use_fp16:
if use_fp8:
raise ValueError("use_fp16=True requires use_fp8=False")
if (config, framework, arch) not in {
("pi05", "torch", "rtx_sm120"),
("pi05", "torch", "rtx_sm89"),
("groot", "torch", "thor"),
("groot", "torch", "rtx_sm120"),
("groot_n17", "torch", "thor"),
("groot_n17", "torch", "rtx_sm120"),
("groot_n17", "torch", "rtx_sm89"),
}:
raise ValueError(
"use_fp16=True is currently experimental and only supports "
"('pi05', 'torch', 'rtx_sm120'/'rtx_sm89'), "
"('groot', 'torch', 'thor'/'rtx_sm120'), and "
"('groot_n17', 'torch', 'thor'/'rtx_sm120'/'rtx_sm89')")
pipe_cls = resolve_pipeline_class(config, framework, arch)
# GROOT N1.7 on RTX defaults to the framework-conforming FP8 frontend.
# rtx_sm120 keeps the historical shared-base registration and is refined
# here to the explicit FP8 production frontend. rtx_sm89 is already
# registered directly to its dedicated FP8 frontend in _PIPELINE_MAP.
if config == "groot_n17" and framework == "torch" \
and arch in ("rtx_sm120", "rtx_sm89") and not use_fp16:
if not use_fp8:
raise ValueError(
"GROOT N1.7 on RTX defaults to FP8; there is no separate "
"non-FP16 BF16 fallback. For the non-quantized full-FP16 "
"reference pass use_fp16=True, use_fp8=False.")
if arch == "rtx_sm120":
from flash_rt.frontends.torch.groot_n17_rtx_fp8 import (
GrootN17TorchFrontendRtxFP8,
)
pipe_cls = GrootN17TorchFrontendRtxFP8
# GROOT N1.7 on Thor (SM110) runs the FP8 backbone (+ bf16 DiT) by
# default. There is no BF16-only fallback; the non-quantized reference is
# the explicit full-FP16 path (use_fp16=True with use_fp8=False), so a
# bare use_fp8=False is rejected rather than silently ignored.
if config == "groot_n17" and framework == "torch" and arch == "thor" \
and not use_fp16 and not use_fp8:
raise ValueError(
"GROOT N1.7 on Thor defaults to FP8; there is no BF16-only "
"fallback. For the non-quantized full-FP16 reference pass "
"use_fp16=True together with use_fp8=False.")
if use_fp16:
# GROOT N1.6 Thor full-FP16 reference: the same fully-kernelized,
# CUDA-graph pipeline as the FP8 production frontend, with the GEMMs run
# in FP16 instead of per-tensor FP8 (an A/B accuracy baseline).
if config == "groot" and framework == "torch" and arch == "thor":
from flash_rt.frontends.torch.groot_thor_fp16 import (
GrootTorchFrontendThorFP16,
)
pipe_cls = GrootTorchFrontendThorFP16
elif config == "groot_n17" and framework == "torch" and arch == "thor":
# N1.7 Thor full-FP16 reference (no FP8): ViT / DeepStack / LLM /
# VL-self-attn run fp16_nn on the shadow weights, and the DiT
# action head runs the bf16 (non-FP8) graph (_DIT_USE_FP8=False).
from flash_rt.frontends.torch.groot_n17_thor_fp16 import (
GrootN17TorchFrontendThorFP16,
)
pipe_cls = GrootN17TorchFrontendThorFP16
else:
if config == "pi05":
from flash_rt.frontends.torch.pi05_rtx_fp16 import (
Pi05TorchFrontendRtxFP16,
)
pipe_cls = Pi05TorchFrontendRtxFP16
elif config == "groot":
from flash_rt.frontends.torch.groot_rtx_fp16 import (
GrootTorchFrontendRtxFP16,
)
pipe_cls = GrootTorchFrontendRtxFP16
else: # config == "groot_n17"
if arch == "rtx_sm89":
from flash_rt.frontends.torch.groot_n17_rtx_sm89_fp16 import (
GrootN17TorchFrontendRtxSm89FP16,
)
pipe_cls = GrootN17TorchFrontendRtxSm89FP16
else:
from flash_rt.frontends.torch.groot_n17_rtx_fp16 import (
GrootN17TorchFrontendRtxFP16,
)
pipe_cls = GrootN17TorchFrontendRtxFP16
# ── FP4 routing (Pi0.5 torch + Pi0.5 JAX on Thor) ──
if use_fp4:
if config != "pi05" or framework not in ("torch", "jax") or arch != "thor":
logger.warning(
"use_fp4=True is only supported for config='pi05' with "
"framework in ('torch', 'jax') on Thor; got config='%s' "
"framework='%s' hardware='%s'. Falling back to FP8.",
config, framework, arch)
use_fp4 = False
else:
try:
import flash_rt.flash_rt_fp4 as _fvk_fp4
if not _fvk_fp4.has_nvfp4():
logger.warning(
"flash_rt_fp4 loaded but has_nvfp4()=False (SM100+ required). "
"Falling back to FP8.")
use_fp4 = False
except ImportError:
logger.warning(
"flash_rt_fp4 extension not available. Falling back to FP8.")
use_fp4 = False
if use_fp4:
if framework == "torch":
from flash_rt.frontends.torch.pi05_thor_fp4 import (
Pi05TorchFrontendThorFP4,
)
pipe_cls = Pi05TorchFrontendThorFP4
else: # framework == "jax"
from flash_rt.frontends.jax.pi05_thor_fp4 import (
Pi05JaxFrontendThorFP4,
)
pipe_cls = Pi05JaxFrontendThorFP4
logger.info(
"FP4 enabled (framework=%s): encoder FFN layers %s",
framework, sorted(fp4_layers))
# Build the kwarg set per-model so we only pass args the target class
# actually accepts. Keeps the dispatch table simple while still letting
# users specify groot/pi0fast knobs.
import inspect
sig = inspect.signature(pipe_cls)
kwargs: dict = {"num_views": num_views}
if "hardware" in sig.parameters:
kwargs["hardware"] = arch
if "use_fp8" in sig.parameters:
kwargs["use_fp8"] = use_fp8
if config == "pi0fast":
kwargs.update(
autotune=autotune,
decode_cuda_graph=decode_cuda_graph,
decode_graph_steps=decode_graph_steps,
max_decode_steps=max_decode_steps,
)
elif config in ("groot", "groot_n17"):
# rtx-side GROOT accepts embodiment_tag + action_horizon; Thor-side
# GROOT accepts embodiment_tag + autotune. Feature-detect via the
# concrete class signature so one call site works for both.
if "autotune" in sig.parameters:
kwargs["autotune"] = autotune
if "embodiment_tag" in sig.parameters and embodiment_tag is not None:
kwargs["embodiment_tag"] = embodiment_tag
if "action_horizon" in sig.parameters and action_horizon is not None:
kwargs["action_horizon"] = action_horizon
elif config == "wan22_ti2v_5b":
if "autotune" in sig.parameters:
kwargs["autotune"] = autotune
else:
# pi05, pi0 — both Thor and rtx variants take (checkpoint, num_views, autotune)
# or (checkpoint, num_views). Feature-detect.
if "autotune" in sig.parameters:
kwargs["autotune"] = autotune
if "weight_cache" in sig.parameters:
kwargs["weight_cache"] = weight_cache
# Orin-specific performance parameters (passed only when accepted and set).
if num_steps is not None and "num_steps" in sig.parameters:
kwargs["num_steps"] = num_steps
if vision_pool_factor is not None and "vision_pool_factor" in sig.parameters:
kwargs["vision_pool_factor"] = vision_pool_factor
if vision_num_layers is not None and "vision_num_layers" in sig.parameters:
kwargs["vision_num_layers"] = vision_num_layers
if cache_frames is not None and "cache_frames" in sig.parameters:
kwargs["cache_frames"] = cache_frames
# Pi0.5 state-in-prompt graph strategy: "exact" (default, per-length
# capture) / "fixed" (opt-in, one graph). Forwarded only if accepted.
if "state_prompt_mode" in sig.parameters:
kwargs["state_prompt_mode"] = state_prompt_mode
if (state_prompt_fixed_max_len is not None and
"state_prompt_fixed_max_len" in sig.parameters):
kwargs["state_prompt_fixed_max_len"] = state_prompt_fixed_max_len
# FP4 frontend accepts these extra kwargs (only set when the class
# actually accepts them — base class ignores, FP4 subclass uses).
if use_fp4 and "use_fp4_encoder_ffn" in sig.parameters:
kwargs["use_fp4_encoder_ffn"] = True
kwargs["fp4_layers"] = fp4_layers
if "use_awq" in sig.parameters:
kwargs["use_awq"] = bool(use_awq)
kwargs["awq_alpha"] = float(awq_alpha)
if "use_p1_split_gu" in sig.parameters:
kwargs["use_p1_split_gu"] = bool(use_p1_split_gu)
pipe = pipe_cls(checkpoint, **kwargs)
logger.info(
"Model loaded: config=%s, framework=%s, arch=%s, class=%s",
config, framework, arch, pipe_cls.__name__)
return VLAModel(pipe, framework)