Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 77 additions & 67 deletions drqn.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,96 @@
import tensorflow as tf
import numpy as np
from collections import deque


class QNetwork:
def __init__(self, learning_rate=0.01, state_size=4,
action_size=2, hidden_size=10, step_size=1 ,
"""
DRQN with a dueling output head.

- LSTM encoder over history of states
- Two fully connected streams:
* Value stream V(s)
* Advantage stream A(s, a)
Combined as Q(s, a) = V(s) + (A(s,a) - mean_a A(s,a))
"""

def __init__(self, learning_rate=0.01, state_size=4,
action_size=2, hidden_size=10, step_size=1,
name='QNetwork'):

with tf.variable_scope(name):
self.inputs_ = tf.placeholder(tf.float32, [None,step_size, state_size], name='inputs_')
self.actions_ = tf.placeholder(tf.int32, [None], name='actions')

# 关闭 eager,以便兼容 TF1 风格的占位符与 Session
tf.compat.v1.disable_eager_execution()

with tf.compat.v1.variable_scope(name):
# 输入:形状为 [batch, step_size, state_size] 的序列
self.inputs_ = tf.compat.v1.placeholder(
tf.float32, [None, step_size, state_size], name='inputs_')
self.actions_ = tf.compat.v1.placeholder(
tf.int32, [None], name='actions')
one_hot_actions = tf.one_hot(self.actions_, action_size)


self.targetQs_ = tf.placeholder(tf.float32, [None], name='target')
##########################################

self.lstm = tf.contrib.rnn.BasicLSTMCell(hidden_size)

self.lstm_out, self.state = tf.nn.dynamic_rnn(self.lstm,self.inputs_,dtype=tf.float32)

self.reduced_out = self.lstm_out[:,-1,:]
self.reduced_out = tf.reshape(self.reduced_out,shape=[-1,hidden_size])

#########################################

#self.w1 = tf.Variable(tf.random_uniform([state_size,hidden_size]))
#self.b1 = tf.Variable(tf.constant(0.1,shape=[hidden_size]))
#self.h1 = tf.matmul(self.inputs_,self.w1) + self.b1
#self.h1 = tf.nn.relu(self.h1)
#self.h1 = tf.contrib.layers.layer_norm(self.h1)
#'''

self.w2 = tf.Variable(tf.random_uniform([hidden_size,hidden_size]))
self.b2 = tf.Variable(tf.constant(0.1,shape=[hidden_size]))
self.h2 = tf.matmul(self.reduced_out,self.w2) + self.b2
self.h2 = tf.nn.relu(self.h2)
self.h2 = tf.contrib.layers.layer_norm(self.h2)

self.w3 = tf.Variable(tf.random_uniform([hidden_size,action_size]))
self.b3 = tf.Variable(tf.constant(0.1,shape=[action_size]))
self.output = tf.matmul(self.h2,self.w3) + self.b3


#self.output = tf.contrib.layers.layer_norm(self.output)


'''
self.fc1 = tf.contrib.layers.fully_connected(self.inputs_, hidden_size)
self.fc2 = tf.contrib.layers.fully_connected(self.fc1, hidden_size)


self.output = tf.contrib.layers.fully_connected(self.fc2, action_size,activation_fn=None)

'''
self.Q = tf.reduce_sum(tf.multiply(self.output, one_hot_actions), axis=1)

self.loss = tf.reduce_mean(tf.square(self.targetQs_ - self.Q))
self.opt = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)

self.targetQs_ = tf.compat.v1.placeholder(
tf.float32, [None], name='target')

# 使用 v1 rnn_cell API 构造 LSTM
lstm_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(hidden_size)
self.lstm_out, self.state = tf.compat.v1.nn.dynamic_rnn(
lstm_cell, self.inputs_, dtype=tf.float32)

# 取序列最后一个时间步的输出
self.reduced_out = self.lstm_out[:, -1, :]
self.reduced_out = tf.reshape(
self.reduced_out, shape=[-1, hidden_size])

