diff --git a/bbq/vm/builtin_globals.go b/bbq/vm/builtin_globals.go index dde59dfd7..6c9089d64 100644 --- a/bbq/vm/builtin_globals.go +++ b/bbq/vm/builtin_globals.go @@ -301,6 +301,10 @@ func init() { registerBuiltinCommonTypeBoundFunctions() registerBuiltinSaturatingArithmeticFunctions() + + registerBuiltinFixedPointPowFunctions() + + registerBuiltinFixedPointMultiplyDivideFunctions() } func registerBuiltinCommonTypeBoundFunctions() { @@ -429,6 +433,32 @@ func registerBuiltinTypeSaturatingArithmeticFunctions(t sema.SaturatingArithmeti } } +func registerBuiltinFixedPointPowFunctions() { + for baseType, funcType := range sema.FixedPointPowFunctionTypes { //nolint:maprange + registerBuiltinTypeBoundFunction( + commons.TypeQualifier(baseType), + NewNativeFunctionValue( + sema.FixedPointNumericTypePowFunctionName, + funcType, + interpreter.NativeFixedPointPowFunction, + ), + ) + } +} + +func registerBuiltinFixedPointMultiplyDivideFunctions() { + for baseType, funcType := range sema.FixedPointMultiplyDivideFunctionTypes { + registerBuiltinTypeBoundFunction( + commons.TypeQualifier(baseType), + NewNativeFunctionValue( + sema.FixedPointNumericTypeMultiplyDivideFunctionName, + funcType, + interpreter.NativeFixedPointMultiplyDivideFunction, + ), + ) + } +} + func newFromStringFunction(typedParser interpreter.TypedStringValueParser) *NativeFunctionValue { functionType := sema.FromStringFunctionType(typedParser.ReceiverType) parser := typedParser.Parser diff --git a/interpreter/fixedpoint_test.go b/interpreter/fixedpoint_test.go index bbbd1dea1..89dc3e555 100644 --- a/interpreter/fixedpoint_test.go +++ b/interpreter/fixedpoint_test.go @@ -23,6 +23,7 @@ import ( "math" "math/big" "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -40,6 +41,531 @@ import ( . "github.com/onflow/cadence/test_utils/sema_utils" ) +func fix128BigInt(s string) *big.Int { + // Parse a decimal string like "33.333333333333333333333333" + // into the raw scaled big.Int (removing the decimal point). + // The fractional part must have exactly Fix128Scale (24) digits. + parts := strings.SplitN(s, ".", 2) + if len(parts) == 1 { + // No decimal point — treat as integer, scale up + v, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("invalid fix128 string: " + s) + } + return v.Mul(v, sema.Fix128FactorIntBig) + } + if len(parts[1]) != sema.Fix128Scale { + panic(fmt.Sprintf("expected %d fractional digits, got %d: %s", sema.Fix128Scale, len(parts[1]), s)) + } + raw := parts[0] + parts[1] + v, ok := new(big.Int).SetString(raw, 10) + if !ok { + panic("invalid fix128 string: " + s) + } + return v +} + +func TestInterpretFixedPointMultiplyDivide(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected uint64 + expectedError bool + } + + // Expected values pre-computed using the fixed-point library's UFix64.FMD(). + testCases := []testCase{ + // Basic: 2*3/1 = 6 + {a: "2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: 600000000}, + // Rounding modes: 10*10/3 + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "towardZero", expected: 3333333333}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 3333333334}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "nearestHalfAway", expected: 3333333333}, + // Fractional: 0.5*0.5/1.0 = 0.25 + {a: "0.50000000", b: "0.50000000", c: "1.00000000", rounding: "towardZero", expected: 25000000}, + // Zero factor + {a: "0.00000000", b: "5.00000000", c: "2.00000000", rounding: "towardZero", expected: 0}, + {a: "5.00000000", b: "0.00000000", c: "2.00000000", rounding: "towardZero", expected: 0}, + // Larger values: 100*200/50 = 400 + {a: "100.00000000", b: "200.00000000", c: "50.00000000", rounding: "towardZero", expected: 40000000000}, + // 1*1/3 with different rounding + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "towardZero", expected: 33333333}, + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 33333334}, + // Division by zero + {a: "1.00000000", b: "2.00000000", c: "0.00000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix64 { + let a: UFix64 = %s + let b: UFix64 = %s + let c: UFix64 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredUFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("Fix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected int64 + expectedError bool + } + + testCases := []testCase{ + // Basic: 2*3/1 = 6 + {a: "2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: 600000000}, + // Signed: (-2)*3/1 = -6 + {a: "-2.00000000", b: "3.00000000", c: "1.00000000", rounding: "towardZero", expected: -600000000}, + // (-5)*(-3)/2 = 7.5 + {a: "-5.00000000", b: "-3.00000000", c: "2.00000000", rounding: "towardZero", expected: 750000000}, + // Rounding: 10*10/3 + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "towardZero", expected: 3333333333}, + {a: "10.00000000", b: "10.00000000", c: "3.00000000", rounding: "awayFromZero", expected: 3333333334}, + // 1/3 truncated + {a: "1.00000000", b: "1.00000000", c: "3.00000000", rounding: "towardZero", expected: 33333333}, + // Division by zero + {a: "1.00000000", b: "2.00000000", c: "0.00000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): Fix64 { + let a: Fix64 = %s + let b: Fix64 = %s + let c: Fix64 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected string + expectedError bool + } + + testCases := []testCase{ + {a: "2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "towardZero", expected: "33.333333333333333333333333"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "awayFromZero", expected: "33.333333333333333333333334"}, + {a: "0.500000000000000000000000", b: "0.500000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "0.250000000000000000000000"}, + {a: "100.000000000000000000000000", b: "200.000000000000000000000000", c: "50.000000000000000000000000", rounding: "towardZero", expected: "400.000000000000000000000000"}, + // Division by zero + {a: "1.000000000000000000000000", b: "2.000000000000000000000000", c: "0.000000000000000000000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix128 { + let a: UFix128 = %s + let b: UFix128 = %s + let c: UFix128 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("Fix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + a, b, c string + rounding string + expected string + expectedError bool + } + + testCases := []testCase{ + {a: "2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "-2.000000000000000000000000", b: "3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "-6.000000000000000000000000"}, + {a: "-2.000000000000000000000000", b: "-3.000000000000000000000000", c: "1.000000000000000000000000", rounding: "towardZero", expected: "6.000000000000000000000000"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "towardZero", expected: "33.333333333333333333333333"}, + {a: "10.000000000000000000000000", b: "10.000000000000000000000000", c: "3.000000000000000000000000", rounding: "awayFromZero", expected: "33.333333333333333333333334"}, + // Division by zero + {a: "1.000000000000000000000000", b: "2.000000000000000000000000", c: "0.000000000000000000000000", rounding: "towardZero", expectedError: true}, + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s * %s / %s (%s)", tc.a, tc.b, tc.c, tc.rounding) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): Fix128 { + let a: Fix128 = %s + let b: Fix128 = %s + let c: Fix128 = %s + return a.multiplyDivide(b, c, rounding: RoundingRule.%s) + } + `, + tc.a, tc.b, tc.c, tc.rounding, + ) + + inter := parseCheckAndPrepareWithRoundingRule(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("default rounding (truncate)", func(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): UFix64 { + let a: UFix64 = 10.0 + let b: UFix64 = 10.0 + let c: UFix64 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.33333333 + expected := interpreter.NewUnmeteredUFix64Value(3333333333) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("Fix64", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): Fix64 { + let a: Fix64 = 10.0 + let b: Fix64 = 10.0 + let c: Fix64 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.33333333 + expected := interpreter.NewUnmeteredFix64Value(3333333333) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("UFix128", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): UFix128 { + let a: UFix128 = 10.0 + let b: UFix128 = 10.0 + let c: UFix128 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.333333333333333333333333 + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt("33.333333333333333333333333")) + AssertValuesEqual(t, inter, expected, result) + }) + + t.Run("Fix128", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndPrepareWithRoundingRule(t, ` + fun test(): Fix128 { + let a: Fix128 = 10.0 + let b: Fix128 = 10.0 + let c: Fix128 = 3.0 + return a.multiplyDivide(b, c) + } + `) + result, err := inter.Invoke("test") + require.NoError(t, err) + + // 10*10/3 truncated = 33.333333333333333333333333 + expected := interpreter.NewFix128ValueFromBigInt(nil, fix128BigInt("33.333333333333333333333333")) + AssertValuesEqual(t, inter, expected, result) + }) + }) +} + +func TestInterpretFixedPointPow(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + base string + exponent string + expected uint64 + expectedError bool + } + + // Expected values were pre-computed using the fixed-point library's UFix64.Pow(Fix64). + testCases := []testCase{ + // Edge cases + {base: "0.00000000", exponent: "0.00000000", expected: 100000000}, // 0^0 = 1 + {base: "0.00000000", exponent: "2.00000000", expected: 0}, // 0^2 = 0 + {base: "1.00000000", exponent: "0.00000000", expected: 100000000}, // 1^0 = 1 + {base: "1.00000000", exponent: "5.00000000", expected: 100000000}, // 1^5 = 1 + {base: "2.00000000", exponent: "0.00000000", expected: 100000000}, // 2^0 = 1 + {base: "2.00000000", exponent: "1.00000000", expected: 200000000}, // 2^1 = 2 + + // Integer exponents + {base: "2.00000000", exponent: "3.00000000", expected: 800000000}, // 2^3 = 8 + {base: "5.00000000", exponent: "2.00000000", expected: 2500000000}, // 5^2 = 25 + {base: "10.00000000", exponent: "3.00000000", expected: 100000000000}, // 10^3 = 1000 + + // Negative exponents + {base: "2.00000000", exponent: "-1.00000000", expected: 50000000}, // 2^(-1) = 0.5 + {base: "4.00000000", exponent: "-1.00000000", expected: 25000000}, // 4^(-1) = 0.25 + {base: "10.00000000", exponent: "-2.00000000", expected: 1000000}, // 10^(-2) = 0.01 + + // Fractional bases + {base: "0.50000000", exponent: "2.00000000", expected: 25000000}, // 0.5^2 = 0.25 + {base: "1.50000000", exponent: "2.00000000", expected: 225000000}, // 1.5^2 = 2.25 + {base: "0.25000000", exponent: "3.00000000", expected: 1562500}, // 0.25^3 = 0.015625 + + // Fractional exponents + {base: "4.00000000", exponent: "0.50000000", expected: 200000000}, // 4^0.5 = 2 + {base: "9.00000000", exponent: "0.50000000", expected: 300000000}, // 9^0.5 = 3 + {base: "8.00000000", exponent: "0.33333333", expected: 199999999}, // 8^(1/3) ≈ 2 + + // Values from library test data + {base: "0.11111111", exponent: "2.00000000", expected: 1234568}, // (1/9)^2 + {base: "0.33333333", exponent: "3.00000000", expected: 3703704}, // (1/3)^3 + {base: "2.71828183", exponent: "1.00000000", expected: 271828183}, // e^1 + {base: "3.14159265", exponent: "-0.50000000", expected: 56418958}, // pi^(-0.5) + {base: "0.14285714", exponent: "2.00000000", expected: 2040816}, // (1/7)^2 + {base: "123.45678901", exponent: "0.50000000", expected: 1111111106}, // 123.45678901^0.5 + + // Repeating decimal bases with negative exponents + {base: "0.66666666", exponent: "-1.00000000", expected: 150000002}, // (2/3)^(-1) + {base: "0.50000000", exponent: "-2.00000000", expected: 400000000}, // 0.5^(-2) = 4 + + // Overflow + {base: "429496.72960000", exponent: "2.00000000", expectedError: true}, // sqrt(MaxUFix64)^2 overflows + {base: "10.00000000", exponent: "20.00000000", expectedError: true}, // 10^20 overflows + + // Underflow (truncated to 0 by handleFixedpointError) + {base: "0.00000003", exponent: "2.00000000", expected: 0}, // 0.00000003^2 underflows to 0 + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s ^ %s", tc.base, tc.exponent) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix64 { + let base: UFix64 = %s + let exponent: Fix64 = %s + return base.pow(exponent) + } + `, + tc.base, + tc.exponent, + ) + + inter := parseCheckAndPrepare(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUnmeteredUFix64Value(tc.expected) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + base string + exponent string + expected string + expectedError bool + } + + // Expected values were pre-computed using the fixed-point library's UFix128.Pow(Fix128). + testCases := []testCase{ + // Edge cases + {base: "0.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 0^0 = 1 + {base: "1.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 1^0 = 1 + {base: "1.000000000000000000000000", exponent: "5.000000000000000000000000", expected: "1.000000000000000000000000"}, // 1^5 = 1 + {base: "2.000000000000000000000000", exponent: "0.000000000000000000000000", expected: "1.000000000000000000000000"}, // 2^0 = 1 + {base: "2.000000000000000000000000", exponent: "1.000000000000000000000000", expected: "2.000000000000000000000000"}, // 2^1 = 2 + + // Integer exponents + {base: "2.000000000000000000000000", exponent: "3.000000000000000000000000", expected: "8.000000000000000000000000"}, // 2^3 = 8 + {base: "5.000000000000000000000000", exponent: "2.000000000000000000000000", expected: "25.000000000000000000000000"}, // 5^2 = 25 + {base: "10.000000000000000000000000", exponent: "3.000000000000000000000000", expected: "1000.000000000000000000000000"}, // 10^3 = 1000 + + // Negative exponents + {base: "2.000000000000000000000000", exponent: "-1.000000000000000000000000", expected: "0.500000000000000000000000"}, // 2^(-1) = 0.5 + {base: "4.000000000000000000000000", exponent: "-1.000000000000000000000000", expected: "0.250000000000000000000000"}, // 4^(-1) = 0.25 + + // Fractional base + {base: "0.500000000000000000000000", exponent: "2.000000000000000000000000", expected: "0.250000000000000000000000"}, // 0.5^2 = 0.25 + + // Fractional exponent + {base: "4.000000000000000000000000", exponent: "0.500000000000000000000000", expected: "2.000000000000000000000000"}, // 4^0.5 = 2 + {base: "9.000000000000000000000000", exponent: "0.500000000000000000000000", expected: "3.000000000000000000000000"}, // 9^0.5 = 3 + } + + for _, tc := range testCases { + + testName := fmt.Sprintf("%s ^ %s", tc.base, tc.exponent) + + t.Run(testName, func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test(): UFix128 { + let base: UFix128 = %s + let exponent: Fix128 = %s + return base.pow(exponent) + } + `, + tc.base, + tc.exponent, + ) + + inter := parseCheckAndPrepare(t, code) + + if tc.expectedError { + _, err := inter.Invoke("test") + require.Error(t, err) + } else { + result, err := inter.Invoke("test") + require.NoError(t, err) + + expected := interpreter.NewUFix128ValueFromBigInt(nil, fix128BigInt(tc.expected)) + AssertValuesEqual(t, inter, expected, result) + } + }) + } + }) +} + func TestInterpretNegativeZeroFixedPoint(t *testing.T) { t.Parallel() diff --git a/interpreter/value.go b/interpreter/value.go index 0374ce61a..c114db2f5 100644 --- a/interpreter/value.go +++ b/interpreter/value.go @@ -23,6 +23,8 @@ import ( "github.com/onflow/atree" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/sema" ) @@ -322,6 +324,12 @@ type FixedPointValue interface { NumberValue IntegerPart() NumberValue Scale() int + MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, + ) NumberValue } type AuthorizedValue interface { diff --git a/interpreter/value_fix128.go b/interpreter/value_fix128.go index 1d5a6077a..b82bd11a4 100644 --- a/interpreter/value_fix128.go +++ b/interpreter/value_fix128.go @@ -370,6 +370,43 @@ func (v Fix128Value) Mod(context NumberValueArithmeticContext, other NumberValue return NewFix128Value(context, valueGetter) } +func (v Fix128Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(Fix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(Fix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() fix.Fix128 { + result, err := fix.Fix128(v).FMD( + fix.Fix128(f), + fix.Fix128(d), + rounding, + ) + handleFixedpointError(err) + return result + } + + return NewFix128Value(context, valueGetter) +} + func (v Fix128Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(Fix128Value) if !ok { diff --git a/interpreter/value_fix64.go b/interpreter/value_fix64.go index 08a4da79c..1e68328b9 100644 --- a/interpreter/value_fix64.go +++ b/interpreter/value_fix64.go @@ -387,6 +387,42 @@ func (v Fix64Value) Mod(context NumberValueArithmeticContext, other NumberValue) return v.Minus(context, truncatedQuotient.Mul(context, o)) } +func (v Fix64Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(Fix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(Fix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() int64 { + a := fix.Fix64(uint64(v)) + b := fix.Fix64(uint64(f)) + c := fix.Fix64(uint64(d)) + result, err := a.FMD(b, c, rounding) + handleFixedpointError(err) + return int64(result) + } + + return NewFix64Value(context, valueGetter) +} + func (v Fix64Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(Fix64Value) if !ok { diff --git a/interpreter/value_number.go b/interpreter/value_number.go index d1a96b228..8766e9de0 100644 --- a/interpreter/value_number.go +++ b/interpreter/value_number.go @@ -21,7 +21,10 @@ package interpreter import ( "math/big" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/common" + "github.com/onflow/cadence/errors" "github.com/onflow/cadence/sema" ) @@ -103,6 +106,30 @@ func getNumberValueFunctionMember( sema.SaturatingArithmeticTypeFunctionTypes[typ], NativeNumberSaturatingDivideFunction, ) + + case sema.FixedPointNumericTypePowFunctionName: + funcType, ok := sema.FixedPointPowFunctionTypes[typ] + if !ok { + return nil + } + return NewBoundHostFunctionValue( + context, + v, + funcType, + NativeFixedPointPowFunction, + ) + + case sema.FixedPointNumericTypeMultiplyDivideFunctionName: + funcType, ok := sema.FixedPointMultiplyDivideFunctionTypes[typ] + if !ok { + return nil + } + return NewBoundHostFunctionValue( + context, + v, + funcType, + NativeFixedPointMultiplyDivideFunction, + ) } return nil @@ -214,3 +241,44 @@ var NativeNumberSaturatingDivideFunction = NativeFunction( return receiver.(NumberValue).SaturatingDiv(context, other) }, ) + +var NativeFixedPointMultiplyDivideFunction = NativeFunction( + func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + _ ArgumentTypesIterator, + receiver Value, + args []Value, + ) Value { + factor := AssertValueOfType[FixedPointValue](args[0]) + divisor := AssertValueOfType[FixedPointValue](args[1]) + var rounding fix.RoundingMode + if len(args) > 2 { + rounding = extractRoundingRule(args[2]) + } else { + rounding = fix.RoundTruncate + } + return receiver.(FixedPointValue).MultiplyDivide(context, factor, divisor, rounding) + }, +) + +var NativeFixedPointPowFunction = NativeFunction( + func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + _ ArgumentTypesIterator, + receiver Value, + args []Value, + ) Value { + switch v := receiver.(type) { + case UFix64Value: + exponent := AssertValueOfType[Fix64Value](args[0]) + return v.Pow(context, exponent) + case UFix128Value: + exponent := AssertValueOfType[Fix128Value](args[0]) + return v.Pow(context, exponent) + default: + panic(errors.NewUnreachableError()) + } + }, +) diff --git a/interpreter/value_ufix128.go b/interpreter/value_ufix128.go index d4ec4dfb5..1480f1877 100644 --- a/interpreter/value_ufix128.go +++ b/interpreter/value_ufix128.go @@ -362,6 +362,53 @@ func (v UFix128Value) Mod(context NumberValueArithmeticContext, other NumberValu return NewUFix128Value(context, valueGetter) } +func (v UFix128Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(UFix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(UFix128Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() fix.UFix128 { + result, err := fix.UFix128(v).FMD( + fix.UFix128(f), + fix.UFix128(d), + rounding, + ) + handleFixedpointError(err) + return result + } + + return NewUFix128Value(context, valueGetter) +} + +func (v UFix128Value) Pow(context NumberValueArithmeticContext, other Fix128Value) NumberValue { + valueGetter := func() fix.UFix128 { + result, err := fix.UFix128(v).Pow(fix.Fix128(other)) + handleFixedpointError(err) + return result + } + + return NewUFix128Value(context, valueGetter) +} + func (v UFix128Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(UFix128Value) if !ok { diff --git a/interpreter/value_ufix64.go b/interpreter/value_ufix64.go index a9908b6d6..b219573de 100644 --- a/interpreter/value_ufix64.go +++ b/interpreter/value_ufix64.go @@ -393,6 +393,54 @@ func (v UFix64Value) Mod(context NumberValueArithmeticContext, other NumberValue return UFix64Value{UFix64Value: result} } +func (v UFix64Value) MultiplyDivide( + context NumberValueArithmeticContext, + factor FixedPointValue, + divisor FixedPointValue, + rounding fix.RoundingMode, +) NumberValue { + f, ok := factor.(UFix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: factor.StaticType(context), + }) + } + + d, ok := divisor.(UFix64Value) + if !ok { + panic(&InvalidOperandsError{ + FunctionName: sema.FixedPointNumericTypeMultiplyDivideFunctionName, + LeftType: v.StaticType(context), + RightType: divisor.StaticType(context), + }) + } + + valueGetter := func() uint64 { + a := fix.UFix64(uint64(v.UFix64Value)) + b := fix.UFix64(uint64(f.UFix64Value)) + c := fix.UFix64(uint64(d.UFix64Value)) + result, err := a.FMD(b, c, rounding) + handleFixedpointError(err) + return uint64(result) + } + + return NewUFix64Value(context, valueGetter) +} + +func (v UFix64Value) Pow(context NumberValueArithmeticContext, other Fix64Value) NumberValue { + valueGetter := func() uint64 { + a := fix.UFix64(uint64(v.UFix64Value)) + b := fix.Fix64(uint64(other)) + result, err := a.Pow(b) + handleFixedpointError(err) + return uint64(result) + } + + return NewUFix64Value(context, valueGetter) +} + func (v UFix64Value) Less(context ValueComparisonContext, other ComparableValue) BoolValue { o, ok := other.(UFix64Value) if !ok { diff --git a/sema/fixedpoint_test.go b/sema/fixedpoint_test.go index 2b1584142..7ac560b82 100644 --- a/sema/fixedpoint_test.go +++ b/sema/fixedpoint_test.go @@ -27,11 +27,219 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/format" "github.com/onflow/cadence/sema" + "github.com/onflow/cadence/stdlib" . "github.com/onflow/cadence/test_utils/sema_utils" ) +func TestCheckFixedPointMultiplyDivide(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + parseAndCheckWithRoundingRule := func(t *testing.T, code string) (*sema.Checker, error) { + return ParseAndCheckWithOptions(t, + code, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + } + + for _, fixedPointType := range common.Concat( + sema.AllSignedFixedPointTypes, + sema.AllUnsignedFixedPointTypes, + ) { + + t.Run(fixedPointType.String(), func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := parseAndCheckWithRoundingRule(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: %[1]s = 3.0 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c, rounding: RoundingRule.towardZero) + `, + fixedPointType, + ), + ) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, fixedPointType, resultType) + }) + + t.Run("invalid, wrong factor type", func(t *testing.T) { + t.Parallel() + + _, err := parseAndCheckWithRoundingRule(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: Int = 3 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c, rounding: RoundingRule.towardZero) + `, + fixedPointType, + ), + ) + require.Error(t, err) + }) + + t.Run("valid, without rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, + fmt.Sprintf( + ` + let a: %[1]s = 2.0 + let b: %[1]s = 3.0 + let c: %[1]s = 1.0 + let result = a.multiplyDivide(b, c) + `, + fixedPointType, + ), + ) + require.NoError(t, err) + }) + }) + } + + t.Run("not available on integer types", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let a: Int = 2 + let result = a.multiplyDivide(3, 1, rounding: RoundingRule.towardZero) + `) + require.Error(t, err) + }) +} + +func TestCheckFixedPointPow(t *testing.T) { + + t.Parallel() + + t.Run("UFix64", func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let result = 2.0.pow(3.0) + `) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, sema.UFix64Type, resultType) + }) + + t.Run("valid, negative exponent", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let result = 2.0.pow(-1.0) + `) + require.NoError(t, err) + }) + + t.Run("valid, explicit types", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: UFix64 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.NoError(t, err) + }) + + t.Run("invalid, wrong exponent type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let result = 2.0.pow(3.0 as UFix64) + `) + require.Error(t, err) + }) + }) + + t.Run("UFix128", func(t *testing.T) { + + t.Parallel() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let base: UFix128 = 2.0 + let exponent: Fix128 = 3.0 + let result = base.pow(exponent) + `) + require.NoError(t, err) + + resultType := RequireGlobalValue(t, checker.Elaboration, "result") + assert.Equal(t, sema.UFix128Type, resultType) + }) + + t.Run("invalid, wrong exponent type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: UFix128 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + }) + + t.Run("not available on signed types", func(t *testing.T) { + + t.Parallel() + + t.Run("Fix64", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: Fix64 = 2.0 + let exponent: Fix64 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + + t.Run("Fix128", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + let base: Fix128 = 2.0 + let exponent: Fix128 = 3.0 + let result = base.pow(exponent) + `) + require.Error(t, err) + }) + }) +} + func TestCheckFixedPointLiteralTypeConversionInVariableDeclaration(t *testing.T) { t.Parallel() diff --git a/sema/type.go b/sema/type.go index e037dc0ca..45f8c327f 100644 --- a/sema/type.go +++ b/sema/type.go @@ -1326,17 +1326,26 @@ func registerSaturatingArithmeticType(t Type) { ) } -func addSaturatingArithmeticFunctions(t SaturatingArithmeticType, members map[string]MemberResolver) { +func addSaturatingArithmeticFunctions( + t SaturatingArithmeticType, + members map[string]MemberResolver, +) { + functionType := SaturatingArithmeticTypeFunctionTypes[t] addArithmeticFunction := func(name string, docString string) { members[name] = MemberResolver{ Kind: common.DeclarationKindFunction, - Resolve: func(memoryGauge common.MemoryGauge, _ string, _ ast.HasPosition, _ func(error)) *Member { + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { return NewPublicFunctionMember( memoryGauge, t, name, - SaturatingArithmeticTypeFunctionTypes[t], + functionType, docString, ) }, @@ -1379,6 +1388,110 @@ type SaturatingArithmeticSupport struct { Divide bool } +const FixedPointNumericTypePowFunctionName = "pow" +const fixedPointNumericTypePowFunctionDocString = ` +Returns this value raised to the power of the given exponent. +The exponent may be negative or fractional. +` + +var FixedPointPowFunctionTypes = map[Type]*FunctionType{} + +func registerFixedPointPowFunction(t *FixedPointNumericType, exponentType *FixedPointNumericType) { + FixedPointPowFunctionTypes[t] = NewSimpleFunctionType( + FunctionPurityView, + []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "exponent", + TypeAnnotation: NewTypeAnnotation(exponentType), + }, + }, + NewTypeAnnotation(t), + ) +} + +func addFixedPointPowFunction( + t *FixedPointNumericType, + members map[string]MemberResolver, +) { + functionType := FixedPointPowFunctionTypes[t] + + members[FixedPointNumericTypePowFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { + return NewPublicFunctionMember( + memoryGauge, + t, + FixedPointNumericTypePowFunctionName, + functionType, + fixedPointNumericTypePowFunctionDocString, + ) + }, + } +} + +const FixedPointNumericTypeMultiplyDivideFunctionName = "multiplyDivide" +const fixedPointNumericTypeMultiplyDivideFunctionDocString = ` +Returns self * factor / divisor, without intermediate rounding +` + +var FixedPointMultiplyDivideFunctionTypes = map[Type]*FunctionType{} + +func registerFixedPointMultiplyDivideFunction(t *FixedPointNumericType) { + FixedPointMultiplyDivideFunctionTypes[t] = &FunctionType{ + Purity: FunctionPurityView, + Parameters: []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "factor", + TypeAnnotation: NewTypeAnnotation(t), + }, + { + Label: ArgumentLabelNotRequired, + Identifier: "divisor", + TypeAnnotation: NewTypeAnnotation(t), + }, + { + Label: "rounding", + Identifier: "rounding", + TypeAnnotation: RoundingRuleTypeAnnotation, + }, + }, + Arity: &Arity{Min: 2, Max: 3}, + ReturnTypeAnnotation: NewTypeAnnotation(t), + } +} + +func addFixedPointMultiplyDivideFunction( + t *FixedPointNumericType, + members map[string]MemberResolver, +) { + functionType := FixedPointMultiplyDivideFunctionTypes[t] + + members[FixedPointNumericTypeMultiplyDivideFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + _ string, + _ ast.HasPosition, + _ func(error), + ) *Member { + return NewPublicFunctionMember( + memoryGauge, + t, + FixedPointNumericTypeMultiplyDivideFunctionName, + functionType, + fixedPointNumericTypeMultiplyDivideFunctionDocString, + ) + }, + } +} + // NumericType represent all the types in the integer range // and non-fractional ranged types. type NumericType struct { @@ -1634,9 +1747,11 @@ var _ FractionalRangedType = &FixedPointNumericType{} var _ SaturatingArithmeticType = &FixedPointNumericType{} func NewFixedPointNumericType(typeName string) *FixedPointNumericType { - return &FixedPointNumericType{ + t := &FixedPointNumericType{ name: typeName, } + registerFixedPointMultiplyDivideFunction(t) + return t } func (t *FixedPointNumericType) Tag() TypeTag { @@ -1677,6 +1792,11 @@ func (t *FixedPointNumericType) WithSaturatingFunctions(saturatingArithmetic Sat return t } +func (t *FixedPointNumericType) WithPowFunction(exponentType *FixedPointNumericType) *FixedPointNumericType { + registerFixedPointPowFunction(t, exponentType) + return t +} + func (t *FixedPointNumericType) SupportsSaturatingAdd() bool { return t.saturatingArithmetic.Add } @@ -1817,6 +1937,10 @@ func (t *FixedPointNumericType) GetMembers() map[string]MemberResolver { // Compute members and cache them computedMembers := map[string]MemberResolver{} addSaturatingArithmeticFunctions(t, computedMembers) + if _, ok := FixedPointPowFunctionTypes[t]; ok { + addFixedPointPowFunction(t, computedMembers) + } + addFixedPointMultiplyDivideFunction(t, computedMembers) computedMembers = withBuiltinMembers(t, computedMembers) t.memberResolvers.Store(&computedMembers) return computedMembers @@ -2142,7 +2266,8 @@ var UFix64Type = NewFixedPointNumericType(UFix64TypeName). Add: true, Subtract: true, Multiply: true, - }) + }). + WithPowFunction(Fix64Type) var UFix64TypeAnnotation = NewTypeAnnotation(UFix64Type) // UFix128Type represents the 128-bit unsigned decimal fixed-point type `UFix128` @@ -2156,7 +2281,8 @@ var UFix128Type = NewFixedPointNumericType(UFix128TypeName). Add: true, Subtract: true, Multiply: true, - }) + }). + WithPowFunction(Fix128Type) var UFix128TypeAnnotation = NewTypeAnnotation(UFix128Type) // Numeric type ranges