Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//nolint:all // Forked from anchor-go generator, maintaining original code structure
package generator

import (
"strings"
"testing"

"github.com/dave/jennifer/jen"
"github.com/gagliardetto/anchor-go/idl"
"github.com/gagliardetto/anchor-go/idl/idltype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// collidingNamedFields is an IDL shape where two distinct field names normalize to the
// same Go identifier via tools.ToCamelUpper (foo_bar and fooBar -> FooBar). Struct
// generation deconflicts these as FooBar and FooBar1; marshal/unmarshal must use the same names.
func collidingNamedFields() idl.IdlDefinedFieldsNamed {
return idl.IdlDefinedFieldsNamed{
{Name: "foo_bar", Ty: &idltype.U8{}},
{Name: "fooBar", Ty: &idltype.U8{}},
}
}

func TestGenerateUniqueFieldNames_collidingIDLNames(t *testing.T) {
fields := collidingNamedFields()
m := generateUniqueFieldNames(fields)
require.Len(t, m, 2)
assert.Equal(t, "FooBar", m["foo_bar"])
assert.Equal(t, "FooBar1", m["fooBar"])
}

// TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames documents the regression where
// gen_MarshalWithEncoder_struct / gen_UnmarshalWithDecoder_struct used tools.ToCamelUpper(field.Name)
// for accessors instead of generateUniqueFieldNames: both fields targeted obj.FooBar, so one
// value was serialized twice and the FooBar1 sibling was never written or read.
//
// The expected assertions describe the correct fixed behavior; they fail until uniquified names
// are threaded through marshal/unmarshal generation.
func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) {
idlMinimal := &idl.Idl{}
fields := collidingNamedFields()
receiver := "CollideAccount"

marshalCode := gen_MarshalWithEncoder_struct(
idlMinimal,
false,
receiver,
"",
fields,
true,
)
unmarshalCode := gen_UnmarshalWithDecoder_struct(
idlMinimal,
false,
receiver,
"",
fields,
)

f := jen.NewFile("fixture")
f.Add(marshalCode)
f.Add(unmarshalCode)
src := f.GoString()

// Correct codegen must reference both uniquified struct fields.
assert.Contains(t, src, "obj.FooBar1", "marshal/unmarshal must access the deconflicted FooBar1 field")

// Buggy codegen encodes/decodes the same field twice; reject duplicate bare obj.FooBar
// Encode/Decode when a second distinct IDL field exists.
encodeFooBar := strings.Count(src, "Encode(obj.FooBar)")
decodeFooBar := strings.Count(src, "Decode(&obj.FooBar)")
assert.Equal(t, 1, encodeFooBar, "each IDL field must map to a single Encode(obj.<Field>); duplicate Encode(obj.FooBar) indicates silent corruption")
assert.Equal(t, 1, decodeFooBar, "each IDL field must map to a single Decode(&obj.<Field>); duplicate Decode(&obj.FooBar) indicates silent corruption")

assert.Contains(t, src, "Encode(obj.FooBar1)")
assert.Contains(t, src, "Decode(&obj.FooBar1)")
}

func TestGenerateUniqueParamNames_collidingIDLNames(t *testing.T) {
fields := collidingNamedFields()
m := generateUniqueParamNames(fields)
require.Len(t, m, 2)
assert.NotEqual(t, m[fields[0].Name], m[fields[1].Name])
b0 := formatParamName(fields[0].Name)
b1 := formatParamName(fields[1].Name)
if b0 == b1 {
assert.Equal(t, b0+"1", m[fields[1].Name])
}
}

func TestGenInstructionType_uniquifiesArgFieldsAndDecode(t *testing.T) {
ins := idl.IdlInstruction{
Name: "do_test",
Args: []idl.IdlField(collidingNamedFields()),
Accounts: []idl.IdlInstructionAccountItem{},
}
g := &Generator{idl: &idl.Idl{}, options: &GeneratorOptions{Package: "test"}}
code, err := g.gen_instructionType(ins)
require.NoError(t, err)

f := jen.NewFile("test")
f.Add(code)
src := f.GoString()

assert.Contains(t, src, "FooBar1")
assert.Equal(t, 1, strings.Count(src, "Decode(&obj.FooBar)"))
assert.Contains(t, src, "Decode(&obj.FooBar1)")
}

func TestGenInstructions_builderEncodesEachArgOnce(t *testing.T) {
fields := collidingNamedFields()
paramNames := generateUniqueParamNames(fields)

idlData := &idl.Idl{
Instructions: []idl.IdlInstruction{
{
Name: "do_test",
Args: []idl.IdlField(fields),
Accounts: []idl.IdlInstructionAccountItem{},
},
},
}
gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}}
out, err := gen.gen_instructions()
require.NoError(t, err)
s := out.File.GoString()

