Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ import {
import { GenerateResponse } from './generate/response.js';
import { Message } from './message.js';
import {
GenerateRequestSchema,
GenerateResponseChunkData,
GenerateResponseData,
GenerationUsageSchema,
ResolvedModel,
resolveModel,
type GenerateActionOptions,
type GenerateRequest,
type GenerationCommonConfigSchema,
type GenerationUsage,
type MessageData,
type MiddlewareRef,
type ModelArgument,
type ModelMiddlewareArgument,
type Part,
type TokenCounterAction,
type ToolRequestPart,
type ToolResponsePart,
} from './model.js';
Expand Down Expand Up @@ -796,3 +800,106 @@ export function tagAsPreamble(msgs?: MessageData[]): MessageData[] | undefined {
},
}));
}

/**
* Counts the tokens for a given generate request.
*/
export async function countTokens<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
registry: Registry,
options:
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
): Promise<GenerationUsage> {
const resolvedOptions: GenerateOptions<O, CustomOptions> = {
...(await Promise.resolve(options)),
};

const childRegistry = Registry.withParent(registry);

maybeRegisterDynamicTools(childRegistry, resolvedOptions);
maybeRegisterDynamicResources(childRegistry, resolvedOptions);

const resolvedModel = await resolveModel(
childRegistry,
resolvedOptions.model
);

const tools = await toolsToActionRefs(childRegistry, resolvedOptions.tools);
const resources = await resourcesToActionRefs(
childRegistry,
resolvedOptions.resources
);

const request = await toGenerateRequest(childRegistry, {
...resolvedOptions,
tools,
resources,
});

request.config = {
...(resolvedModel?.version ? { version: resolvedModel.version } : {}),
...stripUndefinedOptions(resolvedModel?.config),
...stripUndefinedOptions(request.config),
};
Comment thread
ifielker marked this conversation as resolved.
if (Object.keys(request.config || {}).length === 0) {
delete request.config;
}

const middlewareRefs = await normalizeMiddleware(
childRegistry,
resolvedOptions.use
);
const resolvedMiddleware = await resolveMiddleware(
childRegistry,
middlewareRefs
);
maybeRegisterDynamicMiddlewareTools(childRegistry, resolvedMiddleware);

let interceptedRequest = request;
if (resolvedMiddleware && resolvedMiddleware.length > 0) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? there's no middleware for counters...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ai.countTokens takes a fully formed request. Middleware can modify the request. Example: when the model itself doesn't support X so the middleware modifies the request to a format the model does support - Like downloading an image from an unsupported url and including the image in the media instead of the url only. That's a large difference in tokens used. If we countTokens on the unmodified request it will include the length of the url. If we countTokens on the post-middleware request it will include the length of the base64 string...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is true that middleware can modify the request, but also middleware can do a lot of other things (even with potential side-effects)... I'm worried that it may not be obvious to the user and can lead to undesired side-effects.

I think that countTokens should not take a GenerateRequest but just a list of Messages.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That leads to inaccuracies in token counting though... and potentially large ones. The details of the request including middleware and everything else that happens to it is important if we want accurate token counting. And there's a BIG difference between a url string and a downloaded media dataurl...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I suspect our users will often want something like:

const request = {model, config, ...};
const tokens = await ai.countTokens(request);
if (tokens.totalTokens < MAX_TOKENS_ALLOWED) {
  const response = await ai.generate(request);
  // ...
}

and it would probably be bad if tokens.totalTokens was thousands of tokens off the actual answer...

const dispatchModel = async (
index: number,
req: GenerateRequest
): Promise<GenerateResponseData> => {
if (index === resolvedMiddleware.length) {
interceptedRequest = req;
// Return a dummy response to safely unwind the middleware chain without
// executing the actual model generation, since we only want to intercept the mutated request.
return {
message: { role: 'model', content: [] },
finishReason: 'stop',
};
}
const currentMiddleware = resolvedMiddleware[index];
if (currentMiddleware.model) {
return currentMiddleware.model(req, {}, async (modifiedReq) =>
Comment thread
ifielker marked this conversation as resolved.
dispatchModel(index + 1, modifiedReq || req)
);
} else {
return dispatchModel(index + 1, req);
}
};
await dispatchModel(0, request);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why call the model?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how we get the modified request. We essentially apply all the middleware to the request and then return a dummy response to safely unwind the middleware chain without executing the model and then we have the mutated request that we can give to countTokens.

}

const counterActionName = `/model-token-counter/${resolvedModel.modelAction.__action.name}`;
const counterAction =
resolvedModel.modelAction.__tokenCounterAction ||
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in what case is resolvedModel.modelAction.__tokenCounterAction not set?

Copy link
Copy Markdown
Collaborator Author

