Skip to content

Commit 425c025

Browse files
author
David Motsonashvili
committed
Add Function generation capability to ksp processor
1 parent b9a59d3 commit 425c025

File tree

4 files changed

+229
-63
lines changed

4 files changed

+229
-63
lines changed

firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt renamed to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt

Lines changed: 205 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@
1717
package com.google.firebase.ai.ksp
1818

1919
import com.google.devtools.ksp.KspExperimental
20+
import com.google.devtools.ksp.isPublic
2021
import com.google.devtools.ksp.processing.CodeGenerator
2122
import com.google.devtools.ksp.processing.Dependencies
2223
import com.google.devtools.ksp.processing.KSPLogger
2324
import com.google.devtools.ksp.processing.Resolver
2425
import com.google.devtools.ksp.processing.SymbolProcessor
2526
import com.google.devtools.ksp.symbol.ClassKind
27+
import com.google.devtools.ksp.symbol.FunctionKind
2628
import com.google.devtools.ksp.symbol.KSAnnotated
2729
import com.google.devtools.ksp.symbol.KSAnnotation
2830
import com.google.devtools.ksp.symbol.KSClassDeclaration
31+
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
2932
import com.google.devtools.ksp.symbol.KSType
3033
import com.google.devtools.ksp.symbol.KSVisitorVoid
3134
import com.google.devtools.ksp.symbol.Modifier
@@ -40,30 +43,172 @@ import com.squareup.kotlinpoet.TypeSpec
4043
import com.squareup.kotlinpoet.ksp.toClassName
4144
import com.squareup.kotlinpoet.ksp.toTypeName
4245
import com.squareup.kotlinpoet.ksp.writeTo
46+
import java.util.Locale
4347
import javax.annotation.processing.Generated
4448

