diff --git a/README.md b/README.md index 9499ccd..249503b 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ ONNX-based inference system for IFCB (Imaging FlowCytobot) bin data. This tool p | `[cuda]` | `onnxruntime-gpu` | GPU inference via CUDA | | `[torch]` | PyTorch + torchvision | Faster/more flexible data loading, but more dependancies | | `[cuda,torch]` | Both of the above | Full-featured install | +| `[embeddings]` | pyarrow | Writing embedding vectors as Parquet (see [Embeddings](#embeddings)) | | `[dev]` | pytest, black, isort, flake8 | Development and testing | - One of `[cpu]` or `[cuda]` must be used to have the appropriate onnxruntime. They are mutually exclusive. If neither are included, at install, `ifcb-infer` will be unable to run. If in doubt, use `[cuda]`. @@ -78,6 +79,10 @@ ifcb-infer [OPTIONS] MODEL BINS [BINS ...] Tokens: {MODEL_NAME}, {RUN_DATE}, {SUBPATH} (relative dir), {BIN} (bin name) --cpuonly Force CPU inference even if CUDA is available --notorch Use non-PyTorch data loader even if torch is installed +--embeddings Also emit penultimate-layer embedding vectors (see Embeddings) +--embeddings-only Emit only embeddings, skip the score CSV (implies --embeddings) +--embeddings-outfile PATTERN Embedding filename pattern. Same tokens as --outfile. + Default: {MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet ``` - By default, CUDA is used automatically when available/installed and otherwise falls back to using CPU. @@ -174,6 +179,36 @@ outputs/ └── OTZ/2019/D20190723/D20190723T171832_IFCB127.csv ``` +## Embeddings + +In addition to class scores, `ifcb-infer` can emit the CNN's **penultimate-layer embedding** — the global-pooled feature vector that feeds the classification head. No retraining is needed: the embedding is an intermediate activation the trained model already computes on every forward pass; it just needs to be surfaced as a model output. + +This is a two-step workflow: + +**1. One-time graph surgery.** ONNX Runtime only returns tensors declared in the model's graph outputs. Add the embedding tensor as a second output: + +```bash +python -m ifcb_infer.add_embedding_output classifier.onnx classifier_emb.onnx +``` + +The embedding tensor is auto-detected as the data input of the final `Gemm`/`MatMul` (the classification head). For a non-standard architecture, override it with `--tensor-name`. The resulting model returns `[scores, embedding]` from one forward pass and is otherwise identical to the original. + +**2. Run inference with `--embeddings`** against the surgically-modified model: + +```bash +# install the extra once: pip install -e ".[embeddings]" +ifcb-infer --embeddings --classes classes.txt classifier_emb.onnx example-data/bins/ +``` + +Each bin gets, alongside its `.csv` of scores, an `.emb.parquet` file with one row per ROI: + +| Column | Type | Notes | +|---|---|---| +| `pid` | string | ROI identifier (aligned with the score CSV) | +| `embedding` | `fixed_size_list` | the feature vector (e.g. length 2048 for InceptionV3) | + +Embeddings are stored at **float16** to halve on-disk size — ample precision for similarity, clustering, and visualization. The output path follows `--embeddings-outfile` (same tokens as `--outfile`). Use `--embeddings-only` to skip writing the score CSV. Running `--embeddings` against an unmodified (single-output) model raises an error pointing back to step 1. + ## Container Use The Dockerfile installs with `[cuda,torch]` for full GPU support. diff --git a/pyproject.toml b/pyproject.toml index b8af0e6..fe323fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ torch = [ "torchvision", "humanize", ] +embeddings = [ + "pyarrow", +] dev = [ "pytest~=8.3.4", "pytest-mock~=3.14.0", diff --git a/src/ifcb_infer/add_embedding_output.py b/src/ifcb_infer/add_embedding_output.py new file mode 100644 index 0000000..fe5b661 --- /dev/null +++ b/src/ifcb_infer/add_embedding_output.py @@ -0,0 +1,111 @@ +import argparse + +import onnx +from onnx import TensorProto, helper, shape_inference + +# Ops that act as the final classification head; their data input is the +# penultimate-layer embedding (the global-pooled feature vector). +_HEAD_OPS = ("Gemm", "MatMul") + + +def detect_embedding_tensor(graph): + """Return the name of the embedding tensor: the data input of the last + classification-head op (Gemm/MatMul), i.e. its input that is not a weight + or bias initializer.""" + initializers = {init.name for init in graph.initializer} + for node in reversed(graph.node): + if node.op_type in _HEAD_OPS: + data_inputs = [name for name in node.input if name not in initializers] + if not data_inputs: + continue + return data_inputs[0] + raise ValueError( + f"Could not find a classification-head op ({'/'.join(_HEAD_OPS)}) to " + "auto-detect the embedding tensor. Pass --tensor-name explicitly." + ) + + +def _infer_embedding_dim(graph, tensor_name): + """Best-effort embedding dimension D from the head op's weight initializer + (its second dim). Returns None if it cannot be determined.""" + initializers = {init.name: init for init in graph.initializer} + for node in graph.node: + if node.op_type in _HEAD_OPS and tensor_name in node.input: + for name in node.input: + if name in initializers: + dims = initializers[name].dims + if len(dims) == 2: + # Gemm fc.weight is [num_classes, D]; D is the input width. + return dims[1] + return None + + +def add_embedding_output(model_path: str, output_path: str, tensor_name=None): + """Add the penultimate-layer embedding tensor to an ONNX model's declared + graph outputs, producing a dual-head model whose ``session.run(None, ...)`` + returns ``[logits, embedding]``. + + Parameters: + model_path (str): Path to the input ONNX model. + output_path (str): Path to save the modified ONNX model. + tensor_name (str or None): Name of the embedding tensor. If None, it is + auto-detected as the data input of the final Gemm/MatMul. + """ + model = onnx.load(model_path) + graph = model.graph + + if tensor_name is None: + tensor_name = detect_embedding_tensor(graph) + print(f"Embedding tensor: '{tensor_name}'") + + existing_outputs = {out.name for out in graph.output} + if tensor_name in existing_outputs: + print(f"'{tensor_name}' is already a graph output; saving unchanged copy.") + onnx.save(model, output_path) + print(f"model saved to: {output_path}") + return + + # Prefer shape-inferred value_info so the output carries shape/type metadata. + inferred = shape_inference.infer_shapes(model) + value_info = None + for vi in inferred.graph.value_info: + if vi.name == tensor_name: + value_info = vi + break + + if value_info is None: + dim = _infer_embedding_dim(graph, tensor_name) + shape = ["batch_size", dim] if dim is not None else ["batch_size", 0] + value_info = helper.make_tensor_value_info( + tensor_name, TensorProto.FLOAT, shape + ) + print(f"No inferred value_info; declared shape {shape}.") + + graph.output.append(value_info) + onnx.save(model, output_path) + print( + f"Added embedding output '{tensor_name}'. " + f"graph.output is now: {[out.name for out in graph.output]}" + ) + print(f"model saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Add the penultimate-layer embedding tensor to an ONNX " + "model's graph outputs (dual-head: [logits, embedding])." + ) + parser.add_argument("model_path", help="Path to input ONNX model") + parser.add_argument("output_path", help="Path to save the modified ONNX model") + parser.add_argument( + "--tensor-name", + default=None, + help="Embedding tensor name. Default: auto-detect (data input of the " + "final Gemm/MatMul).", + ) + args = parser.parse_args() + add_embedding_output(args.model_path, args.output_path, args.tensor_name) + + +if __name__ == "__main__": + main() diff --git a/src/ifcb_infer/cli.py b/src/ifcb_infer/cli.py index 3622e8b..626cdc0 100644 --- a/src/ifcb_infer/cli.py +++ b/src/ifcb_infer/cli.py @@ -54,6 +54,23 @@ def argparse_init(parser=None): action="store_true", help="Skip softmax normalization check on model output", ) + parser.add_argument( + "--embeddings", + action="store_true", + help="Emit penultimate-layer embedding vectors. Requires a MODEL whose " + "graph exposes the embedding tensor (see add_embedding_output).", + ) + parser.add_argument( + "--embeddings-only", + action="store_true", + help="Skip writing the score CSV; write only embeddings. Implies --embeddings.", + ) + parser.add_argument( + "--embeddings-outfile", + default="{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet", + help="Embedding output filename pattern. Same tokens as --outfile. " + 'Default is "{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet"', + ) return parser @@ -72,6 +89,9 @@ def argparse_runtime_args(args): args.model_name = os.path.splitext(os.path.basename(args.MODEL))[0] + if getattr(args, "embeddings_only", False): + args.embeddings = True + gpu_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") args.gpus = [int(gpu) for gpu in gpu_str.split(",") if gpu.strip()] @@ -145,11 +165,11 @@ def pad_batch(batch: np.ndarray, target_batch_size: int): return np.concatenate([batch, pad], axis=0) -def get_output_path(args, bin_id, bin_relative_path=None): +def _format_output_path(args, outfile, bin_id, bin_relative_path=None): full_subpath = bin_relative_path if bin_relative_path is not None else bin_id subpath_dir = os.path.dirname(full_subpath) bin_name = os.path.basename(full_subpath) - outpath = os.path.join(args.outdir, args.outfile) + outpath = os.path.join(args.outdir, outfile) result = outpath.format( RUN_DATE=args.run_date_str, MODEL_NAME=args.model_name, @@ -159,6 +179,14 @@ def get_output_path(args, bin_id, bin_relative_path=None): return os.path.normpath(result) +def get_output_path(args, bin_id, bin_relative_path=None): + return _format_output_path(args, args.outfile, bin_id, bin_relative_path) + + +def get_embedding_output_path(args, bin_id, bin_relative_path=None): + return _format_output_path(args, args.embeddings_outfile, bin_id, bin_relative_path) + + def write_output(args, bin_id, pids, score_matrix, bin_relative_path=None): outpath = get_output_path(args, bin_id, bin_relative_path) os.makedirs(os.path.dirname(outpath), exist_ok=True) @@ -175,6 +203,42 @@ def write_output(args, bin_id, pids, score_matrix, bin_relative_path=None): print(f"Warning: No data processed for bin {bin_id}") +def resolve_emit_embeddings(args, ort_session): + """Decide whether to emit embeddings for this run, validating that the model + actually exposes the embedding output when --embeddings was requested.""" + emit = args.embeddings and len(ort_session.get_outputs()) > 1 + if args.embeddings and not emit: + raise ValueError( + "--embeddings requested but MODEL exposes a single output. Run " + "`python -m ifcb_infer.add_embedding_output` to add the embedding " + "tensor to the model's graph outputs." + ) + return emit + + +def write_embeddings(args, bin_id, pids, embedding_matrix, bin_relative_path=None): + if embedding_matrix is None: + print(f"Warning: No embeddings processed for bin {bin_id}") + return + + # Imported lazily so non-embedding runs don't require pyarrow. + import pyarrow as pa + import pyarrow.parquet as pq + + outpath = get_embedding_output_path(args, bin_id, bin_relative_path) + os.makedirs(os.path.dirname(outpath), exist_ok=True) + + embedding_matrix = np.ascontiguousarray(embedding_matrix.astype(np.float16)) + n_rows, dim = embedding_matrix.shape + embedding_col = pa.FixedSizeListArray.from_arrays( + pa.array(embedding_matrix.reshape(-1), type=pa.float16()), dim + ) + table = pa.table( + {"pid": pa.array(list(pids), type=pa.string()), "embedding": embedding_col} + ) + pq.write_table(table, outpath) + + def main(): # ort.preload_dlls(directory="") useful for TRT on Windows diff --git a/src/ifcb_infer/sanstorch.py b/src/ifcb_infer/sanstorch.py index ad6b6e9..8e428e5 100644 --- a/src/ifcb_infer/sanstorch.py +++ b/src/ifcb_infer/sanstorch.py @@ -4,7 +4,15 @@ import onnxruntime as ort from tqdm import tqdm -from ifcb_infer.cli import get_output_path, get_providers, pad_batch, write_output +from ifcb_infer.cli import ( + get_embedding_output_path, + get_output_path, + get_providers, + pad_batch, + resolve_emit_embeddings, + write_embeddings, + write_output, +) from ifcb_infer.datasets import IfcbBinDataset, IfcbBinImageTransformer, MyDataLoader @@ -15,6 +23,8 @@ def main(args): args.MODEL, sess_options=sess_options, providers=providers ) + emit_embeddings = resolve_emit_embeddings(args, ort_session) + input0 = ort_session.get_inputs()[0] model_batch = input0.shape[0] img_size = input0.shape[-1] @@ -42,6 +52,7 @@ def main(args): for bin_accessor in pbar: img_pids = [] score_matrix = None + embedding_matrix = None bin_relative_path = None input_dir = args.bin_to_input_dir.get(bin_accessor) @@ -61,7 +72,12 @@ def main(args): dataloader = MyDataLoader(dataset, inference_batchsize, transformer) bin_pid = dataset.pid - expected_output_path = get_output_path(args, bin_pid, bin_relative_path) + if args.embeddings_only: + expected_output_path = get_embedding_output_path( + args, bin_pid, bin_relative_path + ) + else: + expected_output_path = get_output_path(args, bin_pid, bin_relative_path) if os.path.exists(expected_output_path): pbar.set_description( f"batchsize={inference_batchsize} (skipping {bin_pid})" @@ -86,6 +102,21 @@ def main(args): score_matrix = np.concatenate( [score_matrix, batch_score_matrix], axis=0 ) + + if emit_embeddings: + batch_embedding_matrix = outputs[1] + if embedding_matrix is None: + embedding_matrix = batch_embedding_matrix + else: + embedding_matrix = np.concatenate( + [embedding_matrix, batch_embedding_matrix], axis=0 + ) + img_pids.extend(batch_pids) - write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path) + if not args.embeddings_only: + write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path) + if emit_embeddings: + write_embeddings( + args, bin_pid, img_pids, embedding_matrix, bin_relative_path + ) diff --git a/src/ifcb_infer/withtorch.py b/src/ifcb_infer/withtorch.py index 5c5799f..f96c38d 100644 --- a/src/ifcb_infer/withtorch.py +++ b/src/ifcb_infer/withtorch.py @@ -7,7 +7,15 @@ from torchvision.transforms import v2 from tqdm import tqdm -from ifcb_infer.cli import get_output_path, get_providers, pad_batch, write_output +from ifcb_infer.cli import ( + get_embedding_output_path, + get_output_path, + get_providers, + pad_batch, + resolve_emit_embeddings, + write_embeddings, + write_output, +) from ifcb_infer.datasets_torch import IfcbBinsDataset @@ -18,6 +26,8 @@ def main(args): args.MODEL, sess_options=sess_options, providers=providers ) + emit_embeddings = resolve_emit_embeddings(args, ort_session) + input0 = ort_session.get_inputs()[0] model_batch = input0.shape[0] img_size = input0.shape[-1] @@ -55,6 +65,7 @@ def main(args): for bin_accessor in pbar: img_pids = [] score_matrix = None + embedding_matrix = None root_dir = os.path.dirname(bin_accessor) bin_id = os.path.basename(bin_accessor) @@ -89,7 +100,12 @@ def main(args): else: bin_pid = bin_id # fallback for empty binfilesets - expected_output_path = get_output_path(args, bin_pid, bin_relative_path) + if args.embeddings_only: + expected_output_path = get_embedding_output_path( + args, bin_pid, bin_relative_path + ) + else: + expected_output_path = get_output_path(args, bin_pid, bin_relative_path) if os.path.exists(expected_output_path): pbar.set_description( f"batchsize={inference_batchsize} (skipping {bin_pid})" @@ -115,6 +131,21 @@ def main(args): score_matrix = np.concatenate( [score_matrix, batch_score_matrix], axis=0 ) + + if emit_embeddings: + batch_embedding_matrix = outputs[1] + if embedding_matrix is None: + embedding_matrix = batch_embedding_matrix + else: + embedding_matrix = np.concatenate( + [embedding_matrix, batch_embedding_matrix], axis=0 + ) + img_pids.extend(batch_pids) - write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path) + if not args.embeddings_only: + write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path) + if emit_embeddings: + write_embeddings( + args, bin_pid, img_pids, embedding_matrix, bin_relative_path + ) diff --git a/tests/test_infer_functions.py b/tests/test_infer_functions.py index b048c97..da13c9c 100644 --- a/tests/test_infer_functions.py +++ b/tests/test_infer_functions.py @@ -10,7 +10,10 @@ argparse_init, argparse_runtime_args, ensure_softmax, + get_embedding_output_path, get_output_path, + resolve_emit_embeddings, + write_embeddings, ) # The torch and notorch variants share a single argparse/runtime implementation. @@ -329,5 +332,188 @@ def test_negative_values_trigger_softmax(self): np.testing.assert_allclose(result.sum(axis=1), [1.0], atol=1e-6) +def _build_tiny_classifier(path, in_dim=4, n_classes=3): + """Build a minimal ONNX classifier: data -> Relu (the embedding) -> Gemm. + + The Relu output is the penultimate tensor that add_embedding_output should + auto-detect as the embedding. + """ + import onnx + from onnx import TensorProto, helper, numpy_helper + + relu = helper.make_node("Relu", ["data"], ["feat"], name="relu") + # transB=1: weight is [n_classes, in_dim] (like torch fc.weight), so the + # Gemm computes feat[batch,in_dim] @ W^T -> [batch, n_classes]. + gemm = helper.make_node("Gemm", ["feat", "W", "b"], ["output"], name="fc", transB=1) + + w = numpy_helper.from_array( + np.ones((n_classes, in_dim), dtype=np.float32), name="W" + ) + b = numpy_helper.from_array(np.zeros((n_classes,), dtype=np.float32), name="b") + + inp = helper.make_tensor_value_info( + "data", TensorProto.FLOAT, ["batch_size", in_dim] + ) + out = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, ["batch_size", n_classes] + ) + graph = helper.make_graph([relu, gemm], "tiny", [inp], [out], [w, b]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + # Pin a conservative IR version: newer onnx defaults to IR 13, but the + # onnxruntime on CI may only support up to IR 11. Opset 13 needs IR >= 7. + model.ir_version = 10 + onnx.checker.check_model(model) + onnx.save(model, path) + + +class TestAddEmbeddingOutput: + """Test the ONNX graph surgery in add_embedding_output.""" + + def test_auto_detect_embedding_tensor(self, tmp_path): + import onnx + + from ifcb_infer.add_embedding_output import detect_embedding_tensor + + src = str(tmp_path / "tiny.onnx") + _build_tiny_classifier(src) + model = onnx.load(src) + assert detect_embedding_tensor(model.graph) == "feat" + + def test_adds_second_output_and_runs(self, tmp_path): + import onnxruntime as ort + + from ifcb_infer.add_embedding_output import add_embedding_output + + src = str(tmp_path / "tiny.onnx") + dst = str(tmp_path / "tiny_emb.onnx") + _build_tiny_classifier(src, in_dim=4, n_classes=3) + add_embedding_output(src, dst) + + sess = ort.InferenceSession(dst, providers=["CPUExecutionProvider"]) + out_names = [o.name for o in sess.get_outputs()] + assert len(out_names) == 2 + assert out_names[1] == "feat" + + x = np.array([[1.0, -2.0, 3.0, 4.0]], dtype=np.float32) + logits, emb = sess.run(None, {"data": x}) + assert logits.shape == (1, 3) + assert emb.shape == (1, 4) + # embedding is Relu(data): negatives clamped to 0 + np.testing.assert_array_equal(emb, np.array([[1.0, 0.0, 3.0, 4.0]])) + + def test_idempotent_when_already_output(self, tmp_path): + import onnx + + from ifcb_infer.add_embedding_output import add_embedding_output + + src = str(tmp_path / "tiny.onnx") + dst = str(tmp_path / "tiny_emb.onnx") + dst2 = str(tmp_path / "tiny_emb2.onnx") + _build_tiny_classifier(src) + add_embedding_output(src, dst) + # Re-running on the already-modified model must not add a duplicate. + add_embedding_output(dst, dst2, tensor_name="feat") + model = onnx.load(dst2) + assert [o.name for o in model.graph.output] == ["output", "feat"] + + def test_explicit_tensor_name_override(self, tmp_path): + import onnx + + from ifcb_infer.add_embedding_output import add_embedding_output + + src = str(tmp_path / "tiny.onnx") + dst = str(tmp_path / "tiny_emb.onnx") + _build_tiny_classifier(src) + add_embedding_output(src, dst, tensor_name="feat") + model = onnx.load(dst) + assert "feat" in [o.name for o in model.graph.output] + + +class _FakeSession: + def __init__(self, n_outputs): + self._outs = [type("O", (), {"name": f"out{i}"})() for i in range(n_outputs)] + + def get_outputs(self): + return self._outs + + +class TestResolveEmitEmbeddings: + def test_off_by_default(self): + args = type("Args", (), {"embeddings": False})() + assert resolve_emit_embeddings(args, _FakeSession(1)) is False + assert resolve_emit_embeddings(args, _FakeSession(2)) is False + + def test_on_with_dual_output_model(self): + args = type("Args", (), {"embeddings": True})() + assert resolve_emit_embeddings(args, _FakeSession(2)) is True + + def test_raises_on_single_output_model(self): + args = type("Args", (), {"embeddings": True})() + with pytest.raises(ValueError, match="single output"): + resolve_emit_embeddings(args, _FakeSession(1)) + + +class TestWriteEmbeddings: + def setup_method(self): + self.args = type("Args", (), {})() + self.args.outdir = "./outputs" + self.args.run_date_str = "2025-01-15" + self.args.model_name = "test_model" + self.args.embeddings_outfile = "{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet" + + def test_embedding_output_path(self): + result = get_embedding_output_path( + self.args, "test_bin", "MVCO/2023/D20230108/test_bin" + ) + assert result == os.path.normpath( + "./outputs/test_model/MVCO/2023/D20230108/test_bin.emb.parquet" + ) + + def test_writes_parquet_float16_with_pids(self, tmp_path): + pytest.importorskip("pyarrow") + import pyarrow.parquet as pq + + self.args.outdir = str(tmp_path) + pids = ["pidA", "pidB"] + emb = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + write_embeddings(self.args, "test_bin", pids, emb) + + outpath = get_embedding_output_path(self.args, "test_bin") + assert os.path.exists(outpath) + table = pq.read_table(outpath) + assert table.column_names == ["pid", "embedding"] + assert table.column("pid").to_pylist() == pids + emb_back = np.array(table.column("embedding").to_pylist()) + assert emb_back.shape == (2, 3) + assert ( + table.schema.field("embedding").type.value_type + == __import__("pyarrow").float16() + ) + np.testing.assert_array_equal(emb_back, emb) + + def test_none_matrix_writes_nothing(self, tmp_path): + self.args.outdir = str(tmp_path) + write_embeddings(self.args, "test_bin", [], None) + assert not os.path.exists(get_embedding_output_path(self.args, "test_bin")) + + +class TestEmbeddingArgparse: + def test_embeddings_flags_default_off(self): + parser = argparse_init() + args = parser.parse_args(["model.onnx", "bins/"]) + assert args.embeddings is False + assert args.embeddings_only is False + assert args.embeddings_outfile == "{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet" + + def test_embeddings_only_implies_embeddings(self, mocker): + mock_datetime = mocker.patch("datetime.datetime") + mock_datetime.now.return_value.isoformat.return_value = "2025-01-15T14:30:45" + parser = argparse_init() + args = parser.parse_args(["--embeddings-only", "model.onnx", "bins/"]) + args.BINS = [] + argparse_runtime_args(args) + assert args.embeddings is True + + if __name__ == "__main__": pytest.main([__file__, "-v"])