From 4b8ba1298ad0808509e6e0dc3a16a7b373add4c5 Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Fri, 27 Dec 2019 20:42:33 +0530 Subject: [PATCH 01/10] update gitignore --- .gitignore | 5 ++++- rl/test_env.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 rl/test_env.py diff --git a/.gitignore b/.gitignore index ba0430d..583c72e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -__pycache__/ \ No newline at end of file +__pycache__/ +checkpoints/ +logs/ +*.pluto \ No newline at end of file diff --git a/rl/test_env.py b/rl/test_env.py new file mode 100644 index 0000000..f759340 --- /dev/null +++ b/rl/test_env.py @@ -0,0 +1,27 @@ +import gym +from dopamine.discrete_domains.gym_lib import create_gym_environment + + +def test_snake_classic(dopamine=False): + if(dopamine): + env = create_gym_environment( + environment_name="gym_snake_classic:SnakeClassic", + version = 'v0' + ) + else: + env = gym.make("gym_snake_classic:SnakeClassic-v0") + env.reset() + if not dopamine: + env.render() + for _ in range(100): + action = 3 + state, reward, done, _ = env.step(action) + if not dopamine: + env.render() + if done: + break + print(f"Reward :{reward}") + + +if __name__ == "__main__": + test_snake_classic() \ No newline at end of file From a1cc07c685e48030c1a4405599562677450da4ea Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Fri, 27 Dec 2019 20:52:12 +0530 Subject: [PATCH 02/10] running --- .../gym_snake_classic/envs/snake_classic.py | 42 +++++++++---------- gym-snake/gym_snake_classic/envs/src/game.py | 1 - 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 87c6792..a6a63e4 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -21,27 +21,23 @@ class SnakeClassicEnv(gym.Env): 3 : 'RIGHT' } - def __init__(self,rgb=True): + def __init__(self): self.temp_filename='_temp_window.jpg' width,height = (800,600) self.action_space = spaces.Discrete(4) - + self.n_steps = 0 + self.reward = 0 + self.prev_reward = -1 cfg = GameConfig(width = 800, height = 600, player = Snake, food = Food, player_size = (20,20), food_size = (20,20), - render = True, - rgb = rgb ) - if rgb: - self.observation_space = spaces.Box(low=0, high=255, shape= + + self.observation_space = spaces.Box(low=0, high=255, shape= (height, width, 3)) - else: - self.observation_space = spaces.Box(low=0, high=255, shape= - (height, width)) - self.snake_game = Game(cfg) self.snake_game.on_init() @@ -53,12 +49,6 @@ def env(self): def _observe(self): pygame.image.save(self.snake_game.window,self.temp_filename) obs = Image.open(self.temp_filename) - if not self.snake_game.config.rgb: - obs=obs.convert('LA') - # HACK to pass three channels to memory buffer in dopamine - # TODO figure out how to pass rgb directly - - return obs def step(self, action): @@ -69,25 +59,35 @@ def step(self, action): #TODO figure out a faster way obs = self._observe() - #reward - reward = self.get_reward() + #done done = self.snake_game.done + if done : + self.reward -= 100 + else: + if(self.prev_reward == self.reward): + self.reward -= 1 + else: + self.reward += 10 + self.prev_reward=self.reward + #info info = {} - return (obs,reward,done,info) + return (obs,self.reward,done,info) def take_action(self, action): act = self.ACTION_LOOKUP[action] + # print(f"action : {act}") self.snake_game.take_action(act) - + def get_reward(self): - return self.snake_game.score + return self.reward def reset(self): + self.n_steps=0 self.snake_game.reset() return self._observe() diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py index 7d493c9..b7d5393 100644 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -15,7 +15,6 @@ class GameConfig: player_size : Tuple[int] food_size : Tuple[int] render : bool = True - rgb : bool = True class Game: def __init__(self,config:GameConfig)->None: From 587ce438e0896008fe6966afa75ec6517e83f862 Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Fri, 27 Dec 2019 21:34:47 +0530 Subject: [PATCH 03/10] added show to game render function --- .../gym_snake_classic/envs/snake_classic.py | 4 +-- gym-snake/gym_snake_classic/envs/src/game.py | 5 +-- rl/agent.py | 17 +++++---- test.py | 18 ---------- test_env.py | 36 +++++++++++++++++++ 5 files changed, 51 insertions(+), 29 deletions(-) delete mode 100644 test.py create mode 100644 test_env.py diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index a6a63e4..9b95cb1 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -48,19 +48,19 @@ def env(self): def _observe(self): pygame.image.save(self.snake_game.window,self.temp_filename) - obs = Image.open(self.temp_filename) + obs = imread(self.temp_filename) return obs def step(self, action): self.take_action(action) self.snake_game.on_loop() + self.snake_game.on_render(show=False) #observation #TODO figure out a faster way obs = self._observe() - #done done = self.snake_game.done diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py index b7d5393..0a87616 100644 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -83,11 +83,12 @@ def on_loop(self): if( head.colliderect(_pos) ): self._running = False - def on_render(self): + def on_render(self,show=False): self.display.fill((255,255,255)) self.player.draw(self.display,self.config.player_size) self.food.draw(self.display,self.config.food_size) - pygame.display.flip() + if show: + pygame.display.flip() def on_cleanup(self): diff --git a/rl/agent.py b/rl/agent.py index 16ad6ac..a3274a2 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -13,6 +13,9 @@ GAMMA = 0.9 REPLAY_CAPACITY = 10000 BATCH_SIZE = 32 +sess = tf.Session() + + class SnakeDQNAgent(dqn_agent.DQNAgent): def __init__(self,*args,**kwargs): @@ -39,8 +42,6 @@ def _build_replay_buffer(self,use_staging): ) - - env = create_gym_environment( environment_name="gym_snake_classic:SnakeClassic", version = 'v0' @@ -52,8 +53,8 @@ def _build_replay_buffer(self,use_staging): batch_size=32, gamma=GAMMA, ) -print(env.action_space.n) -sess = tf.Session() + + @@ -80,11 +81,13 @@ def _env_fn(*args): checkpoint_file_prefix='ckpt', logging_file_prefix='log', log_every_n=10, - num_iterations=200, - training_steps=2500, - evaluation_steps=1250, + num_iterations=2000, + training_steps=25000, + evaluation_steps=12500, max_steps_per_episode=10000 ) + + runner.run_experiment() diff --git a/test.py b/test.py deleted file mode 100644 index 6e0c008..0000000 --- a/test.py +++ /dev/null @@ -1,18 +0,0 @@ -import gym - - -def test_snake_classic(): - env = gym.make("gym_snake_classic:SnakeClassic-v0") - env.reset() - env.render() - for _ in range(100): - action = 3 - state, reward, done, _ = env.step(action) - env.render() - if done: - break - print(f"Reward :{reward}") - - -if __name__ == "__main__": - test_snake_classic() \ No newline at end of file diff --git a/test_env.py b/test_env.py new file mode 100644 index 0000000..ee2ee6a --- /dev/null +++ b/test_env.py @@ -0,0 +1,36 @@ +import gym +import unittest +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image + +class TestSnakeEnv(unittest.TestCase): + def setUp(self): + self.env = gym.make("gym_snake_classic:SnakeClassic-v0") + + + def test_0_snake_classic(self,render=True): + self.env.reset() + if render: + self.env.render() + for _ in range(100): + action = 3 + state, reward, done, _ = self.env.step(action) + if render: + self.env.render() + if done: + break + self.assertEqual(reward,-89) + + def test_1_observe(self): + img=self.env._observe() + uni = np.unique(img) + self.assertEqual(len(uni),157) + + +if __name__ == "__main__": + t = TestSnakeEnv() + t.setUp() + t.test_0_snake_classic() + t.test_1_observe() + # unittest.main() \ No newline at end of file From 84235099a46f94231be8b27dc9b74a25eb08af90 Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Fri, 27 Dec 2019 22:22:55 +0530 Subject: [PATCH 04/10] added rainbowAgent --- .../gym_snake_classic/envs/snake_classic.py | 3 +- rl/agent.py | 80 +++++++++++++---- rl/models.py | 87 ++++++++++++++----- rl/utils.py | 26 ++++++ 4 files changed, 152 insertions(+), 44 deletions(-) create mode 100644 rl/utils.py diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 9b95cb1..7c4c07b 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -55,7 +55,7 @@ def step(self, action): self.take_action(action) self.snake_game.on_loop() - self.snake_game.on_render(show=False) + self.snake_game.on_render(show=True) #observation #TODO figure out a faster way obs = self._observe() @@ -80,7 +80,6 @@ def step(self, action): def take_action(self, action): act = self.ACTION_LOOKUP[action] - # print(f"action : {act}") self.snake_game.take_action(act) def get_reward(self): diff --git a/rl/agent.py b/rl/agent.py index a3274a2..5cc2d29 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -6,7 +6,9 @@ from dopamine.agents.dqn import dqn_agent from dopamine.discrete_domains.run_experiment import Runner from dopamine.discrete_domains.gym_lib import create_gym_environment -from models import SimpleDQNNetwork +from dopamine.agents.rainbow import rainbow_agent +from models import SimpleDQNNetwork,RainbowNetwork + STACK_SIZE = 4 @@ -15,7 +17,8 @@ BATCH_SIZE = 32 sess = tf.Session() - +#TODO +# * use prioritized buffer class SnakeDQNAgent(dqn_agent.DQNAgent): def __init__(self,*args,**kwargs): @@ -41,33 +44,74 @@ def _build_replay_buffer(self,use_staging): observation_dtype=self.observation_dtype.as_numpy_dtype ) +class SnakeRainbowAgent(rainbow_agent.RainbowAgent): + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + def _build_replay_buffer(self,use_staging): + + """Creates the replay buffer used by the agent. + Args: + use_staging: bool, if True, uses a staging area to prefetch data for + faster training. + Returns: + A WrapperReplayBuffer object. + """ + return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( + replay_capacity = REPLAY_CAPACITY, + batch_size = BATCH_SIZE, + observation_shape=self.observation_shape, + stack_size=self.stack_size, + use_staging=use_staging, + update_horizon=self.update_horizon, + gamma=self.gamma, + observation_dtype=self.observation_dtype.as_numpy_dtype + ) + + env = create_gym_environment( environment_name="gym_snake_classic:SnakeClassic", version = 'v0' ) -memory_buffer = prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( - observation_shape=env.observation_space.shape, - stack_size=STACK_SIZE, - replay_capacity=100, - batch_size=32, - gamma=GAMMA, - ) + def _agent_fn(sess,env,summary_writer): - AGENT = SnakeDQNAgent( - sess=sess, - num_actions = env.action_space.n, - observation_shape = env.observation_space.shape, - stack_size = STACK_SIZE, - network = SimpleDQNNetwork, - gamma=GAMMA, - tf_device = '/gpu:0' , - summary_writer=summary_writer + # AGENT = SnakeDQNAgent( + # sess=sess, + # num_actions = env.action_space.n, + # observation_shape = env.observation_space.shape, + # stack_size = STACK_SIZE, + # network = SimpleDQNNetwork, + # gamma=GAMMA, + # tf_device = '/gpu:0' , + # summary_writer=summary_writer + # ) + + AGENT = SnakeRainbowAgent( + sess=sess, + num_actions=env.action_space.n, + observation_shape=env.observation_space.shape, + stack_size=STACK_SIZE, + network=RainbowNetwork, + num_atoms=51, + vmax=10., + gamma=0.99, + update_horizon=1, + min_replay_history=20000, + update_period=4, + target_update_period=8000, + epsilon_fn=dqn_agent.linearly_decaying_epsilon, + epsilon_train=0.01, + epsilon_eval=0.001, + epsilon_decay_period=250000, + replay_scheme='prioritized', + tf_device='/gpu:*', + summary_writer=summary_writer, ) return AGENT diff --git a/rl/models.py b/rl/models.py index e6d20f7..8fb3bd4 100644 --- a/rl/models.py +++ b/rl/models.py @@ -1,10 +1,12 @@ +import numpy as np import tensorflow as tf + from collections import namedtuple -from tensorflow.contrib import layers as contrib_layers -from tensorflow.contrib import slim as contrib_slim +from utils import merge_last_two_dims DQNNetworkType = namedtuple('dqn_network', ['q_values']) - +RainbowNetworkType = namedtuple( + 'c51_network', ['q_values', 'logits', 'probabilities']) class SimpleDQNNetwork(tf.keras.Model): @@ -43,31 +45,65 @@ def call(self, state,): """ #TODO Make rgb proper #HACK pass - def infer_shape(x): - x = tf.convert_to_tensor(x) + - # If unknown rank, return dynamic shape - if x.shape.dims is None: - return tf.shape(x) + x = tf.cast(state, tf.float32) + x = tf.div(x, 255.) + x = merge_last_two_dims(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.flatten(x) + x = self.dense1(x) - static_shape = x.shape.as_list() - dynamic_shape = tf.shape(x) + return DQNNetworkType(self.dense2(x)) - ret = [] - for i in range(len(static_shape)): - dim = static_shape[i] - if dim is None: - dim = dynamic_shape[i] - ret.append(dim) - return ret - def merge_last_two_dims(tensor): - shape = infer_shape(tensor) - shape[-2] *= shape[-1] - shape.pop(-1) - return tf.reshape(tensor, shape) +class RainbowNetwork(tf.keras.Model): + def __init__(self, num_actions, num_atoms, support, name=None): + """Creates the layers used calculating return distributions. + Args: + num_actions: int, number of actions. + num_atoms: int, the number of buckets of the value function distribution. + support: tf.linspace, the support of the Q-value distribution. + name: str, used to crete scope for network parameters. + """ + super(RainbowNetwork, self).__init__(name=name) + activation_fn = tf.keras.activations.relu + self.num_actions = num_actions + self.num_atoms = num_atoms + self.support = support + self.kernel_initializer = tf.keras.initializers.VarianceScaling( + scale=1.0 / np.sqrt(3.0), mode='fan_in', distribution='uniform') + # Defining layers. + self.conv1 = tf.keras.layers.Conv2D( + 32, [8, 8], strides=4, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.conv2 = tf.keras.layers.Conv2D( + 64, [4, 4], strides=2, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.conv3 = tf.keras.layers.Conv2D( + 64, [3, 3], strides=1, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.flatten = tf.keras.layers.Flatten() + self.dense1 = tf.keras.layers.Dense( + 512, activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='fully_connected') + self.dense2 = tf.keras.layers.Dense( + num_actions * num_atoms, kernel_initializer=self.kernel_initializer, + name='fully_connected') + def call(self, state): + """Creates the output tensor/op given the state tensor as input. + See https://www.tensorflow.org/api_docs/python/tf/keras/Model for more + information on this. Note that tf.keras.Model implements `call` which is + wrapped by `__call__` function by tf.keras.Model. + Args: + state: Tensor, input tensor. + Returns: + collections.namedtuple, output ops (graph mode) or output tensors (eager). + """ x = tf.cast(state, tf.float32) x = tf.div(x, 255.) x = merge_last_two_dims(x) @@ -76,5 +112,8 @@ def merge_last_two_dims(tensor): x = self.conv3(x) x = self.flatten(x) x = self.dense1(x) - - return DQNNetworkType(self.dense2(x)) + x = self.dense2(x) + logits = tf.reshape(x, [-1, self.num_actions, self.num_atoms]) + probabilities = tf.keras.activations.softmax(logits) + q_values = tf.reduce_sum(self.support * probabilities, axis=2) + return RainbowNetworkType(q_values, logits, probabilities) \ No newline at end of file diff --git a/rl/utils.py b/rl/utils.py new file mode 100644 index 0000000..d961ef1 --- /dev/null +++ b/rl/utils.py @@ -0,0 +1,26 @@ +import tensorflow as tf + +def infer_shape(x): + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.shape.dims is None: + return tf.shape(x) + + static_shape = x.shape.as_list() + dynamic_shape = tf.shape(x) + + ret = [] + for i in range(len(static_shape)): + dim = static_shape[i] + if dim is None: + dim = dynamic_shape[i] + ret.append(dim) + + return ret + +def merge_last_two_dims(tensor): + shape = infer_shape(tensor) + shape[-2] *= shape[-1] + shape.pop(-1) + return tf.reshape(tensor, shape) \ No newline at end of file From 1b8e3c8c5c1a3b69ac6fb3e36d4ebc32ffe912ae Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Fri, 27 Dec 2019 23:36:12 +0530 Subject: [PATCH 05/10] added configs --- .../gym_snake_classic/envs/snake_classic.py | 31 +++++++++---------- rl/agent.py | 17 +++++----- rl/configs.py | 5 +++ 3 files changed, 29 insertions(+), 24 deletions(-) create mode 100644 rl/configs.py diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 7c4c07b..353a16b 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -1,14 +1,17 @@ import os import gym +import configs import pygame -from PIL import Image import numpy as np -from matplotlib.image import imread -from gym import error, spaces, utils -from gym.utils import seeding -from gym_snake_classic.envs.src.game import Game,GameConfig + +from gym import spaces +from gym.utils import seeding +from matplotlib.image import imread from gym_snake_classic.envs.src.assets import Snake,Food +from gym_snake_classic.envs.src.game import Game,GameConfig + + class SnakeClassicEnv(gym.Env): metadata = {'render.modes':['human']} @@ -23,13 +26,13 @@ class SnakeClassicEnv(gym.Env): def __init__(self): self.temp_filename='_temp_window.jpg' - width,height = (800,600) + width,height = (400,300) self.action_space = spaces.Discrete(4) self.n_steps = 0 self.reward = 0 self.prev_reward = -1 - cfg = GameConfig(width = 800, - height = 600, + cfg = GameConfig(width = width, + height = height, player = Snake, food = Food, player_size = (20,20), @@ -55,24 +58,20 @@ def step(self, action): self.take_action(action) self.snake_game.on_loop() - self.snake_game.on_render(show=True) + self.snake_game.on_render(show=configs.SHOW) #observation #TODO figure out a faster way obs = self._observe() - #done done = self.snake_game.done if done : - self.reward -= 100 + self.reward -= 10 else: - if(self.prev_reward == self.reward): - self.reward -= 1 - else: - self.reward += 10 + if(not self.prev_reward == self.reward): + self.reward += 100 self.prev_reward=self.reward - #info info = {} return (obs,self.reward,done,info) diff --git a/rl/agent.py b/rl/agent.py index 5cc2d29..d61da3a 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -8,17 +8,18 @@ from dopamine.discrete_domains.gym_lib import create_gym_environment from dopamine.agents.rainbow import rainbow_agent from models import SimpleDQNNetwork,RainbowNetwork +import configs +STACK_SIZE = configs.STACK_SIZE +GAMMA = configs.GAMMA +REPLAY_CAPACITY = configs.REPLAY_CAPACITY +BATCH_SIZE = configs.BATCH_SIZE + -STACK_SIZE = 4 -GAMMA = 0.9 -REPLAY_CAPACITY = 10000 -BATCH_SIZE = 32 sess = tf.Session() -#TODO -# * use prioritized buffer + class SnakeDQNAgent(dqn_agent.DQNAgent): def __init__(self,*args,**kwargs): @@ -37,7 +38,7 @@ def _build_replay_buffer(self,use_staging): replay_capacity = REPLAY_CAPACITY, batch_size = BATCH_SIZE, observation_shape=self.observation_shape, - stack_size=self.stack_size, + stack_size=STACK_SIZE, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, @@ -61,7 +62,7 @@ def _build_replay_buffer(self,use_staging): replay_capacity = REPLAY_CAPACITY, batch_size = BATCH_SIZE, observation_shape=self.observation_shape, - stack_size=self.stack_size, + stack_size=STACK_SIZE, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, diff --git a/rl/configs.py b/rl/configs.py new file mode 100644 index 0000000..677970f --- /dev/null +++ b/rl/configs.py @@ -0,0 +1,5 @@ +STACK_SIZE = 4 +GAMMA = 0.99 +REPLAY_CAPACITY = 10000 +BATCH_SIZE = 128 +SHOW = False \ No newline at end of file From fe298aff99a5ecf1d8c36d9a70be149f227ef2fd Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Sat, 28 Dec 2019 09:53:37 +0530 Subject: [PATCH 06/10] log not working --- .../gym_snake_classic/envs/snake_classic.py | 10 +++++ .../gym_snake_classic/envs/src/assets.py | 5 ++- gym-snake/gym_snake_classic/envs/src/game.py | 41 ++++++++++++++++++- rl/agent.py | 2 +- rl/configs.py | 2 +- 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 353a16b..a3c5130 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -14,6 +14,10 @@ class SnakeClassicEnv(gym.Env): + """ + Gym Environment for classic snake game + """ + metadata = {'render.modes':['human']} reward_range = (-np.inf, np.inf) @@ -59,6 +63,7 @@ def step(self, action): self.take_action(action) self.snake_game.on_loop() self.snake_game.on_render(show=configs.SHOW) + #observation #TODO figure out a faster way obs = self._observe() @@ -68,6 +73,7 @@ def step(self, action): if done : self.reward -= 10 + self.reset() else: if(not self.prev_reward == self.reward): self.reward += 100 @@ -86,7 +92,11 @@ def get_reward(self): def reset(self): self.n_steps=0 + self.reward=0 self.snake_game.reset() + self.snake_game.on_loop() + self.snake_game.on_render(show=configs.SHOW) + return self._observe() diff --git a/gym-snake/gym_snake_classic/envs/src/assets.py b/gym-snake/gym_snake_classic/envs/src/assets.py index 111e0da..cc054f1 100644 --- a/gym-snake/gym_snake_classic/envs/src/assets.py +++ b/gym-snake/gym_snake_classic/envs/src/assets.py @@ -16,12 +16,13 @@ def __init__(self,length,window_size): self.reset() def reset(self): + self.direction=0 self.length=self.init_length self.x=[] self.y=[] for _ in range(self.length): - self.x.append(0) - self.y.append(0) + self.x.append(self.window_size[0]/2) + self.y.append(self.window_size[1]/2) def _update(self): diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py index 0a87616..3da1bfd 100644 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -26,6 +26,9 @@ def __init__(self,config:GameConfig)->None: def reset(self): self.player.reset() + self.spawn_food() + + self._running=True self.on_loop() def on_init(self): @@ -82,8 +85,9 @@ def on_loop(self): if( head.colliderect(_pos) ): self._running = False + self.reset() - def on_render(self,show=False): + def on_render(self,show=True): self.display.fill((255,255,255)) self.player.draw(self.display,self.config.player_size) self.food.draw(self.display,self.config.food_size) @@ -104,5 +108,40 @@ def take_action(self,act): if(act=='RIGHT'): self.player.moveRight() + # Not used in the env + def on_execute(self): + if (self.on_init() == False): + self._running = False + + while(1): + while(self._running): + pygame.event.pump() + keys = pygame.key.get_pressed() + + if(keys[K_RIGHT]): + self.player.moveRight() + + if(keys[K_LEFT]): + self.player.moveLeft() + + if(keys[K_UP]): + self.player.moveUp() + + if(keys[K_DOWN]): + self.player.moveDown() + + if(keys[K_ESCAPE]): + self._running=False + exit(0) + + self.on_loop() + if(self.config.render): + self.on_render() + time.sleep(50/1000.0) + + + print("Exiting") + self.on_cleanup() + diff --git a/rl/agent.py b/rl/agent.py index d61da3a..0430277 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -103,7 +103,7 @@ def _agent_fn(sess,env,summary_writer): vmax=10., gamma=0.99, update_horizon=1, - min_replay_history=20000, + min_replay_history=1000, update_period=4, target_update_period=8000, epsilon_fn=dqn_agent.linearly_decaying_epsilon, diff --git a/rl/configs.py b/rl/configs.py index 677970f..08f6d5b 100644 --- a/rl/configs.py +++ b/rl/configs.py @@ -2,4 +2,4 @@ GAMMA = 0.99 REPLAY_CAPACITY = 10000 BATCH_SIZE = 128 -SHOW = False \ No newline at end of file +SHOW = True \ No newline at end of file From bb9af19b1cbd79fe60718bc8f3ffd1f1e41878da Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Sat, 28 Dec 2019 15:49:34 +0530 Subject: [PATCH 07/10] working --- gym-snake/gym_snake_classic/envs/snake_classic.py | 9 ++++++--- gym-snake/gym_snake_classic/envs/src/game.py | 3 ++- rl/agent.py | 2 +- rl/configs.py | 3 ++- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index a3c5130..8ce8b50 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -30,7 +30,7 @@ class SnakeClassicEnv(gym.Env): def __init__(self): self.temp_filename='_temp_window.jpg' - width,height = (400,300) + width,height = (800,600) self.action_space = spaces.Discrete(4) self.n_steps = 0 self.reward = 0 @@ -73,14 +73,17 @@ def step(self, action): if done : self.reward -= 10 + reward = self.reward #reset changes reward self.reset() else: if(not self.prev_reward == self.reward): self.reward += 100 - self.prev_reward=self.reward + reward = self.reward + + self.prev_reward=reward #info info = {} - return (obs,self.reward,done,info) + return (obs,reward,done,info) def take_action(self, action): diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py index 3da1bfd..042692e 100644 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -84,8 +84,9 @@ def on_loop(self): _pos = pygame.Rect(_pos,(1,1)) if( head.colliderect(_pos) ): + # Dont call reset here,its called in env self._running = False - self.reset() + def on_render(self,show=True): self.display.fill((255,255,255)) diff --git a/rl/agent.py b/rl/agent.py index 0430277..8a268ac 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -103,7 +103,7 @@ def _agent_fn(sess,env,summary_writer): vmax=10., gamma=0.99, update_horizon=1, - min_replay_history=1000, + min_replay_history=configs.MIN_REPLAY_HISTORY, update_period=4, target_update_period=8000, epsilon_fn=dqn_agent.linearly_decaying_epsilon, diff --git a/rl/configs.py b/rl/configs.py index 08f6d5b..2911cf2 100644 --- a/rl/configs.py +++ b/rl/configs.py @@ -2,4 +2,5 @@ GAMMA = 0.99 REPLAY_CAPACITY = 10000 BATCH_SIZE = 128 -SHOW = True \ No newline at end of file +SHOW = True +MIN_REPLAY_HISTORY = 10000 \ No newline at end of file From 3a3318fcf67ac488ad2b72dc496acf3915c15385 Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Sat, 28 Dec 2019 18:26:23 +0530 Subject: [PATCH 08/10] working --- .../gym_snake_classic/envs/snake_classic.py | 28 +++++++++---------- rl/agent.py | 5 ++-- rl/configs.py | 10 +++++-- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 8ce8b50..536e0f1 100644 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -42,44 +42,44 @@ def __init__(self): player_size = (20,20), food_size = (20,20), ) - + self.observation_space = spaces.Box(low=0, high=255, shape= (height, width, 3)) self.snake_game = Game(cfg) self.snake_game.on_init() - + @property def env(self): return self - + def _observe(self): pygame.image.save(self.snake_game.window,self.temp_filename) obs = imread(self.temp_filename) return obs def step(self, action): - + self.take_action(action) self.snake_game.on_loop() self.snake_game.on_render(show=configs.SHOW) - + #observation #TODO figure out a faster way obs = self._observe() - + #done done = self.snake_game.done - + if done : - self.reward -= 10 + self.reward -= 100 reward = self.reward #reset changes reward self.reset() else: if(not self.prev_reward == self.reward): - self.reward += 100 + self.reward += 10 reward = self.reward - + self.prev_reward=reward #info info = {} @@ -89,10 +89,10 @@ def step(self, action): def take_action(self, action): act = self.ACTION_LOOKUP[action] self.snake_game.take_action(act) - + def get_reward(self): return self.reward - + def reset(self): self.n_steps=0 self.reward=0 @@ -101,9 +101,7 @@ def reset(self): self.snake_game.on_render(show=configs.SHOW) return self._observe() - + def render(self,mode='human'): self.snake_game.on_render() - - \ No newline at end of file diff --git a/rl/agent.py b/rl/agent.py index 8a268ac..5ac7291 100644 --- a/rl/agent.py +++ b/rl/agent.py @@ -105,7 +105,7 @@ def _agent_fn(sess,env,summary_writer): update_horizon=1, min_replay_history=configs.MIN_REPLAY_HISTORY, update_period=4, - target_update_period=8000, + target_update_period=configs.TARGET_UPDATE_PERIOD, epsilon_fn=dqn_agent.linearly_decaying_epsilon, epsilon_train=0.01, epsilon_eval=0.001, @@ -113,6 +113,7 @@ def _agent_fn(sess,env,summary_writer): replay_scheme='prioritized', tf_device='/gpu:*', summary_writer=summary_writer, + summary_writing_frequency = config.SUMMARY_WRITING_FREQUENCY ) return AGENT @@ -120,7 +121,7 @@ def _env_fn(*args): return env runner = Runner( - base_dir = '_tmp_agent_dir/', + base_dir = configs.BASE_DIR, create_agent_fn = _agent_fn, create_environment_fn= _env_fn, checkpoint_file_prefix='ckpt', diff --git a/rl/configs.py b/rl/configs.py index 2911cf2..a245251 100644 --- a/rl/configs.py +++ b/rl/configs.py @@ -1,6 +1,10 @@ STACK_SIZE = 4 GAMMA = 0.99 REPLAY_CAPACITY = 10000 -BATCH_SIZE = 128 -SHOW = True -MIN_REPLAY_HISTORY = 10000 \ No newline at end of file +BATCH_SIZE = 32 +SHOW = False +MIN_REPLAY_HISTORY = 10000 #number of transitions that should be experienced + #before the agent begins training its value function +BASE_DIR = 'summaries/snake_classic/' +TARGET_UPDATE_PERIOD = 500 # update period for the target network +SUMMARY_WRITING_FREQUENCY = 500 \ No newline at end of file From 78a256fc984a907cbe547fc5dd21efcb82371b0e Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Mon, 30 Dec 2019 15:52:21 +0530 Subject: [PATCH 09/10] working --- .gitignore | 0 README.md | 0 __pycache__/assets.cpython-37.pyc | Bin __pycache__/window.cpython-36.pyc | Bin __pycache__/window.cpython-37.pyc | Bin gym-snake/gym_snake_classic.egg-info/PKG-INFO | 0 .../gym_snake_classic.egg-info/SOURCES.txt | 0 .../dependency_links.txt | 0 .../gym_snake_classic.egg-info/requires.txt | 0 .../gym_snake_classic.egg-info/top_level.txt | 0 gym-snake/gym_snake_classic/__init__.py | 0 gym-snake/gym_snake_classic/envs/__init__.py | 0 .../gym_snake_classic/envs/snake_classic.py | 25 +++-- .../gym_snake_classic/envs/src/assets.py | 0 gym-snake/gym_snake_classic/envs/src/game.py | 11 +- gym-snake/gym_snake_classic/envs/src/main.py | 4 +- gym-snake/setup.py | 0 rl/configs.py | 13 ++- rl/models.py | 3 +- rl/readme.md | 0 rl/requirements.txt | 0 rl/test_env.py | 0 rl/{agent.py => train.py} | 104 +++++++++--------- rl/utils.py | 0 snake_game/assets.py | 0 snake_game/game.py | 0 snake_game/main.py | 0 test_env.py | 0 28 files changed, 92 insertions(+), 68 deletions(-) mode change 100644 => 100755 .gitignore mode change 100644 => 100755 README.md mode change 100644 => 100755 __pycache__/assets.cpython-37.pyc mode change 100644 => 100755 __pycache__/window.cpython-36.pyc mode change 100644 => 100755 __pycache__/window.cpython-37.pyc mode change 100644 => 100755 gym-snake/gym_snake_classic.egg-info/PKG-INFO mode change 100644 => 100755 gym-snake/gym_snake_classic.egg-info/SOURCES.txt mode change 100644 => 100755 gym-snake/gym_snake_classic.egg-info/dependency_links.txt mode change 100644 => 100755 gym-snake/gym_snake_classic.egg-info/requires.txt mode change 100644 => 100755 gym-snake/gym_snake_classic.egg-info/top_level.txt mode change 100644 => 100755 gym-snake/gym_snake_classic/__init__.py mode change 100644 => 100755 gym-snake/gym_snake_classic/envs/__init__.py mode change 100644 => 100755 gym-snake/gym_snake_classic/envs/snake_classic.py mode change 100644 => 100755 gym-snake/gym_snake_classic/envs/src/assets.py mode change 100644 => 100755 gym-snake/gym_snake_classic/envs/src/game.py mode change 100644 => 100755 gym-snake/gym_snake_classic/envs/src/main.py mode change 100644 => 100755 gym-snake/setup.py mode change 100644 => 100755 rl/configs.py mode change 100644 => 100755 rl/models.py mode change 100644 => 100755 rl/readme.md mode change 100644 => 100755 rl/requirements.txt mode change 100644 => 100755 rl/test_env.py rename rl/{agent.py => train.py} (59%) mode change 100644 => 100755 mode change 100644 => 100755 rl/utils.py mode change 100644 => 100755 snake_game/assets.py mode change 100644 => 100755 snake_game/game.py mode change 100644 => 100755 snake_game/main.py mode change 100644 => 100755 test_env.py diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/__pycache__/assets.cpython-37.pyc b/__pycache__/assets.cpython-37.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/window.cpython-36.pyc b/__pycache__/window.cpython-36.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/window.cpython-37.pyc b/__pycache__/window.cpython-37.pyc old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/PKG-INFO b/gym-snake/gym_snake_classic.egg-info/PKG-INFO old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/SOURCES.txt b/gym-snake/gym_snake_classic.egg-info/SOURCES.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/dependency_links.txt b/gym-snake/gym_snake_classic.egg-info/dependency_links.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/requires.txt b/gym-snake/gym_snake_classic.egg-info/requires.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/top_level.txt b/gym-snake/gym_snake_classic.egg-info/top_level.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/__init__.py b/gym-snake/gym_snake_classic/__init__.py old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/envs/__init__.py b/gym-snake/gym_snake_classic/envs/__init__.py old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py old mode 100644 new mode 100755 index 536e0f1..876f731 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -7,10 +7,15 @@ from gym import spaces from gym.utils import seeding -from matplotlib.image import imread +from PIL import Image from gym_snake_classic.envs.src.assets import Snake,Food from gym_snake_classic.envs.src.game import Game,GameConfig +WIDTH = 400 +HEIGHT = 400 + +OBS_WIDTH=256 +OBS_HEIGHT=256 class SnakeClassicEnv(gym.Env): @@ -30,11 +35,10 @@ class SnakeClassicEnv(gym.Env): def __init__(self): self.temp_filename='_temp_window.jpg' - width,height = (800,600) + width,height = (WIDTH,HEIGHT) self.action_space = spaces.Discrete(4) self.n_steps = 0 self.reward = 0 - self.prev_reward = -1 cfg = GameConfig(width = width, height = height, player = Snake, @@ -44,10 +48,11 @@ def __init__(self): ) self.observation_space = spaces.Box(low=0, high=255, shape= - (height, width, 3)) + (OBS_HEIGHT, OBS_WIDTH, 3)) self.snake_game = Game(cfg) self.snake_game.on_init() + self.prev_length = self.snake_game.player.length @property def env(self): @@ -55,11 +60,13 @@ def env(self): def _observe(self): pygame.image.save(self.snake_game.window,self.temp_filename) - obs = imread(self.temp_filename) + obs = Image.open(self.temp_filename) + obs=obs.resize((OBS_HEIGHT,OBS_WIDTH),Image.BILINEAR) + obs=np.array(obs) return obs def step(self, action): - + self.take_action(action) self.snake_game.on_loop() self.snake_game.on_render(show=configs.SHOW) @@ -76,11 +83,11 @@ def step(self, action): reward = self.reward #reset changes reward self.reset() else: - if(not self.prev_reward == self.reward): - self.reward += 10 + if(not self.prev_length == self.snake_game.player.length): + self.reward += 1000 reward = self.reward - self.prev_reward=reward + self.prev_length=self.snake_game.player.length #info info = {} return (obs,reward,done,info) diff --git a/gym-snake/gym_snake_classic/envs/src/assets.py b/gym-snake/gym_snake_classic/envs/src/assets.py old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py old mode 100644 new mode 100755 index 042692e..8e95b74 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -6,6 +6,8 @@ from typing import Any,Tuple from random import randint +MAX_RUN = 5000 + @dataclass class GameConfig: height : int @@ -48,9 +50,16 @@ def spawn_food(self): step = self.food.step nx=randint(2,10)*step ny=randint(2,10)*step - while((nx,ny) in zip(self.player.x,self.player.y)): + count = 0 + while ( ((nx,ny) in zip(self.player.x,self.player.y)) ) and (count=MAX_RUN: + # Hack so that env pushes towards this + self.player.length+=1000 + print("Game Completed!!!") + self._running=False self.food.position=(nx,ny) @property diff --git a/gym-snake/gym_snake_classic/envs/src/main.py b/gym-snake/gym_snake_classic/envs/src/main.py old mode 100644 new mode 100755 index 795891d..6462c87 --- a/gym-snake/gym_snake_classic/envs/src/main.py +++ b/gym-snake/gym_snake_classic/envs/src/main.py @@ -2,8 +2,8 @@ from assets import Snake,Food -cfg = GameConfig(width = 800, - height = 600, +cfg = GameConfig(width = 400, + height = 400, player = Snake, food = Food, player_size = (20,20), diff --git a/gym-snake/setup.py b/gym-snake/setup.py old mode 100644 new mode 100755 diff --git a/rl/configs.py b/rl/configs.py old mode 100644 new mode 100755 index a245251..884b0b8 --- a/rl/configs.py +++ b/rl/configs.py @@ -1,10 +1,11 @@ -STACK_SIZE = 4 +STACK_SIZE = 2 GAMMA = 0.99 -REPLAY_CAPACITY = 10000 -BATCH_SIZE = 32 +REPLAY_CAPACITY = 7500 +BATCH_SIZE = 64 SHOW = False -MIN_REPLAY_HISTORY = 10000 #number of transitions that should be experienced +MIN_REPLAY_HISTORY = REPLAY_CAPACITY #number of transitions that should be experienced #before the agent begins training its value function BASE_DIR = 'summaries/snake_classic/' -TARGET_UPDATE_PERIOD = 500 # update period for the target network -SUMMARY_WRITING_FREQUENCY = 500 \ No newline at end of file +TARGET_UPDATE_PERIOD = 1000 # update period for the target network +SUMMARY_WRITING_FREQUENCY = 50 +EVAL_MODE=False \ No newline at end of file diff --git a/rl/models.py b/rl/models.py old mode 100644 new mode 100755 index 8fb3bd4..36430e0 --- a/rl/models.py +++ b/rl/models.py @@ -31,7 +31,7 @@ def __init__(self, num_actions, name=None): self.conv3 = tf.keras.layers.Conv2D(64, [3, 3], strides=1, padding='same', activation=activation_fn, name='Conv') self.flatten = tf.keras.layers.Flatten() - self.dense1 = tf.keras.layers.Dense(512, activation=activation_fn, + self.dense1 = tf.keras.layers.Dense(256, activation=activation_fn, name='fully_connected') self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected') @@ -55,6 +55,7 @@ def call(self, state,): x = self.conv3(x) x = self.flatten(x) x = self.dense1(x) + # x = self.dense1(x) return DQNNetworkType(self.dense2(x)) diff --git a/rl/readme.md b/rl/readme.md old mode 100644 new mode 100755 diff --git a/rl/requirements.txt b/rl/requirements.txt old mode 100644 new mode 100755 diff --git a/rl/test_env.py b/rl/test_env.py old mode 100644 new mode 100755 diff --git a/rl/agent.py b/rl/train.py old mode 100644 new mode 100755 similarity index 59% rename from rl/agent.py rename to rl/train.py index 5ac7291..cfbcaad --- a/rl/agent.py +++ b/rl/train.py @@ -1,14 +1,14 @@ import gym +import configs import numpy as np import tensorflow as tf -from dopamine.replay_memory import prioritized_replay_buffer,circular_replay_buffer +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer from dopamine.agents.dqn import dqn_agent from dopamine.discrete_domains.run_experiment import Runner from dopamine.discrete_domains.gym_lib import create_gym_environment from dopamine.agents.rainbow import rainbow_agent -from models import SimpleDQNNetwork,RainbowNetwork -import configs +from models import SimpleDQNNetwork, RainbowNetwork STACK_SIZE = configs.STACK_SIZE @@ -20,12 +20,11 @@ sess = tf.Session() - class SnakeDQNAgent(dqn_agent.DQNAgent): - def __init__(self,*args,**kwargs): - super().__init__(*args,**kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def _build_replay_buffer(self,use_staging): + def _build_replay_buffer(self, use_staging): """Creates the replay buffer used by the agent. Args: @@ -35,21 +34,22 @@ def _build_replay_buffer(self,use_staging): A WrapperReplayBuffer object. """ return circular_replay_buffer.WrappedReplayBuffer( - replay_capacity = REPLAY_CAPACITY, - batch_size = BATCH_SIZE, + replay_capacity=REPLAY_CAPACITY, + batch_size=BATCH_SIZE, observation_shape=self.observation_shape, stack_size=STACK_SIZE, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, - observation_dtype=self.observation_dtype.as_numpy_dtype + observation_dtype=self.observation_dtype.as_numpy_dtype, ) + class SnakeRainbowAgent(rainbow_agent.RainbowAgent): - def __init__(self,*args,**kwargs): - super().__init__(*args,**kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def _build_replay_buffer(self,use_staging): + def _build_replay_buffer(self, use_staging): """Creates the replay buffer used by the agent. Args: @@ -59,38 +59,42 @@ def _build_replay_buffer(self,use_staging): A WrapperReplayBuffer object. """ return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( - replay_capacity = REPLAY_CAPACITY, - batch_size = BATCH_SIZE, + replay_capacity=REPLAY_CAPACITY, + batch_size=BATCH_SIZE, observation_shape=self.observation_shape, stack_size=STACK_SIZE, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, - observation_dtype=self.observation_dtype.as_numpy_dtype + observation_dtype=self.observation_dtype.as_numpy_dtype, ) - env = create_gym_environment( - environment_name="gym_snake_classic:SnakeClassic", - version = 'v0' - ) - - - + environment_name="gym_snake_classic:SnakeClassic", version="v0" +) - -def _agent_fn(sess,env,summary_writer): +def _agent_fn(sess, env, summary_writer): # AGENT = SnakeDQNAgent( # sess=sess, # num_actions = env.action_space.n, # observation_shape = env.observation_space.shape, # stack_size = STACK_SIZE, # network = SimpleDQNNetwork, - # gamma=GAMMA, - # tf_device = '/gpu:0' , - # summary_writer=summary_writer + # gamma=0.99, + # update_horizon=1, + # min_replay_history=configs.MIN_REPLAY_HISTORY, + # update_period=4, + # target_update_period=configs.TARGET_UPDATE_PERIOD, + # epsilon_fn=dqn_agent.linearly_decaying_epsilon, + # epsilon_train=0.01, + # epsilon_eval=0.001, + # epsilon_decay_period=250000, + # eval_mode=configs.EVAL_MODE , # True for training + # tf_device="/gpu:*", + # summary_writer=summary_writer, + # summary_writing_frequency=configs.SUMMARY_WRITING_FREQUENCY, # ) AGENT = SnakeRainbowAgent( @@ -100,7 +104,7 @@ def _agent_fn(sess,env,summary_writer): stack_size=STACK_SIZE, network=RainbowNetwork, num_atoms=51, - vmax=10., + vmax=10.0, gamma=0.99, update_horizon=1, min_replay_history=configs.MIN_REPLAY_HISTORY, @@ -110,31 +114,33 @@ def _agent_fn(sess,env,summary_writer): epsilon_train=0.01, epsilon_eval=0.001, epsilon_decay_period=250000, - replay_scheme='prioritized', - tf_device='/gpu:*', + eval_mode=False , # True for training + replay_scheme="prioritized", + tf_device="/gpu:*", summary_writer=summary_writer, - summary_writing_frequency = config.SUMMARY_WRITING_FREQUENCY + summary_writing_frequency=configs.SUMMARY_WRITING_FREQUENCY, ) return AGENT + def _env_fn(*args): return env -runner = Runner( - base_dir = configs.BASE_DIR, - create_agent_fn = _agent_fn, - create_environment_fn= _env_fn, - checkpoint_file_prefix='ckpt', - logging_file_prefix='log', - log_every_n=10, - num_iterations=2000, - training_steps=25000, - evaluation_steps=12500, - max_steps_per_episode=10000 - ) - - -runner.run_experiment() - - +runner = Runner( + base_dir=configs.BASE_DIR, + create_agent_fn=_agent_fn, + create_environment_fn=_env_fn, + checkpoint_file_prefix="ckpt", + logging_file_prefix="log", + log_every_n=10, + num_iterations=2000, + training_steps=25000, + evaluation_steps=12500, + max_steps_per_episode=10000, +) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + runner.run_experiment() diff --git a/rl/utils.py b/rl/utils.py old mode 100644 new mode 100755 diff --git a/snake_game/assets.py b/snake_game/assets.py old mode 100644 new mode 100755 diff --git a/snake_game/game.py b/snake_game/game.py old mode 100644 new mode 100755 diff --git a/snake_game/main.py b/snake_game/main.py old mode 100644 new mode 100755 diff --git a/test_env.py b/test_env.py old mode 100644 new mode 100755 From 930c2be12518c086bbfaaec7797485b234b559fe Mon Sep 17 00:00:00 2001 From: Aghin Shah Alin Date: Wed, 1 Jan 2020 15:19:45 +0530 Subject: [PATCH 10/10] working --- README.md | 1 + agents.py | 95 ++++++++++++++++++ .../gym_snake_classic/envs/snake_classic.py | 14 ++- .../gym_snake_classic/envs/src/assets.py | 4 +- gym-snake/gym_snake_classic/envs/src/game.py | 8 +- rl/configs.py | 4 +- rl/train.py | 60 ++--------- rl/utils.py | 99 ++++++++++++++++++- snake_game/assets.py | 15 +-- snake_game/game.py | 17 ++-- 10 files changed, 235 insertions(+), 82 deletions(-) create mode 100755 agents.py diff --git a/README.md b/README.md index b1111aa..494d777 100755 --- a/README.md +++ b/README.md @@ -15,3 +15,4 @@ pip install gym-snake env = gym.make("gym_snake_classic:SnakeClassic-v0") ``` +* Size need to be change in params plotter utils \ No newline at end of file diff --git a/agents.py b/agents.py new file mode 100755 index 0000000..3260162 --- /dev/null +++ b/agents.py @@ -0,0 +1,95 @@ +import gym +import configs +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import slim as contrib_slim +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer +from dopamine.agents.dqn import dqn_agent +from dopamine.discrete_domains.run_experiment import Runner +from dopamine.discrete_domains.gym_lib import create_gym_environment +from dopamine.discrete_domains import atari_lib +from dopamine.agents.rainbow import rainbow_agent +from models import SimpleDQNNetwork, RainbowNetwork + + +class SnakeDQNAgent(dqn_agent.DQNAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _build_replay_buffer(self, use_staging): + + """Creates the replay buffer used by the agent. + Args: + use_staging: bool, if True, uses a staging area to prefetch data for + faster training. + Returns: + A WrapperReplayBuffer object. + """ + return circular_replay_buffer.WrappedReplayBuffer( + replay_capacity=configs.REPLAY_CAPACITY, + batch_size=configs.BATCH_SIZE, + observation_shape=self.observation_shape, + stack_size=configs.STACK_SIZE, + use_staging=use_staging, + update_horizon=self.update_horizon, + gamma=self.gamma, + observation_dtype=self.observation_dtype.as_numpy_dtype, + ) + + +class SnakeRainbowAgent(rainbow_agent.RainbowAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.rewards = [] # For visualizer + + def step(self, reward, observation): + self.rewards.append(reward) + return super().step(reward, observation) + def get_rewards(self): + return [np.cumsum(self.rewards)] + + def _build_replay_buffer(self, use_staging): + + """Creates the replay buffer used by the agent. + Args: + use_staging: bool, if True, uses a staging area to prefetch data for + faster training. + Returns: + A WrapperReplayBuffer object. + """ + return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( + replay_capacity=configs.REPLAY_CAPACITY, + batch_size=configs.BATCH_SIZE, + observation_shape=self.observation_shape, + stack_size=configs.STACK_SIZE, + use_staging=use_staging, + update_horizon=self.update_horizon, + gamma=self.gamma, + observation_dtype=self.observation_dtype.as_numpy_dtype, + ) + + def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=False): + if use_legacy_checkpoint: + variables_to_restore = atari_lib.maybe_transform_variable_names( + tf.all_variables(), legacy_checkpoint_load=True) + else: + global_vars = set([x.name for x in tf.global_variables()]) + ckpt_vars = [ + '{}:0'.format(name) + for name, _ in tf.train.list_variables(checkpoint_path) + ] + include_vars = list(global_vars.intersection(set(ckpt_vars))) + variables_to_restore = contrib_slim.get_variables_to_restore( + include=include_vars) + if variables_to_restore: + reloader = tf.train.Saver(var_list=variables_to_restore) + reloader.restore(self._sess, checkpoint_path) + tf.logging.info('Done restoring from %s', checkpoint_path) + else: + tf.logging.info('Nothing to restore!') + + def get_probabilities(self): + return self._sess.run(tf.squeeze(self._net_outputs.probabilities), + {self.state_ph: self.state}) \ No newline at end of file diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py index 876f731..6649628 100755 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -3,7 +3,7 @@ import configs import pygame import numpy as np - +import inspect from gym import spaces from gym.utils import seeding @@ -58,16 +58,17 @@ def __init__(self): def env(self): return self - def _observe(self): + def _observe(self,resize=True): pygame.image.save(self.snake_game.window,self.temp_filename) obs = Image.open(self.temp_filename) - obs=obs.resize((OBS_HEIGHT,OBS_WIDTH),Image.BILINEAR) + if resize: + obs=obs.resize((OBS_HEIGHT,OBS_WIDTH),Image.BILINEAR) obs=np.array(obs) return obs def step(self, action): - self.take_action(action) + self.reward-=1 self.snake_game.on_loop() self.snake_game.on_render(show=configs.SHOW) @@ -79,7 +80,7 @@ def step(self, action): done = self.snake_game.done if done : - self.reward -= 100 + self.reward -= 1000 reward = self.reward #reset changes reward self.reset() else: @@ -111,4 +112,7 @@ def reset(self): def render(self,mode='human'): + if mode=='rgb_array': + return self._observe(resize=False) self.snake_game.on_render() + diff --git a/gym-snake/gym_snake_classic/envs/src/assets.py b/gym-snake/gym_snake_classic/envs/src/assets.py index cc054f1..a510c81 100755 --- a/gym-snake/gym_snake_classic/envs/src/assets.py +++ b/gym-snake/gym_snake_classic/envs/src/assets.py @@ -45,8 +45,8 @@ def update(self): elif self.direction == 3: self.y[0] += self.step - self.x[0] %= self.window_size[0] - self.y[0] %= self.window_size[1] + self.x[0] %= self.window_size[0]-20 + self.y[0] %= self.window_size[1]-20 def moveRight(self): diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py index 8e95b74..d954b13 100755 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -48,13 +48,13 @@ def on_init(self): def spawn_food(self): step = self.food.step - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%(self.window_size[0]-10) + ny=(randint(2,10)*step)%(self.window_size[1]-10) count = 0 while ( ((nx,ny) in zip(self.player.x,self.player.y)) ) and (count=MAX_RUN: # Hack so that env pushes towards this self.player.length+=1000 diff --git a/rl/configs.py b/rl/configs.py index 884b0b8..f45fcb8 100755 --- a/rl/configs.py +++ b/rl/configs.py @@ -1,4 +1,4 @@ -STACK_SIZE = 2 +STACK_SIZE = 4 GAMMA = 0.99 REPLAY_CAPACITY = 7500 BATCH_SIZE = 64 @@ -8,4 +8,4 @@ BASE_DIR = 'summaries/snake_classic/' TARGET_UPDATE_PERIOD = 1000 # update period for the target network SUMMARY_WRITING_FREQUENCY = 50 -EVAL_MODE=False \ No newline at end of file +EVAL_MODE=True \ No newline at end of file diff --git a/rl/train.py b/rl/train.py index cfbcaad..457c0e2 100755 --- a/rl/train.py +++ b/rl/train.py @@ -7,9 +7,12 @@ from dopamine.agents.dqn import dqn_agent from dopamine.discrete_domains.run_experiment import Runner from dopamine.discrete_domains.gym_lib import create_gym_environment +from dopamine.discrete_domains import atari_lib from dopamine.agents.rainbow import rainbow_agent from models import SimpleDQNNetwork, RainbowNetwork +from utils import SnakeRunner +from agents import SnakeRainbowAgent STACK_SIZE = configs.STACK_SIZE GAMMA = configs.GAMMA @@ -20,56 +23,6 @@ sess = tf.Session() -class SnakeDQNAgent(dqn_agent.DQNAgent): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _build_replay_buffer(self, use_staging): - - """Creates the replay buffer used by the agent. - Args: - use_staging: bool, if True, uses a staging area to prefetch data for - faster training. - Returns: - A WrapperReplayBuffer object. - """ - return circular_replay_buffer.WrappedReplayBuffer( - replay_capacity=REPLAY_CAPACITY, - batch_size=BATCH_SIZE, - observation_shape=self.observation_shape, - stack_size=STACK_SIZE, - use_staging=use_staging, - update_horizon=self.update_horizon, - gamma=self.gamma, - observation_dtype=self.observation_dtype.as_numpy_dtype, - ) - - -class SnakeRainbowAgent(rainbow_agent.RainbowAgent): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _build_replay_buffer(self, use_staging): - - """Creates the replay buffer used by the agent. - Args: - use_staging: bool, if True, uses a staging area to prefetch data for - faster training. - Returns: - A WrapperReplayBuffer object. - """ - return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( - replay_capacity=REPLAY_CAPACITY, - batch_size=BATCH_SIZE, - observation_shape=self.observation_shape, - stack_size=STACK_SIZE, - use_staging=use_staging, - update_horizon=self.update_horizon, - gamma=self.gamma, - observation_dtype=self.observation_dtype.as_numpy_dtype, - ) - - env = create_gym_environment( environment_name="gym_snake_classic:SnakeClassic", version="v0" ) @@ -114,7 +67,7 @@ def _agent_fn(sess, env, summary_writer): epsilon_train=0.01, epsilon_eval=0.001, epsilon_decay_period=250000, - eval_mode=False , # True for training + eval_mode=configs.EVAL_MODE , # True for training replay_scheme="prioritized", tf_device="/gpu:*", summary_writer=summary_writer, @@ -123,11 +76,12 @@ def _agent_fn(sess, env, summary_writer): return AGENT + def _env_fn(*args): return env -runner = Runner( +runner = SnakeRunner( base_dir=configs.BASE_DIR, create_agent_fn=_agent_fn, create_environment_fn=_env_fn, @@ -144,3 +98,5 @@ def _env_fn(*args): if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) runner.run_experiment() + # runner.visualize(record_path = configs.BASE_DIR+'visualize/', + # num_global_steps=500) diff --git a/rl/utils.py b/rl/utils.py index d961ef1..0794a5d 100755 --- a/rl/utils.py +++ b/rl/utils.py @@ -1,4 +1,12 @@ import tensorflow as tf +import pygame +from dopamine.discrete_domains.run_experiment import Runner +from dopamine.discrete_domains import run_experiment +from dopamine.utils import agent_visualizer +from dopamine.utils import atari_plotter +from dopamine.utils import bar_plotter +from dopamine.utils import line_plotter +from dopamine.utils import plotter def infer_shape(x): x = tf.convert_to_tensor(x) @@ -23,4 +31,93 @@ def merge_last_two_dims(tensor): shape = infer_shape(tensor) shape[-2] *= shape[-1] shape.pop(-1) - return tf.reshape(tensor, shape) \ No newline at end of file + return tf.reshape(tensor, shape) + +class SnakeRunner(Runner): + def visualize(self, record_path, num_global_steps=500): + if not tf.gfile.Exists(record_path): + tf.gfile.MakeDirs(record_path) + self._agent.eval_mode = True + + # Set up the game playback rendering. + atari_params = { + 'environment': self._environment, + 'width': 400, + 'height': 400, + } + atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params) + # Plot the rewards received next to it. + reward_params = {'x': atari_plot.parameters['width'], + 'xlabel': 'Timestep', + 'ylabel': 'Reward', + 'title': 'Rewards', + 'get_line_data_fn': self._agent.get_rewards} + reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params) + action_names = [ + 'Action {}'.format(x) for x in range(self._agent.num_actions)] + # Plot Q-values (DQN) or Q-value distributions (Rainbow). + q_params = {'x': atari_plot.parameters['width'] // 2, + 'y': atari_plot.parameters['height'], + 'legend': action_names} + if 'DQN' in self._agent.__class__.__name__: + q_params['xlabel'] = 'Timestep' + q_params['ylabel'] = 'Q-Value' + q_params['title'] = 'Q-Values' + q_params['get_line_data_fn'] = self._agent.get_q_values + q_plot = line_plotter.LinePlotter(parameter_dict=q_params) + else: + q_params['xlabel'] = 'Return' + q_params['ylabel'] = 'Return probability' + q_params['title'] = 'Return distribution' + q_params['get_bar_data_fn'] = self._agent.get_probabilities + q_plot = bar_plotter.BarPlotter(parameter_dict=q_params) + screen_width = ( + atari_plot.parameters['width'] + reward_plot.parameters['width']) + screen_height = ( + atari_plot.parameters['height'] + q_plot.parameters['height']) + # Dimensions need to be divisible by 2: + if screen_width % 2 > 0: + screen_width += 1 + if screen_height % 2 > 0: + screen_height += 1 + visualizer = agent_visualizer.AgentVisualizer( + record_path=record_path, plotters=[atari_plot, reward_plot, q_plot], + screen_width=screen_width, screen_height=screen_height) + + global_step = 0 + while global_step < num_global_steps: + initial_observation = self._environment.reset() + action = self._agent.begin_episode(initial_observation) + while True: + observation, reward, is_terminal, _ = self._environment.step(action) + global_step += 1 + visualizer.visualize() + if self._environment.game_over or global_step >= num_global_steps: + break + elif is_terminal: + self._agent.end_episode(reward) + action = self._agent.begin_episode(observation) + else: + action = self._agent.step(reward, observation) + self._end_episode(reward) + visualizer.generate_video() + +class SnakePlotter(plotter.Plotter): + def __init__(self, parameter_dict=None): + super().__init__(parameter_dict) + assert 'environment' in self.parameters + self.game_surface = pygame.Surface((self.parameters['width'], + self.parameters['height'])) + + def draw(self): + """Render the Atari 2600 frame. + + Returns: + object to be rendered by AgentVisualizer. + """ + environment = self.parameters['environment'] + obs = environment.render(mode='rgb_array').astype(np.int32) + + return pygame.transform.scale(self.game_surface, + (self.parameters['width'], + self.parameters['height'])) \ No newline at end of file diff --git a/snake_game/assets.py b/snake_game/assets.py index 67b7fcf..a26155a 100755 --- a/snake_game/assets.py +++ b/snake_game/assets.py @@ -36,8 +36,8 @@ def update(self): elif self.direction == 3: self.y[0] += self.step - self.x[0] %= self.window_size[0] - self.y[0] %= self.window_size[1] + self.x[0] %= self.window_size[0]-20 + self.y[0] %= self.window_size[1]-256 def moveRight(self): @@ -79,16 +79,17 @@ class Food: x,y=(0,0) step=44 - def __init__(self,x,y): - self.x = x*self.step - self.y = y*self.step + def __init__(self,x,y,window_size): + self.window_size=window_size + self.x = (x*self.step)%window_size[0] + self.y = (y*self.step)%window_size[1] @property def position(self): return (self.x,self.y) @position.setter def position(self,value): - self.x=value[0] - self.y=value[1] + self.x=(value[0])%(window_size[0]-5) + self.y=(value[1])%(window_size[1]-5) def draw(self,surface,food_size): pygame.draw.rect(surface,(255,153,51), diff --git a/snake_game/game.py b/snake_game/game.py index d45bf0f..f818f66 100755 --- a/snake_game/game.py +++ b/snake_game/game.py @@ -22,7 +22,7 @@ def __init__(self,config:GameConfig)->None: self.window_size = (self.config.width,self.config.height) self.player = self.config.player(5,self.window_size) self._running = True - self.food = self.config.food(5,5) # setting init position + self.food = self.config.food(5,5,window_size) # setting init position def on_init(self): pygame.init() @@ -33,20 +33,20 @@ def on_init(self): pygame.display.set_caption('Snake') self._running=True self.snake_body = pygame.Surface( self.config.player_size ) - + self.food_img = pygame.Surface( self.config.food_size ) def on_event(self,event): if event.type == QUIT: self._running=False - + def spawn_food(self): step = self.food.step - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%self.window_size[0] + ny=(randint(2,10)*step)%self.window_size[1] while((nx,ny) in zip(self.player.x,self.player.y)): - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%self.window_size[0] + ny=(randint(2,10)*step)%self.window_size[1] self.food.position=(nx,ny) @@ -56,8 +56,7 @@ def on_loop(self): # check collison with food head = self.snake_body.get_rect(topleft=self.player.position) food_pos = self.food_img.get_rect(topleft=self.food.position) - - + if( head.colliderect(food_pos) ): self.player.length = self.player.length+1 self.player.eat(food_pos)