diff --git a/src/core/deviceio_base/cpp/inc/deviceio_base/opaque_data_channel_tracker_base.hpp b/src/core/deviceio_base/cpp/inc/deviceio_base/opaque_data_channel_tracker_base.hpp new file mode 100644 index 000000000..108216331 --- /dev/null +++ b/src/core/deviceio_base/cpp/inc/deviceio_base/opaque_data_channel_tracker_base.hpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tracker.hpp" + +#include +#include +#include + +namespace core +{ + +// Abstract base interface for opaque data channel tracker implementations. +// Returns raw bytes received from the XR_NV_opaque_data_channel extension. +class IOpaqueDataChannelTrackerImpl : public ITrackerImpl +{ +public: + virtual std::optional> get_latest_message() const = 0; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/cpp/CMakeLists.txt b/src/core/deviceio_trackers/cpp/CMakeLists.txt index 0ea801774..272672cc1 100644 --- a/src/core/deviceio_trackers/cpp/CMakeLists.txt +++ b/src/core/deviceio_trackers/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 cmake_minimum_required(VERSION 3.20) @@ -11,12 +11,14 @@ add_library(deviceio_trackers STATIC generic_3axis_pedal_tracker.cpp frame_metadata_tracker_oak.cpp full_body_tracker_pico.cpp + opaque_data_channel_tracker.cpp inc/deviceio_trackers/head_tracker.hpp inc/deviceio_trackers/hand_tracker.hpp inc/deviceio_trackers/controller_tracker.hpp inc/deviceio_trackers/full_body_tracker_pico.hpp inc/deviceio_trackers/generic_3axis_pedal_tracker.hpp inc/deviceio_trackers/frame_metadata_tracker_oak.hpp + inc/deviceio_trackers/opaque_data_channel_tracker.hpp ) target_include_directories(deviceio_trackers diff --git a/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/opaque_data_channel_tracker.hpp b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/opaque_data_channel_tracker.hpp new file mode 100644 index 000000000..0e8ed876f --- /dev/null +++ b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/opaque_data_channel_tracker.hpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace core +{ + +// Receives arbitrary bytes from a CloudXR opaque data channel identified by UUID. +// The channel is created and polled by the live implementation; this public +// tracker exposes the latest received message to Python DeviceIO source nodes. +class OpaqueDataChannelTracker : public ITracker +{ +public: + explicit OpaqueDataChannelTracker(std::array uuid); + + std::string_view get_name() const override; + + std::optional> get_latest_message(const ITrackerSession& session) const; + + const std::array& get_uuid() const; + +private: + static constexpr const char* TRACKER_NAME = "OpaqueDataChannelTracker"; + std::array uuid_; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/cpp/opaque_data_channel_tracker.cpp b/src/core/deviceio_trackers/cpp/opaque_data_channel_tracker.cpp new file mode 100644 index 000000000..bb059b5e7 --- /dev/null +++ b/src/core/deviceio_trackers/cpp/opaque_data_channel_tracker.cpp @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "inc/deviceio_trackers/opaque_data_channel_tracker.hpp" + +namespace core +{ + +OpaqueDataChannelTracker::OpaqueDataChannelTracker(std::array uuid) + : uuid_(uuid) +{ +} + +std::string_view OpaqueDataChannelTracker::get_name() const +{ + return TRACKER_NAME; +} + +std::optional> OpaqueDataChannelTracker::get_latest_message(const ITrackerSession& session) const +{ + return static_cast(session.get_tracker_impl(*this)).get_latest_message(); +} + +const std::array& OpaqueDataChannelTracker::get_uuid() const +{ + return uuid_; +} + +} // namespace core diff --git a/src/core/deviceio_trackers/python/tracker_bindings.cpp b/src/core/deviceio_trackers/python/tracker_bindings.cpp index 494e72233..4e6430ff3 100644 --- a/src/core/deviceio_trackers/python/tracker_bindings.cpp +++ b/src/core/deviceio_trackers/python/tracker_bindings.cpp @@ -7,10 +7,13 @@ #include #include #include +#include #include #include #include +#include + namespace py = pybind11; PYBIND11_MODULE(_deviceio_trackers, m) @@ -95,6 +98,30 @@ PYBIND11_MODULE(_deviceio_trackers, m) { return self.get_body_pose(session); }, py::arg("session"), "Get full body pose tracked state (data is None if inactive)"); + py::class_>( + m, "OpaqueDataChannelTracker") + .def(py::init([](py::bytes uuid_bytes) + { + std::string raw = uuid_bytes; + if (raw.size() != 16) + throw std::invalid_argument("UUID must be exactly 16 bytes"); + std::array arr; + std::memcpy(arr.data(), raw.data(), 16); + return std::make_shared(arr); + }), + py::arg("uuid"), "Construct with a 16-byte UUID identifying the data channel") + .def( + "get_latest_message", + [](const core::OpaqueDataChannelTracker& self, + const core::ITrackerSession& session) -> std::optional + { + auto msg = self.get_latest_message(session); + if (!msg) + return std::nullopt; + return py::bytes(reinterpret_cast(msg->data()), msg->size()); + }, + py::arg("session"), "Get the latest received message bytes, or None if no message this frame"); + m.attr("NUM_JOINTS") = static_cast(core::HandJoint_NUM_JOINTS); m.attr("JOINT_PALM") = static_cast(core::HandJoint_PALM); m.attr("JOINT_WRIST") = static_cast(core::HandJoint_WRIST); diff --git a/src/core/live_trackers/cpp/CMakeLists.txt b/src/core/live_trackers/cpp/CMakeLists.txt index a2e429215..20c3cb402 100644 --- a/src/core/live_trackers/cpp/CMakeLists.txt +++ b/src/core/live_trackers/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 cmake_minimum_required(VERSION 3.20) @@ -12,6 +12,7 @@ add_library(live_trackers STATIC live_full_body_tracker_pico_impl.cpp live_generic_3axis_pedal_tracker_impl.cpp live_frame_metadata_tracker_oak_impl.cpp + live_opaque_data_channel_tracker_impl.cpp inc/live_trackers/schema_tracker_base.hpp inc/live_trackers/schema_tracker.hpp inc/live_trackers/live_deviceio_factory.hpp @@ -21,6 +22,7 @@ add_library(live_trackers STATIC live_full_body_tracker_pico_impl.hpp live_generic_3axis_pedal_tracker_impl.hpp live_frame_metadata_tracker_oak_impl.hpp + live_opaque_data_channel_tracker_impl.hpp ) target_include_directories(live_trackers diff --git a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp index 50b17fdc5..291aaa3ea 100644 --- a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp +++ b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 #pragma once @@ -32,6 +32,8 @@ class HandTracker; class IHandTrackerImpl; class HeadTracker; class IHeadTrackerImpl; +class OpaqueDataChannelTracker; +class IOpaqueDataChannelTrackerImpl; struct OpenXRSessionHandles; /** @@ -61,6 +63,8 @@ class LiveDeviceIOFactory const Generic3AxisPedalTracker* tracker); std::unique_ptr create_frame_metadata_tracker_oak_impl( const FrameMetadataTrackerOak* tracker); + std::unique_ptr create_opaque_data_channel_tracker_impl( + const OpaqueDataChannelTracker* tracker); private: bool should_record(const ITracker* tracker) const; diff --git a/src/core/live_trackers/cpp/live_deviceio_factory.cpp b/src/core/live_trackers/cpp/live_deviceio_factory.cpp index 1160f7996..89da51312 100644 --- a/src/core/live_trackers/cpp/live_deviceio_factory.cpp +++ b/src/core/live_trackers/cpp/live_deviceio_factory.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 #include "inc/live_trackers/live_deviceio_factory.hpp" @@ -9,6 +9,7 @@ #include "live_generic_3axis_pedal_tracker_impl.hpp" #include "live_hand_tracker_impl.hpp" #include "live_head_tracker_impl.hpp" +#include "live_opaque_data_channel_tracker_impl.hpp" #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include #include @@ -77,6 +79,12 @@ std::unique_ptr try_create_oak_impl(LiveDeviceIOFactory& factory, return typed ? factory.create_frame_metadata_tracker_oak_impl(typed) : nullptr; } +std::unique_ptr try_create_opaque_channel_impl(LiveDeviceIOFactory& factory, const ITracker& tracker) +{ + auto* typed = dynamic_cast(&tracker); + return typed ? factory.create_opaque_data_channel_tracker_impl(typed) : nullptr; +} + using CollectExtensionsFn = bool (*)(const ITracker&, std::set&); using TryCreateFn = std::unique_ptr (*)(LiveDeviceIOFactory&, const ITracker&); @@ -94,6 +102,7 @@ inline const TrackerDispatchEntry k_tracker_dispatch[] = { { &try_add_extensions, &try_create_full_body_pico_impl }, { &try_add_extensions, &try_create_generic_pedal_impl }, { &try_add_extensions, &try_create_oak_impl }, + { &try_add_extensions, &try_create_opaque_channel_impl }, }; } // namespace @@ -235,4 +244,10 @@ std::unique_ptr LiveDeviceIOFactory::create_frame_ return std::make_unique(handles_, tracker, std::move(channels)); } +std::unique_ptr LiveDeviceIOFactory::create_opaque_data_channel_tracker_impl( + const OpaqueDataChannelTracker* tracker) +{ + return std::make_unique(handles_, tracker); +} + } // namespace core diff --git a/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.cpp b/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.cpp new file mode 100644 index 000000000..0d85dc449 --- /dev/null +++ b/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.cpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "live_opaque_data_channel_tracker_impl.hpp" + +#include +#include + +#include +#include +#include + +namespace core +{ + +LiveOpaqueDataChannelTrackerImpl::LiveOpaqueDataChannelTrackerImpl(const OpenXRSessionHandles& handles, + const OpaqueDataChannelTracker* tracker) + : instance_(handles.instance) +{ + auto core_funcs = OpenXRCoreFunctions::load(handles.instance, handles.xrGetInstanceProcAddr); + + XrSystemId system_id; + XrSystemGetInfo system_info{ XR_TYPE_SYSTEM_GET_INFO }; + system_info.formFactor = XR_FORM_FACTOR_HEAD_MOUNTED_DISPLAY; + + XrResult result = core_funcs.xrGetSystem(handles.instance, &system_info, &system_id); + if (XR_FAILED(result)) + { + throw std::runtime_error("[OpaqueDataChannel] Failed to get OpenXR system: " + std::to_string(result)); + } + + loadExtensionFunction(handles.instance, handles.xrGetInstanceProcAddr, "xrCreateOpaqueDataChannelNV", + reinterpret_cast(&pfn_create_)); + loadExtensionFunction(handles.instance, handles.xrGetInstanceProcAddr, "xrDestroyOpaqueDataChannelNV", + reinterpret_cast(&pfn_destroy_)); + loadExtensionFunction(handles.instance, handles.xrGetInstanceProcAddr, "xrGetOpaqueDataChannelStateNV", + reinterpret_cast(&pfn_get_state_)); + loadExtensionFunction(handles.instance, handles.xrGetInstanceProcAddr, "xrReceiveOpaqueDataChannelNV", + reinterpret_cast(&pfn_receive_)); + loadExtensionFunction(handles.instance, handles.xrGetInstanceProcAddr, "xrShutdownOpaqueDataChannelNV", + reinterpret_cast(&pfn_shutdown_)); + + XrOpaqueDataChannelCreateInfoNV create_info{}; + create_info.type = XR_TYPE_OPAQUE_DATA_CHANNEL_CREATE_INFO_NV; + create_info.next = nullptr; + create_info.systemId = system_id; + + const auto& uuid = tracker->get_uuid(); + static_assert(sizeof(create_info.uuid.data) == 16); + std::memcpy(create_info.uuid.data, uuid.data(), 16); + + result = pfn_create_(handles.instance, &create_info, &channel_); + if (XR_FAILED(result)) + { + throw std::runtime_error("[OpaqueDataChannel] xrCreateOpaqueDataChannelNV failed: " + std::to_string(result)); + } + + std::cout << "OpaqueDataChannelTracker initialized (channel created, waiting for connection)" << std::endl; +} + +LiveOpaqueDataChannelTrackerImpl::~LiveOpaqueDataChannelTrackerImpl() +{ + if (channel_ != XR_NULL_HANDLE) + { + if (status_ == XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTED_NV) + { + pfn_shutdown_(channel_); + } + pfn_destroy_(channel_); + channel_ = XR_NULL_HANDLE; + } +} + +void LiveOpaqueDataChannelTrackerImpl::update(int64_t monotonic_time_ns) +{ + last_update_time_ = monotonic_time_ns; + latest_message_.reset(); + + if (channel_ == XR_NULL_HANDLE) + { + return; + } + + XrOpaqueDataChannelStateNV state{}; + state.type = XR_TYPE_OPAQUE_DATA_CHANNEL_STATE_NV; + state.next = nullptr; + + XrResult result = pfn_get_state_(channel_, &state); + if (XR_FAILED(result)) + { + return; + } + status_ = state.state; + + if (status_ != XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTED_NV && + status_ != XR_OPAQUE_DATA_CHANNEL_STATUS_SHUTTING_NV) + { + return; + } + + // Drain all queued messages, keeping only the latest. + // xrReceiveOpaqueDataChannelNV dequeues one message per call. + std::vector buffer; + while (true) + { + uint32_t byte_count = 0; + result = pfn_receive_(channel_, 0, &byte_count, nullptr); + if (result == XR_ERROR_CHANNEL_NOT_CONNECTED_NV || byte_count == 0) + { + break; + } + if (XR_FAILED(result)) + { + break; + } + + buffer.resize(byte_count); + result = pfn_receive_(channel_, byte_count, &byte_count, buffer.data()); + if (XR_FAILED(result)) + { + break; + } + + latest_message_ = buffer; + } +} + +std::optional> LiveOpaqueDataChannelTrackerImpl::get_latest_message() const +{ + return latest_message_; +} + +} // namespace core diff --git a/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.hpp b/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.hpp new file mode 100644 index 000000000..7cb5bd692 --- /dev/null +++ b/src/core/live_trackers/cpp/live_opaque_data_channel_tracker_impl.hpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace core +{ + +class OpaqueDataChannelTracker; + +class LiveOpaqueDataChannelTrackerImpl : public IOpaqueDataChannelTrackerImpl +{ +public: + static std::vector required_extensions() + { + return { XR_NV_OPAQUE_DATA_CHANNEL_EXTENSION_NAME }; + } + + LiveOpaqueDataChannelTrackerImpl(const OpenXRSessionHandles& handles, + const OpaqueDataChannelTracker* tracker); + ~LiveOpaqueDataChannelTrackerImpl(); + + LiveOpaqueDataChannelTrackerImpl(const LiveOpaqueDataChannelTrackerImpl&) = delete; + LiveOpaqueDataChannelTrackerImpl& operator=(const LiveOpaqueDataChannelTrackerImpl&) = delete; + LiveOpaqueDataChannelTrackerImpl(LiveOpaqueDataChannelTrackerImpl&&) = delete; + LiveOpaqueDataChannelTrackerImpl& operator=(LiveOpaqueDataChannelTrackerImpl&&) = delete; + + void update(int64_t monotonic_time_ns) override; + std::optional> get_latest_message() const override; + +private: + XrInstance instance_; + XrOpaqueDataChannelNV channel_{ XR_NULL_HANDLE }; + XrOpaqueDataChannelStatusNV status_{ XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTING_NV }; + + PFN_xrCreateOpaqueDataChannelNV pfn_create_{ nullptr }; + PFN_xrDestroyOpaqueDataChannelNV pfn_destroy_{ nullptr }; + PFN_xrGetOpaqueDataChannelStateNV pfn_get_state_{ nullptr }; + PFN_xrReceiveOpaqueDataChannelNV pfn_receive_{ nullptr }; + PFN_xrShutdownOpaqueDataChannelNV pfn_shutdown_{ nullptr }; + + std::optional> latest_message_; + int64_t last_update_time_ = 0; +}; + +} // namespace core diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py index dd4265f93..6200e5a2f 100644 --- a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py @@ -9,6 +9,7 @@ from .controllers_source import ControllersSource from .pedals_source import Generic3AxisPedalSource from .full_body_source import FullBodySource +from .opaque_data_channel_source import OpaqueDataChannelSource from .deviceio_tensor_types import ( HeadPoseTrackedType, HandPoseTrackedType, @@ -29,6 +30,7 @@ "ControllersSource", "Generic3AxisPedalSource", "FullBodySource", + "OpaqueDataChannelSource", "HeadPoseTrackedType", "HandPoseTrackedType", "ControllerSnapshotTrackedType", diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/opaque_data_channel_source.py b/src/core/retargeting_engine/python/deviceio_source_nodes/opaque_data_channel_source.py new file mode 100644 index 000000000..de47cbbc9 --- /dev/null +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/opaque_data_channel_source.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Opaque Data Channel Source Node. + +Receives JSON teleop command messages from a CloudXR opaque data channel +and produces bool pulse signals for start/stop/reset commands. +""" + +import json +import logging +from typing import Any, TYPE_CHECKING + +from .interface import IDeviceIOSource +from ..interface.retargeter_core_types import RetargeterIO, RetargeterIOType +from ..interface.tensor_group import TensorGroup +from ..interface.tensor_group_type import OptionalType, TensorGroupType +from ..interface.tensor_type import TensorType + +if TYPE_CHECKING: + from isaacteleop.deviceio import ITracker + +logger = logging.getLogger(__name__) + + +class _RawBytesType(TensorType): + """Tensor type for raw bytes received from an opaque data channel.""" + + def _check_instance_compatibility(self, other: TensorType) -> bool: + return isinstance(other, _RawBytesType) + + def validate_value(self, value: Any) -> None: + if value is not None and not isinstance(value, (bytes, bytearray)): + raise TypeError( + f"Expected bytes or None for '{self.name}', got {type(value).__name__}" + ) + + +def _raw_bytes_group(name: str) -> TensorGroupType: + return TensorGroupType(name, [_RawBytesType(name)]) + + +class OpaqueDataChannelSource(IDeviceIOSource): + """Receives JSON teleop commands from a CloudXR opaque data channel. + + Parses ``teleop_command`` messages matching the WebXR client protocol:: + + {"type": "teleop_command", "message": {"command": "start teleop"}} + + and produces one-shot bool pulse outputs for ``start_command``, + ``stop_command``, and ``reset_command``. + + Inputs: + - "raw_message": Raw bytes from the opaque data channel (or None + when no message was received this frame). + + Outputs (Optional -- None when no matching command this frame): + - "start_command": bool pulse (True for one frame) + - "stop_command": bool pulse (True for one frame) + - "reset_command": bool pulse (True for one frame) + """ + + OUTPUT_START = "start_command" + OUTPUT_STOP = "stop_command" + OUTPUT_RESET = "reset_command" + + _INPUT_RAW = "raw_message" + + _COMMAND_MAP = { + "start teleop": OUTPUT_START, + "stop teleop": OUTPUT_STOP, + "reset teleop": OUTPUT_RESET, + } + + def __init__(self, uuid: bytes, name: str = "opaque_data_channel") -> None: + """Initialize with the UUID identifying the opaque data channel. + + Args: + uuid: 16-byte UUID matching the channel created by the runtime. + name: Unique node name for the retargeting graph. + """ + import isaacteleop.deviceio as deviceio + + self._tracker = deviceio.OpaqueDataChannelTracker(uuid) + super().__init__(name) + + def get_tracker(self) -> "ITracker": + return self._tracker + + def poll_tracker(self, deviceio_session: Any) -> RetargeterIO: + raw = self._tracker.get_latest_message(deviceio_session) + tg = TensorGroup(self.input_spec()[self._INPUT_RAW]) + tg[0] = raw + return {self._INPUT_RAW: tg} + + def input_spec(self) -> RetargeterIOType: + return {self._INPUT_RAW: _raw_bytes_group(self._INPUT_RAW)} + + def output_spec(self) -> RetargeterIOType: + from isaacteleop.teleop_session_manager.teleop_state_manager_types import ( + bool_signal, + ) + + return { + self.OUTPUT_START: OptionalType(bool_signal(self.OUTPUT_START)), + self.OUTPUT_STOP: OptionalType(bool_signal(self.OUTPUT_STOP)), + self.OUTPUT_RESET: OptionalType(bool_signal(self.OUTPUT_RESET)), + } + + def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + raw: bytes | None = inputs[self._INPUT_RAW][0] + + matched_output: str | None = None + if raw is not None: + matched_output = self._parse_command(raw) + + for key in (self.OUTPUT_START, self.OUTPUT_STOP, self.OUTPUT_RESET): + if key == matched_output: + outputs[key][0] = True + else: + outputs[key].set_none() + + @classmethod + def _parse_command(cls, raw: bytes) -> str | None: + """Parse raw bytes into a recognised output key, or None.""" + try: + payload = json.loads(raw) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.debug("Ignoring non-JSON opaque data channel message") + return None + + if not isinstance(payload, dict): + return None + + msg = payload.get("message") + if isinstance(msg, dict): + command = msg.get("command", "") + elif isinstance(msg, str): + command = msg + else: + return None + + return cls._COMMAND_MAP.get(command) diff --git a/src/core/retargeting_engine_tests/python/test_retargeter_reset.py b/src/core/retargeting_engine_tests/python/test_retargeter_reset.py new file mode 100644 index 000000000..1934190c0 --- /dev/null +++ b/src/core/retargeting_engine_tests/python/test_retargeter_reset.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for retargeter reset behaviour via ExecutionEvents. + +Verifies that stateful retargeters (GripperRetargeter, +LocomotionRootCmdRetargeter, Se3AbsRetargeter, Se3RelRetargeter) +correctly reinitialize their cross-step state when +``context.execution_events.reset`` is True. +""" + +import numpy as np +import numpy.testing as npt +import pytest + +from isaacteleop.retargeting_engine.interface import ( + ComputeContext, + ExecutionEvents, + ExecutionState, + OptionalTensorGroup, + TensorGroup, +) +from isaacteleop.retargeting_engine.interface.retargeter_core_types import GraphTime +from isaacteleop.retargeting_engine.interface.tensor_group_type import ( + OptionalTensorGroupType, +) + +from isaacteleop.retargeters import ( + GripperRetargeter, + GripperRetargeterConfig, + LocomotionRootCmdRetargeter, + LocomotionRootCmdRetargeterConfig, + Se3AbsRetargeter, + Se3RelRetargeter, + Se3RetargeterConfig, +) + + +def _make_context(*, reset: bool = False) -> ComputeContext: + return ComputeContext( + graph_time=GraphTime(sim_time_ns=0, real_time_ns=0), + execution_events=ExecutionEvents( + reset=reset, execution_state=ExecutionState.RUNNING + ), + ) + + +def _build_io(retargeter): + """Build inputs/outputs for a retargeter, using OptionalTensorGroup for optional specs.""" + inputs = {} + for k, v in retargeter.input_spec().items(): + if isinstance(v, OptionalTensorGroupType): + inputs[k] = OptionalTensorGroup(v) + else: + inputs[k] = TensorGroup(v) + outputs = {} + for k, v in retargeter.output_spec().items(): + if isinstance(v, OptionalTensorGroupType): + outputs[k] = OptionalTensorGroup(v) + else: + outputs[k] = TensorGroup(v) + return inputs, outputs + + +# --------------------------------------------------------------------------- +# LocomotionRootCmdRetargeter +# --------------------------------------------------------------------------- + + +class TestLocomotionRootCmdRetargeterReset: + """LocomotionRootCmdRetargeter must restore initial_hip_height on reset.""" + + @pytest.fixture() + def retargeter(self): + cfg = LocomotionRootCmdRetargeterConfig(initial_hip_height=0.72) + return LocomotionRootCmdRetargeter(cfg, name="loco") + + def test_reset_restores_hip_height(self, retargeter): + inputs, outputs = _build_io(retargeter) + + retargeter._hip_height = 0.95 + + retargeter.compute(inputs, outputs, _make_context(reset=True)) + + cmd = np.from_dlpack(outputs["root_command"][0]) + assert cmd[3] == pytest.approx(0.72), "hip_height should be reset to initial" + + def test_no_reset_preserves_hip_height(self, retargeter): + inputs, outputs = _build_io(retargeter) + + retargeter._hip_height = 0.95 + + retargeter.compute(inputs, outputs, _make_context(reset=False)) + + cmd = np.from_dlpack(outputs["root_command"][0]) + assert cmd[3] == pytest.approx(0.95), ( + "hip_height should not change without reset" + ) + + +# --------------------------------------------------------------------------- +# Se3AbsRetargeter +# --------------------------------------------------------------------------- + + +class TestSe3AbsRetargeterReset: + """Se3AbsRetargeter must reinitialize _last_pose on reset.""" + + @pytest.fixture() + def retargeter(self): + cfg = Se3RetargeterConfig(input_device="controller_right") + return Se3AbsRetargeter(cfg, name="se3abs") + + def test_reset_clears_last_pose(self, retargeter): + """After reset with no input, output should be identity pose.""" + inputs, outputs = _build_io(retargeter) + + retargeter._last_pose = np.array( + [1.0, 2.0, 3.0, 0.5, 0.5, 0.5, 0.5], dtype=np.float32 + ) + + retargeter.compute(inputs, outputs, _make_context(reset=True)) + + pose = np.from_dlpack(outputs["ee_pose"][0]) + identity = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], dtype=np.float32) + npt.assert_array_almost_equal(pose, identity) + + def test_no_reset_returns_stale_pose(self, retargeter): + """Without reset and no input, output should be the stale _last_pose.""" + inputs, outputs = _build_io(retargeter) + + stale = np.array([1.0, 2.0, 3.0, 0.5, 0.5, 0.5, 0.5], dtype=np.float32) + retargeter._last_pose = stale.copy() + + retargeter.compute(inputs, outputs, _make_context(reset=False)) + + pose = np.from_dlpack(outputs["ee_pose"][0]) + npt.assert_array_almost_equal(pose, stale) + + +# --------------------------------------------------------------------------- +# Se3RelRetargeter +# --------------------------------------------------------------------------- + + +class TestSe3RelRetargeterReset: + """Se3RelRetargeter must reinitialize all cross-step state on reset.""" + + @pytest.fixture() + def retargeter(self): + cfg = Se3RetargeterConfig(input_device="controller_right") + return Se3RelRetargeter(cfg, name="se3rel") + + def test_reset_restores_first_frame(self, retargeter): + retargeter._first_frame = False + retargeter._smoothed_delta_pos = np.array([1.0, 2.0, 3.0]) + retargeter._smoothed_delta_rot = np.array([0.1, 0.2, 0.3]) + + inputs, outputs = _build_io(retargeter) + retargeter.compute(inputs, outputs, _make_context(reset=True)) + + assert retargeter._first_frame is True + npt.assert_array_equal(retargeter._smoothed_delta_pos, np.zeros(3)) + npt.assert_array_equal(retargeter._smoothed_delta_rot, np.zeros(3)) + assert retargeter._previous_thumb_tip is None + assert retargeter._previous_index_tip is None + + def test_no_reset_preserves_state(self, retargeter): + stale_pos = np.array([1.0, 2.0, 3.0]) + stale_rot = np.array([0.1, 0.2, 0.3]) + stale_wrist = np.array([0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 1.0]) + + retargeter._first_frame = False + retargeter._smoothed_delta_pos = stale_pos.copy() + retargeter._smoothed_delta_rot = stale_rot.copy() + retargeter._previous_wrist = stale_wrist.copy() + + inputs, outputs = _build_io(retargeter) + retargeter.compute(inputs, outputs, _make_context(reset=False)) + + assert retargeter._first_frame is False + npt.assert_array_equal(retargeter._smoothed_delta_pos, stale_pos) + npt.assert_array_equal(retargeter._smoothed_delta_rot, stale_rot) + npt.assert_array_equal(retargeter._previous_wrist, stale_wrist) + + +# --------------------------------------------------------------------------- +# GripperRetargeter +# --------------------------------------------------------------------------- + + +class TestGripperRetargeterReset: + """GripperRetargeter must restore _previous_gripper_command on reset.""" + + @pytest.fixture() + def retargeter(self): + cfg = GripperRetargeterConfig(hand_side="right") + return GripperRetargeter(cfg, name="gripper") + + def test_reset_reopens_gripper(self, retargeter): + """After reset with no input, gripper should output open (1.0).""" + inputs, outputs = _build_io(retargeter) + + retargeter._previous_gripper_command = True # closed + + retargeter.compute(inputs, outputs, _make_context(reset=True)) + + cmd = outputs["gripper_command"][0] + assert cmd == pytest.approx(1.0), "gripper should be open after reset" + + def test_no_reset_preserves_closed_gripper(self, retargeter): + """Without reset, _previous_gripper_command stays True (closed).""" + inputs, outputs = _build_io(retargeter) + + retargeter._previous_gripper_command = True # closed + + retargeter.compute(inputs, outputs, _make_context(reset=False)) + + assert retargeter._previous_gripper_command is True, ( + "gripper state should stay closed without reset" + ) diff --git a/src/core/teleop_session_manager/python/__init__.py b/src/core/teleop_session_manager/python/__init__.py index 9175d71e2..06936b06c 100644 --- a/src/core/teleop_session_manager/python/__init__.py +++ b/src/core/teleop_session_manager/python/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from .teleop_session import TeleopSession @@ -15,6 +15,7 @@ DefaultTeleopStateManager, TwoButtonTeleopStateManager, ) +from .command_teleop_state_manager import CommandTeleopStateManager from .teleop_state_manager_types import ( bool_signal, teleop_state_channel, @@ -32,6 +33,7 @@ "TeleopStateManager", "DefaultTeleopStateManager", "TwoButtonTeleopStateManager", + "CommandTeleopStateManager", "bool_signal", "teleop_state_channel", "reset_event_channel", diff --git a/src/core/teleop_session_manager/python/command_teleop_state_manager.py b/src/core/teleop_session_manager/python/command_teleop_state_manager.py new file mode 100644 index 000000000..347b712ed --- /dev/null +++ b/src/core/teleop_session_manager/python/command_teleop_state_manager.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Command-driven teleop state manager. + +Translates explicit start/stop/reset command pulses (e.g. from a WebXR UI +via an opaque data channel) into ExecutionEvents. Unlike +DefaultTeleopStateManager which uses edge-detected button toggles, this +manager treats each command as a direct state transition. +""" + +from isaacteleop.retargeting_engine.interface import RetargeterIOType +from isaacteleop.retargeting_engine.interface.retargeter_core_types import ( + ComputeContext, + RetargeterIO, +) +from isaacteleop.retargeting_engine.interface.tensor_group_type import OptionalType +from isaacteleop.retargeting_engine.interface.execution_events import ( + ExecutionState, + ExecutionEvents, +) + +from .teleop_state_manager_retargeter import TeleopStateManager +from .teleop_state_manager_types import bool_signal + + +class CommandTeleopStateManager(TeleopStateManager): + """Teleop state manager driven by explicit start/stop/reset commands. + + All inputs are optional. When no command is received in a frame the + manager holds its current state. Initial state is STOPPED. + + Inputs (all OptionalType): + - start_command: pulse sets state to RUNNING. + - stop_command: pulse sets state to STOPPED. + - reset_command: pulse emits ``reset=True`` without changing state. + + Priority when multiple commands arrive in the same frame: + stop > reset > start. + """ + + INPUT_START = "start_command" + INPUT_STOP = "stop_command" + INPUT_RESET = "reset_command" + + def __init__(self, name: str) -> None: + self._state = ExecutionState.STOPPED + super().__init__(name=name) + + def input_spec(self) -> RetargeterIOType: + return { + self.INPUT_START: OptionalType(bool_signal(self.INPUT_START)), + self.INPUT_STOP: OptionalType(bool_signal(self.INPUT_STOP)), + self.INPUT_RESET: OptionalType(bool_signal(self.INPUT_RESET)), + } + + def _compute_execution_events( + self, inputs: RetargeterIO, context: ComputeContext + ) -> ExecutionEvents: + del context + + start = self._read_pulse(inputs, self.INPUT_START) + stop = self._read_pulse(inputs, self.INPUT_STOP) + reset = self._read_pulse(inputs, self.INPUT_RESET) + + if stop: + self._state = ExecutionState.STOPPED + elif start: + self._state = ExecutionState.RUNNING + + return ExecutionEvents(reset=reset, execution_state=self._state) + + @staticmethod + def _read_pulse(inputs: RetargeterIO, key: str) -> bool: + group = inputs[key] + if group.is_none: + return False + return bool(group[0]) diff --git a/src/core/teleop_session_manager_tests/python/test_command_state_manager.py b/src/core/teleop_session_manager_tests/python/test_command_state_manager.py new file mode 100644 index 000000000..58635dd62 --- /dev/null +++ b/src/core/teleop_session_manager_tests/python/test_command_state_manager.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for CommandTeleopStateManager and OpaqueDataChannelSource._parse_command. + +Tests state transitions, priority ordering, and JSON teleop command parsing +without requiring OpenXR hardware. +""" + +import json + +from isaacteleop.retargeting_engine.interface.tensor_group import ( + OptionalTensorGroup, +) +from isaacteleop.retargeting_engine.interface.tensor_group_type import OptionalType +from isaacteleop.retargeting_engine.interface.execution_events import ( + ExecutionState, + ExecutionEvents, +) +from isaacteleop.retargeting_engine.interface.retargeter_core_types import ( + ComputeContext, + RetargeterIO, +) +from isaacteleop.teleop_session_manager.teleop_state_manager_types import bool_signal +from isaacteleop.teleop_session_manager.command_teleop_state_manager import ( + CommandTeleopStateManager, +) +from isaacteleop.retargeting_engine.deviceio_source_nodes.opaque_data_channel_source import ( + OpaqueDataChannelSource, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_inputs( + start: bool | None = None, + stop: bool | None = None, + reset: bool | None = None, +) -> RetargeterIO: + """Build a RetargeterIO matching CommandTeleopStateManager.input_spec. + + None means the input is absent (OptionalTensorGroup.set_none()). + """ + spec = { + CommandTeleopStateManager.INPUT_START: OptionalType( + bool_signal(CommandTeleopStateManager.INPUT_START) + ), + CommandTeleopStateManager.INPUT_STOP: OptionalType( + bool_signal(CommandTeleopStateManager.INPUT_STOP) + ), + CommandTeleopStateManager.INPUT_RESET: OptionalType( + bool_signal(CommandTeleopStateManager.INPUT_RESET) + ), + } + inputs: RetargeterIO = {} + for key, tgt in spec.items(): + tg = OptionalTensorGroup(tgt) + val = {"start_command": start, "stop_command": stop, "reset_command": reset}[ + key + ] + if val is None: + tg.set_none() + else: + tg[0] = val + inputs[key] = tg + return inputs + + +def _make_context() -> ComputeContext: + return ComputeContext( + graph_time=0.0, + execution_events=ExecutionEvents( + reset=False, execution_state=ExecutionState.UNKNOWN + ), + ) + + +# --------------------------------------------------------------------------- +# CommandTeleopStateManager tests +# --------------------------------------------------------------------------- + + +class TestCommandTeleopStateManagerInitialState: + def test_initial_state_is_stopped(self): + sm = CommandTeleopStateManager("sm") + events = sm._compute_execution_events(_make_inputs(), _make_context()) + assert events.execution_state == ExecutionState.STOPPED + assert events.reset is False + + +class TestCommandTeleopStateManagerTransitions: + def test_start_transitions_to_running(self): + sm = CommandTeleopStateManager("sm") + events = sm._compute_execution_events(_make_inputs(start=True), _make_context()) + assert events.execution_state == ExecutionState.RUNNING + + def test_stop_transitions_to_stopped(self): + sm = CommandTeleopStateManager("sm") + sm._compute_execution_events(_make_inputs(start=True), _make_context()) + events = sm._compute_execution_events(_make_inputs(stop=True), _make_context()) + assert events.execution_state == ExecutionState.STOPPED + + def test_reset_does_not_change_state(self): + sm = CommandTeleopStateManager("sm") + sm._compute_execution_events(_make_inputs(start=True), _make_context()) + events = sm._compute_execution_events(_make_inputs(reset=True), _make_context()) + assert events.execution_state == ExecutionState.RUNNING + assert events.reset is True + + def test_no_inputs_holds_state(self): + sm = CommandTeleopStateManager("sm") + sm._compute_execution_events(_make_inputs(start=True), _make_context()) + events = sm._compute_execution_events(_make_inputs(), _make_context()) + assert events.execution_state == ExecutionState.RUNNING + assert events.reset is False + + +class TestCommandTeleopStateManagerPriority: + def test_stop_beats_start(self): + sm = CommandTeleopStateManager("sm") + events = sm._compute_execution_events( + _make_inputs(start=True, stop=True), _make_context() + ) + assert events.execution_state == ExecutionState.STOPPED + + def test_stop_beats_reset(self): + sm = CommandTeleopStateManager("sm") + sm._compute_execution_events(_make_inputs(start=True), _make_context()) + events = sm._compute_execution_events( + _make_inputs(stop=True, reset=True), _make_context() + ) + assert events.execution_state == ExecutionState.STOPPED + assert events.reset is True + + def test_reset_with_start(self): + sm = CommandTeleopStateManager("sm") + events = sm._compute_execution_events( + _make_inputs(start=True, reset=True), _make_context() + ) + assert events.execution_state == ExecutionState.RUNNING + assert events.reset is True + + +# --------------------------------------------------------------------------- +# OpaqueDataChannelSource._parse_command tests +# --------------------------------------------------------------------------- + + +class TestParseCommand: + def test_start_command(self): + raw = json.dumps( + {"type": "teleop_command", "message": {"command": "start teleop"}} + ).encode() + assert OpaqueDataChannelSource._parse_command(raw) == "start_command" + + def test_stop_command(self): + raw = json.dumps( + {"type": "teleop_command", "message": {"command": "stop teleop"}} + ).encode() + assert OpaqueDataChannelSource._parse_command(raw) == "stop_command" + + def test_reset_command(self): + raw = json.dumps( + {"type": "teleop_command", "message": {"command": "reset teleop"}} + ).encode() + assert OpaqueDataChannelSource._parse_command(raw) == "reset_command" + + def test_unknown_command_returns_none(self): + raw = json.dumps( + {"type": "teleop_command", "message": {"command": "fly away"}} + ).encode() + assert OpaqueDataChannelSource._parse_command(raw) is None + + def test_malformed_json_returns_none(self): + assert OpaqueDataChannelSource._parse_command(b"not json") is None + + def test_non_dict_payload_returns_none(self): + assert OpaqueDataChannelSource._parse_command(b'"just a string"') is None + + def test_flat_message_string(self): + """Handles legacy payload where message is a plain string.""" + raw = json.dumps({"type": "teleop_command", "message": "start teleop"}).encode() + assert OpaqueDataChannelSource._parse_command(raw) == "start_command" + + def test_empty_bytes(self): + assert OpaqueDataChannelSource._parse_command(b"") is None + + def test_missing_message_key(self): + raw = json.dumps({"type": "teleop_command"}).encode() + assert OpaqueDataChannelSource._parse_command(raw) is None diff --git a/src/retargeters/gripper_retargeter.py b/src/retargeters/gripper_retargeter.py index 5f0d86179..46ad24958 100644 --- a/src/retargeters/gripper_retargeter.py +++ b/src/retargeters/gripper_retargeter.py @@ -82,6 +82,8 @@ def output_spec(self) -> RetargeterIOType: def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: """Computes gripper command based on controller trigger (priority) or pinch distance (fallback).""" + if context.execution_events.reset: + self._previous_gripper_command = False gripper_out = outputs["gripper_command"] diff --git a/src/retargeters/locomotion_retargeter.py b/src/retargeters/locomotion_retargeter.py index 2cc7ba105..72ac7d81e 100644 --- a/src/retargeters/locomotion_retargeter.py +++ b/src/retargeters/locomotion_retargeter.py @@ -120,6 +120,9 @@ def output_spec(self) -> RetargeterIOType: def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: """Computes root command from controller inputs.""" + if context.execution_events.reset: + self._hip_height = self._config.initial_hip_height + left_thumbstick_x = 0.0 left_thumbstick_y = 0.0 right_thumbstick_x = 0.0 diff --git a/src/retargeters/se3_retargeter.py b/src/retargeters/se3_retargeter.py index 80acda760..b29807dd8 100644 --- a/src/retargeters/se3_retargeter.py +++ b/src/retargeters/se3_retargeter.py @@ -205,6 +205,11 @@ def output_spec(self) -> RetargeterIOType: } def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + if context.execution_events.reset: + self._last_pose = np.array( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], dtype=np.float32 + ) + ee_pose = outputs["ee_pose"] device_name = self._config.input_device inp = inputs[device_name] @@ -333,6 +338,14 @@ def output_spec(self) -> RetargeterIOType: } def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + if context.execution_events.reset: + self._smoothed_delta_pos = np.zeros(3) + self._smoothed_delta_rot = np.zeros(3) + self._previous_wrist = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + self._previous_thumb_tip = None + self._previous_index_tip = None + self._first_frame = True + ee_delta = outputs["ee_delta"] device_name = self._config.input_device inp = inputs[device_name]