# 共享前馈层
self.w2 = tf.Variable(tf.random.uniform([hidden_size, hidden_size]))
self.b2 = tf.Variable(tf.constant(0.1, shape=[hidden_size]))
self.h2 = tf.nn.relu(tf.matmul(self.reduced_out, self.w2) + self.b2)

# ---------------------
# Dueling 结构
# ---------------------
# Value stream V(s)
self.w_value = tf.Variable(tf.random.uniform([hidden_size, 1]))
self.b_value = tf.Variable(tf.constant(0.1, shape=[1]))
self.value = tf.matmul(self.h2, self.w_value) + self.b_value # [batch, 1]

# Advantage stream A(s, a)
self.w_adv = tf.Variable(tf.random.uniform([hidden_size, action_size]))
self.b_adv = tf.Variable(tf.constant(0.1, shape=[action_size]))
self.advantage = tf.matmul(self.h2, self.w_adv) + self.b_adv # [batch, action_size]

# Combine into Q values
adv_mean = tf.reduce_mean(self.advantage, axis=1, keepdims=True)
self.output = self.value + (self.advantage - adv_mean)

# Q 值以及损失
self.Q = tf.reduce_sum(tf.multiply(self.output, one_hot_actions), axis=1)
self.loss = tf.reduce_mean(tf.square(self.targetQs_ - self.Q))
self.opt = tf.compat.v1.train.AdamOptimizer(
learning_rate).minimize(self.loss)

from collections import deque

class Memory():
"""
简单经验回放(暂未加权重),方便后续扩展为优先经验回放(PER)。
"""

def __init__(self, max_size=1000):
self.buffer = deque(maxlen=max_size)

def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size,step_size):
idx = np.random.choice(np.arange(len(self.buffer)-step_size),
size=batch_size, replace=False)

res = []

def sample(self, batch_size, step_size):
idx = np.random.choice(
np.arange(len(self.buffer) - step_size),
size=batch_size, replace=False)

res = []
for i in idx:
temp_buffer = []
temp_buffer = []
for j in range(step_size):
temp_buffer.append(self.buffer[i+j])
temp_buffer.append(self.buffer[i + j])
res.append(temp_buffer)
return res


return res
90 changes: 73 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,38 @@
import matplotlib.pyplot as plt
from collections import deque
import os
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import time

TIME_SLOTS = 100000 # number of time-slots to run simulation
NUM_CHANNELS = 2 # Total number of channels
NUM_USERS = 3 # Total number of users
ATTEMPT_PROB = 1 # attempt probability of ALOHA based models
import argparse
import csv

# -----------------------------
# Configuration (with CLI args)
# -----------------------------

DEFAULT_TIME_SLOTS = 100000 # default number of time-slots
DEFAULT_NUM_CHANNELS = 2 # default number of channels
DEFAULT_NUM_USERS = 3 # default number of users
DEFAULT_ATTEMPT_PROB = 1 # default attempt probability of ALOHA based models

parser = argparse.ArgumentParser(
description="Deep Multi-User RL for Dynamic Spectrum Access")
parser.add_argument("--time-slots", type=int, default=DEFAULT_TIME_SLOTS,
help="number of time-slots to run simulation")
parser.add_argument("--num-channels", type=int, default=DEFAULT_NUM_CHANNELS,
help="total number of channels")
parser.add_argument("--num-users", type=int, default=DEFAULT_NUM_USERS,
help="total number of users")
parser.add_argument("--attempt-prob", type=float, default=DEFAULT_ATTEMPT_PROB,
help="attempt probability of ALOHA-based models")

args, _ = parser.parse_known_args()

TIME_SLOTS = args.time_slots
NUM_CHANNELS = args.num_channels
NUM_USERS = args.num_users
ATTEMPT_PROB = args.attempt_prob

