From 0046eef05fe63bb693134719baa66928dcdb15c6 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Sat, 14 Mar 2026 03:35:27 -0700 Subject: [PATCH] add classify --local mode using local embeddings + cosine similarity Co-Authored-By: Claude --- README.md | 5 +++- jina_cli/api.py | 32 ++++++++++++++++++++ jina_cli/main.py | 15 ++++++---- tests/test_local.py | 72 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 tests/test_local.py diff --git a/README.md b/README.md index 8078fc7..24c91cb 100644 --- a/README.md +++ b/README.md @@ -204,7 +204,7 @@ jina grep serve stop # stop when done ## Local mode -`jina embed`, `jina rerank`, and `jina dedup` support `--local` to run on Apple Silicon via the jina-grep embedding server instead of the Jina API. No API key needed. +`jina embed`, `jina rerank`, `jina classify`, and `jina dedup` support `--local` to run on Apple Silicon via the jina-grep embedding server instead of the Jina API. No API key needed. ```bash # Start the local server first @@ -217,6 +217,9 @@ cat texts.txt | jina embed --local --json # Local reranking (cosine similarity on local embeddings) cat docs.txt | jina rerank --local "machine learning" +# Local classification (cosine similarity on local embeddings) +jina classify --local "this is great" --labels positive,negative + # Local deduplication cat items.txt | jina dedup --local ``` diff --git a/jina_cli/api.py b/jina_cli/api.py index f09896e..a8a396f 100644 --- a/jina_cli/api.py +++ b/jina_cli/api.py @@ -450,6 +450,38 @@ def local_embed( return data.get("data", []) +def local_classify( + texts: list[str], + labels: list[str], + model: str = "jina-embeddings-v5-nano", + task: str = "text-matching", +) -> list[dict]: + """Classify texts into labels using local embeddings and cosine similarity.""" + all_texts = texts + labels + embeddings_data = local_embed(all_texts, model=model, task=task) + embeddings = [item["embedding"] for item in embeddings_data] + + text_embs = embeddings[:len(texts)] + label_embs = embeddings[len(texts):] + + results = [] + for i, text_emb in enumerate(text_embs): + best_label = labels[0] + best_score = -1.0 + for j, label_emb in enumerate(label_embs): + score = _cosine_similarity(text_emb, label_emb) + if score > best_score: + best_score = score + best_label = labels[j] + results.append({ + "index": i, + "prediction": best_label, + "score": best_score, + }) + + return results + + def local_rerank( query: str, documents: list[str], diff --git a/jina_cli/main.py b/jina_cli/main.py index c54ad95..b696435 100644 --- a/jina_cli/main.py +++ b/jina_cli/main.py @@ -410,11 +410,12 @@ def dedup(ctx, k, local, as_json, api_key): @click.argument("text", nargs=-1) @click.option("--labels", required=True, multiple=True, help="Labels for classification (comma-separated or repeated --labels)") -@click.option("--model", default=None, help="Model name (default: jina-embeddings-v5-text-small)") +@click.option("--model", default=None, help="Model name (default: jina-embeddings-v5-text-small, or v5-nano with --local)") +@click.option("--local", is_flag=True, help="Use local MLX server (requires: jina-grep serve start)") @click.option("--json", "as_json", is_flag=True, help="Output as JSON") @click.option("--api-key", default=None, help="Jina API key") @click.pass_context -def classify(ctx, text, labels, model, as_json, api_key): +def classify(ctx, text, labels, model, local, as_json, api_key): """Classify text into labels. Input from arguments or stdin (one text per line). @@ -424,6 +425,7 @@ def classify(ctx, text, labels, model, as_json, api_key): jina classify "this is great" --labels positive,negative echo "stock price rose" | jina classify --labels business,sports,tech jina classify "text1" "text2" --labels cat1 --labels cat2 --labels cat3 + jina classify --local "this is great" --labels positive,negative """ key = api_key or ctx.obj.get("api_key") @@ -450,10 +452,13 @@ def classify(ctx, text, labels, model, as_json, api_key): "Fix: --labels positive,negative", err=True) sys.exit(EXIT_USER_ERROR) - _model = model or "jina-embeddings-v5-text-small" - try: - result = api.classify(texts, parsed_labels, api_key=key, model=_model) + if local: + _model = model or "jina-embeddings-v5-nano" + result = api.local_classify(texts, parsed_labels, model=_model) + else: + _model = model or "jina-embeddings-v5-text-small" + result = api.classify(texts, parsed_labels, api_key=key, model=_model) click.echo(utils.format_classify_results(result, as_json=as_json)) except Exception as e: utils.handle_http_error(e) diff --git a/tests/test_local.py b/tests/test_local.py new file mode 100644 index 0000000..2f50e88 --- /dev/null +++ b/tests/test_local.py @@ -0,0 +1,72 @@ +"""Unit tests for local mode functions (no API key or server needed).""" + +from unittest.mock import patch + +from jina_cli.api import local_classify, _cosine_similarity + + +class TestCosineSimlarity: + def test_identical_vectors(self): + assert abs(_cosine_similarity([1, 0, 0], [1, 0, 0]) - 1.0) < 1e-6 + + def test_orthogonal_vectors(self): + assert abs(_cosine_similarity([1, 0], [0, 1])) < 1e-6 + + def test_zero_vector(self): + assert _cosine_similarity([0, 0], [1, 1]) == 0.0 + + +class TestLocalClassify: + def test_single_text(self): + fake_embeddings = [ + {"embedding": [0.9, 0.1, 0.0]}, # "I love this" - text + {"embedding": [0.8, 0.2, 0.0]}, # "positive" - label (close) + {"embedding": [0.0, 0.1, 0.9]}, # "negative" - label (far) + ] + + with patch("jina_cli.api.local_embed", return_value=fake_embeddings): + result = local_classify( + texts=["I love this"], + labels=["positive", "negative"], + ) + + assert len(result) == 1 + assert result[0]["prediction"] == "positive" + assert result[0]["score"] > 0.5 + assert result[0]["index"] == 0 + + def test_multiple_texts(self): + fake_embeddings = [ + {"embedding": [0.9, 0.1]}, # text 1 - closer to label 1 + {"embedding": [0.1, 0.9]}, # text 2 - closer to label 2 + {"embedding": [0.8, 0.2]}, # label "sports" + {"embedding": [0.2, 0.8]}, # label "politics" + ] + + with patch("jina_cli.api.local_embed", return_value=fake_embeddings): + result = local_classify( + texts=["goal scored", "election results"], + labels=["sports", "politics"], + ) + + assert len(result) == 2 + assert result[0]["prediction"] == "sports" + assert result[1]["prediction"] == "politics" + + def test_result_format(self): + """Results should have index, prediction, score keys.""" + fake_embeddings = [ + {"embedding": [1.0, 0.0]}, + {"embedding": [0.9, 0.1]}, + ] + + with patch("jina_cli.api.local_embed", return_value=fake_embeddings): + result = local_classify( + texts=["test"], + labels=["label1"], + ) + + assert "index" in result[0] + assert "prediction" in result[0] + assert "score" in result[0] + assert result[0]["prediction"] == "label1"