Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
package com.google.firebase.ai.ksp

import com.google.devtools.ksp.KspExperimental
import com.google.devtools.ksp.isPublic
import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.processing.Dependencies
import com.google.devtools.ksp.processing.KSPLogger
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessor
import com.google.devtools.ksp.symbol.ClassKind
import com.google.devtools.ksp.symbol.FunctionKind
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSAnnotation
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSVisitorVoid
import com.google.devtools.ksp.symbol.Modifier
Expand All @@ -40,30 +43,172 @@ import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeName
import com.squareup.kotlinpoet.ksp.writeTo
import java.util.Locale
import javax.annotation.processing.Generated

public class SchemaSymbolProcessor(
public class FirebaseSymbolProcessor(
private val codeGenerator: CodeGenerator,
private val logger: KSPLogger,
) : SymbolProcessor {
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
private val propertyKdocRegex =
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)

override fun process(resolver: Resolver): List<KSAnnotated> {
resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable")
.filterIsInstance<KSClassDeclaration>()
.map { it to SchemaSymbolProcessorVisitor(it, resolver) }
.forEach { it.second.visitClassDeclaration(it.first, Unit) }

resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool")
.filterIsInstance<KSFunctionDeclaration>()
.map { it to FunctionSymbolProcessorVisitor(it, resolver) }
.forEach { it.second.visitFunctionDeclaration(it.first, Unit) }

return emptyList()
}

private inner class FunctionSymbolProcessorVisitor(
private val func: KSFunctionDeclaration,
private val resolver: Resolver,
) : KSVisitorVoid() {
override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) {
var shouldError = false
val fullFunctionName = function.qualifiedName!!.asString()
if (!function.isPublic()) {
logger.warn("$fullFunctionName must be public.")
shouldError = true
}
val containingClass = function.parentDeclaration as? KSClassDeclaration
if (containingClass == null || !containingClass.isCompanionObject) {
logger.warn("$fullFunctionName must be within a companion object ${containingClass!!.qualifiedName!!.asString()}")
shouldError = true
}
if (function.parameters.size != 1) {
logger.warn("$fullFunctionName must have exactly one parameter")
shouldError = true
}
val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration
if (parameter != null) {
if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Generable")
shouldError = true
}
if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Serializable")
shouldError = true
}
}
val output = function.returnType?.resolve()
if (
output != null &&
output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject"
) {
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null
) {
logger.warn("$fullFunctionName output must be annotated @Generable")
shouldError = true
}
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } ==
null
) {
logger.warn("$fullFunctionName output must be annotated @Serializable")
shouldError = true
}
}
if (shouldError) {
logger.error("$fullFunctionName has one or more errors, please resolve them.")
}
val generatedFunctionFile = generateFileSpec(function)
generatedFunctionFile.writeTo(
codeGenerator,
Dependencies(true, function.containingFile!!),
)
}

private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec {
val generatedClassName = functionDeclaration.simpleName.asString()
.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() } + "GeneratedFunctionDeclaration"
return FileSpec.builder(
functionDeclaration.packageName.asString(),
generatedClassName
)
.addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.addType(
TypeSpec.classBuilder(generatedClassName)
.addAnnotation(Generated::class)
.addType(
TypeSpec.companionObjectBuilder()
.addProperty(
PropertySpec.builder(
"FUNCTION_DECLARATION",
ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.parameterizedBy(
functionDeclaration.parameters.first().type.resolve().toClassName(),
functionDeclaration.returnType?.resolve()?.toClassName()
?: ClassName("kotlinx.serialization.json", "JsonObject")
),
KModifier.PUBLIC,
)
.mutable(false)
.initializer(
CodeBlock.builder()
.add(generateCodeBlockForFunctionDeclaration(functionDeclaration))
.build()
)
.build()
)
.build()
)
.build()
)
.build()
}
fun generateCodeBlockForFunctionDeclaration(
functionDeclaration: KSFunctionDeclaration
): CodeBlock {
val builder = CodeBlock.builder()
val hasTypedOutput =
!(functionDeclaration.returnType == null ||
functionDeclaration.returnType!!.resolve().toClassName().canonicalName ==
"kotlinx.serialization.json.JsonObject")
val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) }
val annotationDescription =
getStringFromAnnotation(
functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" },
"description"
)
val description = annotationDescription ?: kdocDescription ?: ""
val inputSchemaName =
"${functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName}GeneratedSchema.SCHEMA"
builder
.addStatement("AutoFunctionDeclaration.create(")
.indent()
.addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName())
.addStatement("description = %S,", description)
.addStatement("inputSchema = $inputSchemaName,")
if (hasTypedOutput) {
val outputSchemaName =
"${functionDeclaration.returnType!!.resolve().toClassName().canonicalName}GeneratedSchema.SCHEMA"
builder.addStatement("outputSchema = $outputSchemaName,")
}
builder.addStatement(
"functionReference = ${functionDeclaration.qualifiedName!!.getQualifier()}::${functionDeclaration.qualifiedName!!.getShortName()},"
)
builder.unindent().addStatement(")")
return builder.build()
}
}

private inner class SchemaSymbolProcessorVisitor(
private val klass: KSClassDeclaration,
private val resolver: Resolver,
) : KSVisitorVoid() {
private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float")
private val baseKdocRegex = Regex("^\\s*(.*?)((@\\w* .*)|\\z)", RegexOption.DOT_MATCHES_ALL)
private val propertyKdocRegex =
Regex("\\s*@property (\\w*) (.*?)(?=@\\w*|\\z)", RegexOption.DOT_MATCHES_ALL)

override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) {
val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA)
Expand Down Expand Up @@ -265,73 +410,73 @@ public class SchemaSymbolProcessor(
builder.addStatement("nullable = %L)", className.isNullable).unindent()
return builder.build()
}
}

private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
guideClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")
private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
guideClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")

val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")
val guideClassDescription = getStringFromAnnotation(guideClassAnnotation, "description")

return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}
return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}

private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
}

private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
}

private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
}

private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
}

private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider

public class SchemaSymbolProcessorProvider : SymbolProcessorProvider {
public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider {
override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
return SchemaSymbolProcessor(environment.codeGenerator, environment.logger)
return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger)
}
}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider
com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai.annotations

@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.SOURCE)
public annotation class Tool(public val description: String = "")
Loading