Skip to content

LSTMWrapper reinitializing all weights. #459

@enriquezaf

Description

@enriquezaf

Hi,

from the LSTMWrapper (models.py):

for name, param in self.named_parameters():
    if 'layer_norm' in name:
	continue
    if "bias" in name:
	nn.init.constant_(param, 0)
    elif "weight" in name and param.ndim >= 2:
	nn.init.orthogonal_(param, 1.0)

Example env g2048 (torch.py), should init the head of the actor with a std of 0.01:

self.decoder = torch.nn.Sequential(
    pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
    nn.GELU(),
    pufferlib.pytorch.layer_init(nn.Linear(hidden_size, num_atns), std=0.01),
)

Output of the first 5 values of that layer:

Layer: policy.decoder.2.weight
Initializing at layer_init, std=0.01:
	Before:      tensor([-0.0307,  0.0421, -0.0238, -0.0417,  0.0277], grad_fn=<SliceBackward0>)
	After:       tensor([-9.2381e-05, -4.8542e-04,  8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)

Initializing at LSTMWrapper, std=1.0: 
	Before:	     tensor([-9.2381e-05, -4.8542e-04,  8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)
	After:	     tensor([-0.0174,  0.0143,  0.0315,  0.0548, -0.0340], grad_fn=<SliceBackward0>)

Values at torch.py:       tensor([-9.2381e-05, -4.8542e-04,  8.5702e-05, -7.5245e-05, -1.5632e-04], grad_fn=<SliceBackward0>)
Values at models.py:      tensor([-0.0174,  0.0143,  0.0315,  0.0548, -0.0340], grad_fn=<SliceBackward0>)
Values at pufferl.py:     tensor([-0.0174,  0.0143,  0.0315,  0.0548, -0.0340], device='cuda:0', grad_fn=<SliceBackward0>)

It got reinitialized with the std 1.0 defined at LSTMWrapper.

The default from layer_init (pytorch.py) is sqrt(2):

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    """CleanRL's default layer initialization"""
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

List of environments affected by initializing weights with a std of sqrt(2) or user defined and later reinitialized with a std of 1.0:

nmmo3
towerclimb
impulsewars
atari
cogames
craftax
gpudrive
gvgai
kinetix
mani_skill
metta
mujoco
nethack
trade_sim
tribal_village
asteroids
battle
blastar
boids
breakout
cartpole
checkers
connect4
convert
convert_circle
drive
drone
enduro
freeway
g2048
go
grid
memory
moba
oldgrid
onlyfish
pacman
pong
pysquared
rware
sanity
shared_pool
snake
squared
target
terraform
tetris
tmaze
trash_pickup
tripletriad
whisker_racer

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions