diff --git a/bbq/vm/builtin_globals.go b/bbq/vm/builtin_globals.go index 81e2905bd..dde59dfd7 100644 --- a/bbq/vm/builtin_globals.go +++ b/bbq/vm/builtin_globals.go @@ -228,13 +228,21 @@ func init() { for _, declaration := range interpreter.ConverterDeclarations { // NOTE: declare in loop, as captured in closure below convert := declaration.Convert + convertWithRounding := declaration.ConvertWithRounding functionType := sema.BaseValueActivation.Find(declaration.Name).Type.(*sema.FunctionType) + var nativeFn interpreter.NativeFunction + if convertWithRounding != nil { + nativeFn = interpreter.NativeConverterFunctionWithRounding(convert, convertWithRounding) + } else { + nativeFn = interpreter.NativeConverterFunction(convert) + } + function := NewNativeFunctionValue( declaration.Name, functionType, - interpreter.NativeConverterFunction(convert), + nativeFn, ) registerBuiltinFunction(function) diff --git a/bbq/vm/value_function.go b/bbq/vm/value_function.go index 1cdc37738..1270705b1 100644 --- a/bbq/vm/value_function.go +++ b/bbq/vm/value_function.go @@ -315,6 +315,13 @@ func (v *NativeFunctionValue) GetMember( ) } +func (v *NativeFunctionValue) SetField(name string, value interpreter.Value) { + if v.fields == nil { + v.fields = make(map[string]interpreter.Value) + } + v.fields[name] = value +} + func (*NativeFunctionValue) RemoveMember(_ interpreter.ValueTransferContext, _ string) interpreter.Value { panic(errors.NewUnreachableError()) } diff --git a/interpreter/fixedpoint_test.go b/interpreter/fixedpoint_test.go index 1534cff97..bbbd1dea1 100644 --- a/interpreter/fixedpoint_test.go +++ b/interpreter/fixedpoint_test.go @@ -28,10 +28,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/activations" "github.com/onflow/cadence/ast" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/fixedpoint" "github.com/onflow/cadence/interpreter" "github.com/onflow/cadence/sema" + "github.com/onflow/cadence/stdlib" . "github.com/onflow/cadence/test_utils/common_utils" . "github.com/onflow/cadence/test_utils/interpreter_utils" . "github.com/onflow/cadence/test_utils/sema_utils" @@ -1350,3 +1353,544 @@ func TestInterpretFixedPointLeastSignificantDecimalHandling(t *testing.T) { } }) } + +func parseCheckAndPrepareWithRoundingRule(t *testing.T, code string) Invokable { + t.Helper() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + baseActivation := activations.NewActivation(nil, interpreter.BaseActivation) + + valueDeclaration := stdlib.InterpreterRoundingRuleConstructor + baseValueActivation.DeclareValue(valueDeclaration) + interpreter.Declare(baseActivation, valueDeclaration) + + invokable, err := parseCheckAndPrepareWithOptions( + t, + code, + ParseCheckAndInterpretOptions{ + ParseAndCheckOptions: &ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + InterpreterConfig: &interpreter.Config{ + BaseActivationHandler: func(_ common.Location) *interpreter.VariableActivation { + return baseActivation + }, + }, + }, + ) + require.NoError(t, err) + + return invokable +} + +func TestInterpretFix64WithRoundingRule(t *testing.T) { + t.Parallel() + + // Fix128 has 24 decimal places, Fix64 has 8. + // Rounding applies to the 9th+ decimal places when converting Fix128 → Fix64. + // + // Test values: + // 1.000000003... (9th digit < 5): non-halfway, fractional part below midpoint + // 1.000000005... (9th digit = 5): exact halfway between 1.00000000 and 1.00000001 + // 1.000000007... (9th digit > 5): non-halfway, fractional part above midpoint + // + // Expected results per rounding mode: + // towardZero: always truncate → 1.00000000 (positive), -1.00000000 (negative) + // awayFromZero: always round up magnitude → 1.00000001 (positive), -1.00000001 (negative) + // nearestHalfAway: < 5 → 1.00000000, = 5 → 1.00000001 (away), > 5 → 1.00000001 + // nearestHalfEven: < 5 → 1.00000000, = 5 → 1.00000000 (even), > 5 → 1.00000001 + + type testCase struct { + name string + code string + expected interpreter.Fix64Value + } + + tests := []testCase{ + // towardZero: truncates fractional part beyond 8 decimals + { + name: "towardZero, positive, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 + }, + { + name: "towardZero, positive, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 + }, + { + name: "towardZero, positive, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 + }, + { + name: "towardZero, negative, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000000), // -1.00000000 + }, + { + name: "towardZero, negative, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000000), // -1.00000000 + }, + { + name: "towardZero, negative, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000000), // -1.00000000 + }, + + // awayFromZero: rounds up magnitude for any nonzero fractional part + { + name: "awayFromZero, positive, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 + }, + { + name: "awayFromZero, positive, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 + }, + { + name: "awayFromZero, positive, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 + }, + { + name: "awayFromZero, negative, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000001), // -1.00000001 + }, + { + name: "awayFromZero, negative, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000001), // -1.00000001 + }, + { + name: "awayFromZero, negative, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000001), // -1.00000001 + }, + + // nearestHalfAway: nearest, tie breaks away from zero + { + name: "nearestHalfAway, positive, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 + }, + { + name: "nearestHalfAway, positive, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 (tie → away) + }, + { + name: "nearestHalfAway, positive, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 + }, + { + name: "nearestHalfAway, negative, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000001), // -1.00000001 (tie → away) + }, + + // nearestHalfEven: nearest, tie breaks to even + { + name: "nearestHalfEven, positive, non-halfway below", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000003000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 + }, + { + name: "nearestHalfEven, positive, exact halfway (last digit 0 is even)", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000000), // 1.00000000 (tie → even, 0 is even) + }, + { + name: "nearestHalfEven, positive, non-halfway above", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000007000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000001), // 1.00000001 + }, + { + name: "nearestHalfEven, positive, exact halfway (last digit 1 is odd)", + code: `fun main(): Fix64 { + let x: Fix128 = 1.000000015000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredFix64Value(100000002), // 1.00000002 (tie → even, 2 is even) + }, + { + name: "nearestHalfEven, negative, exact halfway", + code: `fun main(): Fix64 { + let x: Fix128 = -1.000000005000000000000000 + return Fix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredFix64Value(-100000000), // -1.00000000 (tie → even) + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, tc.code) + result, err := invokable.Invoke("main") + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } + + t.Run("backward compat, no rounding truncates", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepare(t, ` + fun main(): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x) + } + `) + + result, err := invokable.Invoke("main") + require.NoError(t, err) + assert.Equal(t, interpreter.NewUnmeteredFix64Value(100000000), result) + }) + + t.Run("integer conversion ignores rounding", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): Fix64 { + return Fix64(42, rounding: RoundingRule.awayFromZero) + } + `) + + result, err := invokable.Invoke("main") + require.NoError(t, err) + assert.Equal(t, interpreter.NewUnmeteredFix64ValueWithInteger(42), result) + }) +} + +func TestInterpretUFix64WithRoundingRule(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + code string + expected interpreter.UFix64Value + } + + tests := []testCase{ + // towardZero + { + name: "towardZero, non-halfway below", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000003000000000000000 + return UFix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), + }, + { + name: "towardZero, exact halfway", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), + }, + { + name: "towardZero, non-halfway above", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000007000000000000000 + return UFix64(x, rounding: RoundingRule.towardZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), + }, + + // awayFromZero + { + name: "awayFromZero, non-halfway below", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000003000000000000000 + return UFix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), + }, + { + name: "awayFromZero, exact halfway", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), + }, + { + name: "awayFromZero, non-halfway above", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000007000000000000000 + return UFix64(x, rounding: RoundingRule.awayFromZero) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), + }, + + // nearestHalfAway + { + name: "nearestHalfAway, non-halfway below", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000003000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), + }, + { + name: "nearestHalfAway, exact halfway", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), // tie → away + }, + { + name: "nearestHalfAway, non-halfway above", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000007000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfAway) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), + }, + + // nearestHalfEven + { + name: "nearestHalfEven, non-halfway below", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000003000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), + }, + { + name: "nearestHalfEven, exact halfway (last digit 0 is even)", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000000), // tie → even, 0 is even + }, + { + name: "nearestHalfEven, non-halfway above", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000007000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000001), + }, + { + name: "nearestHalfEven, exact halfway (last digit 1 is odd)", + code: `fun main(): UFix64 { + let x: UFix128 = 1.000000015000000000000000 + return UFix64(x, rounding: RoundingRule.nearestHalfEven) + }`, + expected: interpreter.NewUnmeteredUFix64Value(100000002), // tie → even, 2 is even + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, tc.code) + result, err := invokable.Invoke("main") + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } + + t.Run("backward compat, no rounding truncates", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepare(t, ` + fun main(): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x) + } + `) + + result, err := invokable.Invoke("main") + require.NoError(t, err) + assert.Equal(t, interpreter.NewUnmeteredUFix64Value(100000000), result) + }) +} + +func TestInterpretFix64WithRoundingRuleOverflow(t *testing.T) { + t.Parallel() + + t.Run("Fix128 overflow", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): Fix64 { + // Fix128 max is much larger than Fix64 max + let x: Fix128 = Fix128.max + return Fix64(x, rounding: RoundingRule.towardZero) + } + `) + + _, err := invokable.Invoke("main") + RequireError(t, err) + var expectedError *interpreter.OverflowError + require.ErrorAs(t, err, &expectedError) + }) + + t.Run("Fix128 negative overflow", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): Fix64 { + let x: Fix128 = Fix128.min + return Fix64(x, rounding: RoundingRule.towardZero) + } + `) + + _, err := invokable.Invoke("main") + RequireError(t, err) + var expectedError *interpreter.UnderflowError + require.ErrorAs(t, err, &expectedError) + }) + + t.Run("rounding causes overflow", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): Fix64 { + // Fix64.max as Fix128, plus a fraction that would round up + let x: Fix128 = 92233720368.547758079999999999999999 + return Fix64(x, rounding: RoundingRule.awayFromZero) + } + `) + + _, err := invokable.Invoke("main") + RequireError(t, err) + var expectedError *interpreter.OverflowError + require.ErrorAs(t, err, &expectedError) + }) +} + +func TestInterpretUFix64WithRoundingRuleOverflow(t *testing.T) { + t.Parallel() + + t.Run("UFix128 overflow", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): UFix64 { + let x: UFix128 = UFix128.max + return UFix64(x, rounding: RoundingRule.towardZero) + } + `) + + _, err := invokable.Invoke("main") + RequireError(t, err) + var expectedError *interpreter.OverflowError + require.ErrorAs(t, err, &expectedError) + }) + + t.Run("Fix128 negative to UFix64", func(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): UFix64 { + let x: Fix128 = -1.0 + return UFix64(x, rounding: RoundingRule.towardZero) + } + `) + + _, err := invokable.Invoke("main") + RequireError(t, err) + var expectedError *interpreter.UnderflowError + require.ErrorAs(t, err, &expectedError) + }) +} + +func TestInterpretRoundingRuleEnum(t *testing.T) { + t.Parallel() + + invokable := parseCheckAndPrepareWithRoundingRule(t, ` + fun main(): [UInt8] { + return [ + RoundingRule.towardZero.rawValue, + RoundingRule.awayFromZero.rawValue, + RoundingRule.nearestHalfAway.rawValue, + RoundingRule.nearestHalfEven.rawValue + ] + } + `) + + result, err := invokable.Invoke("main") + require.NoError(t, err) + + arrayValue := result.(*interpreter.ArrayValue) + require.Equal(t, 4, arrayValue.Count()) + + assert.Equal(t, interpreter.UInt8Value(0), arrayValue.Get(nil, 0)) + assert.Equal(t, interpreter.UInt8Value(1), arrayValue.Get(nil, 1)) + assert.Equal(t, interpreter.UInt8Value(2), arrayValue.Get(nil, 2)) + assert.Equal(t, interpreter.UInt8Value(3), arrayValue.Get(nil, 3)) +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 20fedbb73..c7a481b74 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -34,6 +34,8 @@ import ( "github.com/onflow/atree" "go.opentelemetry.io/otel/attribute" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/activations" "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" @@ -3510,10 +3512,11 @@ var BigEndianBytesConverters = func() map[string]TypedBigEndianBytesConverter { }() type ValueConverterDeclaration struct { - Min Value - Max Value - Convert func(common.MemoryGauge, Value) Value - nestedVariables []struct { + Min Value + Max Value + Convert func(common.MemoryGauge, Value) Value + ConvertWithRounding func(common.MemoryGauge, Value, fix.RoundingMode) Value + nestedVariables []struct { Name string Value Value } @@ -3678,6 +3681,9 @@ var ConverterDeclarations = []ValueConverterDeclaration{ Convert: func(gauge common.MemoryGauge, value Value) Value { return ConvertFix64(gauge, value) }, + ConvertWithRounding: func(gauge common.MemoryGauge, value Value, round fix.RoundingMode) Value { + return ConvertFix64WithRounding(gauge, value, round) + }, Min: NewUnmeteredFix64Value(math.MinInt64), Max: NewUnmeteredFix64Value(math.MaxInt64), }, @@ -3694,6 +3700,9 @@ var ConverterDeclarations = []ValueConverterDeclaration{ Convert: func(gauge common.MemoryGauge, value Value) Value { return ConvertUFix64(gauge, value) }, + ConvertWithRounding: func(gauge common.MemoryGauge, value Value, round fix.RoundingMode) Value { + return ConvertUFix64WithRounding(gauge, value, round) + }, Min: NewUnmeteredUFix64Value(0), Max: NewUnmeteredUFix64Value(math.MaxUint64), }, @@ -4158,12 +4167,20 @@ var converterFunctionValues = func() []converterFunction { for index, declaration := range ConverterDeclarations { // NOTE: declare in loop, as captured in closure below convert := declaration.Convert + convertWithRounding := declaration.ConvertWithRounding converterFunctionType := sema.BaseValueActivation.Find(declaration.Name).Type.(*sema.FunctionType) + var nativeFn NativeFunction + if convertWithRounding != nil { + nativeFn = NativeConverterFunctionWithRounding(convert, convertWithRounding) + } else { + nativeFn = NativeConverterFunction(convert) + } + converterFunctionValue := NewUnmeteredStaticHostFunctionValueFromNativeFunction( converterFunctionType, - NativeConverterFunction(convert), + nativeFn, ) addMember := func(name string, value Value) { @@ -4434,6 +4451,25 @@ func NativeConverterFunction(convert func(memoryGauge common.MemoryGauge, value } } +func NativeConverterFunctionWithRounding( + convert func(memoryGauge common.MemoryGauge, value Value) Value, + convertWithRounding func(memoryGauge common.MemoryGauge, value Value, round fix.RoundingMode) Value, +) NativeFunction { + return func( + context NativeFunctionContext, + _ TypeArgumentsIterator, + _ ArgumentTypesIterator, + _ Value, + args []Value, + ) Value { + if len(args) > 1 { + roundingRule := extractRoundingRule(args[1]) + return convertWithRounding(context, args[0], roundingRule) + } + return convert(context, args[0]) + } +} + func NativeFromStringFunction(parser StringValueParser) NativeFunction { return func( context NativeFunctionContext, diff --git a/interpreter/value_fix64.go b/interpreter/value_fix64.go index 3e4842467..08a4da79c 100644 --- a/interpreter/value_fix64.go +++ b/interpreter/value_fix64.go @@ -27,6 +27,8 @@ import ( "github.com/onflow/atree" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" "github.com/onflow/cadence/errors" @@ -514,6 +516,40 @@ func ConvertFix64(memoryGauge common.MemoryGauge, value Value) Fix64Value { } } +func ConvertFix64WithRounding(memoryGauge common.MemoryGauge, value Value, roundingRule fix.RoundingMode) Fix64Value { + switch value := value.(type) { + case Fix128Value: + return NewFix64Value( + memoryGauge, + func() int64 { + result, err := fix.Fix128(value).ToFix64(roundingRule) + if err != nil { + handleFixedPointConversionError(err) + } + return int64(result) + }, + ) + + case UFix128Value: + return NewFix64Value( + memoryGauge, + func() int64 { + result, err := fix.UFix128(value).ToUFix64(roundingRule) + if err != nil { + handleFixedPointConversionError(err) + } + if uint64(result) > Fix64MaxValue { + panic(&OverflowError{}) + } + return int64(result) + }, + ) + + default: + return ConvertFix64(memoryGauge, value) + } +} + func (v Fix64Value) GetMember(context MemberAccessibleContext, name string, memberKind common.DeclarationKind) Value { return GetMember( context, diff --git a/interpreter/value_fixedpoint.go b/interpreter/value_fixedpoint.go new file mode 100644 index 000000000..a95ed3f6f --- /dev/null +++ b/interpreter/value_fixedpoint.go @@ -0,0 +1,60 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter + +import ( + fix "github.com/onflow/fixed-point" + + "github.com/onflow/cadence/errors" + "github.com/onflow/cadence/sema" +) + +func extractRoundingRule(value Value) fix.RoundingMode { + composite, ok := value.(*SimpleCompositeValue) + if !ok { + panic(errors.NewUnreachableError()) + } + rawValue, ok := composite.Fields[sema.EnumRawValueFieldName].(UInt8Value) + if !ok { + panic(errors.NewUnreachableError()) + } + return fix.RoundingMode(rawValue) +} + +// handleFixedPointConversionError handles errors from the fixed-point library +// during narrowing conversions (e.g. Fix128 → Fix64). +// +// Unlike handleFixedpointError (used for Fix128 arithmetic), +// this function does NOT ignore UnderflowError: +// for narrowing conversions, a nonzero value that rounds to zero +// is a loss of the entire value, not just precision. +func handleFixedPointConversionError(err error) { + switch err.(type) { + case nil: + return + case fix.PositiveOverflowError: + panic(&OverflowError{}) + case fix.NegativeOverflowError: + panic(&UnderflowError{}) + case fix.UnderflowError: + panic(&UnderflowError{}) + default: + panic(err) + } +} diff --git a/interpreter/value_ufix64.go b/interpreter/value_ufix64.go index 46979bf77..a9908b6d6 100644 --- a/interpreter/value_ufix64.go +++ b/interpreter/value_ufix64.go @@ -26,6 +26,8 @@ import ( "github.com/onflow/atree" + fix "github.com/onflow/fixed-point" + "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" "github.com/onflow/cadence/errors" @@ -167,6 +169,42 @@ func ConvertUFix64(memoryGauge common.MemoryGauge, value Value) UFix64Value { } } +func ConvertUFix64WithRounding(memoryGauge common.MemoryGauge, value Value, roundingRule fix.RoundingMode) UFix64Value { + switch value := value.(type) { + case UFix128Value: + return NewUFix64Value( + memoryGauge, + func() uint64 { + result, err := fix.UFix128(value).ToUFix64(roundingRule) + if err != nil { + handleFixedPointConversionError(err) + } + return uint64(result) + }, + ) + + case Fix128Value: + return NewUFix64Value( + memoryGauge, + func() uint64 { + fix128 := fix.Fix128(value) + if fix128.IsNeg() { + panic(&UnderflowError{}) + } + // A non-negative Fix128 has the same bit representation as UFix128 + result, err := fix.UFix128(fix128).ToUFix64(roundingRule) + if err != nil { + handleFixedPointConversionError(err) + } + return uint64(result) + }, + ) + + default: + return ConvertUFix64(memoryGauge, value) + } +} + var _ Value = UFix64Value{} var _ atree.Storable = UFix64Value{} var _ NumberValue = UFix64Value{} diff --git a/runtime/account_test.go b/runtime/account_test.go index 3ad2141ca..0b00ee686 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -1223,6 +1223,7 @@ var AccountKeyType = ExportedBuiltinType(sema.AccountKeyType).(*cadence.StructTy var PublicKeyType = ExportedBuiltinType(sema.PublicKeyType).(*cadence.StructType) var SignAlgoType = ExportedBuiltinType(sema.SignatureAlgorithmType).(*cadence.EnumType) var HashAlgoType = ExportedBuiltinType(sema.HashAlgorithmType).(*cadence.EnumType) +var RoundingRuleEnumType = ExportedBuiltinType(sema.RoundingRuleType).(*cadence.EnumType) func ExportedBuiltinType(internalType sema.Type) cadence.Type { return ExportType(internalType, map[sema.TypeID]cadence.Type{}) diff --git a/runtime/convertValues.go b/runtime/convertValues.go index 7a5465b17..fa7ad08be 100644 --- a/runtime/convertValues.go +++ b/runtime/convertValues.go @@ -1582,6 +1582,9 @@ func (i valueImporter) importCompositeValue( // (e.g. it has host functions) return i.importSignatureAlgorithm(fields) + case sema.RoundingRuleType: + return i.importRoundingRule(fields) + default: return nil, errors.NewDefaultUserError( "cannot import value of type %s", @@ -1771,3 +1774,56 @@ func (valueImporter) importSignatureAlgorithm( return caseValue, nil } + +func (valueImporter) importRoundingRule( + fields []interpreter.CompositeField, +) ( + interpreter.MemberAccessibleValue, + error, +) { + + var foundRawValue bool + var rawValue interpreter.UInt8Value + + ty := sema.RoundingRuleType + + for _, field := range fields { + switch field.Name { + case sema.EnumRawValueFieldName: + rawValue, foundRawValue = field.Value.(interpreter.UInt8Value) + if !foundRawValue { + return nil, errors.NewDefaultUserError( + "cannot import value of type '%s'. invalid value for field '%s': %v", + ty, + field.Name, + field.Value, + ) + } + + default: + return nil, errors.NewDefaultUserError( + "cannot import value of type '%s'. invalid field '%s'", + ty, + field.Name, + ) + } + } + + if !foundRawValue { + return nil, errors.NewDefaultUserError( + "cannot import value of type '%s'. missing field '%s'", + ty, + sema.EnumRawValueFieldName, + ) + } + + caseValue, ok := stdlib.RoundingRuleCaseValues[rawValue] + if !ok { + return nil, errors.NewDefaultUserError( + "unknown RoundingRule with rawValue %d", + rawValue, + ) + } + + return caseValue, nil +} diff --git a/runtime/crypto_test.go b/runtime/crypto_test.go index c4db701b9..271c0b24f 100644 --- a/runtime/crypto_test.go +++ b/runtime/crypto_test.go @@ -185,7 +185,7 @@ func TestRuntimeHashingAlgorithmExport(t *testing.T) { runtimeInterface := &TestRuntimeInterface{} nextScriptLocation := NewScriptLocationGenerator() - testHashAlgorithm := func(algo sema.CryptoAlgorithm) { + testHashAlgorithm := func(algo sema.NativeEnumCase) { script := fmt.Sprintf(` access(all) fun main(): HashAlgorithm { return HashAlgorithm.%s @@ -231,7 +231,7 @@ func TestRuntimeSignatureAlgorithmExport(t *testing.T) { runtimeInterface := &TestRuntimeInterface{} nextScriptLocation := NewScriptLocationGenerator() - testSignatureAlgorithm := func(algo sema.CryptoAlgorithm) { + testSignatureAlgorithm := func(algo sema.NativeEnumCase) { script := fmt.Sprintf(` access(all) fun main(): SignatureAlgorithm { return SignatureAlgorithm.%s @@ -288,7 +288,7 @@ func TestRuntimeSignatureAlgorithmImport(t *testing.T) { nextScriptLocation := NewScriptLocationGenerator() - testSignatureAlgorithm := func(algo sema.CryptoAlgorithm) { + testSignatureAlgorithm := func(algo sema.NativeEnumCase) { value, err := runtime.ExecuteScript( Script{ @@ -344,7 +344,7 @@ func TestRuntimeHashAlgorithmImport(t *testing.T) { } ` - testHashAlgorithm := func(algo sema.CryptoAlgorithm) { + testHashAlgorithm := func(algo sema.NativeEnumCase) { var logs []string var hashCalls int diff --git a/runtime/program_params_validation_test.go b/runtime/program_params_validation_test.go index eb2216297..4e5d98f0d 100644 --- a/runtime/program_params_validation_test.go +++ b/runtime/program_params_validation_test.go @@ -474,6 +474,13 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) { }, ).WithType(SignAlgoType) + case sema.RoundingRuleType: + value = cadence.NewEnum( + []cadence.Value{ + cadence.NewUInt8(0), + }, + ).WithType(RoundingRuleEnumType) + case sema.PublicKeyType: value = cadence.NewStruct( []cadence.Value{ @@ -1066,6 +1073,13 @@ func TestRuntimeTransactionParameterTypeValidation(t *testing.T) { }, ).WithType(SignAlgoType) + case sema.RoundingRuleType: + value = cadence.NewEnum( + []cadence.Value{ + cadence.NewUInt8(0), + }, + ).WithType(RoundingRuleEnumType) + case sema.PublicKeyType: value = cadence.NewStruct( []cadence.Value{ diff --git a/runtime/rounding_rule_test.go b/runtime/rounding_rule_test.go new file mode 100644 index 000000000..c14d8876d --- /dev/null +++ b/runtime/rounding_rule_test.go @@ -0,0 +1,242 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package runtime_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/encoding/json" + . "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/sema" + . "github.com/onflow/cadence/test_utils/runtime_utils" +) + +func newRoundingRuleArgument(rawValue uint8) cadence.Value { + return cadence.NewEnum([]cadence.Value{ + cadence.UInt8(rawValue), + }).WithType(cadence.NewEnumType( + nil, + sema.RoundingRuleTypeName, + cadence.UInt8Type, + []cadence.Field{ + { + Identifier: sema.EnumRawValueFieldName, + Type: cadence.UInt8Type, + }, + }, + nil, + )) +} + +func TestRuntimeRoundingRuleExport(t *testing.T) { + + t.Parallel() + + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{} + nextScriptLocation := NewScriptLocationGenerator() + + testRoundingRule := func(rule sema.NativeEnumCase) { + script := fmt.Sprintf(` + access(all) fun main(): RoundingRule { + return RoundingRule.%s + } + `, + rule.Name(), + ) + + value, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: nextScriptLocation(), + UseVM: *compile, + }, + ) + + require.NoError(t, err) + + require.IsType(t, cadence.Enum{}, value) + enumValue := value.(cadence.Enum) + + fields := cadence.FieldsMappedByName(enumValue) + require.Len(t, fields, 1) + assert.Equal(t, + cadence.NewUInt8(rule.RawValue()), + fields[sema.EnumRawValueFieldName], + ) + } + + for _, rule := range sema.RoundingRules { + testRoundingRule(rule) + } +} + +func TestRuntimeRoundingRuleImport(t *testing.T) { + + t.Parallel() + + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{ + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + nextScriptLocation := NewScriptLocationGenerator() + + const script = ` + access(all) fun main(rule: RoundingRule): UInt8 { + return rule.rawValue + } + ` + + testRoundingRule := func(rule sema.NativeEnumCase) { + + value, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: encodeArgs([]cadence.Value{ + newRoundingRuleArgument(rule.RawValue()), + }), + }, + Context{ + Interface: runtimeInterface, + Location: nextScriptLocation(), + UseVM: *compile, + }, + ) + + require.NoError(t, err) + assert.Equal(t, cadence.UInt8(rule.RawValue()), value) + } + + for _, rule := range sema.RoundingRules { + testRoundingRule(rule) + } +} + +func TestRuntimeRoundingRuleImportInvalid(t *testing.T) { + + t.Parallel() + + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{ + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + nextScriptLocation := NewScriptLocationGenerator() + + const script = ` + access(all) fun main(rule: RoundingRule): UInt8 { + return rule.rawValue + } + ` + + _, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: encodeArgs([]cadence.Value{ + newRoundingRuleArgument(99), // invalid raw value + }), + }, + Context{ + Interface: runtimeInterface, + Location: nextScriptLocation(), + UseVM: *compile, + }, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown RoundingRule") +} + +func TestRuntimeFix64ConversionWithRoundingRuleArgument(t *testing.T) { + + t.Parallel() + + runtime := NewTestRuntime() + runtimeInterface := &TestRuntimeInterface{ + OnDecodeArgument: func(b []byte, t cadence.Type) (cadence.Value, error) { + return json.Decode(nil, b) + }, + } + nextScriptLocation := NewScriptLocationGenerator() + + t.Run("Fix128 to Fix64 with rounding", func(t *testing.T) { + t.Parallel() + + const script = ` + access(all) fun main(rule: RoundingRule): Fix64 { + let x: Fix128 = 1.000000005000000000000000 + return Fix64(x, rounding: rule) + } + ` + + // towardZero should truncate + value, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: encodeArgs([]cadence.Value{newRoundingRuleArgument(0)}), + }, + Context{ + Interface: runtimeInterface, + Location: nextScriptLocation(), + UseVM: *compile, + }, + ) + + require.NoError(t, err) + assert.Equal(t, cadence.Fix64(100000000), value) // 1.00000000 + }) + + t.Run("UFix128 to UFix64 with rounding", func(t *testing.T) { + t.Parallel() + + const script = ` + access(all) fun main(rule: RoundingRule): UFix64 { + let x: UFix128 = 1.000000005000000000000000 + return UFix64(x, rounding: rule) + } + ` + + // awayFromZero should round up + value, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: encodeArgs([]cadence.Value{newRoundingRuleArgument(1)}), + }, + Context{ + Interface: runtimeInterface, + Location: nextScriptLocation(), + UseVM: *compile, + }, + ) + + require.NoError(t, err) + assert.Equal(t, cadence.UFix64(100000001), value) // 1.00000001 + }) +} diff --git a/sema/crypto_test.go b/sema/crypto_test.go index 6c19362dd..43be41ef1 100644 --- a/sema/crypto_test.go +++ b/sema/crypto_test.go @@ -39,7 +39,7 @@ func TestCheckHashAlgorithmCases(t *testing.T) { baseValueActivation.DeclareValue(value) } - test := func(algorithm sema.CryptoAlgorithm) { + test := func(algorithm sema.NativeEnumCase) { _, err := ParseAndCheckWithOptions(t, fmt.Sprintf( @@ -120,7 +120,7 @@ func TestCheckSignatureAlgorithmCases(t *testing.T) { baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) baseValueActivation.DeclareValue(stdlib.InterpreterSignatureAlgorithmConstructor) - test := func(algorithm sema.CryptoAlgorithm) { + test := func(algorithm sema.NativeEnumCase) { _, err := ParseAndCheckWithOptions(t, fmt.Sprintf( diff --git a/sema/rounding_rule_test.go b/sema/rounding_rule_test.go new file mode 100644 index 000000000..40c2fc769 --- /dev/null +++ b/sema/rounding_rule_test.go @@ -0,0 +1,230 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/common" + "github.com/onflow/cadence/sema" + "github.com/onflow/cadence/stdlib" + . "github.com/onflow/cadence/test_utils/sema_utils" +) + +func TestCheckRoundingRuleCases(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + test := func(rule sema.NativeEnumCase) { + + _, err := ParseAndCheckWithOptions(t, + fmt.Sprintf( + ` + let rule: RoundingRule = RoundingRule.%s + `, + rule.Name(), + ), + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) + } + + for _, rule := range sema.RoundingRules { + test(rule) + } +} + +func TestCheckRoundingRuleConstructor(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + baseValueActivation.DeclareValue(stdlib.InterpreterRoundingRuleConstructor) + + _, err := ParseAndCheckWithOptions(t, + ` + let rule = RoundingRule(rawValue: 0) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) +} + +func TestCheckRoundingRuleRawValue(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + _, err := ParseAndCheckWithOptions(t, + ` + let rule = RoundingRule.towardZero + let rawValue: UInt8 = rule.rawValue + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) +} + +func TestCheckFix64WithRoundingRule(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + t.Run("with rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + let x: Fix64 = Fix64(1, rounding: RoundingRule.towardZero) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) + }) + + t.Run("without rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + let x: Fix64 = Fix64(1) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) + }) + + t.Run("invalid rounding type", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + let x: Fix64 = Fix64(1, rounding: 42) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.Error(t, err) + }) +} + +func TestCheckUFix64WithRoundingRule(t *testing.T) { + + t.Parallel() + + baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) + for _, value := range stdlib.InterpreterDefaultScriptStandardLibraryValues(nil) { + baseValueActivation.DeclareValue(value) + } + + t.Run("with rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + let x: UFix64 = UFix64(1, rounding: RoundingRule.nearestHalfEven) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) + }) + + t.Run("without rounding", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheckWithOptions(t, + ` + let x: UFix64 = UFix64(1) + `, + ParseAndCheckOptions{ + CheckerConfig: &sema.Config{ + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + + require.NoError(t, err) + }) +} diff --git a/sema/rounding_rule_type.go b/sema/rounding_rule_type.go new file mode 100644 index 000000000..b5cb88fae --- /dev/null +++ b/sema/rounding_rule_type.go @@ -0,0 +1,143 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import "github.com/onflow/cadence/errors" + +const RoundingRuleTypeName = "RoundingRule" + +var RoundingRuleType = newNativeEnumType( + RoundingRuleTypeName, + UInt8Type, + nil, +) + +var RoundingRuleTypeAnnotation = NewTypeAnnotation(RoundingRuleType) + +type RoundingRule uint8 + +// NOTE: only add new rules, do *NOT* change existing items, +// reuse raw values for other items, swap the order, etc. +// +// # Existing stored values use these raw values and should not change +// +// IMPORTANT: update RoundingRules +const ( + RoundingRuleTowardZero RoundingRule = iota + RoundingRuleAwayFromZero + RoundingRuleNearestHalfAway + RoundingRuleNearestHalfEven + + // !!! *WARNING* !!! + // ADD NEW RULES *BEFORE* THIS WARNING. + // DO *NOT* ADD NEW RULES AFTER THIS LINE! + RoundingRule_Count +) + +var RoundingRules = []RoundingRule{ + RoundingRuleTowardZero, + RoundingRuleAwayFromZero, + RoundingRuleNearestHalfAway, + RoundingRuleNearestHalfEven, +} + +func (rule RoundingRule) Name() string { + switch rule { + case RoundingRuleTowardZero: + return "towardZero" + case RoundingRuleAwayFromZero: + return "awayFromZero" + case RoundingRuleNearestHalfAway: + return "nearestHalfAway" + case RoundingRuleNearestHalfEven: + return "nearestHalfEven" + } + + panic(errors.NewUnreachableError()) +} + +func (rule RoundingRule) RawValue() uint8 { + switch rule { + case RoundingRuleTowardZero: + return 0 + case RoundingRuleAwayFromZero: + return 1 + case RoundingRuleNearestHalfAway: + return 2 + case RoundingRuleNearestHalfEven: + return 3 + } + + panic(errors.NewUnreachableError()) +} + +func (rule RoundingRule) DocString() string { + switch rule { + case RoundingRuleTowardZero: + return RoundingRuleTowardZeroDocString + case RoundingRuleAwayFromZero: + return RoundingRuleAwayFromZeroDocString + case RoundingRuleNearestHalfAway: + return RoundingRuleNearestHalfAwayDocString + case RoundingRuleNearestHalfEven: + return RoundingRuleNearestHalfEvenDocString + } + + panic(errors.NewUnreachableError()) +} + +const RoundingRuleTowardZeroDocString = ` +Round to the closest representable fixed-point value that has +a magnitude less than or equal to the magnitude of the real result, +effectively truncating the fractional part. + +e.g. 5e-8 / 2 = 2e-8, -5e-8 / 2 = -2e-8 +` + +const RoundingRuleAwayFromZeroDocString = ` +Round to the closest representable fixed-point value that has +a magnitude greater than or equal to the magnitude of the real result, +effectively rounding up any fractional part. + +e.g. 5e-8 / 2 = 3e-8, -5e-8 / 2 = -3e-8 +` + +const RoundingRuleNearestHalfAwayDocString = ` +Round to the closest representable fixed-point value to the real result, +which could be larger (rounded up) or smaller (rounded down) depending on +if the unrepresentable portion is greater than or less than one half +the difference between two available values. + +If two representable values are equally close, +the value will be rounded away from zero. + +e.g. 7e-8 / 2 = 4e-8, 5e-8 / 2 = 3e-8 +` + +const RoundingRuleNearestHalfEvenDocString = ` +Round to the closest representable fixed-point value to the real result, +which could be larger (rounded up) or smaller (rounded down) depending on +if the unrepresentable portion is greater than or less than one half +the difference between two available values. + +If two representable values are equally close, +the value with an even digit in the smallest decimal place will be chosen. + +e.g. 7e-8 / 2 = 4e-8, 5e-8 / 2 = 2e-8 +` diff --git a/sema/rounding_rule_type_test.go b/sema/rounding_rule_type_test.go new file mode 100644 index 000000000..408a75c3e --- /dev/null +++ b/sema/rounding_rule_type_test.go @@ -0,0 +1,56 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoundingRuleValues(t *testing.T) { + t.Parallel() + + // Ensure that the values of the RoundingRule enum are not accidentally changed, + // e.g. by adding a new value in between or by changing an existing value. + + expectedValues := map[RoundingRule]uint8{ + RoundingRuleTowardZero: 0, + RoundingRuleAwayFromZero: 1, + RoundingRuleNearestHalfAway: 2, + RoundingRuleNearestHalfEven: 3, + RoundingRule_Count: 4, + } + + // Check all expected values. + for rule, expectedValue := range expectedValues { + require.Equal(t, expectedValue, uint8(rule), "value mismatch for %d", rule) + } + + // Check that no new values have been added + // without updating the expected values above. + for i := uint8(0); i < uint8(RoundingRule_Count); i++ { + rule := RoundingRule(i) + _, ok := expectedValues[rule] + require.True(t, ok, + fmt.Sprintf("unexpected RoundingRule value %d: update expectedValues", i), + ) + } +} diff --git a/sema/type.go b/sema/type.go index 45bf8a58a..e037dc0ca 100644 --- a/sema/type.go +++ b/sema/type.go @@ -4570,6 +4570,7 @@ var AllBuiltinTypes = common.Concat( PublicKeyType, SignatureAlgorithmType, HashAlgorithmType, + RoundingRuleType, StorageCapabilityControllerType, AccountCapabilityControllerType, DeploymentResultType, @@ -4752,7 +4753,13 @@ func init() { panic(errors.NewUnreachableError()) } - functionType := NumberConversionFunctionType(numberType) + var functionType *FunctionType + switch numberType { + case Fix64Type, UFix64Type: + functionType = FixedPoint64ConversionFunctionType(numberType) + default: + functionType = NumberConversionFunctionType(numberType) + } addMember := func(member *Member) { if functionType.Members == nil { @@ -4865,6 +4872,28 @@ func NumberConversionFunctionType(numberType Type) *FunctionType { } } +func FixedPoint64ConversionFunctionType(numberType Type) *FunctionType { + return &FunctionType{ + Purity: FunctionPurityView, + TypeFunctionType: numberType, + Parameters: []Parameter{ + { + Label: ArgumentLabelNotRequired, + Identifier: "value", + TypeAnnotation: NumberTypeAnnotation, + }, + { + Label: "rounding", + Identifier: "rounding", + TypeAnnotation: RoundingRuleTypeAnnotation, + }, + }, + Arity: &Arity{Min: 1, Max: 2}, + ReturnTypeAnnotation: NewTypeAnnotation(numberType), + ArgumentExpressionsCheck: numberFunctionArgumentExpressionsChecker(numberType), + } +} + func numberConversionDocString(targetDescription string) string { return fmt.Sprintf( "Converts the given number to %s. %s", @@ -9671,7 +9700,7 @@ var PublicKeyTypeVerifyPoPFunctionType = NewSimpleFunctionType( BoolTypeAnnotation, ) -type CryptoAlgorithm interface { +type NativeEnumCase interface { RawValue() uint8 Name() string DocString() string @@ -10176,6 +10205,7 @@ func init() { PublicKeyType, HashAlgorithmType, SignatureAlgorithmType, + RoundingRuleType, AccountType, DeploymentResultType, } diff --git a/stdlib/builtin.go b/stdlib/builtin.go index 76d39f9ef..95ed6fb3e 100644 --- a/stdlib/builtin.go +++ b/stdlib/builtin.go @@ -57,6 +57,7 @@ func InterpreterDefaultStandardLibraryValues(handler StandardLibraryHandler) []S InterpreterAssertFunction, InterpreterPanicFunction, InterpreterSignatureAlgorithmConstructor, + InterpreterRoundingRuleConstructor, InterpreterInclusiveRangeConstructor, NewInterpreterLogFunction(handler), NewInterpreterRevertibleRandomFunction(handler), @@ -76,6 +77,7 @@ func VMDefaultStandardLibraryValues(handler StandardLibraryHandler) []StandardLi VMAssertFunction, VMPanicFunction, VMSignatureAlgorithmConstructor, + VMRoundingRuleConstructor, VMInclusiveRangeConstructor, NewVMLogFunction(handler), NewVMRevertibleRandomFunction(handler), @@ -155,6 +157,7 @@ func VMValues(handler StandardLibraryHandler) []VMValue { return common.Concat( VMSignatureAlgorithmCaseValues, NewVMHashAlgorithmCaseValues(handler), + VMRoundingRuleCaseValues, ) } diff --git a/stdlib/crypto.go b/stdlib/crypto.go index 8f1da9277..9bfbf4c59 100644 --- a/stdlib/crypto.go +++ b/stdlib/crypto.go @@ -20,69 +20,6 @@ package stdlib import ( "github.com/onflow/cadence/common" - "github.com/onflow/cadence/interpreter" - "github.com/onflow/cadence/sema" ) const CryptoContractLocation = common.IdentifierLocation("Crypto") - -func cryptoAlgorithmEnumLookupType[T sema.CryptoAlgorithm]( - enumType *sema.CompositeType, - enumCases []T, -) *sema.FunctionType { - - functionType := sema.EnumLookupFunctionType(enumType) - - for _, algo := range enumCases { - name := algo.Name() - functionType.Members.Set( - name, - sema.NewUnmeteredPublicConstantFieldMember( - enumType, - name, - enumType, - algo.DocString(), - ), - ) - } - - return functionType -} - -type enumCaseConstructor func(rawValue interpreter.UInt8Value) interpreter.MemberAccessibleValue - -func interpreterCryptoAlgorithmEnumValueAndCaseValues[T sema.CryptoAlgorithm]( - functionType *sema.FunctionType, - enumCases []T, - caseConstructor enumCaseConstructor, -) ( - functionValue interpreter.FunctionValue, - cases map[interpreter.UInt8Value]interpreter.MemberAccessibleValue, -) { - - caseCount := len(enumCases) - caseValues := make([]interpreter.EnumCase, caseCount) - constructorNestedVariables := make(map[string]interpreter.Variable, caseCount) - cases = make(map[interpreter.UInt8Value]interpreter.MemberAccessibleValue, caseCount) - - for i, enumCase := range enumCases { - rawValue := interpreter.UInt8Value(enumCase.RawValue()) - caseValue := caseConstructor(rawValue) - cases[rawValue] = caseValue - caseValues[i] = interpreter.EnumCase{ - Value: caseValue, - RawValue: rawValue, - } - constructorNestedVariables[enumCase.Name()] = - interpreter.NewVariableWithValue(nil, caseValue) - } - - functionValue = interpreter.EnumLookupFunction( - nil, - functionType, - caseValues, - constructorNestedVariables, - ) - - return -} diff --git a/stdlib/enum.go b/stdlib/enum.go new file mode 100644 index 000000000..0b18ffe7c --- /dev/null +++ b/stdlib/enum.go @@ -0,0 +1,85 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdlib + +import ( + "github.com/onflow/cadence/interpreter" + "github.com/onflow/cadence/sema" +) + +func nativeEnumLookupType[T sema.NativeEnumCase]( + enumType *sema.CompositeType, + enumCases []T, +) *sema.FunctionType { + + functionType := sema.EnumLookupFunctionType(enumType) + + for _, algo := range enumCases { + name := algo.Name() + functionType.Members.Set( + name, + sema.NewUnmeteredPublicConstantFieldMember( + enumType, + name, + enumType, + algo.DocString(), + ), + ) + } + + return functionType +} + +type enumCaseConstructor func(rawValue interpreter.UInt8Value) interpreter.MemberAccessibleValue + +func interpreterNativeEnumValueAndCaseValues[T sema.NativeEnumCase]( + functionType *sema.FunctionType, + enumCases []T, + caseConstructor enumCaseConstructor, +) ( + functionValue interpreter.FunctionValue, + cases map[interpreter.UInt8Value]interpreter.MemberAccessibleValue, +) { + + caseCount := len(enumCases) + caseValues := make([]interpreter.EnumCase, caseCount) + constructorNestedVariables := make(map[string]interpreter.Variable, caseCount) + cases = make(map[interpreter.UInt8Value]interpreter.MemberAccessibleValue, caseCount) + + for i, enumCase := range enumCases { + rawValue := interpreter.UInt8Value(enumCase.RawValue()) + caseValue := caseConstructor(rawValue) + cases[rawValue] = caseValue + caseValues[i] = interpreter.EnumCase{ + Value: caseValue, + RawValue: rawValue, + } + constructorNestedVariables[enumCase.Name()] = + interpreter.NewVariableWithValue(nil, caseValue) + } + + functionValue = interpreter.EnumLookupFunction( + nil, + functionType, + caseValues, + constructorNestedVariables, + ) + + return +} diff --git a/stdlib/hashalgorithm.go b/stdlib/hashalgorithm.go index f7878bb6d..be6aaabda 100644 --- a/stdlib/hashalgorithm.go +++ b/stdlib/hashalgorithm.go @@ -27,7 +27,7 @@ import ( "github.com/onflow/cadence/sema" ) -var hashAlgorithmLookupType = cryptoAlgorithmEnumLookupType( +var hashAlgorithmLookupType = nativeEnumLookupType( sema.HashAlgorithmType, sema.HashAlgorithms, ) @@ -237,7 +237,7 @@ func hash( // these functions are left as is, since there are differences in the implementations between interpreter and vm func NewInterpreterHashAlgorithmConstructor(hasher Hasher) StandardLibraryValue { - interpreterHashAlgorithmConstructorValue, _ := interpreterCryptoAlgorithmEnumValueAndCaseValues( + interpreterHashAlgorithmConstructorValue, _ := interpreterNativeEnumValueAndCaseValues( hashAlgorithmLookupType, sema.HashAlgorithms, func(rawValue interpreter.UInt8Value) interpreter.MemberAccessibleValue { diff --git a/stdlib/roundingrule.go b/stdlib/roundingrule.go new file mode 100644 index 000000000..dd2ee2585 --- /dev/null +++ b/stdlib/roundingrule.go @@ -0,0 +1,112 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdlib + +import ( + "github.com/onflow/cadence/bbq/commons" + "github.com/onflow/cadence/bbq/vm" + "github.com/onflow/cadence/common" + "github.com/onflow/cadence/interpreter" + "github.com/onflow/cadence/sema" +) + +var roundingRuleStaticType interpreter.StaticType = interpreter.ConvertSemaCompositeTypeToStaticCompositeType( + nil, + sema.RoundingRuleType, +) + +func NewRoundingRuleCase(rawValue interpreter.UInt8Value) interpreter.MemberAccessibleValue { + + fields := map[string]interpreter.Value{ + sema.EnumRawValueFieldName: rawValue, + } + + return interpreter.NewSimpleCompositeValue( + nil, + sema.RoundingRuleType.ID(), + roundingRuleStaticType, + []string{sema.EnumRawValueFieldName}, + fields, + nil, + nil, + nil, + nil, + ) +} + +var roundingRuleLookupType = nativeEnumLookupType( + sema.RoundingRuleType, + sema.RoundingRules, +) + +var interpreterRoundingRuleConstructorValue, RoundingRuleCaseValues = interpreterNativeEnumValueAndCaseValues( + roundingRuleLookupType, + sema.RoundingRules, + NewRoundingRuleCase, +) + +var InterpreterRoundingRuleConstructor = StandardLibraryValue{ + Name: sema.RoundingRuleTypeName, + Type: roundingRuleLookupType, + Value: interpreterRoundingRuleConstructorValue, + Kind: common.DeclarationKindEnum, +} + +var vmRoundingRuleConstructorValue = vm.NewNativeFunctionValue( + sema.RoundingRuleTypeName, + roundingRuleLookupType, + func( + context interpreter.NativeFunctionContext, + _ interpreter.TypeArgumentsIterator, + _ interpreter.ArgumentTypesIterator, + _ interpreter.Value, + args []interpreter.Value, + ) interpreter.Value { + rawValue := args[0].(interpreter.UInt8Value) + + caseValue, ok := RoundingRuleCaseValues[rawValue] + if !ok { + return interpreter.Nil + } + + return interpreter.NewSomeValueNonCopying(context, caseValue) + }, +) + +var VMRoundingRuleConstructor = StandardLibraryValue{ + Name: sema.RoundingRuleTypeName, + Type: roundingRuleLookupType, + Value: vmRoundingRuleConstructorValue, + Kind: common.DeclarationKindEnum, +} + +var VMRoundingRuleCaseValues = func() []VMValue { + values := make([]VMValue, len(sema.RoundingRules)) + for i, roundingRule := range sema.RoundingRules { + rawValue := interpreter.UInt8Value(roundingRule.RawValue()) + values[i] = VMValue{ + Name: commons.TypeQualifiedName( + sema.RoundingRuleType, + roundingRule.Name(), + ), + Value: RoundingRuleCaseValues[rawValue], + } + } + return values +}() diff --git a/stdlib/signaturealgorithm.go b/stdlib/signaturealgorithm.go index 201afe59b..e6ffa9752 100644 --- a/stdlib/signaturealgorithm.go +++ b/stdlib/signaturealgorithm.go @@ -50,12 +50,12 @@ func NewSignatureAlgorithmCase(rawValue interpreter.UInt8Value) interpreter.Memb ) } -var signatureAlgorithmLookupType = cryptoAlgorithmEnumLookupType( +var signatureAlgorithmLookupType = nativeEnumLookupType( sema.SignatureAlgorithmType, sema.SignatureAlgorithms, ) -var interpreterSignatureAlgorithmConstructorValue, SignatureAlgorithmCaseValues = interpreterCryptoAlgorithmEnumValueAndCaseValues( +var interpreterSignatureAlgorithmConstructorValue, SignatureAlgorithmCaseValues = interpreterNativeEnumValueAndCaseValues( signatureAlgorithmLookupType, sema.SignatureAlgorithms, NewSignatureAlgorithmCase, diff --git a/test_utils/test_utils.go b/test_utils/test_utils.go index 7c0f0bec5..36da821d4 100644 --- a/test_utils/test_utils.go +++ b/test_utils/test_utils.go @@ -20,6 +20,7 @@ package test_utils import ( "fmt" + "slices" "strings" "testing" @@ -319,6 +320,29 @@ func ParseCheckAndPrepareWithOptions( // (i.e: only get the values that were added externally for tests) interpreterBaseActivationVariables := interpreterBaseActivation.ValuesInCurrentLevel() + // Collect nested variables (e.g. enum case values like "RoundingRule.towardZero") + // from HostFunctionValues, so they can be registered in both the VM and compiler activations. + type nestedVariableEntry struct { + qualifiedName string + value interpreter.Value + } + var nestedEntries []nestedVariableEntry + + for name, variable := range interpreterBaseActivationVariables { //nolint:maprange + value := variable.GetValue(nil) + if functionValue, ok := value.(*interpreter.HostFunctionValue); ok { + for nestedName, nestedVar := range functionValue.NestedVariables { //nolint:maprange + nestedEntries = append(nestedEntries, nestedVariableEntry{ + qualifiedName: name + "." + nestedName, + value: nestedVar.GetValue(nil), + }) + } + } + } + slices.SortFunc(nestedEntries, func(a, b nestedVariableEntry) int { + return strings.Compare(a.qualifiedName, b.qualifiedName) + }) + vmConfig.BuiltinGlobalsProvider = func(_ common.Location) *activations.Activation[vm.Variable] { activation := activations.NewActivation(nil, vm.DefaultBuiltinGlobals()) @@ -336,7 +360,7 @@ func ParseCheckAndPrepareWithOptions( value := variable.GetValue(nil) if functionValue, ok := value.(*interpreter.HostFunctionValue); ok { - value = vm.NewNativeFunctionValue( + nativeFn := vm.NewNativeFunctionValue( name, functionValue.Type, func( @@ -371,6 +395,13 @@ func ParseCheckAndPrepareWithOptions( }, ) + // Transfer nested variables (e.g. enum case values) + // from the interpreter's HostFunctionValue to the VM's NativeFunctionValue. + for nestedName, nestedVar := range functionValue.NestedVariables { //nolint:maprange + nativeFn.SetField(nestedName, nestedVar.GetValue(nil)) + } + + value = nativeFn } vmVariable := interpreter.NewVariableWithValue( @@ -381,6 +412,14 @@ func ParseCheckAndPrepareWithOptions( activation.Set(name, vmVariable) } + // Register nested variables as separate qualified globals. + for _, entry := range nestedEntries { + activation.Set( + entry.qualifiedName, + interpreter.NewVariableWithValue(nil, entry.value), + ) + } + return activation } @@ -399,6 +438,15 @@ func ParseCheckAndPrepareWithOptions( compiler.NewGlobalImport(name), ) } + + // Register nested variables as separate qualified compiler globals. + for _, entry := range nestedEntries { + activation.Set( + entry.qualifiedName, + compiler.NewGlobalImport(entry.qualifiedName), + ) + } + return activation }, }