Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8b5fc4d
huge refactor
vmoens Nov 11, 2025
52aa42a
fix test
vmoens Nov 11, 2025
a9a986a
refactor
vmoens Nov 12, 2025
d2a5891
use id(weight)
vmoens Nov 12, 2025
e232631
clone the state_dict
vmoens Nov 12, 2025
3fd0d8e
address device mismatch
vmoens Nov 13, 2025
31098bb
fix policy with device
vmoens Nov 13, 2025
b9eb2e8
no TD state_dict
vmoens Nov 13, 2025
0af42e5
fix legacy code
vmoens Nov 13, 2025
066ae5b
fix state dict device
vmoens Nov 13, 2025
eeedff3
fix unwanted model_id
vmoens Nov 13, 2025
9852bc9
final?
vmoens Nov 13, 2025
9ad7832
final!
vmoens Nov 14, 2025
2eef1e3
fixes
vmoens Nov 15, 2025
dfc911a
amend
vmoens Nov 15, 2025
ba2bea7
amend
vmoens Nov 18, 2025
9bdff0e
intermediate-fix
vmoens Nov 25, 2025
f605a86
intermediate
vmoens Dec 1, 2025
3f5d46b
partial
vmoens Dec 1, 2025
0ca9928
fixes
vmoens Dec 1, 2025
1879704
fixes
vmoens Dec 6, 2025
22bbc33
fixes
vmoens Dec 6, 2025
d018763
amend
vmoens Dec 6, 2025
5562201
amend
vmoens Dec 6, 2025
512a5ed
edit
vmoens Dec 7, 2025
c8c24a2
edits
vmoens Dec 7, 2025
e7d5579
amend
vmoens Dec 7, 2025
dd66ea5
amend
vmoens Dec 7, 2025
6aabf2a
edits
vmoens Dec 7, 2025
52538db
edits
vmoens Dec 7, 2025
0686b28
edits
vmoens Dec 7, 2025
2768abb
edits
vmoens Dec 7, 2025
a496e3e
edits
vmoens Dec 7, 2025
ef51447
edits
vmoens Dec 7, 2025
c32d263
edits
vmoens Dec 8, 2025
238f50a
edits
vmoens Dec 8, 2025
c8be973
edits
vmoens Dec 8, 2025
f12514e
lint
vmoens Dec 8, 2025
786a6e0
edits
vmoens Dec 8, 2025
d597f8f
edits
vmoens Dec 8, 2025
1d64492
edits
vmoens Dec 8, 2025
15d2b17
edits
vmoens Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export SDL_VIDEODRIVER=dummy
# legacy from bash scripts: remove?
conda env config vars set \
MAX_IDLE_COUNT=1000 \
MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG TOKENIZERS_PARALLELISM=true
MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=INFO TOKENIZERS_PARALLELISM=true

pip3 install pip --upgrade
pip install virtualenv
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

if __name__ == "__main__":
avail_devices = ("cpu",)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, capacity: int):
rank = args.rank
storage_type = args.storage

torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy


def single_collector_setup():
Expand Down
632 changes: 470 additions & 162 deletions docs/source/reference/collectors_weightsync.rst

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion docs/source/reference/envs_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ Helpers
:toctree: generated/
:template: rl_template_fun.rst

RandomPolicy
check_env_specs
exploration_type
get_available_libraries
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules_actors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ TensorDictModules and SafeModules
SafeModule
SafeSequential
TanhModule
RandomPolicy

Probabilistic actors
--------------------
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/multi_weight_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms.module import ModuleTransform
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
from torchrl.weight_update import MultiProcessWeightSyncScheme


def make_module():
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/weight_sync_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory():
env.close()

# Shared memory is more efficient for frequent updates
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
scheme = SharedMemWeightSyncScheme(strategy="tensordict")

print("Creating multi-collector with shared memory...")
collector = MultiSyncDataCollector(
Expand Down
207 changes: 0 additions & 207 deletions examples/collectors/weight_sync_standalone.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main():
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/delayed_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main():
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self, capacity: int):
if __name__ == "__main__":
args = parser.parse_args()
rank = args.rank
torchrl_logger.info(f"Rank: {rank}")
torchrl_logger.debug(f"RANK: {rank}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
2 changes: 0 additions & 2 deletions sota-implementations/expert-iteration/ei_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import time

from typing import Any, Literal

import torch
Expand Down Expand Up @@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger):
"""
try:
# Wait a bit for wandb to initialize
import time

max_attempts = 10
for attempt in range(max_attempts):
Expand Down
Loading
Loading