Skip to content

Torchax registering jax as PyTorch accelerator #9730

@ajakovljevicTT

Description

@ajakovljevicTT

🐛 Bug

When torchax is imported, it registers "jax" as a PyTorch accelerator, which causes torch.compile(backend='inductor') to fail for functions/models with no tensor inputs.

torchax/init.py:82 unconditionally calls:
torch.utils.rename_privateuse1_backend('jax')

This makes torch.accelerator.is_available() return True.

In torch/_inductor/codecache.py:812-813, FxGraphHashDetails.__init__ does:

if no_tensor_inputs and torch.accelerator.is_available():
    self.default_cuda_device_index = torch.accelerator.current_device_index()

Since PyTorch isn't actually linked with jax device support, current_device_index() raises:
RuntimeError: PyTorch is not linked with support for jax devices

torchax gets imported in via torch_xla/distributed/spmd/xla_sharding.py:725 calling maybe_get_torchax() inside mark_sharding().

To Reproduce

I'm giving a small python repro below:

import torch
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

# Setup SPMD mesh and mark tensor sharding (standard torch_xla SPMD usage)
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh = xs.Mesh(list(range(num_devices)), (num_devices,), ('data',))
t = torch.randn(4, 4).to('xla')
xs.mark_sharding(t, mesh, (0, None))

# Now inductor fails for no-tensor-input functions
@torch.compile(backend='inductor')
def make_grid():
    return torch.zeros(3, 3)

make_grid()  # RuntimeError: PyTorch is not linked with support for jax devices

This will fail with the following error:

  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2196, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2171, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 2392, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2681, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 117, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1106, in aot_module_simplified
    compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 242, in aot_stage2_compile
    return aot_stage2_inference(aot_state, aot_graph_capture)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 315, in aot_stage2_inference
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/schemas.py", line 1251, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2558, in fw_compiler_base
    return compile_fx_forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2275, in compile_fx_forward
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 782, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_aot.py", line 144, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 900, in _compile_fx_inner
    (key_info, cache_info) = FxGraphCache.prepare_key(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 1489, in prepare_key
    key, debug_lines = compiled_fx_graph_hash(
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 955, in compiled_fx_graph_hash
    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 845, in __init__
    self.default_cuda_device_index = torch.accelerator.current_device_index()
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/accelerator/__init__.py", line 132, in current_device_index
    return torch._C._accelerator_getDeviceIndex()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PyTorch is not linked with support for jax devices

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions