Skip to content

Commit f34f6ac

Browse files
author
Donglai Wei
committed
fix formatting
1 parent b264e7c commit f34f6ac

File tree

9 files changed

+45
-45
lines changed

9 files changed

+45
-45
lines changed

connectomics/data/augment/monai_transforms.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def _apply_misalignment_translation(
9999

100100
output = np.zeros(out_shape, img.dtype)
101101
if mode == "slip":
102-
output = img[:, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
103-
output[idx] = img[idx, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
102+
output = img[:, y0: y0 + out_shape[1], x0: x0 + out_shape[2]]
103+
output[idx] = img[idx, y1: y1 + out_shape[1], x1: x1 + out_shape[2]]
104104
else:
105-
output[:idx] = img[:idx, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
106-
output[idx:] = img[idx:, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
105+
output[:idx] = img[:idx, y0: y0 + out_shape[1], x0: x0 + out_shape[2]]
106+
output[idx:] = img[idx:, y1: y1 + out_shape[1], x1: x1 + out_shape[2]]
107107

108108
if is_tensor:
109109
output = torch.from_numpy(output).to(device)
@@ -299,7 +299,7 @@ def _apply_missing_parts(
299299
x_start = self.R.randint(0, img.shape[2] - hole_w + 1)
300300

301301
# Create hole (set to 0 or mean value)
302-
img[section_idx, y_start : y_start + hole_h, x_start : x_start + hole_w] = 0
302+
img[section_idx, y_start: y_start + hole_h, x_start: x_start + hole_w] = 0
303303

304304
return img
305305

@@ -452,24 +452,24 @@ def _apply_cut_noise(
452452
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
453453
region = img[
454454
:,
455-
z_start : z_start + z_len,
456-
y_start : y_start + y_len,
457-
x_start : x_start + x_len,
455+
z_start: z_start + z_len,
456+
y_start: y_start + y_len,
457+
x_start: x_start + x_len,
458458
]
459459
noisy_region = np.clip(region + noise, 0, 1)
460460
img[
461461
:,
462-
z_start : z_start + z_len,
463-
y_start : y_start + y_len,
464-
x_start : x_start + x_len,
462+
z_start: z_start + z_len,
463+
y_start: y_start + y_len,
464+
x_start: x_start + x_len,
465465
] = noisy_region
466466
else:
467467
# (C, H, W) - 2D with channels
468468
noise_shape = (img.shape[0], y_len, x_len)
469469
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
470-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
470+
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
471471
noisy_region = np.clip(region + noise, 0, 1)
472-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
472+
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
473473
elif img.ndim == 3:
474474
# 3D case: (Z, Y, X) or (C, H, W)
475475
# Heuristic: if first dim is small (<=4), assume it's channel (2D with channels)
@@ -478,29 +478,29 @@ def _apply_cut_noise(
478478
# (C, H, W) - 2D with channels
479479
noise_shape = (img.shape[0], y_len, x_len)
480480
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
481-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
481+
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
482482
noisy_region = np.clip(region + noise, 0, 1)
483-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
483+
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
484484
else:
485485
# (Z, Y, X) - 3D
486486
z_len = max(1, int(self.length_ratio * img.shape[0])) # Ensure at least 1
487487
z_start = self.R.randint(0, max(1, img.shape[0] - z_len + 1))
488488
noise_shape = (z_len, y_len, x_len)
489489
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
490490
region = img[
491-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
491+
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
492492
]
493493
noisy_region = np.clip(region + noise, 0, 1)
494494
img[
495-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
495+
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
496496
] = noisy_region
497497
else:
498498
# 2D case: (H, W)
499499
noise_shape = (y_len, x_len)
500500
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
501-
region = img[y_start : y_start + y_len, x_start : x_start + x_len]
501+
region = img[y_start: y_start + y_len, x_start: x_start + x_len]
502502
noisy_region = np.clip(region + noise, 0, 1)
503-
img[y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
503+
img[y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
504504

505505
if is_tensor:
506506
img = torch.from_numpy(img).to(device)
@@ -886,7 +886,7 @@ def _find_best_paste(
886886
neuron_tensor.flip(0) if neuron_tensor.ndim == 3 else neuron_tensor.flip(1)
887887
)
888888

889-
label_paste = labels[best_idx : best_idx + 1]
889+
label_paste = labels[best_idx: best_idx + 1]
890890

891891
if best_angle != 0:
892892
label_paste = self._rotate_3d(label_paste, best_angle)

connectomics/data/io/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171171
Raises:
172172
ValueError: If file format is not supported
173173
"""
174-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
174+
image_suffix = filename[filename.rfind(".") + 1:].lower()
175175
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176176
raise ValueError(
177177
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +281,7 @@ def read_volume(
281281
if filename.endswith(".nii.gz"):
282282
image_suffix = "nii.gz"
283283
else:
284-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
284+
image_suffix = filename[filename.rfind(".") + 1:].lower()
285285

286286
if image_suffix in ["h5", "hdf5"]:
287287
data = read_hdf5(filename, dataset)
@@ -420,7 +420,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420420
if filename.endswith(".nii.gz"):
421421
image_suffix = "nii.gz"
422422
else:
423-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
423+
image_suffix = filename[filename.rfind(".") + 1:].lower()
424424

425425
if image_suffix in ["h5", "hdf5"]:
426426
# HDF5: Read shape from metadata (no data loading)

connectomics/data/io/tiles.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def reconstruct_volume_from_tiles(
160160
if is_image: # Image data
161161
result[
162162
z - z0,
163-
y_actual_start - y0 : y_actual_end - y0,
164-
x_actual_start - x0 : x_actual_end - x0,
163+
y_actual_start - y0: y_actual_end - y0,
164+
x_actual_start - x0: x_actual_end - x0,
165165
] = patch[
166-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
167-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
166+
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
167+
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
168168
0,
169169
]
170170
else: # Label data
171171
result[
172172
z - z0,
173-
y_actual_start - y0 : y_actual_end - y0,
174-
x_actual_start - x0 : x_actual_end - x0,
173+
y_actual_start - y0: y_actual_end - y0,
174+
x_actual_start - x0: x_actual_end - x0,
175175
] = rgb_to_seg(
176176
patch[
177-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
178-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
177+
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
178+
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
179179
]
180180
)
181181

connectomics/data/process/crop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1616
st = np.array(st).astype(np.int32)
1717

1818
if data.ndim == 3:
19-
return data[st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
19+
return data[st[0]: st[0] + sz[0], st[1]: st[1] + sz[1], st[2]: st[2] + sz[2]]
2020
else: # crop spatial dimensions
21-
return data[:, st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
21+
return data[:, st[0]: st[0] + sz[0], st[1]: st[1] + sz[1], st[2]: st[2] + sz[2]]
2222

2323

2424
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
@@ -64,9 +64,9 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
6464
if len(vol_sz) == 3:
6565
mask_sum = (
6666
mask_sum[
67-
pad_sz_pre[0] : pad_sz_post[0],
68-
pad_sz_pre[1] : pad_sz_post[1],
69-
pad_sz_pre[2] : pad_sz_post[2],
67+
pad_sz_pre[0]: pad_sz_post[0],
68+
pad_sz_pre[1]: pad_sz_post[1],
69+
pad_sz_pre[2]: pad_sz_post[2],
7070
]
7171
>= valid_thres
7272
)
@@ -86,7 +86,7 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
8686
)
8787
else:
8888
mask_sum = (
89-
mask_sum[pad_sz_pre[0] : pad_sz_post[0], pad_sz_pre[1] : pad_sz_post[1]] >= valid_thres
89+
mask_sum[pad_sz_pre[0]: pad_sz_post[0], pad_sz_pre[1]: pad_sz_post[1]] >= valid_thres
9090
)
9191
if mask_sum.max() > 0:
9292
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), np.arange(mask_sum.shape[1]))

connectomics/decoding/postprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def apply_binary_postprocessing(
360360

361361
if len(sizes) > cc_config.top_k:
362362
# Get indices of top-k largest components
363-
top_k_indices = np.argsort(sizes)[-cc_config.top_k :]
363+
top_k_indices = np.argsort(sizes)[-cc_config.top_k:]
364364
top_k_labels = label_ids[top_k_indices]
365365

366366
# Create mask keeping only top-k

connectomics/inference/tta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def apply_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
117117
pass # Keep all channels
118118
elif isinstance(tta_channel, int):
119119
if tta_channel != -1:
120-
tensor = tensor[:, tta_channel : tta_channel + 1, ...]
120+
tensor = tensor[:, tta_channel: tta_channel + 1, ...]
121121
elif isinstance(tta_channel, (list, tuple, Sequence)):
122122
# Convert to list of integers (handle both int and string numbers
123123
# from OmegaConf)

connectomics/models/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_external_weights(model, cfg):
6565
stripped_count = 0
6666
for key, value in state_dict.items():
6767
if key.startswith(key_prefix):
68-
new_key = key[len(key_prefix) :]
68+
new_key = key[len(key_prefix):]
6969
new_state_dict[new_key] = value
7070
stripped_count += 1
7171
else:

connectomics/training/deep_supervision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def compute_multitask_loss(
123123
num_label_channels = 1
124124

125125
# Extract label channels
126-
task_label = labels[:, label_ch_offset : label_ch_offset + num_label_channels, ...]
126+
task_label = labels[:, label_ch_offset: label_ch_offset + num_label_channels, ...]
127127
label_ch_offset += num_label_channels
128128

129129
# Apply specified losses for this task

connectomics/utils/visualizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def _log_multi_channel_viz(
303303

304304
# Show each output channel
305305
for i in range(min(output.shape[1], 12)): # Increased limit to 12 channels
306-
channel_img = output[:, i : i + 1].repeat(1, 3, 1, 1) # Convert to RGB
306+
channel_img = output[:, i: i + 1].repeat(1, 3, 1, 1) # Convert to RGB
307307
writer.add_image(
308308
f"{prefix}/output_channel_{i}",
309309
vutils.make_grid(
@@ -317,7 +317,7 @@ def _log_multi_channel_viz(
317317

318318
# Show each label channel
319319
for i in range(min(label.shape[1], 12)): # Increased limit to 12 channels
320-
channel_img = label[:, i : i + 1].repeat(1, 3, 1, 1) # Convert to RGB
320+
channel_img = label[:, i: i + 1].repeat(1, 3, 1, 1) # Convert to RGB
321321
writer.add_image(
322322
f"{prefix}/label_channel_{i}",
323323
vutils.make_grid(
@@ -333,7 +333,7 @@ def _log_multi_channel_viz(
333333
if mask is not None and mask.numel() > 0:
334334
for i in range(min(mask.shape[1], 12)): # Show up to 12 mask channels
335335
# Show mask in cyan for better visibility
336-
mask_channel = mask[:, i : i + 1]
336+
mask_channel = mask[:, i: i + 1]
337337
mask_rgb = torch.cat(
338338
[
339339
torch.zeros_like(mask_channel), # R=0
@@ -393,7 +393,7 @@ def _normalize(self, tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
393393
# Multi-channel: normalize each channel independently
394394
normalized = []
395395
for c in range(tensor.shape[1]):
396-
channel = tensor[:, c : c + 1]
396+
channel = tensor[:, c: c + 1]
397397
min_val = channel.min()
398398
max_val = channel.max()
399399
if max_val > min_val:

0 commit comments

Comments
 (0)