-
Notifications
You must be signed in to change notification settings - Fork 710
feat(js): ai.countTokens support #5116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,18 +49,22 @@ import { | |
| import { GenerateResponse } from './generate/response.js'; | ||
| import { Message } from './message.js'; | ||
| import { | ||
| GenerateRequestSchema, | ||
| GenerateResponseChunkData, | ||
| GenerateResponseData, | ||
| GenerationUsageSchema, | ||
| ResolvedModel, | ||
| resolveModel, | ||
| type GenerateActionOptions, | ||
| type GenerateRequest, | ||
| type GenerationCommonConfigSchema, | ||
| type GenerationUsage, | ||
| type MessageData, | ||
| type MiddlewareRef, | ||
| type ModelArgument, | ||
| type ModelMiddlewareArgument, | ||
| type Part, | ||
| type TokenCounterAction, | ||
| type ToolRequestPart, | ||
| type ToolResponsePart, | ||
| } from './model.js'; | ||
|
|
@@ -796,3 +800,106 @@ export function tagAsPreamble(msgs?: MessageData[]): MessageData[] | undefined { | |
| }, | ||
| })); | ||
| } | ||
|
|
||
| /** | ||
| * Counts the tokens for a given generate request. | ||
| */ | ||
| export async function countTokens< | ||
| O extends z.ZodTypeAny = z.ZodTypeAny, | ||
| CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, | ||
| >( | ||
| registry: Registry, | ||
| options: | ||
| | GenerateOptions<O, CustomOptions> | ||
| | PromiseLike<GenerateOptions<O, CustomOptions>> | ||
| ): Promise<GenerationUsage> { | ||
| const resolvedOptions: GenerateOptions<O, CustomOptions> = { | ||
| ...(await Promise.resolve(options)), | ||
| }; | ||
|
|
||
| const childRegistry = Registry.withParent(registry); | ||
|
|
||
| maybeRegisterDynamicTools(childRegistry, resolvedOptions); | ||
| maybeRegisterDynamicResources(childRegistry, resolvedOptions); | ||
|
|
||
| const resolvedModel = await resolveModel( | ||
| childRegistry, | ||
| resolvedOptions.model | ||
| ); | ||
|
|
||
| const tools = await toolsToActionRefs(childRegistry, resolvedOptions.tools); | ||
| const resources = await resourcesToActionRefs( | ||
| childRegistry, | ||
| resolvedOptions.resources | ||
| ); | ||
|
|
||
| const request = await toGenerateRequest(childRegistry, { | ||
| ...resolvedOptions, | ||
| tools, | ||
| resources, | ||
| }); | ||
|
|
||
| request.config = { | ||
| ...(resolvedModel?.version ? { version: resolvedModel.version } : {}), | ||
| ...stripUndefinedOptions(resolvedModel?.config), | ||
| ...stripUndefinedOptions(request.config), | ||
| }; | ||
| if (Object.keys(request.config || {}).length === 0) { | ||
| delete request.config; | ||
| } | ||
|
|
||
| const middlewareRefs = await normalizeMiddleware( | ||
| childRegistry, | ||
| resolvedOptions.use | ||
| ); | ||
| const resolvedMiddleware = await resolveMiddleware( | ||
| childRegistry, | ||
| middlewareRefs | ||
| ); | ||
| maybeRegisterDynamicMiddlewareTools(childRegistry, resolvedMiddleware); | ||
|
|
||
| let interceptedRequest = request; | ||
| if (resolvedMiddleware && resolvedMiddleware.length > 0) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? there's no middleware for counters...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ai.countTokens takes a fully formed request. Middleware can modify the request. Example: when the model itself doesn't support X so the middleware modifies the request to a format the model does support - Like downloading an image from an unsupported url and including the image in the media instead of the url only. That's a large difference in tokens used. If we countTokens on the unmodified request it will include the length of the url. If we countTokens on the post-middleware request it will include the length of the base64 string...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that is true that middleware can modify the request, but also middleware can do a lot of other things (even with potential side-effects)... I'm worried that it may not be obvious to the user and can lead to undesired side-effects. I think that countTokens should not take a GenerateRequest but just a list of Messages.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That leads to inaccuracies in token counting though... and potentially large ones. The details of the request including middleware and everything else that happens to it is important if we want accurate token counting. And there's a BIG difference between a url string and a downloaded media dataurl...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I suspect our users will often want something like: and it would probably be bad if tokens.totalTokens was thousands of tokens off the actual answer... |
||
| const dispatchModel = async ( | ||
| index: number, | ||
| req: GenerateRequest | ||
| ): Promise<GenerateResponseData> => { | ||
| if (index === resolvedMiddleware.length) { | ||
| interceptedRequest = req; | ||
| // Return a dummy response to safely unwind the middleware chain without | ||
| // executing the actual model generation, since we only want to intercept the mutated request. | ||
| return { | ||
| message: { role: 'model', content: [] }, | ||
| finishReason: 'stop', | ||
| }; | ||
| } | ||
| const currentMiddleware = resolvedMiddleware[index]; | ||
| if (currentMiddleware.model) { | ||
| return currentMiddleware.model(req, {}, async (modifiedReq) => | ||
|
ifielker marked this conversation as resolved.
|
||
| dispatchModel(index + 1, modifiedReq || req) | ||
| ); | ||
| } else { | ||
| return dispatchModel(index + 1, req); | ||
| } | ||
| }; | ||
| await dispatchModel(0, request); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why call the model?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's how we get the modified request. We essentially apply all the middleware to the request and then return a dummy response to safely unwind the middleware chain without executing the model and then we have the mutated request that we can give to countTokens. |
||
| } | ||
|
|
||
| const counterActionName = `/model-token-counter/${resolvedModel.modelAction.__action.name}`; | ||
| const counterAction = | ||
| resolvedModel.modelAction.__tokenCounterAction || | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in what case is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not set if the countTokens was not defined during the model definition (not all models have token counting endpoints in Gemini API for example. Theoretically if someone wanted a tokenCounter that is not tied to a model, they could add one using something like: and then we would still find it using this search. Unlikely, but it keeps the possibility open. Also the error for not finding the countTokens action will be the standard "action not found". |
||
| (await childRegistry.lookupAction< | ||
| typeof GenerateRequestSchema, | ||
| typeof GenerationUsageSchema, | ||
| TokenCounterAction<CustomOptions> | ||
| >(counterActionName)); | ||
|
|
||
| if (!counterAction) { | ||
| throw new GenkitError({ | ||
| status: 'NOT_FOUND', | ||
| message: `Model '${resolvedModel.modelAction.__action.name}' does not support token counting (model-token-counter action not found).`, | ||
| }); | ||
| } | ||
|
|
||
| return await counterAction(interceptedRequest); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.