diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index ba0430d..583c72e --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -__pycache__/ \ No newline at end of file +__pycache__/ +checkpoints/ +logs/ +*.pluto \ No newline at end of file diff --git a/README.md b/README.md old mode 100644 new mode 100755 index b1111aa..494d777 --- a/README.md +++ b/README.md @@ -15,3 +15,4 @@ pip install gym-snake env = gym.make("gym_snake_classic:SnakeClassic-v0") ``` +* Size need to be change in params plotter utils \ No newline at end of file diff --git a/__pycache__/assets.cpython-37.pyc b/__pycache__/assets.cpython-37.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/window.cpython-36.pyc b/__pycache__/window.cpython-36.pyc old mode 100644 new mode 100755 diff --git a/__pycache__/window.cpython-37.pyc b/__pycache__/window.cpython-37.pyc old mode 100644 new mode 100755 diff --git a/agents.py b/agents.py new file mode 100755 index 0000000..3260162 --- /dev/null +++ b/agents.py @@ -0,0 +1,95 @@ +import gym +import configs +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import slim as contrib_slim +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer +from dopamine.agents.dqn import dqn_agent +from dopamine.discrete_domains.run_experiment import Runner +from dopamine.discrete_domains.gym_lib import create_gym_environment +from dopamine.discrete_domains import atari_lib +from dopamine.agents.rainbow import rainbow_agent +from models import SimpleDQNNetwork, RainbowNetwork + + +class SnakeDQNAgent(dqn_agent.DQNAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _build_replay_buffer(self, use_staging): + + """Creates the replay buffer used by the agent. + Args: + use_staging: bool, if True, uses a staging area to prefetch data for + faster training. + Returns: + A WrapperReplayBuffer object. + """ + return circular_replay_buffer.WrappedReplayBuffer( + replay_capacity=configs.REPLAY_CAPACITY, + batch_size=configs.BATCH_SIZE, + observation_shape=self.observation_shape, + stack_size=configs.STACK_SIZE, + use_staging=use_staging, + update_horizon=self.update_horizon, + gamma=self.gamma, + observation_dtype=self.observation_dtype.as_numpy_dtype, + ) + + +class SnakeRainbowAgent(rainbow_agent.RainbowAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.rewards = [] # For visualizer + + def step(self, reward, observation): + self.rewards.append(reward) + return super().step(reward, observation) + def get_rewards(self): + return [np.cumsum(self.rewards)] + + def _build_replay_buffer(self, use_staging): + + """Creates the replay buffer used by the agent. + Args: + use_staging: bool, if True, uses a staging area to prefetch data for + faster training. + Returns: + A WrapperReplayBuffer object. + """ + return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( + replay_capacity=configs.REPLAY_CAPACITY, + batch_size=configs.BATCH_SIZE, + observation_shape=self.observation_shape, + stack_size=configs.STACK_SIZE, + use_staging=use_staging, + update_horizon=self.update_horizon, + gamma=self.gamma, + observation_dtype=self.observation_dtype.as_numpy_dtype, + ) + + def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=False): + if use_legacy_checkpoint: + variables_to_restore = atari_lib.maybe_transform_variable_names( + tf.all_variables(), legacy_checkpoint_load=True) + else: + global_vars = set([x.name for x in tf.global_variables()]) + ckpt_vars = [ + '{}:0'.format(name) + for name, _ in tf.train.list_variables(checkpoint_path) + ] + include_vars = list(global_vars.intersection(set(ckpt_vars))) + variables_to_restore = contrib_slim.get_variables_to_restore( + include=include_vars) + if variables_to_restore: + reloader = tf.train.Saver(var_list=variables_to_restore) + reloader.restore(self._sess, checkpoint_path) + tf.logging.info('Done restoring from %s', checkpoint_path) + else: + tf.logging.info('Nothing to restore!') + + def get_probabilities(self): + return self._sess.run(tf.squeeze(self._net_outputs.probabilities), + {self.state_ph: self.state}) \ No newline at end of file diff --git a/gym-snake/gym_snake_classic.egg-info/PKG-INFO b/gym-snake/gym_snake_classic.egg-info/PKG-INFO old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/SOURCES.txt b/gym-snake/gym_snake_classic.egg-info/SOURCES.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/dependency_links.txt b/gym-snake/gym_snake_classic.egg-info/dependency_links.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/requires.txt b/gym-snake/gym_snake_classic.egg-info/requires.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic.egg-info/top_level.txt b/gym-snake/gym_snake_classic.egg-info/top_level.txt old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/__init__.py b/gym-snake/gym_snake_classic/__init__.py old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/envs/__init__.py b/gym-snake/gym_snake_classic/envs/__init__.py old mode 100644 new mode 100755 diff --git a/gym-snake/gym_snake_classic/envs/snake_classic.py b/gym-snake/gym_snake_classic/envs/snake_classic.py old mode 100644 new mode 100755 index 87c6792..6649628 --- a/gym-snake/gym_snake_classic/envs/snake_classic.py +++ b/gym-snake/gym_snake_classic/envs/snake_classic.py @@ -1,16 +1,28 @@ import os import gym +import configs import pygame -from PIL import Image import numpy as np -from matplotlib.image import imread -from gym import error, spaces, utils -from gym.utils import seeding +import inspect -from gym_snake_classic.envs.src.game import Game,GameConfig +from gym import spaces +from gym.utils import seeding +from PIL import Image from gym_snake_classic.envs.src.assets import Snake,Food +from gym_snake_classic.envs.src.game import Game,GameConfig + +WIDTH = 400 +HEIGHT = 400 + +OBS_WIDTH=256 +OBS_HEIGHT=256 + class SnakeClassicEnv(gym.Env): + """ + Gym Environment for classic snake game + """ + metadata = {'render.modes':['human']} reward_range = (-np.inf, np.inf) @@ -21,60 +33,62 @@ class SnakeClassicEnv(gym.Env): 3 : 'RIGHT' } - def __init__(self,rgb=True): + def __init__(self): self.temp_filename='_temp_window.jpg' - width,height = (800,600) + width,height = (WIDTH,HEIGHT) self.action_space = spaces.Discrete(4) - - cfg = GameConfig(width = 800, - height = 600, + self.n_steps = 0 + self.reward = 0 + cfg = GameConfig(width = width, + height = height, player = Snake, food = Food, player_size = (20,20), food_size = (20,20), - render = True, - rgb = rgb ) - if rgb: - self.observation_space = spaces.Box(low=0, high=255, shape= - (height, width, 3)) - else: - self.observation_space = spaces.Box(low=0, high=255, shape= - (height, width)) + self.observation_space = spaces.Box(low=0, high=255, shape= + (OBS_HEIGHT, OBS_WIDTH, 3)) self.snake_game = Game(cfg) self.snake_game.on_init() - + self.prev_length = self.snake_game.player.length + @property def env(self): return self - - def _observe(self): + + def _observe(self,resize=True): pygame.image.save(self.snake_game.window,self.temp_filename) obs = Image.open(self.temp_filename) - if not self.snake_game.config.rgb: - obs=obs.convert('LA') - # HACK to pass three channels to memory buffer in dopamine - # TODO figure out how to pass rgb directly - - + if resize: + obs=obs.resize((OBS_HEIGHT,OBS_WIDTH),Image.BILINEAR) + obs=np.array(obs) return obs def step(self, action): - self.take_action(action) + self.reward-=1 self.snake_game.on_loop() + self.snake_game.on_render(show=configs.SHOW) + #observation #TODO figure out a faster way obs = self._observe() - - #reward - reward = self.get_reward() - + #done done = self.snake_game.done - + + if done : + self.reward -= 1000 + reward = self.reward #reset changes reward + self.reset() + else: + if(not self.prev_length == self.snake_game.player.length): + self.reward += 1000 + reward = self.reward + + self.prev_length=self.snake_game.player.length #info info = {} return (obs,reward,done,info) @@ -85,14 +99,20 @@ def take_action(self, action): self.snake_game.take_action(act) def get_reward(self): - return self.snake_game.score - + return self.reward + def reset(self): + self.n_steps=0 + self.reward=0 self.snake_game.reset() + self.snake_game.on_loop() + self.snake_game.on_render(show=configs.SHOW) + return self._observe() - + def render(self,mode='human'): + if mode=='rgb_array': + return self._observe(resize=False) self.snake_game.on_render() - - \ No newline at end of file + diff --git a/gym-snake/gym_snake_classic/envs/src/assets.py b/gym-snake/gym_snake_classic/envs/src/assets.py old mode 100644 new mode 100755 index 111e0da..a510c81 --- a/gym-snake/gym_snake_classic/envs/src/assets.py +++ b/gym-snake/gym_snake_classic/envs/src/assets.py @@ -16,12 +16,13 @@ def __init__(self,length,window_size): self.reset() def reset(self): + self.direction=0 self.length=self.init_length self.x=[] self.y=[] for _ in range(self.length): - self.x.append(0) - self.y.append(0) + self.x.append(self.window_size[0]/2) + self.y.append(self.window_size[1]/2) def _update(self): @@ -44,8 +45,8 @@ def update(self): elif self.direction == 3: self.y[0] += self.step - self.x[0] %= self.window_size[0] - self.y[0] %= self.window_size[1] + self.x[0] %= self.window_size[0]-20 + self.y[0] %= self.window_size[1]-20 def moveRight(self): diff --git a/gym-snake/gym_snake_classic/envs/src/game.py b/gym-snake/gym_snake_classic/envs/src/game.py old mode 100644 new mode 100755 index 7d493c9..d954b13 --- a/gym-snake/gym_snake_classic/envs/src/game.py +++ b/gym-snake/gym_snake_classic/envs/src/game.py @@ -6,6 +6,8 @@ from typing import Any,Tuple from random import randint +MAX_RUN = 5000 + @dataclass class GameConfig: height : int @@ -15,7 +17,6 @@ class GameConfig: player_size : Tuple[int] food_size : Tuple[int] render : bool = True - rgb : bool = True class Game: def __init__(self,config:GameConfig)->None: @@ -27,6 +28,9 @@ def __init__(self,config:GameConfig)->None: def reset(self): self.player.reset() + self.spawn_food() + + self._running=True self.on_loop() def on_init(self): @@ -44,11 +48,18 @@ def on_init(self): def spawn_food(self): step = self.food.step - nx=randint(2,10)*step - ny=randint(2,10)*step - while((nx,ny) in zip(self.player.x,self.player.y)): - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%(self.window_size[0]-10) + ny=(randint(2,10)*step)%(self.window_size[1]-10) + count = 0 + while ( ((nx,ny) in zip(self.player.x,self.player.y)) ) and (count=MAX_RUN: + # Hack so that env pushes towards this + self.player.length+=1000 + print("Game Completed!!!") + self._running=False self.food.position=(nx,ny) @property @@ -82,13 +93,16 @@ def on_loop(self): _pos = pygame.Rect(_pos,(1,1)) if( head.colliderect(_pos) ): + # Dont call reset here,its called in env self._running = False + - def on_render(self): + def on_render(self,show=True): self.display.fill((255,255,255)) self.player.draw(self.display,self.config.player_size) self.food.draw(self.display,self.config.food_size) - pygame.display.flip() + if show: + pygame.display.flip() def on_cleanup(self): @@ -104,5 +118,40 @@ def take_action(self,act): if(act=='RIGHT'): self.player.moveRight() + # Not used in the env + def on_execute(self): + if (self.on_init() == False): + self._running = False + + while(1): + while(self._running): + pygame.event.pump() + keys = pygame.key.get_pressed() + + if(keys[K_RIGHT]): + self.player.moveRight() + + if(keys[K_LEFT]): + self.player.moveLeft() + + if(keys[K_UP]): + self.player.moveUp() + + if(keys[K_DOWN]): + self.player.moveDown() + + if(keys[K_ESCAPE]): + self._running=False + exit(0) + + self.on_loop() + if(self.config.render): + self.on_render() + time.sleep(50/1000.0) + + + print("Exiting") + self.on_cleanup() + diff --git a/gym-snake/gym_snake_classic/envs/src/main.py b/gym-snake/gym_snake_classic/envs/src/main.py old mode 100644 new mode 100755 index 795891d..6462c87 --- a/gym-snake/gym_snake_classic/envs/src/main.py +++ b/gym-snake/gym_snake_classic/envs/src/main.py @@ -2,8 +2,8 @@ from assets import Snake,Food -cfg = GameConfig(width = 800, - height = 600, +cfg = GameConfig(width = 400, + height = 400, player = Snake, food = Food, player_size = (20,20), diff --git a/gym-snake/setup.py b/gym-snake/setup.py old mode 100644 new mode 100755 diff --git a/rl/agent.py b/rl/agent.py deleted file mode 100644 index 16ad6ac..0000000 --- a/rl/agent.py +++ /dev/null @@ -1,91 +0,0 @@ -import gym -import numpy as np -import tensorflow as tf - -from dopamine.replay_memory import prioritized_replay_buffer,circular_replay_buffer -from dopamine.agents.dqn import dqn_agent -from dopamine.discrete_domains.run_experiment import Runner -from dopamine.discrete_domains.gym_lib import create_gym_environment -from models import SimpleDQNNetwork - - -STACK_SIZE = 4 -GAMMA = 0.9 -REPLAY_CAPACITY = 10000 -BATCH_SIZE = 32 - -class SnakeDQNAgent(dqn_agent.DQNAgent): - def __init__(self,*args,**kwargs): - super().__init__(*args,**kwargs) - - def _build_replay_buffer(self,use_staging): - - """Creates the replay buffer used by the agent. - Args: - use_staging: bool, if True, uses a staging area to prefetch data for - faster training. - Returns: - A WrapperReplayBuffer object. - """ - return circular_replay_buffer.WrappedReplayBuffer( - replay_capacity = REPLAY_CAPACITY, - batch_size = BATCH_SIZE, - observation_shape=self.observation_shape, - stack_size=self.stack_size, - use_staging=use_staging, - update_horizon=self.update_horizon, - gamma=self.gamma, - observation_dtype=self.observation_dtype.as_numpy_dtype - ) - - - - -env = create_gym_environment( - environment_name="gym_snake_classic:SnakeClassic", - version = 'v0' - ) -memory_buffer = prioritized_replay_buffer.WrappedPrioritizedReplayBuffer( - observation_shape=env.observation_space.shape, - stack_size=STACK_SIZE, - replay_capacity=100, - batch_size=32, - gamma=GAMMA, - ) -print(env.action_space.n) -sess = tf.Session() - - - -def _agent_fn(sess,env,summary_writer): - AGENT = SnakeDQNAgent( - sess=sess, - num_actions = env.action_space.n, - observation_shape = env.observation_space.shape, - stack_size = STACK_SIZE, - network = SimpleDQNNetwork, - gamma=GAMMA, - tf_device = '/gpu:0' , - summary_writer=summary_writer - ) - return AGENT - -def _env_fn(*args): - return env - -runner = Runner( - base_dir = '_tmp_agent_dir/', - create_agent_fn = _agent_fn, - create_environment_fn= _env_fn, - checkpoint_file_prefix='ckpt', - logging_file_prefix='log', - log_every_n=10, - num_iterations=200, - training_steps=2500, - evaluation_steps=1250, - max_steps_per_episode=10000 - ) -runner.run_experiment() - - - diff --git a/rl/configs.py b/rl/configs.py new file mode 100755 index 0000000..f45fcb8 --- /dev/null +++ b/rl/configs.py @@ -0,0 +1,11 @@ +STACK_SIZE = 4 +GAMMA = 0.99 +REPLAY_CAPACITY = 7500 +BATCH_SIZE = 64 +SHOW = False +MIN_REPLAY_HISTORY = REPLAY_CAPACITY #number of transitions that should be experienced + #before the agent begins training its value function +BASE_DIR = 'summaries/snake_classic/' +TARGET_UPDATE_PERIOD = 1000 # update period for the target network +SUMMARY_WRITING_FREQUENCY = 50 +EVAL_MODE=True \ No newline at end of file diff --git a/rl/models.py b/rl/models.py old mode 100644 new mode 100755 index e6d20f7..36430e0 --- a/rl/models.py +++ b/rl/models.py @@ -1,10 +1,12 @@ +import numpy as np import tensorflow as tf + from collections import namedtuple -from tensorflow.contrib import layers as contrib_layers -from tensorflow.contrib import slim as contrib_slim +from utils import merge_last_two_dims DQNNetworkType = namedtuple('dqn_network', ['q_values']) - +RainbowNetworkType = namedtuple( + 'c51_network', ['q_values', 'logits', 'probabilities']) class SimpleDQNNetwork(tf.keras.Model): @@ -29,7 +31,7 @@ def __init__(self, num_actions, name=None): self.conv3 = tf.keras.layers.Conv2D(64, [3, 3], strides=1, padding='same', activation=activation_fn, name='Conv') self.flatten = tf.keras.layers.Flatten() - self.dense1 = tf.keras.layers.Dense(512, activation=activation_fn, + self.dense1 = tf.keras.layers.Dense(256, activation=activation_fn, name='fully_connected') self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected') @@ -43,31 +45,66 @@ def call(self, state,): """ #TODO Make rgb proper #HACK pass - def infer_shape(x): - x = tf.convert_to_tensor(x) + - # If unknown rank, return dynamic shape - if x.shape.dims is None: - return tf.shape(x) + x = tf.cast(state, tf.float32) + x = tf.div(x, 255.) + x = merge_last_two_dims(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.flatten(x) + x = self.dense1(x) + # x = self.dense1(x) - static_shape = x.shape.as_list() - dynamic_shape = tf.shape(x) + return DQNNetworkType(self.dense2(x)) - ret = [] - for i in range(len(static_shape)): - dim = static_shape[i] - if dim is None: - dim = dynamic_shape[i] - ret.append(dim) - return ret - def merge_last_two_dims(tensor): - shape = infer_shape(tensor) - shape[-2] *= shape[-1] - shape.pop(-1) - return tf.reshape(tensor, shape) +class RainbowNetwork(tf.keras.Model): + def __init__(self, num_actions, num_atoms, support, name=None): + """Creates the layers used calculating return distributions. + Args: + num_actions: int, number of actions. + num_atoms: int, the number of buckets of the value function distribution. + support: tf.linspace, the support of the Q-value distribution. + name: str, used to crete scope for network parameters. + """ + super(RainbowNetwork, self).__init__(name=name) + activation_fn = tf.keras.activations.relu + self.num_actions = num_actions + self.num_atoms = num_atoms + self.support = support + self.kernel_initializer = tf.keras.initializers.VarianceScaling( + scale=1.0 / np.sqrt(3.0), mode='fan_in', distribution='uniform') + # Defining layers. + self.conv1 = tf.keras.layers.Conv2D( + 32, [8, 8], strides=4, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.conv2 = tf.keras.layers.Conv2D( + 64, [4, 4], strides=2, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.conv3 = tf.keras.layers.Conv2D( + 64, [3, 3], strides=1, padding='same', activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='Conv') + self.flatten = tf.keras.layers.Flatten() + self.dense1 = tf.keras.layers.Dense( + 512, activation=activation_fn, + kernel_initializer=self.kernel_initializer, name='fully_connected') + self.dense2 = tf.keras.layers.Dense( + num_actions * num_atoms, kernel_initializer=self.kernel_initializer, + name='fully_connected') + def call(self, state): + """Creates the output tensor/op given the state tensor as input. + See https://www.tensorflow.org/api_docs/python/tf/keras/Model for more + information on this. Note that tf.keras.Model implements `call` which is + wrapped by `__call__` function by tf.keras.Model. + Args: + state: Tensor, input tensor. + Returns: + collections.namedtuple, output ops (graph mode) or output tensors (eager). + """ x = tf.cast(state, tf.float32) x = tf.div(x, 255.) x = merge_last_two_dims(x) @@ -76,5 +113,8 @@ def merge_last_two_dims(tensor): x = self.conv3(x) x = self.flatten(x) x = self.dense1(x) - - return DQNNetworkType(self.dense2(x)) + x = self.dense2(x) + logits = tf.reshape(x, [-1, self.num_actions, self.num_atoms]) + probabilities = tf.keras.activations.softmax(logits) + q_values = tf.reduce_sum(self.support * probabilities, axis=2) + return RainbowNetworkType(q_values, logits, probabilities) \ No newline at end of file diff --git a/rl/readme.md b/rl/readme.md old mode 100644 new mode 100755 diff --git a/rl/requirements.txt b/rl/requirements.txt old mode 100644 new mode 100755 diff --git a/rl/test_env.py b/rl/test_env.py new file mode 100755 index 0000000..f759340 --- /dev/null +++ b/rl/test_env.py @@ -0,0 +1,27 @@ +import gym +from dopamine.discrete_domains.gym_lib import create_gym_environment + + +def test_snake_classic(dopamine=False): + if(dopamine): + env = create_gym_environment( + environment_name="gym_snake_classic:SnakeClassic", + version = 'v0' + ) + else: + env = gym.make("gym_snake_classic:SnakeClassic-v0") + env.reset() + if not dopamine: + env.render() + for _ in range(100): + action = 3 + state, reward, done, _ = env.step(action) + if not dopamine: + env.render() + if done: + break + print(f"Reward :{reward}") + + +if __name__ == "__main__": + test_snake_classic() \ No newline at end of file diff --git a/rl/train.py b/rl/train.py new file mode 100755 index 0000000..457c0e2 --- /dev/null +++ b/rl/train.py @@ -0,0 +1,102 @@ +import gym +import configs +import numpy as np +import tensorflow as tf + +from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer +from dopamine.agents.dqn import dqn_agent +from dopamine.discrete_domains.run_experiment import Runner +from dopamine.discrete_domains.gym_lib import create_gym_environment +from dopamine.discrete_domains import atari_lib +from dopamine.agents.rainbow import rainbow_agent +from models import SimpleDQNNetwork, RainbowNetwork +from utils import SnakeRunner + +from agents import SnakeRainbowAgent + +STACK_SIZE = configs.STACK_SIZE +GAMMA = configs.GAMMA +REPLAY_CAPACITY = configs.REPLAY_CAPACITY +BATCH_SIZE = configs.BATCH_SIZE + + +sess = tf.Session() + + +env = create_gym_environment( + environment_name="gym_snake_classic:SnakeClassic", version="v0" +) + + +def _agent_fn(sess, env, summary_writer): + # AGENT = SnakeDQNAgent( + # sess=sess, + # num_actions = env.action_space.n, + # observation_shape = env.observation_space.shape, + # stack_size = STACK_SIZE, + # network = SimpleDQNNetwork, + # gamma=0.99, + # update_horizon=1, + # min_replay_history=configs.MIN_REPLAY_HISTORY, + # update_period=4, + # target_update_period=configs.TARGET_UPDATE_PERIOD, + # epsilon_fn=dqn_agent.linearly_decaying_epsilon, + # epsilon_train=0.01, + # epsilon_eval=0.001, + # epsilon_decay_period=250000, + # eval_mode=configs.EVAL_MODE , # True for training + # tf_device="/gpu:*", + # summary_writer=summary_writer, + # summary_writing_frequency=configs.SUMMARY_WRITING_FREQUENCY, + # ) + + AGENT = SnakeRainbowAgent( + sess=sess, + num_actions=env.action_space.n, + observation_shape=env.observation_space.shape, + stack_size=STACK_SIZE, + network=RainbowNetwork, + num_atoms=51, + vmax=10.0, + gamma=0.99, + update_horizon=1, + min_replay_history=configs.MIN_REPLAY_HISTORY, + update_period=4, + target_update_period=configs.TARGET_UPDATE_PERIOD, + epsilon_fn=dqn_agent.linearly_decaying_epsilon, + epsilon_train=0.01, + epsilon_eval=0.001, + epsilon_decay_period=250000, + eval_mode=configs.EVAL_MODE , # True for training + replay_scheme="prioritized", + tf_device="/gpu:*", + summary_writer=summary_writer, + summary_writing_frequency=configs.SUMMARY_WRITING_FREQUENCY, + ) + return AGENT + + + +def _env_fn(*args): + return env + + +runner = SnakeRunner( + base_dir=configs.BASE_DIR, + create_agent_fn=_agent_fn, + create_environment_fn=_env_fn, + checkpoint_file_prefix="ckpt", + logging_file_prefix="log", + log_every_n=10, + num_iterations=2000, + training_steps=25000, + evaluation_steps=12500, + max_steps_per_episode=10000, +) + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + runner.run_experiment() + # runner.visualize(record_path = configs.BASE_DIR+'visualize/', + # num_global_steps=500) diff --git a/rl/utils.py b/rl/utils.py new file mode 100755 index 0000000..0794a5d --- /dev/null +++ b/rl/utils.py @@ -0,0 +1,123 @@ +import tensorflow as tf +import pygame +from dopamine.discrete_domains.run_experiment import Runner +from dopamine.discrete_domains import run_experiment +from dopamine.utils import agent_visualizer +from dopamine.utils import atari_plotter +from dopamine.utils import bar_plotter +from dopamine.utils import line_plotter +from dopamine.utils import plotter + +def infer_shape(x): + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.shape.dims is None: + return tf.shape(x) + + static_shape = x.shape.as_list() + dynamic_shape = tf.shape(x) + + ret = [] + for i in range(len(static_shape)): + dim = static_shape[i] + if dim is None: + dim = dynamic_shape[i] + ret.append(dim) + + return ret + +def merge_last_two_dims(tensor): + shape = infer_shape(tensor) + shape[-2] *= shape[-1] + shape.pop(-1) + return tf.reshape(tensor, shape) + +class SnakeRunner(Runner): + def visualize(self, record_path, num_global_steps=500): + if not tf.gfile.Exists(record_path): + tf.gfile.MakeDirs(record_path) + self._agent.eval_mode = True + + # Set up the game playback rendering. + atari_params = { + 'environment': self._environment, + 'width': 400, + 'height': 400, + } + atari_plot = atari_plotter.AtariPlotter(parameter_dict=atari_params) + # Plot the rewards received next to it. + reward_params = {'x': atari_plot.parameters['width'], + 'xlabel': 'Timestep', + 'ylabel': 'Reward', + 'title': 'Rewards', + 'get_line_data_fn': self._agent.get_rewards} + reward_plot = line_plotter.LinePlotter(parameter_dict=reward_params) + action_names = [ + 'Action {}'.format(x) for x in range(self._agent.num_actions)] + # Plot Q-values (DQN) or Q-value distributions (Rainbow). + q_params = {'x': atari_plot.parameters['width'] // 2, + 'y': atari_plot.parameters['height'], + 'legend': action_names} + if 'DQN' in self._agent.__class__.__name__: + q_params['xlabel'] = 'Timestep' + q_params['ylabel'] = 'Q-Value' + q_params['title'] = 'Q-Values' + q_params['get_line_data_fn'] = self._agent.get_q_values + q_plot = line_plotter.LinePlotter(parameter_dict=q_params) + else: + q_params['xlabel'] = 'Return' + q_params['ylabel'] = 'Return probability' + q_params['title'] = 'Return distribution' + q_params['get_bar_data_fn'] = self._agent.get_probabilities + q_plot = bar_plotter.BarPlotter(parameter_dict=q_params) + screen_width = ( + atari_plot.parameters['width'] + reward_plot.parameters['width']) + screen_height = ( + atari_plot.parameters['height'] + q_plot.parameters['height']) + # Dimensions need to be divisible by 2: + if screen_width % 2 > 0: + screen_width += 1 + if screen_height % 2 > 0: + screen_height += 1 + visualizer = agent_visualizer.AgentVisualizer( + record_path=record_path, plotters=[atari_plot, reward_plot, q_plot], + screen_width=screen_width, screen_height=screen_height) + + global_step = 0 + while global_step < num_global_steps: + initial_observation = self._environment.reset() + action = self._agent.begin_episode(initial_observation) + while True: + observation, reward, is_terminal, _ = self._environment.step(action) + global_step += 1 + visualizer.visualize() + if self._environment.game_over or global_step >= num_global_steps: + break + elif is_terminal: + self._agent.end_episode(reward) + action = self._agent.begin_episode(observation) + else: + action = self._agent.step(reward, observation) + self._end_episode(reward) + visualizer.generate_video() + +class SnakePlotter(plotter.Plotter): + def __init__(self, parameter_dict=None): + super().__init__(parameter_dict) + assert 'environment' in self.parameters + self.game_surface = pygame.Surface((self.parameters['width'], + self.parameters['height'])) + + def draw(self): + """Render the Atari 2600 frame. + + Returns: + object to be rendered by AgentVisualizer. + """ + environment = self.parameters['environment'] + obs = environment.render(mode='rgb_array').astype(np.int32) + + return pygame.transform.scale(self.game_surface, + (self.parameters['width'], + self.parameters['height'])) \ No newline at end of file diff --git a/snake_game/assets.py b/snake_game/assets.py old mode 100644 new mode 100755 index 67b7fcf..a26155a --- a/snake_game/assets.py +++ b/snake_game/assets.py @@ -36,8 +36,8 @@ def update(self): elif self.direction == 3: self.y[0] += self.step - self.x[0] %= self.window_size[0] - self.y[0] %= self.window_size[1] + self.x[0] %= self.window_size[0]-20 + self.y[0] %= self.window_size[1]-256 def moveRight(self): @@ -79,16 +79,17 @@ class Food: x,y=(0,0) step=44 - def __init__(self,x,y): - self.x = x*self.step - self.y = y*self.step + def __init__(self,x,y,window_size): + self.window_size=window_size + self.x = (x*self.step)%window_size[0] + self.y = (y*self.step)%window_size[1] @property def position(self): return (self.x,self.y) @position.setter def position(self,value): - self.x=value[0] - self.y=value[1] + self.x=(value[0])%(window_size[0]-5) + self.y=(value[1])%(window_size[1]-5) def draw(self,surface,food_size): pygame.draw.rect(surface,(255,153,51), diff --git a/snake_game/game.py b/snake_game/game.py old mode 100644 new mode 100755 index d45bf0f..f818f66 --- a/snake_game/game.py +++ b/snake_game/game.py @@ -22,7 +22,7 @@ def __init__(self,config:GameConfig)->None: self.window_size = (self.config.width,self.config.height) self.player = self.config.player(5,self.window_size) self._running = True - self.food = self.config.food(5,5) # setting init position + self.food = self.config.food(5,5,window_size) # setting init position def on_init(self): pygame.init() @@ -33,20 +33,20 @@ def on_init(self): pygame.display.set_caption('Snake') self._running=True self.snake_body = pygame.Surface( self.config.player_size ) - + self.food_img = pygame.Surface( self.config.food_size ) def on_event(self,event): if event.type == QUIT: self._running=False - + def spawn_food(self): step = self.food.step - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%self.window_size[0] + ny=(randint(2,10)*step)%self.window_size[1] while((nx,ny) in zip(self.player.x,self.player.y)): - nx=randint(2,10)*step - ny=randint(2,10)*step + nx=(randint(2,10)*step)%self.window_size[0] + ny=(randint(2,10)*step)%self.window_size[1] self.food.position=(nx,ny) @@ -56,8 +56,7 @@ def on_loop(self): # check collison with food head = self.snake_body.get_rect(topleft=self.player.position) food_pos = self.food_img.get_rect(topleft=self.food.position) - - + if( head.colliderect(food_pos) ): self.player.length = self.player.length+1 self.player.eat(food_pos) diff --git a/snake_game/main.py b/snake_game/main.py old mode 100644 new mode 100755 diff --git a/test.py b/test.py deleted file mode 100644 index 6e0c008..0000000 --- a/test.py +++ /dev/null @@ -1,18 +0,0 @@ -import gym - - -def test_snake_classic(): - env = gym.make("gym_snake_classic:SnakeClassic-v0") - env.reset() - env.render() - for _ in range(100): - action = 3 - state, reward, done, _ = env.step(action) - env.render() - if done: - break - print(f"Reward :{reward}") - - -if __name__ == "__main__": - test_snake_classic() \ No newline at end of file diff --git a/test_env.py b/test_env.py new file mode 100755 index 0000000..ee2ee6a --- /dev/null +++ b/test_env.py @@ -0,0 +1,36 @@ +import gym +import unittest +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image + +class TestSnakeEnv(unittest.TestCase): + def setUp(self): + self.env = gym.make("gym_snake_classic:SnakeClassic-v0") + + + def test_0_snake_classic(self,render=True): + self.env.reset() + if render: + self.env.render() + for _ in range(100): + action = 3 + state, reward, done, _ = self.env.step(action) + if render: + self.env.render() + if done: + break + self.assertEqual(reward,-89) + + def test_1_observe(self): + img=self.env._observe() + uni = np.unique(img) + self.assertEqual(len(uni),157) + + +if __name__ == "__main__": + t = TestSnakeEnv() + t.setUp() + t.test_0_snake_classic() + t.test_1_observe() + # unittest.main() \ No newline at end of file