-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnice_plot.py
More file actions
73 lines (61 loc) · 2.28 KB
/
nice_plot.py
File metadata and controls
73 lines (61 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import matplotlib.pyplot as plt
from pylab import genfromtxt
import sys
import argparse
import numpy as np
"""
Read file from command line argument.
Expected format: tab separated columns
x f(x) g(x) h(x)
"""
COLORS = ["b", "r", "y", "k"]
def plot():
parser = argparse.ArgumentParser(description="Make Nice Plots")
parser.add_argument ("--mode", default = None, required = True,
help = "Which plots iou, loss, custom")
parser.add_argument ("--title", default = None,
help = "title")
parser.add_argument ("--file", default = None, required = True,
help = "Input plot")
parser.add_argument ("--show", action="store_true",
help = "Shows plot instead of saving it")
parser.add_argument ("--cols", default = None, nargs="*")
args = parser.parse_args ()
file_name = args.file
cols = []
y_label = None
if args.mode == "iou":
cols = ["Epochs", "Train iou", "Val iou"]
y_label = "IoU"
elif args.mode == "loss":
cols = ["Epochs", "Train Loss", "Val Loss"]
y_label = "Loss"
elif args.mode == "size":
cols = ["Training Set Size", "Train Loss", "Val Loss"]
y_label = "Loss"
elif args.mode == "custom":
cols = args.cols
labels = genfromtxt(file_name, delimiter="\t", max_rows = 1, dtype=str)
data = genfromtxt(file_name, delimiter="\t", skip_header = 1)
cols = np.array (cols)
#clean labels and columns
for names in [cols, labels]:
np.place (names, names=="Val iou", "Val IoU")
np.place (names, names=="Train iou", "Train IoU")
for i in range (1, len (cols)):
y_idx = np.argwhere (labels == cols [i])[0][0]
x_idx = np.argwhere (labels == cols [0])[0][0]
plt.plot(data[:, x_idx], data[:, y_idx], label=labels[y_idx], linewidth=4.0)
plt.legend (loc=0)
plt.xlabel(labels[x_idx],fontsize=16)
plt.ylabel(y_label, fontsize=16)
if args.mode == "size":
plt.ylim(ymin=0)
if args.title:
plt.suptitle (args.title,fontsize=20)
if args.show:
plt.show()
save_name = file_name.split("/")[-1].split(".")[0] +"_" + args.mode + ".png"
plt.savefig(save_name)
if __name__ == "__main__":
plot ()