-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvis.py
More file actions
executable file
·209 lines (186 loc) · 6.98 KB
/
vis.py
File metadata and controls
executable file
·209 lines (186 loc) · 6.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Visualize generated and reconstructed pressure fields, sound speed, and sensor locations in 3D using vedo.
Usage:
python vis.py [data_path] # leave data_path empty to use the default path in params.yaml
"""
import os
import jax.numpy as jnp
import numpy as np
import vedo
from vedo import Plotter, Volume, Text2D, Points, Slider2D, color_map
import glob
import util as u
import argparse
vedo.settings.default_backend = "vtk"
class VolumeVisualizer:
def __init__(self, data_path):
self.data_path = data_path
self.file_index = 0
self.plotter = Plotter(axes=6, bg="white", size=(1600, 800))
self.max_iteration_index = 0
self.update_max_indices()
self.iteration_index = self.max_iteration_index
self.c_vol_original = Volume()
self.p_r_vol_original = Volume()
self.last_thresholded_item = None
self.items = self.load_data()
def update_max_indices(self):
self.max_file_index = len(os.listdir(os.path.join(self.data_path, "p0"))) - 1
self.max_iteration_index = (
len(glob.glob(os.path.join(self.data_path, f"p_r/{self.file_index}_*.npy")))
- 1
)
def load_data(self):
self.plotter.clear()
items = []
# Load p0 volume
p0_file = os.path.join(self.data_path, "p0", f"{self.file_index}.npy")
p0_vol = self.load_volume(p0_file).isosurface(0.5).alpha(0.1).cmap("Greens")
items.append(p0_vol)
# Load p_r volume
p_r_file = os.path.join(
self.data_path, "p_r", f"{self.file_index}_{self.iteration_index}.npy"
)
p_r_vol = self.load_volume(p_r_file)
vrange = p_r_vol.scalar_range()
colors = [
(vrange[0], [0.0, 0.0, 1.0]),
(0, [1.0, 1.0, 1.0]),
(vrange[1], [1.0, 0.0, 0.0]),
]
alpha = [1.0, 0.0, 1.0]
p_r_vol.cmap(colors, alpha=alpha).add_scalarbar("p_r").mode(1).alpha(0.1)
p_r_vol.name = "p_r"
print(p_r_vol.name)
self.p_r_original = p_r_vol.copy()
items.append(p_r_vol)
# p_r sliders
self.plotter.add_slider(
self.change_iteration_index,
0,
self.max_iteration_index,
value=self.iteration_index,
title="Iteration Index",
pos="top-right",
)
# Load sound speed volume
c_file = os.path.join(self.data_path, "c", f"{self.file_index}.npy")
lower = u.C - u.C_VARIATION_AMPLITUDE / 2
upper = u.C + u.C_VARIATION_AMPLITUDE / 2
c_vol = self.load_volume(c_file)
# vrange = c_vol.scalar_range()
colors = [
(lower - 1, [0.0, 1.0, 0.0]),
(u.C, [1.0, 1.0, 1.0]),
(upper + 1, [1.0, 0.0, 1.0]),
]
alpha = [0.01, 0.0,0.01]
c_vol.cmap(colors, alpha=alpha, vmin=lower - 1, vmax=upper + 1).add_scalarbar("c", pos=(0.775 - 0.1, 0.05))#.isosurface(np.linspace(lower, upper, 5))
c_vol.name = "c"
self.c_vol_original = c_vol.copy()
self.last_thresholded_item = c_vol
items.append(c_vol)
self.plotter.add_slider(
lambda widget, event: self.apply_threshold(widget, event, c_vol),
lower,
upper,
title="c Threshold",
value=upper,
pos=[(.5,.12), (.7,.12)],
)
# Load sensor points
sensors_file = os.path.join(self.data_path, "sensors", f"{self.file_index}.npy")
sensor_points = self.load_sensors(sensors_file)
items.append(sensor_points)
self.plotter.add(items)
return items
def load_volume(self, file_path):
if os.path.exists(file_path):
data = jnp.load(file_path)
volume = Volume(data)
return volume
else:
print(f"File {file_path} does not exist")
return Volume()
def load_sensors(self, file_path):
if os.path.exists(file_path):
sensors = jnp.load(file_path)
sensor_points = Points(sensors.T)
self.plotter.add(sensor_points)
return sensor_points
else:
print(f"File {file_path} does not exist")
return Points()
def change_file_index(self, widget, event, increment):
self.file_index = np.clip(self.file_index + increment, 0, self.max_file_index)
self.update_max_indices()
self.items = self.load_data()
self.plotter.render()
def change_iteration_index(self, widget, event):
self.iteration_index = int(widget.value)
widget.value = self.iteration_index
self.items = self.load_data()
self.plotter.render()
def apply_threshold(self, widget, event, item):
if item.name == "p_r":
p_r_data = self.p_r_original.tonumpy()
p_r_data[p_r_data < widget.value] = 0
item = Volume(p_r_data)
elif item.name == "c":
c_data = self.c_vol_original.tonumpy()
c_data[c_data < widget.value] = u.C
c_vol = Volume(c_data)
lower = u.C - u.C_VARIATION_AMPLITUDE / 2
upper = u.C + u.C_VARIATION_AMPLITUDE / 2
colors = [
(lower - 1, [0.0, 1.0, 0.0]),
(u.C, [1.0, 1.0, 1.0]),
(upper + 1, [1.0, 0.0, 1.0]),
]
alpha = [0.01, 0.0,0.01]
c_vol.cmap(colors, alpha=alpha, vmin=lower - 1, vmax=upper + 1).add_scalarbar("c", pos=(0.775 - 0.1, 0.05))#.isosurface(np.linspace(lower, upper, 5))
c_vol.name = "c"
item = c_vol
del c_vol, c_data
item.threshold(below=widget.value, replace=0)#.mode(1)
self.plotter.remove(self.last_thresholded_item)
# self.plotter.remove(c_vol)
self.plotter.add(item)
self.plotter.render()
self.last_thresholded_item = item
def show_hide_c(self, widget, event):
c_vol = self.items[-1]
if c_vol.alpha() == 0.0:
c_vol.alpha(0.1)
widget.switch()
else:
c_vol.alpha(0.0)
widget.switch()
self.plotter.render()
def show(self):
self.plotter.add_button(
lambda widget, event: self.change_file_index(widget, event, -1),
states=("<"),
pos=(0.02, 0.05),
)
self.plotter.add_button(
lambda widget, event: self.change_file_index(widget, event, 1),
states=(">"),
pos=(0.06, 0.05),
)
self.plotter.add_button(
lambda widget, event: self.show_hide_c(widget, event),
states=("c is visible", "c is hidden"),
pos=(0.06, 0.1),
)
self.plotter.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"data_path", type=str, default=None, help="data path", nargs="?"
)
args = parser.parse_args()
if args.data_path is not None:
u.DATA_PATH = args.data_path
visualizer = VolumeVisualizer(u.DATA_PATH)
visualizer.show()