From 5dc370f09586ac6e1eaa85bdcb9b7fb2cd9c79eb Mon Sep 17 00:00:00 2001 From: dayshah Date: Sun, 30 Nov 2025 00:56:57 -0800 Subject: [PATCH 01/18] [core][rdt] Initial refactor for bring your own transport Signed-off-by: dayshah --- python/ray/_private/serialization.py | 2 +- .../ray/experimental/collective/__init__.py | 2 - .../ray/experimental/collective/collective.py | 2 +- .../collective/nixl_tensor_transport.py | 197 -------------- python/ray/experimental/collective/util.py | 68 ----- .../collective_tensor_transport.py | 82 +++--- .../gpu_object_manager/gpu_object_manager.py | 54 ++-- .../gpu_object_manager/gpu_object_store.py | 38 ++- .../nixl_tensor_transport.py | 242 ++++++++++++++++++ .../tensor_transport_manager.py | 41 +-- .../experimental/gpu_object_manager/types.py | 30 +++ .../experimental/gpu_object_manager/util.py | 67 +++++ .../data_parallel/dp_server.py | 2 +- python/ray/util/collective/collective.py | 64 ++--- .../collective_group/nixl_backend.py | 147 ----------- python/ray/util/collective/types.py | 74 +----- 16 files changed, 474 insertions(+), 638 deletions(-) delete mode 100644 python/ray/experimental/collective/nixl_tensor_transport.py delete mode 100644 python/ray/experimental/collective/util.py rename python/ray/experimental/{collective => gpu_object_manager}/collective_tensor_transport.py (70%) create mode 100644 python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py rename python/ray/experimental/{collective => gpu_object_manager}/tensor_transport_manager.py (79%) create mode 100644 python/ray/experimental/gpu_object_manager/types.py create mode 100644 python/ray/experimental/gpu_object_manager/util.py delete mode 100644 python/ray/util/collective/collective_group/nixl_backend.py diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index b2ab8a50bf6a..6a9f45a7f1f1 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -722,7 +722,7 @@ def store_gpu_objects(self, obj_id: str, tensors: List["torch.Tensor"]): obj_id is not None ), "`obj_id` is required, and it is the key to retrieve corresponding tensors from the GPU object store." # Regardless of whether `tensors` is empty, we always store the GPU object - # in the GPU object store. This ensures that `get_tensor_transport_metadata` is not + # in the GPU object store. This ensures that `__ray_get_tensor_transport_metadata__` is not # blocked indefinitely. worker = ray._private.worker.global_worker gpu_object_manager = worker.gpu_object_manager diff --git a/python/ray/experimental/collective/__init__.py b/python/ray/experimental/collective/__init__.py index 42289cee1653..a6fe61ea442a 100644 --- a/python/ray/experimental/collective/__init__.py +++ b/python/ray/experimental/collective/__init__.py @@ -9,7 +9,6 @@ allreduce, reducescatter, ) -from ray.experimental.collective.util import get_tensor_transport_manager __all__ = [ "allgather", @@ -19,5 +18,4 @@ "create_collective_group", "destroy_collective_group", "destroy_all_collective_groups", - "get_tensor_transport_manager", ] diff --git a/python/ray/experimental/collective/collective.py b/python/ray/experimental/collective/collective.py index 3a31128088a1..1872ace6fd8e 100644 --- a/python/ray/experimental/collective/collective.py +++ b/python/ray/experimental/collective/collective.py @@ -5,8 +5,8 @@ import ray import ray.experimental.internal_kv as internal_kv from ray.experimental.collective.communicator import CommunicatorHandle -from ray.experimental.collective.util import get_address_and_port from ray.util.annotations import PublicAPI +from ray.util.collective.collective import get_address_and_port from ray.util.collective.collective_group.torch_gloo_collective_group import ( get_master_address_metadata_key, ) diff --git a/python/ray/experimental/collective/nixl_tensor_transport.py b/python/ray/experimental/collective/nixl_tensor_transport.py deleted file mode 100644 index 6ab5b9d03ee3..000000000000 --- a/python/ray/experimental/collective/nixl_tensor_transport.py +++ /dev/null @@ -1,197 +0,0 @@ -from typing import TYPE_CHECKING, List, Optional - -import ray -from ray.experimental.collective.tensor_transport_manager import ( - TensorTransportManager, -) -from ray.util.collective.types import ( - NIXL_GROUP_NAME, - Backend, - NixlCommunicatorMetadata, - NixlTransportMetadata, -) - -if TYPE_CHECKING: - import torch - - -class NixlTensorTransport(TensorTransportManager): - @property - def tensor_transport_backend(self) -> Backend: - return Backend.NIXL - - @staticmethod - def is_one_sided() -> bool: - return True - - @staticmethod - def can_abort_transport() -> bool: - return True - - def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: - def __ray_actor_has_tensor_transport__( - self: "ray.actor.ActorHandle", - ) -> bool: - try: - from ray.util.collective.collective import get_group_handle - - nixl_backend = get_group_handle(NIXL_GROUP_NAME) - return nixl_backend is not None - except Exception: - return False - - return ray.get( - actor.__ray_call__.options(concurrency_group="_ray_system").remote( - __ray_actor_has_tensor_transport__ - ) - ) - - @staticmethod - def extract_tensor_transport_metadata( - obj_id: str, - gpu_object: List["torch.Tensor"], - ) -> NixlTransportMetadata: - from ray._private.worker import global_worker - from ray.util.collective.collective import get_group_handle - from ray.util.collective.collective_group.nixl_backend import NixlBackend - from ray.util.collective.types import NixlTransportMetadata - - gpu_object_store = global_worker.gpu_object_manager.gpu_object_store - nixl_backend: NixlBackend = get_group_handle(NIXL_GROUP_NAME) - device = None - tensor_meta = [] - duplicate_meta = gpu_object_store.record_and_get_meta_if_duplicate( - obj_id, gpu_object - ) - if duplicate_meta is not None: - return duplicate_meta - if gpu_object: - reg_descs, serialized_descs, agent_meta = nixl_backend.get_nixl_metadata( - gpu_object - ) - # We assume all tensors in one GPU object have the same device type. - device = gpu_object[0].device - for t in gpu_object: - if t.device.type != device.type: - raise ValueError( - "All tensors in an RDT object must have the same device type." - ) - tensor_meta.append((t.shape, t.dtype)) - else: - reg_descs, serialized_descs, agent_meta = None, None, None - ret = NixlTransportMetadata( - tensor_meta=tensor_meta, - tensor_device=device, - nixl_reg_descs=reg_descs, - nixl_serialized_descs=serialized_descs, - nixl_agent_meta=agent_meta, - ) - gpu_object_store.record_managed_meta_nixl(obj_id, ret) - return ret - - @staticmethod - def get_tensor_transport_metadata( - src_actor: "ray.actor.ActorHandle", - obj_id: str, - ) -> NixlTransportMetadata: - def __ray_get_tensor_transport_metadata__( - self: "ray.actor.ActorHandle", - obj_id: str, - ) -> NixlTransportMetadata: - - from ray._private.worker import global_worker - - gpu_object_manager = global_worker.gpu_object_manager - gpu_object_store = gpu_object_manager.gpu_object_store - # NOTE: We do not specify a timeout here because the user task that returns - # it could take arbitrarily long and we don't want to trigger a spurious - # timeout. - gpu_object = gpu_object_store.wait_and_get_object(obj_id) - return NixlTensorTransport.extract_tensor_transport_metadata( - obj_id, gpu_object - ) - - # Submit a Ray actor task to the source actor to get the tensor metadata. - # The metadata is a list of tuples, where each tuple contains the shape and dtype - # of a tensor in the GPU object store. This function returns an ObjectRef that - # points to the tensor metadata. - # NOTE(swang): We put this task on the background thread to avoid tasks - # executing on the main thread blocking this task. - - return src_actor.__ray_call__.options(concurrency_group="_ray_system").remote( - __ray_get_tensor_transport_metadata__, obj_id - ) - - @staticmethod - def get_communicator_metadata( - src_actor: "ray.actor.ActorHandle", - dst_actor: "ray.actor.ActorHandle", - backend: Optional[str] = None, - ) -> NixlCommunicatorMetadata: - - communicator_metadata = NixlCommunicatorMetadata( - communicator_name=NIXL_GROUP_NAME, - ) - - return communicator_metadata - - @staticmethod - def recv_multiple_tensors( - tensors, - obj_id: str, - tensor_transport_metadata: NixlTransportMetadata, - communicator_metadata: NixlCommunicatorMetadata, - ): - from ray.util.collective import types - from ray.util.collective.collective import get_group_handle - - if tensors: - g = get_group_handle(communicator_metadata.communicator_name) - - assert isinstance( - tensor_transport_metadata, types.NixlTransportMetadata - ), "metadata must be a NixlTransportMetadata object for NIXL transport" - assert isinstance( - communicator_metadata, types.NixlCommunicatorMetadata - ), "metadata must be a NixlCommunicatorMetadata object for NIXL transport" - - g.recv( - tensors, - obj_id, - tensor_transport_metadata.nixl_serialized_descs, - tensor_transport_metadata.nixl_agent_meta, - ) - - @staticmethod - def send_multiple_tensors( - tensors: List["torch.Tensor"], - communicator_metadata: NixlCommunicatorMetadata, - device: "torch.device", - ): - raise NotImplementedError( - "NIXL transport does not support send_multiple_tensors, since it is a one-sided transport." - ) - - @staticmethod - def garbage_collect(obj_id: str, tensor_transport_meta: NixlTransportMetadata): - from ray._private.worker import global_worker - from ray.util.collective.collective import get_group_handle - - gpu_object_store = global_worker.gpu_object_manager.gpu_object_store - count = gpu_object_store.remove_managed_meta_nixl(obj_id) - if count == 0: - descs = tensor_transport_meta.nixl_reg_descs - if descs is not None: - nixl_backend = get_group_handle(NIXL_GROUP_NAME) - nixl_backend.deregister_memory(descs) - - @staticmethod - def abort_transport( - obj_id: str, - communicator_metadata: NixlCommunicatorMetadata, - ): - from ray.util.collective.collective import get_group_handle - - g = get_group_handle(communicator_metadata.communicator_name) - if g: - g.abort(obj_id) diff --git a/python/ray/experimental/collective/util.py b/python/ray/experimental/collective/util.py deleted file mode 100644 index 241a95e890af..000000000000 --- a/python/ray/experimental/collective/util.py +++ /dev/null @@ -1,68 +0,0 @@ -import socket -from typing import TYPE_CHECKING, Tuple - -import ray -from ray._common.network_utils import find_free_port, is_ipv6 -from ray.experimental.collective.collective_tensor_transport import ( - CollectiveTensorTransport, -) -from ray.experimental.collective.nixl_tensor_transport import NixlTensorTransport -from ray.experimental.collective.tensor_transport_manager import TensorTransportManager -from ray.util.collective.types import Backend - -if TYPE_CHECKING: - import torch - -# Singleton instances for tensor transport managers -_nixl_tensor_transport_manager = None -_gloo_tensor_transport_manager = None -_nccl_tensor_transport_manager = None - - -def get_tensor_transport_manager( - tensor_transport: Backend, -) -> "TensorTransportManager": - """Get the tensor transport manager for the given tensor transport protocol. - - Args: - tensor_transport: The tensor transport protocol to use for the GPU object. - - Returns: - TensorTransportManager: The tensor transport manager for the given tensor transport protocol. - """ - if tensor_transport == Backend.NIXL: - global _nixl_tensor_transport_manager - if _nixl_tensor_transport_manager is None: - _nixl_tensor_transport_manager = NixlTensorTransport() - return _nixl_tensor_transport_manager - elif tensor_transport == Backend.TORCH_GLOO: - global _gloo_tensor_transport_manager - if _gloo_tensor_transport_manager is None: - _gloo_tensor_transport_manager = CollectiveTensorTransport(tensor_transport) - return _gloo_tensor_transport_manager - elif tensor_transport == Backend.NCCL: - global _nccl_tensor_transport_manager - if _nccl_tensor_transport_manager is None: - _nccl_tensor_transport_manager = CollectiveTensorTransport(tensor_transport) - return _nccl_tensor_transport_manager - else: - raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}") - - -def device_match_transport(device: "torch.device", tensor_transport: Backend) -> bool: - """Check if the device matches the transport.""" - if tensor_transport == Backend.NIXL: - return device.type == "cuda" or device.type == "cpu" - elif tensor_transport == Backend.TORCH_GLOO: - return device.type == "cpu" - elif tensor_transport == Backend.NCCL: - return device.type == "cuda" - else: - raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}") - - -def get_address_and_port() -> Tuple[str, int]: - """Returns the IP address and a free port on this node.""" - addr = ray.util.get_node_ip_address() - port = find_free_port(socket.AF_INET6 if is_ipv6(addr) else socket.AF_INET) - return addr, port diff --git a/python/ray/experimental/collective/collective_tensor_transport.py b/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py similarity index 70% rename from python/ray/experimental/collective/collective_tensor_transport.py rename to python/ray/experimental/gpu_object_manager/collective_tensor_transport.py index 78f97d64c87a..954ce67e01de 100644 --- a/python/ray/experimental/collective/collective_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py @@ -1,25 +1,43 @@ +from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional import ray -from ray.experimental.collective.tensor_transport_manager import ( +from ray.experimental.gpu_object_manager.tensor_transport_manager import ( TensorTransportManager, ) -from ray.util.collective.types import ( - Backend, - CollectiveCommunicatorMetadata, - CollectiveTransportMetadata, +from ray.experimental.gpu_object_manager.types import ( + CommunicatorMetadata, + TensorTransportMetadata, ) if TYPE_CHECKING: import torch +@dataclass +class CollectiveTransportMetadata(TensorTransportMetadata): + """Metadata for tensors stored in the GPU object store for collective transport.""" + + +@dataclass +class CollectiveCommunicatorMetadata(CommunicatorMetadata): + """Metadata for the collective communicator (e.g. NCCL, GLOO). + + Args: + src_rank: The rank of the source actor. + dst_rank: The rank of the destination actor. + """ + + src_rank: Optional[int] = None + dst_rank: Optional[int] = None + + class CollectiveTensorTransport(TensorTransportManager): - def __init__(self, tensor_transport_backend: Backend): + def __init__(self, tensor_transport_backend: str): self._tensor_transport_backend = tensor_transport_backend @property - def tensor_transport_backend(self) -> Backend: + def tensor_transport_backend(self) -> str: return self._tensor_transport_backend @staticmethod @@ -38,8 +56,8 @@ def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: ) return len(communicators) > 0 - @staticmethod def extract_tensor_transport_metadata( + self, obj_id: str, gpu_object: List["torch.Tensor"], ) -> CollectiveTransportMetadata: @@ -58,40 +76,8 @@ def extract_tensor_transport_metadata( tensor_device=device, ) - @staticmethod - def get_tensor_transport_metadata( - src_actor: "ray.actor.ActorHandle", - obj_id: str, - ) -> CollectiveTransportMetadata: - def __ray_get_tensor_transport_metadata__( - self: "ray.actor.ActorHandle", - obj_id: str, - ) -> CollectiveTransportMetadata: - - from ray._private.worker import global_worker - - gpu_object_store = global_worker.gpu_object_manager.gpu_object_store - # NOTE: We do not specify a timeout here because the user task that returns - # it could take arbitrarily long and we don't want to trigger a spurious - # timeout. - gpu_object = gpu_object_store.wait_and_get_object(obj_id) - return CollectiveTensorTransport.extract_tensor_transport_metadata( - obj_id, gpu_object - ) - - # Submit a Ray actor task to the source actor to get the tensor metadata. - # The metadata is a list of tuples, where each tuple contains the shape and dtype - # of a tensor in the GPU object store. This function returns an ObjectRef that - # points to the tensor metadata. - # NOTE(swang): We put this task on the background thread to avoid tasks - # executing on the main thread blocking this task. - - return src_actor.__ray_call__.options(concurrency_group="_ray_system").remote( - __ray_get_tensor_transport_metadata__, obj_id - ) - - @staticmethod def get_communicator_metadata( + self, src_actor: "ray.actor.ActorHandle", dst_actor: "ray.actor.ActorHandle", backend: Optional[str] = None, @@ -138,21 +124,20 @@ def get_communicator_metadata( ) return communicator_metadata - @staticmethod def recv_multiple_tensors( + self, tensors, obj_id: str, tensor_transport_metadata: CollectiveTransportMetadata, communicator_metadata: CollectiveCommunicatorMetadata, ): - from ray.util.collective import types from ray.util.collective.collective import recv assert isinstance( - tensor_transport_metadata, types.CollectiveTransportMetadata + tensor_transport_metadata, CollectiveTransportMetadata ), "metadata must be a CollectiveTransportMetadata object for non-NIXL transport" assert isinstance( - communicator_metadata, types.CollectiveCommunicatorMetadata + communicator_metadata, CollectiveCommunicatorMetadata ), "metadata must be a CollectiveCommunicatorMetadata object for non-NIXL transport" for tensor in tensors: @@ -162,8 +147,8 @@ def recv_multiple_tensors( communicator_metadata.communicator_name, ) - @staticmethod def send_multiple_tensors( + self, tensors: List["torch.Tensor"], tensor_transport_metadata: CollectiveTransportMetadata, communicator_metadata: CollectiveCommunicatorMetadata, @@ -183,14 +168,13 @@ def send_multiple_tensors( communicator_metadata.communicator_name, ) - @staticmethod def garbage_collect( - obj_id: str, tensor_transport_meta: CollectiveTransportMetadata + self, obj_id: str, tensor_transport_meta: CollectiveTransportMetadata ): pass - @staticmethod def abort_transport( + self, obj_id: str, communicator_metadata: CollectiveCommunicatorMetadata, ): diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index 788ee9ece376..c98c2d68dd26 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -9,6 +9,7 @@ from ray._private import ray_constants from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef +from ray.util.annotations import PublicAPI if TYPE_CHECKING: import torch @@ -16,7 +17,10 @@ from ray.experimental.gpu_object_manager.gpu_object_store import ( GPUObjectStore, ) - from ray.util.collective.types import CommunicatorMetadata, TensorTransportMetadata + from ray.experimental.gpu_object_manager.types import ( + CommunicatorMetadata, + TensorTransportMetadata, + ) logger = logging.getLogger(__name__) @@ -28,8 +32,6 @@ # of a tensor in the GPU object store. class GPUObjectMeta(NamedTuple): src_actor: "ray.actor.ActorHandle" - # Must be a valid backend name as defined in - # `ray.util.collective.types.Backend`. tensor_transport_backend: str tensor_transport_meta: "TensorTransportMetadata" # sent_dest_actors tracks the set of actor IDs that this object has been sent to. @@ -51,8 +53,7 @@ class TransferMetadata(NamedTuple): timeout: float -# TODO(swang): Uncomment and add an API docs page and example usage. -# @PublicAPI(stability="alpha") +@PublicAPI(stability="alpha") def wait_tensor_freed(tensor: "torch.Tensor", timeout: Optional[float] = None): """ Wait for the tensor to be freed. @@ -181,14 +182,16 @@ def _abort_transport( Cleans up the ref_info_map, kill the src and dst actors, and destroy the collective group if necessary. """ - from ray.experimental.collective import ( - destroy_collective_group, - get_tensor_transport_manager, + from ray.experimental.collective import destroy_collective_group + from ray.experimental.gpu_object_manager.collective_tensor_transport import ( + CollectiveCommunicatorMetadata, ) from ray.experimental.gpu_object_manager.gpu_object_store import ( __ray_abort_transport__, ) - from ray.util.collective.types import CollectiveCommunicatorMetadata + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) ref_info = ref_info_map.pop(failed_ref.hex(), None) if ref_info is None: @@ -292,8 +295,8 @@ def add_gpu_object_ref( tensor_transport: The tensor transport protocol to use for the GPU object. tensor_transport_meta: The tensor transport metadata that is pre-computed. """ - from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.gpu_object_manager.gpu_object_store import ( + __ray_get_tensor_transport_metadata__, _tensor_transport_to_collective_backend, ) @@ -301,12 +304,17 @@ def add_gpu_object_ref( tensor_transport ) obj_id = obj_ref.hex() - tensor_transport_manager = get_tensor_transport_manager( - tensor_transport_backend - ) if not tensor_transport_meta: - tensor_meta = tensor_transport_manager.get_tensor_transport_metadata( - src_actor, obj_id + # Submit a Ray actor task to the source actor to get the tensor metadata. + # The metadata is a list of tuples, where each tuple contains the shape and dtype + # of a tensor in the GPU object store. This function returns an ObjectRef that + # points to the tensor metadata. + # NOTE(swang): We put this task on the background thread to avoid tasks + # executing on the main thread blocking this task. + tensor_meta = src_actor.__ray_call__.options( + concurrency_group="_ray_system" + ).remote( + __ray_get_tensor_transport_metadata__, obj_id, tensor_transport_backend ) else: tensor_meta = tensor_transport_meta @@ -342,10 +350,12 @@ def _fetch_object( Returns: None """ - from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.gpu_object_manager.gpu_object_store import ( __ray_fetch_gpu_object__, ) + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) if tensor_transport not in [ TensorTransportEnum.OBJECT_STORE, @@ -425,11 +435,13 @@ def trigger_out_of_band_tensor_transfer( if self.is_managed_object(arg.hex()): gpu_object_refs.add(arg) if gpu_object_refs: - from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.gpu_object_manager.gpu_object_store import ( __ray_recv__, __ray_send__, ) + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) # Count the number of readers for each GPU object. for obj_ref in gpu_object_refs: @@ -604,10 +616,12 @@ def actor_has_tensor_transport( """ # Import get_collective_groups here to avoid dependency on # collective libraries for default Ray installation. - from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.gpu_object_manager.gpu_object_store import ( _tensor_transport_to_collective_backend, ) + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) tensor_transport_backend = _tensor_transport_to_collective_backend( tensor_transport @@ -632,10 +646,12 @@ def put_object( tensors: The tensors to put into the GPU object manager. """ - from ray.experimental.collective import get_tensor_transport_manager from ray.experimental.gpu_object_manager.gpu_object_store import ( _tensor_transport_to_collective_backend, ) + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) tensor_transport_backend = _tensor_transport_to_collective_backend( tensor_transport diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_store.py b/python/ray/experimental/gpu_object_manager/gpu_object_store.py index 280b3452fe9a..f4ed10500dec 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_store.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_store.py @@ -7,13 +7,14 @@ import ray.util.collective as collective from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef -from ray.experimental.collective import get_tensor_transport_manager -from ray.experimental.collective.util import device_match_transport -from ray.util.collective.types import ( - Backend, +from ray.experimental.gpu_object_manager.types import ( CommunicatorMetadata, TensorTransportMetadata, ) +from ray.experimental.gpu_object_manager.util import ( + device_match_transport, + get_tensor_transport_manager, +) try: import torch @@ -24,15 +25,15 @@ ) TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND = { - TensorTransportEnum.NCCL: Backend.NCCL, - TensorTransportEnum.GLOO: Backend.TORCH_GLOO, - TensorTransportEnum.NIXL: Backend.NIXL, + TensorTransportEnum.NCCL: "nccl", + TensorTransportEnum.GLOO: "torch_gloo", + TensorTransportEnum.NIXL: "nixl", } def _tensor_transport_to_collective_backend( tensor_transport: TensorTransportEnum, -) -> Backend: +) -> str: try: return TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND[tensor_transport] except KeyError: @@ -41,6 +42,24 @@ def _tensor_transport_to_collective_backend( ) +def __ray_get_tensor_transport_metadata__( + self, obj_id: str, backend: str +) -> TensorTransportMetadata: + """Helper function that runs on the src actor to get transport metadata.""" + from ray._private.worker import global_worker + + gpu_object_store = global_worker.gpu_object_manager.gpu_object_store + # NOTE: We do not specify a timeout here because the user task that returns + # it could take arbitrarily long and we don't want to trigger a spurious + # timeout. + gpu_object = gpu_object_store.wait_and_get_object(obj_id) + + tensor_transport_manager = get_tensor_transport_manager(backend) + return tensor_transport_manager.extract_tensor_transport_metadata( + obj_id, gpu_object + ) + + def __ray_send__( self, obj_id: str, @@ -129,12 +148,11 @@ def __ray_abort_transport__(self, obj_id: str, communicator_meta: CommunicatorMe def __ray_free__( self, obj_id: str, - tensor_transport_backend: Backend, + tensor_transport_backend: str, tensor_transport_meta: TensorTransportMetadata, ): try: from ray._private.worker import global_worker - from ray.experimental.collective import get_tensor_transport_manager tensor_transport_manager = get_tensor_transport_manager( tensor_transport_backend diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py new file mode 100644 index 000000000000..31ce49d9e918 --- /dev/null +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -0,0 +1,242 @@ +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +import ray +from ray.experimental.gpu_object_manager.tensor_transport_manager import ( + TensorTransportManager, +) +from ray.experimental.gpu_object_manager.types import ( + CommunicatorMetadata, + TensorTransportMetadata, +) + +if TYPE_CHECKING: + import torch + + +@dataclass +class NixlCommunicatorMetadata(CommunicatorMetadata): + """Metadata for the NIXL communicator.""" + + +@dataclass +class NixlTransportMetadata(TensorTransportMetadata): + """Metadata for tensors stored in the GPU object store for NIXL transport. + + Args: + nixl_serialized_descs: Serialized tensor descriptors for NIXL transport. + nixl_agent_meta: The additional metadata of the remote NIXL agent. + """ + + nixl_reg_descs: Optional[Any] = None + nixl_serialized_descs: Optional[bytes] = None + nixl_agent_meta: Optional[bytes] = None + + __eq__ = object.__eq__ + __hash__ = object.__hash__ + + +class NixlTensorTransport(TensorTransportManager): + def __init__(self, tensor_transport_backend: str): + """ + Creates a NIXL agent with UCX backend. + """ + from nixl._api import nixl_agent, nixl_agent_config + + agent_config = nixl_agent_config(backends=["UCX"]) + ctx = ray.get_runtime_context() + actor_id = ctx.get_actor_id() + if actor_id is None: + # If the actor id is None, it means the current process is a driver. + import uuid + + actor_id = f"RAY-DRIVER-{uuid.uuid4()}" + self._nixl_agent = nixl_agent(actor_id, agent_config) + self._aborted_transfer_obj_ids = set() + self._aborted_transfer_obj_ids_lock = threading.Lock() + + @property + def tensor_transport_backend(self) -> str: + return "nixl" + + @staticmethod + def is_one_sided() -> bool: + return True + + @staticmethod + def can_abort_transport() -> bool: + return True + + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: + # TODO(dayshah): This is called on a .remote RDT call, so it's quite expensive. + + def __ray_actor_has_tensor_transport__( + self: "ray.actor.ActorHandle", + ) -> bool: + # Check if nixl is installed + try: + pass + except Exception: + return False + + return ray.get( + actor.__ray_call__.options(concurrency_group="_ray_system").remote( + __ray_actor_has_tensor_transport__ + ) + ) + + def extract_tensor_transport_metadata( + self, + obj_id: str, + gpu_object: List["torch.Tensor"], + ) -> NixlTransportMetadata: + from ray._private.worker import global_worker + + gpu_object_store = global_worker.gpu_object_manager.gpu_object_store + device = None + tensor_meta = [] + duplicate_meta = gpu_object_store.record_and_get_meta_if_duplicate( + obj_id, gpu_object + ) + if duplicate_meta is not None: + return duplicate_meta + + if gpu_object: + reg_descs = self._nixl_agent.register_memory(gpu_object) + serialized_descs = self._nixl_agent.get_serialized_descs(reg_descs.trim()) + agent_meta = self._nixl_agent.get_agent_metadata() + # We assume all tensors in one GPU object have the same device type. + device = gpu_object[0].device + for t in gpu_object: + if t.device.type != device.type: + raise ValueError( + "All tensors in an RDT object must have the same device type." + ) + tensor_meta.append((t.shape, t.dtype)) + else: + reg_descs, serialized_descs, agent_meta = None, None, None + + ret = NixlTransportMetadata( + tensor_meta=tensor_meta, + tensor_device=device, + nixl_reg_descs=reg_descs, + nixl_serialized_descs=serialized_descs, + nixl_agent_meta=agent_meta, + ) + gpu_object_store.record_managed_meta_nixl(obj_id, ret) + return ret + + def get_communicator_metadata( + self, + src_actor: "ray.actor.ActorHandle", + dst_actor: "ray.actor.ActorHandle", + backend: Optional[str] = None, + ) -> NixlCommunicatorMetadata: + return NixlCommunicatorMetadata() + + def recv_multiple_tensors( + self, + tensors, + obj_id: str, + tensor_transport_metadata: NixlTransportMetadata, + communicator_metadata: NixlCommunicatorMetadata, + ): + if tensors is None: + return + + assert isinstance( + tensor_transport_metadata, NixlTransportMetadata + ), "metadata must be a NixlTransportMetadata object for NIXL transport" + assert isinstance( + communicator_metadata, NixlCommunicatorMetadata + ), "metadata must be a NixlCommunicatorMetadata object for NIXL transport" + + nixl_serialized_descs = tensor_transport_metadata.nixl_serialized_descs + remote_nixl_agent_meta = tensor_transport_metadata.nixl_agent_meta + + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError(f"NIXL transfer aborted for object id: {obj_id}") + + local_descs = None + remote_name = None + xfer_handle = None + try: + nixl_agent = self._nixl_agent + remote_descs = nixl_agent.deserialize_descs(nixl_serialized_descs) + local_descs = nixl_agent.register_memory(tensors) + remote_name = nixl_agent.add_remote_agent(remote_nixl_agent_meta) + + xfer_handle = nixl_agent.initialize_xfer( + # "UUID" here is just a placeholder, can be any bytes, but without it, + # nixl will fail to transfer multiple times. + "READ", + local_descs.trim(), + remote_descs, + remote_name, + "UUID", + ) + + state = nixl_agent.transfer(xfer_handle) + if state == "ERR": + raise RuntimeError("NIXL transfer got to Error state.") + # Since current nixl does not provide a better way, we need to check the state of + # the transfer continuously. + while True: + state = nixl_agent.check_xfer_state(xfer_handle) + if state == "ERR": + raise RuntimeError("NIXL transfer got to Error state.") + if state == "PROC": + with self._aborted_transfer_obj_ids_lock: + if obj_id in self._aborted_transfer_obj_ids: + self._aborted_transfer_obj_ids.remove(obj_id) + raise RuntimeError( + f"NIXL transfer aborted for object id: {obj_id}" + ) + time.sleep(0.001) # Avoid busy waiting + elif state == "DONE": + break + finally: + # We could raise errors or NIXL could raise errors like NIXL_ERR_REMOTE_DISCONNECT, + # so doing best effort cleanup. + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.discard(obj_id) + if xfer_handle: + nixl_agent.release_xfer_handle(xfer_handle) + if remote_name: + nixl_agent.remove_remote_agent(remote_name) + if local_descs: + nixl_agent.deregister_memory(local_descs) + + def send_multiple_tensors( + self, + tensors: List["torch.Tensor"], + communicator_metadata: NixlCommunicatorMetadata, + device: "torch.device", + ): + raise NotImplementedError( + "NIXL transport does not support send_multiple_tensors, since it is a one-sided transport." + ) + + def garbage_collect( + self, obj_id: str, tensor_transport_meta: NixlTransportMetadata + ): + from ray._private.worker import global_worker + + gpu_object_store = global_worker.gpu_object_manager.gpu_object_store + count = gpu_object_store.remove_managed_meta_nixl(obj_id) + if count == 0: + descs = tensor_transport_meta.nixl_reg_descs + if descs is not None: + self._nixl_agent.deregister_memory(descs) + + def abort_transport( + self, + obj_id: str, + communicator_metadata: NixlCommunicatorMetadata, + ): + with self._aborted_transfer_obj_ids_lock: + self._aborted_transfer_obj_ids.add(obj_id) diff --git a/python/ray/experimental/collective/tensor_transport_manager.py b/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py similarity index 79% rename from python/ray/experimental/collective/tensor_transport_manager.py rename to python/ray/experimental/gpu_object_manager/tensor_transport_manager.py index 7f0210b7d9ca..85fd9e54accb 100644 --- a/python/ray/experimental/collective/tensor_transport_manager.py +++ b/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING, List, Optional import ray -from ray.util.collective.types import ( - Backend, +from ray.experimental.gpu_object_manager.types import ( CommunicatorMetadata, TensorTransportMetadata, ) @@ -15,11 +14,11 @@ class TensorTransportManager(ABC): @property @abstractmethod - def tensor_transport_backend(self) -> Backend: + def tensor_transport_backend(self) -> str: """The tensor transport backend, e.g., NCCL. Returns: - Backend: The backend of the tensor transport. + str: The backend of the tensor transport. """ @staticmethod @@ -53,28 +52,9 @@ def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: bool: True if the actor has the tensor transport available, False otherwise. """ - @staticmethod - @abstractmethod - def get_tensor_transport_metadata( - src_actor: "ray.actor.ActorHandle", - obj_id: str, - ) -> TensorTransportMetadata: - """ - Get the tensor transport metadata for the GPU object. - This function retrieves metadata about tensors stored in the GPU object store, - including their shapes, dtypes, and any transport-specific metadata, e.g., NIXL descriptors. - - Args: - src_actor: The actor that runs this function. - obj_id: The ID of the GPU object to get metadata for - - Returns: - TensorTransportMetadata: A named tuple containing the tensor metadata. - """ - - @staticmethod @abstractmethod def extract_tensor_transport_metadata( + self, obj_id: str, gpu_object: List["torch.Tensor"], ) -> TensorTransportMetadata: @@ -89,9 +69,9 @@ def extract_tensor_transport_metadata( TensorTransportMetadata: The tensor transport metadata. """ - @staticmethod @abstractmethod def get_communicator_metadata( + self, src_actor: "ray.actor.ActorHandle", dst_actor: "ray.actor.ActorHandle", backend: Optional[str] = None, @@ -109,9 +89,9 @@ def get_communicator_metadata( CommunicatorMetadata: The communicator metadata. """ - @staticmethod @abstractmethod def recv_multiple_tensors( + self, tensors: List["torch.Tensor"], obj_id: str, tensor_transport_metadata: TensorTransportMetadata, @@ -128,9 +108,9 @@ def recv_multiple_tensors( """ - @staticmethod @abstractmethod def send_multiple_tensors( + self, tensors: List["torch.Tensor"], communicator_metadata: CommunicatorMetadata, ): @@ -142,9 +122,10 @@ def send_multiple_tensors( communicator_metadata: The communicator metadata for the send/recv operation. """ - @staticmethod @abstractmethod - def garbage_collect(obj_id: str, tensor_transport_meta: TensorTransportMetadata): + def garbage_collect( + self, obj_id: str, tensor_transport_meta: TensorTransportMetadata + ): """ Garbage collect for the tensor transport after the GPU object is freed. @@ -153,9 +134,9 @@ def garbage_collect(obj_id: str, tensor_transport_meta: TensorTransportMetadata) tensor_transport_meta: The tensor transport metadata. """ - @staticmethod @abstractmethod def abort_transport( + self, obj_id: str, communicator_metadata: CommunicatorMetadata, ): diff --git a/python/ray/experimental/gpu_object_manager/types.py b/python/ray/experimental/gpu_object_manager/types.py new file mode 100644 index 000000000000..5ac454d30a52 --- /dev/null +++ b/python/ray/experimental/gpu_object_manager/types.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +if TYPE_CHECKING: + import torch + + +@dataclass +class CommunicatorMetadata: + """Metadata for the communicator. + + Args: + communicator_name: The name of the communicator. + """ + + communicator_name: str = "" + + +@dataclass +class TensorTransportMetadata: + """Metadata for tensors stored in the GPU object store. + + Args: + tensor_meta: A list of tuples, each containing the shape and dtype of a tensor. + tensor_device: The device of the tensor. Currently, we require all tensors in the + list have the same device type. + """ + + tensor_meta: List[Tuple["torch.Size", "torch.dtype"]] + tensor_device: Optional["torch.device"] = None diff --git a/python/ray/experimental/gpu_object_manager/util.py b/python/ray/experimental/gpu_object_manager/util.py new file mode 100644 index 000000000000..f759f81d056a --- /dev/null +++ b/python/ray/experimental/gpu_object_manager/util.py @@ -0,0 +1,67 @@ +from typing import TYPE_CHECKING + +from ray.experimental.gpu_object_manager.collective_tensor_transport import ( + CollectiveTensorTransport, +) +from ray.experimental.gpu_object_manager.nixl_tensor_transport import ( + NixlTensorTransport, +) +from ray.experimental.gpu_object_manager.tensor_transport_manager import ( + TensorTransportManager, +) + +if TYPE_CHECKING: + import torch + + +# Class definitions for transport managers +transport_manager_classes: dict[str, TensorTransportManager] = { + "nixl": NixlTensorTransport, + "torch_gloo": CollectiveTensorTransport, + "nccl": CollectiveTensorTransport, +} + +transport_devices = { + "nixl": ["cuda", "cpu"], + "torch_gloo": ["cpu"], + "nccl": ["cuda"], +} + + +# Singleton instances of transport managers +transport_managers = {} + + +def get_tensor_transport_manager( + transport_name: str, +) -> "TensorTransportManager": + """Get the tensor transport manager for the given tensor transport protocol. + + Args: + transport_name: The tensor transport protocol to use for the GPU object. + + Returns: + TensorTransportManager: The tensor transport manager for the given tensor transport protocol. + """ + global transport_manager_classes + global transport_managers + + if transport_name in transport_managers: + return transport_managers[transport_name] + + if transport_name not in transport_manager_classes: + raise ValueError(f"Unsupported tensor transport protocol: {transport_name}") + + transport_managers[transport_name] = transport_manager_classes[transport_name]( + transport_name + ) + return transport_managers[transport_name] + + +def device_match_transport(device: "torch.device", tensor_transport: str) -> bool: + """Check if the device matches the transport.""" + + if tensor_transport not in transport_devices: + raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}") + + return device.type in transport_devices[tensor_transport] diff --git a/python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_server.py b/python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_server.py index 57e6cef92d88..93f31cdbd94c 100644 --- a/python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_server.py +++ b/python/ray/llm/_internal/serve/serving_patterns/data_parallel/dp_server.py @@ -1,11 +1,11 @@ import logging import time -from ray.experimental.collective.util import get_address_and_port from ray.llm._internal.serve.core.configs.llm_config import LLMConfig from ray.llm._internal.serve.core.server.llm_server import LLMServer from ray.runtime_context import get_runtime_context from ray.serve.handle import DeploymentHandle +from ray.util.collective.collective import get_address_and_port logger = logging.getLogger(__name__) diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 5e1681015c25..1476512a2f1b 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -2,18 +2,17 @@ import logging import os +import socket import threading import time -from typing import List +from typing import List, Tuple import numpy as np import ray import ray.experimental.internal_kv as _internal_kv from . import types -from ray.experimental.collective.util import ( - get_address_and_port as _get_address_and_port, -) +from ray._common.network_utils import find_free_port, is_ipv6 from ray.util.collective.collective_group.torch_gloo_collective_group import ( get_master_address_metadata_key as _get_master_addr_key, ) @@ -39,13 +38,6 @@ except ImportError: _TORCH_DISTRIBUTED_AVAILABLE = False -try: - from ray.util.collective.collective_group.nixl_backend import NixlBackend - - _NIXL_AVAILABLE = True -except ImportError: - _NIXL_AVAILABLE = False - def nccl_available(): global _LOG_NCCL_WARNING @@ -69,8 +61,11 @@ def torch_distributed_available(): return _TORCH_DISTRIBUTED_AVAILABLE -def nixl_available(): - return _NIXL_AVAILABLE +def get_address_and_port() -> Tuple[str, int]: + """Returns the IP address and a free port on this node.""" + addr = ray.util.get_node_ip_address() + port = find_free_port(socket.AF_INET6 if is_ipv6(addr) else socket.AF_INET) + return addr, port class GroupManager(object): @@ -99,7 +94,7 @@ def create_collective_group( # Rendezvous: ensure a MASTER_ADDR:MASTER_PORT is published in internal_kv. metadata_key = _get_master_addr_key(group_name) if rank == 0: - addr, port = _get_address_and_port() + addr, port = get_address_and_port() _internal_kv._internal_kv_put(metadata_key, f"{addr}:{port}") else: # Wait until rank 0 publishes the metadata or timeout. @@ -124,10 +119,6 @@ def create_collective_group( _check_backend_availability(backend) logger.debug("Creating NCCL group: '{}'...".format(group_name)) g = NCCLGroup(world_size, rank, group_name) - elif backend == types.Backend.NIXL: - _check_backend_availability(backend) - logger.debug("Creating NIXL Backend: '{}'...".format(group_name)) - g = NixlBackend() else: raise RuntimeError(f"Unexpected backend: {backend}") @@ -758,32 +749,26 @@ def get_group_handle(group_name: str = "default"): Returns: The collective group handle. """ - if group_name != types.NIXL_GROUP_NAME: - _check_inside_actor() + _check_inside_actor() global _group_mgr global _group_mgr_lock with _group_mgr_lock: if not _group_mgr.is_group_exist(group_name): # try loading from remote info store try: - if group_name == types.NIXL_GROUP_NAME: - _group_mgr.create_collective_group( - types.Backend.NIXL, None, None, group_name, None - ) - else: - # if the information is stored in an Info object, - # get and create the group. - name = "info_" + group_name - mgr = ray.get_actor(name=name) - ids, world_size, rank, backend, gloo_timeout = ray.get( - mgr.get_info.remote() - ) - worker = ray._private.worker.global_worker - id_ = worker.core_worker.get_actor_id() - r = rank[ids.index(id_)] - _group_mgr.create_collective_group( - backend, world_size, r, group_name, gloo_timeout - ) + # if the information is stored in an Info object, + # get and create the group. + name = "info_" + group_name + mgr = ray.get_actor(name=name) + ids, world_size, rank, backend, gloo_timeout = ray.get( + mgr.get_info.remote() + ) + worker = ray._private.worker.global_worker + id_ = worker.core_worker.get_actor_id() + r = rank[ids.index(id_)] + _group_mgr.create_collective_group( + backend, world_size, r, group_name, gloo_timeout + ) except ValueError as exc: # check if this group is initialized using options() if ( @@ -834,9 +819,6 @@ def _check_backend_availability(backend: types.Backend): elif backend == types.Backend.TORCH_GLOO: if not torch_distributed_available(): raise RuntimeError("torch.distributed is not available.") - elif backend == types.Backend.NIXL: - if not nixl_available(): - raise RuntimeError("NIXL is not available.") def _check_inside_actor(): diff --git a/python/ray/util/collective/collective_group/nixl_backend.py b/python/ray/util/collective/collective_group/nixl_backend.py deleted file mode 100644 index beff753b055a..000000000000 --- a/python/ray/util/collective/collective_group/nixl_backend.py +++ /dev/null @@ -1,147 +0,0 @@ -import threading -import time -from typing import TYPE_CHECKING, Any, List, Tuple - -from nixl._api import nixl_agent, nixl_agent_config - -import ray -from ray.util.collective.types import Backend - -if TYPE_CHECKING: - import torch - - -class NixlBackend: - """Backend implementation for NIXL tensor transport. - - This class provides functionality for transferring tensors using NIXL. It handles - initialization of the NIXL agent, receiving tensors, and managing NIXL metadata. - """ - - def __init__(self): - """Initialize the NIXL backend. - - Creates a NIXL agent with UCX backend. - """ - agent_config = nixl_agent_config(backends=["UCX"]) - ctx = ray.get_runtime_context() - actor_id = ctx.get_actor_id() - if actor_id is None: - # If the actor id is None, it means the current process is a driver. - import uuid - - actor_id = f"RAY-DRIVER-{uuid.uuid4()}" - self._nixl_agent = nixl_agent(actor_id, agent_config) - self._aborted_transfer_obj_ids = set() - self._aborted_transfer_obj_ids_lock = threading.Lock() - - @classmethod - def backend(cls): - """Get the backend type. - - Returns: - Backend.NIXL: The backend type enum value for NIXL. - """ - return Backend.NIXL - - def recv( - self, - tensors: List["torch.Tensor"], - obj_id: str, - nixl_serialized_descs: bytes, - remote_nixl_agent_meta: bytes, - ): - """Receive tensors from a remote NIXL agent. - - Args: - tensors: List of tensors to receive into. - obj_id: The object ID for related GPU object. - nixl_serialized_descs: Serialized NIXL descriptors for the remote tensors. - remote_nixl_agent_meta: Metadata about the remote NIXL agent. - - Raises: - RuntimeError: If the NIXL transfer enters an error state. - """ - with self._aborted_transfer_obj_ids_lock: - if obj_id in self._aborted_transfer_obj_ids: - self._aborted_transfer_obj_ids.remove(obj_id) - raise RuntimeError(f"NIXL transfer aborted for object id: {obj_id}") - - local_descs = None - remote_name = None - xfer_handle = None - try: - nixl_agent = self._nixl_agent - remote_descs = nixl_agent.deserialize_descs(nixl_serialized_descs) - local_descs = nixl_agent.register_memory(tensors) - remote_name = nixl_agent.add_remote_agent(remote_nixl_agent_meta) - - xfer_handle = nixl_agent.initialize_xfer( - # "UUID" here is just a placeholder, can be any bytes, but without it, - # nixl will fail to transfer multiple times. - "READ", - local_descs.trim(), - remote_descs, - remote_name, - "UUID", - ) - - state = nixl_agent.transfer(xfer_handle) - if state == "ERR": - raise RuntimeError("NIXL transfer got to Error state.") - # Since current nixl does not provide a better way, we need to check the state of - # the transfer continuously. - while True: - state = nixl_agent.check_xfer_state(xfer_handle) - if state == "ERR": - raise RuntimeError("NIXL transfer got to Error state.") - if state == "PROC": - with self._aborted_transfer_obj_ids_lock: - if obj_id in self._aborted_transfer_obj_ids: - self._aborted_transfer_obj_ids.remove(obj_id) - raise RuntimeError( - f"NIXL transfer aborted for object id: {obj_id}" - ) - time.sleep(0.001) # Avoid busy waiting - elif state == "DONE": - break - finally: - # We could raise errors or NIXL could raise errors like NIXL_ERR_REMOTE_DISCONNECT, - # so doing best effort cleanup. - with self._aborted_transfer_obj_ids_lock: - self._aborted_transfer_obj_ids.discard(obj_id) - if xfer_handle: - nixl_agent.release_xfer_handle(xfer_handle) - if remote_name: - nixl_agent.remove_remote_agent(remote_name) - if local_descs: - nixl_agent.deregister_memory(local_descs) - - def get_nixl_metadata( - self, tensors: List["torch.Tensor"] - ) -> Tuple[Any, bytes, bytes]: - """Get NIXL metadata for a set of tensors. - - Args: - tensors: List of tensors to get metadata for. - - Returns: - tuple: A tuple containing: - - Serialized NIXL descriptors for the tensors - - Metadata about this NIXL agent - """ - nixl_agent = self._nixl_agent - reg_descs = nixl_agent.register_memory(tensors) - xfer_descs = reg_descs.trim() - return ( - reg_descs, - nixl_agent.get_serialized_descs(xfer_descs), - nixl_agent.get_agent_metadata(), - ) - - def deregister_memory(self, descs: Any): - self._nixl_agent.deregister_memory(descs) - - def abort(self, obj_id: str): - with self._aborted_transfer_obj_ids_lock: - self._aborted_transfer_obj_ids.add(obj_id) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index c0f395d6e5d7..bf357bbd24c6 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -3,16 +3,14 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Any, List, Optional, Tuple - -from numpy import int32 +from typing import TYPE_CHECKING _NUMPY_AVAILABLE = True _TORCH_AVAILABLE = True _CUPY_AVAILABLE = True if TYPE_CHECKING: - import torch + pass try: import torch as th # noqa: F401 @@ -57,71 +55,6 @@ def __new__(cls, name: str): return backend -@dataclass -class TensorTransportMetadata: - """Metadata for tensors stored in the GPU object store. - - Args: - tensor_meta: A list of tuples, each containing the shape and dtype of a tensor. - tensor_device: The device of the tensor. Currently, we require all tensors in the - list have the same device type. - """ - - tensor_meta: List[Tuple["torch.Size", "torch.dtype"]] - tensor_device: Optional["torch.device"] = None - - -@dataclass -class NixlTransportMetadata(TensorTransportMetadata): - """Metadata for tensors stored in the GPU object store for NIXL transport. - - Args: - nixl_serialized_descs: Serialized tensor descriptors for NIXL transport. - nixl_agent_meta: The additional metadata of the remote NIXL agent. - """ - - nixl_reg_descs: Optional[Any] = None - nixl_serialized_descs: Optional[bytes] = None - nixl_agent_meta: Optional[bytes] = None - - __eq__ = object.__eq__ - __hash__ = object.__hash__ - - -@dataclass -class CollectiveTransportMetadata(TensorTransportMetadata): - """Metadata for tensors stored in the GPU object store for collective transport.""" - - -@dataclass -class CommunicatorMetadata: - """Metadata for the communicator. - - Args: - communicator_name: The name of the communicator. - """ - - communicator_name: str = "" - - -@dataclass -class CollectiveCommunicatorMetadata(CommunicatorMetadata): - """Metadata for the collective communicator (e.g. NCCL, GLOO). - - Args: - src_rank: The rank of the source actor. - dst_rank: The rank of the destination actor. - """ - - src_rank: Optional[int32] = None - dst_rank: Optional[int32] = None - - -@dataclass -class NixlCommunicatorMetadata(CommunicatorMetadata): - """Metadata for the NIXL communicator.""" - - class ReduceOp(Enum): SUM = 0 PRODUCT = 1 @@ -131,9 +64,6 @@ class ReduceOp(Enum): unset_timeout_ms = timedelta(milliseconds=-1) -# This is used to identify the collective group for NIXL. -NIXL_GROUP_NAME = "ray_internal_nixl_group" - @dataclass class AllReduceOptions: From 4fa48a2016b43b4df87b469d2e75868c5f0e8cb2 Mon Sep 17 00:00:00 2001 From: dayshah Date: Tue, 2 Dec 2025 13:22:39 -0800 Subject: [PATCH 02/18] fix nixl Signed-off-by: dayshah --- .../gpu_object_manager/nixl_tensor_transport.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 31ce49d9e918..c104a75f9c53 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -77,7 +77,12 @@ def __ray_actor_has_tensor_transport__( ) -> bool: # Check if nixl is installed try: - pass + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) + + get_tensor_transport_manager("nixl") + return True except Exception: return False From 0386f17fcc055dd6629b44721d8f76919933065f Mon Sep 17 00:00:00 2001 From: dayshah Date: Wed, 3 Dec 2025 23:14:10 -0800 Subject: [PATCH 03/18] if not tensors Signed-off-by: dayshah --- .../experimental/gpu_object_manager/nixl_tensor_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index c104a75f9c53..c8ec34e9f39e 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -148,7 +148,7 @@ def recv_multiple_tensors( tensor_transport_metadata: NixlTransportMetadata, communicator_metadata: NixlCommunicatorMetadata, ): - if tensors is None: + if not tensors: return assert isinstance( From 4abde04fa7bcb6689b17abd47a198c40b0a89d66 Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 19:17:17 +0000 Subject: [PATCH 04/18] remove usage of get_group_handle in base rdt Signed-off-by: dayshah --- .../collective_tensor_transport.py | 1 + .../gpu_object_manager/gpu_object_manager.py | 10 +++++++++- .../gpu_object_manager/gpu_object_store.py | 14 +++++--------- .../ray/experimental/gpu_object_manager/types.py | 8 +------- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py b/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py index 954ce67e01de..019e12656274 100644 --- a/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/collective_tensor_transport.py @@ -28,6 +28,7 @@ class CollectiveCommunicatorMetadata(CommunicatorMetadata): dst_rank: The rank of the destination actor. """ + communicator_name: str = "" src_rank: Optional[int] = None dst_rank: Optional[int] = None diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index c98c2d68dd26..f0ea19a91333 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -212,6 +212,7 @@ def _abort_transport( __ray_abort_transport__, ref_info.obj_id, ref_info.communicator_meta, + ref_info.backend, ) ref_info.dst_actor.__ray_call__.options( concurrency_group="_ray_system_error" @@ -219,6 +220,7 @@ def _abort_transport( __ray_abort_transport__, ref_info.obj_id, ref_info.communicator_meta, + ref_info.backend, ) logger.info( "RDT transfer with src actor %s and dst actor %s failed due to %s.", @@ -401,7 +403,11 @@ def _fetch_object( None, None, tensor_transport_backend ) __ray_recv__( - None, obj_id, [gpu_object_meta.tensor_transport_meta], communicator_meta + None, + obj_id, + [gpu_object_meta.tensor_transport_meta], + communicator_meta, + tensor_transport_backend, ) def trigger_out_of_band_tensor_transfer( @@ -506,6 +512,7 @@ def trigger_out_of_band_tensor_transfer( obj_id, tensor_transport_meta, communicator_meta, + gpu_object_meta.tensor_transport_backend, ) # Receive tensors from the source rank and store them in the @@ -521,6 +528,7 @@ def trigger_out_of_band_tensor_transfer( obj_id, [tensor_transport_meta], communicator_meta, + gpu_object_meta.tensor_transport_backend, ) self._unmonitored_transfers.put( diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_store.py b/python/ray/experimental/gpu_object_manager/gpu_object_store.py index f4ed10500dec..e7db5cf3606f 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_store.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_store.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Set, Union import ray -import ray.util.collective as collective from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef from ray.experimental.gpu_object_manager.types import ( @@ -65,6 +64,7 @@ def __ray_send__( obj_id: str, tensor_transport_meta: TensorTransportMetadata, communicator_meta: CommunicatorMetadata, + backend: str, ): """Helper function that runs on the src actor to send tensors to the dst actor.""" from ray._private.worker import global_worker @@ -76,8 +76,6 @@ def __ray_send__( tensors = gpu_object_store.get_object(obj_id) - backend = collective.get_group_handle(communicator_meta.communicator_name).backend() - tensor_transport_manager = get_tensor_transport_manager(backend) if tensors and not device_match_transport(tensors[0].device, backend): raise ValueError( @@ -95,6 +93,7 @@ def __ray_recv__( obj_id: str, tensor_transport_meta: List[Union[ObjectRef, TensorTransportMetadata]], communicator_meta: CommunicatorMetadata, + backend: str, ): """Helper function that runs on the dst actor to receive tensors from the src actor.""" from ray._private.worker import global_worker @@ -109,10 +108,6 @@ def __ray_recv__( device = tensor_transport_meta.tensor_device tensor_meta = tensor_transport_meta.tensor_meta - backend = collective.get_group_handle( - communicator_meta.communicator_name - ).backend() - if tensor_meta and not device_match_transport(device, backend): raise ValueError( f"Tensor transport backend {backend} does not support tensor transfer on device {device}." @@ -138,9 +133,10 @@ def __ray_recv__( gpu_object_store.add_object(obj_id, e, is_primary=False) -def __ray_abort_transport__(self, obj_id: str, communicator_meta: CommunicatorMetadata): +def __ray_abort_transport__( + self, obj_id: str, communicator_meta: CommunicatorMetadata, backend: str +): """Helper function that can run on an actor doing a send or recv to abort the transport.""" - backend = collective.get_group_handle(communicator_meta.communicator_name).backend() tensor_transport_manager = get_tensor_transport_manager(backend) tensor_transport_manager.abort_transport(obj_id, communicator_meta) diff --git a/python/ray/experimental/gpu_object_manager/types.py b/python/ray/experimental/gpu_object_manager/types.py index 5ac454d30a52..14b5fbacd089 100644 --- a/python/ray/experimental/gpu_object_manager/types.py +++ b/python/ray/experimental/gpu_object_manager/types.py @@ -7,13 +7,7 @@ @dataclass class CommunicatorMetadata: - """Metadata for the communicator. - - Args: - communicator_name: The name of the communicator. - """ - - communicator_name: str = "" + """Metadata for the communicator.""" @dataclass From 99f1d45dde03711b167f1e080b94f3f38498ad85 Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 19:36:48 +0000 Subject: [PATCH 05/18] fix Signed-off-by: dayshah --- .../ray/experimental/gpu_object_manager/gpu_object_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index f0ea19a91333..1876ebeba311 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -242,9 +242,9 @@ def _abort_transport( # isinstance does an implicit cast and makes communicator_name inaccessible # so we have to get communicator_name before the cast. - collective_group_name = ref_info.communicator_meta.communicator_name if isinstance(ref_info.communicator_meta, CollectiveCommunicatorMetadata): try: + collective_group_name = ref_info.communicator_meta.communicator_name destroy_collective_group(collective_group_name) logger.error( "Destroyed collective group %s due to a hanging/failed RDT transfer", From e579076ef85fd619691502d763f2a8a998fbb077 Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 19:50:05 +0000 Subject: [PATCH 06/18] fix Signed-off-by: dayshah --- .../experimental/gpu_object_manager/nixl_tensor_transport.py | 2 +- .../experimental/gpu_object_manager/tensor_transport_manager.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index c8ec34e9f39e..582f3457cc52 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -219,8 +219,8 @@ def recv_multiple_tensors( def send_multiple_tensors( self, tensors: List["torch.Tensor"], + tensor_transport_metadata: NixlTransportMetadata, communicator_metadata: NixlCommunicatorMetadata, - device: "torch.device", ): raise NotImplementedError( "NIXL transport does not support send_multiple_tensors, since it is a one-sided transport." diff --git a/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py b/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py index 85fd9e54accb..8355f637af31 100644 --- a/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py +++ b/python/ray/experimental/gpu_object_manager/tensor_transport_manager.py @@ -112,6 +112,7 @@ def recv_multiple_tensors( def send_multiple_tensors( self, tensors: List["torch.Tensor"], + tensor_transport_metadata: TensorTransportMetadata, communicator_metadata: CommunicatorMetadata, ): """ @@ -119,6 +120,7 @@ def send_multiple_tensors( Args: tensors: The tensors to send. + tensor_transport_metadata: The tensor transport metadata for the RDT object. communicator_metadata: The communicator metadata for the send/recv operation. """ From fb8a99ff17fb49175169d2da912a20a6809d5e6e Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 21:21:37 +0000 Subject: [PATCH 07/18] get_tensor_transport_manager under lock Signed-off-by: dayshah --- .../experimental/gpu_object_manager/util.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/util.py b/python/ray/experimental/gpu_object_manager/util.py index f759f81d056a..1ae94893c3b3 100644 --- a/python/ray/experimental/gpu_object_manager/util.py +++ b/python/ray/experimental/gpu_object_manager/util.py @@ -1,3 +1,4 @@ +import threading from typing import TYPE_CHECKING from ray.experimental.gpu_object_manager.collective_tensor_transport import ( @@ -31,6 +32,8 @@ # Singleton instances of transport managers transport_managers = {} +transport_managers_lock = threading.Lock() + def get_tensor_transport_manager( transport_name: str, @@ -45,17 +48,19 @@ def get_tensor_transport_manager( """ global transport_manager_classes global transport_managers + global transport_managers_lock - if transport_name in transport_managers: - return transport_managers[transport_name] + with transport_managers_lock: + if transport_name in transport_managers: + return transport_managers[transport_name] - if transport_name not in transport_manager_classes: - raise ValueError(f"Unsupported tensor transport protocol: {transport_name}") + if transport_name not in transport_manager_classes: + raise ValueError(f"Unsupported tensor transport protocol: {transport_name}") - transport_managers[transport_name] = transport_manager_classes[transport_name]( - transport_name - ) - return transport_managers[transport_name] + transport_managers[transport_name] = transport_manager_classes[transport_name]( + transport_name + ) + return transport_managers[transport_name] def device_match_transport(device: "torch.device", tensor_transport: str) -> bool: From 8ba388a0dd80d09182cc49b223a0dcc1af2683ec Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 22:08:02 +0000 Subject: [PATCH 08/18] lazy nixl agent init Signed-off-by: dayshah --- .../nixl_tensor_transport.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 582f3457cc52..cba65dd152c3 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -40,20 +40,8 @@ class NixlTransportMetadata(TensorTransportMetadata): class NixlTensorTransport(TensorTransportManager): def __init__(self, tensor_transport_backend: str): - """ - Creates a NIXL agent with UCX backend. - """ - from nixl._api import nixl_agent, nixl_agent_config - - agent_config = nixl_agent_config(backends=["UCX"]) - ctx = ray.get_runtime_context() - actor_id = ctx.get_actor_id() - if actor_id is None: - # If the actor id is None, it means the current process is a driver. - import uuid - - actor_id = f"RAY-DRIVER-{uuid.uuid4()}" - self._nixl_agent = nixl_agent(actor_id, agent_config) + # This is lazily initialized because it requires NIXL to actually be installed and we want to allow an owner that is just coordinating to not need to have NIXL installed. + self._nixl_agent = None self._aborted_transfer_obj_ids = set() self._aborted_transfer_obj_ids_lock = threading.Lock() @@ -69,6 +57,26 @@ def is_one_sided() -> bool: def can_abort_transport() -> bool: return True + def get_nixl_agent(self): + """ + Creates a NIXL agent with UCX backend if not already created. + """ + if self._nixl_agent is not None: + return self._nixl_agent + + from nixl._api import nixl_agent, nixl_agent_config + + agent_config = nixl_agent_config(backends=["UCX"]) + ctx = ray.get_runtime_context() + actor_id = ctx.get_actor_id() + if actor_id is None: + # If the actor id is None, it means the current process is a driver. + import uuid + + actor_id = f"RAY-DRIVER-{uuid.uuid4()}" + self._nixl_agent = nixl_agent(actor_id, agent_config) + return self._nixl_agent + def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: # TODO(dayshah): This is called on a .remote RDT call, so it's quite expensive. @@ -109,9 +117,10 @@ def extract_tensor_transport_metadata( return duplicate_meta if gpu_object: - reg_descs = self._nixl_agent.register_memory(gpu_object) - serialized_descs = self._nixl_agent.get_serialized_descs(reg_descs.trim()) - agent_meta = self._nixl_agent.get_agent_metadata() + nixl_agent = self.get_nixl_agent() + reg_descs = nixl_agent.register_memory(gpu_object) + serialized_descs = nixl_agent.get_serialized_descs(reg_descs.trim()) + agent_meta = nixl_agent.get_agent_metadata() # We assume all tensors in one GPU object have the same device type. device = gpu_object[0].device for t in gpu_object: @@ -170,7 +179,7 @@ def recv_multiple_tensors( remote_name = None xfer_handle = None try: - nixl_agent = self._nixl_agent + nixl_agent = self.get_nixl_agent() remote_descs = nixl_agent.deserialize_descs(nixl_serialized_descs) local_descs = nixl_agent.register_memory(tensors) remote_name = nixl_agent.add_remote_agent(remote_nixl_agent_meta) @@ -236,7 +245,7 @@ def garbage_collect( if count == 0: descs = tensor_transport_meta.nixl_reg_descs if descs is not None: - self._nixl_agent.deregister_memory(descs) + self.get_nixl_agent().deregister_memory(descs) def abort_transport( self, From 02d408c828ede41a5e5d103bbe2b3dbbaf39ca56 Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 22:22:33 +0000 Subject: [PATCH 09/18] actor has nixl fix Signed-off-by: dayshah --- .../gpu_object_manager/nixl_tensor_transport.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index cba65dd152c3..32c18c40026e 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -85,11 +85,7 @@ def __ray_actor_has_tensor_transport__( ) -> bool: # Check if nixl is installed try: - from ray.experimental.gpu_object_manager.util import ( - get_tensor_transport_manager, - ) - - get_tensor_transport_manager("nixl") + self.get_nixl_agent() return True except Exception: return False From ed3b576c674c17b518bda3551d5984588000a646 Mon Sep 17 00:00:00 2001 From: dayshah Date: Thu, 4 Dec 2025 22:30:03 +0000 Subject: [PATCH 10/18] fix nixl has actor Signed-off-by: dayshah --- .../gpu_object_manager/nixl_tensor_transport.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 32c18c40026e..05dd4eb6a182 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -85,7 +85,11 @@ def __ray_actor_has_tensor_transport__( ) -> bool: # Check if nixl is installed try: - self.get_nixl_agent() + from ray.experimental.gpu_object_manager.util import ( + get_tensor_transport_manager, + ) + + get_tensor_transport_manager("nixl").get_nixl_agent() return True except Exception: return False From 42e1f85e41699f8c760af1326d651cc80b680159 Mon Sep 17 00:00:00 2001 From: dayshah Date: Sat, 6 Dec 2025 02:33:08 +0000 Subject: [PATCH 11/18] [core][rdt] Rework passing transport through for bring your own transport Signed-off-by: dayshah --- python/ray/_private/custom_types.py | 4 +- python/ray/_private/worker.py | 59 +++++--------- python/ray/_raylet.pyx | 4 +- python/ray/actor.py | 45 ++++++----- .../ray/experimental/collective/collective.py | 44 ++++------- .../gpu_object_manager/gpu_object_manager.py | 77 ++++++++----------- .../gpu_object_manager/gpu_object_store.py | 18 ----- .../nixl_tensor_transport.py | 7 +- .../experimental/gpu_object_manager/util.py | 12 +-- .../gpu_objects/test_gpu_objects_gloo.py | 3 +- python/ray/util/collective/types.py | 9 +-- src/ray/protobuf/common.proto | 8 +- 12 files changed, 112 insertions(+), 178 deletions(-) diff --git a/python/ray/_private/custom_types.py b/python/ray/_private/custom_types.py index fc5e3a5622fb..76a29bb192bd 100644 --- a/python/ray/_private/custom_types.py +++ b/python/ray/_private/custom_types.py @@ -124,9 +124,7 @@ # See `common.proto` for more details. class TensorTransportEnum(Enum): OBJECT_STORE = TensorTransport.Value("OBJECT_STORE") - NCCL = TensorTransport.Value("NCCL") - GLOO = TensorTransport.Value("GLOO") - NIXL = TensorTransport.Value("NIXL") + DIRECT_TRANSPORT = TensorTransport.Value("DIRECT_TRANSPORT") @classmethod def from_str(cls, name: str) -> "TensorTransportEnum": diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 346e2e2a0e5e..b07ec0b9b928 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -801,7 +801,7 @@ def put_object( value: Any, owner_address: Optional[str] = None, _is_experimental_channel: bool = False, - _tensor_transport: str = "object_store", + _tensor_transport: str = "OBJECT_STORE", ): """Put value in the local object store. @@ -835,18 +835,16 @@ def put_object( "ray.ObjectRef in a list and call 'put' on it." ) tensors = None - tensor_transport: TensorTransportEnum = TensorTransportEnum.from_str( - _tensor_transport - ) + tensor_transport = _tensor_transport.upper() if tensor_transport not in [ - TensorTransportEnum.OBJECT_STORE, - TensorTransportEnum.NIXL, + "OBJECT_STORE", + "NIXL", ]: raise ValueError( "Currently, Ray Direct Transport only supports 'object_store' and 'nixl' for tensor transport in ray.put()." ) try: - if tensor_transport != TensorTransportEnum.OBJECT_STORE: + if tensor_transport != "OBJECT_STORE": ( serialized_value, tensors, @@ -867,19 +865,24 @@ def put_object( # object. Instead, clients will keep the object pinned. pin_object = not _is_experimental_channel + tensor_transport_enum = TensorTransportEnum.OBJECT_STORE + if tensor_transport != "OBJECT_STORE": + tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT + # This *must* be the first place that we construct this python # ObjectRef because an entry with 0 local references is created when # the object is Put() in the core worker, expecting that this python # reference will be created. If another reference is created and # removed before this one, it will corrupt the state in the # reference counter. + ret = self.core_worker.put_object( serialized_value, pin_object=pin_object, owner_address=owner_address, inline_small_object=True, _is_experimental_channel=_is_experimental_channel, - tensor_transport_val=tensor_transport.value, + tensor_transport_val=tensor_transport_enum.value, ) if tensors: self.gpu_object_manager.put_object(ret, tensor_transport, tensors) @@ -896,43 +899,26 @@ def deserialize_objects( self, serialized_objects, object_refs, - tensor_transport_hint: Optional[TensorTransportEnum] = None, + tensor_transport_hint: Optional[str] = None, ): gpu_objects: Dict[str, List["torch.Tensor"]] = {} for obj_ref, (_, _, tensor_transport) in zip(object_refs, serialized_objects): - # TODO: Here tensor_transport_hint is set by the user in ray.get(), tensor_transport is set - # in serialize_objects by ray.method(tensor_transport="xxx"), and obj_ref.tensor_transport() - # is set by ray.put(). We may clean up this logic in the future. if ( tensor_transport is None or tensor_transport == TensorTransportEnum.OBJECT_STORE - ) and ( - obj_ref is None - or obj_ref.tensor_transport() == TensorTransportEnum.OBJECT_STORE.value ): # The object is not a gpu object, so we cannot use other external transport to # fetch it. continue - # If the object is a gpu object, we can choose to use the object store or other external - # transport to fetch it. The `tensor_transport_hint` has the highest priority, then the - # tensor_transport in obj_ref.tensor_transport(), then the tensor_transport in serialize_objects, - # then the default value `OBJECT_STORE`. - chosen_tensor_transport = ( - tensor_transport_hint - or ( - TensorTransportEnum(obj_ref.tensor_transport()) if obj_ref else None - ) - or tensor_transport - or TensorTransportEnum.OBJECT_STORE - ) - object_id = obj_ref.hex() if object_id not in gpu_objects: # If using a non-object store transport, then tensors will be sent # out-of-band. Get them before deserializing the object store data. + # The user can choose OBJECT_STORE as the hint to fetch the RDT object + # through the object store. gpu_objects[object_id] = self.gpu_object_manager.get_gpu_object( - object_id, tensor_transport=chosen_tensor_transport + object_id, tensor_transport=tensor_transport_hint ) # Function actor manager or the import thread may call pickle.loads @@ -983,16 +969,6 @@ def get_objects( f"Attempting to call `get` on the value {object_ref}, " "which is not an ray.ObjectRef." ) - tensor_transport: TensorTransportEnum = ( - TensorTransportEnum.from_str(_tensor_transport) - if _tensor_transport is not None - else None - ) - assert tensor_transport in [ - TensorTransportEnum.OBJECT_STORE, - TensorTransportEnum.NIXL, - None, - ], "Currently, RDT only supports 'object_store' and 'nixl' for tensor transport in ray.get()." timeout_ms = ( int(timeout * 1000) if timeout is not None and timeout != -1 else -1 ) @@ -1004,7 +980,7 @@ def get_objects( ) debugger_breakpoint = b"" - for data, metadata, _ in serialized_objects: + for _, metadata, _ in serialized_objects: if metadata: metadata_fields = metadata.split(b",") if len(metadata_fields) >= 2 and metadata_fields[1].startswith( @@ -1016,6 +992,9 @@ def get_objects( if skip_deserialization: return None, debugger_breakpoint + tensor_transport = ( + _tensor_transport.upper() if _tensor_transport is not None else None + ) values = self.deserialize_objects( serialized_objects, object_refs, tensor_transport_hint=tensor_transport ) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 631eca7631b4..fc2a816a4194 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -2874,7 +2874,7 @@ cdef class CoreWorker: c_object_ids, timeout_ms, results) check_status(op_status) - return RayObjectsToSerializedRayObjects(results) + return RayObjectsToSerializedRayObjects(results, object_refs) def get_if_local(self, object_refs): """Get objects from local plasma store directly @@ -2886,7 +2886,7 @@ cdef class CoreWorker: check_status( CCoreWorkerProcess.GetCoreWorker().GetIfLocal( c_object_ids, &results)) - return RayObjectsToSerializedRayObjects(results) + return RayObjectsToSerializedRayObjects(results, object_refs) def object_exists(self, ObjectRef object_ref, memory_store_only=False): cdef: diff --git a/python/ray/actor.py b/python/ray/actor.py index 851330bf084d..2d0410cf0c0b 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -473,9 +473,7 @@ def annotate_method(method: Callable[_P, _Ret]): if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: - method.__ray_tensor_transport__ = TensorTransportEnum.from_str( - kwargs["tensor_transport"] - ) + method.__ray_tensor_transport__ = kwargs["tensor_transport"].upper() return method # Check if decorator is called without parentheses (args[0] would be the function) @@ -521,7 +519,7 @@ def __init__( enable_task_events: bool, decorator: Optional[Any] = None, signature: Optional[List[inspect.Parameter]] = None, - tensor_transport: Optional[TensorTransportEnum] = None, + tensor_transport: Optional[str] = None, ): """Initialize an _ActorMethodMetadata. @@ -599,7 +597,7 @@ def __init__( enable_task_events: bool, decorator=None, signature: Optional[List[inspect.Parameter]] = None, - tensor_transport: Optional[TensorTransportEnum] = None, + tensor_transport: Optional[str] = None, ): """Initialize an ActorMethod. @@ -649,10 +647,10 @@ def __init__( # and return the resulting ObjectRefs. self._decorator = decorator - # If the task call doesn't specify a tensor transport option, use `_tensor_transport` + # If the task call doesn't specify a tensor transport option, use `OBJECT_STORE` # as the default transport for this actor method. if tensor_transport is None: - tensor_transport = TensorTransportEnum.OBJECT_STORE + tensor_transport = "OBJECT_STORE" self._tensor_transport = tensor_transport def __call__(self, *args, **kwargs): @@ -695,7 +693,7 @@ def options(self, **options): tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: - options["tensor_transport"] = TensorTransportEnum.from_str(tensor_transport) + options["tensor_transport"] = tensor_transport.upper() class FuncWrapper: def remote(self, *args, **kwargs): @@ -800,7 +798,7 @@ def _remote( concurrency_group=None, _generator_backpressure_num_objects=None, enable_task_events=None, - tensor_transport: Optional[TensorTransportEnum] = None, + tensor_transport: Optional[str] = None, ): if num_returns is None: num_returns = self._num_returns @@ -820,15 +818,15 @@ def _remote( if tensor_transport is None: tensor_transport = self._tensor_transport - if tensor_transport != TensorTransportEnum.OBJECT_STORE: + if tensor_transport != "OBJECT_STORE": if num_returns != 1: raise ValueError( - f"Currently, methods with tensor_transport={tensor_transport.name} only support 1 return value. " + f"Currently, methods with tensor_transport={tensor_transport} only support 1 return value. " "Please make sure the actor method is decorated with `@ray.method(num_returns=1)` (the default)." ) if not self._actor._ray_enable_tensor_transport: raise ValueError( - f'Currently, methods with .options(tensor_transport="{tensor_transport.name}") are not supported when enable_tensor_transport=False. ' + f'Currently, methods with .options(tensor_transport="{tensor_transport}") are not supported when enable_tensor_transport=False. ' "Please set @ray.remote(enable_tensor_transport=True) on the actor class definition." ) gpu_object_manager = ray._private.worker.global_worker.gpu_object_manager @@ -836,7 +834,7 @@ def _remote( self._actor, tensor_transport ): raise ValueError( - f'{self._actor} does not have tensor transport {tensor_transport.name} available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with ' + f'{self._actor} does not have tensor transport {tensor_transport} available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with ' "`ray.experimental.collective.create_collective_group` " "before calling actor tasks with non-default tensor_transport." ) @@ -876,7 +874,7 @@ def invocation(args, kwargs): invocation = self._decorator(invocation) object_refs = invocation(args, kwargs) - if tensor_transport != TensorTransportEnum.OBJECT_STORE: + if tensor_transport != "OBJECT_STORE": # Currently, we only support transfer tensor out-of-band when # num_returns is 1. assert isinstance(object_refs, ObjectRef) @@ -979,14 +977,12 @@ def create( self.enable_task_events = {} self.generator_backpressure_num_objects = {} self.concurrency_group_for_methods = {} - self.method_name_to_tensor_transport: Dict[str, TensorTransportEnum] = {} + self.method_name_to_tensor_transport: Dict[str, str] = {} # Check whether any actor methods specify a non-default tensor transport. self.has_tensor_transport_methods = any( - getattr( - method, "__ray_tensor_transport__", TensorTransportEnum.OBJECT_STORE - ) - != TensorTransportEnum.OBJECT_STORE + getattr(method, "__ray_tensor_transport__", "OBJECT_STORE") + != "OBJECT_STORE" for _, method in actor_methods ) @@ -1941,7 +1937,7 @@ def __init__( method_generator_backpressure_num_objects: Dict[str, int], method_enable_task_events: Dict[str, bool], enable_tensor_transport: bool, - method_name_to_tensor_transport: Dict[str, TensorTransportEnum], + method_name_to_tensor_transport: Dict[str, str], actor_method_cpus: int, actor_creation_function_descriptor, cluster_and_job, @@ -1968,7 +1964,7 @@ def __init__( this actor. If True, then methods can be called with .options(tensor_transport=...) to specify a non-default tensor transport. - method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport settings. + method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport type. actor_method_cpus: The number of CPUs required by actor methods. actor_creation_function_descriptor: The function descriptor for actor creation. cluster_and_job: The cluster and job information. @@ -2079,7 +2075,7 @@ def _actor_method_call( concurrency_group_name: Optional[str] = None, generator_backpressure_num_objects: Optional[int] = None, enable_task_events: Optional[bool] = None, - tensor_transport: Optional[TensorTransportEnum] = None, + tensor_transport: Optional[str] = None, ): """Method execution stub for an actor handle. @@ -2157,6 +2153,9 @@ def _actor_method_call( if generator_backpressure_num_objects is None: generator_backpressure_num_objects = -1 + tensor_transport_enum = TensorTransportEnum.OBJECT_STORE + if tensor_transport is not None and tensor_transport != "OBJECT_STORE": + tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, self._ray_actor_id, @@ -2171,7 +2170,7 @@ def _actor_method_call( concurrency_group_name if concurrency_group_name is not None else b"", generator_backpressure_num_objects, enable_task_events, - tensor_transport.value, + tensor_transport_enum.value, ) if num_returns == STREAMING_GENERATOR_RETURN: diff --git a/python/ray/experimental/collective/collective.py b/python/ray/experimental/collective/collective.py index 1872ace6fd8e..7a3298872942 100644 --- a/python/ray/experimental/collective/collective.py +++ b/python/ray/experimental/collective/collective.py @@ -43,7 +43,7 @@ def remove_remote_communicator(self, name: str): def get_collective_groups( self, actors: Optional[List[ray.actor.ActorHandle]] = None, - backend: Optional[str] = None, + backend: Optional[Backend] = None, ): """ Get the collective groups that the given actors are a subset of. Filter by @@ -62,28 +62,6 @@ def get_collective_groups( return collectives -def _do_init_collective_group( - self, - world_size: int, - rank: int, - backend: str = Backend.NCCL, - name: str = "default", -): - """Helper method that runs as a task on a remote actor to create a - collective group. - """ - ray.util.collective.init_collective_group( - world_size, rank, backend, group_name=name - ) - - -def _do_destroy_collective_group(self, name): - """Helper method that runs as a task on a remote actor to destroy a - collective group. - """ - ray.util.collective.destroy_collective_group(name) - - @PublicAPI(stability="alpha") def get_collective_groups( actors: List[ray.actor.ActorHandle], backend: Optional[str] = None @@ -102,7 +80,7 @@ def get_collective_groups( A list of communicator handles that the actors are a subset of. """ manager = RemoteCommunicatorManager.get() - return manager.get_collective_groups(actors, backend) + return manager.get_collective_groups(actors, Backend(backend)) @PublicAPI(stability="alpha") @@ -163,10 +141,16 @@ def create_collective_group( metadata_key = get_master_address_metadata_key(name) internal_kv._internal_kv_put(metadata_key, f"{master_addr}:{master_port}") + def _do_init_collective_group(self, rank: int): + ray.util.collective.init_collective_group( + world_size, rank, backend, group_name=name + ) + try: init_tasks = [ actor.__ray_call__.remote( - _do_init_collective_group, world_size, rank, backend, name + _do_init_collective_group, + rank, ) for rank, actor in enumerate(actors) ] @@ -178,8 +162,8 @@ def create_collective_group( internal_kv._internal_kv_del(metadata_key) # Group was successfully created. - # Register GLOO groups under TORCH_GLOO since GLOO uses torch.distributed. - registration_backend = Backend.TORCH_GLOO if backend == Backend.GLOO else backend + # Register TORCH_GLOO groups under GLOO since GLOO and TORCH_GLOO are the same now, both using torch.distributed. + registration_backend = Backend.GLOO if backend == Backend.TORCH_GLOO else backend comm = CommunicatorHandle(actors, name, registration_backend) manager.add_remote_communicator(comm) return comm @@ -206,9 +190,13 @@ def destroy_collective_group(group_or_name: Union[CommunicatorHandle, str]): manager = RemoteCommunicatorManager.get() group = manager.remove_remote_communicator(name) if group is not None: + + def _do_destroy_collective_group(self): + ray.util.collective.destroy_collective_group(name) + destroy_tasks = [ actor.__ray_call__.options(concurrency_group="_ray_system").remote( - _do_destroy_collective_group, name + _do_destroy_collective_group ) for actor in group.actors ] diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index 1876ebeba311..e48d7dcd678f 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -7,7 +7,6 @@ import ray from ray._private import ray_constants -from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef from ray.util.annotations import PublicAPI @@ -27,7 +26,6 @@ # GPUObjectMeta is a named tuple containing the source actor, tensor transport # backend, tensor metadata, and other information that needs to be recorded. # - The tensor transport backend is the backend used to transport the tensors. -# Currently, the supported backends are "nccl" and "torch_gloo". # - The tensor metadata is a list of tuples, each containing the shape and dtype # of a tensor in the GPU object store. class GPUObjectMeta(NamedTuple): @@ -284,7 +282,7 @@ def add_gpu_object_ref( self, obj_ref: ObjectRef, src_actor: "ray.actor.ActorHandle", - tensor_transport: TensorTransportEnum, + tensor_transport: str, tensor_transport_meta: Optional["TensorTransportMetadata"] = None, ): """Add a GPU object reference to the GPU object manager. This should be @@ -299,12 +297,8 @@ def add_gpu_object_ref( """ from ray.experimental.gpu_object_manager.gpu_object_store import ( __ray_get_tensor_transport_metadata__, - _tensor_transport_to_collective_backend, ) - tensor_transport_backend = _tensor_transport_to_collective_backend( - tensor_transport - ) obj_id = obj_ref.hex() if not tensor_transport_meta: # Submit a Ray actor task to the source actor to get the tensor metadata. @@ -315,14 +309,12 @@ def add_gpu_object_ref( # executing on the main thread blocking this task. tensor_meta = src_actor.__ray_call__.options( concurrency_group="_ray_system" - ).remote( - __ray_get_tensor_transport_metadata__, obj_id, tensor_transport_backend - ) + ).remote(__ray_get_tensor_transport_metadata__, obj_id, tensor_transport) else: tensor_meta = tensor_transport_meta self.managed_gpu_object_metadata[obj_id] = GPUObjectMeta( src_actor=src_actor, - tensor_transport_backend=tensor_transport_backend, + tensor_transport_backend=tensor_transport, tensor_transport_meta=tensor_meta, sent_dest_actors=set(), sent_to_src_actor_and_others_warned=False, @@ -335,7 +327,7 @@ def _get_gpu_object_metadata(self, obj_ref: ObjectRef) -> GPUObjectMeta: def _fetch_object( self, obj_id: str, - tensor_transport: TensorTransportEnum = TensorTransportEnum.OBJECT_STORE, + tensor_transport: Optional[str], ): """ Fetches the GPU object from the source actor's GPU object store via the object store @@ -348,6 +340,9 @@ def _fetch_object( Args: obj_id: The object ID of the GPU object. tensor_transport: The tensor transport to use to fetch the GPU object. + This should either be object store or the actual tensor transport for the RDT object. + If this is None, the tensor transport backend of the RDT object will be used. + Note that NIXL is the only tensor transport that is supported for this right now. Returns: None @@ -359,23 +354,35 @@ def _fetch_object( get_tensor_transport_manager, ) - if tensor_transport not in [ - TensorTransportEnum.OBJECT_STORE, - TensorTransportEnum.NIXL, - ]: - raise ValueError( - f"Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got {tensor_transport}, please specify the correct tensor transport in ray.get()." - ) - if self.gpu_object_store.has_object(obj_id): return + gpu_object_meta = self.managed_gpu_object_metadata[obj_id] - src_actor = gpu_object_meta.src_actor tensor_transport_backend = gpu_object_meta.tensor_transport_backend + if tensor_transport is None: + tensor_transport = tensor_transport_backend + + if tensor_transport not in ["OBJECT_STORE", "NIXL"]: + raise ValueError( + "Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, " + f"got {tensor_transport}, please specify the correct tensor transport in ray.get()." + ) + + if ( + tensor_transport != "OBJECT_STORE" + and tensor_transport != tensor_transport_backend + ): + raise ValueError( + f"Got {tensor_transport} and object had tensor transport backend {tensor_transport_backend}, " + "please specify the correct tensor transport in ray.get()." + ) + + src_actor = gpu_object_meta.src_actor tensor_transport_manager = get_tensor_transport_manager( tensor_transport_backend ) - if tensor_transport == TensorTransportEnum.OBJECT_STORE: + + if tensor_transport == "OBJECT_STORE": tensors = ray.get( src_actor.__ray_call__.options(concurrency_group="_ray_system").remote( __ray_fetch_gpu_object__, obj_id @@ -552,7 +559,7 @@ def trigger_out_of_band_tensor_transfer( def get_gpu_object( self, object_id: str, - tensor_transport: TensorTransportEnum = TensorTransportEnum.OBJECT_STORE, + tensor_transport: Optional[str], ) -> List["torch.Tensor"]: """ Get the GPU object for a given object ID. @@ -610,7 +617,7 @@ def free_object_primary_copy(self, object_id: str): ) def actor_has_tensor_transport( - self, actor: "ray.actor.ActorHandle", tensor_transport: TensorTransportEnum + self, actor: "ray.actor.ActorHandle", tensor_transport: str ): """ Check if the actor has a communicator for the given tensor transport backend. @@ -622,27 +629,17 @@ def actor_has_tensor_transport( Returns: True if the actor has a communicator for the given tensor transport backend, False otherwise. """ - # Import get_collective_groups here to avoid dependency on - # collective libraries for default Ray installation. - from ray.experimental.gpu_object_manager.gpu_object_store import ( - _tensor_transport_to_collective_backend, - ) from ray.experimental.gpu_object_manager.util import ( get_tensor_transport_manager, ) - tensor_transport_backend = _tensor_transport_to_collective_backend( - tensor_transport - ) - tensor_transport_manager = get_tensor_transport_manager( - tensor_transport_backend - ) + tensor_transport_manager = get_tensor_transport_manager(tensor_transport) return tensor_transport_manager.actor_has_tensor_transport(actor) def put_object( self, obj_ref: ObjectRef, - tensor_transport: TensorTransportEnum, + tensor_transport: str, tensors: List["torch.Tensor"], ): """ @@ -654,17 +651,11 @@ def put_object( tensors: The tensors to put into the GPU object manager. """ - from ray.experimental.gpu_object_manager.gpu_object_store import ( - _tensor_transport_to_collective_backend, - ) from ray.experimental.gpu_object_manager.util import ( get_tensor_transport_manager, ) - tensor_transport_backend = _tensor_transport_to_collective_backend( - tensor_transport - ) - transport_manager = get_tensor_transport_manager(tensor_transport_backend) + transport_manager = get_tensor_transport_manager(tensor_transport) tensor_transport_meta = transport_manager.extract_tensor_transport_metadata( obj_ref.hex(), tensors ) diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_store.py b/python/ray/experimental/gpu_object_manager/gpu_object_store.py index e7db5cf3606f..44d5dcf59e86 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_store.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_store.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Set, Union import ray -from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef from ray.experimental.gpu_object_manager.types import ( CommunicatorMetadata, @@ -23,23 +22,6 @@ "Please install torch with 'pip install torch' to use this feature." ) -TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND = { - TensorTransportEnum.NCCL: "nccl", - TensorTransportEnum.GLOO: "torch_gloo", - TensorTransportEnum.NIXL: "nixl", -} - - -def _tensor_transport_to_collective_backend( - tensor_transport: TensorTransportEnum, -) -> str: - try: - return TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND[tensor_transport] - except KeyError: - raise ValueError( - f"Invalid tensor transport {tensor_transport.name}, must be one of {list(TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND.keys())}." - ) - def __ray_get_tensor_transport_metadata__( self, obj_id: str, backend: str diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 05dd4eb6a182..5bfd3208b5ac 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -47,7 +47,7 @@ def __init__(self, tensor_transport_backend: str): @property def tensor_transport_backend(self) -> str: - return "nixl" + return "NIXL" @staticmethod def is_one_sided() -> bool: @@ -79,6 +79,7 @@ def get_nixl_agent(self): def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: # TODO(dayshah): This is called on a .remote RDT call, so it's quite expensive. + print("checking nixl tensor transport") def __ray_actor_has_tensor_transport__( self: "ray.actor.ActorHandle", @@ -89,9 +90,11 @@ def __ray_actor_has_tensor_transport__( get_tensor_transport_manager, ) - get_tensor_transport_manager("nixl").get_nixl_agent() + get_tensor_transport_manager("NIXL").get_nixl_agent() + print("nixl found") return True except Exception: + print("nixl not found") return False return ray.get( diff --git a/python/ray/experimental/gpu_object_manager/util.py b/python/ray/experimental/gpu_object_manager/util.py index 1ae94893c3b3..08381a216897 100644 --- a/python/ray/experimental/gpu_object_manager/util.py +++ b/python/ray/experimental/gpu_object_manager/util.py @@ -17,15 +17,15 @@ # Class definitions for transport managers transport_manager_classes: dict[str, TensorTransportManager] = { - "nixl": NixlTensorTransport, - "torch_gloo": CollectiveTensorTransport, - "nccl": CollectiveTensorTransport, + "NIXL": NixlTensorTransport, + "GLOO": CollectiveTensorTransport, + "NCCL": CollectiveTensorTransport, } transport_devices = { - "nixl": ["cuda", "cpu"], - "torch_gloo": ["cpu"], - "nccl": ["cuda"], + "NIXL": ["cuda", "cpu"], + "GLOO": ["cpu"], + "NCCL": ["cuda"], } diff --git a/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py b/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py index efd5dac0f39e..25ae8a25ac8b 100644 --- a/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py +++ b/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py @@ -10,7 +10,6 @@ import ray from ray._common.test_utils import SignalActor, wait_for_condition -from ray._private.custom_types import TensorTransportEnum from ray.experimental.collective import create_collective_group # tensordict is not supported on macos ci, so we skip the tests @@ -519,7 +518,7 @@ def test_trigger_out_of_band_tensor_transfer(ray_start_regular): assert torch.equal(ret_val_src[0], tensor) gpu_object_manager = ray._private.worker.global_worker.gpu_object_manager - gpu_object_manager.add_gpu_object_ref(gpu_ref, src_actor, TensorTransportEnum.GLOO) + gpu_object_manager.add_gpu_object_ref(gpu_ref, src_actor, "GLOO") # Trigger out-of-band tensor transfer from src_actor to dst_actor. task_args = (gpu_ref,) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index bf357bbd24c6..c4630f64b66b 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -34,14 +34,13 @@ def torch_available(): class Backend(object): """A class to represent different backends.""" - NCCL = "nccl" - MPI = "mpi" + NCCL = "NCCL" + MPI = "MPI" # `pygloo` is deprecated. Use gloo through torch.distributed for both # `GLOO` and `TORCH_GLOO`. - GLOO = "gloo" + GLOO = "GLOO" # Use gloo through torch.distributed. - TORCH_GLOO = "torch_gloo" - NIXL = "nixl" + TORCH_GLOO = "TORCH_GLOO" UNRECOGNIZED = "unrecognized" def __new__(cls, name: str): diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 82d6b66f1c92..1bb541e7ec4c 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -732,12 +732,8 @@ message ObjectReferenceCount { enum TensorTransport { // Use the default object store for tensor transport. OBJECT_STORE = 0; - // Use NCCL for tensor transport. - NCCL = 1; - // Use GLOO for tensor transport. - GLOO = 2; - // Use NIXL for tensor transport. - NIXL = 3; + // Use Ray Direct Transport to transfer tensors directly between workers. + DIRECT_TRANSPORT = 1; } // Argument in the task. From 6034d578b4d711e3c67a1b87fe9688381ba37e95 Mon Sep 17 00:00:00 2001 From: dayshah Date: Sat, 6 Dec 2025 03:40:45 +0000 Subject: [PATCH 12/18] remove prints Signed-off-by: dayshah --- .../experimental/gpu_object_manager/nixl_tensor_transport.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 5bfd3208b5ac..69e40afaf25f 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -91,10 +91,8 @@ def __ray_actor_has_tensor_transport__( ) get_tensor_transport_manager("NIXL").get_nixl_agent() - print("nixl found") return True except Exception: - print("nixl not found") return False return ray.get( From d0286f7d009e4be1a4e262293f09356de0414b08 Mon Sep 17 00:00:00 2001 From: dayshah Date: Sat, 6 Dec 2025 03:41:06 +0000 Subject: [PATCH 13/18] remove prints Signed-off-by: dayshah --- .../experimental/gpu_object_manager/nixl_tensor_transport.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py index 69e40afaf25f..52aad44c61c2 100644 --- a/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py +++ b/python/ray/experimental/gpu_object_manager/nixl_tensor_transport.py @@ -79,8 +79,6 @@ def get_nixl_agent(self): def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool: # TODO(dayshah): This is called on a .remote RDT call, so it's quite expensive. - print("checking nixl tensor transport") - def __ray_actor_has_tensor_transport__( self: "ray.actor.ActorHandle", ) -> bool: From fb6e95c57f3e56f82848d5c82d0d644c0eb3fd61 Mon Sep 17 00:00:00 2001 From: dayshah Date: Sun, 7 Dec 2025 21:22:10 +0000 Subject: [PATCH 14/18] fix tests Signed-off-by: dayshah --- .../ray-core/doc_code/direct_transport_gloo.py | 2 +- python/ray/_private/worker.py | 16 ++++++++++++++++ python/ray/actor.py | 16 ++++++++++++++++ python/ray/experimental/collective/collective.py | 7 +++---- python/ray/includes/common.pxd | 4 +--- .../tests/gpu_objects/test_gpu_objects_gloo.py | 11 +++++++++++ python/ray/util/collective/collective.py | 7 +------ .../torch_gloo_collective_group.py | 2 +- python/ray/util/collective/types.py | 16 +++++++--------- .../tests/dependency_resolver_test.cc | 2 +- src/ray/core_worker/tests/task_manager_test.cc | 4 ++-- 11 files changed, 60 insertions(+), 27 deletions(-) diff --git a/doc/source/ray-core/doc_code/direct_transport_gloo.py b/doc/source/ray-core/doc_code/direct_transport_gloo.py index 866b827db960..398cb6add35c 100644 --- a/doc/source/ray-core/doc_code/direct_transport_gloo.py +++ b/doc/source/ray-core/doc_code/direct_transport_gloo.py @@ -140,7 +140,7 @@ def sum(self, tensor: torch.Tensor): ray.get(tensor) assert ( - "Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got TensorTransportEnum.GLOO, please specify the correct tensor transport in ray.get()" + "Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got GLOO, please specify the correct tensor transport in ray.get()." in str(e.value) ) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index b07ec0b9b928..aec248c87499 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -836,6 +836,14 @@ def put_object( ) tensors = None tensor_transport = _tensor_transport.upper() + if tensor_transport != "OBJECT_STORE": + from ray.experimental.gpu_object_manager.util import ( + transport_manager_classes, + ) + + if tensor_transport not in transport_manager_classes: + raise ValueError(f"Invalid tensor transport: {tensor_transport}") + if tensor_transport not in [ "OBJECT_STORE", "NIXL", @@ -995,6 +1003,14 @@ def get_objects( tensor_transport = ( _tensor_transport.upper() if _tensor_transport is not None else None ) + if tensor_transport is not None and tensor_transport != "OBJECT_STORE": + from ray.experimental.gpu_object_manager.util import ( + transport_manager_classes, + ) + + if tensor_transport not in transport_manager_classes: + raise ValueError(f"Invalid tensor transport: {tensor_transport}") + values = self.deserialize_objects( serialized_objects, object_refs, tensor_transport_hint=tensor_transport ) diff --git a/python/ray/actor.py b/python/ray/actor.py index 2d0410cf0c0b..26759fed908d 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -474,6 +474,15 @@ def annotate_method(method: Callable[_P, _Ret]): method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: method.__ray_tensor_transport__ = kwargs["tensor_transport"].upper() + tensor_transport = method.__ray_tensor_transport__ + if tensor_transport != "OBJECT_STORE": + from ray.experimental.gpu_object_manager.util import ( + transport_manager_classes, + ) + + if tensor_transport not in transport_manager_classes: + raise ValueError(f"Invalid tensor transport: {tensor_transport}") + return method # Check if decorator is called without parentheses (args[0] would be the function) @@ -694,6 +703,13 @@ def options(self, **options): tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: options["tensor_transport"] = tensor_transport.upper() + if tensor_transport != "OBJECT_STORE": + from ray.experimental.gpu_object_manager.util import ( + transport_manager_classes, + ) + + if tensor_transport not in transport_manager_classes: + raise ValueError(f"Invalid tensor transport: {tensor_transport}") class FuncWrapper: def remote(self, *args, **kwargs): diff --git a/python/ray/experimental/collective/collective.py b/python/ray/experimental/collective/collective.py index 7a3298872942..d66542d44c7e 100644 --- a/python/ray/experimental/collective/collective.py +++ b/python/ray/experimental/collective/collective.py @@ -80,7 +80,8 @@ def get_collective_groups( A list of communicator handles that the actors are a subset of. """ manager = RemoteCommunicatorManager.get() - return manager.get_collective_groups(actors, Backend(backend)) + backend = Backend(backend) if backend is not None else None + return manager.get_collective_groups(actors, backend) @PublicAPI(stability="alpha") @@ -162,9 +163,7 @@ def _do_init_collective_group(self, rank: int): internal_kv._internal_kv_del(metadata_key) # Group was successfully created. - # Register TORCH_GLOO groups under GLOO since GLOO and TORCH_GLOO are the same now, both using torch.distributed. - registration_backend = Backend.GLOO if backend == Backend.TORCH_GLOO else backend - comm = CommunicatorHandle(actors, name, registration_backend) + comm = CommunicatorHandle(actors, name, backend) manager.add_remote_communicator(comm) return comm diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 1a5c460b950c..e290b7c3b050 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -276,9 +276,7 @@ cdef extern from "src/ray/protobuf/common.pb.h" nogil: cdef extern from "src/ray/protobuf/common.pb.h" nogil: cdef CTensorTransport TENSOR_TRANSPORT_OBJECT_STORE "ray::rpc::TensorTransport::OBJECT_STORE" - cdef CTensorTransport TENSOR_TRANSPORT_NCCL "ray::rpc::TensorTransport::NCCL" - cdef CTensorTransport TENSOR_TRANSPORT_GLOO "ray::rpc::TensorTransport::GLOO" - cdef CTensorTransport TENSOR_TRANSPORT_NIXL "ray::rpc::TensorTransport::NIXL" + cdef CTensorTransport TENSOR_TRANSPORT_DIRECT_TRANSPORT "ray::rpc::TensorTransport::DIRECT_TRANSPORT" cdef extern from "src/ray/protobuf/common.pb.h" nogil: cdef CPlacementStrategy PLACEMENT_STRATEGY_PACK \ diff --git a/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py b/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py index 25ae8a25ac8b..28bf9d72efb1 100644 --- a/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py +++ b/python/ray/tests/gpu_objects/test_gpu_objects_gloo.py @@ -568,6 +568,17 @@ class InvalidActor: def echo(self, data): return data + actor = GPUTestActor.remote() + with pytest.raises(ValueError, match="Invalid tensor transport"): + actor.double.options(tensor_transport="invalid").remote(torch.randn((1,))) + + with pytest.raises(ValueError, match="Invalid tensor transport"): + ray.put(torch.randn((1,)), _tensor_transport="invalid") + + valid_ref = actor.double.remote(torch.randn((1,))) + with pytest.raises(ValueError, match="Invalid tensor transport"): + ray.get(valid_ref, _tensor_transport="invalid") + @pytest.mark.skipif( not support_tensordict, diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 1476512a2f1b..8803da0219eb 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -88,9 +88,7 @@ def create_collective_group( metadata as well. """ backend = types.Backend(backend) - if backend == types.Backend.MPI: - raise RuntimeError("Ray does not support MPI.") - elif backend == types.Backend.GLOO or backend == types.Backend.TORCH_GLOO: + if backend == types.Backend.GLOO: # Rendezvous: ensure a MASTER_ADDR:MASTER_PORT is published in internal_kv. metadata_key = _get_master_addr_key(group_name) if rank == 0: @@ -816,9 +814,6 @@ def _check_backend_availability(backend: types.Backend): elif backend == types.Backend.NCCL: if not nccl_available(): raise RuntimeError("NCCL is not available.") - elif backend == types.Backend.TORCH_GLOO: - if not torch_distributed_available(): - raise RuntimeError("torch.distributed is not available.") def _check_inside_actor(): diff --git a/python/ray/util/collective/collective_group/torch_gloo_collective_group.py b/python/ray/util/collective/collective_group/torch_gloo_collective_group.py index d2314c5ea54a..cf06728739c3 100644 --- a/python/ray/util/collective/collective_group/torch_gloo_collective_group.py +++ b/python/ray/util/collective/collective_group/torch_gloo_collective_group.py @@ -106,7 +106,7 @@ def destroy_group(self): @classmethod def backend(cls): """The backend of this collective group.""" - return Backend.TORCH_GLOO + return Backend.GLOO def _check_tensor_input(self, tensor: List["torch.Tensor"]) -> "torch.Tensor": """ray.util.collective wraps tensor arguments in a list. diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index c4630f64b66b..23d43cdae005 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -35,22 +35,20 @@ class Backend(object): """A class to represent different backends.""" NCCL = "NCCL" - MPI = "MPI" - # `pygloo` is deprecated. Use gloo through torch.distributed for both - # `GLOO` and `TORCH_GLOO`. GLOO = "GLOO" - # Use gloo through torch.distributed. - TORCH_GLOO = "TORCH_GLOO" UNRECOGNIZED = "unrecognized" def __new__(cls, name: str): - backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED) + upper_name = name.upper() + backend = getattr(Backend, upper_name, Backend.UNRECOGNIZED) if backend == Backend.UNRECOGNIZED: + if upper_name == "TORCH_GLOO": + return Backend.GLOO raise ValueError( - "Unrecognized backend: '{}'. Only NCCL is supported".format(name) + "Unrecognized backend: '{}'. Only NCCL and GLOO are supported".format( + name + ) ) - if backend == Backend.MPI: - raise RuntimeError("Ray does not support MPI backend.") return backend diff --git a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc index 28c87febf837..83f5c691dfa0 100644 --- a/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc +++ b/src/ray/core_worker/task_submission/tests/dependency_resolver_test.cc @@ -465,7 +465,7 @@ TEST(LocalDependencyResolverTest, TestMixedTensorTransport) { LocalDependencyResolver resolver( *store, *task_manager, actor_creator, [&](const ObjectID &object_id) { if (object_id == obj1) { - return rpc::TensorTransport::NCCL; + return rpc::TensorTransport::DIRECT_TRANSPORT; } return rpc::TensorTransport::OBJECT_STORE; }); diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index ec4cc29f3d7a..dfcf9c42ef29 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -64,7 +64,7 @@ TaskSpecification CreateTaskHelper(uint64_t num_returns, if (enable_tensor_transport) { // Currently, only actors support transferring tensors out-of-band. task.GetMutableMessage().set_type(TaskType::ACTOR_TASK); - tensor_transport = rpc::TensorTransport::NCCL; + tensor_transport = rpc::TensorTransport::DIRECT_TRANSPORT; } task.GetMutableMessage().set_tensor_transport(tensor_transport); @@ -2807,7 +2807,7 @@ TEST_F(TaskManagerTest, TestGPUObjectTaskSuccess) { ObjectID gpu_obj_ref = ObjectID::FromRandom(); auto *arg = spec.GetMutableMessage().add_args(); arg->set_is_inlined(false); - arg->set_tensor_transport(rpc::TensorTransport::NCCL); + arg->set_tensor_transport(rpc::TensorTransport::DIRECT_TRANSPORT); arg->mutable_object_ref()->set_object_id(gpu_obj_ref.Binary()); // `gpu_obj_ref` should have a local reference when the sender actor From 21da299571322e39360097551088a466dacbfcc4 Mon Sep 17 00:00:00 2001 From: dayshah Date: Sun, 7 Dec 2025 22:21:14 +0000 Subject: [PATCH 15/18] fix upper Signed-off-by: dayshah --- python/ray/actor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/ray/actor.py b/python/ray/actor.py index 26759fed908d..c71a37dbc548 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -473,8 +473,7 @@ def annotate_method(method: Callable[_P, _Ret]): if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: - method.__ray_tensor_transport__ = kwargs["tensor_transport"].upper() - tensor_transport = method.__ray_tensor_transport__ + tensor_transport = kwargs["tensor_transport"].upper() if tensor_transport != "OBJECT_STORE": from ray.experimental.gpu_object_manager.util import ( transport_manager_classes, @@ -483,6 +482,8 @@ def annotate_method(method: Callable[_P, _Ret]): if tensor_transport not in transport_manager_classes: raise ValueError(f"Invalid tensor transport: {tensor_transport}") + method.__ray_tensor_transport__ = tensor_transport + return method # Check if decorator is called without parentheses (args[0] would be the function) @@ -702,7 +703,7 @@ def options(self, **options): tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: - options["tensor_transport"] = tensor_transport.upper() + tensor_transport = tensor_transport.upper() if tensor_transport != "OBJECT_STORE": from ray.experimental.gpu_object_manager.util import ( transport_manager_classes, @@ -711,6 +712,8 @@ def options(self, **options): if tensor_transport not in transport_manager_classes: raise ValueError(f"Invalid tensor transport: {tensor_transport}") + options["tensor_transport"] = tensor_transport + class FuncWrapper: def remote(self, *args, **kwargs): return func_cls._remote(args=args, kwargs=kwargs, **options) From 88b1fc0968261417dbc601d7827f737257a8ace8 Mon Sep 17 00:00:00 2001 From: dayshah Date: Mon, 8 Dec 2025 22:27:03 +0000 Subject: [PATCH 16/18] util funcs Signed-off-by: dayshah --- python/ray/_private/worker.py | 27 +++++++------------ python/ray/actor.py | 22 ++++++--------- .../gpu_object_manager/gpu_object_manager.py | 8 +++--- .../experimental/gpu_object_manager/util.py | 19 +++++++++++++ 4 files changed, 39 insertions(+), 37 deletions(-) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index aec248c87499..0c2eed3fbe0a 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -836,21 +836,13 @@ def put_object( ) tensors = None tensor_transport = _tensor_transport.upper() - if tensor_transport != "OBJECT_STORE": - from ray.experimental.gpu_object_manager.util import ( - transport_manager_classes, - ) - - if tensor_transport not in transport_manager_classes: - raise ValueError(f"Invalid tensor transport: {tensor_transport}") + from ray.experimental.gpu_object_manager.util import ( + validate_one_sided, + validate_tensor_transport, + ) - if tensor_transport not in [ - "OBJECT_STORE", - "NIXL", - ]: - raise ValueError( - "Currently, Ray Direct Transport only supports 'object_store' and 'nixl' for tensor transport in ray.put()." - ) + validate_tensor_transport(tensor_transport) + validate_one_sided(tensor_transport, "ray.put") try: if tensor_transport != "OBJECT_STORE": ( @@ -1003,13 +995,12 @@ def get_objects( tensor_transport = ( _tensor_transport.upper() if _tensor_transport is not None else None ) - if tensor_transport is not None and tensor_transport != "OBJECT_STORE": + if tensor_transport is not None: from ray.experimental.gpu_object_manager.util import ( - transport_manager_classes, + validate_tensor_transport, ) - if tensor_transport not in transport_manager_classes: - raise ValueError(f"Invalid tensor transport: {tensor_transport}") + validate_tensor_transport(tensor_transport) values = self.deserialize_objects( serialized_objects, object_refs, tensor_transport_hint=tensor_transport diff --git a/python/ray/actor.py b/python/ray/actor.py index c71a37dbc548..b910b67cc335 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -474,14 +474,11 @@ def annotate_method(method: Callable[_P, _Ret]): method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: tensor_transport = kwargs["tensor_transport"].upper() - if tensor_transport != "OBJECT_STORE": - from ray.experimental.gpu_object_manager.util import ( - transport_manager_classes, - ) - - if tensor_transport not in transport_manager_classes: - raise ValueError(f"Invalid tensor transport: {tensor_transport}") + from ray.experimental.gpu_object_manager.util import ( + validate_tensor_transport, + ) + validate_tensor_transport(tensor_transport) method.__ray_tensor_transport__ = tensor_transport return method @@ -704,14 +701,11 @@ def options(self, **options): tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: tensor_transport = tensor_transport.upper() - if tensor_transport != "OBJECT_STORE": - from ray.experimental.gpu_object_manager.util import ( - transport_manager_classes, - ) - - if tensor_transport not in transport_manager_classes: - raise ValueError(f"Invalid tensor transport: {tensor_transport}") + from ray.experimental.gpu_object_manager.util import ( + validate_tensor_transport, + ) + validate_tensor_transport(tensor_transport) options["tensor_transport"] = tensor_transport class FuncWrapper: diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index e48d7dcd678f..7fd5b0b58eda 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -362,11 +362,9 @@ def _fetch_object( if tensor_transport is None: tensor_transport = tensor_transport_backend - if tensor_transport not in ["OBJECT_STORE", "NIXL"]: - raise ValueError( - "Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, " - f"got {tensor_transport}, please specify the correct tensor transport in ray.get()." - ) + from ray.experimental.gpu_object_manager.util import validate_one_sided + + validate_one_sided(tensor_transport, "ray.get") if ( tensor_transport != "OBJECT_STORE" diff --git a/python/ray/experimental/gpu_object_manager/util.py b/python/ray/experimental/gpu_object_manager/util.py index 08381a216897..f9c931aad8be 100644 --- a/python/ray/experimental/gpu_object_manager/util.py +++ b/python/ray/experimental/gpu_object_manager/util.py @@ -70,3 +70,22 @@ def device_match_transport(device: "torch.device", tensor_transport: str) -> boo raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}") return device.type in transport_devices[tensor_transport] + + +def validate_tensor_transport(tensor_transport: str): + if ( + tensor_transport != "OBJECT_STORE" + and tensor_transport not in transport_manager_classes + ): + raise ValueError(f"Invalid tensor transport: {tensor_transport}") + + +def validate_one_sided(tensor_transport: str, ray_usage_func: str): + if ( + tensor_transport != "OBJECT_STORE" + and not transport_manager_classes[tensor_transport].is_one_sided() + ): + raise ValueError( + f"Trying to use two-sided tensor transport: {tensor_transport} for {ray_usage_func}. " + "This is only supported for one-sided transports such as NIXL or the OBJECT_STORE." + ) From 62369d92b6ecf52c3f758608f8a2acfc77f9283a Mon Sep 17 00:00:00 2001 From: dayshah Date: Mon, 8 Dec 2025 22:49:10 +0000 Subject: [PATCH 17/18] fix assert Signed-off-by: dayshah --- doc/source/ray-core/doc_code/direct_transport_gloo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/ray-core/doc_code/direct_transport_gloo.py b/doc/source/ray-core/doc_code/direct_transport_gloo.py index 398cb6add35c..392a66a510f2 100644 --- a/doc/source/ray-core/doc_code/direct_transport_gloo.py +++ b/doc/source/ray-core/doc_code/direct_transport_gloo.py @@ -140,7 +140,7 @@ def sum(self, tensor: torch.Tensor): ray.get(tensor) assert ( - "Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got GLOO, please specify the correct tensor transport in ray.get()." + "Trying to use two-sided tensor transport: GLOO for ray.get. This is only supported for one-sided transports such as NIXL or the OBJECT_STORE." in str(e.value) ) From 3a6aed3c787d756e505890e6c2f3888c33081da1 Mon Sep 17 00:00:00 2001 From: dayshah Date: Mon, 8 Dec 2025 17:31:53 -0800 Subject: [PATCH 18/18] address comments Signed-off-by: dayshah --- python/ray/_private/worker.py | 24 +++++++-------- python/ray/actor.py | 30 +++++++++++-------- .../gpu_object_manager/gpu_object_manager.py | 9 +++--- .../experimental/gpu_object_manager/util.py | 11 +++++-- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 0c2eed3fbe0a..443d676a455b 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -801,7 +801,7 @@ def put_object( value: Any, owner_address: Optional[str] = None, _is_experimental_channel: bool = False, - _tensor_transport: str = "OBJECT_STORE", + _tensor_transport: str = TensorTransportEnum.OBJECT_STORE.name, ): """Put value in the local object store. @@ -835,16 +835,15 @@ def put_object( "ray.ObjectRef in a list and call 'put' on it." ) tensors = None - tensor_transport = _tensor_transport.upper() from ray.experimental.gpu_object_manager.util import ( + normalize_and_validate_tensor_transport, validate_one_sided, - validate_tensor_transport, ) - validate_tensor_transport(tensor_transport) + tensor_transport = normalize_and_validate_tensor_transport(_tensor_transport) validate_one_sided(tensor_transport, "ray.put") try: - if tensor_transport != "OBJECT_STORE": + if tensor_transport != TensorTransportEnum.OBJECT_STORE.name: ( serialized_value, tensors, @@ -866,7 +865,7 @@ def put_object( pin_object = not _is_experimental_channel tensor_transport_enum = TensorTransportEnum.OBJECT_STORE - if tensor_transport != "OBJECT_STORE": + if tensor_transport != TensorTransportEnum.OBJECT_STORE.name: tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT # This *must* be the first place that we construct this python @@ -992,18 +991,17 @@ def get_objects( if skip_deserialization: return None, debugger_breakpoint - tensor_transport = ( - _tensor_transport.upper() if _tensor_transport is not None else None - ) - if tensor_transport is not None: + if _tensor_transport is not None: from ray.experimental.gpu_object_manager.util import ( - validate_tensor_transport, + normalize_and_validate_tensor_transport, ) - validate_tensor_transport(tensor_transport) + _tensor_transport = normalize_and_validate_tensor_transport( + _tensor_transport + ) values = self.deserialize_objects( - serialized_objects, object_refs, tensor_transport_hint=tensor_transport + serialized_objects, object_refs, tensor_transport_hint=_tensor_transport ) if not return_exceptions: # Raise exceptions instead of returning them to the user. diff --git a/python/ray/actor.py b/python/ray/actor.py index b910b67cc335..b6f0024aa40c 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -473,12 +473,12 @@ def annotate_method(method: Callable[_P, _Ret]): if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: - tensor_transport = kwargs["tensor_transport"].upper() + tensor_transport = kwargs["tensor_transport"] from ray.experimental.gpu_object_manager.util import ( - validate_tensor_transport, + normalize_and_validate_tensor_transport, ) - validate_tensor_transport(tensor_transport) + tensor_transport = normalize_and_validate_tensor_transport(tensor_transport) method.__ray_tensor_transport__ = tensor_transport return method @@ -657,7 +657,7 @@ def __init__( # If the task call doesn't specify a tensor transport option, use `OBJECT_STORE` # as the default transport for this actor method. if tensor_transport is None: - tensor_transport = "OBJECT_STORE" + tensor_transport = TensorTransportEnum.OBJECT_STORE.name self._tensor_transport = tensor_transport def __call__(self, *args, **kwargs): @@ -700,12 +700,11 @@ def options(self, **options): tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: - tensor_transport = tensor_transport.upper() from ray.experimental.gpu_object_manager.util import ( - validate_tensor_transport, + normalize_and_validate_tensor_transport, ) - validate_tensor_transport(tensor_transport) + tensor_transport = normalize_and_validate_tensor_transport(tensor_transport) options["tensor_transport"] = tensor_transport class FuncWrapper: @@ -831,7 +830,7 @@ def _remote( if tensor_transport is None: tensor_transport = self._tensor_transport - if tensor_transport != "OBJECT_STORE": + if tensor_transport != TensorTransportEnum.OBJECT_STORE.name: if num_returns != 1: raise ValueError( f"Currently, methods with tensor_transport={tensor_transport} only support 1 return value. " @@ -887,7 +886,7 @@ def invocation(args, kwargs): invocation = self._decorator(invocation) object_refs = invocation(args, kwargs) - if tensor_transport != "OBJECT_STORE": + if tensor_transport != TensorTransportEnum.OBJECT_STORE.name: # Currently, we only support transfer tensor out-of-band when # num_returns is 1. assert isinstance(object_refs, ObjectRef) @@ -994,8 +993,12 @@ def create( # Check whether any actor methods specify a non-default tensor transport. self.has_tensor_transport_methods = any( - getattr(method, "__ray_tensor_transport__", "OBJECT_STORE") - != "OBJECT_STORE" + getattr( + method, + "__ray_tensor_transport__", + TensorTransportEnum.OBJECT_STORE.name, + ) + != TensorTransportEnum.OBJECT_STORE.name for _, method in actor_methods ) @@ -2167,7 +2170,10 @@ def _actor_method_call( generator_backpressure_num_objects = -1 tensor_transport_enum = TensorTransportEnum.OBJECT_STORE - if tensor_transport is not None and tensor_transport != "OBJECT_STORE": + if ( + tensor_transport is not None + and tensor_transport != TensorTransportEnum.OBJECT_STORE.name + ): tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, diff --git a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py index 7fd5b0b58eda..ef18b07ca711 100644 --- a/python/ray/experimental/gpu_object_manager/gpu_object_manager.py +++ b/python/ray/experimental/gpu_object_manager/gpu_object_manager.py @@ -7,6 +7,7 @@ import ray from ray._private import ray_constants +from ray._private.custom_types import TensorTransportEnum from ray._raylet import ObjectRef from ray.util.annotations import PublicAPI @@ -340,9 +341,9 @@ def _fetch_object( Args: obj_id: The object ID of the GPU object. tensor_transport: The tensor transport to use to fetch the GPU object. - This should either be object store or the actual tensor transport for the RDT object. + This should either be object store or the tensor transport for the RDT object. If this is None, the tensor transport backend of the RDT object will be used. - Note that NIXL is the only tensor transport that is supported for this right now. + Note that NIXL is the only supported RDT tensor transport right now. Returns: None @@ -367,7 +368,7 @@ def _fetch_object( validate_one_sided(tensor_transport, "ray.get") if ( - tensor_transport != "OBJECT_STORE" + tensor_transport != TensorTransportEnum.OBJECT_STORE.name and tensor_transport != tensor_transport_backend ): raise ValueError( @@ -380,7 +381,7 @@ def _fetch_object( tensor_transport_backend ) - if tensor_transport == "OBJECT_STORE": + if tensor_transport == TensorTransportEnum.OBJECT_STORE.name: tensors = ray.get( src_actor.__ray_call__.options(concurrency_group="_ray_system").remote( __ray_fetch_gpu_object__, obj_id diff --git a/python/ray/experimental/gpu_object_manager/util.py b/python/ray/experimental/gpu_object_manager/util.py index f9c931aad8be..7bb5b8e8c363 100644 --- a/python/ray/experimental/gpu_object_manager/util.py +++ b/python/ray/experimental/gpu_object_manager/util.py @@ -1,6 +1,7 @@ import threading from typing import TYPE_CHECKING +from ray._private.custom_types import TensorTransportEnum from ray.experimental.gpu_object_manager.collective_tensor_transport import ( CollectiveTensorTransport, ) @@ -72,17 +73,21 @@ def device_match_transport(device: "torch.device", tensor_transport: str) -> boo return device.type in transport_devices[tensor_transport] -def validate_tensor_transport(tensor_transport: str): +def normalize_and_validate_tensor_transport(tensor_transport: str) -> str: + tensor_transport = tensor_transport.upper() + if ( - tensor_transport != "OBJECT_STORE" + tensor_transport != TensorTransportEnum.OBJECT_STORE.name and tensor_transport not in transport_manager_classes ): raise ValueError(f"Invalid tensor transport: {tensor_transport}") + return tensor_transport + def validate_one_sided(tensor_transport: str, ray_usage_func: str): if ( - tensor_transport != "OBJECT_STORE" + tensor_transport != TensorTransportEnum.OBJECT_STORE.name and not transport_manager_classes[tensor_transport].is_one_sided() ): raise ValueError(