p0 := paramNames[fields[0].Name]
p1 := paramNames[fields[1].Name]
assert.Equal(t, 1, strings.Count(s, "Encode("+p0+")"), "each arg must be encoded exactly once")
assert.Equal(t, 1, strings.Count(s, "Encode("+p1+")"))
assert.NotEqual(t, p0, p1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func (g *Generator) gen_instructions() (*OutputFile, error) {
file.HeaderComment("This file contains instructions and instruction parsers.")
{
for _, instruction := range g.idl.Instructions {
uniqueParamNames := generateUniqueParamNames(instruction.Args)
ixCode := Empty()
{
declarerName := newInstructionFuncName(instruction.Name)
Expand Down Expand Up @@ -47,7 +48,7 @@ func (g *Generator) gen_instructions() (*OutputFile, error) {
if IsOption(param.Ty) || IsCOption(param.Ty) {
paramType = Op("*").Add(paramType)
}
paramsCode.Id(formatParamName(param.Name)).Add(paramType)
paramsCode.Id(uniqueParamNames[param.Name]).Add(paramType)
}
},
),
Expand Down Expand Up @@ -140,12 +141,12 @@ func (g *Generator) gen_instructions() (*OutputFile, error) {
instruction.Args,
checkNil,
func(param idl.IdlField) *Statement {
return Id(formatParamName(param.Name))
return Id(uniqueParamNames[param.Name])
},
"enc__",
true, // returnNilErr
func(param idl.IdlField) string {
return formatParamName(param.Name)
return uniqueParamNames[param.Name]
},
)
})
Expand Down Expand Up @@ -306,6 +307,30 @@ func formatParamName(paramName string) string {
return tools.ToCamelLower(paramName)
}

// generateUniqueParamNames creates unique Go parameter names for instruction arguments,
// mirroring generateUniqueFieldNames but using formatParamName as the base identifier
// (builder params use a different convention than struct field names).
func generateUniqueParamNames(fields []idl.IdlField) map[string]string {
fieldNameMap := make(map[string]string)
usedNames := make(map[string]int)

for _, field := range fields {
baseName := formatParamName(field.Name)
finalName := baseName

if count, exists := usedNames[baseName]; exists {
finalName = baseName + fmt.Sprintf("%d", count+1)
usedNames[baseName] = count + 1
} else {
usedNames[baseName] = 0
}

fieldNameMap[field.Name] = finalName
}

return fieldNameMap
}

func newInstructionFuncName(instructionName string) string {
// Check if the instruction name already ends with "instruction" (case-insensitive)
instructionNameLower := strings.ToLower(instructionName)
Expand Down Expand Up @@ -477,6 +502,8 @@ func (g *Generator) gen_instructionParser(typeNames []string, discriminatorNames
func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, error) {
code := Empty()

uniqueArgFieldNames := generateUniqueFieldNames(instruction.Args)

// Check if the instruction name already ends with "instruction" (case-insensitive)
instructionNameLower := strings.ToLower(instruction.Name)
var typeName string
Expand All @@ -496,7 +523,7 @@ func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, e
if IsOption(arg.Ty) || IsCOption(arg.Ty) {
fieldType = Op("*").Add(fieldType)
}
structGroup.Id(tools.ToCamelUpper(arg.Name)).Add(fieldType).Tag(map[string]string{
structGroup.Id(uniqueArgFieldNames[arg.Name]).Add(fieldType).Tag(map[string]string{
"json": arg.Name,
})
}
Expand Down Expand Up @@ -581,7 +608,7 @@ func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, e
)
}
for _, arg := range instruction.Args {
fieldName := tools.ToCamelUpper(arg.Name)
fieldName := uniqueArgFieldNames[arg.Name]
block.Commentf("Deserialize `%s`:", fieldName)

if IsOption(arg.Ty) || IsCOption(arg.Ty) {
Expand Down
11 changes: 6 additions & 5 deletions cmd/generate-bindings/solana/anchor-go/generator/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
. "github.com/dave/jennifer/jen"
"github.com/gagliardetto/anchor-go/idl"
"github.com/gagliardetto/anchor-go/idl/idltype"
"github.com/gagliardetto/anchor-go/tools"
)

func gen_MarshalWithEncoder_struct(
Expand Down Expand Up @@ -45,32 +44,34 @@ func gen_MarshalWithEncoder_struct(
}
switch fields := fields.(type) {
case idl.IdlDefinedFieldsNamed:
uniqueFieldNames := generateUniqueFieldNames(fields)
gen_marshal_DefinedFieldsNamed(
body,
fields,
checkNil,
func(field idl.IdlField) *Statement {
return Id("obj").Dot(tools.ToCamelUpper(field.Name))
return Id("obj").Dot(uniqueFieldNames[field.Name])
},
"encoder",
false, // returnNilErr
func(field idl.IdlField) string {
return tools.ToCamelUpper(field.Name)
return uniqueFieldNames[field.Name]
},
)
case idl.IdlDefinedFieldsTuple:
convertedFields := tupleToFieldsNamed(fields)
uniqueFieldNames := generateUniqueFieldNames(convertedFields)
gen_marshal_DefinedFieldsNamed(
body,
convertedFields,
checkNil,
func(field idl.IdlField) *Statement {
return Id("obj").Dot(tools.ToCamelUpper(field.Name))
return Id("obj").Dot(uniqueFieldNames[field.Name])
},
"encoder",
false, // returnNilErr
func(field idl.IdlField) string {
return tools.ToCamelUpper(field.Name)
return uniqueFieldNames[field.Name]
},
)
case nil:
Expand Down
22 changes: 12 additions & 10 deletions cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ func gen_UnmarshalWithDecoder_struct(

switch fields := fields.(type) {
case idl.IdlDefinedFieldsNamed:
gen_unmarshal_DefinedFieldsNamed(body, fields)
gen_unmarshal_DefinedFieldsNamed(body, fields, generateUniqueFieldNames(fields))
case idl.IdlDefinedFieldsTuple:
convertedFields := tupleToFieldsNamed(fields)
gen_unmarshal_DefinedFieldsNamed(body, convertedFields)
gen_unmarshal_DefinedFieldsNamed(body, convertedFields, generateUniqueFieldNames(convertedFields))
case nil:
// No fields, just an empty struct.
// TODO: should we panic here?
Expand Down Expand Up @@ -229,9 +229,11 @@ func tupleToFieldsNamed(
func gen_unmarshal_DefinedFieldsNamed(
body *Group,
fields idl.IdlDefinedFieldsNamed,
uniqueFieldNames map[string]string,
) {
for _, field := range fields {
exportedArgName := tools.ToCamelUpper(field.Name)
goFieldName := uniqueFieldNames[field.Name]
exportedArgName := goFieldName
if IsOption(field.Ty) || IsCOption(field.Ty) {
body.Commentf("Deserialize `%s` (optional):", exportedArgName)
} else {
Expand All @@ -248,7 +250,7 @@ func gen_unmarshal_DefinedFieldsNamed(
{
argBody.Var().Err().Error()
argBody.List(
Id("obj").Dot(exportedArgName),
Id("obj").Dot(goFieldName),
Err(),
).Op("=").Id(formatEnumParserName(enumName)).Call(Id("decoder"))
}
Expand All @@ -264,11 +266,11 @@ func gen_unmarshal_DefinedFieldsNamed(
// Read the array items:
argBody.For(
Id("i").Op(":=").Lit(0),
Id("i").Op("<").Len(Id("obj").Dot(exportedArgName)),
Id("i").Op("<").Len(Id("obj").Dot(goFieldName)),
Id("i").Op("++"),
).BlockFunc(func(forBody *Group) {
forBody.List(
Id("obj").Dot(exportedArgName).Index(Id("i")),
Id("obj").Dot(goFieldName).Index(Id("i")),
Err(),
).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder"))
forBody.If(Err().Op("!=").Nil()).Block(
Expand Down Expand Up @@ -301,15 +303,15 @@ func gen_unmarshal_DefinedFieldsNamed(
),
)
// Create the vector:
argBody.Id("obj").Dot(exportedArgName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen"))
argBody.Id("obj").Dot(goFieldName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen"))
// Read the vector items:
argBody.For(
Id("i").Op(":=").Lit(0),
Id("i").Op("<").Id("vecLen"),
Id("i").Op("++"),
).BlockFunc(func(forBody *Group) {
forBody.List(
Id("obj").Dot(exportedArgName).Index(Id("i")),
Id("obj").Dot(goFieldName).Index(Id("i")),
Err(),
).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder"))
forBody.If(Err().Op("!=").Nil()).Block(
Expand Down Expand Up @@ -351,7 +353,7 @@ func gen_unmarshal_DefinedFieldsNamed(
),
)
optGroup.If(Id("ok")).Block(
Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName)),
Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName)),
If(Err().Op("!=").Nil()).Block(
Return(
Qual(PkgAnchorGoErrors, "NewField").Call(
Expand All @@ -363,7 +365,7 @@ func gen_unmarshal_DefinedFieldsNamed(
)
})
} else {
body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName))
body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName))
body.If(Err().Op("!=").Nil()).Block(
Return(
Qual(PkgAnchorGoErrors, "NewField").Call(
Expand Down