Skip to content

Conversation

@rka97
Copy link
Contributor

@rka97 rka97 commented Jan 27, 2026

Migrate workloads from V100 to A100s. This PR:

  • Updates the runtime budgets for the workloads
  • Sets the default matmul precision to TF32 for JAX and PyTorch
  • Tunes the number of workers for the imagenet PyTorch input pipeline

Timing comparisons:

<style type="text/css"></style>

Workload step_hint JAX step time JAX PyTorch tf32 step time Pytorch tf32 Reference
criteo1tb_jax 10666 0.835 2.473 0.942 2.792 2.14
fastmri_jax 18094 0.277 1.395 0.376 1.891 1.23
imagenet_resnet_jax 195999 0.259 14.079 0.300 16.344 18.38
imagenet_vit_jax 167999 0.386 18.003 0.321 14.971 19.38
librispeech_conformer_jax 76000 0.599 12.637 0.418 8.824 16.12
librispeech_deepspeech_jax 38400 0.981 10.462 0.557 5.940 12.33
lm_jax 72000 0.448 8.958 0.409 8.182  
ogbg_jax 52000 0.220 3.178 0.238 3.442 3.34
wmt_jax 120000 0.139 4.645 0.155 5.178 12.04

priyakasimbeg and others added 11 commits November 6, 2025 03:53
- Introduced DTYPE enum to standardize data types (FLOAT32, FLOAT16, BFLOAT16) for JAX and PyTorch.
- Updated input pipelines and model definitions in CIFAR and ImageNet workloads to utilize mixed precision.
- Implemented casting policies for parameters and inputs using jmp and torch.autocast.
@rka97 rka97 requested a review from a team as a code owner January 27, 2026 20:54
@github-actions
Copy link

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg
Copy link
Contributor

closing this and opening new PR from a100 branch which has the changes from this source branch merged in.

@github-actions github-actions bot locked and limited conversation to collaborators Jan 29, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants