diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt
index 3837e96955..8d801d3850 100644
--- a/.cspell-wordlist.txt
+++ b/.cspell-wordlist.txt
@@ -215,3 +215,4 @@ MATEUSZ
BLAZEFACE
Blazeface
blazeface
+nums
diff --git a/apps/text-embeddings/app/clip-embeddings/index.tsx b/apps/text-embeddings/app/clip-embeddings/index.tsx
index 02a8a9c656..affe3c2955 100644
--- a/apps/text-embeddings/app/clip-embeddings/index.tsx
+++ b/apps/text-embeddings/app/clip-embeddings/index.tsx
@@ -17,6 +17,7 @@ import {
useTextEmbeddings,
useImageEmbeddings,
ImageEmbeddingsProps,
+ dotProduct,
} from 'react-native-executorch';
type ImageEmbeddingModel = ImageEmbeddingsProps['model'];
@@ -35,7 +36,6 @@ const IMAGE_MODELS: { label: string; value: ImageEmbeddingModel }[] = [
];
import { launchImageLibrary } from 'react-native-image-picker';
import { useIsFocused } from 'expo-router';
-import { dotProduct } from '../../utils/math';
import { ModelPicker } from '../../components/ModelPicker';
const DEFAULT_LABELS = [
diff --git a/apps/text-embeddings/app/text-embeddings/index.tsx b/apps/text-embeddings/app/text-embeddings/index.tsx
index 88e39ce063..fb4711e837 100644
--- a/apps/text-embeddings/app/text-embeddings/index.tsx
+++ b/apps/text-embeddings/app/text-embeddings/index.tsx
@@ -5,7 +5,6 @@ import {
TextInput,
TouchableOpacity,
View,
- SafeAreaView,
ScrollView,
KeyboardAvoidingView,
Platform,
@@ -16,10 +15,18 @@ import {
models,
useTextEmbeddings,
TextEmbeddingsProps,
+ EmbeddingResult,
+ dotProduct,
+ maxSim,
} from 'react-native-executorch';
+import { useIsFocused } from 'expo-router';
+import ErrorBanner from '../../components/ErrorBanner';
+import { SafeAreaView } from 'react-native-safe-area-context';
+
const textEmbedding = models.text_embedding;
type TextEmbeddingModel = TextEmbeddingsProps['model'];
+type Encoding = Float32Array | EmbeddingResult;
const MODELS: { label: string; value: TextEmbeddingModel }[] = [
{ label: 'MiniLM L6', value: textEmbedding.all_minilm_l6_v2() },
@@ -43,10 +50,38 @@ const MODELS: { label: string; value: TextEmbeddingModel }[] = [
label: 'Multilingual Paraphrase',
value: textEmbedding.paraphrase_multilingual_minilm_l12_v2(),
},
+ {
+ label: 'LFM2.5 Embedding',
+ value: textEmbedding.lfm2_5_embedding_350m(),
+ },
+ {
+ label: 'LFM2.5 ColBERT (late-interaction)',
+ value: textEmbedding.lfm2_5_colbert_350m(),
+ },
+];
+
+const CORPUS: string[] = [
+ 'The forecast says heavy showers this afternoon.',
+ "It's so sunny outside today!",
+ 'A thick fog rolled in over the harbor at dawn.',
+ 'The home team scored in the final minute to win the match.',
+ 'She sprinted the last lap and broke the national record.',
+ 'Fans packed the stadium for the championship game.',
+ 'Simmer the tomatoes with garlic before adding the pasta.',
+ 'He whisked the eggs and folded in the melted chocolate.',
+ 'The new phone has a faster chip and a brighter screen.',
+ 'Our servers crashed under the sudden spike in traffic.',
+ 'The flight to Tokyo was delayed by three hours.',
+ 'We hiked along the coast and camped near the cliffs.',
+];
+
+const EXAMPLE_QUERIES: string[] = [
+ "What's the weather like?",
+ 'Who won the match?',
+ 'Tell me about the latest technology',
+ 'How do I cook dinner?',
+ 'Where did they travel?',
];
-import { useIsFocused } from 'expo-router';
-import { dotProduct } from '../../utils/math';
-import ErrorBanner from '../../components/ErrorBanner';
export default function TextEmbeddingsScreenWrapper() {
const isFocused = useIsFocused();
@@ -54,6 +89,8 @@ export default function TextEmbeddingsScreenWrapper() {
return isFocused ? : null;
}
+type RankedResult = { sentence: string; similarity: number };
+
function TextEmbeddingsScreen() {
const [selectedModel, setSelectedModel] = useState(
textEmbedding.all_minilm_l6_v2()
@@ -61,88 +98,70 @@ function TextEmbeddingsScreen() {
const model = useTextEmbeddings({ model: selectedModel });
const [error, setError] = useState(null);
- const [inputSentence, setInputSentence] = useState('');
- const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState<
- { sentence: string; embedding: Float32Array }[]
- >([]);
- const [topMatches, setTopMatches] = useState<
- { sentence: string; similarity: number }[]
+ const isMultiVector = !!selectedModel.multiVector;
+ const skipListIds = selectedModel.skipListIds ?? [];
+
+ const [query, setQuery] = useState('');
+ const [corpusEmbeddings, setCorpusEmbeddings] = useState<
+ { sentence: string; embedding: Encoding }[]
>([]);
+ const [results, setResults] = useState([]);
const [embeddingTime, setEmbeddingTime] = useState(null);
+ const [indexing, setIndexing] = useState(false);
useEffect(
() => {
- const computeEmbeddings = async () => {
+ let cancelled = false;
+ const indexCorpus = async () => {
if (!model.isReady) return;
-
- const sentences = [
- 'The weather is lovely today.',
- "It's so sunny outside!",
- 'He drove to the stadium.',
- ];
-
+ setIndexing(true);
+ setResults([]);
try {
- const embeddings = [];
- for (const sentence of sentences) {
- const embedding = await model.forward(sentence);
- embeddings.push({ sentence, embedding });
+ const embedded = [];
+ for (const sentence of CORPUS) {
+ const embedding = await model.forward(sentence, 'document');
+ if (cancelled) return;
+ embedded.push({ sentence, embedding });
}
-
- setSentencesWithEmbeddings(embeddings);
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
+ setCorpusEmbeddings(embedded);
+ } finally {
+ if (!cancelled) setIndexing(false);
}
};
-
- computeEmbeddings();
+ indexCorpus();
+ return () => {
+ cancelled = true;
+ };
},
+
// eslint-disable-next-line react-hooks/exhaustive-deps
- [model.isReady]
+ [model.isReady, selectedModel]
);
- const checkSimilarities = async () => {
- if (!model.isReady || !inputSentence.trim()) return;
-
+ const runSearch = async (queryText: string = query) => {
+ const q = queryText.trim();
+ if (!model.isReady || !q || corpusEmbeddings.length === 0) return;
+ setQuery(queryText);
try {
const start = Date.now();
- const inputEmbedding = await model.forward(inputSentence);
+ const queryEmbedding = (await model.forward(q, 'query')) as Encoding;
setEmbeddingTime(Date.now() - start);
- const matches = sentencesWithEmbeddings.map(
- ({ sentence, embedding }) => ({
+ const ranked = corpusEmbeddings
+ .map(({ sentence, embedding }) => ({
sentence,
- similarity: dotProduct(inputEmbedding, embedding),
- })
- );
- matches.sort((a, b) => b.similarity - a.similarity);
- setTopMatches(matches.slice(0, 3));
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
- }
- };
-
- const addToSentences = async () => {
- if (!model.isReady || !inputSentence.trim()) return;
-
- try {
- const start = Date.now();
- const embedding = await model.forward(inputSentence);
- setEmbeddingTime(Date.now() - start);
- setSentencesWithEmbeddings((prev) => [
- ...prev,
- { sentence: inputSentence, embedding },
- ]);
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
- }
-
- setInputSentence('');
- setTopMatches([]);
- };
-
- const clearList = async () => {
- if (!model.isReady) return;
- try {
- setSentencesWithEmbeddings([]);
+ similarity: isMultiVector
+ ? maxSim(
+ queryEmbedding as EmbeddingResult,
+ embedding as EmbeddingResult,
+ skipListIds
+ )
+ : dotProduct(
+ queryEmbedding as Float32Array,
+ embedding as Float32Array
+ ),
+ }))
+ .sort((a, b) => b.similarity - a.similarity);
+ setResults(ranked);
} catch (e) {
setError(e instanceof Error ? e.message : String(e));
}
@@ -158,6 +177,9 @@ function TextEmbeddingsScreen() {
return model.isGenerating ? 'Generating...' : 'Model is ready';
};
+ const ready = model.isReady && !indexing && corpusEmbeddings.length > 0;
+ const canSearch = ready && !!query.trim();
+
return (
- Text Embeddings Playground
+ Semantic Search
{getModelStatusText()}
{
setSelectedModel(m);
- setSentencesWithEmbeddings([]);
- setTopMatches([]);
+ setCorpusEmbeddings([]);
+ setResults([]);
+ setQuery('');
}}
/>
setError(null)} />
- List of Existing Sentences
- {sentencesWithEmbeddings.map((item, index) => (
-
- - {item.sentence}
-
- ))}
-
-
- Try Your Sentence
+
+ Search the corpus ({CORPUS.length} sentences)
+
+
+ {isMultiVector
+ ? 'Ranks per-token vectors with MaxSim (late interaction). Ask a full question — tap an example or type your own.'
+ : 'Ranks every sentence by meaning. Ask a full question — tap an example or type your own.'}
+
+
+ {EXAMPLE_QUERIES.map((q) => (
+ runSearch(q)}
+ >
+ {q}
+
+ ))}
+
runSearch()}
+ returnKeyType="search"
/>
-
- runSearch()}
+ style={[
+ styles.buttonPrimary,
+ !canSearch && styles.buttonDisabled,
+ ]}
+ disabled={!canSearch}
+ >
+
+
-
-
- Find Similar
-
-
-
-
-
-
- Add to List
-
-
-
-
-
- Clear List
-
-
-
-
+ {indexing ? 'Indexing corpus…' : 'Search'}
+
+
{embeddingTime !== null && (
- Embedding time: {embeddingTime} ms
+ Query embedded in {embeddingTime} ms
)}
- {topMatches.length > 0 && (
-
- Top Matches
- {topMatches.map((item, index) => (
-
- {item.sentence} ({item.similarity.toFixed(2)})
-
- ))}
-
- )}
+
+ {results.length > 0 && (
+
+ Results
+ {results.map((item, index) => (
+
+ ))}
+
+ )}
);
}
+function ResultRow({
+ sentence,
+ similarity,
+ best,
+ rank,
+}: {
+ sentence: string;
+ similarity: number;
+ best: number;
+ rank: number;
+}) {
+ const fraction = best > 0 ? Math.max(0, similarity / best) : 0;
+ return (
+
+
+ {sentence}
+ {similarity.toFixed(2)}
+
+
+
+
+
+ );
+}
+
const styles = StyleSheet.create({
container: {
flex: 1,
@@ -323,11 +342,68 @@ const styles = StyleSheet.create({
marginBottom: 12,
color: '#1E293B',
},
- sentenceText: {
- fontSize: 14,
+ hint: {
+ fontSize: 13,
+ color: '#64748B',
+ marginBottom: 12,
+ lineHeight: 18,
+ },
+ chipRow: {
+ flexDirection: 'row',
+ flexWrap: 'wrap',
+ gap: 8,
+ marginBottom: 12,
+ },
+ chip: {
+ backgroundColor: '#EEF2FF',
+ borderColor: '#C7D2FE',
+ borderWidth: 1,
+ borderRadius: 16,
+ paddingHorizontal: 12,
+ paddingVertical: 6,
+ },
+ chipDisabled: {
+ opacity: 0.4,
+ },
+ chipText: {
+ fontSize: 13,
+ color: 'navy',
+ },
+ resultRow: {
+ marginBottom: 14,
+ },
+ resultHeader: {
+ flexDirection: 'row',
+ justifyContent: 'space-between',
+ alignItems: 'flex-start',
marginBottom: 6,
+ gap: 8,
+ },
+ resultText: {
+ flex: 1,
+ fontSize: 14,
color: '#334155',
},
+ resultScore: {
+ fontSize: 14,
+ fontWeight: '600',
+ color: '#0F172A',
+ fontVariant: ['tabular-nums'],
+ },
+ barTrack: {
+ height: 8,
+ borderRadius: 4,
+ backgroundColor: '#E2E8F0',
+ overflow: 'hidden',
+ },
+ barFill: {
+ height: '100%',
+ borderRadius: 4,
+ backgroundColor: '#94A3B8',
+ },
+ barFillTop: {
+ backgroundColor: 'navy',
+ },
input: {
backgroundColor: '#F1F5F9',
borderRadius: 10,
@@ -338,17 +414,8 @@ const styles = StyleSheet.create({
minHeight: 40,
textAlignVertical: 'top',
},
- buttonContainer: {
- width: '100%',
- gap: 10,
- },
- buttonGroup: {
- flexDirection: 'row',
- justifyContent: 'space-between',
- gap: 10,
- },
buttonPrimary: {
- flex: 1,
+ width: '100%',
backgroundColor: 'navy',
padding: 12,
borderRadius: 10,
@@ -356,17 +423,6 @@ const styles = StyleSheet.create({
alignItems: 'center',
justifyContent: 'center',
},
- buttonSecondary: {
- flex: 1,
- backgroundColor: 'transparent',
- borderWidth: 2,
- borderColor: 'navy',
- padding: 12,
- borderRadius: 10,
- flexDirection: 'row',
- alignItems: 'center',
- justifyContent: 'center',
- },
buttonDisabled: {
backgroundColor: '#f0f0f0',
borderColor: '#d3d3d3',
@@ -376,17 +432,9 @@ const styles = StyleSheet.create({
textAlign: 'center',
fontWeight: '500',
},
- buttonTextOutline: {
- color: 'navy',
- textAlign: 'center',
- fontWeight: '500',
- },
buttonTextDisabled: {
color: 'gray',
},
- topMatchesContainer: {
- marginTop: 20,
- },
statsText: {
fontSize: 13,
color: '#64748B',
diff --git a/apps/text-embeddings/utils/math.ts b/apps/text-embeddings/utils/math.ts
deleted file mode 100644
index 50c70d1f92..0000000000
--- a/apps/text-embeddings/utils/math.ts
+++ /dev/null
@@ -1,19 +0,0 @@
-import {
- RnExecutorchError,
- RnExecutorchErrorCode,
-} from 'react-native-executorch';
-
-export const dotProduct = (a: Float32Array, b: Float32Array) => {
- if (a.length !== b.length) {
- throw new RnExecutorchError(
- RnExecutorchErrorCode.WrongDimensions,
- `dotProduct needs both vector to have the same length: got a: ${a.length}, b: ${b.length}`
- );
- }
-
- let sum = 0;
- for (let i = 0; i < a.length; i++) {
- sum += a[i] * b[i];
- }
- return sum;
-};
diff --git a/docs/docs/03-hooks/01-natural-language-processing/useTextEmbeddings.md b/docs/docs/03-hooks/01-natural-language-processing/useTextEmbeddings.md
index b9ba8c41b9..439de17560 100644
--- a/docs/docs/03-hooks/01-natural-language-processing/useTextEmbeddings.md
+++ b/docs/docs/03-hooks/01-natural-language-processing/useTextEmbeddings.md
@@ -45,7 +45,13 @@ try {
`useTextEmbeddings` takes [`TextEmbeddingsProps`](../../06-api-reference/interfaces/TextEmbeddingsProps.md) that consists of:
-- `model` of type `object` containing the [model source](../../06-api-reference/interfaces/TextEmbeddingsProps.md#modelsource) and [tokenizer source](../../06-api-reference/interfaces/TextEmbeddingsProps.md#tokenizersource).
+- `model` of type `object` ([`TextEmbeddingsModel`](../../06-api-reference/interfaces/TextEmbeddingsModel.md)) containing:
+ - `modelName` - Unique name identifying the model.
+ - `modelSource` - Location of the used model.
+ - `tokenizerSource` - Location of the used tokenizer.
+ - `prompts` _(optional)_ - Asymmetric `query`/`document` prompts the model is trained with. When present, `forward` requires a `role` and prepends the matching prompt.
+ - `multiVector` _(optional)_ - When `true`, `forward` returns the per-token [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) instead of a single pooled `Float32Array`.
+ - `skipListIds` _(optional)_ - Token ids to exclude from late-interaction (MaxSim) scoring.
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/TextEmbeddingsProps.md#preventload) which prevents auto-loading of the model.
You need more details? Check the following resources:
@@ -60,16 +66,30 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/TextEmbeddingsType.md#forward) method. It accepts one argument, which is a string representing the text you want to embed. The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+To run the model, you can use the [`forward`](../../06-api-reference/interfaces/TextEmbeddingsType.md#forward) method. It accepts the text to embed and, for models trained with asymmetric prompts, an optional `role`. The return type depends on the model:
+
+- **Pooled models** (the default, e.g. MiniLM, MPNet, LFM2.5-Embedding) resolve to a single `Float32Array` — one normalized vector for the whole input.
+- **Multi-vector models** (`multiVector: true`, e.g. LFM2.5-ColBERT) resolve to an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) with the per-token vectors (`vectors`, `numTokens`, `embeddingDim`, `tokenIds`).
+
+For background on why a dense bi-encoder pools to one vector while a late-interaction model keeps per-token vectors, see Liquid AI's [LFM2.5 Retrievers blog post](https://www.liquid.ai/blog/lfm2-5-retrievers).
+
+### Asymmetric prompts (`role`)
+
+Some retrieval models are trained to embed queries and documents with different prefixes (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `). For these models the model config carries the prompts and `forward` requires a `role`:
+
+```typescript
+const queryEmbedding = await model.forward('What is the weather?', 'query');
+const docEmbedding = await model.forward('It is sunny today.', 'document');
+```
+
+The matching prompt is prepended automatically; for models without prompts the `role` argument is absent.
## Example
```typescript
-import { models, useTextEmbeddings } from 'react-native-executorch';
-const dotProduct = (a: number[], b: number[]) =>
- a.reduce((sum, val, i) => sum + val * b[i], 0);
+import { models, useTextEmbeddings, dotProduct } from 'react-native-executorch';
-const cosineSimilarity = (a: number[], b: number[]) => {
+const cosineSimilarity = (a: Float32Array, b: Float32Array) => {
const dot = dotProduct(a, b);
const normA = Math.sqrt(dotProduct(a, a));
const normB = Math.sqrt(dotProduct(b, b));
@@ -112,6 +132,8 @@ function App() {
| [distiluse-base-multilingual-cased-v2](https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased-v2) | 50+ languages | 126 | 512 | Multilingual DistilBERT with a 768→512 projection head. Recommended when broader language coverage matters more than the exact English quality of MiniLM/MPNet. |
| [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) | 50+ languages | 126 | 384 | Multilingual MiniLM-L12 distilled from paraphrase-multilingual-mpnet-base-v2. Compact (≈118 M params) sentence encoder for cross-lingual semantic similarity and retrieval across 50+ languages. |
| [clip-vit-base-patch32-text](https://huggingface.co/openai/clip-vit-base-patch32) | English | 74 | 512 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. CLIP allows to embed images and text into the same vector space. This allows to find similar images as well as to implement image search. This is the text encoder part of the CLIP model. To embed images checkout [clip-vit-base-patch32-image](../02-computer-vision/useImageEmbeddings.md#supported-models). |
+| [LFM2.5-Embedding-350M](https://huggingface.co/LiquidAI/LFM2.5-Embedding-350M) | Multilingual | 512 | 1024 | Dense bi-encoder from Liquid AI with CLS pooling. Trained with asymmetric `query: `/`document: ` prompts, so `forward` requires a `role`. On iOS it runs on the GPU via the MLX backend (physical device only); Android uses XNNPACK. |
+| [LFM2.5-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2.5-ColBERT-350M) | Multilingual | 512 | 128 (per token) | Late-interaction (multi-vector) retriever from Liquid AI: a `Linear(1024→128)` head emits one normalized vector per token. `forward` returns an `EmbeddingResult`; score query/document pairs with MaxSim (see below). Uses `[Q] `/`[D] ` role prompts. |
**`Max Tokens`** - The maximum number of tokens that can be processed by the model. If the input text exceeds this limit, it will be truncated.
@@ -120,3 +142,28 @@ function App() {
:::note
For the supported models, the returned embedding vector is normalized, meaning that its length is equal to 1. This allows for easier comparison of vectors using cosine similarity, just calculate the dot product of two vectors to get the cosine similarity score.
:::
+
+## Late interaction (multi-vector models)
+
+Multi-vector models such as LFM2.5-ColBERT do not pool the sequence into a single vector. Instead, `forward` returns an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) holding one normalized vector per token. You score a query against a document with **MaxSim**: for every query-token vector, take its highest dot product against the document-token vectors, then sum those maxima. The model also ships a `skipListIds` array — the punctuation token ids excluded from scoring.
+
+The library ships a `maxSim` helper (and a `dotProduct` helper for pooled models), so you can score directly without reimplementing it:
+
+```typescript
+import { models, useTextEmbeddings, maxSim } from 'react-native-executorch';
+
+const colbert = models.text_embedding.lfm2_5_colbert_350m();
+const skipListIds = colbert.skipListIds ?? [];
+
+function App() {
+ const model = useTextEmbeddings({ model: colbert });
+
+ // ...
+
+ const query = await model.forward('What is the weather?', 'query');
+ const doc = await model.forward('It is sunny today.', 'document');
+ const score = maxSim(query, doc, skipListIds);
+}
+```
+
+The `skipListIds` shipped on the model config are the punctuation token ids excluded from scoring (derived from the model's training config). Per-token vectors are L2-normalized by the graph, so the dot product equals cosine similarity.
diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
index aa563c213d..b4cd478b0e 100644
--- a/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
+++ b/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
@@ -30,13 +30,20 @@ All methods of `TextEmbeddingsModule` are explained in details here: [`TextEmbed
Use the static [`fromModelName`](../../06-api-reference/classes/TextEmbeddingsModule.md#frommodelname) factory method. It accepts a model config object (e.g. `ALL_MINILM_L6_V2`) containing:
-- [`modelSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#modelsource) - Location of the used model.
-- [`tokenizerSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#tokenizersource) - Location of the used tokenizer.
+- `modelName` - Unique name identifying the model.
+- `modelSource` - Location of the used model.
+- `tokenizerSource` - Location of the used tokenizer.
+- `prompts` _(optional)_ - Asymmetric `query`/`document` prompts the model is trained with. When present, `forward` requires a `role` and prepends the matching prompt.
+- `multiVector` _(optional)_ - When `true`, `forward` returns the per-token `EmbeddingResult` instead of a single pooled `Float32Array`.
+- `skipListIds` _(optional)_ - Token ids to exclude from late-interaction (MaxSim) scoring.
-And an optional `onDownloadProgress` callback. It returns a promise resolving to a `TextEmbeddingsModule` instance.
+And an optional `onDownloadProgress` callback (receiving a value between 0 and 1). It returns a promise resolving to a `TextEmbeddingsModule` instance.
For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/TextEmbeddingsModule.md#forward) method. It accepts one argument, which is the text you want to embed. The method returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+To run the model, use the [`forward`](../../06-api-reference/classes/TextEmbeddingsModule.md#forward) method. It accepts the text to embed and, for models with asymmetric prompts, an optional `role` (`'query' | 'document'`). The method returns a promise resolving to:
+
+- a `Float32Array` — a single pooled vector — for standard models, or
+- an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) with the per-token vectors for `multiVector` models.
diff --git a/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useTextEmbeddings.md b/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useTextEmbeddings.md
index 3e23a88630..304969034b 100644
--- a/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useTextEmbeddings.md
+++ b/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useTextEmbeddings.md
@@ -45,7 +45,13 @@ try {
`useTextEmbeddings` takes [`TextEmbeddingsProps`](../../06-api-reference/interfaces/TextEmbeddingsProps.md) that consists of:
-- `model` of type `object` containing the [model source](../../06-api-reference/interfaces/TextEmbeddingsProps.md#modelsource) and [tokenizer source](../../06-api-reference/interfaces/TextEmbeddingsProps.md#tokenizersource).
+- `model` of type `object` ([`TextEmbeddingsModel`](../../06-api-reference/interfaces/TextEmbeddingsModel.md)) containing:
+ - `modelName` - Unique name identifying the model.
+ - `modelSource` - Location of the used model.
+ - `tokenizerSource` - Location of the used tokenizer.
+ - `prompts` _(optional)_ - Asymmetric `query`/`document` prompts the model is trained with. When present, `forward` requires a `role` and prepends the matching prompt.
+ - `multiVector` _(optional)_ - When `true`, `forward` returns the per-token [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) instead of a single pooled `Float32Array`.
+ - `skipListIds` _(optional)_ - Token ids to exclude from late-interaction (MaxSim) scoring.
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/TextEmbeddingsProps.md#preventload) which prevents auto-loading of the model.
You need more details? Check the following resources:
@@ -60,16 +66,30 @@ You need more details? Check the following resources:
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/interfaces/TextEmbeddingsType.md#forward) method. It accepts one argument, which is a string representing the text you want to embed. The function returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+To run the model, you can use the [`forward`](../../06-api-reference/interfaces/TextEmbeddingsType.md#forward) method. It accepts the text to embed and, for models trained with asymmetric prompts, an optional `role`. The return type depends on the model:
+
+- **Pooled models** (the default, e.g. MiniLM, MPNet, LFM2.5-Embedding) resolve to a single `Float32Array` — one normalized vector for the whole input.
+- **Multi-vector models** (`multiVector: true`, e.g. LFM2.5-ColBERT) resolve to an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) with the per-token vectors (`vectors`, `numTokens`, `embeddingDim`, `tokenIds`).
+
+For background on why a dense bi-encoder pools to one vector while a late-interaction model keeps per-token vectors, see Liquid AI's [LFM2.5 Retrievers blog post](https://www.liquid.ai/blog/lfm2-5-retrievers).
+
+### Asymmetric prompts (`role`)
+
+Some retrieval models are trained to embed queries and documents with different prefixes (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `). For these models the model config carries the prompts and `forward` requires a `role`:
+
+```typescript
+const queryEmbedding = await model.forward('What is the weather?', 'query');
+const docEmbedding = await model.forward('It is sunny today.', 'document');
+```
+
+The matching prompt is prepended automatically; for models without prompts the `role` argument is absent.
## Example
```typescript
-import { models, useTextEmbeddings } from 'react-native-executorch';
-const dotProduct = (a: number[], b: number[]) =>
- a.reduce((sum, val, i) => sum + val * b[i], 0);
+import { models, useTextEmbeddings, dotProduct } from 'react-native-executorch';
-const cosineSimilarity = (a: number[], b: number[]) => {
+const cosineSimilarity = (a: Float32Array, b: Float32Array) => {
const dot = dotProduct(a, b);
const normA = Math.sqrt(dotProduct(a, a));
const normB = Math.sqrt(dotProduct(b, b));
@@ -112,6 +132,8 @@ function App() {
| [distiluse-base-multilingual-cased-v2](https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased-v2) | 50+ languages | 126 | 512 | Multilingual DistilBERT with a 768→512 projection head. Recommended when broader language coverage matters more than the exact English quality of MiniLM/MPNet. |
| [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) | 50+ languages | 126 | 384 | Multilingual MiniLM-L12 distilled from paraphrase-multilingual-mpnet-base-v2. Compact (≈118 M params) sentence encoder for cross-lingual semantic similarity and retrieval across 50+ languages. |
| [clip-vit-base-patch32-text](https://huggingface.co/openai/clip-vit-base-patch32) | English | 74 | 512 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. CLIP allows to embed images and text into the same vector space. This allows to find similar images as well as to implement image search. This is the text encoder part of the CLIP model. To embed images checkout [clip-vit-base-patch32-image](../02-computer-vision/useImageEmbeddings.md#supported-models). |
+| [LFM2.5-Embedding-350M](https://huggingface.co/LiquidAI/LFM2.5-Embedding-350M) | Multilingual | 512 | 1024 | Dense bi-encoder from Liquid AI with CLS pooling. Trained with asymmetric `query: `/`document: ` prompts, so `forward` requires a `role`. On iOS it runs on the GPU via the MLX backend (physical device only); Android uses XNNPACK. |
+| [LFM2.5-ColBERT-350M](https://huggingface.co/LiquidAI/LFM2.5-ColBERT-350M) | Multilingual | 512 | 128 (per token) | Late-interaction (multi-vector) retriever from Liquid AI: a `Linear(1024→128)` head emits one normalized vector per token. `forward` returns an `EmbeddingResult`; score query/document pairs with MaxSim (see below). Uses `[Q] `/`[D] ` role prompts. |
**`Max Tokens`** - The maximum number of tokens that can be processed by the model. If the input text exceeds this limit, it will be truncated.
@@ -120,3 +142,28 @@ function App() {
:::note
For the supported models, the returned embedding vector is normalized, meaning that its length is equal to 1. This allows for easier comparison of vectors using cosine similarity, just calculate the dot product of two vectors to get the cosine similarity score.
:::
+
+## Late interaction (multi-vector models)
+
+Multi-vector models such as LFM2.5-ColBERT do not pool the sequence into a single vector. Instead, `forward` returns an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) holding one normalized vector per token. You score a query against a document with **MaxSim**: for every query-token vector, take its highest dot product against the document-token vectors, then sum those maxima. The model also ships a `skipListIds` array — the punctuation token ids excluded from scoring.
+
+The library ships a `maxSim` helper (and a `dotProduct` helper for pooled models), so you can score directly without reimplementing it:
+
+```typescript
+import { models, useTextEmbeddings, maxSim } from 'react-native-executorch';
+
+const colbert = models.text_embedding.lfm2_5_colbert_350m();
+const skipListIds = colbert.skipListIds ?? [];
+
+function App() {
+ const model = useTextEmbeddings({ model: colbert });
+
+ // ...
+
+ const query = await model.forward('What is the weather?', 'query');
+ const doc = await model.forward('It is sunny today.', 'document');
+ const score = maxSim(query, doc, skipListIds);
+}
+```
+
+The `skipListIds` shipped on the model config are the punctuation token ids excluded from scoring (derived from the model's training config). Per-token vectors are L2-normalized by the graph, so the dot product equals cosine similarity.
diff --git a/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md b/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
index aa563c213d..b4cd478b0e 100644
--- a/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
+++ b/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md
@@ -30,13 +30,20 @@ All methods of `TextEmbeddingsModule` are explained in details here: [`TextEmbed
Use the static [`fromModelName`](../../06-api-reference/classes/TextEmbeddingsModule.md#frommodelname) factory method. It accepts a model config object (e.g. `ALL_MINILM_L6_V2`) containing:
-- [`modelSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#modelsource) - Location of the used model.
-- [`tokenizerSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#tokenizersource) - Location of the used tokenizer.
+- `modelName` - Unique name identifying the model.
+- `modelSource` - Location of the used model.
+- `tokenizerSource` - Location of the used tokenizer.
+- `prompts` _(optional)_ - Asymmetric `query`/`document` prompts the model is trained with. When present, `forward` requires a `role` and prepends the matching prompt.
+- `multiVector` _(optional)_ - When `true`, `forward` returns the per-token `EmbeddingResult` instead of a single pooled `Float32Array`.
+- `skipListIds` _(optional)_ - Token ids to exclude from late-interaction (MaxSim) scoring.
-And an optional `onDownloadProgress` callback. It returns a promise resolving to a `TextEmbeddingsModule` instance.
+And an optional `onDownloadProgress` callback (receiving a value between 0 and 1). It returns a promise resolving to a `TextEmbeddingsModule` instance.
For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
## Running the model
-To run the model, you can use the [`forward`](../../06-api-reference/classes/TextEmbeddingsModule.md#forward) method. It accepts one argument, which is the text you want to embed. The method returns a promise, which can resolve either to an error or an array of numbers representing the embedding.
+To run the model, use the [`forward`](../../06-api-reference/classes/TextEmbeddingsModule.md#forward) method. It accepts the text to embed and, for models with asymmetric prompts, an optional `role` (`'query' | 'document'`). The method returns a promise resolving to:
+
+- a `Float32Array` — a single pooled vector — for standard models, or
+- an [`EmbeddingResult`](../../06-api-reference/interfaces/EmbeddingResult.md) with the per-token vectors for `multiVector` models.
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/classes/TextEmbeddingsModule.md b/docs/versioned_docs/version-0.9.x/06-api-reference/classes/TextEmbeddingsModule.md
index 2c6141349e..9bce0bfdff 100644
--- a/docs/versioned_docs/version-0.9.x/06-api-reference/classes/TextEmbeddingsModule.md
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/classes/TextEmbeddingsModule.md
@@ -1,8 +1,8 @@
# Class: TextEmbeddingsModule
-Defined in: [modules/natural_language_processing/TextEmbeddingsModule.ts:13](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L13)
+Defined in: [modules/natural\_language\_processing/TextEmbeddingsModule.ts:19](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L19)
-Module for generating text embeddings from input text.
+Module for managing a Text Embeddings model instance.
## Extends
@@ -23,14 +23,12 @@ making it worklet-compatible and safe to call from VisionCamera's
frame processor thread.
**Performance characteristics:**
-
- **Zero-copy path**: When using `frame.getNativeBuffer()` from VisionCamera v5,
frame data is accessed directly without copying (fastest, recommended).
- **Copy path**: When using `frame.toArrayBuffer()`, pixel data is copied
from native to JS, then accessed from native code (slower, fallback).
**Usage with VisionCamera:**
-
```typescript
const frameOutput = useFrameOutput({
pixelFormat: 'rgb',
@@ -39,16 +37,12 @@ const frameOutput = useFrameOutput({
// Zero-copy approach (recommended)
const nativeBuffer = frame.getNativeBuffer();
const result = model.generateFromFrame(
- {
- nativeBuffer: nativeBuffer.pointer,
- width: frame.width,
- height: frame.height,
- },
+ { nativeBuffer: nativeBuffer.pointer, width: frame.width, height: frame.height },
...args
);
nativeBuffer.release();
frame.dispose();
- },
+ }
});
```
@@ -80,7 +74,7 @@ Model-specific output (e.g., detections, classifications, embeddings)
`BaseModule.generateFromFrame`
----
+***
### nativeModule
@@ -116,15 +110,16 @@ Always call this method when you're done with a model to prevent memory leaks.
`BaseModule.delete`
----
+***
### forward()
-> **forward**(`input`): `Promise`\<`Float32Array`\<`ArrayBufferLike`\>\>
+> **forward**(`input`, `role?`): `Promise`\<`Float32Array`\<`ArrayBufferLike`\> \| [`EmbeddingResult`](../interfaces/EmbeddingResult.md)\>
-Defined in: [modules/natural_language_processing/TextEmbeddingsModule.ts:82](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L82)
+Defined in: [modules/natural\_language\_processing/TextEmbeddingsModule.ts:101](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L101)
-Executes the model's forward pass to generate an embedding for the provided text.
+Embed text into a pooled `Float32Array`, or a per-token `EmbeddingResult`
+for `multiVector` models.
#### Parameters
@@ -132,15 +127,26 @@ Executes the model's forward pass to generate an embedding for the provided text
`string`
-The text string to embed.
+The text to embed.
+
+##### role?
+
+[`EmbeddingRole`](../type-aliases/EmbeddingRole.md)
+
+Optional role ('query' | 'document') for models with
+ asymmetric prompts; prepends the model's prompt for that role.
#### Returns
-`Promise`\<`Float32Array`\<`ArrayBufferLike`\>\>
+`Promise`\<`Float32Array`\<`ArrayBufferLike`\> \| [`EmbeddingResult`](../interfaces/EmbeddingResult.md)\>
-A Promise resolving to a `Float32Array` containing the embedding vector.
+A `Float32Array` for pooled models, an `EmbeddingResult` otherwise.
----
+#### Throws
+
+If the model is not loaded.
+
+***
### forwardET()
@@ -171,7 +177,7 @@ Array of output tensors.
`BaseModule.forwardET`
----
+***
### getInputShape()
@@ -205,16 +211,18 @@ The input shape as an array of numbers.
`BaseModule.getInputShape`
----
+***
### fromCustomModel()
> `static` **fromCustomModel**(`modelSource`, `tokenizerSource`, `onDownloadProgress?`): `Promise`\<`TextEmbeddingsModule`\>
-Defined in: [modules/natural_language_processing/TextEmbeddingsModule.ts:62](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L62)
+Defined in: [modules/natural\_language\_processing/TextEmbeddingsModule.ts:77](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L77)
-Creates a text embeddings instance with a user-provided model binary and tokenizer.
-Use this when working with a custom-exported model that is not one of the built-in presets.
+Creates a text embeddings instance with a user-provided model binary.
+Use this when working with a custom-exported embeddings model. Internally
+uses `'custom'` as the model name. Note that prompts, multi-vector output,
+and skipLists are model-config features and are not configured here.
#### Parameters
@@ -242,18 +250,13 @@ Optional callback to monitor download progress, receiving a value between 0 and
A Promise resolving to a `TextEmbeddingsModule` instance.
-#### Remarks
-
-The native model contract for this method is not formally defined and may change
-between releases. Refer to the native source code for the current expected tensor interface.
-
----
+***
### fromModelName()
> `static` **fromModelName**(`namedSources`, `onDownloadProgress?`): `Promise`\<`TextEmbeddingsModule`\>
-Defined in: [modules/natural_language_processing/TextEmbeddingsModule.ts:25](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L25)
+Defined in: [modules/natural\_language\_processing/TextEmbeddingsModule.ts:42](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts#L42)
Creates a text embeddings instance for a built-in model.
@@ -261,25 +264,17 @@ Creates a text embeddings instance for a built-in model.
##### namedSources
-An object specifying which built-in model to load and where to fetch it from.
+[`TextEmbeddingsModel`](../interfaces/TextEmbeddingsModel.md)
-###### modelName
-
-[`TextEmbeddingsModelName`](../type-aliases/TextEmbeddingsModelName.md)
-
-###### modelSource
-
-[`ResourceSource`](../type-aliases/ResourceSource.md)
-
-###### tokenizerSource
-
-[`ResourceSource`](../type-aliases/ResourceSource.md)
+An object specifying the model name, model source,
+ tokenizer source, and optional `prompts` / `multiVector` / `skipListIds`.
##### onDownloadProgress?
(`progress`) => `void`
-Optional callback to monitor download progress, receiving a value between 0 and 1.
+Optional callback to monitor download progress,
+ receiving a value between 0 and 1.
#### Returns
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/functions/useTextEmbeddings.md b/docs/versioned_docs/version-0.9.x/06-api-reference/functions/useTextEmbeddings.md
index 6bc23e5219..b5de9d57b1 100644
--- a/docs/versioned_docs/version-0.9.x/06-api-reference/functions/useTextEmbeddings.md
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/functions/useTextEmbeddings.md
@@ -1,21 +1,30 @@
# Function: useTextEmbeddings()
-> **useTextEmbeddings**(`TextEmbeddingsProps`): [`TextEmbeddingsType`](../interfaces/TextEmbeddingsType.md)
+> **useTextEmbeddings**\<`M`\>(`TextEmbeddingsProps`): [`TextEmbeddingsType`](../interfaces/TextEmbeddingsType.md)\<`M`\>
-Defined in: [hooks/natural_language_processing/useTextEmbeddings.ts:14](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts#L14)
+Defined in: [hooks/natural\_language\_processing/useTextEmbeddings.ts:20](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts#L20)
React hook for managing a Text Embeddings model instance.
+## Type Parameters
+
+### M
+
+`M` *extends* [`TextEmbeddingsModel`](../interfaces/TextEmbeddingsModel.md)
+
## Parameters
### TextEmbeddingsProps
-[`TextEmbeddingsProps`](../interfaces/TextEmbeddingsProps.md)
+[`TextEmbeddingsProps`](../interfaces/TextEmbeddingsProps.md)\<`M`\>
Configuration object containing `model` source and optional `preventLoad` flag.
## Returns
-[`TextEmbeddingsType`](../interfaces/TextEmbeddingsType.md)
+[`TextEmbeddingsType`](../interfaces/TextEmbeddingsType.md)\<`M`\>
-Ready to use Text Embeddings model.
+Ready to use Text Embeddings model. `forward` returns a
+ `Float32Array` for pooled models and an `EmbeddingResult` (per-token
+ vectors) for multi-vector models. Models with prompts require a `role`
+ ('query' | 'document') on `forward`.
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingPrompts.md b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingPrompts.md
new file mode 100644
index 0000000000..0244afadd1
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingPrompts.md
@@ -0,0 +1,22 @@
+# Interface: EmbeddingPrompts
+
+Defined in: [types/textEmbeddings.ts:49](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L49)
+
+Asymmetric prompts a model is trained with. When a model config carries
+these, `forward` requires a `role` so the matching prompt is always applied.
+
+## Properties
+
+### document
+
+> **document**: `string`
+
+Defined in: [types/textEmbeddings.ts:51](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L51)
+
+***
+
+### query
+
+> **query**: `string`
+
+Defined in: [types/textEmbeddings.ts:50](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L50)
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingResult.md b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingResult.md
new file mode 100644
index 0000000000..e02bd77aa4
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/EmbeddingResult.md
@@ -0,0 +1,47 @@
+# Interface: EmbeddingResult
+
+Defined in: [types/textEmbeddings.ts:25](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L25)
+
+Per-token (multi-vector) embedding output for late-interaction models (e.g.
+ColBERT). Only `multiVector` models yield this; standard models return a
+pooled `Float32Array` from `forward` instead.
+
+## Properties
+
+### embeddingDim
+
+> **embeddingDim**: `number`
+
+Defined in: [types/textEmbeddings.ts:31](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L31)
+
+Per-token vector dimension.
+
+***
+
+### numTokens
+
+> **numTokens**: `number`
+
+Defined in: [types/textEmbeddings.ts:29](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L29)
+
+Number of token rows.
+
+***
+
+### tokenIds
+
+> **tokenIds**: `number`[]
+
+Defined in: [types/textEmbeddings.ts:33](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L33)
+
+Input token ids per row.
+
+***
+
+### vectors
+
+> **vectors**: `Float32Array`
+
+Defined in: [types/textEmbeddings.ts:27](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L27)
+
+Flat [numTokens * embeddingDim] fp32 vectors (row-major).
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsModel.md b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsModel.md
new file mode 100644
index 0000000000..6bd254a93a
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsModel.md
@@ -0,0 +1,60 @@
+# Interface: TextEmbeddingsModel
+
+Defined in: [types/textEmbeddings.ts:60](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L60)
+
+A text embeddings model config. Two optional flags drive `forward`:
+`prompts` makes a `role` argument required, and `multiVector` makes it return
+a per-token `EmbeddingResult` instead of a pooled `Float32Array`.
+
+## Properties
+
+### modelName
+
+> **modelName**: [`TextEmbeddingsModelName`](../type-aliases/TextEmbeddingsModelName.md)
+
+Defined in: [types/textEmbeddings.ts:61](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L61)
+
+***
+
+### modelSource
+
+> **modelSource**: [`ResourceSource`](../type-aliases/ResourceSource.md)
+
+Defined in: [types/textEmbeddings.ts:62](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L62)
+
+***
+
+### multiVector?
+
+> `optional` **multiVector**: `boolean`
+
+Defined in: [types/textEmbeddings.ts:65](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L65)
+
+***
+
+### prompts?
+
+> `optional` **prompts**: [`EmbeddingPrompts`](EmbeddingPrompts.md)
+
+Defined in: [types/textEmbeddings.ts:64](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L64)
+
+***
+
+### skipListIds?
+
+> `optional` **skipListIds**: `number`[]
+
+Defined in: [types/textEmbeddings.ts:72](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L72)
+
+Document token ids to exclude from late-interaction scoring (e.g. ColBERT's
+punctuation skipList). Derived from the model's training config, so it's
+shipped here rather than reconstructed by the consumer, who passes it to
+their own MaxSim scoring.
+
+***
+
+### tokenizerSource
+
+> **tokenizerSource**: [`ResourceSource`](../type-aliases/ResourceSource.md)
+
+Defined in: [types/textEmbeddings.ts:63](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L63)
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsProps.md b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsProps.md
index 1581b79edb..4556bd9dbd 100644
--- a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsProps.md
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsProps.md
@@ -1,43 +1,31 @@
-# Interface: TextEmbeddingsProps
+# Interface: TextEmbeddingsProps\
-Defined in: [types/textEmbeddings.ts:26](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L26)
+Defined in: [types/textEmbeddings.ts:112](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L112)
Props for the useTextEmbeddings hook.
-## Properties
-
-### model
-
-> **model**: `object`
-
-Defined in: [types/textEmbeddings.ts:27](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L27)
-
-An object containing the model configuration.
-
-#### modelName
-
-> **modelName**: [`TextEmbeddingsModelName`](../type-aliases/TextEmbeddingsModelName.md)
+## Type Parameters
-The unique name of the text embeddings model.
+### M
-#### modelSource
+`M` *extends* [`TextEmbeddingsModel`](TextEmbeddingsModel.md) = [`TextEmbeddingsModel`](TextEmbeddingsModel.md)
-> **modelSource**: [`ResourceSource`](../type-aliases/ResourceSource.md)
+## Properties
-The source of the text embeddings model binary.
+### model
-#### tokenizerSource
+> **model**: `M`
-> **tokenizerSource**: [`ResourceSource`](../type-aliases/ResourceSource.md)
+Defined in: [types/textEmbeddings.ts:115](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L115)
-The source of the tokenizer JSON file.
+An object containing the model configuration.
----
+***
### preventLoad?
> `optional` **preventLoad**: `boolean`
-Defined in: [types/textEmbeddings.ts:41](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L41)
+Defined in: [types/textEmbeddings.ts:116](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L116)
Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsType.md b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsType.md
index 78c267daf8..5f4b9c90dc 100644
--- a/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsType.md
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/interfaces/TextEmbeddingsType.md
@@ -1,8 +1,14 @@
-# Interface: TextEmbeddingsType
+# Interface: TextEmbeddingsType\
-Defined in: [types/textEmbeddings.ts:48](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L48)
+Defined in: [types/textEmbeddings.ts:123](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L123)
-React hook state and methods for managing a Text Embeddings model instance.
+React hook state and methods for a Text Embeddings model instance.
+
+## Type Parameters
+
+### M
+
+`M` *extends* [`TextEmbeddingsModel`](TextEmbeddingsModel.md) = [`TextEmbeddingsModel`](TextEmbeddingsModel.md)
## Properties
@@ -10,64 +16,62 @@ React hook state and methods for managing a Text Embeddings model instance.
> **downloadProgress**: `number`
-Defined in: [types/textEmbeddings.ts:67](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L67)
+Defined in: [types/textEmbeddings.ts:141](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L141)
Tracks the progress of the model download process (value between 0 and 1).
----
+***
### error
> **error**: [`RnExecutorchError`](../classes/RnExecutorchError.md) \| `null`
-Defined in: [types/textEmbeddings.ts:52](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L52)
+Defined in: [types/textEmbeddings.ts:129](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L129)
Contains the error message if the model failed to load or during inference.
----
+***
-### isGenerating
+### forward
-> **isGenerating**: `boolean`
-
-Defined in: [types/textEmbeddings.ts:62](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L62)
+> **forward**: [`ForwardFn`](../type-aliases/ForwardFn.md)\<`M`\>
-Indicates whether the model is currently generating embeddings.
+Defined in: [types/textEmbeddings.ts:149](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L149)
----
+Runs the text embeddings model on the provided input string.
-### isReady
+#### Param
-> **isReady**: `boolean`
+The text string to embed.
-Defined in: [types/textEmbeddings.ts:57](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L57)
+#### Param
-Indicates whether the embeddings model has successfully loaded and is ready for inference.
+Optional role for models with asymmetric prompts. Required if the model has `prompts`.
-## Methods
+#### Returns
-### forward()
+A promise resolving to a Float32Array or EmbeddingResult containing the vector embeddings.
-> **forward**(`input`): `Promise`\<`Float32Array`\<`ArrayBufferLike`\>\>
+#### Throws
-Defined in: [types/textEmbeddings.ts:75](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L75)
+If the model is not loaded or is currently processing another request.
-Runs the text embeddings model on the provided input string.
+***
-#### Parameters
+### isGenerating
-##### input
+> **isGenerating**: `boolean`
-`string`
+Defined in: [types/textEmbeddings.ts:137](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L137)
-The text string to embed.
+Indicates whether the model is currently generating embeddings.
-#### Returns
+***
-`Promise`\<`Float32Array`\<`ArrayBufferLike`\>\>
+### isReady
-A promise resolving to a Float32Array containing the vector embeddings.
+> **isReady**: `boolean`
-#### Throws
+Defined in: [types/textEmbeddings.ts:133](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L133)
-If the model is not loaded or is currently processing another request.
+Indicates whether the embeddings model has successfully loaded and is ready for inference.
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/EmbeddingRole.md b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/EmbeddingRole.md
new file mode 100644
index 0000000000..16d869dd78
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/EmbeddingRole.md
@@ -0,0 +1,9 @@
+# Type Alias: EmbeddingRole
+
+> **EmbeddingRole** = `"query"` \| `"document"`
+
+Defined in: [types/textEmbeddings.ts:42](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L42)
+
+Role for `forward`. Some models are trained with asymmetric query/document
+prompts (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `).
+Passing a role auto-prepends the model's configured prompt for that role.
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardFn.md b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardFn.md
new file mode 100644
index 0000000000..13e311b31e
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardFn.md
@@ -0,0 +1,15 @@
+# Type Alias: ForwardFn\
+
+> **ForwardFn**\<`M`\> = `M` *extends* `object` ? (`input`, `role`) => `Promise`\<[`ForwardReturn`](ForwardReturn.md)\<`M`\>\> : `undefined` *extends* `M`\[`"prompts"`\] ? `M`\[`"prompts"`\] *extends* `undefined` ? (`input`) => `Promise`\<[`ForwardReturn`](ForwardReturn.md)\<`M`\>\> : (`input`, `role?`) => `Promise`\<[`ForwardReturn`](ForwardReturn.md)\<`M`\>\> : (`input`) => `Promise`\<[`ForwardReturn`](ForwardReturn.md)\<`M`\>\>
+
+Defined in: [types/textEmbeddings.ts:90](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L90)
+
+`forward`'s signature, computed from the model config: `role` is required
+when the model has `prompts`, omitted when it has none, and optional when
+unknown (e.g. a heterogeneous model list).
+
+## Type Parameters
+
+### M
+
+`M` *extends* [`TextEmbeddingsModel`](../interfaces/TextEmbeddingsModel.md)
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardReturn.md b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardReturn.md
new file mode 100644
index 0000000000..8ee72147b8
--- /dev/null
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/ForwardReturn.md
@@ -0,0 +1,14 @@
+# Type Alias: ForwardReturn\
+
+> **ForwardReturn**\<`M`\> = `M` *extends* `object` ? [`EmbeddingResult`](../interfaces/EmbeddingResult.md) : `Float32Array`
+
+Defined in: [types/textEmbeddings.ts:79](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L79)
+
+`forward`'s return type: `EmbeddingResult` for `multiVector` models,
+`Float32Array` otherwise.
+
+## Type Parameters
+
+### M
+
+`M` *extends* [`TextEmbeddingsModel`](../interfaces/TextEmbeddingsModel.md)
diff --git a/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/TextEmbeddingsModelName.md b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/TextEmbeddingsModelName.md
index 4d419240ce..54abdf4901 100644
--- a/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/TextEmbeddingsModelName.md
+++ b/docs/versioned_docs/version-0.9.x/06-api-reference/type-aliases/TextEmbeddingsModelName.md
@@ -1,6 +1,6 @@
# Type Alias: TextEmbeddingsModelName
-> **TextEmbeddingsModelName** = `"all-minilm-l6-v2"` \| `"all-mpnet-base-v2"` \| `"multi-qa-minilm-l6-cos-v1"` \| `"multi-qa-mpnet-base-dot-v1"` \| `"distiluse-base-multilingual-cased-v2-8da4w"` \| `"paraphrase-multilingual-minilm-l12-v2-quantized"` \| `"clip-vit-base-patch32-text"`
+> **TextEmbeddingsModelName** = `"all-minilm-l6-v2"` \| `"all-mpnet-base-v2"` \| `"multi-qa-minilm-l6-cos-v1"` \| `"multi-qa-mpnet-base-dot-v1"` \| `"distiluse-base-multilingual-cased-v2-8da4w"` \| `"paraphrase-multilingual-minilm-l12-v2-quantized"` \| `"clip-vit-base-patch32-text"` \| `"lfm2-5-embedding-350m"` \| `"lfm2-5-colbert-350m"`
Defined in: [types/textEmbeddings.ts:8](https://github.com/software-mansion/react-native-executorch/blob/0e95b8934cc7318c1b30a8e534444a8b50d25230/packages/react-native-executorch/src/types/textEmbeddings.ts#L8)
diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
index 76e0fb90c7..dfd9243c48 100644
--- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
@@ -26,17 +26,15 @@ TokenizerModule::TokenizerModule(
memorySizeLowerBound = std::filesystem::file_size(modelPath);
}
-std::vector TokenizerModule::encode(std::string s) const {
+// When the tokenizer.json defines a post_processor, the underlying HFTokenizer
+// treats non-zero bos/eos as a flag to run it with add_special_token=true (not
+// a literal count). So bos=eos=0 skips special tokens; bos=eos=1 applies them.
+std::vector TokenizerModule::encodeImpl(const std::string &s,
+ int8_t bos, int8_t eos) const {
if (!tokenizer) {
THROW_NOT_LOADED_ERROR();
}
-
- // If the used tokenizer.json has defined post_processor field,
- // setting any of bos or eos arguments to value other than provided constant
- // ( which is 0) will result in running the post_processor with
- // 'add_special_token' flag
- auto encodeResult =
- tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens);
+ auto encodeResult = tokenizer->encode(s, bos, eos);
if (!encodeResult.ok()) {
throw RnExecutorchError(
RnExecutorchErrorCode::TokenizerError,
@@ -46,6 +44,15 @@ std::vector TokenizerModule::encode(std::string s) const {
return encodeResult.get();
}
+std::vector TokenizerModule::encode(std::string s) const {
+ return encodeImpl(s, numOfAddedBoSTokens, numOfAddedEoSTokens);
+}
+
+std::vector
+TokenizerModule::encodeWithSpecialTokens(std::string s) const {
+ return encodeImpl(s, /*bos=*/1, /*eos=*/1);
+}
+
std::string TokenizerModule::decode(std::vector vec,
bool skipSpecialTokens) const {
if (!tokenizer) {
diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
index 3c90b25557..0e1356f121 100644
--- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
+++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
@@ -13,6 +13,8 @@ class TokenizerModule {
std::shared_ptr callInvoker);
[[nodiscard("Registered non-void function")]] std::vector
encode(std::string s) const;
+ [[nodiscard("Registered non-void function")]] std::vector
+ encodeWithSpecialTokens(std::string s) const;
[[nodiscard("Registered non-void function")]] std::string
decode(std::vector vec, bool skipSpecialTokens) const;
[[nodiscard("Registered non-void function")]] std::string
@@ -24,6 +26,9 @@ class TokenizerModule {
std::size_t getMemoryLowerBound() const noexcept;
private:
+ std::vector encodeImpl(const std::string &s, int8_t bos,
+ int8_t eos) const;
+
std::unique_ptr tokenizer;
std::size_t memorySizeLowerBound{0};
};
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
index e4209b2f79..fdc87cd9af 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
@@ -17,6 +17,7 @@
#include
#include
+#include
#include
#include
#include
@@ -707,6 +708,30 @@ getJsiValue(const models::style_transfer::PixelDataResult &result,
return obj;
}
+inline jsi::Value getJsiValue(const models::embeddings::EmbeddingResult &result,
+ jsi::Runtime &runtime) {
+ jsi::Object obj(runtime);
+
+ auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr);
+ auto float32ArrayCtor =
+ runtime.global().getPropertyAsFunction(runtime, "Float32Array");
+ auto float32Array = float32ArrayCtor.callAsConstructor(runtime, arrayBuffer)
+ .getObject(runtime);
+ obj.setProperty(runtime, "dataPtr", float32Array);
+
+ obj.setProperty(runtime, "numTokens", jsi::Value(result.numTokens));
+ obj.setProperty(runtime, "embeddingDim", jsi::Value(result.embeddingDim));
+
+ auto idsArray = jsi::Array(runtime, result.tokenIds.size());
+ for (size_t i = 0; i < result.tokenIds.size(); ++i) {
+ idsArray.setValueAtIndex(
+ runtime, i, jsi::Value(static_cast(result.tokenIds[i])));
+ }
+ obj.setProperty(runtime, "tokenIds", idsArray);
+
+ return obj;
+}
+
inline jsi::Value getJsiValue(
const rnexecutorch::models::semantic_segmentation::SegmentationResult
&result,
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp
deleted file mode 100644
index bf291136c1..0000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "BaseEmbeddings.h"
-
-#include
-
-namespace rnexecutorch::models::embeddings {
-
-BaseEmbeddings::BaseEmbeddings(const std::string &modelSource,
- std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {}
-
-std::shared_ptr
-BaseEmbeddings::postprocess(const Result> &forwardResult) {
- auto forwardResultTensor = forwardResult->at(0).toTensor();
- auto buffer = std::make_shared(
- forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes());
- return buffer;
-}
-
-} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h
deleted file mode 100644
index 216d6bf8ce..0000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h
+++ /dev/null
@@ -1,17 +0,0 @@
-#pragma once
-
-#include
-
-namespace rnexecutorch::models::embeddings {
-
-class BaseEmbeddings : public BaseModel {
-public:
- BaseEmbeddings(const std::string &modelSource,
- std::shared_ptr callInvoker);
-
-protected:
- std::shared_ptr
- postprocess(const Result> &forwardResult);
-};
-
-}; // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h
new file mode 100644
index 0000000000..f2de1e899a
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace rnexecutorch::models::embeddings {
+
+// Text embedding output as a [numTokens, embeddingDim] fp32 matrix. Pooled
+// single-vector models output numTokens == 1 (the exported graph pools + L2-
+// normalizes); multi-vector (late-interaction / ColBERT) models output
+// numTokens == sequence length. The TS layer reduces to a single vector or
+// keeps the per-token matrix based on the model's config. `tokenIds` are the
+// input ids (used JS-side for late-interaction skiplist masking).
+struct EmbeddingResult {
+ std::shared_ptr dataPtr;
+ int32_t numTokens;
+ int32_t embeddingDim;
+ std::vector tokenIds;
+};
+
+} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
index ba2c3243b2..52a10b6e40 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
@@ -11,12 +11,12 @@ using namespace executorch::extension;
TextEmbeddings::TextEmbeddings(const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr callInvoker)
- : BaseEmbeddings(modelSource, callInvoker),
+ : BaseModel(modelSource, callInvoker),
tokenizer(
std::make_unique(tokenizerSource, callInvoker)) {}
TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
- auto inputIds = tokenizer->encode(input);
+ auto inputIds = tokenizer->encodeWithSpecialTokens(input);
// Tokenizers-cpp return tokens as int32, but text embedding models require
// int64 as input
std::vector inputIds64;
@@ -40,8 +40,7 @@ void TextEmbeddings::unload() noexcept {
BaseModel::unload();
}
-std::shared_ptr
-TextEmbeddings::generate(const std::string input) {
+EmbeddingResult TextEmbeddings::generate(const std::string input) {
std::scoped_lock lock(inference_mutex_);
auto preprocessed = preprocess(input);
@@ -58,7 +57,37 @@ TextEmbeddings::generate(const std::string input) {
auto forwardResult = BaseModel::forward({tokenIds, attnMask});
CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult);
- return BaseEmbeddings::postprocess(forwardResult);
+ return buildResult(forwardResult->at(0).toTensor(),
+ std::move(preprocessed.inputIds));
+}
+
+EmbeddingResult
+TextEmbeddings::buildResult(const executorch::aten::Tensor &output,
+ std::vector tokenIds) {
+ auto sizes = output.sizes();
+ if (sizes.size() < 2) {
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput,
+ "Embedding output must be at least 2D, got rank " +
+ std::to_string(sizes.size()));
+ }
+
+ const auto numTokens = static_cast(sizes[sizes.size() - 2]);
+ const auto inputTokens = static_cast(tokenIds.size());
+ if (numTokens != 1 && numTokens != inputTokens) {
+ throw RnExecutorchError(
+ RnExecutorchErrorCode::InvalidModelOutput,
+ "Embedding output rows (" + std::to_string(numTokens) +
+ ") != input tokens (" + std::to_string(inputTokens) +
+ "); per-token tokenIds alignment is broken.");
+ }
+
+ return EmbeddingResult{
+ .dataPtr = std::make_shared(output.const_data_ptr(),
+ output.nbytes()),
+ .numTokens = numTokens,
+ .embeddingDim = static_cast(sizes[sizes.size() - 1]),
+ .tokenIds = std::move(tokenIds),
+ };
}
} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
index 93d0988c04..587f697bd4 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
@@ -3,7 +3,8 @@
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
#include
#include
-#include
+#include
+#include
namespace rnexecutorch {
namespace models::embeddings {
@@ -13,13 +14,12 @@ struct TokenIdsWithAttentionMask {
std::vector attentionMask;
};
-class TextEmbeddings final : public BaseEmbeddings {
+class TextEmbeddings final : public BaseModel {
public:
TextEmbeddings(const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr callInvoker);
- [[nodiscard(
- "Registered non-void function")]] std::shared_ptr
+ [[nodiscard("Registered non-void function")]] EmbeddingResult
generate(const std::string input);
void unload() noexcept;
@@ -27,6 +27,8 @@ class TextEmbeddings final : public BaseEmbeddings {
mutable std::mutex inference_mutex_;
std::vector> inputShapes;
TokenIdsWithAttentionMask preprocess(const std::string &input);
+ static EmbeddingResult buildResult(const executorch::aten::Tensor &output,
+ std::vector tokenIds);
std::unique_ptr tokenizer;
};
} // namespace models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
index 68a9a9fef4..3bf5fa2206 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
@@ -16,9 +16,10 @@ Encoder::Encoder(const std::string &tokenizerSource,
encoderSource, tokenizerSource, callInvoker)) {}
std::vector Encoder::generate(std::string input) {
- std::shared_ptr embeddingsText = encoder->generate(input);
+ std::shared_ptr embeddingsText =
+ encoder->generate(input).dataPtr;
std::shared_ptr embeddingsUncond =
- encoder->generate(std::string(constants::kBosToken));
+ encoder->generate(std::string(constants::kBosToken)).dataPtr;
assert(embeddingsText->size() == embeddingsUncond->size());
size_t embeddingsSize = embeddingsText->size() / sizeof(float);
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
index 5f9d7287a5..a901cd56fc 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
@@ -218,7 +218,6 @@ add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp
add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
@@ -230,7 +229,6 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${TOKENIZER_SOURCES}
LIBS tokenizers_deps
)
@@ -306,7 +304,6 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp
${RNEXECUTORCH_DIR}/models/text_to_image/Decoder.cpp
${RNEXECUTORCH_DIR}/models/text_to_image/Scheduler.cpp
${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${TOKENIZER_SOURCES}
LIBS tokenizers_deps
)
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
index ff1abd4c30..cf7d6c4804 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
@@ -53,23 +53,23 @@ TEST(TextEmbeddingsGenerateTests, EmptyStringReturnsResults) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("");
- EXPECT_NE(result, nullptr);
- EXPECT_GT(result->size(), 0u);
+ EXPECT_NE(result.dataPtr, nullptr);
+ EXPECT_GT(result.dataPtr->size(), 0u);
}
TEST(TextEmbeddingsGenerateTests, ValidTextReturnsResults) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("Hello, world!");
- EXPECT_NE(result, nullptr);
- EXPECT_GT(result->size(), 0u);
+ EXPECT_NE(result.dataPtr, nullptr);
+ EXPECT_GT(result.dataPtr->size(), 0u);
}
TEST(TextEmbeddingsGenerateTests, ResultsHaveCorrectSize) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("This is a test sentence.");
- size_t numFloats = result->size() / sizeof(float);
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
EXPECT_EQ(numFloats, kMiniLmEmbeddingDimensions);
}
@@ -78,8 +78,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsAreNormalized) {
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("The quick brown fox jumps over the lazy dog.");
- const float *data = reinterpret_cast(result->data());
- size_t numFloats = result->size() / sizeof(float);
+ const float *data = reinterpret_cast(result.dataPtr->data());
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
float sumOfSquares = 0.0f;
for (size_t i = 0; i < numFloats; ++i) {
@@ -94,8 +94,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsContainValidValues) {
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("Testing valid values.");
- const float *data = reinterpret_cast(result->data());
- size_t numFloats = result->size() / sizeof(float);
+ const float *data = reinterpret_cast(result.dataPtr->data());
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
for (size_t i = 0; i < numFloats; ++i) {
EXPECT_FALSE(std::isnan(data[i]));
@@ -110,9 +110,9 @@ TEST(TextEmbeddingsGenerateTests, DifferentTextProducesDifferentEmbeddings) {
auto result1 = model.generate("Hello, world!");
auto result2 = model.generate("Goodbye, moon!");
- const float *data1 = reinterpret_cast(result1->data());
- const float *data2 = reinterpret_cast(result2->data());
- size_t numFloats = result1->size() / sizeof(float);
+ const float *data1 = reinterpret_cast(result1.dataPtr->data());
+ const float *data2 = reinterpret_cast(result2.dataPtr->data());
+ size_t numFloats = result1.dataPtr->size() / sizeof(float);
bool allEqual = true;
for (size_t i = 0; i < numFloats; ++i) {
@@ -131,9 +131,9 @@ TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) {
auto result1 = model.generate("I love programming");
auto result2 = model.generate("I enjoy coding");
- const float *data1 = reinterpret_cast(result1->data());
- const float *data2 = reinterpret_cast(result2->data());
- size_t numFloats = result1->size() / sizeof(float);
+ const float *data1 = reinterpret_cast(result1.dataPtr->data());
+ const float *data2 = reinterpret_cast(result2.dataPtr->data());
+ size_t numFloats = result1.dataPtr->size() / sizeof(float);
float dotProduct = 0.0f;
for (size_t i = 0; i < numFloats; ++i) {
@@ -142,6 +142,39 @@ TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) {
EXPECT_GT(dotProduct, 0.5f);
}
+TEST(TextEmbeddingsGenerateTests, PooledResultMetadataIsConsistent) {
+ TextEmbeddings model(kValidTextEmbeddingsModelPath,
+ kValidTextEmbeddingsTokenizerPath, nullptr);
+ auto result = model.generate("A pooled embedding has a single row.");
+
+ EXPECT_EQ(result.numTokens, 1);
+ EXPECT_EQ(result.embeddingDim,
+ static_cast(kMiniLmEmbeddingDimensions));
+ EXPECT_EQ(result.dataPtr->size(),
+ static_cast(result.numTokens) * result.embeddingDim *
+ sizeof(float));
+}
+
+TEST(TextEmbeddingsGenerateTests, TokenIdsIncludeSpecialTokens) {
+ TextEmbeddings model(kValidTextEmbeddingsModelPath,
+ kValidTextEmbeddingsTokenizerPath, nullptr);
+ auto result = model.generate("Hello");
+
+ // The tokenizer post_processor wraps the input as [CLS] ... [SEP], so even a
+ // single word yields more than one token id.
+ EXPECT_GT(result.tokenIds.size(), 1u);
+}
+
+TEST(TextEmbeddingsGenerateTests, TokenIdsGrowWithInputLength) {
+ TextEmbeddings model(kValidTextEmbeddingsModelPath,
+ kValidTextEmbeddingsTokenizerPath, nullptr);
+ auto shortResult = model.generate("Hi");
+ auto longResult =
+ model.generate("This sentence is considerably longer than the other.");
+
+ EXPECT_GT(longResult.tokenIds.size(), shortResult.tokenIds.size());
+}
+
TEST(TextEmbeddingsInheritedTests, GetInputShapeWorks) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts
index 5a25c16bd6..5cbc3d981b 100644
--- a/packages/react-native-executorch/src/constants/modelRegistry.ts
+++ b/packages/react-native-executorch/src/constants/modelRegistry.ts
@@ -3,6 +3,10 @@ import { isEmulatorSync } from 'react-native-device-info';
import * as M from './modelUrls';
import * as OCR from './ocr/models';
import { symbols } from './ocr/symbols';
+import {
+ LFM_COLBERT_PROMPTS,
+ LFM_COLBERT_SKIP_LIST,
+} from './textEmbeddings/colbert';
import {
KOKORO_AMERICAN_ENGLISH_FEMALE_HEART,
KOKORO_AMERICAN_ENGLISH_FEMALE_RIVER,
@@ -260,6 +264,53 @@ const GEMMA4_E2B_MM_VARIANTS = {
},
};
+const LFM_EMBEDDING_PROMPTS = { query: 'query: ', document: 'document: ' };
+
+const LFM2_5_EMBEDDING_350M_CONFIG = {
+ modelName: 'lfm2-5-embedding-350m' as const,
+ tokenizerSource: M.LFM2_5_EMBEDDING_350M_TOKENIZER,
+ prompts: LFM_EMBEDDING_PROMPTS,
+ multiVector: false as const,
+};
+
+const LFM2_5_EMBEDDING_350M_VARIANTS = {
+ mlx: {
+ base: {
+ ...LFM2_5_EMBEDDING_350M_CONFIG,
+ modelSource: M.LFM2_5_EMBEDDING_350M_MLX_MODEL,
+ },
+ },
+ xnnpack: {
+ base: {
+ ...LFM2_5_EMBEDDING_350M_CONFIG,
+ modelSource: M.LFM2_5_EMBEDDING_350M_XNNPACK_MODEL,
+ },
+ },
+};
+
+const LFM2_5_COLBERT_350M_CONFIG = {
+ modelName: 'lfm2-5-colbert-350m' as const,
+ tokenizerSource: M.LFM2_5_COLBERT_350M_TOKENIZER,
+ prompts: LFM_COLBERT_PROMPTS,
+ multiVector: true as const,
+ skipListIds: LFM_COLBERT_SKIP_LIST,
+};
+
+const LFM2_5_COLBERT_350M_VARIANTS = {
+ mlx: {
+ base: {
+ ...LFM2_5_COLBERT_350M_CONFIG,
+ modelSource: M.LFM2_5_COLBERT_350M_MLX_MODEL,
+ },
+ },
+ xnnpack: {
+ base: {
+ ...LFM2_5_COLBERT_350M_CONFIG,
+ modelSource: M.LFM2_5_COLBERT_350M_XNNPACK_MODEL,
+ },
+ },
+};
+
const LFM2_5_350M_VARIANTS = {
mlx: { base: { ...M.LFM2_5_350M, modelSource: M.LFM2_5_350M_MLX_MODEL } },
xnnpack: { base: M.LFM2_5_350M, quant: M.LFM2_5_350M_QUANTIZED },
@@ -799,6 +850,14 @@ export const models = {
M.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED
),
clip_vit_base_patch32_text: base(M.CLIP_VIT_BASE_PATCH32_TEXT),
+ lfm2_5_embedding_350m: variant(LFM2_5_EMBEDDING_350M_VARIANTS, {
+ ios: 'mlx',
+ android: 'xnnpack',
+ }),
+ lfm2_5_colbert_350m: variant(LFM2_5_COLBERT_350M_VARIANTS, {
+ ios: 'mlx',
+ android: 'xnnpack',
+ }),
},
image_embedding: {
clip_vit_base_patch32_image: pair(
diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts
index c74fb5a289..8429155dd2 100644
--- a/packages/react-native-executorch/src/constants/modelUrls.ts
+++ b/packages/react-native-executorch/src/constants/modelUrls.ts
@@ -1199,6 +1199,12 @@ export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_8DA4W_MODEL = `${URL_PREFIX}-d
export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_TOKENIZER = `${URL_PREFIX}-distiluse-base-multilingual-cased-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED_MODEL = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/xnnpack/paraphrase_multilingual_minilm_l12_v2_xnnpack_8da4w.pte`;
const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_TOKENIZER = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
+export const LFM2_5_EMBEDDING_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_embedding_350m_xnnpack_8da4w.pte`;
+export const LFM2_5_EMBEDDING_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_embedding_350m_mlx_int4.pte`;
+export const LFM2_5_EMBEDDING_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
+export const LFM2_5_COLBERT_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_colbert_350m_xnnpack_8da4w.pte`;
+export const LFM2_5_COLBERT_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_colbert_350m_mlx_int4.pte`;
+export const LFM2_5_COLBERT_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
const CLIP_VIT_BASE_PATCH32_TEXT_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/xnnpack/clip_vit_base_patch32_text_xnnpack_fp32.pte`;
const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
diff --git a/packages/react-native-executorch/src/constants/textEmbeddings/colbert.ts b/packages/react-native-executorch/src/constants/textEmbeddings/colbert.ts
new file mode 100644
index 0000000000..9f60f5d87b
--- /dev/null
+++ b/packages/react-native-executorch/src/constants/textEmbeddings/colbert.ts
@@ -0,0 +1,7 @@
+export const LFM_COLBERT_SKIP_LIST = [
+ 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524,
+ 535, 536, 537, 538, 539, 540, 541, 568, 569, 570, 571, 572, 573, 600, 601,
+ 602, 603,
+];
+
+export const LFM_COLBERT_PROMPTS = { query: '[Q] ', document: '[D] ' };
diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
index 31ee179925..9e3fa7f0e4 100644
--- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
+++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
@@ -1,6 +1,9 @@
import { TextEmbeddingsModule } from '../../modules/natural_language_processing/TextEmbeddingsModule';
import { useModuleFactory } from '../useModuleFactory';
import {
+ EmbeddingRole,
+ ForwardFn,
+ TextEmbeddingsModel,
TextEmbeddingsType,
TextEmbeddingsProps,
} from '../../types/textEmbeddings';
@@ -9,12 +12,15 @@ import {
* React hook for managing a Text Embeddings model instance.
* @category Hooks
* @param TextEmbeddingsProps - Configuration object containing `model` source and optional `preventLoad` flag.
- * @returns Ready to use Text Embeddings model.
+ * @returns Ready to use Text Embeddings model. `forward` returns a
+ * `Float32Array` for pooled models and an `EmbeddingResult` (per-token
+ * vectors) for multi-vector models. Models with prompts require a `role`
+ * ('query' | 'document') on `forward`.
*/
-export const useTextEmbeddings = ({
+export const useTextEmbeddings = ({
model,
preventLoad = false,
-}: TextEmbeddingsProps): TextEmbeddingsType => {
+}: TextEmbeddingsProps): TextEmbeddingsType => {
const { error, isReady, isGenerating, downloadProgress, runForward } =
useModuleFactory({
factory: (config, onProgress) =>
@@ -24,7 +30,8 @@ export const useTextEmbeddings = ({
preventLoad,
});
- const forward = (input: string) => runForward((inst) => inst.forward(input));
+ const forward = ((input: string, role?: EmbeddingRole) =>
+ runForward((inst) => inst.forward(input, role))) as ForwardFn;
return { error, isReady, isGenerating, downloadProgress, forward };
};
diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts
index 1f190d41f5..ba7ac384f9 100644
--- a/packages/react-native-executorch/src/index.ts
+++ b/packages/react-native-executorch/src/index.ts
@@ -215,6 +215,7 @@ export * from './utils/llm';
export * from './common/Logger';
export * from './utils/llms/context_strategy';
export * from './utils/segmentAnythingPrompts';
+export * from './utils/textEmbeddings';
// types
export * from './types/objectDetection';
diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
index 27b0e59ceb..b9e2e866d1 100644
--- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
+++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
@@ -1,5 +1,11 @@
import { ResourceSource } from '../../types/common';
-import { TextEmbeddingsModelName } from '../../types/textEmbeddings';
+import {
+ EmbeddingPrompts,
+ EmbeddingResult,
+ EmbeddingRole,
+ TextEmbeddingsModel,
+ TextEmbeddingsModelName,
+} from '../../types/textEmbeddings';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { BaseModule } from '../BaseModule';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
@@ -7,27 +13,34 @@ import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
- * Module for generating text embeddings from input text.
+ * Module for managing a Text Embeddings model instance.
* @category Typescript API
*/
export class TextEmbeddingsModule extends BaseModule {
- private constructor(nativeModule: unknown) {
+ private prompts?: EmbeddingPrompts;
+ private multiVector: boolean;
+
+ private constructor(
+ nativeModule: unknown,
+ prompts: EmbeddingPrompts | undefined,
+ multiVector: boolean
+ ) {
super();
this.nativeModule = nativeModule;
+ this.prompts = prompts;
+ this.multiVector = multiVector;
}
/**
* Creates a text embeddings instance for a built-in model.
- * @param namedSources - An object specifying which built-in model to load and where to fetch it from.
- * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
+ * @param namedSources - An object specifying the model name, model source,
+ * tokenizer source, and optional `prompts` / `multiVector` / `skipListIds`.
+ * @param onDownloadProgress - Optional callback to monitor download progress,
+ * receiving a value between 0 and 1.
* @returns A Promise resolving to a `TextEmbeddingsModule` instance.
*/
static async fromModelName(
- namedSources: {
- modelName: TextEmbeddingsModelName;
- modelSource: ResourceSource;
- tokenizerSource: ResourceSource;
- },
+ namedSources: TextEmbeddingsModel,
onDownloadProgress: (progress: number) => void = () => {}
): Promise {
try {
@@ -41,7 +54,9 @@ export class TextEmbeddingsModule extends BaseModule {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted);
}
return new TextEmbeddingsModule(
- await global.loadTextEmbeddings(modelPath, tokenizerPath)
+ await global.loadTextEmbeddings(modelPath, tokenizerPath),
+ namedSources.prompts,
+ namedSources.multiVector ?? false
);
} catch (error) {
Logger.error('Load failed:', error);
@@ -50,10 +65,10 @@ export class TextEmbeddingsModule extends BaseModule {
}
/**
- * Creates a text embeddings instance with a user-provided model binary and tokenizer.
- * Use this when working with a custom-exported model that is not one of the built-in presets.
- * @remarks The native model contract for this method is not formally defined and may change
- * between releases. Refer to the native source code for the current expected tensor interface.
+ * Creates a text embeddings instance with a user-provided model binary.
+ * Use this when working with a custom-exported embeddings model. Internally
+ * uses `'custom'` as the model name. Note that prompts, multi-vector output,
+ * and skipLists are model-config features and are not configured here.
* @param modelSource - A fetchable resource pointing to the model binary.
* @param tokenizerSource - A fetchable resource pointing to the tokenizer file.
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
@@ -75,13 +90,33 @@ export class TextEmbeddingsModule extends BaseModule {
}
/**
- * Executes the model's forward pass to generate an embedding for the provided text.
- * @param input - The text string to embed.
- * @returns A Promise resolving to a `Float32Array` containing the embedding vector.
+ * Embed text into a pooled `Float32Array`, or a per-token `EmbeddingResult`
+ * for `multiVector` models.
+ * @param input - The text to embed.
+ * @param role - Role ('query' | 'document') for models with asymmetric
+ * prompts; the matching prompt is prepended. The `useTextEmbeddings` types
+ * require it for prompted models and omit it for the rest; at the module
+ * level it is optional and a no-op when the model has no prompts.
+ * @returns A `Float32Array` for pooled models, an `EmbeddingResult` otherwise.
+ * @throws {RnExecutorchError} If the model is not loaded.
*/
- async forward(input: string): Promise {
+ async forward(
+ input: string,
+ role?: EmbeddingRole
+ ): Promise {
if (this.nativeModule == null)
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
- return new Float32Array(await this.nativeModule.generate(input));
+ const prefix = (role && this.prompts?.[role]) || '';
+ const res = await this.nativeModule.generate(prefix + input);
+ const vectors = res.dataPtr as Float32Array;
+ if (!this.multiVector) {
+ return vectors.subarray(0, res.embeddingDim);
+ }
+ return {
+ vectors,
+ numTokens: res.numTokens,
+ embeddingDim: res.embeddingDim,
+ tokenIds: res.tokenIds,
+ };
}
}
diff --git a/packages/react-native-executorch/src/types/textEmbeddings.ts b/packages/react-native-executorch/src/types/textEmbeddings.ts
index d9cd120e26..9f24cbcf1d 100644
--- a/packages/react-native-executorch/src/types/textEmbeddings.ts
+++ b/packages/react-native-executorch/src/types/textEmbeddings.ts
@@ -12,7 +12,90 @@ export type TextEmbeddingsModelName =
| 'multi-qa-mpnet-base-dot-v1'
| 'distiluse-base-multilingual-cased-v2-8da4w'
| 'paraphrase-multilingual-minilm-l12-v2-quantized'
- | 'clip-vit-base-patch32-text';
+ | 'clip-vit-base-patch32-text'
+ | 'lfm2-5-embedding-350m'
+ | 'lfm2-5-colbert-350m';
+
+/**
+ * Per-token (multi-vector) embedding output for late-interaction models (e.g.
+ * ColBERT). Only `multiVector` models yield this; standard models return a
+ * pooled `Float32Array` from `forward` instead.
+ * @category Types
+ */
+export interface EmbeddingResult {
+ /** Flat [numTokens * embeddingDim] fp32 vectors (row-major). */
+ vectors: Float32Array;
+ /** Number of token rows. */
+ numTokens: number;
+ /** Per-token vector dimension. */
+ embeddingDim: number;
+ /** Input token ids per row. */
+ tokenIds: number[];
+}
+
+/**
+ * Role for `forward`. Some models are trained with asymmetric query/document
+ * prompts (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `).
+ * Passing a role auto-prepends the model's configured prompt for that role.
+ * @category Types
+ */
+export type EmbeddingRole = 'query' | 'document';
+
+/**
+ * Asymmetric prompts a model is trained with. When a model config carries
+ * these, `forward` requires a `role` so the matching prompt is always applied.
+ * @category Types
+ */
+export interface EmbeddingPrompts {
+ query: string;
+ document: string;
+}
+
+/**
+ * A text embeddings model config. Two optional flags drive `forward`:
+ * `prompts` makes a `role` argument required, and `multiVector` makes it return
+ * a per-token `EmbeddingResult` instead of a pooled `Float32Array`.
+ * @category Types
+ */
+export interface TextEmbeddingsModel {
+ modelName: TextEmbeddingsModelName;
+ modelSource: ResourceSource;
+ tokenizerSource: ResourceSource;
+ prompts?: EmbeddingPrompts;
+ multiVector?: boolean;
+ /**
+ * Document token ids to exclude from late-interaction scoring (e.g. ColBERT's
+ * punctuation skipList). Derived from the model's training config, so it's
+ * shipped here rather than reconstructed by the consumer, who passes it to
+ * their own MaxSim scoring.
+ */
+ skipListIds?: number[];
+}
+
+/**
+ * `forward`'s return type: `EmbeddingResult` for `multiVector` models,
+ * `Float32Array` otherwise.
+ */
+export type ForwardReturn = M extends {
+ multiVector: true;
+}
+ ? EmbeddingResult
+ : Float32Array;
+
+/**
+ * `forward`'s signature, computed from the model config: `role` is required
+ * when the model has `prompts`, omitted when it has none, and optional when
+ * unknown (e.g. a heterogeneous model list).
+ */
+export type ForwardFn = M extends {
+ prompts: EmbeddingPrompts;
+}
+ ? (input: string, role: EmbeddingRole) => Promise>
+ : undefined extends M['prompts']
+ ? M['prompts'] extends undefined
+ ? (input: string) => Promise>
+ : (input: string, role?: EmbeddingRole) => Promise>
+ : (input: string) => Promise>;
/**
* Props for the useTextEmbeddings hook.
@@ -21,56 +104,47 @@ export type TextEmbeddingsModelName =
* @property {TextEmbeddingsModelName} model.modelName - Unique name identifying the model.
* @property {ResourceSource} model.modelSource - The source of the text embeddings model binary.
* @property {ResourceSource} model.tokenizerSource - The source of the tokenizer JSON file.
+ * @property {EmbeddingPrompts} [model.prompts] - Optional asymmetric prompts for query/document roles.
+ * @property {boolean} [model.multiVector] - Optional flag indicating if the model returns per-token embeddings.
+ * @property {number[]} [model.skipListIds] - Optional array of token IDs to skip during scoring.
* @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
*/
-export interface TextEmbeddingsProps {
- model: {
- /**
- * The unique name of the text embeddings model.
- */
- modelName: TextEmbeddingsModelName;
- /**
- * The source of the text embeddings model binary.
- */
- modelSource: ResourceSource;
- /**
- * The source of the tokenizer JSON file.
- */
- tokenizerSource: ResourceSource;
- };
+export interface TextEmbeddingsProps<
+ M extends TextEmbeddingsModel = TextEmbeddingsModel,
+> {
+ model: M;
preventLoad?: boolean;
}
/**
- * React hook state and methods for managing a Text Embeddings model instance.
+ * React hook state and methods for a Text Embeddings model instance.
* @category Types
*/
-export interface TextEmbeddingsType {
+export interface TextEmbeddingsType<
+ M extends TextEmbeddingsModel = TextEmbeddingsModel,
+> {
/**
* Contains the error message if the model failed to load or during inference.
*/
error: null | RnExecutorchError;
-
/**
* Indicates whether the embeddings model has successfully loaded and is ready for inference.
*/
isReady: boolean;
-
/**
* Indicates whether the model is currently generating embeddings.
*/
isGenerating: boolean;
-
/**
* Tracks the progress of the model download process (value between 0 and 1).
*/
downloadProgress: number;
-
/**
* Runs the text embeddings model on the provided input string.
* @param input - The text string to embed.
- * @returns A promise resolving to a Float32Array containing the vector embeddings.
+ * @param role - Optional role for models with asymmetric prompts. Required if the model has `prompts`.
+ * @returns A promise resolving to a Float32Array or EmbeddingResult containing the vector embeddings.
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another request.
*/
- forward(input: string): Promise;
+ forward: ForwardFn;
}
diff --git a/packages/react-native-executorch/src/utils/textEmbeddings.ts b/packages/react-native-executorch/src/utils/textEmbeddings.ts
new file mode 100644
index 0000000000..1dd241661c
--- /dev/null
+++ b/packages/react-native-executorch/src/utils/textEmbeddings.ts
@@ -0,0 +1,43 @@
+import { EmbeddingResult } from '../types/textEmbeddings';
+import { RnExecutorchError } from '../errors/errorUtils';
+import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
+
+export const dotProduct = (a: Float32Array, b: Float32Array) => {
+ if (a.length !== b.length) {
+ throw new RnExecutorchError(
+ RnExecutorchErrorCode.WrongDimensions,
+ `dotProduct needs both vector to have the same length: got a: ${a.length}, b: ${b.length}`
+ );
+ }
+
+ let sum = 0;
+ for (let i = 0; i < a.length; i++) {
+ sum += (a[i] ?? 0) * (b[i] ?? 0);
+ }
+ return sum;
+};
+
+export const maxSim = (
+ query: EmbeddingResult,
+ doc: EmbeddingResult,
+ skipListIds: number[] = []
+) => {
+ const dim = query.embeddingDim;
+ const skip = new Set(skipListIds);
+ let score = 0;
+ for (let qi = 0; qi < query.numTokens; qi++) {
+ const qOff = qi * dim;
+ let best = -Infinity;
+ for (let di = 0; di < doc.numTokens; di++) {
+ if (skip.has(doc.tokenIds[di]!)) continue;
+ const dOff = di * dim;
+ let dot = 0;
+ for (let k = 0; k < dim; k++) {
+ dot += (query.vectors[qOff + k] ?? 0) * (doc.vectors[dOff + k] ?? 0);
+ }
+ if (dot > best) best = dot;
+ }
+ if (best !== -Infinity) score += best;
+ }
+ return score;
+};