@@ -97,7 +97,7 @@ def get_value_loss(flat_params):
9797 for param in value_net .parameters ():
9898 value_loss += param .pow (2 ).sum () * args .l2_reg
9999 value_loss .backward ()
100- return (value_loss .data .double ().numpy ()[ 0 ] , get_flat_grad_from (value_net ).data .double ().numpy ())
100+ return (value_loss .data .double ().numpy (), get_flat_grad_from (value_net ).data .double ().numpy ())
101101
102102 flat_params , _ , opt_info = scipy .optimize .fmin_l_bfgs_b (get_value_loss , get_flat_params_from (value_net ).double ().numpy (), maxiter = 25 )
103103 set_flat_params_to (value_net , torch .Tensor (flat_params ))
@@ -108,7 +108,12 @@ def get_value_loss(flat_params):
108108 fixed_log_prob = normal_log_density (Variable (actions ), action_means , action_log_stds , action_stds ).data .clone ()
109109
110110 def get_loss (volatile = False ):
111- action_means , action_log_stds , action_stds = policy_net (Variable (states , volatile = volatile ))
111+ if volatile :
112+ with torch .no_grad ():
113+ action_means , action_log_stds , action_stds = policy_net (Variable (states ))
114+ else :
115+ action_means , action_log_stds , action_stds = policy_net (Variable (states ))
116+
112117 log_prob = normal_log_density (Variable (actions ), action_means , action_log_stds , action_stds )
113118 action_loss = - Variable (advantages ) * torch .exp (log_prob - Variable (fixed_log_prob ))
114119 return action_loss .mean ()
0 commit comments