hyper-optimized alpha-zero implementation with ray + cython for speed
train an agent that beats random actions and pure MCTS in 2 minutes
train.py: distributed training with rayctree/: mcts nodes in cython (node.py = pure python)mcts.py: mcts playoutsnetwork.py: neural net stuffboard.py: gomoku board
- ray distributed parts (
train.py):- one distributed replay buffer
- N actors with the 'best model' weights which self-play games and store data in replay buffer
- M 'candidate models' which pull from the replay buffer and train
- each iteration they play against the 'best model' and if they win the 'best model' weights is updated
- include write/evaluation locks on 'best weights'
- 1 best model weights store (PS / parameter server)
- stores the best weights which are retrived by self-play and updated when candidates win
- cython impl
ctree/: c++/cython mctsnode.py: pure python mcts
-- todos --
- jax network impl
- tpu + gpu support
- saved model weights
