Skip to content

Commit dfc911a

Browse files
committed
amend
1 parent 2eef1e3 commit dfc911a

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

test/test_collector.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3953,14 +3953,22 @@ def test_weight_update(self, weight_updater):
39533953
policy_weights = TensorDict.from_module(policy)
39543954
kwargs = {}
39553955
if weight_updater == "scheme_shared":
3956-
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
3956+
scheme = SharedMemWeightSyncScheme()
3957+
kwargs = {"weight_sync_schemes": {"policy": scheme}}
39573958
elif weight_updater == "scheme_pipe":
3958-
kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}}
3959+
scheme = MultiProcessWeightSyncScheme()
3960+
kwargs = {"weight_sync_schemes": {"policy": scheme}}
39593961
elif weight_updater == "weight_updater":
3962+
scheme = None
39603963
kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)}
39613964
else:
39623965
raise NotImplementedError
39633966

3967+
if scheme is not None:
3968+
scheme.init_on_sender(
3969+
model=policy_factory(), devices=[device] * 2, model_id="policy"
3970+
)
3971+
39643972
collector = MultiSyncDataCollector(
39653973
create_env_fn=[env_maker, env_maker],
39663974
policy_factory=policy_factory,
@@ -3973,6 +3981,8 @@ def test_weight_update(self, weight_updater):
39733981
storing_device="cpu",
39743982
**kwargs,
39753983
)
3984+
if weight_updater == "weight_updater":
3985+
assert collector._legacy_weight_updater
39763986

39773987
# When using policy_factory, must pass weights explicitly
39783988
collector.update_policy_weights_(policy_weights)

0 commit comments

Comments
 (0)