Skip to content

Commit 9ad7832

Browse files
committed
final!
1 parent 9852bc9 commit 9ad7832

File tree

8 files changed

+407
-107
lines changed

8 files changed

+407
-107
lines changed

docs/source/reference/collectors_weightsync.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Weight update schemes can be used outside of collectors for custom synchronizati
4949
The new simplified API provides four core methods for weight synchronization:
5050

5151
- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
52-
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
52+
- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side
5353
- ``get_sender()`` - Get the configured sender instance
5454
- ``get_receiver()`` - Get the configured receiver instance
5555

@@ -85,7 +85,7 @@ Here's a basic example:
8585
# or sender.send_async(weights); sender.wait_async() # Asynchronous send
8686
8787
# On the worker process side:
88-
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
88+
# scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy)
8989
# receiver = scheme.get_receiver()
9090
# # Non-blocking check for new weights
9191
# if receiver.receive(timeout=0.001):

test/test_weightsync.py

Lines changed: 84 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import argparse
88
import importlib.util
9+
910
import pickle
11+
import threading
1012
import time
1113

1214
import pytest
@@ -26,12 +28,10 @@
2628
RayWeightSyncScheme,
2729
RPCWeightSyncScheme,
2830
SharedMemTransport,
29-
)
30-
from torchrl.weight_update.utils import _resolve_model
31-
from torchrl.weight_update.weight_sync_schemes import (
3231
SharedMemWeightSyncScheme,
3332
WeightStrategy,
3433
)
34+
from torchrl.weight_update.utils import _resolve_model
3535

3636
_has_ray = importlib.util.find_spec("ray") is not None
3737

@@ -43,7 +43,7 @@ def worker_update_policy(pipe, timeout=5.0):
4343
policy.bias.fill_(0.0)
4444

4545
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
46-
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
46+
scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy)
4747
receiver = scheme.get_receiver()
4848

4949
if receiver._transport.pipe.poll(timeout):
@@ -62,7 +62,7 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
6262
policy.bias.fill_(0.0)
6363

6464
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
65-
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
65+
scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy)
6666
receiver = scheme.get_receiver()
6767

6868
if receiver._transport.pipe.poll(timeout):
@@ -100,7 +100,7 @@ def test_mp_transport_basic(self):
100100
proc.start()
101101

102102
test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
103-
transport.send_weights("policy", test_weights)
103+
transport.send_weights(test_weights)
104104

105105
proc.join(timeout=10.0)
106106
assert not proc.is_alive()
@@ -113,7 +113,7 @@ def test_mp_transport_async(self):
113113
proc.start()
114114

115115
test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
116-
transport.send_weights_async("policy", test_weights)
116+
transport.send_weights_async(test_weights)
117117
transport.wait_ack()
118118

119119
proc.join(timeout=10.0)
@@ -124,13 +124,16 @@ def test_shared_mem_transport(self):
124124
{"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[]
125125
).share_memory_()
126126

127-
transport = SharedMemTransport({"policy": shared_buffer})
127+
transport = SharedMemTransport()
128+
transport.register_weights(
129+
params_map={0: shared_buffer}, init_queues={0: mp.Queue()}
130+
)
128131

129132
new_weights = TensorDict(
130133
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
131134
)
132135

133-
transport.send_weights("policy", new_weights)
136+
transport.send_weights(new_weights)
134137

135138
assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4))
136139
assert torch.allclose(shared_buffer["bias"], torch.ones(2))
@@ -255,7 +258,10 @@ def test_shared_mem_scheme(self):
255258
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
256259
)
257260

258-
transport.send_weights("policy", new_weights)
261+
transport.register_weights(
262+
params_map={0: shared_buffer}, init_queues={0: mp.Queue()}
263+
)
264+
transport.send_weights(new_weights)
259265

260266
assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4))
261267
assert torch.allclose(shared_buffer["bias"], torch.ones(2))
@@ -265,7 +271,7 @@ def test_no_weight_sync_scheme(self):
265271
transport = scheme.create_transport(None)
266272

267273
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
268-
transport.send_weights("policy", weights)
274+
transport.send_weights(weights)
269275

270276
@classmethod
271277
def _worker_with_receive(cls, pipe, scheme):
@@ -274,7 +280,7 @@ def _worker_with_receive(cls, pipe, scheme):
274280
policy.weight.fill_(0.0)
275281
policy.bias.fill_(0.0)
276282

277-
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
283+
scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy)
278284
receiver = scheme.get_receiver()
279285

280286
# Non-blocking receive should return False when no data
@@ -354,7 +360,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
354360
collector.shutdown()
355361

356362
def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):
357-
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
363+
scheme = MultiProcessWeightSyncScheme()
358364

359365
collector = MultiSyncDataCollector(
360366
create_env_fn=[
@@ -660,73 +666,76 @@ def test_multiprocess_scheme_serialize_after_sender_init(self):
660666
parent_pipe.close()
661667
child_pipe.close()
662668

663-
def test_shared_mem_scheme_serialize_before_init(self):
664-
"""Test that uninitialized SharedMemWeightSyncScheme can be pickled."""
665-
scheme = SharedMemWeightSyncScheme(strategy="tensordict")
666-
667-
# Serialize and deserialize
668-
pickled = pickle.dumps(scheme)
669-
restored = pickle.loads(pickled)
670-
671-
# Check that configuration is preserved
672-
assert restored.strategy == "tensordict"
673-
assert restored._sender is None
674-
assert restored._receiver is None
669+
# Serialize and deserialize
670+
@staticmethod
671+
def _get_scheme_from_queue(q, scheme):
672+
try:
673+
restored = scheme
674+
# Check that configuration is preserved but runtime state is cleared
675+
assert restored.strategy == "tensordict"
676+
assert restored._sender is None
677+
assert not restored._initialized_on_sender
678+
679+
q.put("success")
680+
except Exception as err:
681+
q.put(f"failure: {err}")
682+
finally:
683+
q.close()
675684

685+
@pytest.mark.timeout(10)
676686
def test_shared_mem_scheme_serialize_after_init(self):
677687
"""Test that initialized SharedMemWeightSyncScheme can be pickled."""
678688
parent_pipe, child_pipe = mp.Pipe()
689+
q = mp.Queue()
690+
try:
691+
# Create shared buffer
692+
shared_buffer = TensorDict(
693+
{"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[]
694+
).share_memory_()
695+
696+
scheme = SharedMemWeightSyncScheme()
697+
698+
def init_on_sender(scheme, pipe):
699+
scheme.init_on_sender(params_map={0: shared_buffer})
700+
scheme.synchronize_weights()
701+
msg = pipe.recv()
702+
assert msg == "registered"
703+
704+
def init_on_receiver(scheme: SharedMemWeightSyncScheme, child_pipe):
705+
scheme.init_on_receiver(
706+
worker_idx=0, model=nn.Linear(4, 2, device="meta")
707+
)
708+
scheme.synchronize_weights()
709+
child_pipe.send("registered")
710+
711+
future_sender = threading.Thread(
712+
target=init_on_sender,
713+
kwargs={"scheme": scheme, "pipe": parent_pipe},
714+
)
715+
future_receiver = threading.Thread(
716+
target=init_on_receiver,
717+
kwargs={"scheme": scheme, "child_pipe": child_pipe},
718+
)
719+
future_receiver.start()
720+
future_sender.start()
721+
future_receiver.join(timeout=10.0)
722+
future_sender.join(timeout=10.0)
679723

680-
# Create shared buffer
681-
shared_buffer = TensorDict(
682-
{"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[]
683-
).share_memory_()
684-
685-
scheme = SharedMemWeightSyncScheme(
686-
strategy="tensordict",
687-
)
688-
689-
def init_on_sender(scheme, child_pipe):
690-
(model_id, data), msg = child_pipe.recv()
691-
if msg == "register_shared_weights":
692-
child_pipe.send((None, "registered"))
693-
else:
694-
raise ValueError(f"Expected 'register_shared_weights' but got {msg}")
695-
696-
# Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker
697-
import threading
698-
699-
future_sender = threading.Thread(
700-
target=scheme.init_on_sender,
701-
kwargs={"model_id": "policy", "pipes": [parent_pipe]},
702-
)
703-
future_receiver = threading.Thread(
704-
target=init_on_sender,
705-
kwargs={"scheme": scheme, "child_pipe": child_pipe},
706-
)
707-
future_receiver.start()
708-
future_sender.start()
709-
future_receiver.join()
710-
future_sender.join()
711-
712-
# Scheme now has _sender with non-serializable state
713-
assert scheme._sender is not None
714-
715-
# Serialize and deserialize
716-
pickled = pickle.dumps(scheme)
717-
restored = pickle.loads(pickled)
718-
719-
# Check that configuration is preserved but runtime state is cleared
720-
assert restored.strategy == "tensordict"
721-
assert restored._sender is None
722-
assert not restored._initialized_on_sender
723-
724-
# Note: policy_weights dict is preserved (but may need re-sharing)
725-
assert "policy" in restored.policy_weights
724+
# Scheme now has _sender with non-serializable state
725+
assert scheme._sender is not None
726726

727-
# Clean up
728-
parent_pipe.close()
729-
child_pipe.close()
727+
proc = mp.Process(target=self._get_scheme_from_queue, args=(q, scheme))
728+
proc.start()
729+
try:
730+
msg = q.get(timeout=10.0)
731+
assert msg == "success", msg
732+
finally:
733+
proc.join()
734+
finally:
735+
q.close()
736+
# Clean up
737+
parent_pipe.close()
738+
child_pipe.close()
730739

731740
def test_no_weight_sync_scheme_serialize(self):
732741
"""Test that NoWeightSyncScheme can be pickled."""
@@ -809,7 +818,7 @@ def test_scheme_reinitialization_after_unpickle(self):
809818
"""Test that a scheme can be re-initialized after unpickling.
810819
811820
This is the expected workflow: pickle a scheme, unpickle it in a worker,
812-
then call init_on_worker() to establish new runtime resources.
821+
then call init_on_receiver() to establish new runtime resources.
813822
"""
814823
# Initialize and pickle a scheme
815824
parent_pipe, child_pipe = mp.Pipe()

torchrl/collectors/_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _make_policy_factory(
3939

4040
if weight_sync_scheme is not None:
4141
# Initialize the receiver on the worker side
42-
weight_sync_scheme.init_on_worker(
42+
weight_sync_scheme.init_on_receiver(
4343
model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe
4444
)
4545
# Get the receiver and synchronize initial weights
@@ -147,7 +147,7 @@ def _main_async_collector(
147147
inner_collector._weight_receivers[model_id] = receiver
148148
else:
149149
# Initialize receivers for other models
150-
scheme.init_on_worker(model_id=model_id, context=inner_collector)
150+
scheme.init_on_receiver(model_id=model_id, context=inner_collector)
151151
receiver = scheme.get_receiver()
152152
receiver.synchronize_weights(worker_idx=worker_idx)
153153
inner_collector._weight_receivers[model_id] = receiver

torchrl/weight_update/_mp.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import weakref
4-
from typing import Any
4+
from typing import Any, overload
55

66
from torchrl.weight_update.weight_sync_schemes import (
77
TransportBackend,
@@ -22,7 +22,7 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme):
2222
Synchronization flow:
2323
- init_on_sender() creates a MPWeightSender and registers all worker pipes
2424
- synchronize_weights() triggers the initial weight distribution via pipes
25-
- init_on_worker() creates a MPWeightReceiver that receives from its pipe
25+
- init_on_receiver() creates a MPWeightReceiver that receives from its pipe
2626
- Subsequent updates use send() which extracts, sends, and waits for ACKs
2727
2828
Args:
@@ -55,6 +55,27 @@ def synchronize_weights(self):
5555
)
5656
self._sender.synchronize_weights()
5757

58+
@overload
59+
def init_on_sender(
60+
self,
61+
model_id: str,
62+
context: Any,
63+
**kwargs,
64+
) -> None:
65+
...
66+
67+
@overload
68+
def init_on_sender(
69+
self,
70+
model_id: str,
71+
context: None = None,
72+
*,
73+
pipes: list = ...,
74+
num_workers: int | None = None,
75+
**kwargs,
76+
) -> None:
77+
...
78+
5879
def init_on_sender(
5980
self,
6081
model_id: str,
@@ -93,7 +114,28 @@ def init_on_sender(
93114
self._sender = sender
94115
self._initialized_on_sender = True
95116

96-
def init_on_worker(
117+
@overload
118+
def init_on_receiver(
119+
self,
120+
model_id: str,
121+
context: Any,
122+
**kwargs,
123+
) -> None:
124+
...
125+
126+
@overload
127+
def init_on_receiver(
128+
self,
129+
model_id: str,
130+
context: None = None,
131+
*,
132+
pipe: Any = ...,
133+
model: Any | None = None,
134+
**kwargs,
135+
) -> None:
136+
...
137+
138+
def init_on_receiver(
97139
self,
98140
model_id: str,
99141
context: Any = None,
@@ -138,7 +180,7 @@ def create_transport(self, pipe: Any) -> TransportBackend:
138180
"""Create an MPTransport using the provided pipe.
139181
140182
Note:
141-
This is used internally by init_on_sender/init_on_worker.
183+
This is used internally by init_on_sender/init_on_receiver.
142184
"""
143185
return MPTransport(pipe)
144186

0 commit comments

Comments
 (0)