Skip to content

Commit ae3377b

Browse files
Jin Xucopybara-github
authored andcommitted
Migrate load_from_volume into inputs.py
PiperOrigin-RevId: 646187131
1 parent 1e95fbf commit ae3377b

File tree

1 file changed

+93
-6
lines changed

1 file changed

+93
-6
lines changed

ffn/training/inputs.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import tensorflow.compat.v1 as tf
2929
from tensorflow.io import gfile
30+
import tensorstore as ts
3031

3132

3233
def create_filename_queue(coordinates_file_pattern, shuffle=True):
@@ -323,21 +324,107 @@ def weighted_load_patch_coordinates(
323324
)
324325

325326

326-
def load_from_numpylike(coordinates, volume_names, shape, volume_map,
327-
name=None):
327+
def _filter_masked(item, volinfo_map_string: str):
328+
mask_value = load_from_volume(
329+
item['coord'],
330+
item['volname'],
331+
patch_size=(1, 1, 1),
332+
dtype=tf.int64,
333+
num_channels=1,
334+
volinfo_map_string=volinfo_map_string,
335+
)
336+
return mask_value[0, 0, 0, 0, 0] > 0
337+
338+
339+
def load_from_volume(
340+
coord, volname, patch_size, dtype, num_channels, volinfo_map_string: str
341+
):
342+
"""Loads data from a volume using TensorStore.
343+
344+
Args:
345+
coord: The coordinates to load from.
346+
volname: The name of the volume.
347+
patch_size: The size of the patch to load.
348+
dtype: The data type of the volume.
349+
num_channels: The number of channels in the volume.
350+
volinfo_map_string: A string representation of the volume info map with the
351+
format "volname1:volinfo_path1,volname2:volinfo_path2".
352+
353+
Returns:
354+
A tensor containing the loaded data.
355+
"""
356+
if num_channels != 1:
357+
raise ValueError('Only num_channels=1 is currently supported.')
358+
359+
volinfo_map = {}
360+
for pair in volinfo_map_string.split(','):
361+
name, path = pair.split(':')
362+
volinfo_map[name.strip()] = path.strip()
363+
364+
def _load_single_volume(inputs):
365+
coord, volinfo_path = inputs
366+
print('volinfo_path:', volinfo_path)
367+
print('coord:', coord)
368+
volinfo_path = volinfo_path.numpy().decode('utf-8')
369+
coord = coord.numpy()
370+
spec = {'driver': 'volumestore', 'volinfo_path': volinfo_path}
371+
372+
store = ts.open(spec, open=True).result()
373+
374+
start_coord = [max(0, c - (p // 2)) for c, p in zip(coord, patch_size)]
375+
stop_coord = [
376+
min(store.shape[i], c + (p // 2) + (p % 2))
377+
for i, (c, p) in enumerate(zip(coord, patch_size))
378+
]
379+
380+
data = (
381+
store[
382+
start_coord[0] : stop_coord[0],
383+
start_coord[1] : stop_coord[1],
384+
start_coord[2] : stop_coord[2],
385+
]
386+
.read()
387+
.result()
388+
)
389+
390+
data = data[:, :, :, 0].transpose(2, 1, 0).astype(dtype.as_numpy_dtype)
391+
data = data[..., tf.newaxis]
392+
return data
393+
394+
patch_size = list(patch_size)
395+
# Convert lists to tensors for tf.map_fn
396+
coords_tensor = tf.convert_to_tensor(coord)
397+
volinfo_paths_tensor = tf.convert_to_tensor(
398+
[volinfo_map[v].encode('utf-8') for v in volname], dtype=tf.string
399+
)
400+
401+
# Use tf.map_fn to process each volume
402+
data_tensor = tf.map_fn(
403+
_load_single_volume,
404+
(coords_tensor, volinfo_paths_tensor),
405+
fn_output_signature=dtype,
406+
dtype=dtype,
407+
)
408+
409+
return data_tensor
410+
411+
412+
def load_from_numpylike(
413+
coordinates, volume_names, shape, volume_map, name=None
414+
):
328415
"""TensorFlow Python op that loads data from Numpy-like volumes.
329416
330417
The volume object must support Numpy-like indexing, as well as shape, ndim,
331418
and dtype properties. The volume can be 3d or 4d.
332419
333420
Args:
334-
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the
335-
center of the subvolume to load.
421+
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the center
422+
of the subvolume to load.
336423
volume_names: tensor of shape [1] containing names of volumes to load data
337-
from.
424+
from.
338425
shape: a 3-sequence giving the XYZ shape of the data to load.
339426
volume_map: a dictionary mapping volume names to volume objects. See above
340-
for API requirements of the Numpy-like volume objects.
427+
for API requirements of the Numpy-like volume objects.
341428
name: the op name.
342429
343430
Returns:

0 commit comments

Comments
 (0)