Skip to content
62 changes: 22 additions & 40 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"log/slog"
"maps"
"os"
"path/filepath"
"path"
"reflect"
"strings"

Expand Down Expand Up @@ -517,70 +517,54 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
return result, nil
}

// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
useDefaultDir := false
if dir == "" {
dir = "./prompts"
useDefaultDir = true
// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where prompts are located.
func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) {
if fsys == nil {
panic(errors.New("no prompt filesystem provided"))
}

path, err := filepath.Abs(dir)
if err != nil {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
if _, err := fs.Stat(fsys, dir); err != nil {
panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err))
}

if _, err := os.Stat(path); os.IsNotExist(err) {
if !useDefaultDir {
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
}
slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir)
return
}

loadPromptDir(r, path, namespace)
}

// loadPromptDir recursively loads prompts and partials from the directory.
func loadPromptDir(r api.Registry, dir string, namespace string) {
entries, err := os.ReadDir(dir)
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to read prompt directory structure: %w", err))
}

for _, entry := range entries {
filename := entry.Name()
path := filepath.Join(dir, filename)
filePath := path.Join(dir, filename)
if entry.IsDir() {
loadPromptDir(r, path, namespace)
LoadPromptDirFromFS(r, fsys, filePath, namespace)
} else if strings.HasSuffix(filename, ".prompt") {
if strings.HasPrefix(filename, "_") {
partialName := strings.TrimSuffix(filename[1:], ".prompt")
source, err := os.ReadFile(path)
source, err := fs.ReadFile(fsys, filePath)
if err != nil {
slog.Error("Failed to read partial file", "error", err)
continue
}
r.RegisterPartial(partialName, string(source))
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path)
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath)
} else {
LoadPrompt(r, dir, filename, namespace)
LoadPromptFromFS(r, fsys, dir, filename, namespace)
}
}
}
}

// LoadPrompt loads a single prompt into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
// LoadPromptFromFS loads a single prompt from a filesystem into the registry.
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
// The dir parameter specifies the directory within the filesystem where the prompt is located.
func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt {
name := strings.TrimSuffix(filename, ".prompt")
name, variant, _ := strings.Cut(name, ".")

sourceFile := filepath.Join(dir, filename)
source, err := os.ReadFile(sourceFile)
sourceFile := path.Join(dir, filename)
source, err := fs.ReadFile(fsys, sourceFile)
if err != nil {
slog.Error("Failed to read prompt file", "file", sourceFile, "error", err)
return nil
Expand Down Expand Up @@ -696,12 +680,10 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

promptOpts := []PromptOption{opts}

// Add system prompt if found
if systemText != "" {
promptOpts = append(promptOpts, WithSystem(systemText))
}

// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
if len(nonSystemMessages) > 0 {
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
} else if systemText == "" {
Expand Down
151 changes: 137 additions & 14 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"path/filepath"
"strings"
"testing"
"testing/fstest"

"github.com/firebase/genkit/go/core/api"
"github.com/firebase/genkit/go/internal/base"
Expand Down Expand Up @@ -913,7 +914,7 @@ Hello, {{name}}!
reg := registry.New()

// Call loadPrompt
LoadPrompt(reg, tempDir, "example.prompt", "test-namespace")
LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.prompt", "test-namespace")

// Verify that the prompt was registered correctly
prompt := LookupPrompt(reg, "test-namespace/example")
Expand Down Expand Up @@ -962,7 +963,7 @@ input:
}

reg := registry.New()
LoadPrompt(reg, tempDir, "snake.prompt", "snake-namespace")
LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "snake.prompt", "snake-namespace")

prompt := LookupPrompt(reg, "snake-namespace/snake")
if prompt == nil {
Expand Down Expand Up @@ -1010,8 +1011,9 @@ func TestLoadPrompt_FileNotFound(t *testing.T) {
// Initialize a mock registry
reg := registry.New()

// Call loadPrompt with a non-existent file
LoadPrompt(reg, "./nonexistent", "missing.prompt", "test-namespace")
// Call loadPrompt with a non-existent file in a valid temp directory
tempDir := t.TempDir()
LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "missing.prompt", "test-namespace")

// Verify that the prompt was not registered
prompt := LookupPrompt(reg, "missing")
Expand All @@ -1036,7 +1038,7 @@ func TestLoadPrompt_InvalidPromptFile(t *testing.T) {
reg := registry.New()

// Call loadPrompt
LoadPrompt(reg, tempDir, "invalid.prompt", "test-namespace")
LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "invalid.prompt", "test-namespace")

// Verify that the prompt was not registered
prompt := LookupPrompt(reg, "invalid")
Expand Down Expand Up @@ -1067,7 +1069,7 @@ Hello, {{name}}!
reg := registry.New()

// Call loadPrompt
LoadPrompt(reg, tempDir, "example.variant.prompt", "test-namespace")
LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.variant.prompt", "test-namespace")

// Verify that the prompt was registered correctly
prompt := LookupPrompt(reg, "test-namespace/example.variant")
Expand Down Expand Up @@ -1122,7 +1124,7 @@ Hello, {{name}}!
reg := registry.New()

// Call LoadPromptFolder
LoadPromptDir(reg, tempDir, "test-namespace")
LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace")

// Verify that the prompt was registered correctly
prompt := LookupPrompt(reg, "test-namespace/example")
Expand All @@ -1137,16 +1139,137 @@ Hello, {{name}}!
}
}