45-
public class SchemaSymbolProcessor(
49+
public class FirebaseSymbolProcessor(
4650
private val codeGenerator: CodeGenerator,
4751
private val logger: KSPLogger,
4852
) : SymbolProcessor {
53+
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
54+
private val propertyKdocRegex =
55+
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)
56+
4957
override fun process(resolver: Resolver): List<KSAnnotated> {
5058
resolver
5159
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable")
5260
.filterIsInstance<KSClassDeclaration>()
5361
.map { it to SchemaSymbolProcessorVisitor(it, resolver) }
5462
.forEach { it.second.visitClassDeclaration(it.first, Unit) }
5563

64+
resolver
65+
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool")
66+
.filterIsInstance<KSFunctionDeclaration>()
67+
.map { it to FunctionSymbolProcessorVisitor(it, resolver) }
68+
.forEach { it.second.visitFunctionDeclaration(it.first, Unit) }
69+
5670
return emptyList()
5771
}
5872

73+
private inner class FunctionSymbolProcessorVisitor(
74+
private val func: KSFunctionDeclaration,
75+
private val resolver: Resolver,
76+
) : KSVisitorVoid() {
77+
override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) {
78+
var shouldError = false
79+
val fullFunctionName = function.qualifiedName!!.asString()
80+
if (!function.isPublic()) {
81+
logger.warn("$fullFunctionName must be public.")
82+
shouldError = true
83+
}
84+
val containingClass = function.parentDeclaration as? KSClassDeclaration
85+
if (containingClass == null || !containingClass.isCompanionObject) {
86+
logger.warn("$fullFunctionName must be within a companion object ${containingClass!!.qualifiedName!!.asString()}")
87+
shouldError = true
88+
}
89+
if (function.parameters.size != 1) {
90+
logger.warn("$fullFunctionName must have exactly one parameter")
91+
shouldError = true
92+
}
93+
val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration
94+
if (parameter != null) {
95+
if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) {
96+
logger.warn("$fullFunctionName parameter must be annotated @Generable")
97+
shouldError = true
98+
}
99+
if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) {
100+
logger.warn("$fullFunctionName parameter must be annotated @Serializable")
101+
shouldError = true
102+
}
103+
}
104+
val output = function.returnType?.resolve()
105+
if (
106+
output != null &&
107+
output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject"
108+
) {
109+
if (
110+
output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null
111+
) {
112+
logger.warn("$fullFunctionName output must be annotated @Generable")
113+
shouldError = true
114+
}
115+
if (
116+
output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } ==
117+
null
118+
) {
119+
logger.warn("$fullFunctionName output must be annotated @Serializable")
120+
shouldError = true
121+
}
122+
}
123+
if (shouldError) {
124+
logger.error("$fullFunctionName has one or more errors, please resolve them.")
125+
}
126+
val generatedFunctionFile = generateFileSpec(function)
127+
generatedFunctionFile.writeTo(
128+
codeGenerator,
129+
Dependencies(true, function.containingFile!!),
130+
)
131+
}
132+
133+
private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec {
134+
val generatedClassName = functionDeclaration.simpleName.asString()
135+
.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() } + "GeneratedFunctionDeclaration"
136+
return FileSpec.builder(
137+
functionDeclaration.packageName.asString(),
138+
generatedClassName
139+
)
140+
.addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration")
141+
.addType(
142+
TypeSpec.classBuilder(generatedClassName)
143+
.addAnnotation(Generated::class)
144+
.addType(
145+
TypeSpec.companionObjectBuilder()
146+
.addProperty(
147+
PropertySpec.builder(
148+
"FUNCTION_DECLARATION",
149+
ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration")
150+
.parameterizedBy(
151+
functionDeclaration.parameters.first().type.resolve().toClassName(),
152+
functionDeclaration.returnType?.resolve()?.toClassName()
153+
?: ClassName("kotlinx.serialization.json", "JsonObject")
154+
),
155+
KModifier.PUBLIC,
156+
)
157+
.mutable(false)
158+
.initializer(
159+
CodeBlock.builder()
160+
.add(generateCodeBlockForFunctionDeclaration(functionDeclaration))
161+
.build()
162+
)
163+
.build()
164+
)
165+
.build()
166+
)
167+
.build()
168+
)
169+
.build()
170+
}
171+
fun generateCodeBlockForFunctionDeclaration(
172+
functionDeclaration: KSFunctionDeclaration
173+
): CodeBlock {
174+
val builder = CodeBlock.builder()
175+
val hasTypedOutput =
176+
!(functionDeclaration.returnType == null ||
177+
functionDeclaration.returnType!!.resolve().toClassName().canonicalName ==
178+
"kotlinx.serialization.json.JsonObject")
179+
val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) }
180+
val annotationDescription =
181+
getStringFromAnnotation(
182+
functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" },
183+
"description"
184+
)
185+
val description = annotationDescription ?: kdocDescription ?: ""
186+
val inputSchemaName =
187+
"${functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName}GeneratedSchema.SCHEMA"
188+
builder
189+
.addStatement("AutoFunctionDeclaration.create(")
190+
.indent()
191+
.addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName())
192+
.addStatement("description = %S,", description)
193+
.addStatement("inputSchema = $inputSchemaName,")
194+
if (hasTypedOutput) {
195+
val outputSchemaName =
196+
"${functionDeclaration.returnType!!.resolve().toClassName().canonicalName}GeneratedSchema.SCHEMA"
197+
builder.addStatement("outputSchema = $outputSchemaName,")
198+
}
199+
builder.addStatement(
200+
"functionReference = ${functionDeclaration.qualifiedName!!.getQualifier()}::${functionDeclaration.qualifiedName!!.getShortName()},"
201+
)
202+
builder.unindent().addStatement(")")
203+
return builder.build()
204+
}
205+
}
206+
59207
private inner class SchemaSymbolProcessorVisitor(
60208
private val klass: KSClassDeclaration,
61209
private val resolver: Resolver,
62210
) : KSVisitorVoid() {
63211
private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float")
64-
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
65-
private val propertyKdocRegex =
66-
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)
67212

68213
override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) {
69214
val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA)
@@ -265,73 +410,73 @@ public class SchemaSymbolProcessor(
265410
builder.addStatement("nullable = %L)", className.isNullable).unindent()
266411
return builder.build()
267412
}
413+
}
268414

