diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0f8df2917a92..e25efba23d9e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -63,7 +63,7 @@ def initialize_tensor_parallelism( local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"} + backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl", "neuron": "neuron"} backend = backend_map.get(device_type) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)