diff --git a/.gitignore b/.gitignore index d26f3ce..0ffbeb4 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,5 @@ cufile.log # annotation tool /annotation_tool/application -/annotation_tool/Dockerfile \ No newline at end of file +/annotation_tool/Dockerfile +.idea/ diff --git a/cellvit/inference/inference_memory.py b/cellvit/inference/inference_memory.py index a411724..d5d00e7 100644 --- a/cellvit/inference/inference_memory.py +++ b/cellvit/inference/inference_memory.py @@ -141,6 +141,8 @@ def process_wsi( ] call_ids = [] + MAX_IN_FLIGHT = 4 + inference_results = [] self.logger.info("Extracting cells using CellViT...") with torch.no_grad(): @@ -158,15 +160,39 @@ def process_wsi( else: predictions = self.model.forward(patches, retrieve_tokens=True) predictions = self.apply_softmax_reorder(predictions) - call_id = batch_actor.convert_batch_to_graph_nodes.remote( - predictions, metadata - ) + # ==================== {origin code start} ==================== + # call_id = batch_actor.convert_batch_to_graph_nodes.remote( + # predictions, metadata + # ) + # call_ids.append(call_id) + # pbar.update(1) + # pbar.total = len(wsi_inference_dataloader) + # ==================== {origin code end} ==================== + # ==================== {fix code start} ==================== + call_id = batch_actor.convert_batch_to_graph_nodes.remote(predictions, metadata) call_ids.append(call_id) + if len(call_ids) >= MAX_IN_FLIGHT: + ready_ids, call_ids = ray.wait(call_ids, num_returns=1, timeout=None) + ready_results = ray.get(ready_ids) + inference_results.extend(ready_results) + ray.internal.free(ready_ids) + del ready_results, ready_ids + # ==================== {fix code end} ==================== pbar.update(1) pbar.total = len(wsi_inference_dataloader) self.logger.info("Waiting for final batches to be processed...") - inference_results = [ray.get(call_id) for call_id in call_ids] + # ==================== {origin code start} ==================== + # inference_results = [ray.get(call_id) for call_id in call_ids] + # ==================== {origin code end} ==================== + # ==================== {fix code start} ==================== + while call_ids: + ready_ids, call_ids = ray.wait(call_ids, num_returns=min(8, len(call_ids)), timeout=None) + ready_results = ray.get(ready_ids) + inference_results.extend(ready_results) + ray.internal.free(ready_ids) + del ready_results, ready_ids + # ==================== {fix code end} ==================== del pbar [ray.kill(batch_actor) for batch_actor in batch_pooling_actors] @@ -191,10 +217,18 @@ def process_wsi( batch_cell_tokens, batch_cell_positions, ) = batch_results - cell_dict_wsi = cell_dict_wsi + batch_complete_dict - cell_dict_detection = cell_dict_detection + batch_detection - graph_data["cell_tokens"] = graph_data["cell_tokens"] + batch_cell_tokens - graph_data["positions"] = graph_data["positions"] + batch_cell_positions + # ==================== {origin code start} ==================== + # cell_dict_wsi = cell_dict_wsi + batch_complete_dict + # cell_dict_detection = cell_dict_detection + batch_detection + # graph_data["cell_tokens"] = graph_data["cell_tokens"] + batch_cell_tokens + # graph_data["positions"] = graph_data["positions"] + batch_cell_positions + # ==================== {origin code end} ==================== + # ==================== {fix code start} ==================== + cell_dict_wsi.extend(batch_complete_dict) + cell_dict_detection.extend(batch_detection) + graph_data["cell_tokens"].extend(batch_cell_tokens) + graph_data["positions"].extend(batch_cell_positions) + # ==================== {fix code end} ==================== # cleaning overlapping cells if len(cell_dict_wsi) == 0: