Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions dimos/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def monitor_threads(request):
# HuggingFace safetensors conversion thread - no user cleanup API
# https://github.com/huggingface/transformers/issues/29513
"Thread-auto_conversion",
# rpyc spawns per-call response threads inside the connection's
# SpawnThread protocol. They get cleaned up when the connection
# is closed (in the rpyc client's teardown), not per-test.
"RpycSpawnThread-",
]
new_threads = [
t
Expand Down
16 changes: 16 additions & 0 deletions dimos/core/coordination/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def create(cls, module: type[ModuleBase], kwargs: dict[str, Any]) -> Self:
)


# These fields cannot be pickled.
_PROXY_FIELDS = ("transport_map", "global_config_overrides", "remapping_map")


@dataclass(frozen=True)
class Blueprint:
blueprints: tuple[BlueprintAtom, ...]
Expand All @@ -155,6 +159,18 @@ class Blueprint:
requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple)
configurator_checks: "tuple[SystemConfigurator, ...]" = field(default_factory=tuple)

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state.pop("active_blueprints", None) # recomputable cached_property
for name in _PROXY_FIELDS:
state[name] = dict(state[name])
return state

def __setstate__(self, state: dict[str, Any]) -> None:
for name in _PROXY_FIELDS:
state[name] = MappingProxyType(state[name])
self.__dict__.update(state)

@classmethod
def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
blueprint = BlueprintAtom.create(module, kwargs)
Expand Down
88 changes: 88 additions & 0 deletions dimos/core/coordination/coordinator_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from dimos.protocol.rpc.pubsubrpc import LCMRPC
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
from dimos.protocol.rpc.spec import RPCInspectable

logger = setup_logger()


class CoordinatorRPC:
"""Owns the LCM RPC connection to the singleton Coordinator service."""

NAME = "Coordinator"

def __init__(self, rpc: LCMRPC) -> None:
self._rpc = rpc

@classmethod
def serve(cls, coordinator: RPCInspectable) -> CoordinatorRPC:
"""Publish `coordinator`'s @rpc methods under the `Coordinator/` prefix."""
cls._ensure_no_existing_service()
rpc = LCMRPC()
rpc.serve_module_rpc(coordinator, name=cls.NAME)
rpc.start()
return cls(rpc)

@classmethod
def connect(cls, *, timeout: float) -> CoordinatorRPC:
"""Attach to a running Coordinator, raising `TimeoutError` if none answers."""
rpc = LCMRPC()
rpc.start()
client = cls(rpc)
try:
client.call("ping", rpc_timeout=timeout)
except BaseException:
rpc.stop()
raise
return client

def call(self, method: str, *args: Any, rpc_timeout: float | None = None, **kwargs: Any) -> Any:
"""Invoke `Coordinator/<method>` and return its result."""
result, _unsub = self._rpc.call_sync(
f"{self.NAME}/{method}",
([*args], kwargs),
rpc_timeout=rpc_timeout,
)
return result

@property
def rpc(self) -> LCMRPC:
return self._rpc

def stop(self) -> None:
try:
self._rpc.stop()
except Exception:
logger.error("Error closing Coordinator RPC service", exc_info=True)

@classmethod
def _ensure_no_existing_service(cls) -> None:
probe = LCMRPC()
probe.start()
try:
try:
probe.call_sync(f"{cls.NAME}/ping", ([], {}), rpc_timeout=0.5)
except TimeoutError:
return
raise RuntimeError(f"another {cls.NAME} service is already running on this LCM bus")
finally:
probe.stop()
69 changes: 49 additions & 20 deletions dimos/core/coordination/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Mapping, MutableMapping
from collections.abc import Callable, Mapping, MutableMapping
import importlib
import inspect
import shutil
import sys
import threading
from typing import TYPE_CHECKING, Any, cast

from dimos.core.coordination.rpyc_server import RpycServer
from dimos.core.coordination.coordinator_rpc import CoordinatorRPC
from dimos.core.coordination.worker_manager import WorkerManager
from dimos.core.coordination.worker_manager_docker import WorkerManagerDocker
from dimos.core.coordination.worker_manager_python import WorkerManagerPython
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import ModuleBase, ModuleSpec
from dimos.core.resource import Resource
from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport
from dimos.porcelain.remote_module_source import ModuleDescriptor
from dimos.spec.utils import is_spec, spec_annotation_compliance, spec_structural_compliance
from dimos.utils.generic import short_id
from dimos.utils.logging_config import setup_logger
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
self._module_transports: dict[type[ModuleBase], dict[str, PubSubTransport[Any]]] = {}
self._started = False
self._modules_lock = threading.RLock()
self._rpyc = RpycServer(self)
self._coordinator_rpc: CoordinatorRPC | None = None

def start(self) -> None:
from dimos.core.o3dpickle import register_picklers
Expand All @@ -74,7 +75,9 @@ def start(self) -> None:
self._started = True

def stop(self) -> None:
self._rpyc.stop()
if self._coordinator_rpc is not None:
self._coordinator_rpc.stop()
self._coordinator_rpc = None

for module_class, module in reversed(self._deployed_modules.items()):
logger.info("Stopping module...", module=module_class.__name__)
Expand All @@ -92,25 +95,50 @@ def _stop_manager(m: WorkerManager) -> None:

safe_thread_map(tuple(self._managers.values()), _stop_manager)

