diff --git a/protoc-gen-connect-python/generator/generator.go b/protoc-gen-connect-python/generator/generator.go index 8f4827a..7f5a3a6 100644 --- a/protoc-gen-connect-python/generator/generator.go +++ b/protoc-gen-connect-python/generator/generator.go @@ -2,65 +2,49 @@ package generator import ( "bytes" + "context" "fmt" "path" "slices" "strings" "unicode" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" + "github.com/bufbuild/protoplugin" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/pluginpb" ) -func Generate(r *pluginpb.CodeGeneratorRequest) *pluginpb.CodeGeneratorResponse { - resp := &pluginpb.CodeGeneratorResponse{} +func Handle(ctx context.Context, _ protoplugin.PluginEnv, responseWriter protoplugin.ResponseWriter, request protoplugin.Request) error { + responseWriter.SetFeatureProto3Optional() + responseWriter.SetFeatureSupportsEditions( + descriptorpb.Edition_EDITION_PROTO3, + descriptorpb.Edition_EDITION_2023, + ) - resp.SupportedFeatures = proto.Uint64(uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) | uint64(pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS)) - resp.MinimumEdition = proto.Int32(int32(descriptorpb.Edition_EDITION_PROTO3)) - resp.MaximumEdition = proto.Int32(int32(descriptorpb.Edition_EDITION_2023)) + conf := parseConfig(request.Parameter()) - conf := parseConfig(r.GetParameter()) - - files := r.GetFileToGenerate() - if len(files) == 0 { - resp.Error = proto.String("no files to generate") - return resp - } - - fds := &descriptorpb.FileDescriptorSet{ - File: r.GetProtoFile(), - } - reg, err := protodesc.NewFiles(fds) + fileDescriptors, err := request.FileDescriptorsToGenerate() if err != nil { - panic(err) + return fmt.Errorf("failed to get file descriptors to generate: %w", err) } - reg.RangeFiles(func(fd protoreflect.FileDescriptor) bool { - if !slices.Contains(files, string(fd.Path())) { - return true - } - + for _, fileDescriptor := range fileDescriptors { // We don't generate any code for non-services - if fd.Services().Len() == 0 { - return true + if fileDescriptor.Services().Len() == 0 { + continue } - connectFile, err := GenerateConnectFile(fd, conf) + name, content, err := generateConnectFile(fileDescriptor, conf) if err != nil { - resp.Error = proto.String("File[" + fd.Path() + "][generate]: " + err.Error()) - return false + return fmt.Errorf("failed to generate file %q: %w", fileDescriptor.Path(), err) } - resp.File = append(resp.File, connectFile) - return true - }) + responseWriter.AddFile(name, content) + } - return resp + return nil } -func GenerateConnectFile(fd protoreflect.FileDescriptor, conf Config) (*pluginpb.CodeGeneratorResponse_File, error) { +func generateConnectFile(fd protoreflect.FileDescriptor, conf Config) (string, string, error) { filename := fd.Path() fileNameWithoutSuffix := strings.TrimSuffix(filename, path.Ext(filename)) @@ -128,15 +112,11 @@ func GenerateConnectFile(fd protoreflect.FileDescriptor, conf Config) (*pluginpb var buf = &bytes.Buffer{} err := ConnectTemplate.Execute(buf, vars) if err != nil { - return nil, err - } - - resp := &pluginpb.CodeGeneratorResponse_File{ - Name: proto.String(strings.TrimSuffix(filename, path.Ext(filename)) + "_connect.py"), - Content: proto.String(buf.String()), + return "", "", fmt.Errorf("failed to execute template: %w", err) } - return resp, nil + outputName := strings.TrimSuffix(filename, path.Ext(filename)) + "_connect.py" + return outputName, buf.String(), nil } func sanitizePythonName(name string) string { diff --git a/protoc-gen-connect-python/generator/generator_test.go b/protoc-gen-connect-python/generator/generator_test.go index e15ed39..3bac8ad 100644 --- a/protoc-gen-connect-python/generator/generator_test.go +++ b/protoc-gen-connect-python/generator/generator_test.go @@ -1,9 +1,12 @@ package generator import ( + "bytes" + "io" "strings" "testing" + "github.com/bufbuild/protoplugin" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/types/descriptorpb" @@ -11,6 +14,8 @@ import ( ) func TestGenerateConnectFile(t *testing.T) { + t.Parallel() + tests := []struct { name string input *descriptorpb.FileDescriptorProto @@ -95,17 +100,17 @@ func TestGenerateConnectFile(t *testing.T) { t.Fatalf("Failed to create FileDescriptorProto: %v", err) return } - got, err := GenerateConnectFile(fd, Config{}) + gotName, gotContent, err := generateConnectFile(fd, Config{}) if (err != nil) != tt.wantErr { - t.Errorf("GenerateConnectFile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("generateConnectFile() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil { - if got.GetName() != tt.wantFile { - t.Errorf("GenerateConnectFile() got filename = %v, want %v", got.GetName(), tt.wantFile) + if gotName != tt.wantFile { + t.Errorf("generateConnectFile() got filename = %v, want %v", gotName, tt.wantFile) } - content := got.GetContent() + content := gotContent if !strings.Contains(content, "from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping") { t.Error("Generated code missing required imports") } @@ -118,6 +123,8 @@ func TestGenerateConnectFile(t *testing.T) { } func TestGenerate(t *testing.T) { + t.Parallel() + tests := []struct { name string req *pluginpb.CodeGeneratorRequest @@ -193,21 +200,21 @@ func TestGenerate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := Generate(tt.req) + resp := generate(t, tt.req) if tt.wantErr { if resp.GetError() == "" { - t.Error("Generate() expected error but got none") + t.Error("generate() expected error but got none") } } else { if resp.GetError() != "" { - t.Errorf("Generate() unexpected error: %v", resp.GetError()) + t.Errorf("generate() unexpected error: %v", resp.GetError()) } if len(resp.GetFile()) == 0 { - t.Error("Generate() returned no files") + t.Error("generate() returned no files") } for _, s := range tt.wantStrings { if !strings.Contains(resp.GetFile()[0].GetContent(), s) { - t.Errorf("Generate() missing expected string: %v", s) + t.Errorf("generate() missing expected string: %v", s) } } } @@ -216,6 +223,8 @@ func TestGenerate(t *testing.T) { } func TestEdition2023Support(t *testing.T) { + t.Parallel() + // Create a request with an Edition 2023 proto file edition2023 := descriptorpb.Edition_EDITION_2023 @@ -274,11 +283,11 @@ func TestEdition2023Support(t *testing.T) { } // Call Generate - resp := Generate(req) + resp := generate(t, req) // Verify no error occurred if resp.GetError() != "" { - t.Fatalf("Generate() failed for Edition 2023 proto: %v", resp.GetError()) + t.Fatalf("generate() failed for Edition 2023 proto: %v", resp.GetError()) } // Verify the generator declared Edition support @@ -310,3 +319,47 @@ func TestEdition2023Support(t *testing.T) { } } } + +// generate is a test helper that runs the plugin handler using [protoplugin.Run]. +func generate(t *testing.T, req *pluginpb.CodeGeneratorRequest) *pluginpb.CodeGeneratorResponse { + t.Helper() + + // Marshal request to bytes for stdin + reqBytes, err := proto.Marshal(req) + if err != nil { + resp := &pluginpb.CodeGeneratorResponse{} + resp.Error = proto.String("failed to marshal request: " + err.Error()) + return resp + } + + // Prepare stdin and stdout + stdin := bytes.NewReader(reqBytes) + stdout := &bytes.Buffer{} + + // Run the plugin + err = protoplugin.Run( + t.Context(), + protoplugin.Env{ + Args: nil, + Environ: nil, + Stdin: stdin, + Stdout: stdout, + Stderr: io.Discard, + }, + protoplugin.HandlerFunc(Handle), + ) + if err != nil { + resp := &pluginpb.CodeGeneratorResponse{} + resp.Error = proto.String("failed to run plugin: " + err.Error()) + return resp + } + + // Unmarshal response + resp := &pluginpb.CodeGeneratorResponse{} + if err := proto.Unmarshal(stdout.Bytes(), resp); err != nil { + errorResp := &pluginpb.CodeGeneratorResponse{} + errorResp.Error = proto.String("failed to unmarshal response: " + err.Error()) + return errorResp + } + return resp +} diff --git a/protoc-gen-connect-python/generator/template_test.go b/protoc-gen-connect-python/generator/template_test.go index ff614c4..ee57afc 100644 --- a/protoc-gen-connect-python/generator/template_test.go +++ b/protoc-gen-connect-python/generator/template_test.go @@ -7,6 +7,7 @@ import ( ) func TestConnectTemplate(t *testing.T) { + t.Parallel() tests := []struct { name string vars ConnectTemplateVariables @@ -73,6 +74,7 @@ func TestConnectTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var buf bytes.Buffer err := ConnectTemplate.Execute(&buf, tt.vars) if err != nil { diff --git a/protoc-gen-connect-python/go.mod b/protoc-gen-connect-python/go.mod index dc683ae..db701de 100644 --- a/protoc-gen-connect-python/go.mod +++ b/protoc-gen-connect-python/go.mod @@ -2,4 +2,7 @@ module github.com/connectrpc/connect-python/protoc-gen-connect-python go 1.24.3 -require google.golang.org/protobuf v1.36.11 +require ( + github.com/bufbuild/protoplugin v0.0.0-20250218205857-750e09ce93e1 + google.golang.org/protobuf v1.36.11 +) diff --git a/protoc-gen-connect-python/go.sum b/protoc-gen-connect-python/go.sum index 296be18..3287406 100644 --- a/protoc-gen-connect-python/go.sum +++ b/protoc-gen-connect-python/go.sum @@ -1,4 +1,18 @@ +github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= +github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= +github.com/bufbuild/protoplugin v0.0.0-20250218205857-750e09ce93e1 h1:V1xulAoqLqVg44rY97xOR+mQpD2N+GzhMHVwJ030WEU= +github.com/bufbuild/protoplugin v0.0.0-20250218205857-750e09ce93e1/go.mod h1:c5D8gWRIZ2HLWO3gXYTtUfw/hbJyD8xikv2ooPxnklQ= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/protoc-gen-connect-python/main.go b/protoc-gen-connect-python/main.go index 6e49b62..911ddee 100644 --- a/protoc-gen-connect-python/main.go +++ b/protoc-gen-connect-python/main.go @@ -1,44 +1,11 @@ package main import ( - "io" - "log" - "os" - - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/pluginpb" + "github.com/bufbuild/protoplugin" "github.com/connectrpc/connect-python/protoc-gen-connect-python/generator" ) func main() { - data, err := io.ReadAll(os.Stdin) - if err != nil { - log.Fatalln("could not read from stdin", err) - return - } - req := &pluginpb.CodeGeneratorRequest{} - err = proto.Unmarshal(data, req) - if err != nil { - log.Fatalln("could not unmarshal proto", err) - return - } - if len(req.GetFileToGenerate()) == 0 { - log.Fatalln("no files to generate") - return - } - resp := generator.Generate(req) - - if resp == nil { - resp = &pluginpb.CodeGeneratorResponse{} - } - - data, err = proto.Marshal(resp) - if err != nil { - log.Fatalln("could not unmarshal response proto", err) - } - _, err = os.Stdout.Write(data) - if err != nil { - log.Fatalln("could not write response to stdout", err) - } + protoplugin.Main(protoplugin.HandlerFunc(generator.Handle)) } diff --git a/protoc-gen-connect-python/pyproject.toml b/protoc-gen-connect-python/pyproject.toml index 6c5b25f..08f496d 100644 --- a/protoc-gen-connect-python/pyproject.toml +++ b/protoc-gen-connect-python/pyproject.toml @@ -37,6 +37,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Software Development :: Code Generators", "Topic :: Software Development :: Compilers", "Topic :: System :: Networking", diff --git a/pyproject.toml b/pyproject.toml index ad206fe..fb373d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12",