@ifielker ifielker Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not set if the countTokens was not defined during the model definition (not all models have token counting endpoints in Gemini API for example. Theoretically if someone wanted a tokenCounter that is not tied to a model, they could add one using something like:

const genericTokenCounter = action(
  {
    actionType: 'model-token-counter',
    name: 'generic-counter',
    inputSchema: GenerateRequestSchema,
    outputSchema: GenerationUsageSchema,
  },
  async (request) => {
    return {
      totalTokens: 42,
    };
  }
);

and then we would still find it using this search. Unlikely, but it keeps the possibility open. Also the error for not finding the countTokens action will be the standard "action not found".

(await childRegistry.lookupAction<
typeof GenerateRequestSchema,
typeof GenerationUsageSchema,
TokenCounterAction<CustomOptions>
>(counterActionName));

if (!counterAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model '${resolvedModel.modelAction.__action.name}' does not support token counting (model-token-counter action not found).`,
});
}

return await counterAction(interceptedRequest);
}
90 changes: 89 additions & 1 deletion js/ai/src/genkit-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@ import {
type EmbeddingBatch,
} from './embedder.js';
import {
countTokens,
generate,
generateStream,
type GenerateOptions,
type GenerateResponse,
type GenerateStreamOptions,
type GenerateStreamResponse,
} from './generate.js';
import { GenerationCommonConfigSchema, type Part } from './model-types.js';
import {
GenerationCommonConfigSchema,
type GenerationUsage,
type Part,
} from './model-types.js';

/**
* `GenkitAI` encapsulates Genkit's AI APIs.
Expand Down Expand Up @@ -263,6 +268,89 @@ export class GenkitAI {
return generateStream(this.registry, options);
}

/**
* Make a countTokens call to the default model with a simple text prompt.
*
* ```ts
* const ai = genkit({
* plugins: [googleAI()],
* model: googleAI.model('gemini-flash-latest'), // default model
* })
*
* const usage = await ai.countTokens('hi');
* ```
*/
countTokens(strPrompt: string): Promise<GenerationUsage>;

/**
* Make a countTokens call to the default model with a multipart request.
*
* ```ts
* const ai = genkit({
* plugins: [googleAI()],
* model: googleAI.model('gemini-flash-latest'), // default model
* })
*
* const usage = await ai.countTokens([
* { media: {url: 'http://....'} },
* { text: 'describe this image' }
* ]);
* ```
*/
countTokens(parts: Part[]): Promise<GenerationUsage>;

/**
* Count tokens calculates the token usage of a generative model based on the provided prompt and configuration.
*
* See {@link GenerateOptions} for detailed information about available options.
*
* ```ts
* const ai = genkit({
* plugins: [googleAI()],
* })
*
* const usage = await ai.countTokens({
* system: 'talk like a pirate',
* prompt: [
* { media: { url: 'http://....' } },
* { text: 'describe this image' }
* ],
* messages: conversationHistory,
* tools: [ userInfoLookup ],
* model: googleAI.model('gemini-flash-latest'),
* });
* ```
*/
countTokens<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
opts:
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
): Promise<GenerationUsage>;

async countTokens<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
options:
| string
| Part[]
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
): Promise<GenerationUsage> {
if (typeof options === 'string' || Array.isArray(options)) {
options = { prompt: options };
}
return countTokens(
this.registry,
options as
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
);
}

/**
* Checks the status of of a given operation. Returns a new operation which will contain the updated status.
*
Expand Down
5 changes: 5 additions & 0 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export {
GenerateResponseChunk,
GenerationBlockedError,
GenerationResponseError,
countTokens,
generate,
generateOperation,
generateStream,
Expand Down Expand Up @@ -75,8 +76,10 @@ export {
ModelResponseSchema,
PartSchema,
RoleSchema,
isModelAction,
modelActionMetadata,
modelRef,
registerModelAction,
type GenerateRequest,
type GenerateRequestData,
type GenerateResponseChunkData,
Expand All @@ -91,6 +94,8 @@ export {
type ModelResponseData,
type Part,
type Role,
type TokenCounterAction,
type TokenCounterMiddleware,
type ToolRequestPart,
type ToolResponsePart,
} from './model.js';
Expand Down
2 changes: 2 additions & 0 deletions js/ai/src/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ export const GenerationUsageSchema = z.object({
outputVideos: z.number().optional(),
inputAudioFiles: z.number().optional(),
outputAudioFiles: z.number().optional(),
/** @deprecated use `raw` instead */
custom: z.record(z.number()).optional(),
raw: z.unknown().optional(),
thoughtsTokens: z.number().optional(),
cachedContentTokens: z.number().optional(),
});
Expand Down
Loading
Loading