Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ cufile.log

# annotation tool
/annotation_tool/application
/annotation_tool/Dockerfile
/annotation_tool/Dockerfile
.idea/
50 changes: 42 additions & 8 deletions cellvit/inference/inference_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def process_wsi(
]

call_ids = []
MAX_IN_FLIGHT = 4
inference_results = []

self.logger.info("Extracting cells using CellViT...")
with torch.no_grad():
Expand All @@ -158,15 +160,39 @@ def process_wsi(
else:
predictions = self.model.forward(patches, retrieve_tokens=True)
predictions = self.apply_softmax_reorder(predictions)
call_id = batch_actor.convert_batch_to_graph_nodes.remote(
predictions, metadata
)
# ==================== {origin code start} ====================
# call_id = batch_actor.convert_batch_to_graph_nodes.remote(
# predictions, metadata
# )
# call_ids.append(call_id)
# pbar.update(1)
# pbar.total = len(wsi_inference_dataloader)
# ==================== {origin code end} ====================
# ==================== {fix code start} ====================
call_id = batch_actor.convert_batch_to_graph_nodes.remote(predictions, metadata)
call_ids.append(call_id)
if len(call_ids) >= MAX_IN_FLIGHT:
ready_ids, call_ids = ray.wait(call_ids, num_returns=1, timeout=None)
ready_results = ray.get(ready_ids)
inference_results.extend(ready_results)
ray.internal.free(ready_ids)
del ready_results, ready_ids
# ==================== {fix code end} ====================
pbar.update(1)
pbar.total = len(wsi_inference_dataloader)

self.logger.info("Waiting for final batches to be processed...")
inference_results = [ray.get(call_id) for call_id in call_ids]
# ==================== {origin code start} ====================
# inference_results = [ray.get(call_id) for call_id in call_ids]
# ==================== {origin code end} ====================
# ==================== {fix code start} ====================
while call_ids:
ready_ids, call_ids = ray.wait(call_ids, num_returns=min(8, len(call_ids)), timeout=None)
ready_results = ray.get(ready_ids)
inference_results.extend(ready_results)
ray.internal.free(ready_ids)
del ready_results, ready_ids
# ==================== {fix code end} ====================
del pbar
[ray.kill(batch_actor) for batch_actor in batch_pooling_actors]

Expand All @@ -191,10 +217,18 @@ def process_wsi(
batch_cell_tokens,
batch_cell_positions,
) = batch_results
cell_dict_wsi = cell_dict_wsi + batch_complete_dict
cell_dict_detection = cell_dict_detection + batch_detection
graph_data["cell_tokens"] = graph_data["cell_tokens"] + batch_cell_tokens
graph_data["positions"] = graph_data["positions"] + batch_cell_positions
# ==================== {origin code start} ====================
# cell_dict_wsi = cell_dict_wsi + batch_complete_dict
# cell_dict_detection = cell_dict_detection + batch_detection
# graph_data["cell_tokens"] = graph_data["cell_tokens"] + batch_cell_tokens
# graph_data["positions"] = graph_data["positions"] + batch_cell_positions
# ==================== {origin code end} ====================
# ==================== {fix code start} ====================
cell_dict_wsi.extend(batch_complete_dict)
cell_dict_detection.extend(batch_detection)
graph_data["cell_tokens"].extend(batch_cell_tokens)
graph_data["positions"].extend(batch_cell_positions)
# ==================== {fix code end} ====================

# cleaning overlapping cells
if len(cell_dict_wsi) == 0:
Expand Down