func TestLoadPromptFolder_DirectoryNotFound(t *testing.T) {
func TestLoadPromptFolder_EmptyDirectory(t *testing.T) {
// Initialize a mock registry
reg := &registry.Registry{}
reg := registry.New()

// Call LoadPromptFolder with a non-existent directory
LoadPromptDir(reg, "", "test-namespace")
// Create an empty temp directory
tempDir := t.TempDir()

// Call LoadPromptFolder with an empty directory
LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace")

// Verify that no prompts were registered
if prompt := LookupPrompt(reg, "example"); prompt != nil {
t.Fatalf("Prompt should not have been registered for a non-existent directory")
t.Fatalf("Prompt should not have been registered for an empty directory")
}
}

func TestLoadPromptFS(t *testing.T) {
mockPromptContent := `---
model: test/chat
description: A test prompt
input:
schema:
type: object
properties:
name:
type: string
output:
format: text
schema:
type: string
---

Hello, {{name}}!
`
mockPartialContent := `Welcome {{name}}!`

fsys := fstest.MapFS{
"prompts/example.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)},
"prompts/sub/nested.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)},
"prompts/_greeting.prompt": &fstest.MapFile{Data: []byte(mockPartialContent)},
}

reg := registry.New()

LoadPromptDirFromFS(reg, fsys, "prompts", "test-namespace")

prompt := LookupPrompt(reg, "test-namespace/example")
if prompt == nil {
t.Fatalf("Prompt 'test-namespace/example' was not registered")
}

nestedPrompt := LookupPrompt(reg, "test-namespace/nested")
if nestedPrompt == nil {
t.Fatalf("Nested prompt 'test-namespace/nested' was not registered")
}
}

func TestLoadPromptFS_WithVariant(t *testing.T) {
mockPromptContent := `---
model: test/chat
description: A test prompt with variant
---

Hello from variant!
`

fsys := fstest.MapFS{
"prompts/greeting.experimental.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)},
}

reg := registry.New()

LoadPromptDirFromFS(reg, fsys, "prompts", "")

prompt := LookupPrompt(reg, "greeting.experimental")
if prompt == nil {
t.Fatalf("Prompt with variant 'greeting.experimental' was not registered")
}
}

func TestLoadPromptFS_NilFS(t *testing.T) {
reg := registry.New()

defer func() {
if r := recover(); r == nil {
t.Errorf("Expected panic for nil filesystem")
}
}()

LoadPromptDirFromFS(reg, nil, "prompts", "test-namespace")
}

func TestLoadPromptFS_InvalidRoot(t *testing.T) {
fsys := fstest.MapFS{
"other/example.prompt": &fstest.MapFile{Data: []byte("test")},
}

reg := registry.New()

defer func() {
if r := recover(); r == nil {
t.Errorf("Expected panic for invalid root directory")
}
}()

LoadPromptDirFromFS(reg, fsys, "nonexistent", "test-namespace")
}

func TestLoadPromptFromFS(t *testing.T) {
mockPromptContent := `---
model: test/chat
description: A single prompt test
---

Test content
`

fsys := fstest.MapFS{
"prompts/single.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)},
}

reg := registry.New()

prompt := LoadPromptFromFS(reg, fsys, "prompts", "single.prompt", "ns")
if prompt == nil {
t.Fatalf("LoadPromptFromFS failed to load prompt")
}

lookedUp := LookupPrompt(reg, "ns/single")
if lookedUp == nil {
t.Fatalf("Prompt 'ns/single' was not registered")
}
}

Expand Down Expand Up @@ -1206,7 +1329,7 @@ Hello!
ConfigureFormats(reg)
definePromptModel(reg)

prompt := LoadPrompt(reg, tempDir, "example.prompt", "multi-namespace")
prompt := LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.prompt", "multi-namespace")

_, err = prompt.Execute(context.Background())
if err != nil {
Expand All @@ -1233,7 +1356,7 @@ Hello!
t.Fatalf("Failed to create mock prompt file: %v", err)
}

prompt := LoadPrompt(registry.New(), tempDir, "example.prompt", "multi-namespace-roles")
prompt := LoadPromptFromFS(registry.New(), os.DirFS(tempDir), ".", "example.prompt", "multi-namespace-roles")

actionOpts, err := prompt.Render(context.Background(), map[string]any{})
if err != nil {
Expand Down
Loading
Loading