@@ -3953,14 +3953,22 @@ def test_weight_update(self, weight_updater):
39533953 policy_weights = TensorDict .from_module (policy )
39543954 kwargs = {}
39553955 if weight_updater == "scheme_shared" :
3956- kwargs = {"weight_sync_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
3956+ scheme = SharedMemWeightSyncScheme ()
3957+ kwargs = {"weight_sync_schemes" : {"policy" : scheme }}
39573958 elif weight_updater == "scheme_pipe" :
3958- kwargs = {"weight_sync_schemes" : {"policy" : MultiProcessWeightSyncScheme ()}}
3959+ scheme = MultiProcessWeightSyncScheme ()
3960+ kwargs = {"weight_sync_schemes" : {"policy" : scheme }}
39593961 elif weight_updater == "weight_updater" :
3962+ scheme = None
39603963 kwargs = {"weight_updater" : self .MPSWeightUpdaterBase (policy_weights , 2 )}
39613964 else :
39623965 raise NotImplementedError
39633966
3967+ if scheme is not None :
3968+ scheme .init_on_sender (
3969+ model = policy_factory (), devices = [device ] * 2 , model_id = "policy"
3970+ )
3971+
39643972 collector = MultiSyncDataCollector (
39653973 create_env_fn = [env_maker , env_maker ],
39663974 policy_factory = policy_factory ,
@@ -3973,6 +3981,8 @@ def test_weight_update(self, weight_updater):
39733981 storing_device = "cpu" ,
39743982 ** kwargs ,
39753983 )
3984+ if weight_updater == "weight_updater" :
3985+ assert collector ._legacy_weight_updater
39763986
39773987 # When using policy_factory, must pass weights explicitly
39783988 collector .update_policy_weights_ (policy_weights )
0 commit comments