diff --git a/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go new file mode 100644 index 00000000..d05d41a1 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go @@ -0,0 +1,134 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// collidingNamedFields is an IDL shape where two distinct field names normalize to the +// same Go identifier via tools.ToCamelUpper (foo_bar and fooBar -> FooBar). Struct +// generation deconflicts these as FooBar and FooBar1; marshal/unmarshal must use the same names. +func collidingNamedFields() idl.IdlDefinedFieldsNamed { + return idl.IdlDefinedFieldsNamed{ + {Name: "foo_bar", Ty: &idltype.U8{}}, + {Name: "fooBar", Ty: &idltype.U8{}}, + } +} + +func TestGenerateUniqueFieldNames_collidingIDLNames(t *testing.T) { + fields := collidingNamedFields() + m := generateUniqueFieldNames(fields) + require.Len(t, m, 2) + assert.Equal(t, "FooBar", m["foo_bar"]) + assert.Equal(t, "FooBar1", m["fooBar"]) +} + +// TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames documents the regression where +// gen_MarshalWithEncoder_struct / gen_UnmarshalWithDecoder_struct used tools.ToCamelUpper(field.Name) +// for accessors instead of generateUniqueFieldNames: both fields targeted obj.FooBar, so one +// value was serialized twice and the FooBar1 sibling was never written or read. +// +// The expected assertions describe the correct fixed behavior; they fail until uniquified names +// are threaded through marshal/unmarshal generation. +func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) { + idlMinimal := &idl.Idl{} + fields := collidingNamedFields() + receiver := "CollideAccount" + + marshalCode := gen_MarshalWithEncoder_struct( + idlMinimal, + false, + receiver, + "", + fields, + true, + ) + unmarshalCode := gen_UnmarshalWithDecoder_struct( + idlMinimal, + false, + receiver, + "", + fields, + ) + + f := jen.NewFile("fixture") + f.Add(marshalCode) + f.Add(unmarshalCode) + src := f.GoString() + + // Correct codegen must reference both uniquified struct fields. + assert.Contains(t, src, "obj.FooBar1", "marshal/unmarshal must access the deconflicted FooBar1 field") + + // Buggy codegen encodes/decodes the same field twice; reject duplicate bare obj.FooBar + // Encode/Decode when a second distinct IDL field exists. + encodeFooBar := strings.Count(src, "Encode(obj.FooBar)") + decodeFooBar := strings.Count(src, "Decode(&obj.FooBar)") + assert.Equal(t, 1, encodeFooBar, "each IDL field must map to a single Encode(obj.); duplicate Encode(obj.FooBar) indicates silent corruption") + assert.Equal(t, 1, decodeFooBar, "each IDL field must map to a single Decode(&obj.); duplicate Decode(&obj.FooBar) indicates silent corruption") + + assert.Contains(t, src, "Encode(obj.FooBar1)") + assert.Contains(t, src, "Decode(&obj.FooBar1)") +} + +func TestGenerateUniqueParamNames_collidingIDLNames(t *testing.T) { + fields := collidingNamedFields() + m := generateUniqueParamNames(fields) + require.Len(t, m, 2) + assert.NotEqual(t, m[fields[0].Name], m[fields[1].Name]) + b0 := formatParamName(fields[0].Name) + b1 := formatParamName(fields[1].Name) + if b0 == b1 { + assert.Equal(t, b0+"1", m[fields[1].Name]) + } +} + +func TestGenInstructionType_uniquifiesArgFieldsAndDecode(t *testing.T) { + ins := idl.IdlInstruction{ + Name: "do_test", + Args: []idl.IdlField(collidingNamedFields()), + Accounts: []idl.IdlInstructionAccountItem{}, + } + g := &Generator{idl: &idl.Idl{}, options: &GeneratorOptions{Package: "test"}} + code, err := g.gen_instructionType(ins) + require.NoError(t, err) + + f := jen.NewFile("test") + f.Add(code) + src := f.GoString() + + assert.Contains(t, src, "FooBar1") + assert.Equal(t, 1, strings.Count(src, "Decode(&obj.FooBar)")) + assert.Contains(t, src, "Decode(&obj.FooBar1)") +} + +func TestGenInstructions_builderEncodesEachArgOnce(t *testing.T) { + fields := collidingNamedFields() + paramNames := generateUniqueParamNames(fields) + + idlData := &idl.Idl{ + Instructions: []idl.IdlInstruction{ + { + Name: "do_test", + Args: []idl.IdlField(fields), + Accounts: []idl.IdlInstructionAccountItem{}, + }, + }, + } + gen := &Generator{idl: idlData, options: &GeneratorOptions{Package: "test"}} + out, err := gen.gen_instructions() + require.NoError(t, err) + s := out.File.GoString() + + p0 := paramNames[fields[0].Name] + p1 := paramNames[fields[1].Name] + assert.Equal(t, 1, strings.Count(s, "Encode("+p0+")"), "each arg must be encoded exactly once") + assert.Equal(t, 1, strings.Count(s, "Encode("+p1+")")) + assert.NotEqual(t, p0, p1) +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go index 2f8123ba..a6c7e799 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go @@ -17,6 +17,7 @@ func (g *Generator) gen_instructions() (*OutputFile, error) { file.HeaderComment("This file contains instructions and instruction parsers.") { for _, instruction := range g.idl.Instructions { + uniqueParamNames := generateUniqueParamNames(instruction.Args) ixCode := Empty() { declarerName := newInstructionFuncName(instruction.Name) @@ -47,7 +48,7 @@ func (g *Generator) gen_instructions() (*OutputFile, error) { if IsOption(param.Ty) || IsCOption(param.Ty) { paramType = Op("*").Add(paramType) } - paramsCode.Id(formatParamName(param.Name)).Add(paramType) + paramsCode.Id(uniqueParamNames[param.Name]).Add(paramType) } }, ), @@ -140,12 +141,12 @@ func (g *Generator) gen_instructions() (*OutputFile, error) { instruction.Args, checkNil, func(param idl.IdlField) *Statement { - return Id(formatParamName(param.Name)) + return Id(uniqueParamNames[param.Name]) }, "enc__", true, // returnNilErr func(param idl.IdlField) string { - return formatParamName(param.Name) + return uniqueParamNames[param.Name] }, ) }) @@ -306,6 +307,30 @@ func formatParamName(paramName string) string { return tools.ToCamelLower(paramName) } +// generateUniqueParamNames creates unique Go parameter names for instruction arguments, +// mirroring generateUniqueFieldNames but using formatParamName as the base identifier +// (builder params use a different convention than struct field names). +func generateUniqueParamNames(fields []idl.IdlField) map[string]string { + fieldNameMap := make(map[string]string) + usedNames := make(map[string]int) + + for _, field := range fields { + baseName := formatParamName(field.Name) + finalName := baseName + + if count, exists := usedNames[baseName]; exists { + finalName = baseName + fmt.Sprintf("%d", count+1) + usedNames[baseName] = count + 1 + } else { + usedNames[baseName] = 0 + } + + fieldNameMap[field.Name] = finalName + } + + return fieldNameMap +} + func newInstructionFuncName(instructionName string) string { // Check if the instruction name already ends with "instruction" (case-insensitive) instructionNameLower := strings.ToLower(instructionName) @@ -477,6 +502,8 @@ func (g *Generator) gen_instructionParser(typeNames []string, discriminatorNames func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, error) { code := Empty() + uniqueArgFieldNames := generateUniqueFieldNames(instruction.Args) + // Check if the instruction name already ends with "instruction" (case-insensitive) instructionNameLower := strings.ToLower(instruction.Name) var typeName string @@ -496,7 +523,7 @@ func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, e if IsOption(arg.Ty) || IsCOption(arg.Ty) { fieldType = Op("*").Add(fieldType) } - structGroup.Id(tools.ToCamelUpper(arg.Name)).Add(fieldType).Tag(map[string]string{ + structGroup.Id(uniqueArgFieldNames[arg.Name]).Add(fieldType).Tag(map[string]string{ "json": arg.Name, }) } @@ -581,7 +608,7 @@ func (g *Generator) gen_instructionType(instruction idl.IdlInstruction) (Code, e ) } for _, arg := range instruction.Args { - fieldName := tools.ToCamelUpper(arg.Name) + fieldName := uniqueArgFieldNames[arg.Name] block.Commentf("Deserialize `%s`:", fieldName) if IsOption(arg.Ty) || IsCOption(arg.Ty) { diff --git a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go index d5dd14a2..d06ed940 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go @@ -7,7 +7,6 @@ import ( . "github.com/dave/jennifer/jen" "github.com/gagliardetto/anchor-go/idl" "github.com/gagliardetto/anchor-go/idl/idltype" - "github.com/gagliardetto/anchor-go/tools" ) func gen_MarshalWithEncoder_struct( @@ -45,32 +44,34 @@ func gen_MarshalWithEncoder_struct( } switch fields := fields.(type) { case idl.IdlDefinedFieldsNamed: + uniqueFieldNames := generateUniqueFieldNames(fields) gen_marshal_DefinedFieldsNamed( body, fields, checkNil, func(field idl.IdlField) *Statement { - return Id("obj").Dot(tools.ToCamelUpper(field.Name)) + return Id("obj").Dot(uniqueFieldNames[field.Name]) }, "encoder", false, // returnNilErr func(field idl.IdlField) string { - return tools.ToCamelUpper(field.Name) + return uniqueFieldNames[field.Name] }, ) case idl.IdlDefinedFieldsTuple: convertedFields := tupleToFieldsNamed(fields) + uniqueFieldNames := generateUniqueFieldNames(convertedFields) gen_marshal_DefinedFieldsNamed( body, convertedFields, checkNil, func(field idl.IdlField) *Statement { - return Id("obj").Dot(tools.ToCamelUpper(field.Name)) + return Id("obj").Dot(uniqueFieldNames[field.Name]) }, "encoder", false, // returnNilErr func(field idl.IdlField) string { - return tools.ToCamelUpper(field.Name) + return uniqueFieldNames[field.Name] }, ) case nil: diff --git a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go index 5e7b6a35..f9d7cb01 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go @@ -118,10 +118,10 @@ func gen_UnmarshalWithDecoder_struct( switch fields := fields.(type) { case idl.IdlDefinedFieldsNamed: - gen_unmarshal_DefinedFieldsNamed(body, fields) + gen_unmarshal_DefinedFieldsNamed(body, fields, generateUniqueFieldNames(fields)) case idl.IdlDefinedFieldsTuple: convertedFields := tupleToFieldsNamed(fields) - gen_unmarshal_DefinedFieldsNamed(body, convertedFields) + gen_unmarshal_DefinedFieldsNamed(body, convertedFields, generateUniqueFieldNames(convertedFields)) case nil: // No fields, just an empty struct. // TODO: should we panic here? @@ -229,9 +229,11 @@ func tupleToFieldsNamed( func gen_unmarshal_DefinedFieldsNamed( body *Group, fields idl.IdlDefinedFieldsNamed, + uniqueFieldNames map[string]string, ) { for _, field := range fields { - exportedArgName := tools.ToCamelUpper(field.Name) + goFieldName := uniqueFieldNames[field.Name] + exportedArgName := goFieldName if IsOption(field.Ty) || IsCOption(field.Ty) { body.Commentf("Deserialize `%s` (optional):", exportedArgName) } else { @@ -248,7 +250,7 @@ func gen_unmarshal_DefinedFieldsNamed( { argBody.Var().Err().Error() argBody.List( - Id("obj").Dot(exportedArgName), + Id("obj").Dot(goFieldName), Err(), ).Op("=").Id(formatEnumParserName(enumName)).Call(Id("decoder")) } @@ -264,11 +266,11 @@ func gen_unmarshal_DefinedFieldsNamed( // Read the array items: argBody.For( Id("i").Op(":=").Lit(0), - Id("i").Op("<").Len(Id("obj").Dot(exportedArgName)), + Id("i").Op("<").Len(Id("obj").Dot(goFieldName)), Id("i").Op("++"), ).BlockFunc(func(forBody *Group) { forBody.List( - Id("obj").Dot(exportedArgName).Index(Id("i")), + Id("obj").Dot(goFieldName).Index(Id("i")), Err(), ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) forBody.If(Err().Op("!=").Nil()).Block( @@ -301,7 +303,7 @@ func gen_unmarshal_DefinedFieldsNamed( ), ) // Create the vector: - argBody.Id("obj").Dot(exportedArgName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen")) + argBody.Id("obj").Dot(goFieldName).Op("=").Make(Index().Id(enumTypeName), Id("vecLen")) // Read the vector items: argBody.For( Id("i").Op(":=").Lit(0), @@ -309,7 +311,7 @@ func gen_unmarshal_DefinedFieldsNamed( Id("i").Op("++"), ).BlockFunc(func(forBody *Group) { forBody.List( - Id("obj").Dot(exportedArgName).Index(Id("i")), + Id("obj").Dot(goFieldName).Index(Id("i")), Err(), ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")) forBody.If(Err().Op("!=").Nil()).Block( @@ -351,7 +353,7 @@ func gen_unmarshal_DefinedFieldsNamed( ), ) optGroup.If(Id("ok")).Block( - Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName)), + Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName)), If(Err().Op("!=").Nil()).Block( Return( Qual(PkgAnchorGoErrors, "NewField").Call( @@ -363,7 +365,7 @@ func gen_unmarshal_DefinedFieldsNamed( ) }) } else { - body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(exportedArgName)) + body.Err().Op("=").Id("decoder").Dot("Decode").Call(Op("&").Id("obj").Dot(goFieldName)) body.If(Err().Op("!=").Nil()).Block( Return( Qual(PkgAnchorGoErrors, "NewField").Call(