From 1458dcf4252e27369884f1917340f3d1af02faed Mon Sep 17 00:00:00 2001 From: Ardit Marku Date: Thu, 28 Sep 2023 11:18:50 +0300 Subject: [PATCH 1/2] Remove workaround with compositeType.ResolveMembers() --- runtime/sema/type.go | 14 ++------------ runtime/stdlib/test_contract.go | 1 - 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index ec47eee467..4eb524bc5e 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -4294,11 +4294,7 @@ func (t *CompositeType) GetMembers() map[string]MemberResolver { } func (t *CompositeType) initializeMemberResolvers() { - t.memberResolversOnce.Do(t.initializerMemberResolversFunc()) -} - -func (t *CompositeType) initializerMemberResolversFunc() func() { - return func() { + t.memberResolversOnce.Do(func() { memberResolvers := MembersMapAsResolvers(t.Members) // Check conformances. @@ -4317,13 +4313,7 @@ func (t *CompositeType) initializerMemberResolversFunc() func() { }) t.memberResolvers = withBuiltinMembers(t, memberResolvers) - } -} - -func (t *CompositeType) ResolveMembers() { - if t.Members.Len() != len(t.GetMembers()) { - t.initializerMemberResolversFunc()() - } + }) } func (t *CompositeType) FieldPosition(name string, declaration ast.CompositeLikeDeclaration) ast.Position { diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 3a7bcd12fb..41610f5661 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -1193,7 +1193,6 @@ func newTestContractType() *TestContractType { ty.expectFailureFunction = newTestTypeExpectFailureFunction( expectFailureFunctionType, ) - compositeType.ResolveMembers() return ty } From e4bba039a3374fd925f4dbfc340870da57725599 Mon Sep 17 00:00:00 2001 From: Ardit Marku Date: Thu, 28 Sep 2023 11:20:01 +0300 Subject: [PATCH 2/2] Add native function declarations for all revelent Test contract functions --- runtime/stdlib/contracts/test.cdc | 83 ++++++++++++++++ runtime/stdlib/test_contract.go | 10 +- runtime/stdlib/test_test.go | 151 +++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 7 deletions(-) diff --git a/runtime/stdlib/contracts/test.cdc b/runtime/stdlib/contracts/test.cdc index 8e7a39da7d..83c0d3cdc2 100644 --- a/runtime/stdlib/contracts/test.cdc +++ b/runtime/stdlib/contracts/test.cdc @@ -437,4 +437,87 @@ access(all) contract Test { assert(found, message: "the error message did not contain the given sub-string") } + /// Creates a matcher with a test function. + /// The test function is of type '((T): Bool)', + /// where 'T' is bound to 'AnyStruct'. + /// + access(all) + native fun newMatcher(_ test: ((T): Bool)): Test.Matcher {} + + /// Wraps a function call in a closure, and expects it to fail with + /// an error message that contains the given error message portion. + /// + access(all) + native fun expectFailure( + _ functionWrapper: ((): Void), + errorMessageSubstring: String + ) {} + + /// Expect function tests a value against a matcher + /// and fails the test if it's not a match. + /// + access(all) + native fun expect(_ value: T, _ matcher: Test.Matcher) {} + + /// Returns a matcher that succeeds if the tested + /// value is equal to the given value. + /// + access(all) + native fun equal(_ value: T): Test.Matcher {} + + /// Fails the test-case if the given values are not equal, and + /// reports a message which explains how the two values differ. + /// + access(all) + native fun assertEqual(_ expected: AnyStruct, _ actual: AnyStruct) {} + + /// Returns a matcher that succeeds if the tested value is + /// an array or dictionary and the tested value contains + /// no elements. + /// + access(all) + native fun beEmpty(): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value is + /// an array or dictionary and has the given number of elements. + /// + access(all) + native fun haveElementCount(_ count: Int): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value is + /// an array that contains a value that is equal to the given + /// value, or the tested value is a dictionary that contains + /// an entry where the key is equal to the given value. + /// + access(all) + native fun contain(_ element: AnyStruct): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value + /// is a number and greater than the given number. + /// + access(all) + native fun beGreaterThan(_ value: Number): Test.Matcher {} + + /// Returns a matcher that succeeds if the tested value + /// is a number and less than the given number. + /// + access(all) + native fun beLessThan(_ value: Number): Test.Matcher {} + + /// Read a local file, and return the content as a string. + /// + access(all) + native fun readFile(_ path: String): String {} + + /// Fails the test-case if the given condition is false, + /// and reports a message which explains how the condition is false. + /// + access(all) + native fun assert(_ condition: Bool, message: String): Void {} + + /// Fails the test-case with a message. + /// + access(all) + native fun fail(message: String): Void {} + } diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 41610f5661..9aadd919da 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -950,7 +950,10 @@ func newTestContractType() *TestContractType { program, err := parser.ParseProgram( nil, contracts.TestContract, - parser.Config{}, + parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, ) if err != nil { panic(err) @@ -965,8 +968,9 @@ func newTestContractType() *TestContractType { TestContractLocation, nil, &sema.Config{ - BaseValueActivation: activation, - AccessCheckMode: sema.AccessCheckModeStrict, + BaseValueActivation: activation, + AccessCheckMode: sema.AccessCheckModeStrict, + AllowNativeDeclarations: true, }, ) if err != nil { diff --git a/runtime/stdlib/test_test.go b/runtime/stdlib/test_test.go index de5545174a..6772ec871e 100644 --- a/runtime/stdlib/test_test.go +++ b/runtime/stdlib/test_test.go @@ -57,7 +57,10 @@ func newTestContractInterpreterWithTestFramework( program, err := parser.ParseProgram( nil, []byte(code), - parser.Config{}, + parser.Config{ + NativeModifierEnabled: true, + TypeParametersEnabled: true, + }, ) require.NoError(t, err) @@ -70,8 +73,9 @@ func newTestContractInterpreterWithTestFramework( utils.TestLocation, nil, &sema.Config{ - BaseValueActivation: activation, - AccessCheckMode: sema.AccessCheckModeStrict, + BaseValueActivation: activation, + AccessCheckMode: sema.AccessCheckModeStrict, + AllowNativeDeclarations: true, ImportHandler: func( checker *sema.Checker, importedLocation common.Location, @@ -660,6 +664,74 @@ func TestTestEqualMatcher(t *testing.T) { }) } +func TestAssertFunction(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + pub fun testAssertWithNoArgs() { + Test.assert(true) + } + + pub fun testAssertWithNoArgsFail() { + Test.assert(false) + } + + pub fun testAssertWithMessage() { + Test.assert(true, message: "some reason") + } + + pub fun testAssertWithMessageFail() { + Test.assert(false, message: "some reason") + } + ` + + inter, err := newTestContractInterpreter(t, script) + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithNoArgs") + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithNoArgsFail") + require.Error(t, err) + assert.ErrorContains(t, err, "assertion failed") + + _, err = inter.Invoke("testAssertWithMessage") + require.NoError(t, err) + + _, err = inter.Invoke("testAssertWithMessageFail") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed: some reason") +} + +func TestFailFunction(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + pub fun testFailWithoutMessage() { + Test.fail() + } + + pub fun testFailWithMessage() { + Test.fail(message: "some error") + } + ` + + inter, err := newTestContractInterpreter(t, script) + require.NoError(t, err) + + _, err = inter.Invoke("testFailWithoutMessage") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed") + + _, err = inter.Invoke("testFailWithMessage") + require.Error(t, err) + require.ErrorContains(t, err, "assertion failed: some error") +} + func TestAssertEqual(t *testing.T) { t.Parallel() @@ -931,7 +1003,7 @@ func TestAssertEqual(t *testing.T) { pub fun test() { let foo = Foo() let bar <- create Bar() - Test.expect(foo, Test.equal(<-bar)) + Test.assertEqual(foo, <-bar) } pub struct Foo {} @@ -2645,6 +2717,77 @@ func TestBlockchain(t *testing.T) { assert.True(t, getAccountInvoked) }) + t.Run("readFile", func(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + pub fun test() { + let content = Test.readFile("some_file.cdc") + Test.assertEqual("Hey there!", content) + } + ` + + readFileInvoked := false + + testFramework := &mockedTestFramework{ + emulatorBackend: func() stdlib.Blockchain { + return &mockedBlockchain{} + }, + readFile: func(path string) (string, error) { + readFileInvoked = true + assert.Equal(t, "some_file.cdc", path) + + return "Hey there!", nil + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.NoError(t, err) + + assert.True(t, readFileInvoked) + }) + + t.Run("readFile with failure", func(t *testing.T) { + t.Parallel() + + const script = ` + import Test + + pub fun test() { + let content = Test.readFile("some_file.cdc") + Test.assertEqual("Hey there!", content) + } + ` + + readFileInvoked := false + + testFramework := &mockedTestFramework{ + emulatorBackend: func() stdlib.Blockchain { + return &mockedBlockchain{} + }, + readFile: func(path string) (string, error) { + readFileInvoked = true + assert.Equal(t, "some_file.cdc", path) + + return "", fmt.Errorf("could not read file: %s", path) + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.Error(t, err) + assert.ErrorContains(t, err, "could not read file: some_file.cdc") + + assert.True(t, readFileInvoked) + }) + // TODO: Add more tests for the remaining functions. }