1717package com.google.firebase.ai.ksp
1818
1919import com.google.devtools.ksp.KspExperimental
20+ import com.google.devtools.ksp.isPublic
2021import com.google.devtools.ksp.processing.CodeGenerator
2122import com.google.devtools.ksp.processing.Dependencies
2223import com.google.devtools.ksp.processing.KSPLogger
2324import com.google.devtools.ksp.processing.Resolver
2425import com.google.devtools.ksp.processing.SymbolProcessor
2526import com.google.devtools.ksp.symbol.ClassKind
27+ import com.google.devtools.ksp.symbol.FunctionKind
2628import com.google.devtools.ksp.symbol.KSAnnotated
2729import com.google.devtools.ksp.symbol.KSAnnotation
2830import com.google.devtools.ksp.symbol.KSClassDeclaration
31+ import com.google.devtools.ksp.symbol.KSFunctionDeclaration
2932import com.google.devtools.ksp.symbol.KSType
3033import com.google.devtools.ksp.symbol.KSVisitorVoid
3134import com.google.devtools.ksp.symbol.Modifier
@@ -40,30 +43,172 @@ import com.squareup.kotlinpoet.TypeSpec
4043import com.squareup.kotlinpoet.ksp.toClassName
4144import com.squareup.kotlinpoet.ksp.toTypeName
4245import com.squareup.kotlinpoet.ksp.writeTo
46+ import java.util.Locale
4347import javax.annotation.processing.Generated
4448
45- public class SchemaSymbolProcessor (
49+ public class FirebaseSymbolProcessor (
4650 private val codeGenerator : CodeGenerator ,
4751 private val logger : KSPLogger ,
4852) : SymbolProcessor {
53+ private val baseKdocRegex = Regex (" ^\\ s*(.*?)((@\\ w* .*)|\\ z)" , RegexOption .DOT_MATCHES_ALL )
54+ private val propertyKdocRegex =
55+ Regex (" \\ s*@property (\\ w*) (.*?)(?=@\\ w*|\\ z)" , RegexOption .DOT_MATCHES_ALL )
56+
4957 override fun process (resolver : Resolver ): List <KSAnnotated > {
5058 resolver
5159 .getSymbolsWithAnnotation(" com.google.firebase.ai.annotations.Generable" )
5260 .filterIsInstance<KSClassDeclaration >()
5361 .map { it to SchemaSymbolProcessorVisitor (it, resolver) }
5462 .forEach { it.second.visitClassDeclaration(it.first, Unit ) }
5563
64+ resolver
65+ .getSymbolsWithAnnotation(" com.google.firebase.ai.annotations.Tool" )
66+ .filterIsInstance<KSFunctionDeclaration >()
67+ .map { it to FunctionSymbolProcessorVisitor (it, resolver) }
68+ .forEach { it.second.visitFunctionDeclaration(it.first, Unit ) }
69+
5670 return emptyList()
5771 }
5872
73+ private inner class FunctionSymbolProcessorVisitor (
74+ private val func : KSFunctionDeclaration ,
75+ private val resolver : Resolver ,
76+ ) : KSVisitorVoid() {
77+ override fun visitFunctionDeclaration (function : KSFunctionDeclaration , data : Unit ) {
78+ var shouldError = false
79+ val fullFunctionName = function.qualifiedName!! .asString()
80+ if (! function.isPublic()) {
81+ logger.warn(" $fullFunctionName must be public." )
82+ shouldError = true
83+ }
84+ val containingClass = function.parentDeclaration as ? KSClassDeclaration
85+ if (containingClass == null || ! containingClass.isCompanionObject) {
86+ logger.warn(" $fullFunctionName must be within a companion object ${containingClass!! .qualifiedName!! .asString()} " )
87+ shouldError = true
88+ }
89+ if (function.parameters.size != 1 ) {
90+ logger.warn(" $fullFunctionName must have exactly one parameter" )
91+ shouldError = true
92+ }
93+ val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration
94+ if (parameter != null ) {
95+ if (parameter.annotations.find { it.shortName.getShortName() == " Generable" } == null ) {
96+ logger.warn(" $fullFunctionName parameter must be annotated @Generable" )
97+ shouldError = true
98+ }
99+ if (parameter.annotations.find { it.shortName.getShortName() == " Serializable" } == null ) {
100+ logger.warn(" $fullFunctionName parameter must be annotated @Serializable" )
101+ shouldError = true
102+ }
103+ }
104+ val output = function.returnType?.resolve()
105+ if (
106+ output != null &&
107+ output.toClassName().canonicalName != " kotlinx.serialization.json.JsonObject"
108+ ) {
109+ if (
110+ output.declaration.annotations.find { it.shortName.getShortName() != " Generable" } == null
111+ ) {
112+ logger.warn(" $fullFunctionName output must be annotated @Generable" )
113+ shouldError = true
114+ }
115+ if (
116+ output.declaration.annotations.find { it.shortName.getShortName() != " Serializable" } ==
117+ null
118+ ) {
119+ logger.warn(" $fullFunctionName output must be annotated @Serializable" )
120+ shouldError = true
121+ }
122+ }
123+ if (shouldError) {
124+ logger.error(" $fullFunctionName has one or more errors, please resolve them." )
125+ }
126+ val generatedFunctionFile = generateFileSpec(function)
127+ generatedFunctionFile.writeTo(
128+ codeGenerator,
129+ Dependencies (true , function.containingFile!! ),
130+ )
131+ }
132+
133+ private fun generateFileSpec (functionDeclaration : KSFunctionDeclaration ): FileSpec {
134+ val generatedClassName = functionDeclaration.simpleName.asString()
135+ .replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale .ROOT ) else it.toString() } + " GeneratedFunctionDeclaration"
136+ return FileSpec .builder(
137+ functionDeclaration.packageName.asString(),
138+ generatedClassName
139+ )
140+ .addImport(" com.google.firebase.ai.type" , " AutoFunctionDeclaration" )
141+ .addType(
142+ TypeSpec .classBuilder(generatedClassName)
143+ .addAnnotation(Generated ::class )
144+ .addType(
145+ TypeSpec .companionObjectBuilder()
146+ .addProperty(
147+ PropertySpec .builder(
148+ " FUNCTION_DECLARATION" ,
149+ ClassName (" com.google.firebase.ai.type" , " AutoFunctionDeclaration" )
150+ .parameterizedBy(
151+ functionDeclaration.parameters.first().type.resolve().toClassName(),
152+ functionDeclaration.returnType?.resolve()?.toClassName()
153+ ? : ClassName (" kotlinx.serialization.json" , " JsonObject" )
154+ ),
155+ KModifier .PUBLIC ,
156+ )
157+ .mutable(false )
158+ .initializer(
159+ CodeBlock .builder()
160+ .add(generateCodeBlockForFunctionDeclaration(functionDeclaration))
161+ .build()
162+ )
163+ .build()
164+ )
165+ .build()
166+ )
167+ .build()
168+ )
169+ .build()
170+ }
171+ fun generateCodeBlockForFunctionDeclaration (
172+ functionDeclaration : KSFunctionDeclaration
173+ ): CodeBlock {
174+ val builder = CodeBlock .builder()
175+ val hasTypedOutput =
176+ ! (functionDeclaration.returnType == null ||
177+ functionDeclaration.returnType!! .resolve().toClassName().canonicalName ==
178+ " kotlinx.serialization.json.JsonObject" )
179+ val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) }
180+ val annotationDescription =
181+ getStringFromAnnotation(
182+ functionDeclaration.annotations.find { it.shortName.getShortName() == " Tool" },
183+ " description"
184+ )
185+ val description = annotationDescription ? : kdocDescription ? : " "
186+ val inputSchemaName =
187+ " ${functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName} GeneratedSchema.SCHEMA"
188+ builder
189+ .addStatement(" AutoFunctionDeclaration.create(" )
190+ .indent()
191+ .addStatement(" functionName = %S," , functionDeclaration.simpleName.getShortName())
192+ .addStatement(" description = %S," , description)
193+ .addStatement(" inputSchema = $inputSchemaName ," )
194+ if (hasTypedOutput) {
195+ val outputSchemaName =
196+ " ${functionDeclaration.returnType!! .resolve().toClassName().canonicalName} GeneratedSchema.SCHEMA"
197+ builder.addStatement(" outputSchema = $outputSchemaName ," )
198+ }
199+ builder.addStatement(
200+ " functionReference = ${functionDeclaration.qualifiedName!! .getQualifier()} ::${functionDeclaration.qualifiedName!! .getShortName()} ,"
201+ )
202+ builder.unindent().addStatement(" )" )
203+ return builder.build()
204+ }
205+ }
206+
59207 private inner class SchemaSymbolProcessorVisitor (
60208 private val klass : KSClassDeclaration ,
61209 private val resolver : Resolver ,
62210 ) : KSVisitorVoid() {
63211 private val numberTypes = setOf (" kotlin.Int" , " kotlin.Long" , " kotlin.Double" , " kotlin.Float" )
64- private val baseKdocRegex = Regex (" ^\\ s*(.*?)((@\\ w* .*)|\\ z)" , RegexOption .DOT_MATCHES_ALL )
65- private val propertyKdocRegex =
66- Regex (" \\ s*@property (\\ w*) (.*?)(?=@\\ w*|\\ z)" , RegexOption .DOT_MATCHES_ALL )
67212
68213 override fun visitClassDeclaration (classDeclaration : KSClassDeclaration , data : Unit ) {
69214 val isDataClass = classDeclaration.modifiers.contains(Modifier .DATA )
@@ -265,73 +410,73 @@ public class SchemaSymbolProcessor(
265410 builder.addStatement(" nullable = %L)" , className.isNullable).unindent()
266411 return builder.build()
267412 }
413+ }
268414
269- private fun getDescriptionFromAnnotations (
270- guideAnnotation : KSAnnotation ? ,
271- guideClassAnnotation : KSAnnotation ? ,
272- description : String? ,
273- baseKdoc : String? ,
274- ): String? {
275- val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, " description" )
415+ private fun getDescriptionFromAnnotations (
416+ guideAnnotation : KSAnnotation ? ,
417+ guideClassAnnotation : KSAnnotation ? ,
418+ description : String? ,
419+ baseKdoc : String? ,
420+ ): String? {
421+ val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, " description" )
276422
277- val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, " description" )
423+ val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, " description" )
278424
279- return guidePropertyDescription ? : guideClassDescription ? : description ? : baseKdoc
280- }
425+ return guidePropertyDescription ? : guideClassDescription ? : description ? : baseKdoc
426+ }
281427
282- private fun getDoubleFromAnnotation (
283- guideAnnotation : KSAnnotation ? ,
284- doubleName : String ,
285- ): Double? {
286- val guidePropertyDoubleValue =
287- guideAnnotation
288- ?.arguments
289- ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
290- ?.value as ? Double
291- if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == - 1.0 ) {
292- return null
293- }
294- return guidePropertyDoubleValue
428+ private fun getDoubleFromAnnotation (
429+ guideAnnotation : KSAnnotation ? ,
430+ doubleName : String ,
431+ ): Double? {
432+ val guidePropertyDoubleValue =
433+ guideAnnotation
434+ ?.arguments
435+ ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
436+ ?.value as ? Double
437+ if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == - 1.0 ) {
438+ return null
295439 }
440+ return guidePropertyDoubleValue
441+ }
296442
297- private fun getIntFromAnnotation (guideAnnotation : KSAnnotation ? , intName : String ): Int? {
298- val guidePropertyIntValue =
299- guideAnnotation
300- ?.arguments
301- ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
302- ?.value as ? Int
303- if (guidePropertyIntValue == null || guidePropertyIntValue == - 1 ) {
304- return null
305- }
306- return guidePropertyIntValue
443+ private fun getIntFromAnnotation (guideAnnotation : KSAnnotation ? , intName : String ): Int? {
444+ val guidePropertyIntValue =
445+ guideAnnotation
446+ ?.arguments
447+ ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
448+ ?.value as ? Int
449+ if (guidePropertyIntValue == null || guidePropertyIntValue == - 1 ) {
450+ return null
307451 }
452+ return guidePropertyIntValue
453+ }
308454
309- private fun getStringFromAnnotation (
310- guideAnnotation : KSAnnotation ? ,
311- stringName : String ,
312- ): String? {
313- val guidePropertyStringValue =
314- guideAnnotation
315- ?.arguments
316- ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
317- ?.value as ? String
318- if (guidePropertyStringValue.isNullOrEmpty()) {
319- return null
320- }
321- return guidePropertyStringValue
455+ private fun getStringFromAnnotation (
456+ guideAnnotation : KSAnnotation ? ,
457+ stringName : String ,
458+ ): String? {
459+ val guidePropertyStringValue =
460+ guideAnnotation
461+ ?.arguments
462+ ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
463+ ?.value as ? String
464+ if (guidePropertyStringValue.isNullOrEmpty()) {
465+ return null
322466 }
467+ return guidePropertyStringValue
468+ }
323469
324- private fun extractBaseKdoc (kdoc : String ): String? {
325- return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1 )?.value?.trim().let {
326- if (it.isNullOrEmpty()) null else it
327- }
470+ private fun extractBaseKdoc (kdoc : String ): String? {
471+ return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1 )?.value?.trim().let {
472+ if (it.isNullOrEmpty()) null else it
328473 }
474+ }
329475
330- private fun extractPropertyKdocs (kdoc : String ): Map <String , String > {
331- return propertyKdocRegex
332- .findAll(kdoc)
333- .map { it.groups[1 ]!! .value to it.groups[2 ]!! .value.replace(" \n " , " " ).trim() }
334- .toMap()
335- }
476+ private fun extractPropertyKdocs (kdoc : String ): Map <String , String > {
477+ return propertyKdocRegex
478+ .findAll(kdoc)
479+ .map { it.groups[1 ]!! .value to it.groups[2 ]!! .value.replace(" \n " , " " ).trim() }
480+ .toMap()
336481 }
337482}
0 commit comments