diff --git a/bbq/vm/builtin_globals.go b/bbq/vm/builtin_globals.go index 81e2905bd..6dcba979c 100644 --- a/bbq/vm/builtin_globals.go +++ b/bbq/vm/builtin_globals.go @@ -293,6 +293,8 @@ func init() { registerBuiltinCommonTypeBoundFunctions() registerBuiltinSaturatingArithmeticFunctions() + + registerBuiltinFixedPointPowFunctions() } func registerBuiltinCommonTypeBoundFunctions() { @@ -421,6 +423,19 @@ func registerBuiltinTypeSaturatingArithmeticFunctions(t sema.SaturatingArithmeti } } +func registerBuiltinFixedPointPowFunctions() { + for baseType, funcType := range sema.FixedPointPowFunctionTypes { //nolint:maprange + registerBuiltinTypeBoundFunction( + commons.TypeQualifier(baseType), + NewNativeFunctionValue( + sema.FixedPointNumericTypePowFunctionName, + funcType, + interpreter.NativeFixedPointPowFunction, + ), + ) + } +} + func newFromStringFunction(typedParser interpreter.TypedStringValueParser) *NativeFunctionValue { functionType := sema.FromStringFunctionType(typedParser.ReceiverType) parser := typedParser.Parser diff --git a/interpreter/fixedpoint_test.go b/interpreter/fixedpoint_test.go index 1534cff97..a7ccb4c2a 100644 --- a/interpreter/fixedpoint_test.go +++ b/interpreter/fixedpoint_test.go @@ -37,6 +37,183 @@ import ( . "github.com/onflow/cadence/test_utils/sema_utils" ) +func TestInterpretFixedPointPow(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + base string + exponent string + expected uint64 + expectedError bool + } + + // Expected values were pre-computed using the fixed-point library's UFix64.Pow(Fix64). + testCases := []testCase{ + // Edge cases + {base: "0.00000000", exponent: "0.00000000", expected: 100000000}, // 0^0 = 1 + {base: "0.00000000", exponent: "2.00000000", expected: 0}, // 0^2 = 0 + {base: "1.00000000", exponent: "0.00000000", expected: 100000000}, // 1^0 = 1 + {base: "1.00000000", exponent: "5.00000000", expected: 100000000}, // 1^5 = 1 + {base: "2.00000000", exponent: "0.00000000", expected: 100000000}, // 2^0 = 1 + {base: "2.00000000", exponent: "1.00000000", expected: 200000000}, // 2^1 = 2 + + // Integer exponents + {base: "2.00000000", exponent: "3.00000000", expected: 800000000}, // 2^3 = 8 + {base: "5.00000000", exponent: "2.00000000", expected: 2500000000}, // 5^2 = 25 + {base: "10.00000000", exponent: "3.00000000", expected: 100000000000}, // 10^3 = 1000 + + // Negative exponents + {base: "2.00000000", exponent: "-1.00000000", expected: 50000000}, // 2^(-1) = 0.5 + {base: "4.00000000", exponent: "-1.00000000", expected: 25000000}, // 4^(-1) = 0.25 + {base: "10.00000000", exponent: "-2.00000000", expected: 1000000}, // 10^(-2) = 0.01 + + // Fractional bases + {base: "0.50000000", exponent: "2.00000000", expected: 25000000}, // 0.5^2 = 0.25 + {base: "1.50000000", exponent: "2.00000000", expected: 225000000}, // 1.5^2 = 2.25 + {base: "0.25000000", exponent: "3.00000000", expected: 1562500}, // 0.25^3 = 0.015625 + + // Fractional exponents + {base: "4.00000000", exponent: "0.50000000", expected: 200000000}, // 4^0.5 = 2 + {base: "9.00000000", exponent: "0.50000000", expected: 300000000}, // 9^0.5 = 3 + {base: "8.00000000", exponent: "0.33333333", expected: 199999999}, // 8^(1/3) ≈ 2 + + // Values from library test data + {base: "0.11111111", exponent: "2.00000000", expected: 1234568}, // (1/9)^2 + {base: "0.33333333", exponent: "3.00000000", expected: 3703704}, // (1/3)^3 + {base: "2.71828183", exponent: "1.00000000", expected: 271828183}, // e^1 + {base: "3.14159265", exponent: "-0.50000000", expected: 56418958}, // pi^(-0.5) + {base: "0.14285714", exponent: "2.00000000", expected: 2040816}, // (1/7)^2 + {base: "123.45678901", exponent: "0.50000000", expected: 1111111106}, // 123.45678901^0.5 + + // Repeating decimal bases with negative exponents + {base: "0.66666666", exponent: "-1.00000000", expected: 150000002}, // (2/3)^(-1) + {base: "0.50000000", exponent: "-2.00000000", expected: 400000000}, // 0.5^(-2) = 4 + + // Overflow + {base: "429496.72960000", exponent: "2.00000000", expectedError: true}, // sqrt(MaxUFix64)^2 overflows + {base: "10.00000000", exponent: "20.00000000", expectedError: true}, // 10^20 overflows + + // Underflow (truncated to 0 by handleFixedpointError) + {base: "0.00000003", exponent: "2.00000000", expected: 0}, // 0.00000003^2 underflows to 0 + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s ^ %s", tc.base, tc.exponent) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix64 { + let base: UFix64 = %s + let exponent: Fix64 = %s + return base.pow(exponent) + } + `, + tc.base, + tc.exponent, + ) + + inter := parseCheckAndPrepare(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredUFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + base string + exponent string + expected string + expectedError bool + } + + // Expected values were pre-computed using the fixed-point library's UFix128.Pow(Fix128). + testCases := []testCase{ + // Edge cases + {base: "0.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 0^0 = 1 + {base: "1.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 1^0 = 1 + {base: "1.000000000000000000000000", exponent: "5.000000000000000000000000", expected: "1.000000000000000000000000"}, // 1^5 = 1 + {base: "2.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 2^0 = 1 + {base: "2.000000000000000000000000", exponent: "1.000000000000000000000000", expected: "2.000000000000000000000000"}, // 2^1 = 2 + + // Integer exponents + {base: "2.000000000000000000000000", exponent: "3.000000000000000000000000", expected: "8.000000000000000000000000"}, // 2^3 = 8 + {base: "5.000000000000000000000000", exponent: "2.000000000000000000000000", expected: "25.000000000000000000000000"}, // 5^2 = 25 + {base: "10.000000000000000000000000", exponent: "3.000000000000000000000000", expected: "1000.000000000000000000000000"}, // 10^3 = 1000 + + // Negative exponents + {base: "2.000000000000000000000000", exponent: "-1.000000000000000000000000", expected: "0.500000000000000000000000"}, // 2^(-1) = 0.5 + {base: "4.000000000000000000000000", exponent: "-1.000000000000000000000000", expected: "0.250000000000000000000000"}, // 4^(-1) = 0.25 + + // Fractional base + {base: "0.500000000000000000000000", exponent: "2.000000000000000000000000", expected: "0.250000000000000000000000"}, // 0.5^2 = 0.25 + + // Fractional exponent + {base: "4.000000000000000000000000", exponent: "0.500000000000000000000000", expected: "2.000000000000000000000000"}, // 4^0.5 = 2 + {base: "9.000000000000000000000000", exponent: "0.500000000000000000000000", expected: "3.000000000000000000000000"}, // 9^0.5 = 3 + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s ^ %s", tc.base, tc.exponent) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix128 { + let base: UFix128 = %s + let exponent: Fix128 = %s + return base.pow(exponent) + } + `, + tc.base, + tc.exponent, + ) + + inter := parseCheckAndPrepare(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := parseCheckAndPrepare(t, fmt.Sprintf( + `let expected: UFix128 = %s`, + tc.expected, + )).GetGlobal("expected") + + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) +} + func TestInterpretNegativeZeroFixedPoint(t *testing.T) { t.Parallel() diff --git a/interpreter/value_number.go b/interpreter/value_number.go index d1a96b228..67004e2d7 100644 --- a/interpreter/value_number.go +++ b/interpreter/value_number.go @@ -22,6 +22,7 @@ import ( "math/big" "github.com/onflow/cadence/common" + "github.com/onflow/cadence/errors" "github.com/onflow/cadence/sema" ) @@ -103,6 +104,18 @@ func getNumberValueFunctionMember( sema.SaturatingArithmeticTypeFunctionTypes[typ], NativeNumberSaturatingDivideFunction, ) + + case sema.FixedPointNumericTypePowFunctionName: + funcType, ok := sema.FixedPointPowFunctionTypes[typ] + if !ok { + return nil + } + return NewBoundHostFunctionValue( + context, + v, + funcType, + NativeFixedPointPowFunction, + ) } return nil @@ -214,3 +227,24 @@ var NativeNumberSaturatingDivideFunction = NativeFunction( return receiver.(NumberValue).SaturatingDiv(context, other) }, ) + +var NativeFixedPointPowFunction = NativeFunction( + func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + _ ArgumentTypesIterator, + receiver Value, + args []Value, + ) Value { + switch v := receiver.(type) { + case UFix64Value: + exponent := AssertValueOfType[Fix64Value](args[0]) + return v.Pow(context, exponent) + case UFix128Value: + exponent := AssertValueOfType[Fix128Value](args[0]) + return v.Pow(context, exponent) + default: + panic(errors.NewUnreachableError()) + } + }, +) diff --git a/interpreter/value_ufix128.go b/interpreter/value_ufix128.go index d4ec4dfb5..bc3958784 100644 --- a/interpreter/value_ufix128.go +++ b/interpreter/value_ufix128.go @@ -362,6 +362,16 @@ func (v UFix128Value) Mod(context NumberValueArithmeticContext, other NumberValu return NewUFix128Value(context, valueGetter) } +func (v UFix128Value) Pow(context NumberValueArithmeticContext, other Fix128Value) NumberValue { + valueGetter := func() fix.UFix128 { + result, err := fix.UFix128(v).Pow(fix.Fix128(other)) + handleFixedpointError(err) + return result + } + + return NewUFix128Value(context, valueGetter) +} + func (v UFix128Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(UFix128Value) if !ok { diff --git a/interpreter/value_ufix64.go b/interpreter/value_ufix64.go index 46979bf77..48c561606 100644 --- a/interpreter/value_ufix64.go +++ b/interpreter/value_ufix64.go @@ -26,6 +26,8 @@ import ( "github.com/onflow/atree" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" "github.com/onflow/cadence/errors" @@ -355,6 +357,18 @@ func (v UFix64Value) Mod(context NumberValueArithmeticContext, other NumberValue return UFix64Value{UFix64Value: result} } +func (v UFix64Value) Pow(context NumberValueArithmeticContext, other Fix64Value) NumberValue { + valueGetter := func() uint64 { + a := fix.UFix64(uint64(v.UFix64Value)) + b := fix.Fix64(uint64(other)) + result, err := a.Pow(b) + handleFixedpointError(err) + return uint64(result) + } + + return NewUFix64Value(context, valueGetter) +} + func (v UFix64Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(UFix64Value) if !ok { diff --git a/sema/fixedpoint_test.go b/sema/fixedpoint_test.go index 2b1584142..205b90bab 100644 --- a/sema/fixedpoint_test.go +++ b/sema/fixedpoint_test.go @@ -32,6 +32,114 @@ import ( . "github.com/onflow/cadence/test_utils/sema_utils" ) +func TestCheckFixedPointPow(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let result = 2.0.pow(3.0) + `) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, sema.UFix64Type, resultType) + }) + + t.Run("valid, negative exponent", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let result = 2.0.pow(-1.0) + `) + require.NoError(t, err) + }) + + t.Run("valid, explicit types", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: UFix64 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.NoError(t, err) + }) + + t.Run("invalid, wrong exponent type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let result = 2.0.pow(3.0 as UFix64) + `) + require.Error(t, err) + }) + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let base: UFix128 = 2.0 + let exponent: Fix128 = 3.0 + let result = base.pow(exponent) + `) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, sema.UFix128Type, resultType) + }) + + t.Run("invalid, wrong exponent type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: UFix128 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + }) + + t.Run("not available on signed types", func(t *testing.T) { + + t.Parallel() + + t.Run("Fix64", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: Fix64 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + + t.Run("Fix128", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: Fix128 = 2.0 + let exponent: Fix128 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + }) +} + func TestCheckFixedPointLiteralTypeConversionInVariableDeclaration(t *testing.T) { t.Parallel() diff --git a/sema/type.go b/sema/type.go index 45bf8a58a..894b54837 100644 --- a/sema/type.go +++ b/sema/type.go @@ -1326,17 +1326,26 @@ func registerSaturatingArithmeticType(t Type) { ) } -func addSaturatingArithmeticFunctions(t SaturatingArithmeticType, members map[string]MemberResolver) { +func addSaturatingArithmeticFunctions( + t SaturatingArithmeticType, + members map[string]MemberResolver, +) { + functionType := SaturatingArithmeticTypeFunctionTypes[t] addArithmeticFunction := func(name string, docString string) { members[name] = MemberResolver{ Kind: common.DeclarationKindFunction, - Resolve: func(memoryGauge common.MemoryGauge, _ string, _ ast.HasPosition, _ func(error)) *Member { + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { return NewPublicFunctionMember( memoryGauge, t, name, - SaturatingArithmeticTypeFunctionTypes[t], + functionType, docString, ) }, @@ -1379,6 +1388,53 @@ type SaturatingArithmeticSupport struct { Divide bool } +const FixedPointNumericTypePowFunctionName = "pow" +const fixedPointNumericTypePowFunctionDocString = ` +Returns this value raised to the power of the given exponent. +The exponent may be negative or fractional. +` + +var FixedPointPowFunctionTypes = map[Type]*FunctionType{} + +func registerFixedPointPowFunction(t *FixedPointNumericType, exponentType *FixedPointNumericType) { + FixedPointPowFunctionTypes[t] = NewSimpleFunctionType( + FunctionPurityView, + []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "exponent", + TypeAnnotation: NewTypeAnnotation(exponentType), + }, + }, + NewTypeAnnotation(t), + ) +} + +func addFixedPointPowFunction( + t *FixedPointNumericType, + members map[string]MemberResolver, +) { + functionType := FixedPointPowFunctionTypes[t] + + members[FixedPointNumericTypePowFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { + return NewPublicFunctionMember( + memoryGauge, + t, + FixedPointNumericTypePowFunctionName, + functionType, + fixedPointNumericTypePowFunctionDocString, + ) + }, + } +} + // NumericType represent all the types in the integer range // and non-fractional ranged types. type NumericType struct { @@ -1677,6 +1733,11 @@ func (t *FixedPointNumericType) WithSaturatingFunctions(saturatingArithmetic Sat return t } +func (t *FixedPointNumericType) WithPowFunction(exponentType *FixedPointNumericType) *FixedPointNumericType { + registerFixedPointPowFunction(t, exponentType) + return t +} + func (t *FixedPointNumericType) SupportsSaturatingAdd() bool { return t.saturatingArithmetic.Add } @@ -1817,6 +1878,9 @@ func (t *FixedPointNumericType) GetMembers() map[string]MemberResolver { // Compute members and cache them computedMembers := map[string]MemberResolver{} addSaturatingArithmeticFunctions(t, computedMembers) + if _, ok := FixedPointPowFunctionTypes[t]; ok { + addFixedPointPowFunction(t, computedMembers) + } computedMembers = withBuiltinMembers(t, computedMembers) t.memberResolvers.Store(&computedMembers) return computedMembers @@ -2142,7 +2206,8 @@ var UFix64Type = NewFixedPointNumericType(UFix64TypeName). Add: true, Subtract: true, Multiply: true, - }) + }). + WithPowFunction(Fix64Type) var UFix64TypeAnnotation = NewTypeAnnotation(UFix64Type) // UFix128Type represents the 128-bit unsigned decimal fixed-point type `UFix128` @@ -2156,7 +2221,8 @@ var UFix128Type = NewFixedPointNumericType(UFix128TypeName). Add: true, Subtract: true, Multiply: true, - }) + }). + WithPowFunction(Fix128Type) var UFix128TypeAnnotation = NewTypeAnnotation(UFix128Type) // Numeric type ranges