From 0d6ec557b1eaae0053e354f7ffe99fd0b69606c6 Mon Sep 17 00:00:00 2001 From: runjerry Date: Mon, 11 Jan 2021 12:57:08 -0800 Subject: [PATCH] add pets environments and reward functions --- alf/environments/gym_pets/__init__.py | 30 +++ alf/environments/gym_pets/envs/__init__.py | 18 ++ .../gym_pets/envs/assets/cartpole.xml | 35 ++++ .../gym_pets/envs/assets/half_cheetah.xml | 95 +++++++++ .../gym_pets/envs/assets/pusher.xml | 101 ++++++++++ .../gym_pets/envs/assets/reacher3d.xml | 154 ++++++++++++++ alf/environments/gym_pets/envs/cartpole.py | 72 +++++++ .../gym_pets/envs/half_cheetah.py | 64 ++++++ alf/environments/gym_pets/envs/pusher.py | 76 +++++++ alf/environments/gym_pets/envs/reacher.py | 97 +++++++++ alf/examples/mbrl_reward_functions.py | 188 ++++++++++++++++++ alf/utils/common.py | 21 ++ 12 files changed, 951 insertions(+) create mode 100644 alf/environments/gym_pets/__init__.py create mode 100644 alf/environments/gym_pets/envs/__init__.py create mode 100644 alf/environments/gym_pets/envs/assets/cartpole.xml create mode 100644 alf/environments/gym_pets/envs/assets/half_cheetah.xml create mode 100644 alf/environments/gym_pets/envs/assets/pusher.xml create mode 100644 alf/environments/gym_pets/envs/assets/reacher3d.xml create mode 100644 alf/environments/gym_pets/envs/cartpole.py create mode 100644 alf/environments/gym_pets/envs/half_cheetah.py create mode 100644 alf/environments/gym_pets/envs/pusher.py create mode 100644 alf/environments/gym_pets/envs/reacher.py create mode 100644 alf/examples/mbrl_reward_functions.py diff --git a/alf/environments/gym_pets/__init__.py b/alf/environments/gym_pets/__init__.py new file mode 100644 index 000000000..f87e745d0 --- /dev/null +++ b/alf/environments/gym_pets/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gym.envs.registration import register + +register( + id='MBRLCartpole-v0', + entry_point='alf.environments.gym_pets.envs:CartpoleEnv') + +register( + id='MBRLReacher3D-v0', + entry_point='alf.environments.gym_pets.envs:Reacher3DEnv') + +register( + id='MBRLPusher-v0', entry_point='alf.environments.gym_pets.envs:PusherEnv') + +register( + id='MBRLHalfCheetah-v0', + entry_point='alf.environments.gym_pets.envs:HalfCheetahEnv') diff --git a/alf/environments/gym_pets/envs/__init__.py b/alf/environments/gym_pets/envs/__init__.py new file mode 100644 index 000000000..c659fd52c --- /dev/null +++ b/alf/environments/gym_pets/envs/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cartpole import CartpoleEnv +from .half_cheetah import HalfCheetahEnv +from .pusher import PusherEnv +from .reacher import Reacher3DEnv diff --git a/alf/environments/gym_pets/envs/assets/cartpole.xml b/alf/environments/gym_pets/envs/assets/cartpole.xml new file mode 100644 index 000000000..284a58c9a --- /dev/null +++ b/alf/environments/gym_pets/envs/assets/cartpole.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + diff --git a/alf/environments/gym_pets/envs/assets/half_cheetah.xml b/alf/environments/gym_pets/envs/assets/half_cheetah.xml new file mode 100644 index 000000000..40a1cb62c --- /dev/null +++ b/alf/environments/gym_pets/envs/assets/half_cheetah.xml @@ -0,0 +1,95 @@ + + + + + + + + + + diff --git a/alf/environments/gym_pets/envs/assets/pusher.xml b/alf/environments/gym_pets/envs/assets/pusher.xml new file mode 100644 index 000000000..9e81b01a6 --- /dev/null +++ b/alf/environments/gym_pets/envs/assets/pusher.xml @@ -0,0 +1,101 @@ + + + diff --git a/alf/environments/gym_pets/envs/assets/reacher3d.xml b/alf/environments/gym_pets/envs/assets/reacher3d.xml new file mode 100644 index 000000000..a51c71b93 --- /dev/null +++ b/alf/environments/gym_pets/envs/assets/reacher3d.xml @@ -0,0 +1,154 @@ + + + + diff --git a/alf/environments/gym_pets/envs/cartpole.py b/alf/environments/gym_pets/envs/cartpole.py new file mode 100644 index 000000000..191ea7b61 --- /dev/null +++ b/alf/environments/gym_pets/envs/cartpole.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + + +class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle): + PENDULUM_LENGTH = 0.6 + + def __init__(self): + utils.EzPickle.__init__(self) + dir_path = os.path.dirname(os.path.realpath(__file__)) + mujoco_env.MujocoEnv.__init__(self, + '%s/assets/cartpole.xml' % dir_path, 2) + + def step(self, a): + self.do_simulation(a, self.frame_skip) + ob = self._get_obs() + + cost_lscale = CartpoleEnv.PENDULUM_LENGTH + reward = np.exp(-np.sum( + np.square( + self._get_ee_pos(ob) - + np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale** + 2)) + reward -= 0.01 * np.sum(np.square(a)) + + done = False + return ob, reward, done, {} + + def reset_model(self): + qpos = self.init_qpos + np.random.normal(0, 0.1, + np.shape(self.init_qpos)) + qvel = self.init_qvel + np.random.normal(0, 0.1, + np.shape(self.init_qvel)) + self.set_state(qpos, qvel) + return self._get_obs() + + def _get_obs(self): + return np.concatenate([self.data.qpos, self.data.qvel]).ravel() + + @staticmethod + def _get_ee_pos(x): + x0, theta = x[0], x[1] + return np.array([ + x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta), + -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta) + ]) + + def viewer_setup(self): + v = self.viewer + v.cam.trackbodyid = 0 + v.cam.distance = v.model.stat.extent diff --git a/alf/environments/gym_pets/envs/half_cheetah.py b/alf/environments/gym_pets/envs/half_cheetah.py new file mode 100644 index 000000000..7a3a58b06 --- /dev/null +++ b/alf/environments/gym_pets/envs/half_cheetah.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + + +class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): + def __init__(self): + self.prev_qpos = None + dir_path = os.path.dirname(os.path.realpath(__file__)) + mujoco_env.MujocoEnv.__init__( + self, '%s/assets/half_cheetah.xml' % dir_path, 5) + utils.EzPickle.__init__(self) + + def step(self, action): + self.prev_qpos = np.copy(self.data.qpos.flat) + self.do_simulation(action, self.frame_skip) + ob = self._get_obs() + + reward_ctrl = -0.1 * np.square(action).sum() + reward_run = ob[0] - 0.0 * np.square(ob[2]) + reward = reward_run + reward_ctrl + + done = False + return ob, reward, done, {} + + def _get_obs(self): + return np.concatenate([ + (self.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt, + self.data.qpos.flat[1:], + self.data.qvel.flat, + ]) + + def reset_model(self): + qpos = self.init_qpos + np.random.normal( + loc=0, scale=0.001, size=self.model.nq) + qvel = self.init_qvel + np.random.normal( + loc=0, scale=0.001, size=self.model.nv) + self.set_state(qpos, qvel) + self.prev_qpos = np.copy(self.data.qpos.flat) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.distance = self.stat.extent * 0.25 + self.viewer.cam.elevation = -55 diff --git a/alf/environments/gym_pets/envs/pusher.py b/alf/environments/gym_pets/envs/pusher.py new file mode 100644 index 000000000..c382bec2c --- /dev/null +++ b/alf/environments/gym_pets/envs/pusher.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + + +class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): + def __init__(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path, + 4) + utils.EzPickle.__init__(self) + self.reset_model() + + def step(self, a): + obj_pos = self.get_body_com("object"), + vec_1 = obj_pos - self.get_body_com("tips_arm") + vec_2 = obj_pos - self.get_body_com("goal") + + reward_near = -np.sum(np.abs(vec_1)) + reward_dist = -np.sum(np.abs(vec_2)) + reward_ctrl = -np.square(a).sum() + reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near + + self.do_simulation(a, self.frame_skip) + ob = self._get_obs() + done = False + return ob, reward, done, {} + + def viewer_setup(self): + self.viewer.cam.trackbodyid = -1 + self.viewer.cam.distance = 4.0 + + def reset_model(self): + qpos = self.init_qpos + + self.goal_pos = np.asarray([0, 0]) + self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal( + 0, 0.025, [2]) + + qpos[-4:-2] = self.cylinder_pos + qpos[-2:] = self.goal_pos + qvel = self.init_qvel + self.np_random.uniform( + low=-0.005, high=0.005, size=self.model.nv) + qvel[-4:] = 0 + self.set_state(qpos, qvel) + self.ac_goal_pos = self.get_body_com("goal") + + return self._get_obs() + + def _get_obs(self): + return np.concatenate([ + self.data.qpos.flat[:7], + self.data.qvel.flat[:7], + self.get_body_com("tips_arm"), + self.get_body_com("object"), + ]) diff --git a/alf/environments/gym_pets/envs/reacher.py b/alf/environments/gym_pets/envs/reacher.py new file mode 100644 index 000000000..764e7f313 --- /dev/null +++ b/alf/environments/gym_pets/envs/reacher.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + + +class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle): + def __init__(self): + self.viewer = None + utils.EzPickle.__init__(self) + dir_path = os.path.dirname(os.path.realpath(__file__)) + self.goal = np.zeros(3) + mujoco_env.MujocoEnv.__init__( + self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2) + + def step(self, a): + self.do_simulation(a, self.frame_skip) + ob = self._get_obs() + reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal)) + reward -= 0.01 * np.square(a).sum() + done = False + return ob, reward, done, dict(reward_dist=0, reward_ctrl=0) + + def viewer_setup(self): + self.viewer.cam.trackbodyid = 1 + self.viewer.cam.distance = 2.5 + self.viewer.cam.elevation = -30 + self.viewer.cam.azimuth = 270 + + def reset_model(self): + qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel) + qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3]) + qvel[-3:] = 0 + self.goal = qpos[-3:] + self.set_state(qpos, qvel) + return self._get_obs() + + def _get_obs(self): + return np.concatenate([ + self.data.qpos.flat, + self.data.qvel.flat[:-3], + ]) + + def get_EE_pos(self, states): + theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ + states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:] + + rot_axis = np.concatenate([ + np.cos(theta2) * np.cos(theta1), + np.cos(theta2) * np.sin(theta1), -np.sin(theta2) + ], + axis=1) + rot_perp_axis = np.concatenate( + [-np.sin(theta1), + np.cos(theta1), + np.zeros(theta1.shape)], axis=1) + cur_end = np.concatenate([ + 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2), + 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - + 0.188, -0.4 * np.sin(theta2) + ], + axis=1) + + for length, hinge, roll in [(0.321, theta4, theta3), + (0.16828, theta6, theta5)]: + perp_all_axis = np.cross(rot_axis, rot_perp_axis) + x = np.cos(hinge) * rot_axis + y = np.sin(hinge) * np.sin(roll) * rot_perp_axis + z = -np.sin(hinge) * np.cos(roll) * perp_all_axis + new_rot_axis = x + y + z + new_rot_perp_axis = np.cross(new_rot_axis, rot_axis) + new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \ + rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] + new_rot_perp_axis /= np.linalg.norm( + new_rot_perp_axis, axis=1, keepdims=True) + rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis + + return cur_end diff --git a/alf/examples/mbrl_reward_functions.py b/alf/examples/mbrl_reward_functions.py new file mode 100644 index 000000000..a2a23b65d --- /dev/null +++ b/alf/examples/mbrl_reward_functions.py @@ -0,0 +1,188 @@ +# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gin +import torch + +from alf.utils import common +# implement the respective reward functions for desired environments here + + +@gin.configurable +def reward_function_for_pendulum(obs, action): + """Function for computing reward for gym Pendulum environment. It takes + as input: + (1) observation (Tensor of shape [batch_size, observation_dim]) + (2) action (Tensor of shape [batch_size, num_actions]) + and returns a reward Tensor of shape [batch_size]. + """ + + def _observation_cost(obs): + c_theta, s_theta, d_theta = obs[..., :1], obs[..., 1:2], obs[..., 2:3] + theta = torch.atan2(s_theta, c_theta) + cost = theta**2 + 0.1 * d_theta**2 + cost = torch.sum(cost, dim=1) + cost = torch.where( + torch.isnan(cost), 1e6 * torch.ones_like(cost), cost) + return cost + + def _action_cost(action): + return 0.001 * torch.sum(action**2, dim=-1) + + cost = _observation_cost(obs) + _action_cost(action) + # negative cost as reward + reward = -cost + return reward + + +@gin.configurable +def reward_function_for_cartpole(obs, action): + """Function for computing reward for gym CartPole environment. It takes + as input: + (1) observation (Tensor of shape [batch_size, observation_dim]) + (2) action (Tensor of shape [batch_size, num_actions]) + and returns a reward Tensor of shape [batch_size]. + """ + + def _observation_cost(obs): + x0, theta = obs[..., :1], obs[..., 1:2] + ee_pos = torch.cat( + (x0 - 0.6 * torch.sin(theta), -0.6 * torch.cos(theta)), dim=-1) + cost = (ee_pos - torch.as_tensor([.0, .6]))**2 + cost = -torch.exp(torch.sum(cost, dim=-1) / (0.6**2)) + cost = torch.where( + torch.isnan(cost), 1e6 * torch.ones_like(cost), cost) + + return cost + + def _action_cost(action): + cost = 0.01 * torch.sum(action**2, dim=-1) + return cost + + cost = _observation_cost(obs) + _action_cost(action) + reward = -cost + return reward + + +@gin.configurable +def reward_function_for_halfcheetah(obs, action): + """Function for computing reward for gym CartPole environment. It takes + as input: + (1) observation (Tensor of shape [batch_size, observation_dim]) + (2) action (Tensor of shape [batch_size, num_actions]) + and returns a reward Tensor of shape [batch_size]. + """ + + def _observation_cost(obs): + cost = -obs[..., 0] + return cost + + def _action_cost(action): + cost = 0.1 * torch.sum(action**2, dim=-1) + return cost + + cost = _observation_cost(obs) + _action_cost(action) + reward = -cost + return reward + + +@gin.configurable +def reward_function_for_pusher(obs, action): + """Function for computing reward for gym CartPole environment. It takes + as input: + (1) observation (Tensor of shape [batch_size, observation_dim]) + (3) action (Tensor of shape [batch_size, num_actions]) + and returns a reward Tensor of shape [batch_size]. + """ + + def _observation_cost(obs): + to_w, og_w = 0.5, 1.25 + tip_pos, obj_pos = obs[..., 14:17], obs[..., 17:20] + tip_obj_dist = torch.sum(torch.abs(tip_pos - obj_pos), dim=-1) + obj_goal_dist = torch.sum( + torch.abs(common.get_gym_env_attr('ac_goal_pos') - obj_pos), + dim=-1) + cost = to_w * tip_obj_dist + og_w * obj_goal_dist + cost = torch.where( + torch.isnan(cost), 1e6 * torch.ones_like(cost), cost) + + return cost + + def _action_cost(action): + cost = 0.1 * torch.sum(action**2, dim=-1) + return cost + + cost = _observation_cost(obs) + _action_cost(action) + reward = -cost + return reward + + +@gin.configurable +def reward_function_for_reacher(obs, action): + """Function for computing reward for gym CartPole environment. It takes + as input: + (1) observation (Tensor of shape [batch_size, observation_dim]) + (2) action (Tensor of shape [batch_size, num_actions]) + and returns a reward Tensor of shape [batch_size]. + """ + + def _observation_cost(obs): + theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ + obs[..., :1], obs[..., 1:2], obs[..., 2:3], obs[..., 3:4], \ + obs[..., 4:5], obs[..., 5:6], obs[..., 6:] + rot_axis = torch.cat( + (torch.cos(theta2) * torch.cos(theta1), + torch.cos(theta2) * torch.sin(theta1), -torch.sin(theta2)), + dim=-1) + rot_perp_axis = torch.cat( + (-torch.sin(theta1), torch.cos(theta1), torch.zeros_like(theta1)), + dim=-1) + cur_end = torch.cat(( + 0.1 * torch.cos(theta1) + 0.4 * torch.cos(theta1) * torch.cos(theta2), + 0.1 * torch.sin(theta1) + 0.4 * torch.sin(theta1) * torch.cos(theta2) \ + - 0.188, + -0.4 * torch.sin(theta2)), dim=-1) + + for length, hinge, roll in [(0.321, theta4, theta3), \ + (0.16828, theta6, theta5)]: + perp_all_axis = torch.cross(rot_axis, rot_perp_axis) + x = torch.cos(hinge) * rot_axis + y = torch.sin(hinge) * torch.sin(roll) * rot_perp_axis + z = -torch.sin(hinge) * torch.cos(roll) * perp_all_axis + new_rot_axis = x + y + z + new_rot_perp_axis = torch.cross(new_rot_axis, rot_axis) + tmp_rot_perp_axis = torch.where( + torch.lt(torch.norm(new_rot_perp_axis, dim=-1), 1e-30), + rot_perp_axis.permute(-1, + *list(range(rot_perp_axis.ndim - 1))), + new_rot_perp_axis.permute( + -1, *list(range(new_rot_perp_axis.ndim - 1)))) + new_rot_perp_axis = tmp_rot_perp_axis.permute( + *list(range(1, tmp_rot_perp_axis.ndim)), 0) + new_rot_perp_axis /= torch.norm( + new_rot_perp_axis, dim=-1, keepdim=True) + rot_axis, rot_perp_axis, cur_end = \ + new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis + + cost = torch.sum( + torch.square(cur_end - common.get_gym_env_attr('goal')), dim=-1) + return cost + + def _action_cost(action): + cost = 0.01 * torch.sum(action**2, dim=-1) + return cost + + cost = _observation_cost(obs) + _action_cost(action) + reward = -cost + return reward diff --git a/alf/utils/common.py b/alf/utils/common.py index 47f2cb60c..126a84855 100644 --- a/alf/utils/common.py +++ b/alf/utils/common.py @@ -33,6 +33,7 @@ from typing import Callable import alf +from alf.environments.parallel_environment import ParallelAlfEnvironment import alf.nest as nest from alf.tensor_specs import TensorSpec, BoundedTensorSpec from alf.utils.spec_utils import zeros_from_spec as zero_tensor_from_nested_spec @@ -577,6 +578,26 @@ def get_vocab_size(): return 0 +@gin.configurable +def get_gym_env_attr(attr): + """Get specific attr of gym env wrapped in the global environment. Used for + customized gym environments. + + Args: + attr (str): the attribute of the gym env. + + Returns: + gym_env.attr + """ + assert _env + if isinstance(_env, ParallelAlfEnvironment): + gym_env = _env.envs[0].gym + else: + gym_env = _env._env.gym + assert hasattr(gym_env, attr) + return torch.as_tensor(getattr(gym_env, attr), dtype=torch.float32) + + @gin.configurable def active_action_target_entropy(active_action_portion=0.2, min_entropy=0.3): """Automatically compute target entropy given the action spec. Currently