diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go index a5bd2bf5..eb787fa3 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go @@ -18,6 +18,16 @@ func isComplexEnum(envel idltype.IdlType) bool { return false } +func isOptionalComplexEnum(ty idltype.IdlType) bool { + switch v := ty.(type) { + case *idltype.Option: + return isComplexEnum(v.Option) + case *idltype.COption: + return isComplexEnum(v.COption) + } + return false +} + func register_TypeName_as_ComplexEnum(name string) { typeRegistryComplexEnum[name] = struct{}{} } diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go new file mode 100644 index 00000000..002a9f06 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go @@ -0,0 +1,106 @@ +//nolint:all // Forked from anchor-go generator, maintaining original code structure +package generator + +import ( + "strings" + "testing" + + "github.com/dave/jennifer/jen" + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/stretchr/testify/assert" +) + +// complexEnumGuard mirrors the condition used in gen_marshal_DefinedFieldsNamed +// and gen_unmarshal_DefinedFieldsNamed to decide whether a field is routed to +// the specialized enum encoder/parser or falls through to the generic +// Encode/Decode path. +func complexEnumGuard(ty idltype.IdlType) bool { + return isComplexEnum(ty) || + (IsArray(ty) && isComplexEnum(ty.(*idltype.Array).Type)) || + (IsVec(ty) && isComplexEnum(ty.(*idltype.Vec).Vec)) || + isOptionalComplexEnum(ty) +} + +func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) { + const name = "Outcome" + register_TypeName_as_ComplexEnum(name) + t.Cleanup(func() { delete(typeRegistryComplexEnum, name) }) + + defined := &idltype.Defined{Name: name} + + assert.True(t, complexEnumGuard(defined), "bare Defined") + assert.True(t, complexEnumGuard(&idltype.Option{Option: defined}), "Option") + assert.True(t, complexEnumGuard(&idltype.COption{COption: defined}), "COption") +} + +// TestComplexEnumGuard_rejectsNonComplexOptionals ensures the guard does NOT +// fire for Option/COption wrapping a non-complex Defined or a primitive. +// A false positive here would cause the switch to enter the Option/COption case +// where .Option.(*idltype.Defined) would panic on a non-Defined inner type. +func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) { + const complexName = "Outcome" + register_TypeName_as_ComplexEnum(complexName) + t.Cleanup(func() { delete(typeRegistryComplexEnum, complexName) }) + + nonComplex := &idltype.Defined{Name: "PlainStruct"} + + assert.False(t, complexEnumGuard(&idltype.Option{Option: nonComplex}), + "Option must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(&idltype.COption{COption: nonComplex}), + "COption must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.U64{}}), + "Option must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(&idltype.COption{COption: &idltype.U8{}}), + "COption must not trigger the complex-enum path") + assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}), + "Option> — nested containers not supported, must not match") +} + +// TestComplexEnumCodegen_optionalComplexEnum runs the actual marshal/unmarshal +// generator with Option and COption fields and +// verifies the generated Go source uses the specialized enum encoder/parser +// instead of the generic Encode/Decode. +func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) { + const enumName = "Outcome" + register_TypeName_as_ComplexEnum(enumName) + t.Cleanup(func() { delete(typeRegistryComplexEnum, enumName) }) + + fields := idl.IdlDefinedFieldsNamed{ + {Name: "id", Ty: &idltype.U64{}}, + {Name: "verdict", Ty: &idltype.Option{Option: &idltype.Defined{Name: enumName}}}, + {Name: "alt_verdict", Ty: &idltype.COption{COption: &idltype.Defined{Name: enumName}}}, + {Name: "checksum", Ty: &idltype.U64{}}, + } + + marshalCode := gen_MarshalWithEncoder_struct( + &idl.Idl{}, false, "Report", "", fields, true, + ) + unmarshalCode := gen_UnmarshalWithDecoder_struct( + &idl.Idl{}, false, "Report", "", fields, + ) + + f := jen.NewFile("fixture") + f.Add(marshalCode) + f.Add(unmarshalCode) + src := f.GoString() + + // Specialized enum encoder/parser must appear. + assert.Contains(t, src, "EncodeOutcome", + "Option/COption fields must call the specialized enum encoder") + assert.Contains(t, src, "DecodeOutcome", + "Option/COption fields must call the specialized enum parser") + + // Option flags must still be written/read. + assert.Contains(t, src, "WriteOption") + assert.Contains(t, src, "WriteCOption") + assert.Contains(t, src, "ReadOption") + assert.Contains(t, src, "ReadCOption") + + // Only the two plain U64 fields (Id, Checksum) should use the generic + // encoder/decoder. If the enum fields also fall through, the count is 4. + assert.Equal(t, 2, strings.Count(src, ".Encode("), + "generic Encode must only be used for non-enum fields") + assert.Equal(t, 2, strings.Count(src, ".Decode("), + "generic Decode must only be used for non-enum fields") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go index d5dd14a2..468167a2 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go @@ -152,7 +152,7 @@ func gen_marshal_DefinedFieldsNamed( body.Commentf("Serialize `%s`:", exportedArgName) } - if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) { + if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) { switch field.Ty.(type) { case *idltype.Defined: enumTypeName := field.Ty.(*idltype.Defined).Name @@ -260,6 +260,12 @@ func gen_marshal_DefinedFieldsNamed( ) }) }) + case *idltype.Option: + enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name + gen_marshal_optionalComplexEnum(body, "WriteOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName) + case *idltype.COption: + enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name + gen_marshal_optionalComplexEnum(body, "WriteCOption", enumTypeName, field, checkNil, nameFormatter, encoderVariableName, returnNilErr, exportedArgName) } } else { if IsOption(field.Ty) || IsCOption(field.Ty) { @@ -380,3 +386,58 @@ func gen_marshal_DefinedFieldsNamed( } } } + +func gen_marshal_optionalComplexEnum( + body *Group, + optionalityWriterName string, + enumTypeName string, + field idl.IdlField, + checkNil bool, + nameFormatter func(field idl.IdlField) *Statement, + encoderVariableName string, + returnNilErr bool, + exportedArgName string, +) { + errReturn := func(wrapped Code) *Statement { + return ReturnFunc(func(returnBody *Group) { + if returnNilErr { + returnBody.Nil() + } + returnBody.Add(wrapped) + }) + } + optionalityErr := func() *Statement { + return errReturn( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call(Lit("error while encoding optionality: %w"), Err()), + ), + ) + } + fieldErr := func() *Statement { + return errReturn( + Qual(PkgAnchorGoErrors, "NewField").Call(Lit(exportedArgName), Err()), + ) + } + + if checkNil { + body.BlockFunc(func(optGroup *Group) { + optGroup.If(nameFormatter(field).Op("==").Nil()).Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(False()), + If(Err().Op("!=").Nil()).Block(optionalityErr()), + ).Else().Block( + Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()), + If(Err().Op("!=").Nil()).Block(optionalityErr()), + Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)), + If(Err().Op("!=").Nil()).Block(fieldErr()), + ) + }) + } else { + body.BlockFunc(func(optGroup *Group) { + optGroup.Err().Op("=").Id(encoderVariableName).Dot(optionalityWriterName).Call(True()) + optGroup.If(Err().Op("!=").Nil()).Block(optionalityErr()) + optGroup.Err().Op("=").Id(formatEnumEncoderName(enumTypeName)).Call(Id(encoderVariableName), nameFormatter(field)) + optGroup.If(Err().Op("!=").Nil()).Block(fieldErr()) + }) + } +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go index 5e7b6a35..784e832a 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go @@ -238,9 +238,7 @@ func gen_unmarshal_DefinedFieldsNamed( body.Commentf("Deserialize `%s`:", exportedArgName) } - if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) { - // TODO: this assumes this cannot be an option; - // - check whether this is an option? + if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) { switch field.Ty.(type) { case *idltype.Defined: enumName := field.Ty.(*idltype.Defined).Name @@ -325,6 +323,12 @@ func gen_unmarshal_DefinedFieldsNamed( ) }) }) + case *idltype.Option: + enumTypeName := field.Ty.(*idltype.Option).Option.(*idltype.Defined).Name + gen_unmarshal_optionalComplexEnum(body, "ReadOption", enumTypeName, exportedArgName) + case *idltype.COption: + enumTypeName := field.Ty.(*idltype.COption).COption.(*idltype.Defined).Name + gen_unmarshal_optionalComplexEnum(body, "ReadCOption", enumTypeName, exportedArgName) } } else { if IsOption(field.Ty) || IsCOption(field.Ty) { @@ -376,3 +380,39 @@ func gen_unmarshal_DefinedFieldsNamed( } } } + +func gen_unmarshal_optionalComplexEnum( + body *Group, + optionalityReaderName string, + enumTypeName string, + exportedArgName string, +) { + body.BlockFunc(func(optGroup *Group) { + optGroup.List(Id("ok"), Err()).Op(":=").Id("decoder").Dot(optionalityReaderName).Call() + optGroup.If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewOption").Call( + Lit(exportedArgName), + Qual("fmt", "Errorf").Call( + Lit("error while reading optionality: %w"), + Err(), + ), + ), + ), + ) + optGroup.If(Id("ok")).Block( + List( + Id("obj").Dot(exportedArgName), + Err(), + ).Op("=").Id(formatEnumParserName(enumTypeName)).Call(Id("decoder")), + If(Err().Op("!=").Nil()).Block( + Return( + Qual(PkgAnchorGoErrors, "NewField").Call( + Lit(exportedArgName), + Err(), + ), + ), + ), + ) + }) +}