Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
__pycache__/
__pycache__/
checkpoints/
logs/
*.pluto
1 change: 1 addition & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pip install gym-snake
env = gym.make("gym_snake_classic:SnakeClassic-v0")
```

* Size need to be change in params plotter utils
Empty file modified __pycache__/assets.cpython-37.pyc
100644 → 100755
Empty file.
Empty file modified __pycache__/window.cpython-36.pyc
100644 → 100755
Empty file.
Empty file modified __pycache__/window.cpython-37.pyc
100644 → 100755
Empty file.
95 changes: 95 additions & 0 deletions agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import gym
import configs
import numpy as np
import tensorflow as tf

from tensorflow.contrib import slim as contrib_slim
from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer
from dopamine.replay_memory import prioritized_replay_buffer, circular_replay_buffer
from dopamine.agents.dqn import dqn_agent
from dopamine.discrete_domains.run_experiment import Runner
from dopamine.discrete_domains.gym_lib import create_gym_environment
from dopamine.discrete_domains import atari_lib
from dopamine.agents.rainbow import rainbow_agent
from models import SimpleDQNNetwork, RainbowNetwork


class SnakeDQNAgent(dqn_agent.DQNAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _build_replay_buffer(self, use_staging):

"""Creates the replay buffer used by the agent.
Args:
use_staging: bool, if True, uses a staging area to prefetch data for
faster training.
Returns:
A WrapperReplayBuffer object.
"""
return circular_replay_buffer.WrappedReplayBuffer(
replay_capacity=configs.REPLAY_CAPACITY,
batch_size=configs.BATCH_SIZE,
observation_shape=self.observation_shape,
stack_size=configs.STACK_SIZE,
use_staging=use_staging,
update_horizon=self.update_horizon,
gamma=self.gamma,
observation_dtype=self.observation_dtype.as_numpy_dtype,
)


class SnakeRainbowAgent(rainbow_agent.RainbowAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rewards = [] # For visualizer

def step(self, reward, observation):
self.rewards.append(reward)
return super().step(reward, observation)
def get_rewards(self):
return [np.cumsum(self.rewards)]

def _build_replay_buffer(self, use_staging):

"""Creates the replay buffer used by the agent.
Args:
use_staging: bool, if True, uses a staging area to prefetch data for
faster training.
Returns:
A WrapperReplayBuffer object.
"""
return prioritized_replay_buffer.WrappedPrioritizedReplayBuffer(
replay_capacity=configs.REPLAY_CAPACITY,
batch_size=configs.BATCH_SIZE,
observation_shape=self.observation_shape,
stack_size=configs.STACK_SIZE,
use_staging=use_staging,
update_horizon=self.update_horizon,
gamma=self.gamma,
observation_dtype=self.observation_dtype.as_numpy_dtype,
)

def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=False):
if use_legacy_checkpoint:
variables_to_restore = atari_lib.maybe_transform_variable_names(
tf.all_variables(), legacy_checkpoint_load=True)
else:
global_vars = set([x.name for x in tf.global_variables()])
ckpt_vars = [
'{}:0'.format(name)
for name, _ in tf.train.list_variables(checkpoint_path)
]
include_vars = list(global_vars.intersection(set(ckpt_vars)))
variables_to_restore = contrib_slim.get_variables_to_restore(
include=include_vars)
if variables_to_restore:
reloader = tf.train.Saver(var_list=variables_to_restore)
reloader.restore(self._sess, checkpoint_path)
tf.logging.info('Done restoring from %s', checkpoint_path)
else:
tf.logging.info('Nothing to restore!')

def get_probabilities(self):
return self._sess.run(tf.squeeze(self._net_outputs.probabilities),
{self.state_ph: self.state})
Empty file modified gym-snake/gym_snake_classic.egg-info/PKG-INFO
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic.egg-info/SOURCES.txt
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic.egg-info/dependency_links.txt
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic.egg-info/requires.txt
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic.egg-info/top_level.txt
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic/__init__.py
100644 → 100755
Empty file.
Empty file modified gym-snake/gym_snake_classic/envs/__init__.py
100644 → 100755
Empty file.
96 changes: 58 additions & 38 deletions gym-snake/gym_snake_classic/envs/snake_classic.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import os
import gym
import configs
import pygame
from PIL import Image
import numpy as np
from matplotlib.image import imread
from gym import error, spaces, utils
from gym.utils import seeding
import inspect

from gym_snake_classic.envs.src.game import Game,GameConfig
from gym import spaces
from gym.utils import seeding
from PIL import Image
from gym_snake_classic.envs.src.assets import Snake,Food
from gym_snake_classic.envs.src.game import Game,GameConfig

WIDTH = 400
HEIGHT = 400

OBS_WIDTH=256
OBS_HEIGHT=256


class SnakeClassicEnv(gym.Env):
"""
Gym Environment for classic snake game
"""

metadata = {'render.modes':['human']}
reward_range = (-np.inf, np.inf)

Expand All @@ -21,60 +33,62 @@ class SnakeClassicEnv(gym.Env):
3 : 'RIGHT'
}

def __init__(self,rgb=True):
def __init__(self):
self.temp_filename='_temp_window.jpg'
width,height = (800,600)
width,height = (WIDTH,HEIGHT)
self.action_space = spaces.Discrete(4)

cfg = GameConfig(width = 800,
height = 600,
self.n_steps = 0
self.reward = 0
cfg = GameConfig(width = width,
height = height,
player = Snake,
food = Food,
player_size = (20,20),
food_size = (20,20),
render = True,
rgb = rgb
)
if rgb:
self.observation_space = spaces.Box(low=0, high=255, shape=
(height, width, 3))
else:
self.observation_space = spaces.Box(low=0, high=255, shape=
(height, width))

self.observation_space = spaces.Box(low=0, high=255, shape=
(OBS_HEIGHT, OBS_WIDTH, 3))

self.snake_game = Game(cfg)
self.snake_game.on_init()

self.prev_length = self.snake_game.player.length

@property
def env(self):
return self
def _observe(self):

def _observe(self,resize=True):
pygame.image.save(self.snake_game.window,self.temp_filename)
obs = Image.open(self.temp_filename)
if not self.snake_game.config.rgb:
obs=obs.convert('LA')
# HACK to pass three channels to memory buffer in dopamine
# TODO figure out how to pass rgb directly


if resize:
obs=obs.resize((OBS_HEIGHT,OBS_WIDTH),Image.BILINEAR)
obs=np.array(obs)
return obs

def step(self, action):

self.take_action(action)
self.reward-=1
self.snake_game.on_loop()
self.snake_game.on_render(show=configs.SHOW)

#observation
#TODO figure out a faster way
obs = self._observe()

#reward
reward = self.get_reward()


#done
done = self.snake_game.done


if done :
self.reward -= 1000
reward = self.reward #reset changes reward
self.reset()
else:
if(not self.prev_length == self.snake_game.player.length):
self.reward += 1000
reward = self.reward

self.prev_length=self.snake_game.player.length
#info
info = {}
return (obs,reward,done,info)
Expand All @@ -85,14 +99,20 @@ def take_action(self, action):
self.snake_game.take_action(act)

def get_reward(self):
return self.snake_game.score
return self.reward

def reset(self):
self.n_steps=0
self.reward=0
self.snake_game.reset()
self.snake_game.on_loop()
self.snake_game.on_render(show=configs.SHOW)

return self._observe()


def render(self,mode='human'):
if mode=='rgb_array':
return self._observe(resize=False)
self.snake_game.on_render()



9 changes: 5 additions & 4 deletions gym-snake/gym_snake_classic/envs/src/assets.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def __init__(self,length,window_size):
self.reset()

def reset(self):
self.direction=0
self.length=self.init_length
self.x=[]
self.y=[]
for _ in range(self.length):
self.x.append(0)
self.y.append(0)
self.x.append(self.window_size[0]/2)
self.y.append(self.window_size[1]/2)


def _update(self):
Expand All @@ -44,8 +45,8 @@ def update(self):
elif self.direction == 3:
self.y[0] += self.step

self.x[0] %= self.window_size[0]
self.y[0] %= self.window_size[1]
self.x[0] %= self.window_size[0]-20
self.y[0] %= self.window_size[1]-20


def moveRight(self):
Expand Down
Loading