Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions cmd/avrogo/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import (
"strconv"
"strings"

"golang.org/x/text/cases"
"golang.org/x/text/language"

"github.com/actgardner/gogen-avro/v10/parser"
"github.com/actgardner/gogen-avro/v10/schema"
)
Expand Down Expand Up @@ -45,7 +42,7 @@ func shouldImportAvroTypeGen(namespace *parser.Namespace, definitions []schema.Q
return false
}

func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schema.QualifiedName) error {
func generate(w io.Writer, pkg string, caser func(string) string, ns *parser.Namespace, definitions []schema.QualifiedName) error {
extTypes, err := externalTypeMap(ns)
if err != nil {
return err
Expand All @@ -63,6 +60,7 @@ func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schem
gc := &generateContext{
imports: make(map[string]string),
extTypes: extTypes,
caser: caser,
}
// Add avrotypegen package conditionally when there is a RecordDefinition in the namespace.
if shouldImportAvroTypeGen(ns, definitions) {
Expand Down Expand Up @@ -405,7 +403,7 @@ func (gc *generateContext) defaultFuncLiteral(v interface{}, t schema.AvroType)
if err != nil {
return "", fmt.Errorf("at field %s: %v", field.Name(), err)
}
ident, err := goName(field.Name())
ident, err := gc.goName(field.Name())
if err != nil {
return "", err
}
Expand All @@ -423,10 +421,10 @@ func (gc *generateContext) defaultFuncLiteral(v interface{}, t schema.AvroType)
}

// goName returns an exported Go identifier for the Avro name s.
func goName(s string) (string, error) {
func (gc *generateContext) goName(s string) (string, error) {
lastIndex := strings.LastIndex(s, ".")
name := s[lastIndex+1:]
name = cases.Title(language.Und, cases.NoLower).String(strings.Trim(name, "_"))
name = gc.caser(name)
if !isExportedGoIdentifier(name) {
return "", fmt.Errorf("cannot form an exported Go identifier from %q", s)
}
Expand Down Expand Up @@ -489,6 +487,7 @@ func writeUnionComment(w io.Writer, union []typeInfo, indent string) {
type generateContext struct {
imports map[string]string
extTypes map[schema.QualifiedName]goType
caser func(string) string
}

func (gc *generateContext) GoTypeOf(t schema.AvroType) typeInfo {
Expand Down
128 changes: 127 additions & 1 deletion cmd/avrogo/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,133 @@ func TestGenerate(t *testing.T) {
ns, fileDefinitions, err := parseFiles([]string{"testdata/schema/object.avsc"})
assert.NoError(t, err)

err = generate(&buf, testPackage, ns, fileDefinitions[0])
err = generate(&buf, testPackage, getCaser(), ns, fileDefinitions[0])
assert.NoError(t, err)
g.Assert(t, "object", buf.Bytes())
}

func TestGoName(t *testing.T) {
var testcases = []struct {
testName string
goInitialisms bool
extraInitialisms string
avroName string
goName string
}{
{
testName: "default naming",
goInitialisms: false,
avroName: "user.first_name",
goName: "First_name",
},
{
testName: "Go initialisms",
goInitialisms: true,
avroName: "user.first_name",
goName: "FirstName",
},
{
testName: "default naming with ID",
goInitialisms: false,
avroName: "user.user_id",
goName: "User_id",
},
{
testName: "Go initialisms with ID",
goInitialisms: true,
avroName: "user.user_id",
goName: "UserID",
},
{
testName: "Go initialisms without extra initialisms",
goInitialisms: true,
avroName: "power.power_mw",
goName: "PowerMw",
},
{
testName: "Go initialisms with extra initialisms",
goInitialisms: true,
extraInitialisms: "KW,MW,GW",
avroName: "power.power_mw",
goName: "PowerMW",
},
}

c := qt.New(t)

for _, test := range testcases {
c.Run(test.testName, func(c *qt.C) {
goInitialismsFlag = &test.goInitialisms
extraInitialismsFlag = &test.extraInitialisms
gc := generateContext{caser: getCaser()}

gotName, err := gc.goName(test.avroName)
c.Assert(err, qt.IsNil)
c.Assert(gotName, qt.Equals, test.goName)
})
}
}

func TestSymbolName(t *testing.T) {
enumDef := avro.NewEnumDefinition(avro.QualifiedName{Namespace: "ns", Name: "name"}, nil, nil, "", "", nil)

var testcases = []struct {
testName string
goInitialisms bool
extraInitialisms string
symbol string
goName string
}{
{
testName: "default naming",
goInitialisms: false,
symbol: "OPTION_ONE",
goName: "NameOPTION_ONE",
},
{
testName: "Go initialisms",
goInitialisms: true,
symbol: "OPTION_ONE",
goName: "NameOptionOne",
},
{
testName: "default naming with ID",
goInitialisms: false,
symbol: "OPTION_ID",
goName: "NameOPTION_ID",
},
{
testName: "Go initialisms with ID",
goInitialisms: true,
symbol: "OPTION_ID",
goName: "NameOptionID",
},
{
testName: "Go initialisms without extra initialisms",
goInitialisms: true,
symbol: "OPTION_TWO",
goName: "NameOptionTwo",
},
{
testName: "Go initialisms with extra initialisms",
goInitialisms: true,
extraInitialisms: "ONE,TWO,THREE",
symbol: "OPTION_TWO",
goName: "NameOptionTWO",
},
}

c := qt.New(t)

for _, test := range testcases {
c.Run(test.testName, func(c *qt.C) {
goInitialismsFlag = &test.goInitialisms
extraInitialismsFlag = &test.extraInitialisms
gc := generateContext{caser: getCaser()}

gotName := symbolName(&gc, enumDef, test.symbol)
c.Assert(gotName, qt.Equals, test.goName)
})
}
}

46 changes: 40 additions & 6 deletions cmd/avrogo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
// suffix for generated files (default "_gen")
// -tokenize
// if true, generate one dedicated file per qualified name found in the schema files
// -goinitialisms
// if true, use standard Go initialisms in names
// -extrainitialisms
// comma separated list of initialisms to use in names in addition to standard Go initialisms
//
// By default, a type is generated for each Avro definition
// in the schema. Some additional metadata fields are
Expand All @@ -42,18 +46,23 @@ import (
"github.com/actgardner/gogen-avro/v10/parser"
"github.com/actgardner/gogen-avro/v10/resolver"
"github.com/actgardner/gogen-avro/v10/schema"
"github.com/ettle/strcase"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)

// Generate the tests.

//go:generate go run ./generatetestcode.go

var (
dirFlag = flag.String("d", ".", "directory to write Go files to")
pkgFlag = flag.String("p", os.Getenv("GOPACKAGE"), "package name (defaults to $GOPACKAGE)")
testFlag = flag.Bool("t", strings.HasSuffix(os.Getenv("GOFILE"), "_test.go"), "generated files will have _test.go suffix (defaults to true if $GOFILE is a test file)")
suffixFlag = flag.String("s", "_gen", "suffix for generated files")
tokenizeFlag = flag.Bool("tokenize", false, "generate one dedicated file per qualified name found in the input schema files")
dirFlag = flag.String("d", ".", "directory to write Go files to")
pkgFlag = flag.String("p", os.Getenv("GOPACKAGE"), "package name (defaults to $GOPACKAGE)")
testFlag = flag.Bool("t", strings.HasSuffix(os.Getenv("GOFILE"), "_test.go"), "generated files will have _test.go suffix (defaults to true if $GOFILE is a test file)")
suffixFlag = flag.String("s", "_gen", "suffix for generated files")
tokenizeFlag = flag.Bool("tokenize", false, "generate one dedicated file per qualified name found in the input schema files")
goInitialismsFlag = flag.Bool("goinitialisms", false, "use standard Go initialisms in names")
extraInitialismsFlag = flag.String("extrainitialisms", "", "comma separated list of initialisms to use in names in addition to standard Go initialisms")
)

var flag = stdflag.NewFlagSet("", stdflag.ContinueOnError)
Expand All @@ -79,6 +88,10 @@ func main1() int {
fmt.Fprintf(os.Stderr, "avrogo: -p flag must specify a package name or set $GOPACKAGE\n")
return 1
}
if *extraInitialismsFlag != "" && !*goInitialismsFlag {
fmt.Fprintf(os.Stderr, "avrogo: -extrainitialisms flag must only be used with -goinitialisms\n")
return 1
}
if err := generateFiles(files); err != nil {
fmt.Fprintf(os.Stderr, "avrogo: %v\n", err)
return 1
Expand Down Expand Up @@ -196,9 +209,30 @@ func baseN(name string, n int) (string, bool) {
return strings.Join(parts, "_"), ok
}

// getCaser returns a function to be used to set the case of names in generated
// Go. The behaviour is determined by flags.
func getCaser() func(string) string {
if *goInitialismsFlag {
return strcase.NewCaser(true, parseExtraInitialismsFlag(), nil).ToPascal
}
return func(name string) string {
return cases.Title(language.Und, cases.NoLower).String(strings.Trim(name, "_"))
}
}

func parseExtraInitialismsFlag() map[string]bool {
result := map[string]bool{}
for _, initial := range strings.Split(*extraInitialismsFlag, ",") {
if initial != "" {
result[initial] = true
}
}
return result
}

func generateFile(outFile string, ns *parser.Namespace, definitions []schema.QualifiedName) error {
var buf bytes.Buffer
if err := generate(&buf, *pkgFlag, ns, definitions); err != nil {
if err := generate(&buf, *pkgFlag, getCaser(), ns, definitions); err != nil {
return err
}
if buf.Len() == 0 {
Expand Down
18 changes: 9 additions & 9 deletions cmd/avrogo/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package main

import (
"go/token"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"reflect"
"regexp"
"strconv"
Expand All @@ -27,9 +25,11 @@ var templateFuncs = template.FuncMap{
"isExportedGoIdentifier": isExportedGoIdentifier,
"defName": defName,
"symbolName": symbolName,
"goName": goName,
"indent": indent,
"doc": doc,
"goName": func(gc *generateContext, name string) (string, error) {
return gc.goName(name)
},
"indent": indent,
"doc": doc,
"import": func(gc *generateContext, pkg string) string {
gc.addImport(pkg)
return ""
Expand Down Expand Up @@ -73,7 +73,7 @@ var bodyTemplate = newTemplate(`
«- if isExportedGoIdentifier .Name»
«- .Name» «$type.GoType»
«- else»
«- goName .Name» «$type.GoType» ` + "`" + `json:«printf "%q" .Name»` + "`" + `
«- goName $.Ctx .Name» «$type.GoType» ` + "`" + `json:«printf "%q" .Name»` + "`" + `
«- end»
«end»
}
Expand All @@ -89,7 +89,7 @@ var bodyTemplate = newTemplate(`
type «defName .» int
const (
«- range $i, $sym := .Symbols»
«symbolName $def $sym»«if eq $i 0» «defName $def» = iota«end»
«symbolName $.Ctx $def $sym»«if eq $i 0» «defName $def» = iota«end»
«- end»
)

Expand Down Expand Up @@ -141,8 +141,8 @@ func defName(def schema.Definition) string {
return goTypeForDefinition(def).Name
}

func symbolName(e *schema.EnumDefinition, symbol string) string {
return defName(e) + cases.Title(language.Und, cases.NoLower).String(symbol)
func symbolName(gc *generateContext, e *schema.EnumDefinition, symbol string) string {
return defName(e) + gc.caser(symbol)
}

func quote(s string) string {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.23.0

require (
github.com/actgardner/gogen-avro/v10 v10.2.1
github.com/ettle/strcase v0.2.0
github.com/frankban/quicktest v1.14.0
github.com/google/uuid v1.6.0
github.com/kr/pretty v0.3.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ettle/strcase v0.2.0 h1:fGNiVF21fHXpX1niBgk0aROov1LagYsOwV/xqKDKR/Q=
github.com/ettle/strcase v0.2.0/go.mod h1:DajmHElDSaX76ITe3/VHVyMin4LWSJN5Z909Wp+ED1A=
github.com/frankban/quicktest v1.2.2/go.mod h1:Qh/WofXFeiAFII1aEBu529AtJo6Zg2VHscnEsbBnJ20=
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=
Expand Down