269-
private fun getDescriptionFromAnnotations(
270-
guideAnnotation: KSAnnotation?,
271-
guideClassAnnotation: KSAnnotation?,
272-
description: String?,
273-
baseKdoc: String?,
274-
): String? {
275-
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")
415+
private fun getDescriptionFromAnnotations(
416+
guideAnnotation: KSAnnotation?,
417+
guideClassAnnotation: KSAnnotation?,
418+
description: String?,
419+
baseKdoc: String?,
420+
): String? {
421+
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")
276422

277-
val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")
423+
val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")
278424

279-
return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
280-
}
425+
return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
426+
}
281427

282-
private fun getDoubleFromAnnotation(
283-
guideAnnotation: KSAnnotation?,
284-
doubleName: String,
285-
): Double? {
286-
val guidePropertyDoubleValue =
287-
guideAnnotation
288-
?.arguments
289-
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
290-
?.value as? Double
291-
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
292-
return null
293-
}
294-
return guidePropertyDoubleValue
428+
private fun getDoubleFromAnnotation(
429+
guideAnnotation: KSAnnotation?,
430+
doubleName: String,
431+
): Double? {
432+
val guidePropertyDoubleValue =
433+
guideAnnotation
434+
?.arguments
435+
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
436+
?.value as? Double
437+
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
438+
return null
295439
}
440+
return guidePropertyDoubleValue
441+
}
296442

297-
private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
298-
val guidePropertyIntValue =
299-
guideAnnotation
300-
?.arguments
301-
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
302-
?.value as? Int
303-
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
304-
return null
305-
}
306-
return guidePropertyIntValue
443+
private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
444+
val guidePropertyIntValue =
445+
guideAnnotation
446+
?.arguments
447+
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
448+
?.value as? Int
449+
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
450+
return null
307451
}
452+
return guidePropertyIntValue
453+
}
308454

309-
private fun getStringFromAnnotation(
310-
guideAnnotation: KSAnnotation?,
311-
stringName: String,
312-
): String? {
313-
val guidePropertyStringValue =
314-
guideAnnotation
315-
?.arguments
316-
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
317-
?.value as? String
318-
if (guidePropertyStringValue.isNullOrEmpty()) {
319-
return null
320-
}
321-
return guidePropertyStringValue
455+
private fun getStringFromAnnotation(
456+
guideAnnotation: KSAnnotation?,
457+
stringName: String,
458+
): String? {
459+
val guidePropertyStringValue =
460+
guideAnnotation
461+
?.arguments
462+
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
463+
?.value as? String
464+
if (guidePropertyStringValue.isNullOrEmpty()) {
465+
return null
322466
}
467+
return guidePropertyStringValue
468+
}
323469

324-
private fun extractBaseKdoc(kdoc: String): String? {
325-
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
326-
if (it.isNullOrEmpty()) null else it
327-
}
470+
private fun extractBaseKdoc(kdoc: String): String? {
471+
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
472+
if (it.isNullOrEmpty()) null else it
328473
}
474+
}
329475

330-
private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
331-
return propertyKdocRegex
332-
.findAll(kdoc)
333-
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
334-
.toMap()
335-
}
476+
private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
477+
return propertyKdocRegex
478+
.findAll(kdoc)
479+
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
480+
.toMap()
336481
}
337482
}

firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt renamed to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor
2020
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
2121
import com.google.devtools.ksp.processing.SymbolProcessorProvider
2222

23-
public class SchemaSymbolProcessorProvider : SymbolProcessorProvider {
23+
public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider {
2424
override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
25-
return SchemaSymbolProcessor(environment.codeGenerator, environment.logger)
25+
return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger)
2626
}
2727
}
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider
1+
com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.ai.annotations
18+
19+
@Target(AnnotationTarget.FUNCTION)
20+
@Retention(AnnotationRetention.SOURCE)
21+
public annotation class Tool(public val description: String = "")

0 commit comments

Comments
 (0)