diff --git a/alf/environments/gym_pets/__init__.py b/alf/environments/gym_pets/__init__.py
new file mode 100644
index 000000000..f87e745d0
--- /dev/null
+++ b/alf/environments/gym_pets/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from gym.envs.registration import register
+
+register(
+ id='MBRLCartpole-v0',
+ entry_point='alf.environments.gym_pets.envs:CartpoleEnv')
+
+register(
+ id='MBRLReacher3D-v0',
+ entry_point='alf.environments.gym_pets.envs:Reacher3DEnv')
+
+register(
+ id='MBRLPusher-v0', entry_point='alf.environments.gym_pets.envs:PusherEnv')
+
+register(
+ id='MBRLHalfCheetah-v0',
+ entry_point='alf.environments.gym_pets.envs:HalfCheetahEnv')
diff --git a/alf/environments/gym_pets/envs/__init__.py b/alf/environments/gym_pets/envs/__init__.py
new file mode 100644
index 000000000..c659fd52c
--- /dev/null
+++ b/alf/environments/gym_pets/envs/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .cartpole import CartpoleEnv
+from .half_cheetah import HalfCheetahEnv
+from .pusher import PusherEnv
+from .reacher import Reacher3DEnv
diff --git a/alf/environments/gym_pets/envs/assets/cartpole.xml b/alf/environments/gym_pets/envs/assets/cartpole.xml
new file mode 100644
index 000000000..284a58c9a
--- /dev/null
+++ b/alf/environments/gym_pets/envs/assets/cartpole.xml
@@ -0,0 +1,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alf/environments/gym_pets/envs/assets/half_cheetah.xml b/alf/environments/gym_pets/envs/assets/half_cheetah.xml
new file mode 100644
index 000000000..40a1cb62c
--- /dev/null
+++ b/alf/environments/gym_pets/envs/assets/half_cheetah.xml
@@ -0,0 +1,95 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alf/environments/gym_pets/envs/assets/pusher.xml b/alf/environments/gym_pets/envs/assets/pusher.xml
new file mode 100644
index 000000000..9e81b01a6
--- /dev/null
+++ b/alf/environments/gym_pets/envs/assets/pusher.xml
@@ -0,0 +1,101 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alf/environments/gym_pets/envs/assets/reacher3d.xml b/alf/environments/gym_pets/envs/assets/reacher3d.xml
new file mode 100644
index 000000000..a51c71b93
--- /dev/null
+++ b/alf/environments/gym_pets/envs/assets/reacher3d.xml
@@ -0,0 +1,154 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/alf/environments/gym_pets/envs/cartpole.py b/alf/environments/gym_pets/envs/cartpole.py
new file mode 100644
index 000000000..191ea7b61
--- /dev/null
+++ b/alf/environments/gym_pets/envs/cartpole.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import os
+
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ PENDULUM_LENGTH = 0.6
+
+ def __init__(self):
+ utils.EzPickle.__init__(self)
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ mujoco_env.MujocoEnv.__init__(self,
+ '%s/assets/cartpole.xml' % dir_path, 2)
+
+ def step(self, a):
+ self.do_simulation(a, self.frame_skip)
+ ob = self._get_obs()
+
+ cost_lscale = CartpoleEnv.PENDULUM_LENGTH
+ reward = np.exp(-np.sum(
+ np.square(
+ self._get_ee_pos(ob) -
+ np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale**
+ 2))
+ reward -= 0.01 * np.sum(np.square(a))
+
+ done = False
+ return ob, reward, done, {}
+
+ def reset_model(self):
+ qpos = self.init_qpos + np.random.normal(0, 0.1,
+ np.shape(self.init_qpos))
+ qvel = self.init_qvel + np.random.normal(0, 0.1,
+ np.shape(self.init_qvel))
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def _get_obs(self):
+ return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
+
+ @staticmethod
+ def _get_ee_pos(x):
+ x0, theta = x[0], x[1]
+ return np.array([
+ x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta),
+ -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta)
+ ])
+
+ def viewer_setup(self):
+ v = self.viewer
+ v.cam.trackbodyid = 0
+ v.cam.distance = v.model.stat.extent
diff --git a/alf/environments/gym_pets/envs/half_cheetah.py b/alf/environments/gym_pets/envs/half_cheetah.py
new file mode 100644
index 000000000..7a3a58b06
--- /dev/null
+++ b/alf/environments/gym_pets/envs/half_cheetah.py
@@ -0,0 +1,64 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import os
+
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ def __init__(self):
+ self.prev_qpos = None
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ mujoco_env.MujocoEnv.__init__(
+ self, '%s/assets/half_cheetah.xml' % dir_path, 5)
+ utils.EzPickle.__init__(self)
+
+ def step(self, action):
+ self.prev_qpos = np.copy(self.data.qpos.flat)
+ self.do_simulation(action, self.frame_skip)
+ ob = self._get_obs()
+
+ reward_ctrl = -0.1 * np.square(action).sum()
+ reward_run = ob[0] - 0.0 * np.square(ob[2])
+ reward = reward_run + reward_ctrl
+
+ done = False
+ return ob, reward, done, {}
+
+ def _get_obs(self):
+ return np.concatenate([
+ (self.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt,
+ self.data.qpos.flat[1:],
+ self.data.qvel.flat,
+ ])
+
+ def reset_model(self):
+ qpos = self.init_qpos + np.random.normal(
+ loc=0, scale=0.001, size=self.model.nq)
+ qvel = self.init_qvel + np.random.normal(
+ loc=0, scale=0.001, size=self.model.nv)
+ self.set_state(qpos, qvel)
+ self.prev_qpos = np.copy(self.data.qpos.flat)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.stat.extent * 0.25
+ self.viewer.cam.elevation = -55
diff --git a/alf/environments/gym_pets/envs/pusher.py b/alf/environments/gym_pets/envs/pusher.py
new file mode 100644
index 000000000..c382bec2c
--- /dev/null
+++ b/alf/environments/gym_pets/envs/pusher.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import os
+
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ def __init__(self):
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path,
+ 4)
+ utils.EzPickle.__init__(self)
+ self.reset_model()
+
+ def step(self, a):
+ obj_pos = self.get_body_com("object"),
+ vec_1 = obj_pos - self.get_body_com("tips_arm")
+ vec_2 = obj_pos - self.get_body_com("goal")
+
+ reward_near = -np.sum(np.abs(vec_1))
+ reward_dist = -np.sum(np.abs(vec_2))
+ reward_ctrl = -np.square(a).sum()
+ reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
+
+ self.do_simulation(a, self.frame_skip)
+ ob = self._get_obs()
+ done = False
+ return ob, reward, done, {}
+
+ def viewer_setup(self):
+ self.viewer.cam.trackbodyid = -1
+ self.viewer.cam.distance = 4.0
+
+ def reset_model(self):
+ qpos = self.init_qpos
+
+ self.goal_pos = np.asarray([0, 0])
+ self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal(
+ 0, 0.025, [2])
+
+ qpos[-4:-2] = self.cylinder_pos
+ qpos[-2:] = self.goal_pos
+ qvel = self.init_qvel + self.np_random.uniform(
+ low=-0.005, high=0.005, size=self.model.nv)
+ qvel[-4:] = 0
+ self.set_state(qpos, qvel)
+ self.ac_goal_pos = self.get_body_com("goal")
+
+ return self._get_obs()
+
+ def _get_obs(self):
+ return np.concatenate([
+ self.data.qpos.flat[:7],
+ self.data.qvel.flat[:7],
+ self.get_body_com("tips_arm"),
+ self.get_body_com("object"),
+ ])
diff --git a/alf/environments/gym_pets/envs/reacher.py b/alf/environments/gym_pets/envs/reacher.py
new file mode 100644
index 000000000..764e7f313
--- /dev/null
+++ b/alf/environments/gym_pets/envs/reacher.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import os
+
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+
+
+class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+ def __init__(self):
+ self.viewer = None
+ utils.EzPickle.__init__(self)
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ self.goal = np.zeros(3)
+ mujoco_env.MujocoEnv.__init__(
+ self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2)
+
+ def step(self, a):
+ self.do_simulation(a, self.frame_skip)
+ ob = self._get_obs()
+ reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal))
+ reward -= 0.01 * np.square(a).sum()
+ done = False
+ return ob, reward, done, dict(reward_dist=0, reward_ctrl=0)
+
+ def viewer_setup(self):
+ self.viewer.cam.trackbodyid = 1
+ self.viewer.cam.distance = 2.5
+ self.viewer.cam.elevation = -30
+ self.viewer.cam.azimuth = 270
+
+ def reset_model(self):
+ qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel)
+ qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3])
+ qvel[-3:] = 0
+ self.goal = qpos[-3:]
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def _get_obs(self):
+ return np.concatenate([
+ self.data.qpos.flat,
+ self.data.qvel.flat[:-3],
+ ])
+
+ def get_EE_pos(self, states):
+ theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
+ states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:]
+
+ rot_axis = np.concatenate([
+ np.cos(theta2) * np.cos(theta1),
+ np.cos(theta2) * np.sin(theta1), -np.sin(theta2)
+ ],
+ axis=1)
+ rot_perp_axis = np.concatenate(
+ [-np.sin(theta1),
+ np.cos(theta1),
+ np.zeros(theta1.shape)], axis=1)
+ cur_end = np.concatenate([
+ 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2),
+ 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) -
+ 0.188, -0.4 * np.sin(theta2)
+ ],
+ axis=1)
+
+ for length, hinge, roll in [(0.321, theta4, theta3),
+ (0.16828, theta6, theta5)]:
+ perp_all_axis = np.cross(rot_axis, rot_perp_axis)
+ x = np.cos(hinge) * rot_axis
+ y = np.sin(hinge) * np.sin(roll) * rot_perp_axis
+ z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
+ new_rot_axis = x + y + z
+ new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
+ new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \
+ rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
+ new_rot_perp_axis /= np.linalg.norm(
+ new_rot_perp_axis, axis=1, keepdims=True)
+ rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
+
+ return cur_end
diff --git a/alf/examples/mbrl_reward_functions.py b/alf/examples/mbrl_reward_functions.py
new file mode 100644
index 000000000..a2a23b65d
--- /dev/null
+++ b/alf/examples/mbrl_reward_functions.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gin
+import torch
+
+from alf.utils import common
+# implement the respective reward functions for desired environments here
+
+
+@gin.configurable
+def reward_function_for_pendulum(obs, action):
+ """Function for computing reward for gym Pendulum environment. It takes
+ as input:
+ (1) observation (Tensor of shape [batch_size, observation_dim])
+ (2) action (Tensor of shape [batch_size, num_actions])
+ and returns a reward Tensor of shape [batch_size].
+ """
+
+ def _observation_cost(obs):
+ c_theta, s_theta, d_theta = obs[..., :1], obs[..., 1:2], obs[..., 2:3]
+ theta = torch.atan2(s_theta, c_theta)
+ cost = theta**2 + 0.1 * d_theta**2
+ cost = torch.sum(cost, dim=1)
+ cost = torch.where(
+ torch.isnan(cost), 1e6 * torch.ones_like(cost), cost)
+ return cost
+
+ def _action_cost(action):
+ return 0.001 * torch.sum(action**2, dim=-1)
+
+ cost = _observation_cost(obs) + _action_cost(action)
+ # negative cost as reward
+ reward = -cost
+ return reward
+
+
+@gin.configurable
+def reward_function_for_cartpole(obs, action):
+ """Function for computing reward for gym CartPole environment. It takes
+ as input:
+ (1) observation (Tensor of shape [batch_size, observation_dim])
+ (2) action (Tensor of shape [batch_size, num_actions])
+ and returns a reward Tensor of shape [batch_size].
+ """
+
+ def _observation_cost(obs):
+ x0, theta = obs[..., :1], obs[..., 1:2]
+ ee_pos = torch.cat(
+ (x0 - 0.6 * torch.sin(theta), -0.6 * torch.cos(theta)), dim=-1)
+ cost = (ee_pos - torch.as_tensor([.0, .6]))**2
+ cost = -torch.exp(torch.sum(cost, dim=-1) / (0.6**2))
+ cost = torch.where(
+ torch.isnan(cost), 1e6 * torch.ones_like(cost), cost)
+
+ return cost
+
+ def _action_cost(action):
+ cost = 0.01 * torch.sum(action**2, dim=-1)
+ return cost
+
+ cost = _observation_cost(obs) + _action_cost(action)
+ reward = -cost
+ return reward
+
+
+@gin.configurable
+def reward_function_for_halfcheetah(obs, action):
+ """Function for computing reward for gym CartPole environment. It takes
+ as input:
+ (1) observation (Tensor of shape [batch_size, observation_dim])
+ (2) action (Tensor of shape [batch_size, num_actions])
+ and returns a reward Tensor of shape [batch_size].
+ """
+
+ def _observation_cost(obs):
+ cost = -obs[..., 0]
+ return cost
+
+ def _action_cost(action):
+ cost = 0.1 * torch.sum(action**2, dim=-1)
+ return cost
+
+ cost = _observation_cost(obs) + _action_cost(action)
+ reward = -cost
+ return reward
+
+
+@gin.configurable
+def reward_function_for_pusher(obs, action):
+ """Function for computing reward for gym CartPole environment. It takes
+ as input:
+ (1) observation (Tensor of shape [batch_size, observation_dim])
+ (3) action (Tensor of shape [batch_size, num_actions])
+ and returns a reward Tensor of shape [batch_size].
+ """
+
+ def _observation_cost(obs):
+ to_w, og_w = 0.5, 1.25
+ tip_pos, obj_pos = obs[..., 14:17], obs[..., 17:20]
+ tip_obj_dist = torch.sum(torch.abs(tip_pos - obj_pos), dim=-1)
+ obj_goal_dist = torch.sum(
+ torch.abs(common.get_gym_env_attr('ac_goal_pos') - obj_pos),
+ dim=-1)
+ cost = to_w * tip_obj_dist + og_w * obj_goal_dist
+ cost = torch.where(
+ torch.isnan(cost), 1e6 * torch.ones_like(cost), cost)
+
+ return cost
+
+ def _action_cost(action):
+ cost = 0.1 * torch.sum(action**2, dim=-1)
+ return cost
+
+ cost = _observation_cost(obs) + _action_cost(action)
+ reward = -cost
+ return reward
+
+
+@gin.configurable
+def reward_function_for_reacher(obs, action):
+ """Function for computing reward for gym CartPole environment. It takes
+ as input:
+ (1) observation (Tensor of shape [batch_size, observation_dim])
+ (2) action (Tensor of shape [batch_size, num_actions])
+ and returns a reward Tensor of shape [batch_size].
+ """
+
+ def _observation_cost(obs):
+ theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
+ obs[..., :1], obs[..., 1:2], obs[..., 2:3], obs[..., 3:4], \
+ obs[..., 4:5], obs[..., 5:6], obs[..., 6:]
+ rot_axis = torch.cat(
+ (torch.cos(theta2) * torch.cos(theta1),
+ torch.cos(theta2) * torch.sin(theta1), -torch.sin(theta2)),
+ dim=-1)
+ rot_perp_axis = torch.cat(
+ (-torch.sin(theta1), torch.cos(theta1), torch.zeros_like(theta1)),
+ dim=-1)
+ cur_end = torch.cat((
+ 0.1 * torch.cos(theta1) + 0.4 * torch.cos(theta1) * torch.cos(theta2),
+ 0.1 * torch.sin(theta1) + 0.4 * torch.sin(theta1) * torch.cos(theta2) \
+ - 0.188,
+ -0.4 * torch.sin(theta2)), dim=-1)
+
+ for length, hinge, roll in [(0.321, theta4, theta3), \
+ (0.16828, theta6, theta5)]:
+ perp_all_axis = torch.cross(rot_axis, rot_perp_axis)
+ x = torch.cos(hinge) * rot_axis
+ y = torch.sin(hinge) * torch.sin(roll) * rot_perp_axis
+ z = -torch.sin(hinge) * torch.cos(roll) * perp_all_axis
+ new_rot_axis = x + y + z
+ new_rot_perp_axis = torch.cross(new_rot_axis, rot_axis)
+ tmp_rot_perp_axis = torch.where(
+ torch.lt(torch.norm(new_rot_perp_axis, dim=-1), 1e-30),
+ rot_perp_axis.permute(-1,
+ *list(range(rot_perp_axis.ndim - 1))),
+ new_rot_perp_axis.permute(
+ -1, *list(range(new_rot_perp_axis.ndim - 1))))
+ new_rot_perp_axis = tmp_rot_perp_axis.permute(
+ *list(range(1, tmp_rot_perp_axis.ndim)), 0)
+ new_rot_perp_axis /= torch.norm(
+ new_rot_perp_axis, dim=-1, keepdim=True)
+ rot_axis, rot_perp_axis, cur_end = \
+ new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
+
+ cost = torch.sum(
+ torch.square(cur_end - common.get_gym_env_attr('goal')), dim=-1)
+ return cost
+
+ def _action_cost(action):
+ cost = 0.01 * torch.sum(action**2, dim=-1)
+ return cost
+
+ cost = _observation_cost(obs) + _action_cost(action)
+ reward = -cost
+ return reward
diff --git a/alf/utils/common.py b/alf/utils/common.py
index 47f2cb60c..126a84855 100644
--- a/alf/utils/common.py
+++ b/alf/utils/common.py
@@ -33,6 +33,7 @@
from typing import Callable
import alf
+from alf.environments.parallel_environment import ParallelAlfEnvironment
import alf.nest as nest
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils.spec_utils import zeros_from_spec as zero_tensor_from_nested_spec
@@ -577,6 +578,26 @@ def get_vocab_size():
return 0
+@gin.configurable
+def get_gym_env_attr(attr):
+ """Get specific attr of gym env wrapped in the global environment. Used for
+ customized gym environments.
+
+ Args:
+ attr (str): the attribute of the gym env.
+
+ Returns:
+ gym_env.attr
+ """
+ assert _env
+ if isinstance(_env, ParallelAlfEnvironment):
+ gym_env = _env.envs[0].gym
+ else:
+ gym_env = _env._env.gym
+ assert hasattr(gym_env, attr)
+ return torch.as_tensor(getattr(gym_env, attr), dtype=torch.float32)
+
+
@gin.configurable
def active_action_target_entropy(active_action_portion=0.2, min_entropy=0.3):
"""Automatically compute target entropy given the action spec. Currently