-
Notifications
You must be signed in to change notification settings - Fork 562
Open
Description
🐛 Bug
When torchax is imported, it registers "jax" as a PyTorch accelerator, which causes torch.compile(backend='inductor') to fail for functions/models with no tensor inputs.
torchax/init.py:82 unconditionally calls:
torch.utils.rename_privateuse1_backend('jax')
This makes torch.accelerator.is_available() return True.
In torch/_inductor/codecache.py:812-813, FxGraphHashDetails.__init__ does:
if no_tensor_inputs and torch.accelerator.is_available():
self.default_cuda_device_index = torch.accelerator.current_device_index()
Since PyTorch isn't actually linked with jax device support, current_device_index() raises:
RuntimeError: PyTorch is not linked with support for jax devices
torchax gets imported in via torch_xla/distributed/spmd/xla_sharding.py:725 calling maybe_get_torchax() inside mark_sharding().
To Reproduce
I'm giving a small python repro below:
import torch
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
# Setup SPMD mesh and mark tensor sharding (standard torch_xla SPMD usage)
xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh = xs.Mesh(list(range(num_devices)), (num_devices,), ('data',))
t = torch.randn(4, 4).to('xla')
xs.mark_sharding(t, mesh, (0, None))
# Now inductor fails for no-tensor-input functions
@torch.compile(backend='inductor')
def make_grid():
return torch.zeros(3, 3)
make_grid() # RuntimeError: PyTorch is not linked with support for jax devices
This will fail with the following error:
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2196, in _call_user_compiler
raise BackendCompilerFailed(
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 2171, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 2392, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2681, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/backends/common.py", line 117, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1106, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 242, in aot_stage2_compile
return aot_stage2_inference(aot_state, aot_graph_capture)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 315, in aot_stage2_inference
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/schemas.py", line 1251, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2558, in fw_compiler_base
return compile_fx_forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2275, in compile_fx_forward
return inner_compile(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 782, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_aot.py", line 144, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 900, in _compile_fx_inner
(key_info, cache_info) = FxGraphCache.prepare_key(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 1489, in prepare_key
key, debug_lines = compiled_fx_graph_hash(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 955, in compiled_fx_graph_hash
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/codecache.py", line 845, in __init__
self.default_cuda_device_index = torch.accelerator.current_device_index()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/accelerator/__init__.py", line 132, in current_device_index
return torch._C._accelerator_getDeviceIndex()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PyTorch is not linked with support for jax devices
Metadata
Metadata
Assignees
Labels
No labels