diff --git a/ast/access.go b/ast/access.go index 86bb9afc9..6f95e6f34 100644 --- a/ast/access.go +++ b/ast/access.go @@ -33,6 +33,7 @@ import ( type Access interface { isAccess() + Walk(walkChild func(Element)) Keyword() string Description() string String() string @@ -172,6 +173,10 @@ func NewEntitlementAccess(entitlements EntitlementSet) EntitlementAccess { func (EntitlementAccess) isAccess() {} +func (e EntitlementAccess) Walk(walkChild func(Element)) { + e.EntitlementSet.Walk(walkChild) +} + func (EntitlementAccess) Description() string { return "entitled access" } @@ -300,6 +305,8 @@ func PrimitiveAccessCount() int { func (PrimitiveAccess) isAccess() {} +func (PrimitiveAccess) Walk(_ func(Element)) {} + // TODO: remove. // only used by tests which are not updated yet // to include contract and account access diff --git a/ast/access_test.go b/ast/access_test.go index c7944ca99..ec66527cc 100644 --- a/ast/access_test.go +++ b/ast/access_test.go @@ -87,6 +87,29 @@ func TestMappedAccess_Walk(t *testing.T) { assert.Equal(t, []Element{mapType}, visited) } +func TestEntitlementAccess_Walk(t *testing.T) { + + t.Parallel() + + e := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + f := &NominalType{ + Identifier: Identifier{Identifier: "F"}, + } + + access := NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{e, f}), + ) + + var visited []Element + access.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, []Element{e, f}, visited) +} + func TestConjunctiveEntitlementSet_MarshalJSON(t *testing.T) { t.Parallel() diff --git a/ast/attachment.go b/ast/attachment.go index afa1a682f..0d297ded0 100644 --- a/ast/attachment.go +++ b/ast/attachment.go @@ -71,7 +71,14 @@ func (*AttachmentDeclaration) ElementType() ElementType { } func (d *AttachmentDeclaration) Walk(walkChild func(Element)) { - walkDeclarations(walkChild, d.Members.declarations) + if d.Access != nil { + d.Access.Walk(walkChild) + } + if d.BaseType != nil { + walkChild(d.BaseType) + } + walkElements(walkChild, d.Conformances) + walkElements(walkChild, d.Members.declarations) } func (*AttachmentDeclaration) isDeclaration() {} diff --git a/ast/attachment_test.go b/ast/attachment_test.go index c958b00db..ebe3651e3 100644 --- a/ast/attachment_test.go +++ b/ast/attachment_test.go @@ -252,38 +252,141 @@ func TestAttachmentDeclaration_Walk(t *testing.T) { t.Parallel() - field := &FieldDeclaration{ - Identifier: Identifier{Identifier: "field"}, - TypeAnnotation: &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Int"}, + t.Run("members only", func(t *testing.T) { + t.Parallel() + + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + TypeAnnotation: &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, }, - }, - } + } - function := &FunctionDeclaration{ - Identifier: Identifier{Identifier: "function"}, - } + function := &FunctionDeclaration{ + Identifier: Identifier{Identifier: "function"}, + } - decl := &AttachmentDeclaration{ - Members: NewUnmeteredMembers([]Declaration{ - field, - function, - }), - } + decl := &AttachmentDeclaration{ + Members: NewUnmeteredMembers([]Declaration{ + field, + function, + }), + } - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + field, + function, + }, + visited, + ) }) - assert.Equal(t, - []Element{ - field, - function, - }, - visited, - ) + t.Run("with base type", func(t *testing.T) { + t.Parallel() + + baseType := &NominalType{ + Identifier: Identifier{Identifier: "Base"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &AttachmentDeclaration{ + BaseType: baseType, + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + baseType, + field, + }, + visited, + ) + }) + + t.Run("with conformances", func(t *testing.T) { + t.Parallel() + + conformance := &NominalType{ + Identifier: Identifier{Identifier: "Foo"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &AttachmentDeclaration{ + Conformances: []*NominalType{conformance}, + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + conformance, + field, + }, + visited, + ) + }) + + t.Run("with entitlement access, base type, and conformances", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + baseType := &NominalType{ + Identifier: Identifier{Identifier: "Base"}, + } + conformance := &NominalType{ + Identifier: Identifier{Identifier: "Iface"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &AttachmentDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + BaseType: baseType, + Conformances: []*NominalType{conformance}, + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + baseType, + conformance, + field, + }, + visited, + ) + }) } func TestAttachExpressionMarshallJSON(t *testing.T) { diff --git a/ast/block.go b/ast/block.go index be02b717e..4e2c1f543 100644 --- a/ast/block.go +++ b/ast/block.go @@ -51,7 +51,7 @@ func (b *Block) IsEmpty() bool { } func (b *Block) Walk(walkChild func(Element)) { - walkStatements(walkChild, b.Statements) + walkElements(walkChild, b.Statements) } var blockStartDoc prettier.Doc = prettier.Text("{") diff --git a/ast/block_test.go b/ast/block_test.go index f3bd0de6c..9ff8b60fd 100644 --- a/ast/block_test.go +++ b/ast/block_test.go @@ -870,7 +870,7 @@ func TestFunctionBlock_Walk(t *testing.T) { ) } -func TestTestCondition_Doc(t *testing.T) { +func TestCondition_Doc(t *testing.T) { t.Parallel() t.Run("with test and message", func(t *testing.T) { @@ -972,6 +972,87 @@ func TestTestCondition_Doc(t *testing.T) { }) } +func TestCondition_Walk(t *testing.T) { + + t.Parallel() + + t.Run("with test only", func(t *testing.T) { + t.Parallel() + + test := &BoolExpression{Value: true} + + condition := TestCondition{ + Test: test, + } + + var visited []Element + condition.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, []Element{test}, visited) + }) + + t.Run("with test and message", func(t *testing.T) { + t.Parallel() + + test := &BoolExpression{Value: true} + message := &StringExpression{Value: "fail"} + + condition := TestCondition{ + Test: test, + Message: message, + } + + var visited []Element + condition.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, []Element{test, message}, visited) + }) +} + +func TestConditions_Walk(t *testing.T) { + + t.Parallel() + + t.Run("empty", func(t *testing.T) { + t.Parallel() + + conditions := &Conditions{} + + var visited []Element + conditions.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Empty(t, visited) + }) + + t.Run("with conditions", func(t *testing.T) { + t.Parallel() + + cond1 := TestCondition{ + Test: &BoolExpression{Value: true}, + } + cond2 := TestCondition{ + Test: &BoolExpression{Value: false}, + } + + conditions := &Conditions{ + Conditions: []Condition{cond1, cond2}, + } + + var visited []Element + conditions.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, []Element{cond1, cond2}, visited) + }) +} + func TestEmitCondition_Walk(t *testing.T) { t.Parallel() diff --git a/ast/composite.go b/ast/composite.go index 6f5148c46..3a7e64109 100644 --- a/ast/composite.go +++ b/ast/composite.go @@ -92,7 +92,11 @@ func (*CompositeDeclaration) ElementType() ElementType { } func (d *CompositeDeclaration) Walk(walkChild func(Element)) { - walkDeclarations(walkChild, d.Members.declarations) + if d.Access != nil { + d.Access.Walk(walkChild) + } + walkElements(walkChild, d.Conformances) + walkElements(walkChild, d.Members.declarations) } func (*CompositeDeclaration) isDeclaration() {} @@ -354,6 +358,9 @@ func (*FieldDeclaration) ElementType() ElementType { } func (d *FieldDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } if d.TypeAnnotation != nil { walkChild(d.TypeAnnotation) } @@ -525,8 +532,10 @@ func (*EnumCaseDeclaration) ElementType() ElementType { return ElementTypeEnumCaseDeclaration } -func (*EnumCaseDeclaration) Walk(_ func(Element)) { - // NO-OP +func (d *EnumCaseDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } } func (*EnumCaseDeclaration) isDeclaration() {} diff --git a/ast/composite_test.go b/ast/composite_test.go index 5f06023e8..4b38bfa25 100644 --- a/ast/composite_test.go +++ b/ast/composite_test.go @@ -472,6 +472,40 @@ func TestFieldDeclaration_Walk(t *testing.T) { assert.Empty(t, visited) }) + + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + typeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, + } + + decl := &FieldDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Identifier: Identifier{Identifier: "foo"}, + TypeAnnotation: typeAnnotation, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + typeAnnotation, + }, + visited, + ) + }) } func TestCompositeDeclaration_MarshalJSON(t *testing.T) { @@ -1051,6 +1085,53 @@ enum AB: CD { }) } +func TestEnumCaseDeclaration_Walk(t *testing.T) { + + t.Parallel() + + t.Run("without access", func(t *testing.T) { + t.Parallel() + + decl := &EnumCaseDeclaration{ + Identifier: Identifier{Identifier: "x"}, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Empty(t, visited) + }) + + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + + decl := &EnumCaseDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Identifier: Identifier{Identifier: "x"}, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + }, + visited, + ) + }) +} + func TestEnumCaseDeclaration_Doc(t *testing.T) { t.Parallel() @@ -1092,28 +1173,125 @@ func TestCompositeDeclaration_Walk(t *testing.T) { t.Parallel() - field := &FieldDeclaration{ - Identifier: Identifier{Identifier: "field"}, - TypeAnnotation: &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Int"}, + t.Run("members only", func(t *testing.T) { + t.Parallel() + + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + TypeAnnotation: &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, }, - }, - } + } - decl := &CompositeDeclaration{ - Members: NewUnmeteredMembers([]Declaration{field}), - } + decl := &CompositeDeclaration{ + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) + assert.Equal(t, + []Element{ + field, + }, + visited, + ) }) - assert.Equal(t, - []Element{ - field, - }, - visited, - ) + t.Run("with conformances", func(t *testing.T) { + t.Parallel() + + conformance1 := &NominalType{ + Identifier: Identifier{Identifier: "Foo"}, + } + conformance2 := &NominalType{ + Identifier: Identifier{Identifier: "Bar"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &CompositeDeclaration{ + Conformances: []*NominalType{conformance1, conformance2}, + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + conformance1, + conformance2, + field, + }, + visited, + ) + }) + + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &CompositeDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + field, + }, + visited, + ) + }) + + t.Run("with mapped access", func(t *testing.T) { + t.Parallel() + + entitlementMap := &NominalType{ + Identifier: Identifier{Identifier: "M"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &CompositeDeclaration{ + Access: NewMappedAccess(entitlementMap, Position{}), + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlementMap, + field, + }, + visited, + ) + }) } diff --git a/ast/entitlement_declaration.go b/ast/entitlement_declaration.go index c1667588f..d088f5c14 100644 --- a/ast/entitlement_declaration.go +++ b/ast/entitlement_declaration.go @@ -60,7 +60,11 @@ func (*EntitlementDeclaration) ElementType() ElementType { return ElementTypeEntitlementDeclaration } -func (*EntitlementDeclaration) Walk(_ func(Element)) {} +func (d *EntitlementDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } +} func (*EntitlementDeclaration) isDeclaration() {} @@ -127,6 +131,7 @@ func (d *EntitlementDeclaration) String() string { type EntitlementMapElement interface { isEntitlementMapElement() + Walk(walkChild func(Element)) Doc() prettier.Doc } @@ -162,6 +167,15 @@ func (d *EntitlementMapRelation) Doc() prettier.Doc { } } +func (d *EntitlementMapRelation) Walk(walkChild func(Element)) { + if d.Input != nil { + walkChild(d.Input) + } + if d.Output != nil { + walkChild(d.Output) + } +} + // EntitlementMappingDeclaration type EntitlementMappingDeclaration struct { Access Access @@ -198,7 +212,19 @@ func (*EntitlementMappingDeclaration) ElementType() ElementType { return ElementTypeEntitlementMappingDeclaration } -func (*EntitlementMappingDeclaration) Walk(_ func(Element)) {} +func (d *EntitlementMappingDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } + for _, element := range d.Elements { + switch element := element.(type) { + case Element: + walkChild(element) + case EntitlementMapElement: + element.Walk(walkChild) + } + } +} func (*EntitlementMappingDeclaration) isDeclaration() {} diff --git a/ast/entitlement_declaration_test.go b/ast/entitlement_declaration_test.go index 70c6a040a..e4a3c7313 100644 --- a/ast/entitlement_declaration_test.go +++ b/ast/entitlement_declaration_test.go @@ -422,3 +422,221 @@ include X ) }) } + +func TestEntitlementDeclaration_Walk(t *testing.T) { + + t.Parallel() + + t.Run("without access", func(t *testing.T) { + t.Parallel() + + decl := &EntitlementDeclaration{ + Identifier: Identifier{Identifier: "E"}, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Empty(t, visited) + }) + + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "X"}, + } + + decl := &EntitlementDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Identifier: Identifier{Identifier: "E"}, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + }, + visited, + ) + }) +} + +func TestEntitlementMapRelation_Walk(t *testing.T) { + + t.Parallel() + + t.Run("with input and output", func(t *testing.T) { + t.Parallel() + + input := &NominalType{ + Identifier: Identifier{Identifier: "X"}, + } + output := &NominalType{ + Identifier: Identifier{Identifier: "Y"}, + } + + relation := &EntitlementMapRelation{ + Input: input, + Output: output, + } + + var visited []Element + relation.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + input, + output, + }, + visited, + ) + }) + + t.Run("with nil input", func(t *testing.T) { + t.Parallel() + + output := &NominalType{ + Identifier: Identifier{Identifier: "Y"}, + } + + relation := &EntitlementMapRelation{ + Output: output, + } + + var visited []Element + relation.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + output, + }, + visited, + ) + }) +} + +func TestEntitlementMappingDeclaration_Walk(t *testing.T) { + + t.Parallel() + + t.Run("empty", func(t *testing.T) { + t.Parallel() + + decl := &EntitlementMappingDeclaration{} + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Empty(t, visited) + }) + + t.Run("with inclusion", func(t *testing.T) { + t.Parallel() + + inclusion := &NominalType{ + Identifier: Identifier{Identifier: "X"}, + } + + decl := &EntitlementMappingDeclaration{ + Elements: []EntitlementMapElement{inclusion}, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + inclusion, + }, + visited, + ) + }) + + t.Run("with relation", func(t *testing.T) { + t.Parallel() + + input := &NominalType{ + Identifier: Identifier{Identifier: "X"}, + } + output := &NominalType{ + Identifier: Identifier{Identifier: "Y"}, + } + + decl := &EntitlementMappingDeclaration{ + Elements: []EntitlementMapElement{ + &EntitlementMapRelation{ + Input: input, + Output: output, + }, + }, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + input, + output, + }, + visited, + ) + }) + + t.Run("with inclusion and relation", func(t *testing.T) { + t.Parallel() + + inclusion := &NominalType{ + Identifier: Identifier{Identifier: "X"}, + } + input := &NominalType{ + Identifier: Identifier{Identifier: "A"}, + } + output := &NominalType{ + Identifier: Identifier{Identifier: "B"}, + } + + decl := &EntitlementMappingDeclaration{ + Elements: []EntitlementMapElement{ + inclusion, + &EntitlementMapRelation{ + Input: input, + Output: output, + }, + }, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + inclusion, + input, + output, + }, + visited, + ) + }) +} diff --git a/ast/expression.go b/ast/expression.go index b51ea46bb..8b37aa26c 100644 --- a/ast/expression.go +++ b/ast/expression.go @@ -264,7 +264,7 @@ func (*StringTemplateExpression) isExpression() {} func (*StringTemplateExpression) isIfStatementTest() {} func (e *StringTemplateExpression) Walk(walkChild func(Element)) { - walkExpressions(walkChild, e.Expressions) + walkElements(walkChild, e.Expressions) } func (e *StringTemplateExpression) String() string { @@ -513,7 +513,7 @@ func (*ArrayExpression) isExpression() {} func (*ArrayExpression) isIfStatementTest() {} func (e *ArrayExpression) Walk(walkChild func(Element)) { - walkExpressions(walkChild, e.Values) + walkElements(walkChild, e.Values) } func (e *ArrayExpression) String() string { diff --git a/ast/function_declaration.go b/ast/function_declaration.go index 0601056ba..bbab8f65f 100644 --- a/ast/function_declaration.go +++ b/ast/function_declaration.go @@ -141,6 +141,9 @@ func (d *FunctionDeclaration) EndPosition(memoryGauge common.MemoryGauge) Positi } func (d *FunctionDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } if d.TypeParameterList != nil { d.TypeParameterList.Walk(walkChild) } diff --git a/ast/function_declaration_test.go b/ast/function_declaration_test.go index a00411314..02b1972ea 100644 --- a/ast/function_declaration_test.go +++ b/ast/function_declaration_test.go @@ -503,80 +503,122 @@ func TestFunctionDeclaration_Walk(t *testing.T) { t.Parallel() - typeBound := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Bound"}, - }, - } + t.Run("without access", func(t *testing.T) { + t.Parallel() - paramTypeAnnotation := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "String"}, - }, - } + typeBound := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Bound"}, + }, + } - defaultArg := &IntegerExpression{ - PositiveLiteral: []byte("42"), - Base: 10, - } + paramTypeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "String"}, + }, + } - returnTypeAnnotation := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Int"}, - }, - } + defaultArg := &IntegerExpression{ + PositiveLiteral: []byte("42"), + Base: 10, + } - body := &IntegerExpression{ - PositiveLiteral: []byte("1"), - Base: 10, - } + returnTypeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, + } - decl := &FunctionDeclaration{ - Identifier: Identifier{Identifier: "foo"}, - TypeParameterList: &TypeParameterList{ - TypeParameters: []*TypeParameter{ - { - Identifier: Identifier{Identifier: "T"}, - TypeBound: typeBound, + body := &IntegerExpression{ + PositiveLiteral: []byte("1"), + Base: 10, + } + + decl := &FunctionDeclaration{ + Identifier: Identifier{Identifier: "foo"}, + TypeParameterList: &TypeParameterList{ + TypeParameters: []*TypeParameter{ + { + Identifier: Identifier{Identifier: "T"}, + TypeBound: typeBound, + }, }, }, - }, - ParameterList: &ParameterList{ - Parameters: []*Parameter{ - { - Identifier: Identifier{Identifier: "x"}, - TypeAnnotation: paramTypeAnnotation, - DefaultArgument: defaultArg, + ParameterList: &ParameterList{ + Parameters: []*Parameter{ + { + Identifier: Identifier{Identifier: "x"}, + TypeAnnotation: paramTypeAnnotation, + DefaultArgument: defaultArg, + }, }, }, - }, - ReturnTypeAnnotation: returnTypeAnnotation, - FunctionBlock: &FunctionBlock{ - Block: &Block{ - Statements: []Statement{ - &ReturnStatement{ - Expression: body, + ReturnTypeAnnotation: returnTypeAnnotation, + FunctionBlock: &FunctionBlock{ + Block: &Block{ + Statements: []Statement{ + &ReturnStatement{ + Expression: body, + }, }, }, }, - }, - } + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) + assert.Equal(t, + []Element{ + typeBound, + paramTypeAnnotation, + defaultArg, + returnTypeAnnotation, + decl.FunctionBlock, + }, + visited, + ) }) - assert.Equal(t, - []Element{ - typeBound, - paramTypeAnnotation, - defaultArg, - returnTypeAnnotation, - decl.FunctionBlock, - }, - visited, - ) + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + + returnTypeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, + } + + decl := &FunctionDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Identifier: Identifier{Identifier: "foo"}, + ParameterList: &ParameterList{ + Parameters: []*Parameter{}, + }, + ReturnTypeAnnotation: returnTypeAnnotation, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + returnTypeAnnotation, + }, + visited, + ) + }) } func TestSpecialFunctionDeclaration_MarshalJSON(t *testing.T) { @@ -979,80 +1021,84 @@ func TestSpecialFunctionDeclaration_Walk(t *testing.T) { t.Parallel() - typeBound := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Bound"}, - }, - } + t.Run("without access", func(t *testing.T) { + t.Parallel() - paramTypeAnnotation := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "String"}, - }, - } + typeBound := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Bound"}, + }, + } - defaultArg := &IntegerExpression{ - PositiveLiteral: []byte("42"), - Base: 10, - } + paramTypeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "String"}, + }, + } - returnTypeAnnotation := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Int"}, - }, - } + defaultArg := &IntegerExpression{ + PositiveLiteral: []byte("42"), + Base: 10, + } - body := &IntegerExpression{ - PositiveLiteral: []byte("1"), - Base: 10, - } + returnTypeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, + } - decl := &SpecialFunctionDeclaration{ - FunctionDeclaration: &FunctionDeclaration{ - Identifier: Identifier{Identifier: "foo"}, - TypeParameterList: &TypeParameterList{ - TypeParameters: []*TypeParameter{ - { - Identifier: Identifier{Identifier: "T"}, - TypeBound: typeBound, + body := &IntegerExpression{ + PositiveLiteral: []byte("1"), + Base: 10, + } + + decl := &SpecialFunctionDeclaration{ + FunctionDeclaration: &FunctionDeclaration{ + Identifier: Identifier{Identifier: "foo"}, + TypeParameterList: &TypeParameterList{ + TypeParameters: []*TypeParameter{ + { + Identifier: Identifier{Identifier: "T"}, + TypeBound: typeBound, + }, }, }, - }, - ParameterList: &ParameterList{ - Parameters: []*Parameter{ - { - Identifier: Identifier{Identifier: "x"}, - TypeAnnotation: paramTypeAnnotation, - DefaultArgument: defaultArg, + ParameterList: &ParameterList{ + Parameters: []*Parameter{ + { + Identifier: Identifier{Identifier: "x"}, + TypeAnnotation: paramTypeAnnotation, + DefaultArgument: defaultArg, + }, }, }, - }, - ReturnTypeAnnotation: returnTypeAnnotation, - FunctionBlock: &FunctionBlock{ - Block: &Block{ - Statements: []Statement{ - &ReturnStatement{ - Expression: body, + ReturnTypeAnnotation: returnTypeAnnotation, + FunctionBlock: &FunctionBlock{ + Block: &Block{ + Statements: []Statement{ + &ReturnStatement{ + Expression: body, + }, }, }, }, }, - }, - } + } - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) - }) + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) - assert.Equal(t, - []Element{ - typeBound, - paramTypeAnnotation, - defaultArg, - returnTypeAnnotation, - decl.FunctionDeclaration.FunctionBlock, - }, - visited, - ) + assert.Equal(t, + []Element{ + typeBound, + paramTypeAnnotation, + defaultArg, + returnTypeAnnotation, + decl.FunctionDeclaration.FunctionBlock, + }, + visited, + ) + }) } diff --git a/ast/interface.go b/ast/interface.go index b551abe58..8dae62955 100644 --- a/ast/interface.go +++ b/ast/interface.go @@ -70,7 +70,11 @@ func (*InterfaceDeclaration) ElementType() ElementType { } func (d *InterfaceDeclaration) Walk(walkChild func(Element)) { - walkDeclarations(walkChild, d.Members.declarations) + if d.Access != nil { + d.Access.Walk(walkChild) + } + walkElements(walkChild, d.Conformances) + walkElements(walkChild, d.Members.declarations) } func (*InterfaceDeclaration) isDeclaration() {} diff --git a/ast/interface_test.go b/ast/interface_test.go index a87101869..bc5abf161 100644 --- a/ast/interface_test.go +++ b/ast/interface_test.go @@ -306,36 +306,96 @@ func TestInterfaceDeclaration_Walk(t *testing.T) { t.Parallel() - field := &FieldDeclaration{ - Identifier: Identifier{Identifier: "field"}, - TypeAnnotation: &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "String"}, + t.Run("members only", func(t *testing.T) { + t.Parallel() + + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + TypeAnnotation: &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "String"}, + }, + }, + } + + function := &FunctionDeclaration{ + Identifier: Identifier{Identifier: "function"}, + } + + decl := &InterfaceDeclaration{ + Members: NewUnmeteredMembers([]Declaration{ + field, + function, + }), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + field, + function, }, - }, - } - - function := &FunctionDeclaration{ - Identifier: Identifier{Identifier: "function"}, - } - - decl := &InterfaceDeclaration{ - Members: NewUnmeteredMembers([]Declaration{ - field, - function, - }), - } - - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) + visited, + ) }) - assert.Equal(t, - []Element{ - field, - function, - }, - visited, - ) + t.Run("with conformances", func(t *testing.T) { + t.Parallel() + + conformance := &NominalType{ + Identifier: Identifier{Identifier: "Foo"}, + } + field := &FieldDeclaration{ + Identifier: Identifier{Identifier: "field"}, + } + + decl := &InterfaceDeclaration{ + Conformances: []*NominalType{conformance}, + Members: NewUnmeteredMembers([]Declaration{field}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + conformance, + field, + }, + visited, + ) + }) + + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + + decl := &InterfaceDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Members: NewUnmeteredMembers([]Declaration{}), + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + }, + visited, + ) + }) } diff --git a/ast/program.go b/ast/program.go index bd402d581..ef5d10e96 100644 --- a/ast/program.go +++ b/ast/program.go @@ -67,7 +67,7 @@ func (p *Program) EndPosition(memoryGauge common.MemoryGauge) Position { } func (p *Program) Walk(walkChild func(Element)) { - walkDeclarations(walkChild, p.declarations) + walkElements(walkChild, p.declarations) } func (p *Program) PragmaDeclarations() []*PragmaDeclaration { diff --git a/ast/statement.go b/ast/statement.go index ca71c5162..0f137e485 100644 --- a/ast/statement.go +++ b/ast/statement.go @@ -853,7 +853,7 @@ func (s *SwitchStatement) Walk(walkChild func(Element)) { if expression != nil { walkChild(expression) } - walkStatements(walkChild, switchCase.Statements) + walkElements(walkChild, switchCase.Statements) } } diff --git a/ast/variable_declaration.go b/ast/variable_declaration.go index f9cd98a83..ec6dde75f 100644 --- a/ast/variable_declaration.go +++ b/ast/variable_declaration.go @@ -107,6 +107,9 @@ func (d *VariableDeclaration) EndPosition(memoryGauge common.MemoryGauge) Positi func (*VariableDeclaration) isIfStatementTest() {} func (d *VariableDeclaration) Walk(walkChild func(Element)) { + if d.Access != nil { + d.Access.Walk(walkChild) + } if d.TypeAnnotation != nil { walkChild(d.TypeAnnotation) } diff --git a/ast/variable_declaration_test.go b/ast/variable_declaration_test.go index 3e9d5c202..43912973b 100644 --- a/ast/variable_declaration_test.go +++ b/ast/variable_declaration_test.go @@ -419,40 +419,78 @@ func TestVariableDeclaration_Walk(t *testing.T) { t.Parallel() - typeAnnotation := &TypeAnnotation{ - Type: &NominalType{ - Identifier: Identifier{Identifier: "Int"}, - }, - } + t.Run("without access", func(t *testing.T) { + t.Parallel() - value := &IntegerExpression{ - PositiveLiteral: []byte("1"), - Base: 10, - } + typeAnnotation := &TypeAnnotation{ + Type: &NominalType{ + Identifier: Identifier{Identifier: "Int"}, + }, + } - secondValue := &IntegerExpression{ - PositiveLiteral: []byte("2"), - Base: 10, - } + value := &IntegerExpression{ + PositiveLiteral: []byte("1"), + Base: 10, + } - decl := &VariableDeclaration{ - Identifier: Identifier{Identifier: "x"}, - TypeAnnotation: typeAnnotation, - Value: value, - SecondValue: secondValue, - } + secondValue := &IntegerExpression{ + PositiveLiteral: []byte("2"), + Base: 10, + } + + decl := &VariableDeclaration{ + Identifier: Identifier{Identifier: "x"}, + TypeAnnotation: typeAnnotation, + Value: value, + SecondValue: secondValue, + } - var visited []Element - decl.Walk(func(element Element) { - visited = append(visited, element) + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + typeAnnotation, + value, + secondValue, + }, + visited, + ) }) - assert.Equal(t, - []Element{ - typeAnnotation, - value, - secondValue, - }, - visited, - ) + t.Run("with entitlement access", func(t *testing.T) { + t.Parallel() + + entitlement := &NominalType{ + Identifier: Identifier{Identifier: "E"}, + } + + value := &IntegerExpression{ + PositiveLiteral: []byte("1"), + Base: 10, + } + + decl := &VariableDeclaration{ + Access: NewEntitlementAccess( + NewConjunctiveEntitlementSet([]*NominalType{entitlement}), + ), + Identifier: Identifier{Identifier: "x"}, + Value: value, + } + + var visited []Element + decl.Walk(func(element Element) { + visited = append(visited, element) + }) + + assert.Equal(t, + []Element{ + entitlement, + value, + }, + visited, + ) + }) } diff --git a/ast/walk.go b/ast/walk.go index 9b272b97e..cb279a585 100644 --- a/ast/walk.go +++ b/ast/walk.go @@ -44,20 +44,8 @@ func Walk(walker Walker, element Element) { walker.Walk(nil) } -func walkExpressions(walkChild func(Element), expressions []Expression) { - for _, expression := range expressions { - walkChild(expression) - } -} - -func walkStatements(walkChild func(Element), statements []Statement) { - for _, statement := range statements { - walkChild(statement) - } -} - -func walkDeclarations(walkChild func(Element), declarations []Declaration) { - for _, declaration := range declarations { - walkChild(declaration) +func walkElements[E Element](walkChild func(Element), elements []E) { + for _, element := range elements { + walkChild(element) } }