-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting_script.py
More file actions
29 lines (19 loc) · 808 Bytes
/
plotting_script.py
File metadata and controls
29 lines (19 loc) · 808 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
##loading the npy files
q_learning=np.load('All_rewards_Q_learning.npy')
#taking the mean of 5 runs
avg_rewards=np.mean(q_learning,axis=0)
##get the std of the data
std_rewards=np.std(q_learning,axis=0)
smoothing_window=5
#avg_rewards=pd.DataFrame(avg_rewards)
smoothed_avg_rewards=pd.Series(avg_rewards).rolling(smoothing_window, min_periods=smoothing_window).mean()
plt.plot(smoothed_avg_rewards,color='red')
##filling in between stds
plt.fill_between(range(len(smoothed_avg_rewards)),smoothed_avg_rewards+std_rewards,smoothed_avg_rewards-std_rewards,alpha=0.2,edgecolor='red',facecolor='red')
plt.xlabel('Number of episodes')
plt.ylabel('Average Cumulative Reward')
plt.title('Q_learning in HIV Drug Scheduling')
plt.show()