Skip to content

Commit e200eb8

Browse files
committed
Update to pytorch 0.4
1 parent eb26e29 commit e200eb8

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_value_loss(flat_params):
9797
for param in value_net.parameters():
9898
value_loss += param.pow(2).sum() * args.l2_reg
9999
value_loss.backward()
100-
return (value_loss.data.double().numpy()[0], get_flat_grad_from(value_net).data.double().numpy())
100+
return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())
101101

102102
flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25)
103103
set_flat_params_to(value_net, torch.Tensor(flat_params))
@@ -108,7 +108,12 @@ def get_value_loss(flat_params):
108108
fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()
109109

110110
def get_loss(volatile=False):
111-
action_means, action_log_stds, action_stds = policy_net(Variable(states, volatile=volatile))
111+
if volatile:
112+
with torch.no_grad():
113+
action_means, action_log_stds, action_stds = policy_net(Variable(states))
114+
else:
115+
action_means, action_log_stds, action_stds = policy_net(Variable(states))
116+
112117
log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
113118
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
114119
return action_loss.mean()

models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.autograd as autograd
33
import torch.nn as nn
4-
import torch.nn.functional as F
54

65

76
class Policy(nn.Module):
@@ -21,8 +20,8 @@ def __init__(self, num_inputs, num_outputs):
2120
self.final_value = 0
2221

2322
def forward(self, x):
24-
x = F.tanh(self.affine1(x))
25-
x = F.tanh(self.affine2(x))
23+
x = torch.tanh(self.affine1(x))
24+
x = torch.tanh(self.affine2(x))
2625

2726
action_mean = self.action_mean(x)
2827
action_log_std = self.action_log_std.expand_as(action_mean)
@@ -41,8 +40,8 @@ def __init__(self, num_inputs):
4140
self.value_head.bias.data.mul_(0.0)
4241

4342
def forward(self, x):
44-
x = F.tanh(self.affine1(x))
45-
x = F.tanh(self.affine2(x))
43+
x = torch.tanh(self.affine1(x))
44+
x = torch.tanh(self.affine2(x))
4645

4746
state_values = self.value_head(x)
4847
return state_values

trpo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ def linesearch(model,
3232
max_backtracks=10,
3333
accept_ratio=.1):
3434
fval = f(True).data
35-
print("fval before", fval[0])
35+
print("fval before", fval.item())
3636
for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
3737
xnew = x + stepfrac * fullstep
3838
set_flat_params_to(model, xnew)
3939
newfval = f(True).data
4040
actual_improve = fval - newfval
4141
expected_improve = expected_improve_rate * stepfrac
4242
ratio = actual_improve / expected_improve
43-
print("a/e/r", actual_improve[0], expected_improve[0], ratio[0])
43+
print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())
4444

45-
if ratio[0] > accept_ratio and actual_improve[0] > 0:
46-
print("fval after", newfval[0])
45+
if ratio.item() > accept_ratio and actual_improve.item() > 0:
46+
print("fval after", newfval.item())
4747
return True, xnew
4848
return False, x
4949

0 commit comments

Comments
 (0)