|
27 | 27 | import numpy as np |
28 | 28 | import tensorflow.compat.v1 as tf |
29 | 29 | from tensorflow.io import gfile |
| 30 | +import tensorstore as ts |
30 | 31 |
|
31 | 32 |
|
32 | 33 | def create_filename_queue(coordinates_file_pattern, shuffle=True): |
@@ -323,21 +324,107 @@ def weighted_load_patch_coordinates( |
323 | 324 | ) |
324 | 325 |
|
325 | 326 |
|
326 | | -def load_from_numpylike(coordinates, volume_names, shape, volume_map, |
327 | | - name=None): |
| 327 | +def _filter_masked(item, volinfo_map_string: str): |
| 328 | + mask_value = load_from_volume( |
| 329 | + item['coord'], |
| 330 | + item['volname'], |
| 331 | + patch_size=(1, 1, 1), |
| 332 | + dtype=tf.int64, |
| 333 | + num_channels=1, |
| 334 | + volinfo_map_string=volinfo_map_string, |
| 335 | + ) |
| 336 | + return mask_value[0, 0, 0, 0, 0] > 0 |
| 337 | + |
| 338 | + |
| 339 | +def load_from_volume( |
| 340 | + coord, volname, patch_size, dtype, num_channels, volinfo_map_string: str |
| 341 | +): |
| 342 | + """Loads data from a volume using TensorStore. |
| 343 | +
|
| 344 | + Args: |
| 345 | + coord: The coordinates to load from. |
| 346 | + volname: The name of the volume. |
| 347 | + patch_size: The size of the patch to load. |
| 348 | + dtype: The data type of the volume. |
| 349 | + num_channels: The number of channels in the volume. |
| 350 | + volinfo_map_string: A string representation of the volume info map with the |
| 351 | + format "volname1:volinfo_path1,volname2:volinfo_path2". |
| 352 | +
|
| 353 | + Returns: |
| 354 | + A tensor containing the loaded data. |
| 355 | + """ |
| 356 | + if num_channels != 1: |
| 357 | + raise ValueError('Only num_channels=1 is currently supported.') |
| 358 | + |
| 359 | + volinfo_map = {} |
| 360 | + for pair in volinfo_map_string.split(','): |
| 361 | + name, path = pair.split(':') |
| 362 | + volinfo_map[name.strip()] = path.strip() |
| 363 | + |
| 364 | + def _load_single_volume(inputs): |
| 365 | + coord, volinfo_path = inputs |
| 366 | + print('volinfo_path:', volinfo_path) |
| 367 | + print('coord:', coord) |
| 368 | + volinfo_path = volinfo_path.numpy().decode('utf-8') |
| 369 | + coord = coord.numpy() |
| 370 | + spec = {'driver': 'volumestore', 'volinfo_path': volinfo_path} |
| 371 | + |
| 372 | + store = ts.open(spec, open=True).result() |
| 373 | + |
| 374 | + start_coord = [max(0, c - (p // 2)) for c, p in zip(coord, patch_size)] |
| 375 | + stop_coord = [ |
| 376 | + min(store.shape[i], c + (p // 2) + (p % 2)) |
| 377 | + for i, (c, p) in enumerate(zip(coord, patch_size)) |
| 378 | + ] |
| 379 | + |
| 380 | + data = ( |
| 381 | + store[ |
| 382 | + start_coord[0] : stop_coord[0], |
| 383 | + start_coord[1] : stop_coord[1], |
| 384 | + start_coord[2] : stop_coord[2], |
| 385 | + ] |
| 386 | + .read() |
| 387 | + .result() |
| 388 | + ) |
| 389 | + |
| 390 | + data = data[:, :, :, 0].transpose(2, 1, 0).astype(dtype.as_numpy_dtype) |
| 391 | + data = data[..., tf.newaxis] |
| 392 | + return data |
| 393 | + |
| 394 | + patch_size = list(patch_size) |
| 395 | + # Convert lists to tensors for tf.map_fn |
| 396 | + coords_tensor = tf.convert_to_tensor(coord) |
| 397 | + volinfo_paths_tensor = tf.convert_to_tensor( |
| 398 | + [volinfo_map[v].encode('utf-8') for v in volname], dtype=tf.string |
| 399 | + ) |
| 400 | + |
| 401 | + # Use tf.map_fn to process each volume |
| 402 | + data_tensor = tf.map_fn( |
| 403 | + _load_single_volume, |
| 404 | + (coords_tensor, volinfo_paths_tensor), |
| 405 | + fn_output_signature=dtype, |
| 406 | + dtype=dtype, |
| 407 | + ) |
| 408 | + |
| 409 | + return data_tensor |
| 410 | + |
| 411 | + |
| 412 | +def load_from_numpylike( |
| 413 | + coordinates, volume_names, shape, volume_map, name=None |
| 414 | +): |
328 | 415 | """TensorFlow Python op that loads data from Numpy-like volumes. |
329 | 416 |
|
330 | 417 | The volume object must support Numpy-like indexing, as well as shape, ndim, |
331 | 418 | and dtype properties. The volume can be 3d or 4d. |
332 | 419 |
|
333 | 420 | Args: |
334 | | - coordinates: tensor of shape [1, 3] containing XYZ coordinates of the |
335 | | - center of the subvolume to load. |
| 421 | + coordinates: tensor of shape [1, 3] containing XYZ coordinates of the center |
| 422 | + of the subvolume to load. |
336 | 423 | volume_names: tensor of shape [1] containing names of volumes to load data |
337 | | - from. |
| 424 | + from. |
338 | 425 | shape: a 3-sequence giving the XYZ shape of the data to load. |
339 | 426 | volume_map: a dictionary mapping volume names to volume objects. See above |
340 | | - for API requirements of the Numpy-like volume objects. |
| 427 | + for API requirements of the Numpy-like volume objects. |
341 | 428 | name: the op name. |
342 | 429 |
|
343 | 430 | Returns: |
|
0 commit comments