Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 142 additions & 32 deletions github/gen-accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

//go:build ignore

// gen-accessors generates accessor methods for structs with pointer fields.
// gen-accessors generates accessor methods for all struct fields.
// This is so that interfaces can be easily crafted by users of this repo
// within their own code bases.
// See https://github.com/google/go-github/issues/4059 for details.
//
// It is meant to be used by go-github contributors in conjunction with the
// go generate tool before sending a PR to GitHub.
// Please see the CONTRIBUTING.md file for more information.
//
// Usage:
//
// go run gen-accessors.go [-v [file1.go file2.go ...]]
package main

import (
Expand Down Expand Up @@ -39,14 +46,15 @@ var (

// skipStructMethods lists "struct.method" combos to skip.
skipStructMethods = map[string]bool{
"RepositoryContent.GetContent": true,
"AbuseRateLimitError.GetResponse": true,
"Client.GetBaseURL": true,
"Client.GetUploadURL": true,
"ErrorResponse.GetResponse": true,
"RateLimitError.GetResponse": true,
"AbuseRateLimitError.GetResponse": true,
"MarketplaceService.GetStubbed": true,
"PackageVersion.GetBody": true,
"PackageVersion.GetMetadata": true,
"RateLimitError.GetResponse": true,
"RepositoryContent.GetContent": true,
}
// skipStructs lists structs to skip.
skipStructs = map[string]bool{
Expand All @@ -67,6 +75,18 @@ func logf(fmt string, args ...any) {

func main() {
flag.Parse()

// For debugging purposes, processing just a single or a few files is helpful:
var processOnly map[string]bool
if *verbose { // Only create the map if args are provided.
for _, arg := range flag.Args() {
if processOnly == nil {
processOnly = map[string]bool{}
}
processOnly[arg] = true
}
}

fset := token.NewFileSet()

pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
Expand All @@ -83,6 +103,10 @@ func main() {
Imports: map[string]string{},
}
for filename, f := range pkg.Files {
if *verbose && processOnly != nil && !processOnly[filename] {
continue
}

logf("Processing %v...", filename)
if err := t.processAST(f); err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -116,8 +140,12 @@ func (t *templateData) processAST(f *ast.File) error {
logf("Struct %v is in skip list; skipping.", ts.Name)
continue
}
if _, ok := ts.Type.(*ast.Ident); ok { // e.g. type SomeService service
continue
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
logf("Skipping TypeSpec of type %T", ts.Type)
continue
}
for _, field := range st.Fields.List {
Expand All @@ -141,14 +169,21 @@ func (t *templateData) processAST(f *ast.File) error {
if !ok {
switch x := field.Type.(type) {
case *ast.MapType:
logf("processAST: addMapType(x, %q, %q)", ts.Name.String(), fieldName.String())
t.addMapType(x, ts.Name.String(), fieldName.String(), false)
continue
case *ast.ArrayType:
if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] {
logf("Method %v is whitelist; adding getter method.", key)
t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
continue
}
logf("processAST: addArrayType(x, %q, %q)", ts.Name.String(), fieldName.String())
t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
continue
case *ast.Ident:
logf("processAST: addSimpleValueIdent(x, %q, %q)", ts.Name.String(), fieldName.String())
t.addSimpleValueIdent(x, ts.Name.String(), fieldName.String())
continue
case *ast.SelectorExpr:
logf("processAST: addSimpleValueSelectorExpr(x, %q, %q)", ts.Name.String(), fieldName.String())
t.addSimpleValueSelectorExpr(x, ts.Name.String(), fieldName.String())
continue
}

logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
Expand Down Expand Up @@ -254,24 +289,62 @@ func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName st
t.Getters = append(t.Getters, ng)
}

func (t *templateData) addSimpleValueIdent(x *ast.Ident, receiverType, fieldName string) {
getter := genIdentGetter(x, receiverType, fieldName)
getter.IsSimpleValue = true
logf("addSimpleValueIdent: Processing %q - fieldName=%q, getter.ZeroValue=%q, x.Obj=%#v", x.String(), fieldName, getter.ZeroValue, x.Obj)
if getter.ZeroValue == "nil" {
if x.Obj == nil {
switch x.String() {
case "any": // NOOP - leave as `nil`
default:
getter.ZeroValue = x.String() + "{}"
}
} else {
if ts, ok := x.Obj.Decl.(*ast.TypeSpec); ok {
logf("addSimpleValueIdent: Processing %q of type %T", x.String(), ts.Type)
switch xX := ts.Type.(type) {
case *ast.Ident:
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
getter.ZeroValue = zeroValueOfIdent(xX)
case *ast.StructType:
getter.ZeroValue = x.String() + "{}"
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
case *ast.InterfaceType, *ast.ArrayType: // NOOP - leave as `nil`
logf("addSimpleValueIdent: Processing %q of type %T - zero value is %q", x.String(), ts.Type, getter.ZeroValue)
default:
log.Fatalf("addSimpleValueIdent: unhandled case %T", xX)
}
}
}
}
t.Getters = append(t.Getters, getter)
}

func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
var zeroValue string
var namedStruct bool
getter := genIdentGetter(x, receiverType, fieldName)
t.Getters = append(t.Getters, getter)
}

func zeroValueOfIdent(x *ast.Ident) string {
switch x.String() {
case "int", "int64":
zeroValue = "0"
case "int", "int64", "float64", "uint8", "uint16":
return "0"
case "string":
zeroValue = `""`
return `""`
case "bool":
zeroValue = "false"
return "false"
case "Timestamp":
zeroValue = "Timestamp{}"
return "Timestamp{}"
default:
zeroValue = "nil"
namedStruct = true
return "nil"
}
}

t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
func genIdentGetter(x *ast.Ident, receiverType, fieldName string) *getter {
zeroValue := zeroValueOfIdent(x)
namedStruct := zeroValue == "nil"
return newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)
}

func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
Expand Down Expand Up @@ -300,10 +373,28 @@ func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string
t.Getters = append(t.Getters, ng)
}

func (t *templateData) addSimpleValueSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
getter := t.genSelectorExprGetter(x, receiverType, fieldName)
if getter == nil {
return
}
getter.IsSimpleValue = true
logf("addSimpleValueSelectorExpr: Processing field name %q - %#v - zero value is %q", fieldName, x, getter.ZeroValue)
t.Getters = append(t.Getters, getter)
}

func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
getter := t.genSelectorExprGetter(x, receiverType, fieldName)
if getter == nil {
return
}
t.Getters = append(t.Getters, getter)
}

func (t *templateData) genSelectorExprGetter(x *ast.SelectorExpr, receiverType, fieldName string) *getter {
if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
return nil
}

var xX string
if xx, ok := x.X.(*ast.Ident); ok {
Expand All @@ -322,10 +413,12 @@ func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldN
if xX == "time" && x.Sel.Name == "Duration" {
zeroValue = "0"
}
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
return newGetter(receiverType, fieldName, fieldType, zeroValue, false)
default:
logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
}

return nil
}

type templateData struct {
Expand All @@ -337,15 +430,16 @@ type templateData struct {
}

type getter struct {
sortVal string // Lower-case version of "ReceiverType.FieldName".
ReceiverVar string // The one-letter variable name to match the ReceiverType.
ReceiverType string
FieldName string
FieldType string
ZeroValue string
NamedStruct bool // Getter for named struct.
MapType bool
ArrayType bool
sortVal string // Lower-case version of "ReceiverType.FieldName".
ReceiverVar string // The one-letter variable name to match the ReceiverType.
ReceiverType string
FieldName string
FieldType string
ZeroValue string
NamedStruct bool // Getter for named struct.
MapType bool
ArrayType bool
IsSimpleValue bool
}

const source = `// Code generated by gen-accessors; DO NOT EDIT.
Expand All @@ -366,7 +460,15 @@ import (
)
{{end}}
{{range .Getters}}
{{if .NamedStruct}}
{{if .IsSimpleValue}}
// Get{{.FieldName}} returns the {{.FieldName}} field.
func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
if {{.ReceiverVar}} == nil {
return {{.ZeroValue}}
}
return {{.ReceiverVar}}.{{.FieldName}}
}
{{else if .NamedStruct}}
// Get{{.FieldName}} returns the {{.FieldName}} field.
func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
if {{.ReceiverVar}} == nil {
Expand Down Expand Up @@ -413,7 +515,15 @@ import (
)
{{end}}
{{range .Getters}}
{{if .NamedStruct}}
{{if .IsSimpleValue}}
func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
tt.Parallel()
{{.ReceiverVar}} := &{{.ReceiverType}}{}
{{.ReceiverVar}}.Get{{.FieldName}}()
{{.ReceiverVar}} = nil
{{.ReceiverVar}}.Get{{.FieldName}}()
}
{{else if .NamedStruct}}
func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
tt.Parallel()
{{.ReceiverVar}} := &{{.ReceiverType}}{}
Expand Down
Loading
Loading