-
Notifications
You must be signed in to change notification settings - Fork 373
Open
Description
Hi,
from the LSTMWrapper (models.py):
for name, param in self.named_parameters():
if 'layer_norm' in name:
continue
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name and param.ndim >= 2:
nn.init.orthogonal_(param, 1.0)
Example env g2048 (torch.py), should init the head of the actor with a std of 0.01:
self.decoder = torch.nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
nn.GELU(),
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, num_atns), std=0.01),
)
Output of the first 5 values of that layer:
Layer: policy.decoder.2.weight
Initializing at layer_init, std=0.01:
Before: tensor([-0.0307, 0.0421, -0.0238, -0.0417, 0.0277], grad_fn=<SliceBackward0>)
After: tensor([-9.2381e-05, -4.8542e-04, 8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)
Initializing at LSTMWrapper, std=1.0:
Before: tensor([-9.2381e-05, -4.8542e-04, 8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)
After: tensor([-0.0174, 0.0143, 0.0315, 0.0548, -0.0340], grad_fn=<SliceBackward0>)
Values at torch.py: tensor([-9.2381e-05, -4.8542e-04, 8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)
Values at models.py: tensor([-0.0174, 0.0143, 0.0315, 0.0548, -0.0340], grad_fn=<SliceBackward0>)
Values at pufferl.py: tensor([-0.0174, 0.0143, 0.0315, 0.0548, -0.0340], device='cuda:0', grad_fn=<SliceBackward0>)
It got reinitialized with the std 1.0 defined at LSTMWrapper.
The default from layer_init (pytorch.py) is sqrt(2):
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
"""CleanRL's default layer initialization"""
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
List of environments affected by initializing weights with a std of sqrt(2) or user defined and later reinitialized with a std of 1.0:
nmmo3
towerclimb
impulsewars
atari
cogames
craftax
gpudrive
gvgai
kinetix
mani_skill
metta
mujoco
nethack
trade_sim
tribal_village
asteroids
battle
blastar
boids
breakout
cartpole
checkers
connect4
convert
convert_circle
drive
drone
enduro
freeway
g2048
go
grid
memory
moba
oldgrid
onlyfish
pacman
pong
pysquared
rware
sanity
shared_pool
snake
squared
target
terraform
tetris
tmaze
trash_pickup
tripletriad
whisker_racer
Metadata
Metadata
Assignees
Labels
No labels