diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index c070c24..3604ce8 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -132,7 +132,6 @@ def tile_bag_dataloader( collate_fn=collate_fn, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, persistent_workers=(num_workers > 0), - pin_memory=torch.cuda.is_available(), ) return ( @@ -416,7 +415,6 @@ def create_dataloader( num_workers=num_workers, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, persistent_workers=(num_workers > 0), - pin_memory=torch.cuda.is_available(), ) return dl, categories or [] else: diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 035769f..6d81a8f 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -342,6 +342,8 @@ def deploy_categorical_model_( patient_label=patient_label, ground_truth_label=ground_truth_label, cut_off=cut_off, + time_label=time_label, + status_label=status_label, ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( @@ -351,6 +353,8 @@ def deploy_categorical_model_( patient_label=patient_label, ground_truth_label=ground_truth_label, cut_off=cut_off, + time_label=time_label, + status_label=status_label, ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": @@ -641,6 +645,8 @@ def _to_survival_prediction_df( ], predictions: Mapping[PatientId, torch.Tensor], patient_label: PandasLabel, + time_label: PandasLabel = "time", + status_label: PandasLabel = "event", cut_off: float | None = None, **kwargs, ) -> pd.DataFrame: @@ -671,9 +677,9 @@ def _to_survival_prediction_df( # call .split on ground-truth values — assume structured input. If # the value is not a 2-tuple/list, treat both fields as unknown. if isinstance(gt, (tuple, list)) and len(gt) == 2: - row["time"], row["event"] = gt + row[time_label], row[status_label] = gt else: - row["time"], row["event"] = None, None + row[time_label], row[status_label] = None, None rows.append(row) diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index b2daa38..95564b9 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -311,14 +311,12 @@ def extract_( default_slide_mpp=default_slide_mpp, ) # Parallelism is implemented in the dataset iterator already, so one worker is enough! - # pin_memory speeds up CPU→GPU DMA for tile batches. # num_workers=1 is intentional: WSI read parallelism is inside _supertiles. dl = DataLoader( ds, batch_size=64, num_workers=1, drop_last=False, - pin_memory=torch.cuda.is_available(), ) feats, xs_um, ys_um = [], [], []