66
77import argparse
88import importlib .util
9+
910import pickle
11+ import threading
1012import time
1113
1214import pytest
2628 RayWeightSyncScheme ,
2729 RPCWeightSyncScheme ,
2830 SharedMemTransport ,
29- )
30- from torchrl .weight_update .utils import _resolve_model
31- from torchrl .weight_update .weight_sync_schemes import (
3231 SharedMemWeightSyncScheme ,
3332 WeightStrategy ,
3433)
34+ from torchrl .weight_update .utils import _resolve_model
3535
3636_has_ray = importlib .util .find_spec ("ray" ) is not None
3737
@@ -43,7 +43,7 @@ def worker_update_policy(pipe, timeout=5.0):
4343 policy .bias .fill_ (0.0 )
4444
4545 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
46- scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
46+ scheme .init_on_receiver (model_id = "policy" , pipe = pipe , model = policy )
4747 receiver = scheme .get_receiver ()
4848
4949 if receiver ._transport .pipe .poll (timeout ):
@@ -62,7 +62,7 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
6262 policy .bias .fill_ (0.0 )
6363
6464 scheme = MultiProcessWeightSyncScheme (strategy = "tensordict" )
65- scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
65+ scheme .init_on_receiver (model_id = "policy" , pipe = pipe , model = policy )
6666 receiver = scheme .get_receiver ()
6767
6868 if receiver ._transport .pipe .poll (timeout ):
@@ -100,7 +100,7 @@ def test_mp_transport_basic(self):
100100 proc .start ()
101101
102102 test_weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
103- transport .send_weights ("policy" , test_weights )
103+ transport .send_weights (test_weights )
104104
105105 proc .join (timeout = 10.0 )
106106 assert not proc .is_alive ()
@@ -113,7 +113,7 @@ def test_mp_transport_async(self):
113113 proc .start ()
114114
115115 test_weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
116- transport .send_weights_async ("policy" , test_weights )
116+ transport .send_weights_async (test_weights )
117117 transport .wait_ack ()
118118
119119 proc .join (timeout = 10.0 )
@@ -124,13 +124,16 @@ def test_shared_mem_transport(self):
124124 {"weight" : torch .zeros (2 , 4 ), "bias" : torch .zeros (2 )}, batch_size = []
125125 ).share_memory_ ()
126126
127- transport = SharedMemTransport ({"policy" : shared_buffer })
127+ transport = SharedMemTransport ()
128+ transport .register_weights (
129+ params_map = {0 : shared_buffer }, init_queues = {0 : mp .Queue ()}
130+ )
128131
129132 new_weights = TensorDict (
130133 {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}, batch_size = []
131134 )
132135
133- transport .send_weights ("policy" , new_weights )
136+ transport .send_weights (new_weights )
134137
135138 assert torch .allclose (shared_buffer ["weight" ], torch .ones (2 , 4 ))
136139 assert torch .allclose (shared_buffer ["bias" ], torch .ones (2 ))
@@ -255,7 +258,10 @@ def test_shared_mem_scheme(self):
255258 {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}, batch_size = []
256259 )
257260
258- transport .send_weights ("policy" , new_weights )
261+ transport .register_weights (
262+ params_map = {0 : shared_buffer }, init_queues = {0 : mp .Queue ()}
263+ )
264+ transport .send_weights (new_weights )
259265
260266 assert torch .allclose (shared_buffer ["weight" ], torch .ones (2 , 4 ))
261267 assert torch .allclose (shared_buffer ["bias" ], torch .ones (2 ))
@@ -265,7 +271,7 @@ def test_no_weight_sync_scheme(self):
265271 transport = scheme .create_transport (None )
266272
267273 weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
268- transport .send_weights ("policy" , weights )
274+ transport .send_weights (weights )
269275
270276 @classmethod
271277 def _worker_with_receive (cls , pipe , scheme ):
@@ -274,7 +280,7 @@ def _worker_with_receive(cls, pipe, scheme):
274280 policy .weight .fill_ (0.0 )
275281 policy .bias .fill_ (0.0 )
276282
277- scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
283+ scheme .init_on_receiver (model_id = "policy" , pipe = pipe , model = policy )
278284 receiver = scheme .get_receiver ()
279285
280286 # Non-blocking receive should return False when no data
@@ -354,7 +360,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
354360 collector .shutdown ()
355361
356362 def test_multisyncdatacollector_multiprocess_scheme (self , simple_policy ):
357- scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
363+ scheme = MultiProcessWeightSyncScheme ()
358364
359365 collector = MultiSyncDataCollector (
360366 create_env_fn = [
@@ -660,73 +666,76 @@ def test_multiprocess_scheme_serialize_after_sender_init(self):
660666 parent_pipe .close ()
661667 child_pipe .close ()
662668
663- def test_shared_mem_scheme_serialize_before_init (self ):
664- """Test that uninitialized SharedMemWeightSyncScheme can be pickled."""
665- scheme = SharedMemWeightSyncScheme (strategy = "tensordict" )
666-
667- # Serialize and deserialize
668- pickled = pickle .dumps (scheme )
669- restored = pickle .loads (pickled )
670-
671- # Check that configuration is preserved
672- assert restored .strategy == "tensordict"
673- assert restored ._sender is None
674- assert restored ._receiver is None
669+ # Serialize and deserialize
670+ @staticmethod
671+ def _get_scheme_from_queue (q , scheme ):
672+ try :
673+ restored = scheme
674+ # Check that configuration is preserved but runtime state is cleared
675+ assert restored .strategy == "tensordict"
676+ assert restored ._sender is None
677+ assert not restored ._initialized_on_sender
678+
679+ q .put ("success" )
680+ except Exception as err :
681+ q .put (f"failure: { err } " )
682+ finally :
683+ q .close ()
675684
685+ @pytest .mark .timeout (10 )
676686 def test_shared_mem_scheme_serialize_after_init (self ):
677687 """Test that initialized SharedMemWeightSyncScheme can be pickled."""
678688 parent_pipe , child_pipe = mp .Pipe ()
689+ q = mp .Queue ()
690+ try :
691+ # Create shared buffer
692+ shared_buffer = TensorDict (
693+ {"weight" : torch .zeros (2 , 4 ), "bias" : torch .zeros (2 )}, batch_size = []
694+ ).share_memory_ ()
695+
696+ scheme = SharedMemWeightSyncScheme ()
697+
698+ def init_on_sender (scheme , pipe ):
699+ scheme .init_on_sender (params_map = {0 : shared_buffer })
700+ scheme .synchronize_weights ()
701+ msg = pipe .recv ()
702+ assert msg == "registered"
703+
704+ def init_on_receiver (scheme : SharedMemWeightSyncScheme , child_pipe ):
705+ scheme .init_on_receiver (
706+ worker_idx = 0 , model = nn .Linear (4 , 2 , device = "meta" )
707+ )
708+ scheme .synchronize_weights ()
709+ child_pipe .send ("registered" )
710+
711+ future_sender = threading .Thread (
712+ target = init_on_sender ,
713+ kwargs = {"scheme" : scheme , "pipe" : parent_pipe },
714+ )
715+ future_receiver = threading .Thread (
716+ target = init_on_receiver ,
717+ kwargs = {"scheme" : scheme , "child_pipe" : child_pipe },
718+ )
719+ future_receiver .start ()
720+ future_sender .start ()
721+ future_receiver .join (timeout = 10.0 )
722+ future_sender .join (timeout = 10.0 )
679723
680- # Create shared buffer
681- shared_buffer = TensorDict (
682- {"weight" : torch .zeros (2 , 4 ), "bias" : torch .zeros (2 )}, batch_size = []
683- ).share_memory_ ()
684-
685- scheme = SharedMemWeightSyncScheme (
686- strategy = "tensordict" ,
687- )
688-
689- def init_on_sender (scheme , child_pipe ):
690- (model_id , data ), msg = child_pipe .recv ()
691- if msg == "register_shared_weights" :
692- child_pipe .send ((None , "registered" ))
693- else :
694- raise ValueError (f"Expected 'register_shared_weights' but got { msg } " )
695-
696- # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker
697- import threading
698-
699- future_sender = threading .Thread (
700- target = scheme .init_on_sender ,
701- kwargs = {"model_id" : "policy" , "pipes" : [parent_pipe ]},
702- )
703- future_receiver = threading .Thread (
704- target = init_on_sender ,
705- kwargs = {"scheme" : scheme , "child_pipe" : child_pipe },
706- )
707- future_receiver .start ()
708- future_sender .start ()
709- future_receiver .join ()
710- future_sender .join ()
711-
712- # Scheme now has _sender with non-serializable state
713- assert scheme ._sender is not None
714-
715- # Serialize and deserialize
716- pickled = pickle .dumps (scheme )
717- restored = pickle .loads (pickled )
718-
719- # Check that configuration is preserved but runtime state is cleared
720- assert restored .strategy == "tensordict"
721- assert restored ._sender is None
722- assert not restored ._initialized_on_sender
723-
724- # Note: policy_weights dict is preserved (but may need re-sharing)
725- assert "policy" in restored .policy_weights
724+ # Scheme now has _sender with non-serializable state
725+ assert scheme ._sender is not None
726726
727- # Clean up
728- parent_pipe .close ()
729- child_pipe .close ()
727+ proc = mp .Process (target = self ._get_scheme_from_queue , args = (q , scheme ))
728+ proc .start ()
729+ try :
730+ msg = q .get (timeout = 10.0 )
731+ assert msg == "success" , msg
732+ finally :
733+ proc .join ()
734+ finally :
735+ q .close ()
736+ # Clean up
737+ parent_pipe .close ()
738+ child_pipe .close ()
730739
731740 def test_no_weight_sync_scheme_serialize (self ):
732741 """Test that NoWeightSyncScheme can be pickled."""
@@ -809,7 +818,7 @@ def test_scheme_reinitialization_after_unpickle(self):
809818 """Test that a scheme can be re-initialized after unpickling.
810819
811820 This is the expected workflow: pickle a scheme, unpickle it in a worker,
812- then call init_on_worker () to establish new runtime resources.
821+ then call init_on_receiver () to establish new runtime resources.
813822 """
814823 # Initialize and pickle a scheme
815824 parent_pipe , child_pipe = mp .Pipe ()
0 commit comments