diff --git a/cmd/avrogo/generate.go b/cmd/avrogo/generate.go index 7fb212f..d20bcad 100644 --- a/cmd/avrogo/generate.go +++ b/cmd/avrogo/generate.go @@ -10,9 +10,6 @@ import ( "strconv" "strings" - "golang.org/x/text/cases" - "golang.org/x/text/language" - "github.com/actgardner/gogen-avro/v10/parser" "github.com/actgardner/gogen-avro/v10/schema" ) @@ -45,7 +42,7 @@ func shouldImportAvroTypeGen(namespace *parser.Namespace, definitions []schema.Q return false } -func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schema.QualifiedName) error { +func generate(w io.Writer, pkg string, caser func(string) string, ns *parser.Namespace, definitions []schema.QualifiedName) error { extTypes, err := externalTypeMap(ns) if err != nil { return err @@ -63,6 +60,7 @@ func generate(w io.Writer, pkg string, ns *parser.Namespace, definitions []schem gc := &generateContext{ imports: make(map[string]string), extTypes: extTypes, + caser: caser, } // Add avrotypegen package conditionally when there is a RecordDefinition in the namespace. if shouldImportAvroTypeGen(ns, definitions) { @@ -405,7 +403,7 @@ func (gc *generateContext) defaultFuncLiteral(v interface{}, t schema.AvroType) if err != nil { return "", fmt.Errorf("at field %s: %v", field.Name(), err) } - ident, err := goName(field.Name()) + ident, err := gc.goName(field.Name()) if err != nil { return "", err } @@ -423,10 +421,10 @@ func (gc *generateContext) defaultFuncLiteral(v interface{}, t schema.AvroType) } // goName returns an exported Go identifier for the Avro name s. -func goName(s string) (string, error) { +func (gc *generateContext) goName(s string) (string, error) { lastIndex := strings.LastIndex(s, ".") name := s[lastIndex+1:] - name = cases.Title(language.Und, cases.NoLower).String(strings.Trim(name, "_")) + name = gc.caser(name) if !isExportedGoIdentifier(name) { return "", fmt.Errorf("cannot form an exported Go identifier from %q", s) } @@ -489,6 +487,7 @@ func writeUnionComment(w io.Writer, union []typeInfo, indent string) { type generateContext struct { imports map[string]string extTypes map[schema.QualifiedName]goType + caser func(string) string } func (gc *generateContext) GoTypeOf(t schema.AvroType) typeInfo { diff --git a/cmd/avrogo/generate_test.go b/cmd/avrogo/generate_test.go index 2f978e0..2530f5c 100644 --- a/cmd/avrogo/generate_test.go +++ b/cmd/avrogo/generate_test.go @@ -88,7 +88,133 @@ func TestGenerate(t *testing.T) { ns, fileDefinitions, err := parseFiles([]string{"testdata/schema/object.avsc"}) assert.NoError(t, err) - err = generate(&buf, testPackage, ns, fileDefinitions[0]) + err = generate(&buf, testPackage, getCaser(), ns, fileDefinitions[0]) assert.NoError(t, err) g.Assert(t, "object", buf.Bytes()) } + +func TestGoName(t *testing.T) { + var testcases = []struct { + testName string + goInitialisms bool + extraInitialisms string + avroName string + goName string + }{ + { + testName: "default naming", + goInitialisms: false, + avroName: "user.first_name", + goName: "First_name", + }, + { + testName: "Go initialisms", + goInitialisms: true, + avroName: "user.first_name", + goName: "FirstName", + }, + { + testName: "default naming with ID", + goInitialisms: false, + avroName: "user.user_id", + goName: "User_id", + }, + { + testName: "Go initialisms with ID", + goInitialisms: true, + avroName: "user.user_id", + goName: "UserID", + }, + { + testName: "Go initialisms without extra initialisms", + goInitialisms: true, + avroName: "power.power_mw", + goName: "PowerMw", + }, + { + testName: "Go initialisms with extra initialisms", + goInitialisms: true, + extraInitialisms: "KW,MW,GW", + avroName: "power.power_mw", + goName: "PowerMW", + }, + } + + c := qt.New(t) + + for _, test := range testcases { + c.Run(test.testName, func(c *qt.C) { + goInitialismsFlag = &test.goInitialisms + extraInitialismsFlag = &test.extraInitialisms + gc := generateContext{caser: getCaser()} + + gotName, err := gc.goName(test.avroName) + c.Assert(err, qt.IsNil) + c.Assert(gotName, qt.Equals, test.goName) + }) + } +} + +func TestSymbolName(t *testing.T) { + enumDef := avro.NewEnumDefinition(avro.QualifiedName{Namespace: "ns", Name: "name"}, nil, nil, "", "", nil) + + var testcases = []struct { + testName string + goInitialisms bool + extraInitialisms string + symbol string + goName string + }{ + { + testName: "default naming", + goInitialisms: false, + symbol: "OPTION_ONE", + goName: "NameOPTION_ONE", + }, + { + testName: "Go initialisms", + goInitialisms: true, + symbol: "OPTION_ONE", + goName: "NameOptionOne", + }, + { + testName: "default naming with ID", + goInitialisms: false, + symbol: "OPTION_ID", + goName: "NameOPTION_ID", + }, + { + testName: "Go initialisms with ID", + goInitialisms: true, + symbol: "OPTION_ID", + goName: "NameOptionID", + }, + { + testName: "Go initialisms without extra initialisms", + goInitialisms: true, + symbol: "OPTION_TWO", + goName: "NameOptionTwo", + }, + { + testName: "Go initialisms with extra initialisms", + goInitialisms: true, + extraInitialisms: "ONE,TWO,THREE", + symbol: "OPTION_TWO", + goName: "NameOptionTWO", + }, + } + + c := qt.New(t) + + for _, test := range testcases { + c.Run(test.testName, func(c *qt.C) { + goInitialismsFlag = &test.goInitialisms + extraInitialismsFlag = &test.extraInitialisms + gc := generateContext{caser: getCaser()} + + gotName := symbolName(&gc, enumDef, test.symbol) + c.Assert(gotName, qt.Equals, test.goName) + }) + } +} + diff --git a/cmd/avrogo/main.go b/cmd/avrogo/main.go index 447980b..a9d5ac6 100644 --- a/cmd/avrogo/main.go +++ b/cmd/avrogo/main.go @@ -19,6 +19,10 @@ // suffix for generated files (default "_gen") // -tokenize // if true, generate one dedicated file per qualified name found in the schema files +// -goinitialisms +// if true, use standard Go initialisms in names +// -extrainitialisms +// comma separated list of initialisms to use in names in addition to standard Go initialisms // // By default, a type is generated for each Avro definition // in the schema. Some additional metadata fields are @@ -42,6 +46,9 @@ import ( "github.com/actgardner/gogen-avro/v10/parser" "github.com/actgardner/gogen-avro/v10/resolver" "github.com/actgardner/gogen-avro/v10/schema" + "github.com/ettle/strcase" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) // Generate the tests. @@ -49,11 +56,13 @@ import ( //go:generate go run ./generatetestcode.go var ( - dirFlag = flag.String("d", ".", "directory to write Go files to") - pkgFlag = flag.String("p", os.Getenv("GOPACKAGE"), "package name (defaults to $GOPACKAGE)") - testFlag = flag.Bool("t", strings.HasSuffix(os.Getenv("GOFILE"), "_test.go"), "generated files will have _test.go suffix (defaults to true if $GOFILE is a test file)") - suffixFlag = flag.String("s", "_gen", "suffix for generated files") - tokenizeFlag = flag.Bool("tokenize", false, "generate one dedicated file per qualified name found in the input schema files") + dirFlag = flag.String("d", ".", "directory to write Go files to") + pkgFlag = flag.String("p", os.Getenv("GOPACKAGE"), "package name (defaults to $GOPACKAGE)") + testFlag = flag.Bool("t", strings.HasSuffix(os.Getenv("GOFILE"), "_test.go"), "generated files will have _test.go suffix (defaults to true if $GOFILE is a test file)") + suffixFlag = flag.String("s", "_gen", "suffix for generated files") + tokenizeFlag = flag.Bool("tokenize", false, "generate one dedicated file per qualified name found in the input schema files") + goInitialismsFlag = flag.Bool("goinitialisms", false, "use standard Go initialisms in names") + extraInitialismsFlag = flag.String("extrainitialisms", "", "comma separated list of initialisms to use in names in addition to standard Go initialisms") ) var flag = stdflag.NewFlagSet("", stdflag.ContinueOnError) @@ -79,6 +88,10 @@ func main1() int { fmt.Fprintf(os.Stderr, "avrogo: -p flag must specify a package name or set $GOPACKAGE\n") return 1 } + if *extraInitialismsFlag != "" && !*goInitialismsFlag { + fmt.Fprintf(os.Stderr, "avrogo: -extrainitialisms flag must only be used with -goinitialisms\n") + return 1 + } if err := generateFiles(files); err != nil { fmt.Fprintf(os.Stderr, "avrogo: %v\n", err) return 1 @@ -196,9 +209,30 @@ func baseN(name string, n int) (string, bool) { return strings.Join(parts, "_"), ok } +// getCaser returns a function to be used to set the case of names in generated +// Go. The behaviour is determined by flags. +func getCaser() func(string) string { + if *goInitialismsFlag { + return strcase.NewCaser(true, parseExtraInitialismsFlag(), nil).ToPascal + } + return func(name string) string { + return cases.Title(language.Und, cases.NoLower).String(strings.Trim(name, "_")) + } +} + +func parseExtraInitialismsFlag() map[string]bool { + result := map[string]bool{} + for _, initial := range strings.Split(*extraInitialismsFlag, ",") { + if initial != "" { + result[initial] = true + } + } + return result +} + func generateFile(outFile string, ns *parser.Namespace, definitions []schema.QualifiedName) error { var buf bytes.Buffer - if err := generate(&buf, *pkgFlag, ns, definitions); err != nil { + if err := generate(&buf, *pkgFlag, getCaser(), ns, definitions); err != nil { return err } if buf.Len() == 0 { diff --git a/cmd/avrogo/template.go b/cmd/avrogo/template.go index 998167e..cc22631 100644 --- a/cmd/avrogo/template.go +++ b/cmd/avrogo/template.go @@ -2,8 +2,6 @@ package main import ( "go/token" - "golang.org/x/text/cases" - "golang.org/x/text/language" "reflect" "regexp" "strconv" @@ -27,9 +25,11 @@ var templateFuncs = template.FuncMap{ "isExportedGoIdentifier": isExportedGoIdentifier, "defName": defName, "symbolName": symbolName, - "goName": goName, - "indent": indent, - "doc": doc, + "goName": func(gc *generateContext, name string) (string, error) { + return gc.goName(name) + }, + "indent": indent, + "doc": doc, "import": func(gc *generateContext, pkg string) string { gc.addImport(pkg) return "" @@ -73,7 +73,7 @@ var bodyTemplate = newTemplate(` «- if isExportedGoIdentifier .Name» «- .Name» «$type.GoType» «- else» - «- goName .Name» «$type.GoType» ` + "`" + `json:«printf "%q" .Name»` + "`" + ` + «- goName $.Ctx .Name» «$type.GoType» ` + "`" + `json:«printf "%q" .Name»` + "`" + ` «- end» «end» } @@ -89,7 +89,7 @@ var bodyTemplate = newTemplate(` type «defName .» int const ( «- range $i, $sym := .Symbols» - «symbolName $def $sym»«if eq $i 0» «defName $def» = iota«end» + «symbolName $.Ctx $def $sym»«if eq $i 0» «defName $def» = iota«end» «- end» ) @@ -141,8 +141,8 @@ func defName(def schema.Definition) string { return goTypeForDefinition(def).Name } -func symbolName(e *schema.EnumDefinition, symbol string) string { - return defName(e) + cases.Title(language.Und, cases.NoLower).String(symbol) +func symbolName(gc *generateContext, e *schema.EnumDefinition, symbol string) string { + return defName(e) + gc.caser(symbol) } func quote(s string) string { diff --git a/go.mod b/go.mod index cdc54b8..145e139 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( github.com/actgardner/gogen-avro/v10 v10.2.1 + github.com/ettle/strcase v0.2.0 github.com/frankban/quicktest v1.14.0 github.com/google/uuid v1.6.0 github.com/kr/pretty v0.3.0 diff --git a/go.sum b/go.sum index a3539eb..c55e49a 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/ettle/strcase v0.2.0 h1:fGNiVF21fHXpX1niBgk0aROov1LagYsOwV/xqKDKR/Q= +github.com/ettle/strcase v0.2.0/go.mod h1:DajmHElDSaX76ITe3/VHVyMin4LWSJN5Z909Wp+ED1A= github.com/frankban/quicktest v1.2.2/go.mod h1:Qh/WofXFeiAFII1aEBu529AtJo6Zg2VHscnEsbBnJ20= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=