Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ONNX-based inference system for IFCB (Imaging FlowCytobot) bin data. This tool p
| `[cuda]` | `onnxruntime-gpu` | GPU inference via CUDA |
| `[torch]` | PyTorch + torchvision | Faster/more flexible data loading, but more dependancies |
| `[cuda,torch]` | Both of the above | Full-featured install |
| `[embeddings]` | pyarrow | Writing embedding vectors as Parquet (see [Embeddings](#embeddings)) |
| `[dev]` | pytest, black, isort, flake8 | Development and testing |

- One of `[cpu]` or `[cuda]` must be used to have the appropriate onnxruntime. They are mutually exclusive. If neither are included, at install, `ifcb-infer` will be unable to run. If in doubt, use `[cuda]`.
Expand Down Expand Up @@ -78,6 +79,10 @@ ifcb-infer [OPTIONS] MODEL BINS [BINS ...]
Tokens: {MODEL_NAME}, {RUN_DATE}, {SUBPATH} (relative dir), {BIN} (bin name)
--cpuonly Force CPU inference even if CUDA is available
--notorch Use non-PyTorch data loader even if torch is installed
--embeddings Also emit penultimate-layer embedding vectors (see Embeddings)
--embeddings-only Emit only embeddings, skip the score CSV (implies --embeddings)
--embeddings-outfile PATTERN Embedding filename pattern. Same tokens as --outfile.
Default: {MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet
```

- By default, CUDA is used automatically when available/installed and otherwise falls back to using CPU.
Expand Down Expand Up @@ -174,6 +179,36 @@ outputs/
└── OTZ/2019/D20190723/D20190723T171832_IFCB127.csv
```

## Embeddings

In addition to class scores, `ifcb-infer` can emit the CNN's **penultimate-layer embedding** — the global-pooled feature vector that feeds the classification head. No retraining is needed: the embedding is an intermediate activation the trained model already computes on every forward pass; it just needs to be surfaced as a model output.

This is a two-step workflow:

**1. One-time graph surgery.** ONNX Runtime only returns tensors declared in the model's graph outputs. Add the embedding tensor as a second output:

```bash
python -m ifcb_infer.add_embedding_output classifier.onnx classifier_emb.onnx
```

The embedding tensor is auto-detected as the data input of the final `Gemm`/`MatMul` (the classification head). For a non-standard architecture, override it with `--tensor-name`. The resulting model returns `[scores, embedding]` from one forward pass and is otherwise identical to the original.

**2. Run inference with `--embeddings`** against the surgically-modified model:

```bash
# install the extra once: pip install -e ".[embeddings]"
ifcb-infer --embeddings --classes classes.txt classifier_emb.onnx example-data/bins/
```

Each bin gets, alongside its `.csv` of scores, an `.emb.parquet` file with one row per ROI:

| Column | Type | Notes |
|---|---|---|
| `pid` | string | ROI identifier (aligned with the score CSV) |
| `embedding` | `fixed_size_list<float16>` | the feature vector (e.g. length 2048 for InceptionV3) |

Embeddings are stored at **float16** to halve on-disk size — ample precision for similarity, clustering, and visualization. The output path follows `--embeddings-outfile` (same tokens as `--outfile`). Use `--embeddings-only` to skip writing the score CSV. Running `--embeddings` against an unmodified (single-output) model raises an error pointing back to step 1.

## Container Use

The Dockerfile installs with `[cuda,torch]` for full GPU support.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ torch = [
"torchvision",
"humanize",
]
embeddings = [
"pyarrow",
]
dev = [
"pytest~=8.3.4",
"pytest-mock~=3.14.0",
Expand Down
111 changes: 111 additions & 0 deletions src/ifcb_infer/add_embedding_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse

import onnx
from onnx import TensorProto, helper, shape_inference

# Ops that act as the final classification head; their data input is the
# penultimate-layer embedding (the global-pooled feature vector).
_HEAD_OPS = ("Gemm", "MatMul")


def detect_embedding_tensor(graph):
"""Return the name of the embedding tensor: the data input of the last
classification-head op (Gemm/MatMul), i.e. its input that is not a weight
or bias initializer."""
initializers = {init.name for init in graph.initializer}
for node in reversed(graph.node):
if node.op_type in _HEAD_OPS:
data_inputs = [name for name in node.input if name not in initializers]
if not data_inputs:
continue
return data_inputs[0]
raise ValueError(
f"Could not find a classification-head op ({'/'.join(_HEAD_OPS)}) to "
"auto-detect the embedding tensor. Pass --tensor-name explicitly."
)


def _infer_embedding_dim(graph, tensor_name):
"""Best-effort embedding dimension D from the head op's weight initializer
(its second dim). Returns None if it cannot be determined."""
initializers = {init.name: init for init in graph.initializer}
for node in graph.node:
if node.op_type in _HEAD_OPS and tensor_name in node.input:
for name in node.input:
if name in initializers:
dims = initializers[name].dims
if len(dims) == 2:
# Gemm fc.weight is [num_classes, D]; D is the input width.
return dims[1]
return None


def add_embedding_output(model_path: str, output_path: str, tensor_name=None):
"""Add the penultimate-layer embedding tensor to an ONNX model's declared
graph outputs, producing a dual-head model whose ``session.run(None, ...)``
returns ``[logits, embedding]``.

Parameters:
model_path (str): Path to the input ONNX model.
output_path (str): Path to save the modified ONNX model.
tensor_name (str or None): Name of the embedding tensor. If None, it is
auto-detected as the data input of the final Gemm/MatMul.
"""
model = onnx.load(model_path)
graph = model.graph

if tensor_name is None:
tensor_name = detect_embedding_tensor(graph)
print(f"Embedding tensor: '{tensor_name}'")

existing_outputs = {out.name for out in graph.output}
if tensor_name in existing_outputs:
print(f"'{tensor_name}' is already a graph output; saving unchanged copy.")
onnx.save(model, output_path)
print(f"model saved to: {output_path}")
return

# Prefer shape-inferred value_info so the output carries shape/type metadata.
inferred = shape_inference.infer_shapes(model)
value_info = None
for vi in inferred.graph.value_info:
if vi.name == tensor_name:
value_info = vi
break

if value_info is None:
dim = _infer_embedding_dim(graph, tensor_name)
shape = ["batch_size", dim] if dim is not None else ["batch_size", 0]
value_info = helper.make_tensor_value_info(
tensor_name, TensorProto.FLOAT, shape
)
print(f"No inferred value_info; declared shape {shape}.")

graph.output.append(value_info)
onnx.save(model, output_path)
print(
f"Added embedding output '{tensor_name}'. "
f"graph.output is now: {[out.name for out in graph.output]}"
)
print(f"model saved to: {output_path}")


def main():
parser = argparse.ArgumentParser(
description="Add the penultimate-layer embedding tensor to an ONNX "
"model's graph outputs (dual-head: [logits, embedding])."
)
parser.add_argument("model_path", help="Path to input ONNX model")
parser.add_argument("output_path", help="Path to save the modified ONNX model")
parser.add_argument(
"--tensor-name",
default=None,
help="Embedding tensor name. Default: auto-detect (data input of the "
"final Gemm/MatMul).",
)
args = parser.parse_args()
add_embedding_output(args.model_path, args.output_path, args.tensor_name)


if __name__ == "__main__":
main()
68 changes: 66 additions & 2 deletions src/ifcb_infer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ def argparse_init(parser=None):
action="store_true",
help="Skip softmax normalization check on model output",
)
parser.add_argument(
"--embeddings",
action="store_true",
help="Emit penultimate-layer embedding vectors. Requires a MODEL whose "
"graph exposes the embedding tensor (see add_embedding_output).",
)
parser.add_argument(
"--embeddings-only",
action="store_true",
help="Skip writing the score CSV; write only embeddings. Implies --embeddings.",
)
parser.add_argument(
"--embeddings-outfile",
default="{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet",
help="Embedding output filename pattern. Same tokens as --outfile. "
'Default is "{MODEL_NAME}/{SUBPATH}/{BIN}.emb.parquet"',
)

return parser

Expand All @@ -72,6 +89,9 @@ def argparse_runtime_args(args):

args.model_name = os.path.splitext(os.path.basename(args.MODEL))[0]

if getattr(args, "embeddings_only", False):
args.embeddings = True

gpu_str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
args.gpus = [int(gpu) for gpu in gpu_str.split(",") if gpu.strip()]

Expand Down Expand Up @@ -145,11 +165,11 @@ def pad_batch(batch: np.ndarray, target_batch_size: int):
return np.concatenate([batch, pad], axis=0)


def get_output_path(args, bin_id, bin_relative_path=None):
def _format_output_path(args, outfile, bin_id, bin_relative_path=None):
full_subpath = bin_relative_path if bin_relative_path is not None else bin_id
subpath_dir = os.path.dirname(full_subpath)
bin_name = os.path.basename(full_subpath)
outpath = os.path.join(args.outdir, args.outfile)
outpath = os.path.join(args.outdir, outfile)
result = outpath.format(
RUN_DATE=args.run_date_str,
MODEL_NAME=args.model_name,
Expand All @@ -159,6 +179,14 @@ def get_output_path(args, bin_id, bin_relative_path=None):
return os.path.normpath(result)


def get_output_path(args, bin_id, bin_relative_path=None):
return _format_output_path(args, args.outfile, bin_id, bin_relative_path)


def get_embedding_output_path(args, bin_id, bin_relative_path=None):
return _format_output_path(args, args.embeddings_outfile, bin_id, bin_relative_path)


def write_output(args, bin_id, pids, score_matrix, bin_relative_path=None):
outpath = get_output_path(args, bin_id, bin_relative_path)
os.makedirs(os.path.dirname(outpath), exist_ok=True)
Expand All @@ -175,6 +203,42 @@ def write_output(args, bin_id, pids, score_matrix, bin_relative_path=None):
print(f"Warning: No data processed for bin {bin_id}")


def resolve_emit_embeddings(args, ort_session):
"""Decide whether to emit embeddings for this run, validating that the model
actually exposes the embedding output when --embeddings was requested."""
emit = args.embeddings and len(ort_session.get_outputs()) > 1
if args.embeddings and not emit:
raise ValueError(
"--embeddings requested but MODEL exposes a single output. Run "
"`python -m ifcb_infer.add_embedding_output` to add the embedding "
"tensor to the model's graph outputs."
)
return emit


def write_embeddings(args, bin_id, pids, embedding_matrix, bin_relative_path=None):
if embedding_matrix is None:
print(f"Warning: No embeddings processed for bin {bin_id}")
return

# Imported lazily so non-embedding runs don't require pyarrow.
import pyarrow as pa
import pyarrow.parquet as pq

outpath = get_embedding_output_path(args, bin_id, bin_relative_path)
os.makedirs(os.path.dirname(outpath), exist_ok=True)

embedding_matrix = np.ascontiguousarray(embedding_matrix.astype(np.float16))
n_rows, dim = embedding_matrix.shape
embedding_col = pa.FixedSizeListArray.from_arrays(
pa.array(embedding_matrix.reshape(-1), type=pa.float16()), dim
)
table = pa.table(
{"pid": pa.array(list(pids), type=pa.string()), "embedding": embedding_col}
)
pq.write_table(table, outpath)


def main():
# ort.preload_dlls(directory="") useful for TRT on Windows

Expand Down
37 changes: 34 additions & 3 deletions src/ifcb_infer/sanstorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
import onnxruntime as ort
from tqdm import tqdm

from ifcb_infer.cli import get_output_path, get_providers, pad_batch, write_output
from ifcb_infer.cli import (
get_embedding_output_path,
get_output_path,
get_providers,
pad_batch,
resolve_emit_embeddings,
write_embeddings,
write_output,
)
from ifcb_infer.datasets import IfcbBinDataset, IfcbBinImageTransformer, MyDataLoader


Expand All @@ -15,6 +23,8 @@ def main(args):
args.MODEL, sess_options=sess_options, providers=providers
)

emit_embeddings = resolve_emit_embeddings(args, ort_session)

input0 = ort_session.get_inputs()[0]
model_batch = input0.shape[0]
img_size = input0.shape[-1]
Expand Down Expand Up @@ -42,6 +52,7 @@ def main(args):
for bin_accessor in pbar:
img_pids = []
score_matrix = None
embedding_matrix = None

bin_relative_path = None
input_dir = args.bin_to_input_dir.get(bin_accessor)
Expand All @@ -61,7 +72,12 @@ def main(args):
dataloader = MyDataLoader(dataset, inference_batchsize, transformer)
bin_pid = dataset.pid

expected_output_path = get_output_path(args, bin_pid, bin_relative_path)
if args.embeddings_only:
expected_output_path = get_embedding_output_path(
args, bin_pid, bin_relative_path
)
else:
expected_output_path = get_output_path(args, bin_pid, bin_relative_path)
if os.path.exists(expected_output_path):
pbar.set_description(
f"batchsize={inference_batchsize} (skipping {bin_pid})"
Expand All @@ -86,6 +102,21 @@ def main(args):
score_matrix = np.concatenate(
[score_matrix, batch_score_matrix], axis=0
)

if emit_embeddings:
batch_embedding_matrix = outputs[1]
if embedding_matrix is None:
embedding_matrix = batch_embedding_matrix
else:
embedding_matrix = np.concatenate(
[embedding_matrix, batch_embedding_matrix], axis=0
)

img_pids.extend(batch_pids)

write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path)
if not args.embeddings_only:
write_output(args, bin_pid, img_pids, score_matrix, bin_relative_path)
if emit_embeddings:
write_embeddings(
args, bin_pid, img_pids, embedding_matrix, bin_relative_path
)
Loading
Loading