#It creates a one hot vector of a number as num with size as len
def one_hot(num,len):
Expand Down Expand Up @@ -53,13 +78,16 @@ def state_generator(action,obs):
beta = 1 #Annealing constant for Monte - Carlo

# reseting default tensorflow computational graph
tf.reset_default_graph()
tf.compat.v1.reset_default_graph()

#initializing the environment
env = env_network(NUM_USERS,NUM_CHANNELS,ATTEMPT_PROB)

#initializing deep Q network
mainQN = QNetwork(name='main',hidden_size=hidden_size,learning_rate=learning_rate,step_size=step_size,state_size=state_size,action_size=action_size)
#initializing deep Q network (online and target for Double DQN)
mainQN = QNetwork(name='main',hidden_size=hidden_size,learning_rate=learning_rate,
step_size=step_size,state_size=state_size,action_size=action_size)
targetQN = QNetwork(name='target',hidden_size=hidden_size,learning_rate=learning_rate,
step_size=step_size,state_size=state_size,action_size=action_size)

#this is experience replay buffer(deque) from which each batch will be sampled and fed to the neural network for training
memory = Memory(max_size=memory_size)
Expand Down Expand Up @@ -214,13 +242,22 @@ def get_next_states_user(batch):
interval = 1 # debug interval

# saver object to save the checkpoints of the DQN to disk
saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()

#initializing the session
sess = tf.Session()
sess = tf.compat.v1.Session()

#initialing all the tensorflow variables
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())

# -----------------------------
# Double DQN target update ops
# -----------------------------
main_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='main')
target_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='target')
update_target_ops = [t.assign(m) for m, t in zip(main_vars, target_vars)]
sess.run(update_target_ops)
TARGET_UPDATE_INTERVAL = 1000


#list of total rewards
Expand Down Expand Up @@ -374,12 +411,16 @@ def get_next_states_user(batch):
rewards = np.reshape(rewards,[-1,rewards.shape[2]])
next_states = np.reshape(next_states,[-1,next_states.shape[2],next_states.shape[3]])

# creating target vector (possible best action)
target_Qs = sess.run(mainQN.output,feed_dict={mainQN.inputs_:next_states})

# creating target vector using Double DQN:
# use main network to select actions, target network to evaluate them
next_Q_main = sess.run(mainQN.output, feed_dict={mainQN.inputs_: next_states})
best_actions = np.argmax(next_Q_main, axis=1)
next_Q_target = sess.run(targetQN.output, feed_dict={targetQN.inputs_: next_states})

# Q_target = reward + gamma * Q_next
targets = rewards[:,-1] + gamma * np.max(target_Qs,axis=1)
# Q_target = reward + gamma * Q_target(s', argmax_a Q_main(s', a))
batch_indices = np.arange(next_Q_target.shape[0])
next_best_q = next_Q_target[batch_indices, best_actions]
targets = rewards[:, -1] + gamma * next_best_q

# calculating loss and train using Adam optimizer
loss, _ = sess.run([mainQN.loss,mainQN.opt],
Expand All @@ -390,6 +431,9 @@ def get_next_states_user(batch):

# Training block ends
########################################################################################
# Periodically update target network parameters from main network
if time_step % TARGET_UPDATE_INTERVAL == 0:
sess.run(update_target_ops)

if time_step %5000 == 4999:
plt.figure(1)
Expand All @@ -414,6 +458,18 @@ def get_next_states_user(batch):
cum_r = [0]
cum_collision = [0]
saver.save(sess,'checkpoints/dqn_multi-user.ckpt')
# save a one-line CSV summary for this window
with open('results_summary.csv', 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([
TIME_SLOTS,
NUM_CHANNELS,
NUM_USERS,
ATTEMPT_PROB,
time_step + 1,
cum_r[-1],
cum_collision[-1]
])
#print time_step,loss , sum(reward) , Qs

print ("*************************************************")
Expand Down