From 81c946d0449c7be05f6a07ae2c7fdd0409f6983e Mon Sep 17 00:00:00 2001 From: afloresep Date: Tue, 31 Mar 2026 22:16:15 +0200 Subject: [PATCH 1/5] add NLP example --- examples/word_embeddings_50k.py | 446 ++++++++++++++++++++++++++++++++ 1 file changed, 446 insertions(+) create mode 100644 examples/word_embeddings_50k.py diff --git a/examples/word_embeddings_50k.py b/examples/word_embeddings_50k.py new file mode 100644 index 0000000..0010829 --- /dev/null +++ b/examples/word_embeddings_50k.py @@ -0,0 +1,446 @@ +"""50K word embedding TMAP — semantic map of English vocabulary. + +Embed 50,000 English nouns (from WordNet) with a sentence-transformer, +build a TMAP with cosine metric, and explore the natural semantic +organization of the English language. + +WordNet provides automatic semantic categories (animal, food, artifact, +plant, person, etc.) for coloring the map. + +Outputs +------- +examples/word_embeddings_50k_tmap.html Interactive TMAP +examples/word_embeddings_50k_report.txt Analysis report + +Usage +----- + python examples/word_embeddings_50k.py + python examples/word_embeddings_50k.py --n 50000 + python examples/word_embeddings_50k.py --serve + +Requirements +------------ + pip install sentence-transformers nltk +""" + +from __future__ import annotations + +import argparse +import time +from pathlib import Path + +import numpy as np + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) + +CACHE_DIR = Path(__file__).parent / "data" / "word50k_cache" +OUTPUT_DIR = Path(__file__).parent + +# Readable labels for WordNet lexnames +LEXNAME_LABELS = { + # Nouns + "noun.Tops": "general", + "noun.act": "action", + "noun.animal": "animal", + "noun.artifact": "artifact", + "noun.attribute": "attribute", + "noun.body": "body", + "noun.cognition": "cognition", + "noun.communication": "communication", + "noun.event": "event", + "noun.feeling": "feeling", + "noun.food": "food", + "noun.group": "group", + "noun.location": "location", + "noun.motive": "motive", + "noun.object": "object", + "noun.person": "person", + "noun.phenomenon": "phenomenon", + "noun.plant": "plant", + "noun.possession": "possession", + "noun.process": "process", + "noun.quantity": "quantity", + "noun.relation": "relation", + "noun.shape": "shape", + "noun.state": "state", + "noun.substance": "substance", + "noun.time": "time", + # Verbs + "verb.body": "v:body", + "verb.change": "v:change", + "verb.cognition": "v:thinking", + "verb.communication": "v:speaking", + "verb.competition": "v:competing", + "verb.consumption": "v:consuming", + "verb.contact": "v:touching", + "verb.creation": "v:creating", + "verb.emotion": "v:emotion", + "verb.motion": "v:motion", + "verb.perception": "v:perception", + "verb.possession": "v:having", + "verb.social": "v:social", + "verb.stative": "v:being", + "verb.weather": "v:weather", + # Adjectives & Adverbs + "adj.all": "adjective", + "adj.pert": "adj:relational", + "adj.ppl": "adj:participle", + "adv.all": "adverb", +} + + +# --------------------------------------------------------------------------- +# 1. Vocabulary from WordNet +# --------------------------------------------------------------------------- + + +def load_wordnet_vocabulary(n: int) -> tuple[list[str], list[str]]: + """Extract n single words from WordNet (all POS) with semantic categories.""" + import nltk + + nltk.download("wordnet", quiet=True) + nltk.download("omw-1.4", quiet=True) + from nltk.corpus import wordnet as wn + + print(f"Extracting vocabulary from WordNet (target: {n:,} words)...") + t0 = time.time() + + # Collect (word, category) pairs. For words in multiple categories, + # use the most common synset's category (first synset = most frequent). + word_cat: dict[str, str] = {} + for synset in wn.all_synsets(): + cat = LEXNAME_LABELS.get(synset.lexname(), synset.lexname()) + for lemma in synset.lemmas(): + name = lemma.name().lower() + if "_" not in name and name.isalpha() and 3 <= len(name) <= 20: + if name not in word_cat: + word_cat[name] = cat + + # Sort by word length then alphabetically for determinism, take first n + all_words = sorted(word_cat.keys(), key=lambda w: (len(w), w)) + selected = all_words[:n] + + words = selected + categories = [word_cat[w] for w in words] + + elapsed = time.time() - t0 + print(f" {len(words):,} words extracted in {elapsed:.1f}s") + + # Category distribution + from collections import Counter + + counts = Counter(categories) + print(f" {len(counts)} categories:") + for cat, count in counts.most_common(): + print(f" {cat:20s} {count:6,}") + + return words, categories + + +# --------------------------------------------------------------------------- +# 2. Embedding +# --------------------------------------------------------------------------- + + +def compute_embeddings( + words: list[str], + model_name: str, + batch_size: int = 512, +) -> np.ndarray: + """Embed words with sentence-transformers, cached to disk.""" + cache_path = CACHE_DIR / f"embeddings_{len(words)}_{model_name.replace('/', '_')}.npy" + if cache_path.exists(): + print(f" Loading cached embeddings from {cache_path}") + return np.load(cache_path) + + from sentence_transformers import SentenceTransformer + + print(f" Loading model: {model_name}") + model = SentenceTransformer(model_name) + print(f" Encoding {len(words):,} words (batch_size={batch_size})...") + t0 = time.time() + embeddings = model.encode( + words, + batch_size=batch_size, + show_progress_bar=True, + normalize_embeddings=True, + ) + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s — shape: {embeddings.shape}") + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + np.save(cache_path, embeddings) + return embeddings + + +# --------------------------------------------------------------------------- +# 3. TMAP +# --------------------------------------------------------------------------- + + +def build_tmap(embeddings: np.ndarray, k: int) -> TMAP: + """Build TMAP with cosine metric. Stores index for later querying.""" + print(f"Fitting TMAP (metric='cosine', k={k}, n={len(embeddings):,})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=k, + layout_iterations=1000, + seed=42, + store_index=True, + ).fit(embeddings.astype(np.float32)) + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s") + return model + + +# --------------------------------------------------------------------------- +# 4. Analysis +# --------------------------------------------------------------------------- + + +def analyze( + model: TMAP, + words: list[str], + categories: list[str], +) -> str: + """Generate analysis report.""" + tree = model.tree_ + cat_arr = np.array(categories) + lines: list[str] = [] + w = lines.append + + w(f"Word Embedding TMAP — {len(words):,} words, {len(set(categories))} categories\n") + + # 1. Category clustering + be = boundary_edges(tree, cat_arr) + w("1. Category boundaries:") + w( + f" Same-category edges: {len(tree.edges) - len(be):,} / {len(tree.edges):,} " + f"({1 - len(be) / len(tree.edges):.1%})" + ) + w(f" Cross-category edges: {len(be):,} ({len(be) / len(tree.edges):.1%})\n") + + # 2. Subtree purity + purity = subtree_purity(tree, cat_arr, min_size=20) + valid = purity[~np.isnan(purity)] + w(f"2. Subtree purity: mean={valid.mean():.3f} median={np.median(valid):.3f}\n") + + # 3. Most connected category pairs + cmat, classes = confusion_matrix_from_tree(tree, cat_arr) + np.fill_diagonal(cmat, 0) + upper = np.triu_indices_from(cmat, k=1) + pair_counts = cmat[upper] + top_idx = np.argsort(pair_counts)[::-1][:15] + w("3. Most connected category pairs:") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:5d} edges: {classes[r]:>15s} <-> {classes[c]}") + w("") + + # 4. Sample paths + word_to_idx = {w: i for i, w in enumerate(words)} + paths_to_trace = [ + ("dog", "cat"), + ("dog", "wolf"), + ("apple", "banana"), + ("guitar", "piano"), + ("doctor", "nurse"), + ("mountain", "ocean"), + ("happy", "sad"), + ("sword", "shield"), + ("rain", "snow"), + ("dog", "guitar"), + ("brain", "computer"), + ] + w("4. Semantic paths:") + for word_a, word_b in paths_to_trace: + if word_a not in word_to_idx or word_b not in word_to_idx: + continue + idx_a = word_to_idx[word_a] + idx_b = word_to_idx[word_b] + path_nodes = tree.path(idx_a, idx_b) + path_words = [words[n] for n in path_nodes] + if len(path_words) <= 10: + path_str = " → ".join(path_words) + else: + path_str = ( + " → ".join(path_words[:5]) + + f" → [{len(path_words) - 7} more] → " + + " → ".join(path_words[-2:]) + ) + w(f" {word_a} → {word_b} ({len(path_words)} hops): {path_str}") + w("") + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# 5. Visualization +# --------------------------------------------------------------------------- + + +def create_visualization( + model: TMAP, + words: list[str], + categories: list[str], +): + viz = model.to_tmapviz() + viz.title = f"English Vocabulary — {len(words):,} Words" + viz.add_color_layout("category", categories, categorical=True, color="tab20") + viz.add_label("word", words) + return viz + + +# --------------------------------------------------------------------------- +# 6. Playground — interactive word query +# --------------------------------------------------------------------------- + +MODEL_SAVE_PATH = CACHE_DIR / "word_tmap.model" +WORDS_SAVE_PATH = CACHE_DIR / "word_list.npy" +CATS_SAVE_PATH = CACHE_DIR / "word_categories.npy" + + +def playground(model_name: str, k_show: int = 10) -> None: + """Interactive REPL: type a word (or phrase), see where it lands.""" + from sentence_transformers import SentenceTransformer + + print("Loading saved model...") + tmap_model = TMAP.load(MODEL_SAVE_PATH) + words = np.load(WORDS_SAVE_PATH, allow_pickle=True).tolist() + categories = np.load(CATS_SAVE_PATH, allow_pickle=True).tolist() + word_to_idx = {w: i for i, w in enumerate(words)} + + print(f"Loading encoder: {model_name}") + encoder = SentenceTransformer(model_name) + + print(f"\nPlayground ready — {len(words):,} words in map") + print("Type a word or phrase to see where it would land.") + print("Type two words separated by ' -> ' to trace a path.") + print("Type 'quit' to exit.\n") + + while True: + try: + query = input(">>> ").strip() + except (EOFError, KeyboardInterrupt): + print() + break + if not query or query.lower() == "quit": + break + + # Path mode: "word_a -> word_b" + if " -> " in query: + parts = [p.strip() for p in query.split(" -> ", 1)] + idx_a = word_to_idx.get(parts[0].lower()) + idx_b = word_to_idx.get(parts[1].lower()) + if idx_a is None: + print(f" '{parts[0]}' not in vocabulary") + continue + if idx_b is None: + print(f" '{parts[1]}' not in vocabulary") + continue + path_nodes = tmap_model.tree_.path(idx_a, idx_b) + path_words = [words[n] for n in path_nodes] + print(f" Path ({len(path_words)} hops):") + print(f" {' → '.join(path_words)}") + print() + continue + + # Query mode: embed and find neighbors + emb = encoder.encode( + [query], + normalize_embeddings=True, + ).astype(np.float32) + + indices, distances = tmap_model.kneighbors(emb) + idx_row = indices[0] + dist_row = distances[0] + + # Check if the query itself is in the vocabulary + in_vocab = query.lower() in word_to_idx + + print(f' Query: "{query}"' + (" (in vocabulary)" if in_vocab else " (not in vocabulary)")) + print(" Nearest neighbors on the map:") + for rank, (idx, dist) in enumerate(zip(idx_row, dist_row)): + if idx < 0: + break + word = words[idx] + cat = categories[idx] + coord = tmap_model.embedding_[idx] + marker = " <--" if word == query.lower() else "" + print( + f" {rank + 1:2d}. {word:25s} [{cat:>15s}] " + f"dist={dist:.4f} pos=({coord[0]:.1f}, {coord[1]:.1f}){marker}" + ) + + # Show where it would be placed (centroid of top neighbors) + valid = idx_row[idx_row >= 0][:5] + if len(valid) > 0: + centroid = tmap_model.embedding_[valid].mean(axis=0) + print(f" Approximate position: ({centroid[0]:.1f}, {centroid[1]:.1f})") + print() + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser(description="50K word embedding TMAP") + parser.add_argument("--n", type=int, default=50000, help="Number of words") + parser.add_argument("--model", type=str, default="all-MiniLM-L6-v2") + parser.add_argument("--k", type=int, default=20, help="Number of neighbors") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + parser.add_argument( + "--playground", + action="store_true", + help="Interactive mode: query words against the saved model", + ) + args = parser.parse_args() + + if args.playground: + if not MODEL_SAVE_PATH.exists(): + print(f"No saved model at {MODEL_SAVE_PATH}") + print("Run without --playground first to build and save the model.") + return + playground(args.model) + return + + words, categories = load_wordnet_vocabulary(args.n) + embeddings = compute_embeddings(words, args.model) + model = build_tmap(embeddings, args.k) + + # Save model + word list for playground mode + CACHE_DIR.mkdir(parents=True, exist_ok=True) + model.save(MODEL_SAVE_PATH) + np.save(WORDS_SAVE_PATH, np.array(words, dtype=object)) + np.save(CATS_SAVE_PATH, np.array(categories, dtype=object)) + print(f"Model saved to {MODEL_SAVE_PATH}") + + report = analyze(model, words, categories) + report_path = OUTPUT_DIR / "word_embeddings_50k_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + print("Building visualization...") + viz = create_visualization(model, words, categories) + html_path = viz.write_html(OUTPUT_DIR / "word_embeddings_50k_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main() From d3a779eb9a17a7c7af118247aa9f5c271e62cd29 Mon Sep 17 00:00:00 2001 From: afloresep Date: Thu, 2 Apr 2026 10:39:36 +0200 Subject: [PATCH 2/5] add singlecell helpers: subset_anndata, sample_obs_indices, obs_to_numeric --- src/tmap/utils/__init__.py | 12 +- src/tmap/utils/singlecell.py | 209 +++++++++++++++++++++++++++++++++-- 2 files changed, 211 insertions(+), 10 deletions(-) diff --git a/src/tmap/utils/__init__.py b/src/tmap/utils/__init__.py index 2f0c8fc..2884437 100644 --- a/src/tmap/utils/__init__.py +++ b/src/tmap/utils/__init__.py @@ -20,7 +20,14 @@ read_protein_csv, sequence_properties, ) -from tmap.utils.singlecell import cell_metadata, from_anndata, marker_scores +from tmap.utils.singlecell import ( + cell_metadata, + from_anndata, + marker_scores, + obs_to_numeric, + sample_obs_indices, + subset_anndata, +) __all__ = [ "AVAILABLE_PROPERTIES", @@ -33,6 +40,7 @@ "from_anndata", "marker_scores", "molecular_properties", + "obs_to_numeric", "murcko_scaffolds", "reaction_properties", "parse_alignment", @@ -41,5 +49,7 @@ "read_pdb", "read_pdb_dir", "read_protein_csv", + "sample_obs_indices", "sequence_properties", + "subset_anndata", ] diff --git a/src/tmap/utils/singlecell.py b/src/tmap/utils/singlecell.py index 30c6711..427cc47 100644 --- a/src/tmap/utils/singlecell.py +++ b/src/tmap/utils/singlecell.py @@ -1,7 +1,8 @@ """Single-cell RNA-seq utilities for TMAP. Bridge between the scverse/AnnData ecosystem and TMAP. Provides helpers -to extract matrices, cell metadata, and gene scores from AnnData objects. +to extract matrices, subset AnnData objects, sample observations, parse +observation columns, and compute lightweight marker scores. Requires ``anndata`` (install via ``pip install anndata``). """ @@ -29,6 +30,65 @@ def _to_dense(X: np.ndarray | sparse.spmatrix) -> NDArray[np.float32]: return np.asarray(X, dtype=np.float32) +def _group_quotas(counts: NDArray[np.int64], max_items: int, mode: str) -> NDArray[np.int64]: + """Allocate sample sizes across groups without replacement.""" + if mode == "proportional": + quotas = np.floor(max_items * counts / counts.sum()).astype(np.int64) + quotas = np.minimum(np.maximum(quotas, 1), counts) + + while quotas.sum() > max_items: + idx = int(np.argmax(quotas)) + if quotas[idx] > 1: + quotas[idx] -= 1 + while quotas.sum() < max_items: + remaining = counts - quotas + idx = int(np.argmax(remaining)) + if remaining[idx] == 0: + break + quotas[idx] += 1 + return quotas + + if mode == "balanced": + quotas = np.zeros_like(counts, dtype=np.int64) + remaining = int(max_items) + active = np.arange(len(counts), dtype=np.int64) + + while remaining > 0 and len(active) > 0: + share = max(1, remaining // len(active)) + new_active: list[int] = [] + changed = False + for idx in active: + capacity = int(counts[idx] - quotas[idx]) + if capacity <= 0: + continue + take = min(capacity, share) + if take > 0: + quotas[idx] += take + remaining -= take + changed = True + if quotas[idx] < counts[idx]: + new_active.append(int(idx)) + if remaining == 0: + break + if not changed: + break + active = np.asarray(new_active, dtype=np.int64) + + if remaining > 0: + remaining_cap = counts - quotas + for idx in np.argsort(-remaining_cap): + take = min(int(remaining_cap[idx]), remaining) + if take <= 0: + continue + quotas[idx] += take + remaining -= take + if remaining == 0: + break + return quotas + + raise ValueError(f"Unknown sampling mode: {mode!r}") + + def from_anndata( adata: AnnData, use_rep: str | None = "X_pca", @@ -53,14 +113,18 @@ def from_anndata( Layer in ``adata.layers`` to use instead of ``adata.X``. Only used when *use_rep* is ``None`` or missing. n_top_genes : int or None - If using ``adata.X`` / a layer, subset to the top *n* highly - variable genes (requires ``adata.var['highly_variable']``). - Ignored when using an obsm representation. + If using ``adata.X`` / a layer and ``adata.var['highly_variable']`` + exists, use the marked highly variable genes when at least + *n_top_genes* are available. Ignored when using an obsm + representation. Returns ------- ndarray of shape ``(n_cells, n_features)``, dtype ``float32`` - Ready to pass to ``TMAP(metric='cosine').fit(X)``. + Ready to pass to ``TMAP(metric='cosine').fit(X)``. Expression + matrices and layers are densified before return, so for large + sparse inputs it is better to subset first or use a precomputed + obsm representation. """ # Try obsm representation first if use_rep is not None and use_rep in adata.obsm: @@ -89,7 +153,11 @@ def from_anndata( n_hvg = int(hvg_mask.sum()) if n_hvg >= n_top_genes: raw = raw[:, hvg_mask] - logger.info("from_anndata: subset to %d HVGs", n_hvg) + logger.info( + "from_anndata: subset to %d highly variable genes (requested at least %d)", + n_hvg, + n_top_genes, + ) else: logger.info( "from_anndata: only %d HVGs found (requested %d), using all", @@ -102,6 +170,128 @@ def from_anndata( return X +def subset_anndata( + adata: AnnData, + *, + obs_mask: Sequence[bool] | NDArray[np.bool_] | None = None, + obs_indices: Sequence[int] | NDArray[np.int64] | None = None, + copy: bool = True, +) -> AnnData: + """Subset observations in one step. + + This is mainly useful for backed ``AnnData`` objects, where repeated + slicing into views is not allowed. Backed inputs are materialized with + ``to_memory()`` after subsetting. + + Parameters + ---------- + adata : AnnData + Annotated data matrix. + obs_mask : sequence[bool] or None + Boolean mask over observations. + obs_indices : sequence[int] or None + Explicit observation indices. + copy : bool + Whether to return a copy for in-memory AnnData inputs. + + Returns + ------- + AnnData + Subsetted object. Backed inputs return an in-memory ``AnnData``. + """ + if obs_mask is not None and obs_indices is not None: + raise ValueError("Provide either obs_mask or obs_indices, not both.") + + if obs_mask is None and obs_indices is None: + if getattr(adata, "isbacked", False): + return adata.to_memory() + return adata.copy() if copy else adata + + if obs_mask is not None: + mask = np.asarray(obs_mask, dtype=bool) + if mask.ndim != 1 or len(mask) != adata.n_obs: + raise ValueError("obs_mask must be a 1D boolean array of length adata.n_obs.") + idx = np.flatnonzero(mask) + else: + idx = np.asarray(obs_indices, dtype=np.int64) + if idx.ndim != 1: + raise ValueError("obs_indices must be a 1D array.") + + subset = adata[np.sort(idx)] + if getattr(adata, "isbacked", False): + return subset.to_memory() + return subset.copy() if copy else subset + + +def sample_obs_indices( + groups: Sequence[object] | NDArray, + *, + max_items: int, + seed: int, + mode: str = "proportional", +) -> NDArray[np.int64]: + """Sample observation indices from groups without replacement. + + Parameters + ---------- + groups : sequence + Group label per observation. + max_items : int + Maximum number of observations to keep. + seed : int + Random seed for group-wise sampling. + mode : {"proportional", "balanced"} + ``"proportional"`` preserves group frequencies as much as + possible. ``"balanced"`` spreads observations more evenly across + groups, capped by the number available in each group. + + Returns + ------- + ndarray of shape ``(n_kept,)``, dtype ``int64`` + Sorted indices into the original array. + """ + groups_arr = np.asarray(groups).astype(str) + if groups_arr.ndim != 1: + raise ValueError("groups must be a 1D array.") + if max_items <= 0: + raise ValueError("max_items must be positive.") + if len(groups_arr) <= max_items: + return np.arange(len(groups_arr), dtype=np.int64) + + unique, counts = np.unique(groups_arr, return_counts=True) + quotas = _group_quotas(counts.astype(np.int64), max_items, mode) + rng = np.random.default_rng(seed) + + keep_parts: list[np.ndarray] = [] + for group, quota in zip(unique, quotas, strict=True): + idx = np.where(groups_arr == group)[0] + chosen = np.sort(rng.choice(idx, size=int(quota), replace=False)) + keep_parts.append(chosen) + + return np.sort(np.concatenate(keep_parts)) + + +def obs_to_numeric(values: Sequence[object] | NDArray) -> NDArray[np.float32] | None: + """Convert an observation-like array to numeric values when possible.""" + arr = np.asarray(values) + + if np.issubdtype(arr.dtype, np.number): + return arr.astype(np.float32, copy=False) + + out = np.full(arr.shape[0], np.nan, dtype=np.float32) + for i, value in enumerate(arr.astype(str)): + digits = "".join(ch for ch in value if ch.isdigit() or ch in ".-") + if digits and digits not in {"-", ".", "-."}: + try: + out[i] = float(digits) + except ValueError: + continue + + if np.isnan(out).all(): + return None + return out + + def cell_metadata( adata: AnnData, keys: Sequence[str] | None = None, @@ -171,12 +361,13 @@ def marker_scores( ndarray of shape ``(n_cells,)``, dtype ``float64`` Mean expression across the gene set for each cell. """ - var_names = list(adata.var_names) + var_index = {str(name): i for i, name in enumerate(adata.var_names)} indices = [] missing = [] for g in gene_list: - if g in var_names: - indices.append(var_names.index(g)) + idx = var_index.get(g) + if idx is not None: + indices.append(idx) else: missing.append(g) From 9b11b3118af5acfe8f1b4256beb4a9ba884abe34 Mon Sep 17 00:00:00 2001 From: afloresep Date: Thu, 2 Apr 2026 10:42:14 +0200 Subject: [PATCH 3/5] add tests for singlecell helpers --- tests/test_singlecell.py | 161 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 tests/test_singlecell.py diff --git a/tests/test_singlecell.py b/tests/test_singlecell.py new file mode 100644 index 0000000..fe2bdbb --- /dev/null +++ b/tests/test_singlecell.py @@ -0,0 +1,161 @@ +"""Tests for tmap.utils.singlecell.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from scipy import sparse + +anndata = pytest.importorskip("anndata") + +from tmap.utils.singlecell import ( # noqa: E402 + cell_metadata, + from_anndata, + marker_scores, + obs_to_numeric, + sample_obs_indices, + subset_anndata, +) + + +def _make_adata(): + X = sparse.csr_matrix( + np.array( + [ + [1.0, 0.0, 3.0, 0.0], + [0.0, 2.0, 0.0, 4.0], + [5.0, 0.0, 6.0, 0.0], + ], + dtype=np.float32, + ) + ) + obs = pd.DataFrame( + { + "celltype": pd.Categorical(["prog", "prog", "diff"]), + "day": [2, 2, 7], + "quality": [0.1, 0.2, 0.9], + }, + index=["c0", "c1", "c2"], + ) + var = pd.DataFrame( + {"highly_variable": [True, False, True, False]}, + index=["GeneA", "GeneB", "GeneC", "GeneD"], + ) + + adata = anndata.AnnData(X=X, obs=obs, var=var) + adata.obsm["X_pca"] = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + ], + dtype=np.float32, + ) + adata.layers["counts"] = sparse.csr_matrix( + np.array( + [ + [10.0, 1.0, 30.0, 0.0], + [0.0, 20.0, 0.0, 40.0], + [50.0, 0.0, 60.0, 0.0], + ], + dtype=np.float32, + ) + ) + return adata + + +def test_from_anndata_uses_obsm_representation(): + adata = _make_adata() + X = from_anndata(adata, use_rep="X_pca") + + assert X.shape == (3, 2) + assert X.dtype == np.float32 + np.testing.assert_allclose(X, adata.obsm["X_pca"]) + + +def test_from_anndata_falls_back_to_layer_and_hvgs(): + adata = _make_adata() + X = from_anndata(adata, use_rep="missing_rep", layer="counts", n_top_genes=2) + + assert X.shape == (3, 2) + assert X.dtype == np.float32 + np.testing.assert_allclose( + X, + np.array( + [ + [10.0, 30.0], + [0.0, 0.0], + [50.0, 60.0], + ], + dtype=np.float32, + ), + ) + + +def test_cell_metadata_preserves_numeric_and_categorical_columns(): + adata = _make_adata() + meta = cell_metadata(adata, keys=["celltype", "day", "quality"]) + + assert set(meta) == {"celltype", "day", "quality"} + assert meta["celltype"].dtype == object + assert meta["day"].dtype == np.float64 + assert meta["quality"].dtype == np.float64 + assert meta["celltype"].tolist() == ["prog", "prog", "diff"] + + +def test_obs_to_numeric_parses_embedded_numbers(): + values = pd.Series(["day_2", "t=7.5", "unknown"], dtype="object") + out = obs_to_numeric(values) + + assert out is not None + np.testing.assert_allclose(out[:2], [2.0, 7.5]) + assert np.isnan(out[2]) + + +def test_sample_obs_indices_balanced_and_proportional(): + groups = np.array(["0"] * 6 + ["1"] * 3 + ["2"] * 1) + + proportional = sample_obs_indices(groups, max_items=6, seed=42, mode="proportional") + balanced = sample_obs_indices(groups, max_items=6, seed=42, mode="balanced") + + prop_groups = groups[proportional] + bal_groups = groups[balanced] + assert {g: int((prop_groups == g).sum()) for g in np.unique(prop_groups)} == { + "0": 4, + "1": 1, + "2": 1, + } + assert {g: int((bal_groups == g).sum()) for g in np.unique(bal_groups)} == { + "0": 3, + "1": 2, + "2": 1, + } + + +def test_subset_anndata_materializes_backed_file(tmp_path): + adata = _make_adata() + path = tmp_path / "toy.h5ad" + adata.write_h5ad(path) + + backed = anndata.read_h5ad(path, backed="r") + subset = subset_anndata(backed, obs_mask=np.array([True, False, True])) + + assert subset.n_obs == 2 + assert getattr(subset, "isbacked", False) is False + assert subset.obs_names.tolist() == ["c0", "c2"] + np.testing.assert_allclose(subset.obsm["X_pca"], adata.obsm["X_pca"][[0, 2]]) + + +def test_marker_scores_dense_mean_expression(): + adata = _make_adata() + scores = marker_scores(adata, ["GeneA", "GeneC"]) + + np.testing.assert_allclose(scores, [2.0, 0.0, 5.5]) + assert scores.dtype == np.float64 + + +def test_marker_scores_raises_when_no_genes_found(): + adata = _make_adata() + with pytest.raises(ValueError, match="None of the requested genes found"): + marker_scores(adata, ["Missing1", "Missing2"]) From c26b70cf27ffdd9a3d0b93b88036e62e408ff963 Mon Sep 17 00:00:00 2001 From: afloresep Date: Thu, 2 Apr 2026 10:42:39 +0200 Subject: [PATCH 4/5] update examples: add singlecell section to README, format mnist and pet_breed --- examples/README.md | 7 +++ examples/mnist_cosine_tmap.py | 8 +-- examples/pet_breed_audit.py | 95 +++++++++++++++++++++++------------ 3 files changed, 73 insertions(+), 37 deletions(-) diff --git a/examples/README.md b/examples/README.md index abacff6..2c18ce4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -21,6 +21,13 @@ |---------|-------------|--------------| | [`afdb_clusters_tmap.py`](afdb_clusters_tmap.py) | AlphaFold DB: 2.3M structural clusters from Foldseek | Precomputed `KNNGraph`, taxonomy resolution, `node_diversity`, large-scale pipeline | +## Single-cell + +| Example | Description | Key features | +|---------|-------------|--------------| +| [`singlecell_trajectory_tmap.py`](singlecell_trajectory_tmap.py) | Murine lung regeneration trajectory from an official AnnData `.h5ad` | `from_anndata`, `cell_metadata`, `marker_scores`, pseudotime via `distances_from()` | +| [`singlecell_reprogramming_tmap.py`](singlecell_reprogramming_tmap.py) | Morris fibroblast-to-iEP direct reprogramming trajectory from an official AnnData `.h5ad` | Backed AnnData filtering, explicit root/target anchors, reference pseudotime comparison | + ## Quick start The fastest way to try TMAP: diff --git a/examples/mnist_cosine_tmap.py b/examples/mnist_cosine_tmap.py index 5d80fe9..5131fda 100644 --- a/examples/mnist_cosine_tmap.py +++ b/examples/mnist_cosine_tmap.py @@ -33,7 +33,7 @@ def main() -> None: cfg = LayoutConfig() cfg.k = 20 cfg.kc = 200 - cfg.node_size = 1/30 + cfg.node_size = 1 / 30 cfg.mmm_repeats = 2 cfg.sl_extra_scaling_steps = 10 cfg.sl_scaling_type = ScalingType.RelativeToDrawing @@ -44,11 +44,7 @@ def main() -> None: print("\nFitting TMAP (metric='cosine', n_neighbors=20)...") t0 = time.perf_counter() model = TMAP( - metric="cosine", - n_neighbors=20, - seed=42, - layout_iterations=1000, - layout_config=cfg + metric="cosine", n_neighbors=20, seed=42, layout_iterations=1000, layout_config=cfg ).fit(X) elapsed = time.perf_counter() - t0 print(f" Done in {elapsed:.1f}s") diff --git a/examples/pet_breed_audit.py b/examples/pet_breed_audit.py index e7b5d13..d3dcf02 100644 --- a/examples/pet_breed_audit.py +++ b/examples/pet_breed_audit.py @@ -62,12 +62,14 @@ def _extract_embeddings( if cache_emb.exists() and cache_lbl.exists(): return np.load(cache_emb), np.load(cache_lbl) - transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) dataset = datasets.OxfordIIITPet( root=str(Path(__file__).parent / "data" / "oxford-iiit-pet"), split=split, @@ -76,15 +78,20 @@ def _extract_embeddings( transform=transform, ) loader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=False, num_workers=2, + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=2, ) model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) model.eval().to(device) # Hook avgpool output (2048-d) features: list[torch.Tensor] = [] + def _hook(_mod: nn.Module, _inp: tuple, out: torch.Tensor) -> None: features.append(out.squeeze(-1).squeeze(-1).cpu()) + model.avgpool.register_forward_hook(_hook) all_labels: list[int] = [] @@ -219,16 +226,18 @@ def _find_best_failure_paths( continue conf_along = path_properties(tree, idx, best_target, confidences) - paths.append({ - "from": idx, - "to": best_target, - "true_class": class_names[true_cls], - "pred_class": class_names[preds[idx]], - "path_length": int(best_len), - "conf_start": float(confidences[idx]), - "conf_end": float(confidences[best_target]), - "conf_min": float(conf_along.min()), - }) + paths.append( + { + "from": idx, + "to": best_target, + "true_class": class_names[true_cls], + "pred_class": class_names[preds[idx]], + "path_length": int(best_len), + "conf_start": float(confidences[idx]), + "conf_end": float(confidences[best_target]), + "conf_min": float(conf_along.min()), + } + ) paths.sort(key=lambda p: -p["path_length"]) return paths[:top_k] @@ -247,13 +256,15 @@ def _analyze_tree( w = lines.append acc = (preds == true_labels).mean() - w(f"Oxford-IIIT Pets Classifier Audit " - f"({len(true_labels)} test images, {len(class_names)} breeds)") + w( + f"Oxford-IIIT Pets Classifier Audit " + f"({len(true_labels)} test images, {len(class_names)} breeds)" + ) w(f"Overall accuracy: {acc:.1%}\n") # 1. Boundary edges be = boundary_edges(tree, preds) - w(f"1. Boundary edges: {len(be)} / {len(tree.edges)} edges ({len(be)/len(tree.edges):.1%})") + w(f"1. Boundary edges: {len(be)} / {len(tree.edges)} edges ({len(be) / len(tree.edges):.1%})") w(" Edges where neighboring points have different predicted breeds.\n") # 2. Confusion matrix @@ -267,8 +278,10 @@ def _analyze_tree( if pair_counts[i] == 0: break r, c = upper[0][i], upper[1][i] - w(f" {class_names[classes[r]]:>30s} <-> " - f"{class_names[classes[c]]:<30s} ({pair_counts[i]} edges)") + w( + f" {class_names[classes[r]]:>30s} <-> " + f"{class_names[classes[c]]:<30s} ({pair_counts[i]} edges)" + ) w("") # 3. Confidence gradients @@ -288,10 +301,14 @@ def _analyze_tree( paths = _find_best_failure_paths(tree, true_labels, preds, confidences, class_names) w(f"5. Top failure paths ({len(paths)} shown):") for p in paths: - w(f" Node {p['from']} ({p['true_class']}, " - f"predicted {p['pred_class']}, conf={p['conf_start']:.2f})") - w(f" -> Node {p['to']} (correct, conf={p['conf_end']:.2f}) " - f"path_len={p['path_length']} min_conf={p['conf_min']:.2f}") + w( + f" Node {p['from']} ({p['true_class']}, " + f"predicted {p['pred_class']}, conf={p['conf_start']:.2f})" + ) + w( + f" -> Node {p['to']} (correct, conf={p['conf_end']:.2f}) " + f"path_len={p['path_length']} min_conf={p['conf_min']:.2f}" + ) w("") # 6. Mislabel candidates @@ -302,8 +319,10 @@ def _analyze_tree( top_wrong = wrong[np.argsort(wrong_conf)[::-1][:10]] w("6. Mislabel candidates (high-confidence misclassifications):") for idx in top_wrong: - w(f" Index {idx:5d} true={class_names[true_labels[idx]]:>25s} " - f"pred={class_names[preds[idx]]:<25s} conf={confidences[idx]:.3f}") + w( + f" Index {idx:5d} true={class_names[true_labels[idx]]:>25s} " + f"pred={class_names[preds[idx]]:<25s} conf={confidences[idx]:.3f}" + ) else: w("6. No misclassifications found.") w("") @@ -359,7 +378,9 @@ def main() -> None: # Class names ds = datasets.OxfordIIITPet( root=str(Path(__file__).parent / "data" / "oxford-iiit-pet"), - split="test", target_types="category", download=False, + split="test", + target_types="category", + download=False, ) class_names = [name.replace("_", " ").title() for name in ds.classes] n_classes = len(class_names) @@ -368,7 +389,13 @@ def main() -> None: # Train linear probe print(f"Training linear probe ({args.epochs} epochs)...") preds, confidences = _train_probe( - train_emb, train_labels, test_emb, test_labels, n_classes, args.epochs, device, + train_emb, + train_labels, + test_emb, + test_labels, + n_classes, + args.epochs, + device, ) acc = (preds == test_labels).mean() print(f"Probe accuracy: {acc:.1%}") @@ -391,7 +418,13 @@ def main() -> None: # Visualization print("Building visualization...") viz = _create_visualization( - model, class_names, test_labels, preds, confidences, purity, image_uris, + model, + class_names, + test_labels, + preds, + confidences, + purity, + image_uris, ) html_path = viz.write_html(OUTPUT_DIR / "pets_tmap") print(f"HTML saved to {html_path}") From 4a0cc6cbf610fe9ac512e13681e576320479138e Mon Sep 17 00:00:00 2001 From: afloresep Date: Thu, 2 Apr 2026 10:46:07 +0200 Subject: [PATCH 5/5] add examples: cub200 birds, emnist, flowers, word embeddings, wikiart --- examples/cub200_birds_tmap.py | 426 ++++++++ examples/emnist_characters_tmap.py | 361 +++++++ examples/flowers_tmap.py | 557 +++++++++++ examples/wikiart_tmap.py | 452 +++++++++ examples/word_embeddings_tmap.py | 1488 ++++++++++++++++++++++++++++ 5 files changed, 3284 insertions(+) create mode 100644 examples/cub200_birds_tmap.py create mode 100644 examples/emnist_characters_tmap.py create mode 100644 examples/flowers_tmap.py create mode 100644 examples/wikiart_tmap.py create mode 100644 examples/word_embeddings_tmap.py diff --git a/examples/cub200_birds_tmap.py b/examples/cub200_birds_tmap.py new file mode 100644 index 0000000..f31cb3d --- /dev/null +++ b/examples/cub200_birds_tmap.py @@ -0,0 +1,426 @@ +"""CUB-200 bird species — morphological paths across 200 species. + +Build a TMAP of 12K bird images across 200 species. The tree reveals +visual similarity: similar-looking species cluster together, and paths +trace morphological gradients across bird families — from hummingbird +to kingfisher to woodpecker (small, colorful, pointed beak). + +Outputs +------- +examples/cub200_tmap.html Interactive TMAP with bird tooltips +examples/cub200_report.txt Species analysis report + +Data +---- +Downloads CUB-200-2011 from HuggingFace (~1.2 GB, cached). + +Usage +----- + python examples/cub200_birds_tmap.py + python examples/cub200_birds_tmap.py --serve + +Requirements +------------ + pip install datasets torch torchvision +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import time +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torchvision import models, transforms + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) + +CACHE_DIR = Path(__file__).parent / "data" / "cub200_cache" +OUTPUT_DIR = Path(__file__).parent + +# Bird families for supercategory analysis (prefix → family) +BIRD_FAMILY = { + "Albatross": "seabird", + "Auklet": "seabird", + "Cormorant": "seabird", + "Frigate": "seabird", + "Fulmar": "seabird", + "Gull": "seabird", + "Jaeger": "seabird", + "Pelican": "seabird", + "Puffin": "seabird", + "Tern": "seabird", + "Booby": "seabird", + "Petrel": "seabird", + "Grebe": "waterbird", + "Heron": "waterbird", + "Crane": "waterbird", + "Duck": "waterbird", + "Mallard": "waterbird", + "Merganser": "waterbird", + "Kingfisher": "waterbird", + "Sparrow": "songbird", + "Warbler": "songbird", + "Wren": "songbird", + "Finch": "songbird", + "Bunting": "songbird", + "Vireo": "songbird", + "Tanager": "songbird", + "Grosbeak": "songbird", + "Oriole": "songbird", + "Cardinal": "songbird", + "Towhee": "songbird", + "Junco": "songbird", + "Goldfinch": "songbird", + "Woodpecker": "woodpecker", + "Flicker": "woodpecker", + "Hummingbird": "hummingbird", + "Jay": "corvid", + "Crow": "corvid", + "Raven": "corvid", + "Hawk": "raptor", + "Eagle": "raptor", + "Falcon": "raptor", + "Owl": "raptor", + "Osprey": "raptor", + "Kite": "raptor", + "Swallow": "aerial", + "Swift": "aerial", + "Nighthawk": "aerial", + "Mockingbird": "mimic", + "Thrasher": "mimic", + "Catbird": "mimic", + "Blackbird": "icterid", + "Meadowlark": "icterid", + "Cowbird": "icterid", + "Grackle": "icterid", + "Bobolink": "icterid", + "Cuckoo": "other", + "Pigeon": "other", + "Parakeet": "other", + "Flycatcher": "flycatcher", + "Kingbird": "flycatcher", + "Pewee": "flycatcher", + "Phoebe": "flycatcher", + "Nuthatch": "tree-clinger", + "Creeper": "tree-clinger", + "Chickadee": "tit", + "Titmouse": "tit", + "Waxwing": "other", + "Shrike": "other", + "Pipit": "other", + "Lark": "other", + "Whip": "other", +} + + +def _guess_family(species_name: str) -> str: + """Map a CUB-200 species name to a bird family.""" + # Species names are like "Black_footed_Albatross" + for key, family in BIRD_FAMILY.items(): + if key.lower() in species_name.lower(): + return family + return "other" + + +# 1. Data loading + + +def load_cub200() -> tuple: + """Load CUB-200-2011, return (dataset, labels, species_names).""" + try: + from datasets import concatenate_datasets, load_dataset + except ImportError: + raise ImportError("pip install datasets") + + print("Loading CUB-200-2011 from HuggingFace...") + ds = load_dataset("bentrevett/caltech-ucsd-birds-200-2011") + species_names = ds["train"].features["label"].names + + # Combine train + test + combined = concatenate_datasets([ds["train"], ds["test"]]) + labels = np.array(combined["label"]) + print(f" {len(combined):,} images, {len(species_names)} species") + + return combined, labels, species_names + + +# 2. Embeddings + + +def extract_embeddings( + ds, + batch_size: int, + device: torch.device, +) -> np.ndarray: + """Extract ResNet-50 avgpool embeddings (2048-d), cached.""" + n = len(ds) + cache_path = CACHE_DIR / f"embeddings_resnet50_{n}.npy" + if cache_path.exists(): + print(f" Loading cached embeddings: {cache_path.name}") + return np.load(cache_path) + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) + model.eval().to(device) + features: list[torch.Tensor] = [] + + def _hook(_mod, _inp, out): + features.append(out.squeeze(-1).squeeze(-1).cpu()) + + model.avgpool.register_forward_hook(_hook) + + print(f" Extracting embeddings ({n:,} images)...") + t0 = time.time() + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + batch = ds[start:end] + tensors = [] + for img in batch["image"]: + if not isinstance(img, Image.Image): + img = Image.fromarray(np.array(img)) + tensors.append(transform(img.convert("RGB"))) + with torch.no_grad(): + model(torch.stack(tensors).to(device)) + if (start // batch_size) % 20 == 0: + print(f" {start * 100 // n}%", flush=True) + + embeddings = torch.cat(features).numpy() + print(f" Done in {time.time() - t0:.1f}s — shape: {embeddings.shape}") + CACHE_DIR.mkdir(parents=True, exist_ok=True) + np.save(cache_path, embeddings) + return embeddings + + +# 3. Image encoding +def encode_images(ds, size: int = 96) -> list[str]: + """Encode bird images as base64 JPEG data URIs.""" + print(f" Encoding {len(ds):,} images for tooltips...") + uris: list[str] = [] + for i in range(len(ds)): + img = ds[i]["image"] + if not isinstance(img, Image.Image): + img = Image.fromarray(np.array(img)) + img = img.convert("RGB").resize((size, size), Image.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=80) + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + uris.append(f"data:image/jpeg;base64,{b64}") + return uris + + +# 4. Analysis +def _clean_name(name: str) -> str: + return name.replace("_", " ") + + +def analyze_birds( + model: TMAP, + labels: np.ndarray, + species_names: list[str], +) -> str: + """Generate bird species analysis report.""" + tree = model.tree_ + lines: list[str] = [] + w = lines.append + + species = np.array([species_names[l] for l in labels]) + families = np.array([_guess_family(s) for s in species]) + + w(f"CUB-200 Bird Species Analysis — {len(labels):,} images, {len(species_names)} species\n") + + # 1. Family boundaries + be_fam = boundary_edges(tree, families) + n_edges = len(tree.edges) + w("1. Bird family boundaries:") + w( + f" Same-family edges: {n_edges - len(be_fam)} / {n_edges} " + f"({(n_edges - len(be_fam)) / n_edges:.1%})" + ) + + cmat_f, cls_f = confusion_matrix_from_tree(tree, families) + np.fill_diagonal(cmat_f, 0) + upper = np.triu_indices_from(cmat_f, k=1) + pair_counts = cmat_f[upper] + cmat_f.T[upper] + top_idx = np.argsort(pair_counts)[::-1][:10] + w(" Most connected families:") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:4d} edges: {cls_f[r]:>15s} <-> {cls_f[c]}") + w("") + + # 2. Species confusion + cmat_s, cls_s = confusion_matrix_from_tree(tree, species) + np.fill_diagonal(cmat_s, 0) + upper_s = np.triu_indices_from(cmat_s, k=1) + pair_counts_s = cmat_s[upper_s] + cmat_s.T[upper_s] + top_s = np.argsort(pair_counts_s)[::-1][:15] + w("2. Most visually similar species pairs:") + for i in top_s: + if pair_counts_s[i] == 0: + break + r, c = upper_s[0][i], upper_s[1][i] + w( + f" {pair_counts_s[i]:3d} edges: {_clean_name(cls_s[r]):>30s} <-> " + f"{_clean_name(cls_s[c])}" + ) + w("") + + # 3. Subtree purity + purity_f = subtree_purity(tree, families, min_size=10) + valid_f = purity_f[~np.isnan(purity_f)] + purity_s = subtree_purity(tree, species, min_size=10) + valid_s = purity_s[~np.isnan(purity_s)] + w("3. Subtree purity:") + w(f" By family: mean={valid_f.mean():.3f} median={np.median(valid_f):.3f}") + w(f" By species: mean={valid_s.mean():.3f} median={np.median(valid_s):.3f}\n") + + # 4. Morphological paths + w("4. Morphological paths between species:") + path_pairs = [ + ("Ruby_throated_Hummingbird", "Belted_Kingfisher"), + ("Bald_Eagle", "Osprey"), + ("American_Crow", "Common_Raven"), + ("House_Sparrow", "Song_Sparrow"), + ("Laysan_Albatross", "Herring_Gull"), + ("Red_headed_Woodpecker", "Downy_Woodpecker"), + ("Scarlet_Tanager", "Northern_Cardinal"), + ("Mallard", "Pelican"), + ] + for sp_a, sp_b in path_pairs: + # Fuzzy match — species names might have slight variations + idx_a = np.where([sp_a.lower() in s.lower() for s in species])[0] + idx_b = np.where([sp_b.lower() in s.lower() for s in species])[0] + if len(idx_a) == 0 or len(idx_b) == 0: + w(f" {_clean_name(sp_a)} -> {_clean_name(sp_b)}: (not found)") + continue + + node_a, node_b = int(idx_a[0]), int(idx_b[0]) + try: + path_nodes = tree.path(node_a, node_b) + except IndexError: + w(f" {_clean_name(sp_a):>30s} -> {_clean_name(sp_b):<30s} (disconnected)") + continue + + path_species = species[path_nodes] + path_families = families[path_nodes] + unique_sp = [] + for s in path_species: + if not unique_sp or unique_sp[-1] != s: + unique_sp.append(_clean_name(s)) + + w( + f" {_clean_name(sp_a):>30s} -> {_clean_name(sp_b):<30s} " + f"hops={len(path_nodes):3d} species={len(set(path_species))} " + f"families={len(set(path_families))}" + ) + if len(unique_sp) <= 6: + w(f" Route: {' -> '.join(unique_sp)}") + else: + route = unique_sp[:3] + ["..."] + unique_sp[-2:] + w(f" Route: {' -> '.join(route)}") + w("") + + return "\n".join(lines) + + +# 5. Visualization + + +def create_visualization( + model: TMAP, + labels: np.ndarray, + species_names: list[str], + image_uris: list[str], +): + viz = model.to_tmapviz() + viz.title = f"CUB-200: {len(labels):,} Bird Images" + + species = [species_names[l] for l in labels] + families = [_guess_family(s) for s in species] + + viz.add_color_layout("family", families, categorical=True, color="tab20") + + viz.add_images(image_uris, tooltip_size=100) + viz.add_label("species", [_clean_name(s) for s in species]) + + return viz + + +# Main + + +def main() -> None: + parser = argparse.ArgumentParser(description="CUB-200 Birds TMAP") + parser.add_argument("--k", type=int, default=15) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + args = parser.parse_args() + + device = torch.device( + args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Device: {device}") + + # Load + ds, labels, species_names = load_cub200() + + # Embeddings + print("Extracting ResNet-50 embeddings...") + embeddings = extract_embeddings(ds, args.batch_size, device) + + # Build TMAP + print(f"Building TMAP (metric='cosine', k={args.k})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=args.k, + layout_iterations=1000, + seed=42, + ).fit(embeddings.astype(np.float32)) + print(f" Done in {time.time() - t0:.1f}s") + + # Analysis + report = analyze_birds(model, labels, species_names) + report_path = OUTPUT_DIR / "cub200_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + # Visualization + print("Encoding images for tooltips...") + image_uris = encode_images(ds) + + print("Building visualization...") + viz = create_visualization(model, labels, species_names, image_uris) + html_path = viz.write_html(OUTPUT_DIR / "cub200_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/emnist_characters_tmap.py b/examples/emnist_characters_tmap.py new file mode 100644 index 0000000..eb590bc --- /dev/null +++ b/examples/emnist_characters_tmap.py @@ -0,0 +1,361 @@ +"""EMNIST handwritten characters — why OCR confuses letters and digits. + +Build a TMAP of handwritten characters (digits + letters) to reveal +visual similarity bridges: 0 ↔ O ↔ Q ↔ D, or 1 ↔ l ↔ I. Each step +along the path shows a tiny handwriting variation, and the endpoints +are different characters that happen to look alike. + +The tree explains exactly why OCR systems confuse certain characters — +you can trace the path and see the gradual morphing at each step. + +Outputs +------- +examples/emnist_tmap.html Interactive TMAP with character tooltips +examples/emnist_report.txt Character confusion analysis + +Data +---- +Downloads EMNIST via torchvision (~500 MB, cached). + +Usage +----- + python examples/emnist_characters_tmap.py + python examples/emnist_characters_tmap.py --max-images 30000 + python examples/emnist_characters_tmap.py --serve +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import time +from pathlib import Path + +import numpy as np +from PIL import Image +from torchvision import datasets + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) + +CACHE_DIR = Path(__file__).parent / "data" / "emnist_cache" +OUTPUT_DIR = Path(__file__).parent + +# EMNIST balanced split: 47 classes +# 0-9 = digits, 10-35 = A-Z, 36-46 = select lowercase (a,b,d,e,f,g,h,n,q,r,t) +CHAR_MAP_BALANCED = list("0123456789") + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + list("abdefghnqrt") + +# Group characters by visual shape for analysis +SHAPE_GROUP = { + "0": "round", + "O": "round", + "Q": "round", + "D": "round", + "C": "round", + "G": "round", + "1": "vertical", + "I": "vertical", + "l": "vertical", + "t": "vertical", + "7": "angular", + "T": "angular", + "Y": "angular", + "V": "angular", + "2": "curved", + "S": "curved", + "5": "curved", + "Z": "curved", + "3": "curved", + "8": "looped", + "B": "looped", + "6": "looped", + "9": "looped", + "b": "looped", + "d": "looped", + "q": "looped", + "g": "looped", + "4": "angular", + "A": "angular", + "H": "angular", + "K": "angular", + "M": "angular", + "N": "angular", + "W": "angular", + "E": "angular", + "F": "angular", + "L": "angular", + "P": "mixed", + "R": "mixed", + "J": "mixed", + "U": "mixed", + "X": "angular", + "a": "round", + "e": "round", + "f": "vertical", + "h": "vertical", + "n": "curved", + "r": "curved", +} + + +# 1. Data loading + + +def load_emnist(max_images: int | None) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Load EMNIST balanced split, return (images_flat, labels, char_names).""" + data_dir = Path(__file__).parent / "data" / "emnist" + + print("Loading EMNIST (balanced split)...") + ds_train = datasets.EMNIST(root=str(data_dir), split="balanced", train=True, download=True) + ds_test = datasets.EMNIST(root=str(data_dir), split="balanced", train=False, download=True) + + # Combine train + test + all_images = [] + all_labels = [] + for ds in [ds_train, ds_test]: + for img, label in ds: + # EMNIST images are transposed — fix orientation + img = img.transpose(Image.TRANSPOSE) + all_images.append(np.array(img, dtype=np.float32).flatten()) + all_labels.append(label) + + images = np.stack(all_images) + labels = np.array(all_labels, dtype=np.int64) + + print(f" {len(images):,} images, {len(np.unique(labels))} classes, {images.shape[1]}D") + + # Subsample if needed + if max_images and len(images) > max_images: + rng = np.random.RandomState(42) + idx = rng.choice(len(images), max_images, replace=False) + idx.sort() + images = images[idx] + labels = labels[idx] + print(f" Subsampled to {len(images):,}") + + return images, labels, CHAR_MAP_BALANCED + + +# 2. Image encoding for tooltips + + +def encode_char_images( + images_flat: np.ndarray, + size: int = 48, +) -> list[str]: + """Convert flat 784-d vectors back to upscaled images for tooltips.""" + print(f" Encoding {len(images_flat):,} character images ({size}x{size})...") + uris: list[str] = [] + side = int(np.sqrt(images_flat.shape[1])) + for vec in images_flat: + arr = vec.reshape(side, side).astype(np.uint8) + img = Image.fromarray(arr, mode="L") + img = img.resize((size, size), Image.NEAREST) + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + uris.append(f"data:image/png;base64,{b64}") + return uris + + +# 3. Analysis + + +def analyze_characters( + model: TMAP, + labels: np.ndarray, + char_names: list[str], +) -> str: + """Generate character confusion analysis report.""" + tree = model.tree_ + lines: list[str] = [] + w = lines.append + + chars = np.array([char_names[l] for l in labels]) + groups = np.array([SHAPE_GROUP.get(c, "other") for c in chars]) + + n_classes = len(np.unique(labels)) + w("EMNIST Character Confusion Analysis") + w(f" {len(labels):,} images, {n_classes} character classes\n") + + # 1. Character boundaries + be = boundary_edges(tree, chars) + n_edges = len(tree.edges) + w("1. Character boundaries:") + w( + f" Same-character edges: {n_edges - len(be)} / {n_edges} " + f"({(n_edges - len(be)) / n_edges:.1%})" + ) + + # Shape group boundaries + be_g = boundary_edges(tree, groups) + w( + f" Same-shape-group edges: {n_edges - len(be_g)} / {n_edges} " + f"({(n_edges - len(be_g)) / n_edges:.1%})\n" + ) + + # 2. Most confused character pairs + cmat, classes = confusion_matrix_from_tree(tree, chars) + np.fill_diagonal(cmat, 0) + upper = np.triu_indices_from(cmat, k=1) + pair_counts = cmat[upper] + cmat.T[upper] + top_idx = np.argsort(pair_counts)[::-1][:20] + w("2. Most confused character pairs (shared tree edges):") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:4d} edges: '{classes[r]}' <-> '{classes[c]}'") + w("") + + # 3. Subtree purity + purity = subtree_purity(tree, chars, min_size=20) + valid = purity[~np.isnan(purity)] + purity_g = subtree_purity(tree, groups, min_size=20) + valid_g = purity_g[~np.isnan(purity_g)] + w("3. Subtree purity:") + w(f" By character: mean={valid.mean():.3f} median={np.median(valid):.3f}") + w(f" By shape group: mean={valid_g.mean():.3f} median={np.median(valid_g):.3f}\n") + + # 4. OCR confusion paths + w("4. OCR confusion paths (why these characters get mixed up):") + path_pairs = [ + ("0", "O"), # digit vs letter, identical + ("0", "D"), # round shapes + ("1", "I"), # vertical strokes + ("5", "S"), # similar curves + ("8", "B"), # looped + ("6", "b"), # mirror + ("9", "q"), # mirror + ("2", "Z"), # similar shape + ("Q", "9"), # tail similarity + ("W", "M"), # inversions + ("g", "9"), # looped tail + ("A", "4"), # angular, pointed top + ] + for ch_a, ch_b in path_pairs: + idx_a = np.where(chars == ch_a)[0] + idx_b = np.where(chars == ch_b)[0] + if len(idx_a) == 0 or len(idx_b) == 0: + continue + + node_a, node_b = int(idx_a[0]), int(idx_b[0]) + try: + path_nodes = tree.path(node_a, node_b) + except IndexError: + w(f" '{ch_a}' -> '{ch_b}': (disconnected)") + continue + + path_chars = chars[path_nodes] + unique = [] + for ch in path_chars: + if not unique or unique[-1] != ch: + unique.append(ch) + + w( + f" '{ch_a}' -> '{ch_b}': {len(path_nodes):3d} hops, " + f"characters crossed: {len(set(path_chars))}" + ) + if len(unique) <= 10: + route = " -> ".join(f"'{c}'" for c in unique) + else: + route = " -> ".join(f"'{c}'" for c in unique[:5]) + route += " -> ... -> " + " -> ".join(f"'{c}'" for c in unique[-3:]) + w(f" Route: {route}") + w("") + + # 5. Digit vs letter separation + is_digit = np.array(["digit" if l < 10 else "letter" for l in labels]) + be_dl = boundary_edges(tree, is_digit) + w("5. Digit vs letter separation:") + w( + f" Same-type edges: {n_edges - len(be_dl)} / {n_edges} " + f"({(n_edges - len(be_dl)) / n_edges:.1%})" + ) + w(f" Cross-type edges: {len(be_dl)} ({len(be_dl) / n_edges:.1%})") + + return "\n".join(lines) + + +# 4. Visualization + + +def create_visualization( + model: TMAP, + labels: np.ndarray, + char_names: list[str], + image_uris: list[str], +): + """Build TmapViz with character coloring and image tooltips.""" + viz = model.to_tmapviz() + viz.title = f"EMNIST — {len(labels):,} Handwritten Characters" + + chars = [char_names[l] for l in labels] + groups = [SHAPE_GROUP.get(c, "other") for c in chars] + char_type = ["digit" if l < 10 else "letter" for l in labels] + + viz.add_color_layout("type", char_type, categorical=True, color="Set1") + viz.add_color_layout("shape group", groups, categorical=True, color="tab10") + + viz.add_images(image_uris, tooltip_size=48) + viz.add_label("character", [f"'{c}'" for c in chars]) + + return viz + + +# Main + + +def main() -> None: + parser = argparse.ArgumentParser(description="EMNIST Character TMAP") + parser.add_argument("--max-images", type=int, default=20000) + parser.add_argument("--k", type=int, default=15) + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + args = parser.parse_args() + + # Load + images, labels, char_names = load_emnist( + args.max_images if args.max_images > 0 else None, + ) + + # Build TMAP (raw pixels + cosine, like MNIST example) + print(f"Building TMAP (metric='cosine', k={args.k})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=args.k, + layout_iterations=1000, + seed=42, + ).fit(images) + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s") + + # Analysis + report = analyze_characters(model, labels, char_names) + report_path = OUTPUT_DIR / "emnist_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + # Visualization + print("Encoding character images...") + image_uris = encode_char_images(images) + + print("Building visualization...") + viz = create_visualization(model, labels, char_names, image_uris) + html_path = viz.write_html(OUTPUT_DIR / "emnist_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/flowers_tmap.py b/examples/flowers_tmap.py new file mode 100644 index 0000000..6e863ee --- /dev/null +++ b/examples/flowers_tmap.py @@ -0,0 +1,557 @@ +"""Oxford Flowers 102 — morphological gradients with TMAP. + +Build a TMAP of 8K flower images across 102 species. The tree reveals +visual morphological gradients: paths trace smooth transitions between +flower shapes, colors, and petal structures. + +Paths like sunflower -> daisy -> dandelion (yellow, radial) or +rose -> camellia -> magnolia show clear visual gradients where each +step is a tiny change in petal shape, color, or structure. + +Outputs +------- +examples/flowers_tmap.html Interactive TMAP with flower tooltips +examples/flowers_report.txt Morphological analysis report + +Data +---- +Downloads Oxford Flowers 102 via torchvision (~350 MB, cached). + +Usage +----- + python examples/flowers_tmap.py + python examples/flowers_tmap.py --serve + python examples/flowers_tmap.py --device cuda + +Requirements +------------ + pip install torch torchvision +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision import datasets, models, transforms + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) + +CACHE_DIR = Path(__file__).parent / "data" / "flowers_cache" +OUTPUT_DIR = Path(__file__).parent + +# 102 flower category names (0-indexed, matching torchvision labels) +FLOWER_NAMES = [ + "pink primrose", + "hard-leaved pocket orchid", + "canterbury bells", + "sweet pea", + "english marigold", + "tiger lily", + "moon orchid", + "bird of paradise", + "monkshood", + "globe thistle", + "snapdragon", + "colt's foot", + "king protea", + "spear thistle", + "yellow iris", + "globe-flower", + "purple coneflower", + "peruvian lily", + "balloon flower", + "giant white arum lily", + "fire lily", + "pincushion flower", + "fritillary", + "red ginger", + "grape hyacinth", + "corn poppy", + "prince of wales feathers", + "stemless gentian", + "artichoke", + "sweet william", + "carnation", + "garden phlox", + "love in the mist", + "mexican aster", + "alpine sea holly", + "ruby-lipped cattleya", + "cape flower", + "great masterwort", + "siam tulip", + "lenten rose", + "barbeton daisy", + "daffodil", + "sword lily", + "poinsettia", + "bolero deep blue", + "wallflower", + "marigold", + "buttercup", + "oxeye daisy", + "common dandelion", + "petunia", + "wild pansy", + "primula", + "sunflower", + "pelargonium", + "bishop of llandaff", + "gaura", + "geranium", + "orange dahlia", + "pink-yellow dahlia", + "cautleya spicata", + "japanese anemone", + "black-eyed susan", + "silverbush", + "californian poppy", + "osteospermum", + "spring crocus", + "bearded iris", + "windflower", + "tree poppy", + "gazania", + "azalea", + "water lily", + "rose", + "thorn apple", + "morning glory", + "passion flower", + "lotus", + "toad lily", + "anthurium", + "frangipani", + "clematis", + "hibiscus", + "columbine", + "desert-rose", + "tree mallow", + "magnolia", + "cyclamen", + "watercress", + "canna lily", + "hippeastrum", + "bee balm", + "ball moss", + "foxglove", + "bougainvillea", + "camellia", + "mallow", + "mexican petunia", + "bromelia", + "blanket flower", + "trumpet creeper", + "blackberry lily", +] + +# Group flowers by visual characteristics for supercategory analysis +FLOWER_GROUP = { + "sunflower": "daisy-like", + "oxeye daisy": "daisy-like", + "common dandelion": "daisy-like", + "barbeton daisy": "daisy-like", + "black-eyed susan": "daisy-like", + "gazania": "daisy-like", + "osteospermum": "daisy-like", + "blanket flower": "daisy-like", + "mexican aster": "daisy-like", + "colt's foot": "daisy-like", + "marigold": "daisy-like", + "english marigold": "daisy-like", + "rose": "rose-like", + "camellia": "rose-like", + "carnation": "rose-like", + "lenten rose": "rose-like", + "sweet william": "rose-like", + "garden phlox": "rose-like", + "azalea": "rose-like", + "pink primrose": "rose-like", + "pelargonium": "rose-like", + "geranium": "rose-like", + "mallow": "rose-like", + "tree mallow": "rose-like", + "hibiscus": "rose-like", + "desert-rose": "rose-like", + "moon orchid": "orchid-like", + "hard-leaved pocket orchid": "orchid-like", + "ruby-lipped cattleya": "orchid-like", + "siam tulip": "orchid-like", + "tiger lily": "lily-like", + "fire lily": "lily-like", + "giant white arum lily": "lily-like", + "canna lily": "lily-like", + "water lily": "lily-like", + "sword lily": "lily-like", + "toad lily": "lily-like", + "peruvian lily": "lily-like", + "blackberry lily": "lily-like", + "bearded iris": "iris-like", + "yellow iris": "iris-like", + "spring crocus": "iris-like", + "morning glory": "trumpet-shaped", + "trumpet creeper": "trumpet-shaped", + "petunia": "trumpet-shaped", + "foxglove": "trumpet-shaped", + "snapdragon": "trumpet-shaped", + "canterbury bells": "trumpet-shaped", + "mexican petunia": "trumpet-shaped", + "passion flower": "exotic", + "bird of paradise": "exotic", + "anthurium": "exotic", + "red ginger": "exotic", + "frangipani": "exotic", + "bougainvillea": "exotic", + "king protea": "exotic", + "bromelia": "exotic", + "cautleya spicata": "exotic", + "lotus": "aquatic", + "hippeastrum": "bulb", + "grape hyacinth": "bulb", + "daffodil": "bulb", + "buttercup": "simple", + "windflower": "simple", + "corn poppy": "simple", + "californian poppy": "simple", + "tree poppy": "simple", + "wild pansy": "simple", + "primula": "simple", + "sweet pea": "simple", + "love in the mist": "simple", + "balloon flower": "simple", + "columbine": "simple", + "magnolia": "tree flower", + "clematis": "vine", + "cyclamen": "bulb", + "poinsettia": "exotic", + "bee balm": "exotic", + "wallflower": "simple", + "artichoke": "thistle-like", + "globe thistle": "thistle-like", + "spear thistle": "thistle-like", + "alpine sea holly": "thistle-like", + "purple coneflower": "thistle-like", + "pincushion flower": "thistle-like", + "ball moss": "other", +} + + +def load_flowers() -> tuple[list[Image.Image], np.ndarray]: + """Load all splits of Flowers 102, return (images, labels).""" + data_dir = Path(__file__).parent / "data" / "flowers" + all_images: list[Image.Image] = [] + all_labels: list[int] = [] + + for split in ("train", "val", "test"): + ds = datasets.Flowers102( + root=str(data_dir), + split=split, + download=True, + ) + for img, label in ds: + if not isinstance(img, Image.Image): + img = Image.fromarray(np.array(img)) + all_images.append(img.convert("RGB")) + all_labels.append(label) + + labels = np.array(all_labels, dtype=np.int64) + print(f" {len(all_images)} images, {len(np.unique(labels))} species") + return all_images, labels + + +def extract_embeddings( + images: list[Image.Image], + batch_size: int, + device: torch.device, +) -> np.ndarray: + """Extract ResNet-50 avgpool embeddings (2048-d), cached to disk.""" + cache_path = CACHE_DIR / f"embeddings_{len(images)}.npy" + if cache_path.exists(): + print(f" Loading cached embeddings: {cache_path.name}") + return np.load(cache_path) + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) + model.eval().to(device) + + features: list[torch.Tensor] = [] + + def _hook(_mod: nn.Module, _inp: tuple, out: torch.Tensor) -> None: + features.append(out.squeeze(-1).squeeze(-1).cpu()) + + model.avgpool.register_forward_hook(_hook) + + print(f" Extracting embeddings ({len(images)} images)...") + t0 = time.time() + for start in range(0, len(images), batch_size): + end = min(start + batch_size, len(images)) + batch = [transform(img) for img in images[start:end]] + tensor = torch.stack(batch).to(device) + with torch.no_grad(): + model(tensor) + + embeddings = torch.cat(features).numpy() + print(f" Done in {time.time() - t0:.1f}s — shape: {embeddings.shape}") + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + np.save(cache_path, embeddings) + return embeddings + + +def encode_images(images: list[Image.Image], size: int = 96) -> list[str]: + """Encode images as base64 JPEG data URIs.""" + print(f" Encoding {len(images)} images for tooltips ({size}x{size})...") + uris: list[str] = [] + for img in images: + img = img.resize((size, size), Image.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=80) + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + uris.append(f"data:image/jpeg;base64,{b64}") + return uris + + +def analyze_flowers( + model: TMAP, + labels: np.ndarray, +) -> str: + """Generate flower morphological analysis report.""" + tree = model.tree_ + lines: list[str] = [] + w = lines.append + + species = np.array([FLOWER_NAMES[l] for l in labels]) + groups = np.array([FLOWER_GROUP.get(s, "other") for s in species]) + + n_species = len(np.unique(labels)) + w("Oxford Flowers 102 — Morphological Analysis") + w(f" {len(labels):,} images, {n_species} species\n") + + # 1. Morphological group boundaries + be_group = boundary_edges(tree, groups) + n_edges = len(tree.edges) + w("1. Morphological group boundaries:") + w( + f" Same-group edges: {n_edges - len(be_group)} / {n_edges} " + f"({(n_edges - len(be_group)) / n_edges:.1%})" + ) + + cmat_g, cls_g = confusion_matrix_from_tree(tree, groups) + np.fill_diagonal(cmat_g, 0) + upper = np.triu_indices_from(cmat_g, k=1) + pair_counts = cmat_g[upper] + cmat_g.T[upper] + top_idx = np.argsort(pair_counts)[::-1][:10] + w(" Most connected groups:") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:4d} edges: {cls_g[r]:>15s} <-> {cls_g[c]}") + w("") + + # 2. Species-level boundaries + be_species = boundary_edges(tree, species) + w("2. Species boundaries:") + w( + f" Same-species edges: {n_edges - len(be_species)} / {n_edges} " + f"({(n_edges - len(be_species)) / n_edges:.1%})" + ) + + cmat_s, cls_s = confusion_matrix_from_tree(tree, species) + np.fill_diagonal(cmat_s, 0) + upper_s = np.triu_indices_from(cmat_s, k=1) + pair_counts_s = cmat_s[upper_s] + cmat_s.T[upper_s] + top_s = np.argsort(pair_counts_s)[::-1][:15] + w(" Most visually similar species pairs:") + for i in top_s: + if pair_counts_s[i] == 0: + break + r, c = upper_s[0][i], upper_s[1][i] + w(f" {pair_counts_s[i]:3d} edges: {cls_s[r]:>25s} <-> {cls_s[c]}") + w("") + + # 3. Subtree purity + purity_g = subtree_purity(tree, groups, min_size=10) + valid_g = purity_g[~np.isnan(purity_g)] + purity_s = subtree_purity(tree, species, min_size=10) + valid_s = purity_s[~np.isnan(purity_s)] + w("3. Subtree purity:") + w(f" By group: mean={valid_g.mean():.3f} median={np.median(valid_g):.3f}") + w(f" By species: mean={valid_s.mean():.3f} median={np.median(valid_s):.3f}\n") + + # 4. Morphological paths + w("4. Morphological paths (species to species along the tree):") + path_pairs = [ + ("sunflower", "oxeye daisy"), # both daisy-like, yellow + ("sunflower", "rose"), # very different morphology + ("rose", "camellia"), # visually similar + ("water lily", "lotus"), # aquatic flowers + ("bearded iris", "moon orchid"), # complex petals + ("corn poppy", "californian poppy"), # both poppies + ("trumpet creeper", "morning glory"), # both trumpet-shaped + ("sunflower", "passion flower"), # radial vs complex + ("king protea", "artichoke"), # both spiky/thistle-like + ] + for sp_a, sp_b in path_pairs: + idx_a = np.where(species == sp_a)[0] + idx_b = np.where(species == sp_b)[0] + if len(idx_a) == 0 or len(idx_b) == 0: + w(f" {sp_a} -> {sp_b}: (species not found)") + continue + + node_a, node_b = int(idx_a[0]), int(idx_b[0]) + try: + path_nodes = tree.path(node_a, node_b) + except IndexError: + w(f" {sp_a:>25s} -> {sp_b:<25s} (disconnected)") + continue + + path_species = species[path_nodes] + unique_species = [] + for s in path_species: + if not unique_species or unique_species[-1] != s: + unique_species.append(s) + + w( + f" {sp_a:>25s} -> {sp_b:<25s} " + f"hops={len(path_nodes):3d} species crossed={len(set(path_species))}" + ) + # Show route through species + if len(unique_species) <= 8: + w(f" Route: {' -> '.join(unique_species)}") + else: + route = unique_species[:4] + ["..."] + unique_species[-3:] + w(f" Route: {' -> '.join(route)}") + w("") + + # 5. Per-group coherence + w("5. Per-group tree coherence:") + w(f" {'Group':>18s} {'Count':>6s} {'Boundary %':>10s}") + group_counts = {} + for g in groups: + group_counts[g] = group_counts.get(g, 0) + 1 + + for group in sorted(group_counts, key=group_counts.get, reverse=True): + mask = groups == group + grp_idx = set(np.where(mask)[0]) + boundary = 0 + internal = 0 + for s, t in tree.edges: + s_in = s in grp_idx + t_in = t in grp_idx + if s_in and t_in: + internal += 1 + elif s_in or t_in: + boundary += 1 + total = boundary + internal + bfrac = boundary / total if total > 0 else 0.0 + w(f" {group:>18s} {group_counts[group]:6d} {bfrac:10.1%}") + + return "\n".join(lines) + + +# 5. Visualization + + +def create_visualization( + model: TMAP, + labels: np.ndarray, + image_uris: list[str], +): + """Build TmapViz with species coloring and flower tooltips.""" + viz = model.to_tmapviz() + viz.title = f"Oxford Flowers 102 — {len(labels):,} Images" + + species = [FLOWER_NAMES[l] for l in labels] + groups = [FLOWER_GROUP.get(s, "other") for s in species] + + viz.add_label("species", species) + viz.add_color_layout("group", groups, categorical=True, color="tab20") + + viz.add_images(image_uris, tooltip_size=100) + viz.add_label("flower", species) + + return viz + + +# Main + + +def main() -> None: + parser = argparse.ArgumentParser(description="Oxford Flowers 102 TMAP") + parser.add_argument("--k", type=int, default=15, help="Number of neighbors") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + args = parser.parse_args() + + device = torch.device( + args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"Device: {device}") + + # Load data + print("Loading Flowers 102...") + images, labels = load_flowers() + + # Embeddings + print("Extracting ResNet-50 embeddings...") + embeddings = extract_embeddings(images, args.batch_size, device) + + # Build TMAP + print(f"Building TMAP (metric='cosine', k={args.k})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=args.k, + layout_iterations=1000, + seed=42, + ).fit(embeddings.astype(np.float32)) + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s") + + # Analysis + report = analyze_flowers(model, labels) + report_path = OUTPUT_DIR / "flowers_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + # Visualization + print("Encoding images for tooltips...") + image_uris = encode_images(images) + + print("Building visualization...") + viz = create_visualization(model, labels, image_uris) + html_path = viz.write_html(OUTPUT_DIR / "flowers_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/wikiart_tmap.py b/examples/wikiart_tmap.py new file mode 100644 index 0000000..115b3e5 --- /dev/null +++ b/examples/wikiart_tmap.py @@ -0,0 +1,452 @@ +"""WikiArt — art history as a navigable tree. + +Build a TMAP of paintings from WikiArt, colored by artistic style. +Trace paths between art movements to see how painting styles evolve +through visual similarity: from Impressionism through Post-Impressionism +to Cubism and Abstract — each step a tiny visual shift, the endpoints +radically different. + +Outputs +------- +examples/wikiart_tmap.html Interactive TMAP with painting tooltips +examples/wikiart_report.txt Art style analysis report + +Data +---- +Downloads WikiArt dataset from HuggingFace (~6 GB, cached after first run). + +Usage +----- + python examples/wikiart_tmap.py + python examples/wikiart_tmap.py --max-images 10000 + python examples/wikiart_tmap.py --serve + +Requirements +------------ + pip install datasets torch torchvision +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torchvision import models, transforms + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) + +CACHE_DIR = Path(__file__).parent / "data" / "wikiart_cache" +OUTPUT_DIR = Path(__file__).parent + +# Map styles to broad eras for supercategory analysis +STYLE_ERA = { + "Early_Renaissance": "Renaissance", + "High_Renaissance": "Renaissance", + "Northern_Renaissance": "Renaissance", + "Mannerism_Late_Renaissance": "Renaissance", + "Baroque": "Classical", + "Rococo": "Classical", + "Romanticism": "19th Century", + "Realism": "19th Century", + "Impressionism": "19th Century", + "Post_Impressionism": "19th Century", + "Pointillism": "19th Century", + "Art_Nouveau_Modern": "Early Modern", + "Symbolism": "Early Modern", + "Fauvism": "Early Modern", + "Expressionism": "Early Modern", + "Cubism": "Modern", + "Analytical_Cubism": "Modern", + "Synthetic_Cubism": "Modern", + "Naive_Art_Primitivism": "Modern", + "Abstract_Expressionism": "Contemporary", + "Action_painting": "Contemporary", + "Color_Field_Painting": "Contemporary", + "Minimalism": "Contemporary", + "Pop_Art": "Contemporary", + "Contemporary_Realism": "Contemporary", + "New_Realism": "Contemporary", + "Ukiyo_e": "East Asian", +} + + +# 1. Data loading + + +def load_wikiart(max_images: int | None) -> tuple: + """Load WikiArt from HuggingFace. + + Returns (images, styles, artists, style_names, artist_names). + """ + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "HuggingFace datasets library is required.\nInstall with: pip install datasets" + ) + + print("Loading WikiArt from HuggingFace...") + ds = load_dataset("huggan/wikiart", split="train") + style_names = ds.features["style"].names + artist_names = ds.features["artist"].names + print(f" {len(ds):,} paintings, {len(style_names)} styles, {len(artist_names)} artists") + + if max_images and len(ds) > max_images: + print(f" Subsampling to {max_images:,} images...") + ds = ds.shuffle(seed=42).select(range(max_images)) + + # Extract metadata + styles = np.array(ds["style"]) + artists = np.array(ds["artist"]) + + return ds, styles, artists, style_names, artist_names + + +# 2. Embedding extraction + + +def extract_embeddings( + ds, + batch_size: int, + device: torch.device, + cache_tag: str, +) -> np.ndarray: + """Extract ResNet-50 avgpool embeddings (2048-d), cached to disk.""" + cache_path = CACHE_DIR / f"embeddings_{cache_tag}.npy" + if cache_path.exists(): + print(f" Loading cached embeddings: {cache_path.name}") + return np.load(cache_path) + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) + model.eval().to(device) + + features: list[torch.Tensor] = [] + + def _hook(_mod: nn.Module, _inp: tuple, out: torch.Tensor) -> None: + features.append(out.squeeze(-1).squeeze(-1).cpu()) + + model.avgpool.register_forward_hook(_hook) + + n = len(ds) + print(f" Extracting embeddings ({n:,} images, batch_size={batch_size})...") + t0 = time.time() + + for start in range(0, n, batch_size): + end = min(start + batch_size, n) + batch = ds[start:end] + images = batch["image"] + + tensors = [] + for img in images: + if not isinstance(img, Image.Image): + img = Image.fromarray(np.array(img)) + tensors.append(transform(img.convert("RGB"))) + + batch_tensor = torch.stack(tensors).to(device) + with torch.no_grad(): + model(batch_tensor) + + if (start // batch_size) % 20 == 0: + pct = start * 100 // n + print(f" {pct}% ({start:,}/{n:,})", flush=True) + + embeddings = torch.cat(features).numpy() + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s — shape: {embeddings.shape}") + + CACHE_DIR.mkdir(parents=True, exist_ok=True) + np.save(cache_path, embeddings) + return embeddings + + +# 3. Image encoding for tooltips + + +def encode_images(ds, size: int = 80, quality: int = 70) -> list[str]: + """Encode images as base64 JPEG data URIs for tooltips.""" + print(f" Encoding {len(ds):,} images for tooltips ({size}x{size})...") + uris: list[str] = [] + for i in range(len(ds)): + img = ds[i]["image"] + if not isinstance(img, Image.Image): + img = Image.fromarray(np.array(img)) + img = img.convert("RGB").resize((size, size), Image.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=quality) + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + uris.append(f"data:image/jpeg;base64,{b64}") + return uris + + +# 4. Analysis + + +def analyze_styles( + model: TMAP, + styles: np.ndarray, + artists: np.ndarray, + style_names: list[str], + artist_names: list[str], +) -> str: + """Generate art style analysis report.""" + tree = model.tree_ + lines: list[str] = [] + w = lines.append + + style_labels = np.array([style_names[s] for s in styles]) + eras = np.array([STYLE_ERA.get(s, "Other") for s in style_labels]) + + w(f"WikiArt Style Analysis — {len(styles):,} paintings, {len(style_names)} styles\n") + + # 1. Era boundaries + be_era = boundary_edges(tree, eras) + n_edges = len(tree.edges) + w("1. Art era boundaries:") + w( + f" Same-era edges: {n_edges - len(be_era)} / {n_edges} " + f"({(n_edges - len(be_era)) / n_edges:.1%})" + ) + + cmat_era, cls_era = confusion_matrix_from_tree(tree, eras) + np.fill_diagonal(cmat_era, 0) + upper = np.triu_indices_from(cmat_era, k=1) + pair_counts_era = cmat_era[upper] + cmat_era.T[upper] + top_era = np.argsort(pair_counts_era)[::-1][:10] + w(" Most connected eras:") + for i in top_era: + if pair_counts_era[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts_era[i]:4d} edges: {cls_era[r]:>15s} <-> {cls_era[c]}") + w("") + + # 2. Style boundaries + be_style = boundary_edges(tree, style_labels) + w("2. Style boundaries:") + w( + f" Same-style edges: {n_edges - len(be_style)} / {n_edges} " + f"({(n_edges - len(be_style)) / n_edges:.1%})" + ) + + cmat, classes = confusion_matrix_from_tree(tree, style_labels) + np.fill_diagonal(cmat, 0) + upper = np.triu_indices_from(cmat, k=1) + pair_counts = cmat[upper] + cmat.T[upper] + top_idx = np.argsort(pair_counts)[::-1][:15] + w(" Most connected style pairs:") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:4d} edges: {classes[r]:>30s} <-> {classes[c]}") + w("") + + # 3. Subtree purity + purity_style = subtree_purity(tree, style_labels, min_size=20) + valid = purity_style[~np.isnan(purity_style)] + purity_era = subtree_purity(tree, eras, min_size=20) + valid_era = purity_era[~np.isnan(purity_era)] + w("3. Subtree purity:") + w(f" By era: mean={valid_era.mean():.3f} median={np.median(valid_era):.3f}") + w(f" By style: mean={valid.mean():.3f} median={np.median(valid):.3f}\n") + + # 4. Art historical paths + w("4. Art historical paths:") + path_pairs = [ + ("Impressionism", "Cubism"), + ("Impressionism", "Abstract_Expressionism"), + ("High_Renaissance", "Impressionism"), + ("Baroque", "Pop_Art"), + ("Realism", "Minimalism"), + ("Romanticism", "Expressionism"), + ("Ukiyo_e", "Impressionism"), + ] + for style_a, style_b in path_pairs: + idx_a = np.where(style_labels == style_a)[0] + idx_b = np.where(style_labels == style_b)[0] + if len(idx_a) == 0 or len(idx_b) == 0: + w(f" {style_a} -> {style_b}: (style not found in subsample)") + continue + + node_a, node_b = int(idx_a[0]), int(idx_b[0]) + try: + path_nodes = tree.path(node_a, node_b) + except IndexError: + w(f" {style_a:>30s} -> {style_b:<30s} (disconnected)") + continue + + path_styles = style_labels[path_nodes] + unique_styles = [] + for s in path_styles: + if not unique_styles or unique_styles[-1] != s: + unique_styles.append(s) + + w( + f" {style_a:>30s} -> {style_b:<30s} " + f"hops={len(path_nodes):4d} styles crossed={len(set(path_styles))}" + ) + w(f" Route: {' -> '.join(unique_styles[:8])}{'...' if len(unique_styles) > 8 else ''}") + w("") + + # 5. Per-style coherence + w("5. Per-style tree coherence:") + w(f" {'Style':>30s} {'Count':>6s} {'Boundary %':>10s}") + style_counts = {} + for s in style_labels: + style_counts[s] = style_counts.get(s, 0) + 1 + + for style in sorted(style_counts, key=style_counts.get, reverse=True): + mask = style_labels == style + style_idx = set(np.where(mask)[0]) + boundary = 0 + internal = 0 + for s, t in tree.edges: + s_in = s in style_idx + t_in = t in style_idx + if s_in and t_in: + internal += 1 + elif s_in or t_in: + boundary += 1 + total = boundary + internal + bfrac = boundary / total if total > 0 else 0.0 + w(f" {style:>30s} {style_counts[style]:6d} {bfrac:10.1%}") + + return "\n".join(lines) + + +# 5. Visualization + + +def create_visualization( + model: TMAP, + styles: np.ndarray, + artists: np.ndarray, + style_names: list[str], + artist_names: list[str], + image_uris: list[str], +): + """Build TmapViz with style coloring and painting tooltips.""" + viz = model.to_tmapviz() + viz.title = f"WikiArt — {len(styles):,} Paintings" + + style_labels = [style_names[s] for s in styles] + eras = [STYLE_ERA.get(s, "Other") for s in style_labels] + + # Style and era coloring + viz.add_color_layout("style", style_labels, categorical=True, color="tab20") + viz.add_color_layout("era", eras, categorical=True, color="Set2") + + # Top artists only (too many to show all) + artist_labels = [artist_names[a] for a in artists] + artist_counts = {} + for a in artist_labels: + artist_counts[a] = artist_counts.get(a, 0) + 1 + top_artists = {a for a, c in sorted(artist_counts.items(), key=lambda x: -x[1])[:30]} + artist_display = [a if a in top_artists else "Other" for a in artist_labels] + viz.add_color_layout("artist (top 30)", artist_display, categorical=True, color="tab20") + + # Tooltips + viz.add_images(image_uris, tooltip_size=100) + viz.add_label( + "painting", [f"{style_names[s]} | {artist_names[a]}" for s, a in zip(styles, artists)] + ) + + return viz + + +# Main + + +def main() -> None: + parser = argparse.ArgumentParser(description="WikiArt TMAP") + parser.add_argument( + "--max-images", + type=int, + default=20000, + help="Maximum images to use (default: 20000, set 0 for all)", + ) + parser.add_argument("--k", type=int, default=15, help="Number of neighbors") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + args = parser.parse_args() + + device = torch.device( + args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") + ) + max_img = args.max_images if args.max_images > 0 else None + print(f"Device: {device}") + + # Load data + ds, styles, artists, style_names, artist_names = load_wikiart(max_img) + n = len(ds) + + # Extract embeddings + cache_tag = f"resnet50_{n}" + print("Extracting ResNet-50 embeddings...") + embeddings = extract_embeddings(ds, args.batch_size, device, cache_tag) + + # Build TMAP + print(f"Building TMAP (metric='cosine', k={args.k})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=args.k, + layout_iterations=1000, + seed=42, + ).fit(embeddings.astype(np.float32)) + elapsed = time.time() - t0 + print(f" Done in {elapsed:.1f}s") + + # Analysis + report = analyze_styles(model, styles, artists, style_names, artist_names) + report_path = OUTPUT_DIR / "wikiart_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + # Visualization + print("Encoding images for tooltips...") + image_uris = encode_images(ds, size=80, quality=70) + + print("Building visualization...") + viz = create_visualization( + model, + styles, + artists, + style_names, + artist_names, + image_uris, + ) + html_path = viz.write_html(OUTPUT_DIR / "wikiart_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/word_embeddings_tmap.py b/examples/word_embeddings_tmap.py new file mode 100644 index 0000000..94da700 --- /dev/null +++ b/examples/word_embeddings_tmap.py @@ -0,0 +1,1488 @@ +"""Word embedding TMAP — semantic map of English words. + +Embed ~800 common English words with a sentence-transformer, build a TMAP +with cosine metric, and explore semantic neighborhoods: animals cluster near +animals, countries near countries, and the tree paths between them reveal +thematic transitions (e.g., "salmon → trout → bass → guitar"). + +Outputs +------- +examples/word_embeddings_tmap.html Interactive TMAP — hover to read words +examples/word_embeddings_report.txt Semantic analysis report + +Usage +----- + python examples/word_embeddings_tmap.py + python examples/word_embeddings_tmap.py --serve + +Requirements +------------ + pip install sentence-transformers +""" + +from __future__ import annotations + +import argparse +import time +from pathlib import Path + +import numpy as np + +from tmap import TMAP +from tmap.graph.analysis import ( + boundary_edges, + confusion_matrix_from_tree, + subtree_purity, +) +from tmap.visualization import TmapViz + +OUTPUT_DIR = Path(__file__).parent + +# Vocabulary: ~800 words across 15 semantic categories + +VOCABULARY: dict[str, list[str]] = { + "animals": [ + "dog", + "cat", + "horse", + "cow", + "pig", + "sheep", + "goat", + "chicken", + "duck", + "rabbit", + "mouse", + "rat", + "hamster", + "wolf", + "fox", + "bear", + "deer", + "moose", + "elk", + "lion", + "tiger", + "leopard", + "cheetah", + "elephant", + "giraffe", + "zebra", + "hippo", + "rhino", + "gorilla", + "monkey", + "chimpanzee", + "orangutan", + "baboon", + "lemur", + "eagle", + "hawk", + "falcon", + "owl", + "parrot", + "penguin", + "flamingo", + "pelican", + "crow", + "sparrow", + "robin", + "pigeon", + "swan", + "heron", + "dolphin", + "whale", + "shark", + "salmon", + "trout", + "tuna", + "octopus", + "crab", + "lobster", + "jellyfish", + "starfish", + "seahorse", + "eel", + "butterfly", + "bee", + "ant", + "spider", + "snake", + "scorpion", + "beetle", + "dragonfly", + "mosquito", + "cockroach", + "grasshopper", + "caterpillar", + "turtle", + "frog", + "crocodile", + "lizard", + "panther", + "jaguar", + "hyena", + "buffalo", + "camel", + "llama", + "alpaca", + "donkey", + "mule", + "otter", + "beaver", + "badger", + "raccoon", + "skunk", + "porcupine", + "hedgehog", + "squirrel", + "chipmunk", + "weasel", + "ferret", + "mink", + "koala", + "kangaroo", + "platypus", + "armadillo", + "sloth", + ], + "food": [ + "bread", + "rice", + "pasta", + "noodle", + "pizza", + "burger", + "sandwich", + "salad", + "soup", + "stew", + "steak", + "pork", + "beef", + "lamb", + "fish", + "shrimp", + "sushi", + "cheese", + "butter", + "milk", + "yogurt", + "egg", + "bacon", + "sausage", + "apple", + "banana", + "orange", + "grape", + "strawberry", + "blueberry", + "mango", + "pineapple", + "watermelon", + "peach", + "cherry", + "tomato", + "potato", + "onion", + "garlic", + "pepper", + "carrot", + "broccoli", + "spinach", + "corn", + "mushroom", + "chocolate", + "cake", + "cookie", + "pie", + "donut", + "candy", + "honey", + "sugar", + "salt", + "cinnamon", + "vanilla", + "ginger", + "mustard", + "ketchup", + "vinegar", + "olive", + "avocado", + "coconut", + "almond", + "walnut", + "peanut", + "cashew", + "pistachio", + "hazelnut", + "lemon", + "lime", + "grapefruit", + "cranberry", + "raspberry", + "blackberry", + "plum", + "apricot", + "fig", + "pomegranate", + "papaya", + "guava", + "lychee", + "kiwi", + "melon", + "celery", + "cucumber", + "zucchini", + "eggplant", + "pumpkin", + "squash", + "radish", + "turnip", + "beet", + "cabbage", + "lettuce", + "kale", + "cauliflower", + "asparagus", + "artichoke", + "leek", + "parsley", + "basil", + "oregano", + "thyme", + "rosemary", + "mint", + "cilantro", + "dill", + "cumin", + "turmeric", + "paprika", + "saffron", + "nutmeg", + "clove", + "waffle", + "pancake", + "croissant", + "bagel", + "muffin", + "pretzel", + "tortilla", + "dumpling", + "ravioli", + "lasagna", + "risotto", + "paella", + "curry", + "hummus", + "guacamole", + "tahini", + "tofu", + "tempeh", + ], + "professions": [ + "doctor", + "nurse", + "surgeon", + "dentist", + "pharmacist", + "therapist", + "psychiatrist", + "pediatrician", + "cardiologist", + "dermatologist", + "radiologist", + "anesthesiologist", + "paramedic", + "midwife", + "optometrist", + "teacher", + "professor", + "researcher", + "scientist", + "engineer", + "architect", + "programmer", + "designer", + "artist", + "musician", + "singer", + "actor", + "director", + "writer", + "journalist", + "editor", + "photographer", + "lawyer", + "judge", + "detective", + "police", + "soldier", + "pilot", + "astronaut", + "farmer", + "carpenter", + "plumber", + "electrician", + "mechanic", + "chef", + "waiter", + "bartender", + "accountant", + "banker", + "economist", + "politician", + "diplomat", + "librarian", + "firefighter", + "biologist", + "chemist", + "physicist", + "mathematician", + "geologist", + "astronomer", + "archaeologist", + "anthropologist", + "psychologist", + "sociologist", + "historian", + "philosopher", + "linguist", + "translator", + "curator", + "sculptor", + "painter", + "illustrator", + "animator", + "choreographer", + "dancer", + "comedian", + "magician", + "acrobat", + "veterinarian", + "botanist", + "zoologist", + "ecologist", + "geneticist", + "neuroscientist", + "epidemiologist", + "pathologist", + "toxicologist", + ], + "sports": [ + "football", + "soccer", + "basketball", + "baseball", + "tennis", + "golf", + "swimming", + "running", + "cycling", + "boxing", + "wrestling", + "judo", + "karate", + "fencing", + "archery", + "skiing", + "snowboarding", + "surfing", + "sailing", + "rowing", + "volleyball", + "hockey", + "rugby", + "cricket", + "badminton", + "marathon", + "sprint", + "javelin", + "gymnastics", + "diving", + "weightlifting", + "climbing", + "skateboarding", + "rollerskating", + "triathlon", + "pentathlon", + "polo", + "lacrosse", + "squash", + "handball", + "curling", + "bobsled", + "luge", + "biathlon", + "decathlon", + "hurdles", + "discus", + "shotput", + "hammerthrow", + "highjump", + "longjump", + "polevault", + "steeplechase", + "waterpolo", + "canoeing", + "kayaking", + "windsurfing", + "kitesurfing", + "paragliding", + "skydiving", + "bungee", + "rafting", + "snorkeling", + "scuba", + ], + "music": [ + "guitar", + "piano", + "violin", + "drums", + "trumpet", + "saxophone", + "flute", + "clarinet", + "cello", + "harp", + "harmonica", + "accordion", + "banjo", + "ukulele", + "orchestra", + "symphony", + "concert", + "opera", + "jazz", + "blues", + "rock", + "pop", + "classical", + "hip-hop", + "reggae", + "country", + "folk", + "melody", + "rhythm", + "harmony", + "choir", + "solo", + "trombone", + "oboe", + "bassoon", + "tuba", + "mandolin", + "sitar", + "didgeridoo", + "bagpipes", + "xylophone", + "marimba", + "timpani", + "tambourine", + "bongo", + "tabla", + "synthesizer", + "organ", + "soprano", + "tenor", + "baritone", + "alto", + "bass", + "sonata", + "concerto", + "fugue", + "prelude", + "overture", + "aria", + "ballad", + "anthem", + "lullaby", + "serenade", + "requiem", + "funk", + "soul", + "gospel", + "swing", + "bossa nova", + "salsa", + "techno", + "disco", + "grunge", + "punk", + "metal", + "ska", + ], + "countries": [ + "France", + "Germany", + "Italy", + "Spain", + "Portugal", + "England", + "Scotland", + "Ireland", + "Sweden", + "Norway", + "Denmark", + "Finland", + "Poland", + "Russia", + "Ukraine", + "Greece", + "Turkey", + "Egypt", + "Morocco", + "Nigeria", + "Kenya", + "Ethiopia", + "China", + "Japan", + "Korea", + "India", + "Thailand", + "Vietnam", + "Indonesia", + "Australia", + "Brazil", + "Argentina", + "Mexico", + "Canada", + "Cuba", + "Peru", + "Colombia", + "Chile", + "Iran", + "Iraq", + "Israel", + "Pakistan", + "Netherlands", + "Belgium", + "Switzerland", + "Austria", + "Hungary", + "Romania", + "Bulgaria", + "Serbia", + "Croatia", + "Slovenia", + "Czech Republic", + "Slovakia", + "Lithuania", + "Latvia", + "Estonia", + "Philippines", + "Malaysia", + "Singapore", + "Myanmar", + "Cambodia", + "Mongolia", + "Nepal", + "Bangladesh", + "Sri Lanka", + "Afghanistan", + "Saudi Arabia", + "Yemen", + "Oman", + "Qatar", + "Kuwait", + "Jordan", + "Lebanon", + "Syria", + "Libya", + "Tunisia", + "Algeria", + "Sudan", + "Ghana", + "Senegal", + "Tanzania", + "Uganda", + "Mozambique", + "Zimbabwe", + "Botswana", + "Namibia", + "Madagascar", + "Angola", + "Congo", + "Venezuela", + "Ecuador", + "Bolivia", + "Paraguay", + "Uruguay", + "Panama", + "Costa Rica", + "Guatemala", + "Honduras", + "Jamaica", + "Haiti", + "Dominican Republic", + "Trinidad", + "New Zealand", + "Iceland", + ], + "cities": [ + "Paris", + "London", + "Berlin", + "Rome", + "Madrid", + "Barcelona", + "Amsterdam", + "Vienna", + "Prague", + "Budapest", + "Warsaw", + "Moscow", + "Istanbul", + "Athens", + "Lisbon", + "Dublin", + "Edinburgh", + "Copenhagen", + "Stockholm", + "Oslo", + "Helsinki", + "Brussels", + "Zurich", + "Geneva", + "Tokyo", + "Beijing", + "Shanghai", + "Seoul", + "Bangkok", + "Singapore", + "Mumbai", + "Delhi", + "Dubai", + "Cairo", + "Lagos", + "Nairobi", + "New York", + "Los Angeles", + "Chicago", + "San Francisco", + "Miami", + "Toronto", + "Vancouver", + "Montreal", + "Mexico City", + "Havana", + "Buenos Aires", + "Rio de Janeiro", + "Lima", + "Bogota", + "Santiago", + "Sydney", + "Melbourne", + "Auckland", + "Johannesburg", + "Cape Town", + ], + "colors": [ + "red", + "blue", + "green", + "yellow", + "orange", + "purple", + "pink", + "brown", + "black", + "white", + "gray", + "silver", + "gold", + "crimson", + "scarlet", + "turquoise", + "cyan", + "magenta", + "violet", + "indigo", + "maroon", + "beige", + "ivory", + "coral", + "amber", + "teal", + "navy", + "olive", + "khaki", + "lavender", + "lilac", + "peach", + "rust", + "bronze", + "copper", + "platinum", + "emerald", + "sapphire", + "ruby", + "jade", + "pearl", + "charcoal", + "cream", + "tan", + ], + "emotions": [ + "happy", + "sad", + "angry", + "afraid", + "surprised", + "disgusted", + "anxious", + "nervous", + "excited", + "calm", + "peaceful", + "lonely", + "jealous", + "proud", + "ashamed", + "guilty", + "grateful", + "hopeful", + "frustrated", + "confused", + "bored", + "curious", + "nostalgic", + "melancholy", + "furious", + "terrified", + "delighted", + "amused", + "content", + "miserable", + "cheerful", + "gloomy", + "ecstatic", + "euphoric", + "blissful", + "serene", + "tranquil", + "restless", + "agitated", + "irritated", + "resentful", + "bitter", + "envious", + "compassionate", + "sympathetic", + "empathetic", + "tender", + "enthusiastic", + "passionate", + "indifferent", + "apathetic", + "numb", + "overwhelmed", + "relieved", + "skeptical", + "suspicious", + "paranoid", + "vulnerable", + "insecure", + "confident", + "determined", + "resolute", + "bewildered", + "perplexed", + "astonished", + "stunned", + "dumbfounded", + ], + "body": [ + "head", + "face", + "eye", + "ear", + "nose", + "mouth", + "tongue", + "tooth", + "lip", + "chin", + "forehead", + "neck", + "shoulder", + "arm", + "elbow", + "wrist", + "hand", + "finger", + "thumb", + "chest", + "stomach", + "back", + "hip", + "leg", + "knee", + "ankle", + "foot", + "toe", + "skin", + "bone", + "muscle", + "brain", + "heart", + "lung", + "liver", + "kidney", + "skull", + "jaw", + "cheek", + "eyebrow", + "eyelash", + "nostril", + "throat", + "collarbone", + "ribcage", + "spine", + "pelvis", + "thigh", + "calf", + "shin", + "heel", + "palm", + "knuckle", + "fingernail", + "tendon", + "ligament", + "cartilage", + "artery", + "vein", + "nerve", + "pancreas", + "spleen", + "bladder", + "intestine", + "esophagus", + "colon", + ], + "vehicles": [ + "car", + "truck", + "bus", + "van", + "motorcycle", + "bicycle", + "scooter", + "train", + "subway", + "tram", + "airplane", + "helicopter", + "jet", + "rocket", + "boat", + "ship", + "yacht", + "canoe", + "kayak", + "ferry", + "ambulance", + "taxi", + "limousine", + "tractor", + "tank", + "sedan", + "convertible", + "minivan", + "pickup", + "jeep", + "hatchback", + "trolley", + "monorail", + "gondola", + "catamaran", + "sailboat", + "submarine", + "cruiser", + "freighter", + "tanker", + "barge", + "glider", + "biplane", + "seaplane", + "hovercraft", + "segway", + "rickshaw", + "chariot", + "carriage", + "sleigh", + "sled", + ], + "weather": [ + "rain", + "snow", + "wind", + "storm", + "thunder", + "lightning", + "fog", + "cloud", + "sunshine", + "rainbow", + "hail", + "frost", + "ice", + "breeze", + "hurricane", + "tornado", + "drought", + "flood", + "blizzard", + "mist", + "dew", + "sleet", + "heat", + "cold", + "monsoon", + "cyclone", + "typhoon", + "drizzle", + "downpour", + "squall", + "gust", + "gale", + "whirlwind", + "avalanche", + "mudslide", + "tsunami", + "humidity", + "overcast", + "haze", + "smog", + ], + "clothing": [ + "shirt", + "pants", + "dress", + "skirt", + "jacket", + "coat", + "sweater", + "hoodie", + "jeans", + "shorts", + "suit", + "tie", + "scarf", + "gloves", + "hat", + "cap", + "boots", + "shoes", + "sandals", + "socks", + "belt", + "uniform", + "pajamas", + "vest", + "raincoat", + "blouse", + "cardigan", + "tuxedo", + "gown", + "robe", + "kimono", + "poncho", + "parka", + "blazer", + "overalls", + "leggings", + "stockings", + "apron", + "bikini", + "swimsuit", + "wetsuit", + "sneakers", + "heels", + "loafers", + "slippers", + "mittens", + "earmuffs", + "beret", + "turban", + "tiara", + "crown", + "veil", + "cape", + "cloak", + "shawl", + ], + "furniture": [ + "chair", + "table", + "desk", + "sofa", + "couch", + "bed", + "mattress", + "pillow", + "blanket", + "shelf", + "bookcase", + "cabinet", + "drawer", + "wardrobe", + "mirror", + "lamp", + "carpet", + "curtain", + "stool", + "bench", + "hammock", + "nightstand", + "dresser", + "armoire", + "ottoman", + "recliner", + "futon", + "crib", + "bunkbed", + "headboard", + "footrest", + "chandelier", + "candelabra", + "vase", + "rug", + "tapestry", + "screen", + "partition", + "credenza", + "hutch", + "sideboard", + ], + "tools": [ + "hammer", + "screwdriver", + "wrench", + "pliers", + "drill", + "saw", + "axe", + "shovel", + "rake", + "scissors", + "knife", + "needle", + "tape", + "glue", + "nail", + "screw", + "bolt", + "wire", + "ladder", + "rope", + "chisel", + "mallet", + "crowbar", + "clamp", + "vise", + "level", + "compass", + "ruler", + "protractor", + "caliper", + "sandpaper", + "soldering iron", + "blowtorch", + "jackhammer", + "chainsaw", + "grinder", + "wheelbarrow", + "hose", + "trowel", + "pickaxe", + "sickle", + ], + "nature": [ + "mountain", + "river", + "lake", + "ocean", + "sea", + "forest", + "jungle", + "desert", + "valley", + "canyon", + "waterfall", + "volcano", + "island", + "beach", + "cliff", + "cave", + "meadow", + "prairie", + "glacier", + "reef", + "swamp", + "marsh", + "hill", + "plateau", + "peninsula", + "archipelago", + "lagoon", + "fjord", + "delta", + "estuary", + "ravine", + "gorge", + "ridge", + "summit", + "tundra", + "savanna", + "steppe", + "oasis", + "geyser", + "spring", + "brook", + "creek", + "pond", + "bay", + "cove", + "strait", + "cape", + "dune", + ], + "science": [ + "atom", + "molecule", + "electron", + "proton", + "neutron", + "photon", + "gravity", + "magnetism", + "electricity", + "radiation", + "frequency", + "wavelength", + "spectrum", + "velocity", + "acceleration", + "momentum", + "entropy", + "energy", + "mass", + "force", + "pressure", + "temperature", + "density", + "viscosity", + "conductivity", + "resistance", + "voltage", + "cell", + "gene", + "DNA", + "RNA", + "protein", + "enzyme", + "chromosome", + "nucleus", + "mitosis", + "mutation", + "evolution", + "species", + "genome", + "bacteria", + "virus", + "fungus", + "parasite", + "antibody", + "vaccine", + "microscope", + "telescope", + "thermometer", + "barometer", + "centrifuge", + "hypothesis", + "theory", + "experiment", + "observation", + "equation", + ], + "technology": [ + "computer", + "laptop", + "tablet", + "smartphone", + "keyboard", + "monitor", + "printer", + "scanner", + "router", + "modem", + "server", + "database", + "algorithm", + "software", + "hardware", + "processor", + "memory", + "storage", + "internet", + "website", + "browser", + "email", + "download", + "upload", + "pixel", + "resolution", + "bandwidth", + "encryption", + "firewall", + "bluetooth", + "wifi", + "satellite", + "antenna", + "cable", + "fiber", + "robot", + "drone", + "sensor", + "microchip", + "transistor", + "circuit", + "battery", + "solar panel", + "turbine", + "generator", + "laser", + "hologram", + "virtual reality", + "augmented reality", + "blockchain", + "artificial intelligence", + "machine learning", + "neural network", + ], + "architecture": [ + "house", + "apartment", + "castle", + "palace", + "cathedral", + "church", + "mosque", + "temple", + "synagogue", + "monastery", + "pagoda", + "shrine", + "skyscraper", + "tower", + "bridge", + "tunnel", + "dam", + "aqueduct", + "pyramid", + "colosseum", + "amphitheater", + "stadium", + "arena", + "lighthouse", + "windmill", + "barn", + "cottage", + "villa", + "mansion", + "bungalow", + "cabin", + "igloo", + "tent", + "yurt", + "hut", + "fortress", + "citadel", + "bunker", + "barracks", + "garrison", + "library", + "museum", + "gallery", + "theater", + "cinema", + "hospital", + "school", + "university", + "laboratory", + "observatory", + "greenhouse", + ], +} + + +def build_vocabulary() -> tuple[list[str], list[str]]: + """Return (words, categories) from the vocabulary dict.""" + words: list[str] = [] + categories: list[str] = [] + for category, word_list in VOCABULARY.items(): + for word in word_list: + words.append(word) + categories.append(category) + print(f"Vocabulary: {len(words)} words across {len(VOCABULARY)} categories") + return words, categories + + +def compute_embeddings(words: list[str], model_name: str) -> np.ndarray: + """Embed words with sentence-transformers.""" + from sentence_transformers import SentenceTransformer + + print(f"Loading model: {model_name}") + model = SentenceTransformer(model_name) + print(f"Encoding {len(words)} words...") + t0 = time.time() + embeddings = model.encode(words, normalize_embeddings=True, show_progress_bar=False) + print(f" Done in {time.time() - t0:.1f}s — shape: {embeddings.shape}") + return embeddings + + +def build_tmap(embeddings: np.ndarray, k: int) -> TMAP: + """Build TMAP with cosine metric.""" + print(f"Fitting TMAP (metric='cosine', k={k})...") + t0 = time.time() + model = TMAP( + metric="cosine", + n_neighbors=k, + layout_iterations=1000, + seed=42, + ).fit(embeddings.astype(np.float32)) + print(f" Done in {time.time() - t0:.1f}s") + return model + + +def analyze( + model: TMAP, + words: list[str], + categories: list[str], +) -> str: + """Generate semantic analysis report.""" + tree = model.tree_ + cat_arr = np.array(categories) + lines: list[str] = [] + w = lines.append + + w(f"Word Embedding TMAP — {len(words)} words, {len(set(categories))} categories\n") + + # 1. Category clustering + be = boundary_edges(tree, cat_arr) + w("1. Category boundaries:") + w( + f" Same-category edges: {len(tree.edges) - len(be)} / {len(tree.edges)} " + f"({1 - len(be) / len(tree.edges):.1%})" + ) + w(f" Cross-category edges: {len(be)} ({len(be) / len(tree.edges):.1%})\n") + + # 2. Subtree purity + purity = subtree_purity(tree, cat_arr, min_size=5) + valid = purity[~np.isnan(purity)] + w(f"2. Subtree purity: mean={valid.mean():.3f} median={np.median(valid):.3f}\n") + + # 3. Most connected category pairs + cmat, classes = confusion_matrix_from_tree(tree, cat_arr) + np.fill_diagonal(cmat, 0) + upper = np.triu_indices_from(cmat, k=1) + pair_counts = cmat[upper] + top_idx = np.argsort(pair_counts)[::-1][:10] + w("3. Most connected category pairs (cross-category tree edges):") + for i in top_idx: + if pair_counts[i] == 0: + break + r, c = upper[0][i], upper[1][i] + w(f" {pair_counts[i]:3d} edges: {classes[r]:>15s} <-> {classes[c]}") + w("") + + # 4. Semantic paths — trace between words + word_to_idx = {w: i for i, w in enumerate(words)} + paths_to_trace = [ + ("dog", "wolf"), + ("dog", "cat"), + ("guitar", "piano"), + ("salmon", "sushi"), + ("rain", "snow"), + ("happy", "sad"), + ("doctor", "nurse"), + ("France", "Germany"), + ("car", "airplane"), + ("guitar", "violin"), + ("hammer", "screwdriver"), + ("mountain", "ocean"), + ("dog", "guitar"), + ("happy", "volcano"), + ] + w("4. Semantic paths (word → word along the tree):") + for word_a, word_b in paths_to_trace: + if word_a not in word_to_idx or word_b not in word_to_idx: + continue + idx_a = word_to_idx[word_a] + idx_b = word_to_idx[word_b] + path_nodes = tree.path(idx_a, idx_b) + path_words = [words[n] for n in path_nodes] + # Show full path if short, otherwise truncated + if len(path_words) <= 8: + path_str = " → ".join(path_words) + else: + path_str = " → ".join(path_words[:4]) + " → ... → " + " → ".join(path_words[-2:]) + w(f" {word_a} → {word_b} ({len(path_words)} hops): {path_str}") + w("") + + return "\n".join(lines) + + +def create_visualization( + model: TMAP, + words: list[str], + categories: list[str], +) -> TmapViz: + """Build TmapViz with category coloring and word labels.""" + + viz = model.to_tmapviz() + viz.title = "Word Embeddings — Semantic Map" + + viz.add_color_layout("category", categories, categorical=True, color="tab20") + viz.add_label("word", words) + + return viz + + +def main() -> None: + parser = argparse.ArgumentParser(description="Word embedding TMAP") + parser.add_argument( + "--model", + type=str, + default="all-MiniLM-L6-v2", + help="Sentence-transformer model (default: all-MiniLM-L6-v2)", + ) + parser.add_argument("--k", type=int, default=10, help="Number of neighbors") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--port", type=int, default=8050) + args = parser.parse_args() + + words, categories = build_vocabulary() + embeddings = compute_embeddings(words, args.model) + model = build_tmap(embeddings, args.k) + + report = analyze(model, words, categories) + report_path = OUTPUT_DIR / "word_embeddings_report.txt" + report_path.write_text(report, encoding="utf-8") + print(f"\nReport saved to {report_path}") + print("\n" + report) + + print("Building visualization...") + viz = create_visualization(model, words, categories) + html_path = viz.write_html(OUTPUT_DIR / "word_embeddings_tmap") + print(f"HTML saved to {html_path}") + + if args.serve: + print(f"Serving on http://127.0.0.1:{args.port}") + viz.serve(port=args.port) + + +if __name__ == "__main__": + main()