def start_rpyc_service(self) -> int:
return self._rpyc.start()
def start_rpc_service(self) -> None:
"""Expose the coordinator's API as @rpc methods over LCM."""
if self._coordinator_rpc is not None:
return
self._coordinator_rpc = CoordinatorRPC.serve(self)

def list_module_names(self) -> list[str]:
@property
def rpcs(self) -> dict[str, Callable[..., Any]]:
"""Methods exposed via the Coordinator @rpc service."""
return {
"ping": self.ping,
"list_modules": self.list_modules,
"load_blueprint_by_name": self.load_blueprint_by_name,
"load_blueprint": self.load_blueprint,
"restart_module_by_class_name": self.restart_module_by_class_name,
}

def ping(self) -> str:
"""Used by clients to check if the coordinator is alive and responsive."""
return "pong"

def list_modules(self) -> list[ModuleDescriptor]:
with self._modules_lock:
return [cls.__name__ for cls in self._deployed_modules]
descriptors: list[ModuleDescriptor] = []
for cls in self._deployed_modules:
qualified = f"{cls.__module__}.{cls.__name__}"
descriptors.append(
ModuleDescriptor(
class_name=cls.__name__,
qualified_path=qualified,
rpc_names=list(cls.rpcs.keys()),
)
)
return descriptors

def get_module_endpoint(self, class_name: str) -> tuple[str, int, int]:
"""Return (host, worker_rpyc_port, module_id) for the given class name.
def load_blueprint_by_name(self, name: str) -> None:
# Avoid circular import.
from dimos.robot.get_all_blueprints import get_by_name

Lazily starts the worker-side RPyC server on first use.
"""
self.load_blueprint(get_by_name(name))

def list_module_names(self) -> list[str]:
with self._modules_lock:
for cls, proxy in self._deployed_modules.items():
if cls.__name__ == class_name:
actor = cast("ModuleProxy", proxy).actor_instance
port = actor.start_rpyc()
return ("localhost", int(port), int(actor._module_id))
raise KeyError(class_name)
return [cls.__name__ for cls in self._deployed_modules]

def health_check(self) -> bool:
return all(m.health_check() for m in self._managers.values())
Expand Down Expand Up @@ -404,11 +432,12 @@ def restart_module_by_class_name(
class_name: str,
*,
reload_source: bool = True,
) -> ModuleProxyProtocol:
) -> None:
with self._modules_lock:
for cls in self._deployed_modules:
if cls.__name__ == class_name:
return self._restart_module(cls, reload_source=reload_source)
self._restart_module(cls, reload_source=reload_source)
return
raise ValueError(f"No deployed module with class name {class_name!r}")

def restart_module(
Expand Down
44 changes: 0 additions & 44 deletions dimos/core/coordination/python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,15 @@
import signal
import sys
import threading
import time
import traceback
from typing import TYPE_CHECKING, Any

from rpyc.utils.server import ThreadedServer

from dimos.core.coordination.rpyc_services import WorkerRpycService
from dimos.core.coordination.worker_messages import (
CallMethodRequest,
DeployModuleRequest,
GetAttrRequest,
SetRefRequest,
ShutdownRequest,
StartRpycRequest,
SuppressConsoleRequest,
UndeployModuleRequest,
WorkerRequest,
Expand Down Expand Up @@ -135,10 +130,6 @@ def set_ref(self, ref: Any) -> ActorFuture:
result = self._send_request_to_worker(SetRefRequest(module_id=self._module_id, ref=ref))
return ActorFuture(result)

def start_rpyc(self) -> int:
port: int = self._send_request_to_worker(StartRpycRequest())
return port

def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the worker process."""
if name.startswith("_"):
Expand Down Expand Up @@ -332,8 +323,6 @@ def _suppress_console_output() -> None:
class _WorkerState:
instances: dict[int, Any]
worker_id: int
rpyc_server: ThreadedServer | None = None
rpyc_thread: threading.Thread | None = None
should_stop: bool = False


Expand Down Expand Up @@ -405,40 +394,7 @@ def _handle_request(request: Any, state: _WorkerState) -> WorkerResponse:
_suppress_console_output()
return WorkerResponse(result=True)

case StartRpycRequest():
if state.rpyc_server is not None:
return WorkerResponse(result=state.rpyc_server.port)
WorkerRpycService._instances = state.instances
state.rpyc_server = ThreadedServer(
WorkerRpycService,
port=0,
hostname=global_config.listen_host,
protocol_config={
"allow_all_attrs": True,
"allow_public_attrs": True,
"allow_pickle": True,
},
)
# `ThreadedServer.__init__` binds the socket but `listen()` only
# runs once `start()` executes on the thread, which sets
# `active=True` immediately after. Wait on that flag so callers
# never see a Connection refused before the accept loop is live.
state.rpyc_thread = threading.Thread(target=state.rpyc_server.start, daemon=True)
state.rpyc_thread.start()
deadline = time.monotonic() + 5.0
while not state.rpyc_server.active:
if not state.rpyc_thread.is_alive():
raise RuntimeError("rpyc server thread died before listening")
if time.monotonic() > deadline:
raise RuntimeError("rpyc server failed to start listening within 5s")
time.sleep(0.001)
return WorkerResponse(result=state.rpyc_server.port)

case ShutdownRequest():
if state.rpyc_server is not None:
state.rpyc_server.close()
if state.rpyc_thread is not None:
state.rpyc_thread.join(timeout=5)
state.should_stop = True
return WorkerResponse(result=True)

Expand Down
Loading
Loading