diff --git a/codegen/masm/intrinsics/mem.masm b/codegen/masm/intrinsics/mem.masm index 1afe92e2f..364ed5d1d 100644 --- a/codegen/masm/intrinsics/mem.masm +++ b/codegen/masm/intrinsics/mem.masm @@ -167,26 +167,44 @@ pub proc load_sw # [addr, offset] # load the element containing the data we want mem_load else # [addr, offset] + # convert the byte offset to a bit offset + swap.1 push.8 u32wrapping_mul swap.1 # [addr, bit_offset] # the load crosses an element boundary # # 1. load the first element - dup.0 mem_load # [e0, addr, offset] + dup.0 mem_load # [e0, addr, bit_offset] # 2. load the second element - swap.1 # [addr, e0, offset] - push.1 u32overflowing_add # [overflowed, addr + 1, e0, offset] - assertz mem_load # [e1, e0, offset] - # shift low bits - push.32 dup.3 # [offset, 32, e1, e0, offset] - u32overflowing_sub assertz # [32 - offset, e1, e0, offset] - u32shr # [lo, e0, offset] - # shift high bits left by the offset - swap.2 # [offset, e0, lo] - u32shl # [hi, lo] + swap.1 # [addr, e0, bit_offset] + push.1 u32overflowing_add # [overflowed, addr + 1, e0, bit_offset] + assertz mem_load # [e1, e0, bit_offset] + # Reconstruct the 32-bit window whose first byte begins at the original byte pointer. + # `e0` contributes the low part after shifting right, and `e1` contributes the carried + # high part after shifting left into the vacated bits. + swap.1 # [e0, e1, bit_offset] + dup.2 # [bit_offset, e0, e1, bit_offset] + u32shr # [lo, e1, bit_offset] + movup.2 # [bit_offset, lo, e1] + push.32 swap.1 # [bit_offset, 32, lo, e1] + u32overflowing_sub assertz # [32 - bit_offset, lo, e1] + movup.2 swap.1 # [32 - bit_offset, e1, lo] + u32shl # [hi, lo] # combine the two halves u32or # [result] end end +# Load a 16-bit integer from the given native pointer tuple. +# +# A native pointer tuple consists of an element address where the data begins, and a byte offset, +# which is the offset of the first byte, in the 32-bit representation of that element. +# +# Stack transition: [addr, offset] -> [value] +pub proc load_u16(addr: ptr, offset: u8) -> u16 + exec.load_sw + push.65535 + u32and +end + # This handles emitting code that handles aligning an unaligned 64-bit value which is split across # three elements. # @@ -436,6 +454,26 @@ pub proc store_sw # [addr, offset, value] end end +# Store a 16-bit integer to the given native pointer tuple. +# +# A native pointer tuple consists of an element address where the data begins, and a byte offset, +# which is the offset of the first byte, in the 32-bit representation of that element. +# +# Stack transition: [addr, offset, value] -> [] +pub proc store_u16(addr: ptr, offset: u8, value: u16) + # Load the current 32-bit window at the destination, keep its upper half, then overwrite the + # target two bytes before delegating the write-back to `store_sw`. + dup.1 dup.1 exec.load_sw # [window, addr, offset, value] + push.4294901760 # 0xffff0000 + u32and # [masked_window, addr, offset, value] + movup.3 # [value, masked_window, addr, offset] + push.65535 + u32and # [value16, masked_window, addr, offset] + u32or # [combined, addr, offset] + swap.2 swap.1 # [addr, offset, combined] + exec.store_sw +end + # Store two 32-bit words to the given native pointer tuple. # # A native pointer tuple consists of an element address where the data begins, and a byte offset, diff --git a/codegen/masm/src/emit/felt.rs b/codegen/masm/src/emit/felt.rs index abc43d77c..6578cf5b2 100644 --- a/codegen/masm/src/emit/felt.rs +++ b/codegen/masm/src/emit/felt.rs @@ -32,7 +32,13 @@ impl OpEmitter<'_> { /// `[a, ..] => [a, ..]` #[inline(always)] pub fn assert_felt_is_zero(&mut self, span: SourceSpan) { - self.emit_all([masm::Instruction::Dup0, masm::Instruction::Assertz], span); + self.emit_all( + [ + masm::Instruction::Dup0, + Self::assertz_with_message_inst("expected felt value to be zero", span), + ], + span, + ); } /// Convert a field element to i128 by zero-extension. @@ -85,7 +91,7 @@ impl OpEmitter<'_> { // Split into u32 limbs masm::Instruction::U32Split, // Assert most significant 32 bits are unused - masm::Instruction::Assertz, + Self::assertz_with_message_inst("felt value does not fit in 32 bits", span), ], span, ); @@ -105,7 +111,7 @@ impl OpEmitter<'_> { // Split into u32 limbs masm::Instruction::U32Split, // Assert most significant 32 bits are unused - masm::Instruction::Assertz, + Self::assertz_with_message_inst("felt value does not fit in 32 bits", span), ], span, ); diff --git a/codegen/masm/src/emit/int128.rs b/codegen/masm/src/emit/int128.rs index 2b2db8a50..22bc850dc 100644 --- a/codegen/masm/src/emit/int128.rs +++ b/codegen/masm/src/emit/int128.rs @@ -78,7 +78,11 @@ impl OpEmitter<'_> { // // What remains on the stack at this point are the low 64-bits, // which is also our result. - self.emit_n(2, masm::Instruction::Assertz, span); + self.emit_n( + 2, + Self::assertz_with_message_inst("128-bit value does not fit in u64", span), + span, + ); } /// Convert a 128-bit value to u32 @@ -95,7 +99,11 @@ impl OpEmitter<'_> { // // What remains on the stack at this point are the low 32-bits, // which is also our result. - self.emit_n(3, masm::Instruction::Assertz, span); + self.emit_n( + 3, + Self::assertz_with_message_inst("128-bit value does not fit in u32", span), + span, + ); } /// Convert a unsigned 128-bit value to i64 @@ -139,7 +147,7 @@ impl OpEmitter<'_> { [ // Assert that both 32-bit limbs of the most significant 64 bits match, // consuming them in the process - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst("128-bit value does not fit in i64", span), // At this point, the stack is: [is_signed, x1, x0] // // Select an expected value for the sign bit based on the is_signed flag @@ -158,7 +166,7 @@ impl OpEmitter<'_> { // any other combination will trap. // // [x1, x0] - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst("128-bit value does not fit in i64", span), ], span, ); diff --git a/codegen/masm/src/emit/int32.rs b/codegen/masm/src/emit/int32.rs index 56afce7a8..71dcbd7c7 100644 --- a/codegen/masm/src/emit/int32.rs +++ b/codegen/masm/src/emit/int32.rs @@ -107,7 +107,7 @@ impl OpEmitter<'_> { #[inline] pub fn assert_signed_int32(&mut self, span: SourceSpan) { self.is_signed_int32(span); - self.emit(masm::Instruction::Assert, span); + self.emit(Self::assert_with_message_inst("expected a signed i32 value", span), span); } /// Emits code to assert that a 32-bit value on the operand stack does not have the i32 sign bit @@ -119,7 +119,7 @@ impl OpEmitter<'_> { #[inline] pub fn assert_unsigned_int32(&mut self, span: SourceSpan) { self.is_signed_int32(span); - self.emit(masm::Instruction::Assertz, span); + self.emit(Self::assertz_with_message_inst("expected a non-negative i32 value", span), span); } /// Assert that the 32-bit value on the stack is a valid i32 value @@ -131,7 +131,7 @@ impl OpEmitter<'_> { // the value is <= i32::MIN, which is 1 more than i32::MAX. self.push_i32(i32::MIN, span); self.emit(masm::Instruction::U32Lte, span); - self.emit(masm::Instruction::Assert, span); + self.emit(Self::assert_with_message_inst("value does not fit in i32", span), span); } /// Emits code to assert that a 32-bit value on the operand stack is equal to the given constant @@ -148,7 +148,10 @@ impl OpEmitter<'_> { [ masm::Instruction::Dup0, masm::Instruction::EqImm(Felt::new(value as u64).into()), - masm::Instruction::Assert, + Self::assert_with_message_inst( + format!("expected u32 value to equal {value}"), + span, + ), ], span, ); @@ -164,7 +167,13 @@ impl OpEmitter<'_> { /// `[expected, input, ..] => [input, ..]` #[inline] pub fn assert_eq_u32(&mut self, span: SourceSpan) { - self.emit_all([masm::Instruction::Dup1, masm::Instruction::AssertEq], span); + self.emit_all( + [ + masm::Instruction::Dup1, + Self::assert_eq_with_message_inst("expected u32 values to be equal", span), + ], + span, + ); } /// Emits code to select a constant u32 value, using the `n`th value on the operand @@ -244,7 +253,10 @@ impl OpEmitter<'_> { // Apply the mask masm::Instruction::U32And, // Assert that the masked bits and the mask are equal - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst( + format!("value does not fit in signed {n}-bit range"), + span, + ), ], span, ); @@ -293,7 +305,13 @@ impl OpEmitter<'_> { self.emit_push(mask, span); self.emit(masm::Instruction::U32And, span); // Assert the masked value is all 0s - self.emit(masm::Instruction::Assertz, span); + self.emit( + Self::assertz_with_message_inst( + format!("value does not fit in unsigned {n}-bit range"), + span, + ), + span, + ); } /// Convert an i32/u32 value on the stack to an unsigned N-bit integer value diff --git a/codegen/masm/src/emit/int64.rs b/codegen/masm/src/emit/int64.rs index 53c0bd935..668e3b178 100644 --- a/codegen/masm/src/emit/int64.rs +++ b/codegen/masm/src/emit/int64.rs @@ -14,7 +14,7 @@ impl OpEmitter<'_> { // Assert that value is <= P, then unsplit the limbs to get a felt self.push_u64(P, span); self.lt_u64(span); - self.emit(masm::Instruction::Assert, span); + self.emit(Self::assert_with_message_inst("u64 value does not fit in felt", span), span); // `u32unsplit` expects `[hi, lo]` on the stack; u64 values are represented as `[lo, hi]`. self.emit(masm::Instruction::Swap1, span); self.u32unsplit(span); @@ -41,14 +41,27 @@ impl OpEmitter<'_> { // Bring `hi` to the top of the stack and assert it is zero. This consumes `hi`, // leaving only `lo` on the stack. masm::Instruction::Swap1, - masm::Instruction::Assertz, + // Assert hi bits are zero + Self::assertz_with_message_inst( + format!("u64 value does not fit in unsigned {n}-bit range"), + span, + ), // Check that the remaining bits fit in range masm::Instruction::Dup0, ], span, ); self.emit_push(Felt::new(2u64.pow(n) - 1), span); - self.emit_all([masm::Instruction::U32Lte, masm::Instruction::Assert], span); + self.emit_all( + [ + masm::Instruction::U32Lte, + Self::assert_with_message_inst( + format!("u64 value does not fit in unsigned {n}-bit range"), + span, + ), + ], + span, + ); } /// Convert an i64 value to a signed N-bit integer, where N <= 32 @@ -75,7 +88,10 @@ impl OpEmitter<'_> { self.emit_all( [ // [is_unsigned, x_lo] - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst( + format!("i64 value does not fit in signed {n}-bit range"), + span, + ), // [x_lo, is_unsigned, x_lo] masm::Instruction::Dup1, ], @@ -104,7 +120,10 @@ impl OpEmitter<'_> { // [expected_sign_bits, sign_bits, x_lo] masm::Instruction::CDrop, // [x_lo] - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst( + format!("i64 value does not fit in signed {n}-bit range"), + span, + ), ], span, ); @@ -220,7 +239,7 @@ impl OpEmitter<'_> { // the value is <= i64::MIN, which is 1 more than i64::MAX. self.push_i64(i64::MIN, span); self.lte_u64(span); - self.emit(masm::Instruction::Assert, span); + self.emit(Self::assert_with_message_inst("value does not fit in i64", span), span); } /// Duplicate the i64/u64 value on top of the stack @@ -428,7 +447,7 @@ impl OpEmitter<'_> { match overflow { Overflow::Checked => { self.raw_exec("::miden::core::math::u64::overflowing_add", span); - self.emit(masm::Instruction::Assertz, span); + self.emit(Self::assertz_with_message_inst("u64 addition overflowed", span), span); } Overflow::Unchecked | Overflow::Wrapping => { self.raw_exec("::miden::core::math::u64::wrapping_add", span); @@ -493,7 +512,10 @@ impl OpEmitter<'_> { match overflow { Overflow::Checked => { self.raw_exec("::miden::core::math::u64::overflowing_sub", span); - self.emit(masm::Instruction::Assertz, span); + self.emit( + Self::assertz_with_message_inst("u64 subtraction overflowed", span), + span, + ); } Overflow::Unchecked | Overflow::Wrapping => { self.raw_exec("::miden::core::math::u64::wrapping_sub", span); @@ -575,7 +597,7 @@ impl OpEmitter<'_> { masm::Instruction::Drop, // Bring overflow back to the top and assert it is zero masm::Instruction::MovUp2, - masm::Instruction::Assertz, + Self::assertz_with_message_inst("u64 multiplication overflowed", span), ], span, ); diff --git a/codegen/masm/src/emit/mem.rs b/codegen/masm/src/emit/mem.rs index d1d2d4984..8c0707bfd 100644 --- a/codegen/masm/src/emit/mem.rs +++ b/codegen/masm/src/emit/mem.rs @@ -9,6 +9,94 @@ use crate::{OperandStack, lower::NativePtr}; /// Allocation impl OpEmitter<'_> { + /// Emit the loop header for a counted `while.true` loop. + /// + /// The caller provides the concrete `dup` instruction needed to bring `count` to the top of + /// the stack after the loop index has been seeded with zero, because each caller carries + /// `count` at a different depth in its loop state. + /// + /// Stack transition: + /// + /// - Before: `[loop_state..]` + /// - After: `[count > 0, i = 0, loop_state..]` + /// + /// For example: + /// + /// - `memset`: `[dst, count, value..] -> [count > 0, i = 0, dst, count, value..]` + /// - `memcpy`: `[src, dst, count] -> [count > 0, i = 0, src, dst, count]` + fn emit_counted_loop_header(&mut self, count_dup: masm::Instruction, span: SourceSpan) { + self.emit_push(0u32, span); + self.emit(count_dup, span); + self.emit_push(0u32, span); + self.emit(masm::Instruction::U32Gt, span); + } + + /// Emit the loop back-edge condition for a counted `while.true` loop. + /// + /// The caller provides the concrete `dup` instruction needed to bring `count` to the top of + /// the stack after incrementing the loop index, because each caller carries `count` at a + /// different depth in its loop state. + /// + /// Stack transition: + /// + /// - Before: `[i, loop_state..]` + /// - After: `[i + 1 < count, i + 1, loop_state..]` + /// + /// For example: + /// + /// - `memset`: `[i, dst, count, value..] -> [i + 1 < count, i + 1, dst, count, value..]` + /// - `memcpy`: `[i, src, dst, count] -> [i + 1 < count, i + 1, src, dst, count]` + fn emit_counted_loop_next_condition(&mut self, count_dup: masm::Instruction, span: SourceSpan) { + self.emit_all( + [ + masm::Instruction::U32WrappingAddImm(1.into()), + masm::Instruction::Dup0, + count_dup, + masm::Instruction::U32Lt, + ], + span, + ); + } + + /// Convert the byte pointer on top of the stack to a word-aligned element address. + /// + /// This traps unless the input byte address is aligned to a 16-byte Miden word boundary. + fn emit_word_aligned_element_addr_from_byte_ptr(&mut self, span: SourceSpan) { + self.emit_all( + [ + masm::Instruction::U32DivModImm(16.into()), + Self::assertz_with_message_inst( + "expected a 16-byte-aligned byte pointer for the word-copy fast path", + span, + ), + // `u32widening_mul` leaves `[lo, hi]` on the stack; assert on `hi` and keep `lo`. + masm::Instruction::U32WideningMulImm(4.into()), + masm::Instruction::Swap1, + Self::assertz_with_message_inst( + "word-copy fast path element address conversion overflowed", + span, + ), + ], + span, + ); + } + + /// Build a MASM block whose stack protocol is managed by the caller. + /// + /// This is used for branch bodies which operate on a known stack shape from the enclosing + /// emitter, but which do not need to synchronize typed operand-stack state back to it. + fn build_masm_block( + &mut self, + span: SourceSpan, + emit: impl FnOnce(&mut OpEmitter<'_>), + ) -> masm::Block { + let mut ops = Vec::default(); + let mut stack = OperandStack::new(self.context_rc()); + let mut emitter = OpEmitter::new(self.invoked, &mut ops, &mut stack); + emit(&mut emitter); + masm::Block::new(span, ops) + } + /// Grow the heap (from the perspective of Wasm programs) by N pages, returning the previous /// size of the heap (in pages) if successful, or -1 if the heap could not be grown. pub fn mem_grow(&mut self, span: SourceSpan) { @@ -219,6 +307,18 @@ impl OpEmitter<'_> { } } + if ty.size_in_bits() == 16 { + self.load_16bit_dynamic(span); + return; + } + + self.load_small_from_current_element(ty, span); + } + + /// Load a sub-word value which is fully contained in the current 32-bit element. + /// + /// Stack transition: `[addr, offset] -> [value]`. + fn load_small_from_current_element(&mut self, ty: &Type, span: SourceSpan) { // Stack: [element_addr, byte_offset] // First, load the aligned word containing our value @@ -257,6 +357,16 @@ impl OpEmitter<'_> { self.emit(masm::Instruction::U32And, span); } + /// Load a 16-bit value from a dynamic native pointer tuple. + /// + /// This delegates to a dedicated intrinsic which owns the complete stack protocol for both the + /// within-element and cross-element cases. + /// + /// Stack transition: `[addr, offset] -> [value]`. + fn load_16bit_dynamic(&mut self, span: SourceSpan) { + self.raw_exec("::intrinsics::mem::load_u16", span); + } + fn load_double_word_imm(&mut self, ptr: NativePtr, span: SourceSpan) { if ptr.is_element_aligned() { self.emit_all( @@ -612,7 +722,11 @@ impl OpEmitter<'_> { body_emitter.emit_all( [ masm::Instruction::U32WideningMadd, // [value_size * i + dst, i, dst, count, value] - masm::Instruction::Assertz, // [aligned_dst, i, dst, count, value..] + masm::Instruction::Swap1, + Self::assertz_with_message_inst( + "memset destination address computation overflowed", + span, + ), // [aligned_dst, i, dst, count, value..] ], span, ); @@ -630,27 +744,15 @@ impl OpEmitter<'_> { body_emitter.store(span); // [i, dst, count, value] // Loop body - increment iteration count, determine whether to continue loop - body_emitter.emit_all( - [ - masm::Instruction::U32WrappingAddImm(1.into()), - masm::Instruction::Dup0, // [i++, i++, dst, count, value] - masm::Instruction::Dup3, // [count, i++, i++, dst, count, value] - masm::Instruction::U32Gte, // [i++ >= count, i++, dst, count, value] - ], - span, - ); + body_emitter.emit_counted_loop_next_condition(masm::Instruction::Dup3, span); + // [i++ < count, i++, dst, count, value] // Switch back to original block and emit loop header and 'while.true' instruction // // Loop header - prepare to loop until `count` iterations have been performed // [dst, count, value..] - self.emit_push(0u32, span); // [i, dst, count, value..] - self.emit(masm::Instruction::Dup2, span); // [count, i, dst, count, value..] - self.emit_push(Felt::ZERO, span); - self.emit( - masm::Instruction::Gte, // [count > 0, i, dst, count, value..] - span, - ); + self.emit_counted_loop_header(masm::Instruction::Dup2, span); + // [count > 0, i, dst, count, value..] self.current_block.push(masm::Op::While { span, body: masm::Block::new(span, body), @@ -670,7 +772,13 @@ impl OpEmitter<'_> { /// /// The semantics of this instruction are as follows: /// - /// * The `` + /// * `count` is expressed in units of the pointee type, not bytes + /// * the effective byte length is `count * size_of(*src)` + /// * `count == 0` leaves memory unchanged and performs no copy + /// * source and destination pointers are interpreted in the address space described by their + /// pointer type + /// * optimized word-copy fast paths are only used for byte-addressable pointers; native + /// pointers fall back to the generic loop pub fn memcpy(&mut self, span: SourceSpan) { let src = self.stack.pop().expect("operand stack is empty"); let dst = self.stack.pop().expect("operand stack is empty"); @@ -681,6 +789,10 @@ impl OpEmitter<'_> { assert_eq!(ty, dst.ty(), "expected src and dst operands to have the same type"); let value_ty = ty.pointee().unwrap().clone(); let value_size = u32::try_from(value_ty.size_in_bytes()).expect("invalid value size"); + let is_byte_pointer = match &ty { + Type::Ptr(ptr_ty) => ptr_ty.is_byte_pointer(), + _ => unreachable!("memcpy expects pointer operands"), + }; // Use optimized intrinsics when available match value_size { @@ -718,94 +830,89 @@ impl OpEmitter<'_> { ); // then: convert byte addresses/count to element units and delegate to core - let mut then_ops = Vec::default(); - let mut then_stack = OperandStack::new(self.context_rc()); - let mut then_emitter = OpEmitter::new(self.invoked, &mut then_ops, &mut then_stack); - then_emitter.emit_all( - [ - // Convert `src` to element address - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - // Convert `dst` to an element address - masm::Instruction::Swap1, - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - // Bring `count` to top to convert to element count - masm::Instruction::Swap2, - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - ], - span, - ); - then_emitter.raw_exec("::miden::core::mem::memcopy_elements", span); - - // else: fall back to the generic implementation - let mut else_ops = Vec::default(); - let mut else_stack = OperandStack::new(self.context_rc()); - let mut else_emitter = OpEmitter::new(self.invoked, &mut else_ops, &mut else_stack); - else_emitter.emit_memcpy_fallback_loop( - src.clone(), - dst.clone(), - count.clone(), - value_ty.clone(), - value_size, - span, - ); + let then_blk = self.build_masm_block(span, |then_emitter| { + then_emitter.emit_all( + [ + // Convert `src` to element address + masm::Instruction::U32DivModImm(4.into()), + Self::assertz_with_message_inst( + "memcpy byte-copy fast path expected the source pointer to be \ + 4-byte aligned", + span, + ), + // Convert `dst` to an element address + masm::Instruction::Swap1, + masm::Instruction::U32DivModImm(4.into()), + Self::assertz_with_message_inst( + "memcpy byte-copy fast path expected the destination pointer to \ + be 4-byte aligned", + span, + ), + // Bring `count` to top to convert to element count + masm::Instruction::Swap2, + masm::Instruction::U32DivModImm(4.into()), + Self::assertz_with_message_inst( + "memcpy byte-copy fast path expected the byte count to be \ + divisible by 4", + span, + ), + ], + span, + ); + then_emitter.raw_exec("::miden::core::mem::memcopy_elements", span); + }); + + let else_blk = self.build_masm_block(span, |else_emitter| { + else_emitter.emit_memcpy_fallback_loop( + src.clone(), + dst.clone(), + count.clone(), + value_ty.clone(), + value_size, + span, + ); + }); self.current_block.push(masm::Op::If { span, - then_blk: masm::Block::new(span, then_ops), - else_blk: masm::Block::new(span, else_ops), + then_blk, + else_blk, }); return; } // Word-sized values have an optimized intrinsic we can lean on - 16 => { - // We have to convert byte addresses to element addresses - self.emit_all( - [ - // Convert `src` to element address, and assert aligned to an element address - // - // TODO: We should probably also assert that the address is word-aligned, but - // that is going to happen anyway. That said, the closer to the source the - // better for debugging. - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - // Convert `dst` to an element address the same way - masm::Instruction::Swap1, - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - // Swap with `count` to get us into the correct ordering: [count, src, dst] - masm::Instruction::Swap2, - ], - span, - ); + 16 if is_byte_pointer => { + // Convert `src` to a word-aligned element address. + self.emit_word_aligned_element_addr_from_byte_ptr(span); + // Convert `dst` to an element address the same way. + self.emit(masm::Instruction::Swap1, span); + self.emit_word_aligned_element_addr_from_byte_ptr(span); + // Swap with `count` to get us into the correct ordering: [count, src, dst]. + self.emit(masm::Instruction::Swap2, span); self.raw_exec("::miden::core::mem::memcopy_words", span); return; } // Values which can be broken up into word-sized chunks can piggy-back on the // intrinsic for word-sized values, but we have to compute a new `count` by // multiplying `count` by the number of words in each value - size if size > 16 && size.is_multiple_of(16) => { + size if is_byte_pointer && size > 16 && size.is_multiple_of(16) => { let factor = size / 16; + // Convert `src` to a word-aligned element address. + self.emit_word_aligned_element_addr_from_byte_ptr(span); + // Convert `dst` to an element address the same way. + self.emit(masm::Instruction::Swap1, span); + self.emit_word_aligned_element_addr_from_byte_ptr(span); self.emit_all( [ - // Convert `src` to element address, and assert aligned to an element address - // - // TODO: We should probably also assert that the address is word-aligned, but - // that is going to happen anyway. That said, the closer to the source the - // better for debugging. - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, - // Convert `dst` to an element address the same way - masm::Instruction::Swap1, - masm::Instruction::U32DivModImm(4.into()), - masm::Instruction::Assertz, // Swap with `count` to get us into the correct ordering: [count, src, dst] masm::Instruction::Swap2, // Compute the corrected count masm::Instruction::U32WideningMulImm(factor.into()), - masm::Instruction::Assertz, // [count * (size / 16), src, dst] + masm::Instruction::Swap1, + Self::assertz_with_message_inst( + "memcpy word-copy fast path element count overflowed", + span, + ), // [count * (size / 16), src, dst] ], span, ); @@ -847,9 +954,13 @@ impl OpEmitter<'_> { body_emitter.emit_all( [ masm::Instruction::U32WideningMadd, - masm::Instruction::Assertz, // [new_dst := i * offset + dst, i, src, dst, count] - masm::Instruction::Dup2, // [src, new_dst, i, src, dst, count] - masm::Instruction::Dup2, // [i, src, new_dst, i, src, dst, count] + masm::Instruction::Swap1, + Self::assertz_with_message_inst( + "memcpy destination address computation overflowed", + span, + ), // [new_dst := i * offset + dst, i, src, dst, count] + masm::Instruction::Dup2, // [src, new_dst, i, src, dst, count] + masm::Instruction::Dup2, // [i, src, new_dst, i, src, dst, count] ], span, ); @@ -857,7 +968,11 @@ impl OpEmitter<'_> { body_emitter.emit_all( [ masm::Instruction::U32WideningMadd, - masm::Instruction::Assertz, // [new_src := i * offset + src, new_dst, i, src, dst, count] + masm::Instruction::Swap1, + Self::assertz_with_message_inst( + "memcpy source address computation overflowed", + span, + ), // [new_src := i * offset + src, new_dst, i, src, dst, count] ], span, ); @@ -876,28 +991,16 @@ impl OpEmitter<'_> { body_emitter.store(span); // [i, src, dst, count] // Increment iteration count, determine whether to continue loop - body_emitter.emit_all( - [ - masm::Instruction::U32WrappingAddImm(1.into()), - masm::Instruction::Dup0, // [i++, i++, src, dst, count] - masm::Instruction::Dup4, // [count, i++, i++, src, dst, count] - masm::Instruction::U32Gte, // [i++ >= count, i++, src, dst, count] - ], - span, - ); + body_emitter.emit_counted_loop_next_condition(masm::Instruction::Dup4, span); + // [i++ < count, i++, src, dst, count] // Switch back to original block and emit loop header and 'while.true' instruction // // Loop header - prepare to loop until `count` iterations have been performed // [src, dst, count] - self.emit_push(0u32, span); // [i, src, dst, count] - self.emit(masm::Instruction::Dup3, span); // [count, i, src, dst, count] - self.emit_push(Felt::ZERO, span); - self.emit( - masm::Instruction::Gte, // [count > 0, i, src, dst, count] - span, - ); + self.emit_counted_loop_header(masm::Instruction::Dup3, span); + // [count > 0, i, src, dst, count] self.current_block.push(masm::Op::While { span, body: masm::Block::new(span, body), @@ -1052,6 +1155,21 @@ impl OpEmitter<'_> { return; } + if type_size == 16 { + self.store_16bit_dynamic(span); + return; + } + + self.store_small_within_element( + u32::try_from(type_size).expect("invalid sub-word type size"), + span, + ); + } + + /// Store a sub-word value which is fully contained in the current 32-bit element. + /// + /// Stack transition: `[addr, offset, value] -> []`. + fn store_small_within_element(&mut self, type_size: u32, span: SourceSpan) { // Stack: [addr, offset, value] // Load the current aligned value self.emit_all( @@ -1100,6 +1218,16 @@ impl OpEmitter<'_> { ); } + /// Store a 16-bit value to a dynamic native pointer tuple. + /// + /// This delegates to a dedicated intrinsic which owns the complete stack protocol for both the + /// within-element and cross-element cases. + /// + /// Stack transition: `[addr, offset, value] -> []`. + fn store_16bit_dynamic(&mut self, span: SourceSpan) { + self.raw_exec("::intrinsics::mem::store_u16", span); + } + /// Store a sub-word value using an immediate pointer /// /// This function stores sub-word values (u8, u16, etc.) to memory at a specific immediate address. @@ -1113,6 +1241,14 @@ impl OpEmitter<'_> { /// - Before: [value] (where value is already truncated to the correct size) /// - After: [] fn store_small_imm(&mut self, ty: &Type, imm: NativePtr, span: SourceSpan) { + if ty.size_in_bits() == 16 && !imm.is_element_aligned() { + // Route unaligned 16-bit immediates through the dynamic path so they share the same + // cross-element windowing logic as byte-pointer stores. + self.push_native_ptr(imm, span); + self.store_small(ty, None, span); + return; + } + assert!(imm.alignment() as usize >= ty.min_alignment()); // For immediate pointers, we always load from the element-aligned address diff --git a/codegen/masm/src/emit/mod.rs b/codegen/masm/src/emit/mod.rs index 16e37a332..24605cea2 100644 --- a/codegen/masm/src/emit/mod.rs +++ b/codegen/masm/src/emit/mod.rs @@ -1,4 +1,4 @@ -use alloc::rc::Rc; +use alloc::{rc::Rc, sync::Arc}; use midenc_session::diagnostics::Span; @@ -155,6 +155,48 @@ pub struct OpEmitter<'a> { current_block: &'a mut Vec, } impl<'a> OpEmitter<'a> { + /// Build a MASM `assert` instruction with an inline diagnostic. + #[inline] + pub fn assert_with_message_inst( + message: impl Into>, + span: SourceSpan, + ) -> masm::Instruction { + masm::Instruction::AssertWithError(masm::Immediate::Value(Span::new(span, message.into()))) + } + + /// Build a MASM `assert_eq` instruction with an inline diagnostic. + #[inline] + pub fn assert_eq_with_message_inst( + message: impl Into>, + span: SourceSpan, + ) -> masm::Instruction { + masm::Instruction::AssertEqWithError(masm::Immediate::Value(Span::new( + span, + message.into(), + ))) + } + + /// Build a MASM `assert_eqw` instruction with an inline diagnostic. + #[inline] + pub fn assert_eqw_with_message_inst( + message: impl Into>, + span: SourceSpan, + ) -> masm::Instruction { + masm::Instruction::AssertEqwWithError(masm::Immediate::Value(Span::new( + span, + message.into(), + ))) + } + + /// Build a MASM `assertz` instruction with an inline diagnostic. + #[inline] + pub fn assertz_with_message_inst( + message: impl Into>, + span: SourceSpan, + ) -> masm::Instruction { + masm::Instruction::AssertzWithError(masm::Immediate::Value(Span::new(span, message.into()))) + } + #[inline(always)] pub fn new( invoked: &'a mut BTreeSet, @@ -748,6 +790,24 @@ mod tests { }; } + /// Assert that the emitted block ends by delegating to the dedicated 16-bit memory intrinsic. + fn assert_unaligned_16bit_intrinsic(block: &[Op], intrinsic: &str) { + let execs = block + .iter() + .filter_map(|op| match op { + Op::Inst(inst) => match inst.inner() { + masm::Instruction::Exec(target) => Some(target.to_string()), + _ => None, + }, + _ => None, + }) + .collect::>(); + assert!( + execs.iter().any(|target| target == intrinsic), + "expected block to delegate to `{intrinsic}`, found execs: {execs:?}" + ); + } + #[test] fn op_emitter_stack_manipulation_test() { let mut block = Vec::default(); @@ -2059,7 +2119,7 @@ mod tests { emitter.assert_eq_imm(ten, SourceSpan::default()); assert_eq!(emitter.stack_len(), 2); - emitter.assert_eq(SourceSpan::default()); + emitter.assert_eq(None, SourceSpan::default()); assert_eq!(emitter.stack_len(), 0); } @@ -2137,6 +2197,66 @@ mod tests { assert_eq!(emitter.stack()[1], Type::U32); } + #[test] + fn op_emitter_unaligned_u16_load_imm_test() { + let mut block = Vec::default(); + let context = Rc::new(Context::default()); + let mut stack = OperandStack::new(context.clone()); + let mut invoked = BTreeSet::default(); + let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack); + + emitter.load_imm(130, Type::U16, SourceSpan::default()); + + assert_eq!(emitter.stack_len(), 1); + assert_eq!(emitter.stack()[0], Type::U16); + assert_unaligned_16bit_intrinsic(&block, "::intrinsics::mem::load_u16"); + } + + #[test] + fn op_emitter_unaligned_i16_load_imm_test() { + let mut block = Vec::default(); + let context = Rc::new(Context::default()); + let mut stack = OperandStack::new(context.clone()); + let mut invoked = BTreeSet::default(); + let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack); + + emitter.load_imm(130, Type::I16, SourceSpan::default()); + + assert_eq!(emitter.stack_len(), 1); + assert_eq!(emitter.stack()[0], Type::I16); + assert_unaligned_16bit_intrinsic(&block, "::intrinsics::mem::load_u16"); + } + + #[test] + fn op_emitter_unaligned_u16_store_imm_test() { + let mut block = Vec::default(); + let context = Rc::new(Context::default()); + let mut stack = OperandStack::new(context.clone()); + let mut invoked = BTreeSet::default(); + let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack); + + emitter.push(Type::U16); + emitter.store_imm(130, SourceSpan::default()); + + assert_eq!(emitter.stack_len(), 0); + assert_unaligned_16bit_intrinsic(&block, "::intrinsics::mem::store_u16"); + } + + #[test] + fn op_emitter_unaligned_i16_store_imm_test() { + let mut block = Vec::default(); + let context = Rc::new(Context::default()); + let mut stack = OperandStack::new(context.clone()); + let mut invoked = BTreeSet::default(); + let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack); + + emitter.push(Type::I16); + emitter.store_imm(130, SourceSpan::default()); + + assert_eq!(emitter.stack_len(), 0); + assert_unaligned_16bit_intrinsic(&block, "::intrinsics::mem::store_u16"); + } + #[test] fn op_emitter_truncate_stack_drops_all_with_remainder() { let mut block = Vec::default(); diff --git a/codegen/masm/src/emit/primop.rs b/codegen/masm/src/emit/primop.rs index 529aa9208..76f13ab25 100644 --- a/codegen/masm/src/emit/primop.rs +++ b/codegen/masm/src/emit/primop.rs @@ -1,4 +1,5 @@ use miden_assembly_syntax::parser::WordValue; +use midenc_dialect_hir::assertions; use midenc_hir::{ Felt, Immediate, SourceSpan, Type, dialects::builtin::attributes::{ArgumentExtension, Signature}, @@ -8,12 +9,26 @@ use super::{OpEmitter, int64, masm}; use crate::TraceEvent; impl OpEmitter<'_> { + /// Format a diagnostic message for a HIR assertion code when one is available. + fn assertion_message(code: Option, default: impl Into) -> String { + let default = default.into(); + match code.filter(|code| *code != 0) { + Some(assertions::ASSERT_FAILED_ALIGNMENT) => { + "pointer address does not meet minimum alignment for the type".into() + } + Some(code) => format!("{default} (assertion code 0x{code:08x})"), + None => default, + } + } + /// Assert that an integer value on the stack has the value 1 /// /// This operation consumes the input value. - pub fn assert(&mut self, _code: Option, span: SourceSpan) { + pub fn assert(&mut self, code: Option, span: SourceSpan) { let arg = self.stack.pop().expect("operand stack is empty"); - match arg.ty() { + let ty = arg.ty().clone(); + let message = Self::assertion_message(code, format!("expected {ty} value to equal 1")); + match ty { Type::Felt | Type::U32 | Type::I32 @@ -22,7 +37,7 @@ impl OpEmitter<'_> { | Type::U8 | Type::I8 | Type::I1 => { - self.emit(masm::Instruction::Assert, span); + self.emit(Self::assert_with_message_inst(message, span), span); } Type::I128 | Type::U128 => { self.emit_all( @@ -31,13 +46,19 @@ impl OpEmitter<'_> { span, WordValue([Felt::ZERO, Felt::ZERO, Felt::ZERO, Felt::ONE]).into(), ))), - masm::Instruction::AssertEqw, + Self::assert_eqw_with_message_inst(message, span), ], span, ); } Type::U64 | Type::I64 => { - self.emit_all([masm::Instruction::Assertz, masm::Instruction::Assert], span); + self.emit_all( + [ + Self::assertz_with_message_inst(message.clone(), span), + Self::assert_with_message_inst(message, span), + ], + span, + ); } ty if !ty.is_integer() => { panic!("invalid argument to assert: expected integer, got {ty}") @@ -49,9 +70,11 @@ impl OpEmitter<'_> { /// Assert that an integer value on the stack has the value 0 /// /// This operation consumes the input value. - pub fn assertz(&mut self, _code: Option, span: SourceSpan) { + pub fn assertz(&mut self, code: Option, span: SourceSpan) { let arg = self.stack.pop().expect("operand stack is empty"); - match arg.ty() { + let ty = arg.ty().clone(); + let message = Self::assertion_message(code, format!("expected {ty} value to equal 0")); + match ty { Type::Felt | Type::U32 | Type::I32 @@ -60,10 +83,16 @@ impl OpEmitter<'_> { | Type::U8 | Type::I8 | Type::I1 => { - self.emit(masm::Instruction::Assertz, span); + self.emit(Self::assertz_with_message_inst(message, span), span); } Type::U64 | Type::I64 => { - self.emit_all([masm::Instruction::Assertz, masm::Instruction::Assertz], span); + self.emit_all( + [ + Self::assertz_with_message_inst(message.clone(), span), + Self::assertz_with_message_inst(message, span), + ], + span, + ); } Type::U128 | Type::I128 => { self.emit_all( @@ -72,7 +101,7 @@ impl OpEmitter<'_> { span, WordValue([Felt::ZERO; 4]).into(), ))), - masm::Instruction::AssertEqw, + Self::assert_eqw_with_message_inst(message, span), ], span, ); @@ -87,11 +116,12 @@ impl OpEmitter<'_> { /// Assert that the top two integer values on the stack have the same value /// /// This operation consumes the input values. - pub fn assert_eq(&mut self, span: SourceSpan) { + pub fn assert_eq(&mut self, code: Option, span: SourceSpan) { let rhs = self.pop().expect("operand stack is empty"); let lhs = self.pop().expect("operand stack is empty"); - let ty = lhs.ty(); + let ty = lhs.ty().clone(); assert_eq!(ty, rhs.ty(), "expected assert_eq operands to have the same type"); + let message = Self::assertion_message(code, format!("expected {ty} values to be equal")); match ty { Type::Felt | Type::U32 @@ -101,17 +131,19 @@ impl OpEmitter<'_> { | Type::U8 | Type::I8 | Type::I1 => { - self.emit(masm::Instruction::AssertEq, span); + self.emit(Self::assert_eq_with_message_inst(message, span), span); + } + Type::U128 | Type::I128 => { + self.emit(Self::assert_eqw_with_message_inst(message, span), span) } - Type::U128 | Type::I128 => self.emit(masm::Instruction::AssertEqw, span), Type::U64 | Type::I64 => { self.emit_all( [ // compare the hi bits masm::Instruction::MovUp2, - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst(message.clone(), span), // compare the low bits - masm::Instruction::AssertEq, + Self::assert_eq_with_message_inst(message, span), ], span, ); @@ -130,7 +162,8 @@ impl OpEmitter<'_> { #[allow(unused)] pub fn assert_eq_imm(&mut self, imm: Immediate, span: SourceSpan) { let lhs = self.pop().expect("operand stack is empty"); - let ty = lhs.ty(); + let ty = lhs.ty().clone(); + let message = format!("expected {ty} value to equal {imm}"); assert_eq!(ty, imm.ty(), "expected assert_eq_imm operands to have the same type"); match ty { Type::Felt @@ -144,14 +177,14 @@ impl OpEmitter<'_> { self.emit_all( [ masm::Instruction::EqImm(imm.as_felt().unwrap().into()), - masm::Instruction::Assert, + Self::assert_with_message_inst(message, span), ], span, ); } Type::I128 | Type::U128 => { self.push_immediate(imm, span); - self.emit(masm::Instruction::AssertEqw, span) + self.emit(Self::assert_eqw_with_message_inst(message, span), span) } Type::I64 | Type::U64 => { let imm = match imm { @@ -163,9 +196,9 @@ impl OpEmitter<'_> { self.emit_all( [ masm::Instruction::EqImm(Felt::new(hi as u64).into()), - masm::Instruction::Assert, + Self::assert_with_message_inst(message.clone(), span), masm::Instruction::EqImm(Felt::new(lo as u64).into()), - masm::Instruction::Assert, + Self::assert_with_message_inst(message, span), ], span, ) diff --git a/codegen/masm/src/emit/smallint.rs b/codegen/masm/src/emit/smallint.rs index 3a66de0ce..20abe1b77 100644 --- a/codegen/masm/src/emit/smallint.rs +++ b/codegen/masm/src/emit/smallint.rs @@ -46,7 +46,13 @@ impl OpEmitter<'_> { 1 => (), n => { self.is_signed_smallint(n, span); - self.emit(masm::Instruction::Assert, span); + self.emit( + Self::assert_with_message_inst( + format!("{n}-bit integer signedness check failed"), + span, + ), + span, + ); } } } diff --git a/codegen/masm/src/emit/unary.rs b/codegen/masm/src/emit/unary.rs index 265343a02..7cda1d5ee 100644 --- a/codegen/masm/src/emit/unary.rs +++ b/codegen/masm/src/emit/unary.rs @@ -345,7 +345,16 @@ impl OpEmitter<'_> { // bit being set will make the i8 larger than 0 or 1 self.emit(masm::Instruction::Dup0, span); self.emit_push(2u32, span); - self.emit_all([masm::Instruction::Lt, masm::Instruction::Assert], span); + self.emit_all( + [ + masm::Instruction::Lt, + Self::assert_with_message_inst( + "expected i8 value to be 0 or 1 when casting to i1", + span, + ), + ], + span, + ); } // i1 (Type::I1, _) => self.zext_smallint(src_bits, dst_bits, span), @@ -436,7 +445,7 @@ impl OpEmitter<'_> { masm::Instruction::Swap1, masm::Instruction::Sub, masm::Instruction::U32OverflowingSubImm(1.into()), - masm::Instruction::Assertz, + Self::assertz_with_message_inst("ilog2 is undefined for zero", span), ], span, ); @@ -935,7 +944,10 @@ impl OpEmitter<'_> { self.emit_all( [ // Assert that the high bits are zero - masm::Instruction::Assertz, + Self::assertz_with_message_inst( + "u64 exponent for pow2 must fit in u32", + span, + ), // This asserts if value > 63, thus result is guaranteed to fit in u64 masm::Instruction::Pow2, // Obtain the u64 representation by splitting the felt result diff --git a/codegen/masm/src/lower/lowering.rs b/codegen/masm/src/lower/lowering.rs index b0e1bd65f..c4a840321 100644 --- a/codegen/masm/src/lower/lowering.rs +++ b/codegen/masm/src/lower/lowering.rs @@ -13,7 +13,9 @@ use midenc_session::diagnostics::{Report, Severity, Spanned}; use smallvec::{SmallVec, smallvec}; use super::*; -use crate::{Constraint, emitter::BlockEmitter, masm, opt::operands::SolverOptions}; +use crate::{ + Constraint, emit::OpEmitter, emitter::BlockEmitter, masm, opt::operands::SolverOptions, +}; /// Convert a resolved callee [`midenc_hir::SymbolPath`] into a MASM [`masm::InvocationTarget`]. fn invocation_target_from_symbol_path( @@ -462,7 +464,9 @@ impl HirLowering for hir::Assertz { impl HirLowering for hir::AssertEq { fn emit(&self, emitter: &mut BlockEmitter<'_>) -> Result<(), Report> { - emitter.emitter().assert_eq(self.span()); + let code = *self.get_code(); + + emitter.emitter().assert_eq(Some(code), self.span()); Ok(()) } @@ -475,7 +479,8 @@ impl HirLowering for ub::Unreachable { let span = self.span(); let mut op_emitter = emitter.emitter(); op_emitter.emit_push(0u32, span); - op_emitter.emit(masm::Instruction::Assert, span); + op_emitter + .emit(OpEmitter::assert_with_message_inst("entered unreachable code", span), span); Ok(()) } diff --git a/tests/integration/src/codegen/intrinsics/mem.rs b/tests/integration/src/codegen/intrinsics/mem.rs index bdac4e28b..2c1929ef8 100644 --- a/tests/integration/src/codegen/intrinsics/mem.rs +++ b/tests/integration/src/codegen/intrinsics/mem.rs @@ -365,6 +365,121 @@ fn load_u16() { } } +macro_rules! define_unaligned_16bit_load_tests { + ( + $run_fn:ident, + $rust_ty:ty, + $hir_ty:expr, + $offset_1_test:ident, + $offset_2_test:ident, + $offset_3_test:ident + ) => { + #[doc = concat!( + "Runs a `", + stringify!($rust_ty), + "` load test from the specified unaligned byte offset." + )] + fn $run_fn(offset: u32) { + setup::enable_compiler_instrumentation(); + + let write_to = 17 * 2u32.pow(16); + let read_from = write_to + offset; + + let (package, context) = compile_test_module( + [Type::from(PointerType::new($hir_ty))], + [$hir_ty], + |builder| { + let block = builder.current_block(); + let ptr = block.borrow().arguments()[0] as ValueRef; + let loaded = builder.load(ptr, SourceSpan::default()).unwrap(); + builder.ret(Some(loaded), SourceSpan::default()).unwrap(); + }, + ); + + let config = proptest::test_runner::Config::with_cases(10); + let res = TestRunner::new(config).run(&any::<$rust_ty>(), move |value| { + let expected = value.to_le_bytes(); + let mut initial_bytes = [0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88]; + initial_bytes[offset as usize] = expected[0]; + initial_bytes[offset as usize + 1] = expected[1]; + let initializers = [Initializer::MemoryBytes { + addr: write_to, + bytes: &initial_bytes, + }]; + + let args = [Felt::new(read_from as u64)]; + let output = eval_package::<$rust_ty, _, _>( + &package, + initializers, + &args, + context.session(), + |_| Ok(()), + )?; + + prop_assert_eq!(output, value, "expected 0x{:x}; found 0x{:x}", value, output,); + + Ok(()) + }); + + match res { + Err(TestError::Fail(reason, value)) => { + panic!("FAILURE: {}\nMinimal failing case: {value:?}", reason.message()); + } + Ok(_) => (), + _ => panic!("Unexpected test result: {res:?}"), + } + } + + #[doc = concat!( + "Tests that loading a `", + stringify!($rust_ty), + "` from byte offset 1 stays within the current element." + )] + #[test] + fn $offset_1_test() { + $run_fn(1); + } + + #[doc = concat!( + "Tests that loading a `", + stringify!($rust_ty), + "` from byte offset 2 stays within the current element." + )] + #[test] + fn $offset_2_test() { + $run_fn(2); + } + + #[doc = concat!( + "Tests that loading a `", + stringify!($rust_ty), + "` from byte offset 3 correctly reconstructs the value across the next element \ + boundary." + )] + #[test] + fn $offset_3_test() { + $run_fn(3); + } + }; +} + +define_unaligned_16bit_load_tests!( + run_load_unaligned_u16, + u16, + Type::U16, + load_unaligned_u16_offset_1, + load_unaligned_u16_offset_2, + load_unaligned_u16 +); +define_unaligned_16bit_load_tests!( + run_load_unaligned_i16, + i16, + Type::I16, + load_unaligned_i16_offset_1, + load_unaligned_i16_offset_2, + load_unaligned_i16 +); + /// Tests the memory load intrinsic for loads of boolean (i.e. 1-bit) values #[test] fn load_bool() { @@ -570,6 +685,165 @@ fn store_u16() { } } +macro_rules! define_unaligned_16bit_store_tests { + ( + $run_fn:ident, + $rust_ty:ty, + $hir_ty:expr, + $to_felt:expr, + $offset_1_test:ident, + $offset_2_test:ident, + $offset_3_test:ident + ) => { + #[doc = concat!( + "Runs a `", + stringify!($rust_ty), + "` store test at the specified unaligned byte offset." + )] + fn $run_fn(offset: u32) { + setup::enable_compiler_instrumentation(); + + let write_to = 17 * 2u32.pow(16); + let store_to = write_to + offset; + + let (package, context) = compile_test_module([$hir_ty], [Type::U32], |builder| { + let block = builder.current_block(); + let value = block.borrow().arguments()[0] as ValueRef; + + let addr = builder.u32(store_to, SourceSpan::default()); + let ptr = builder + .inttoptr(addr, Type::from(PointerType::new($hir_ty)), SourceSpan::default()) + .unwrap(); + + builder.store(ptr, value, SourceSpan::default()).unwrap(); + + let result = builder.u32(1, SourceSpan::default()); + builder.ret(Some(result), SourceSpan::default()).unwrap(); + }); + + let config = proptest::test_runner::Config::with_cases(32); + let res = TestRunner::new(config).run(&any::<$rust_ty>(), move |store_value| { + let initial_bytes = [0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88]; + let initializers = [Initializer::MemoryBytes { + addr: write_to, + bytes: &initial_bytes, + }]; + + let args = [($to_felt)(store_value)]; + let output = eval_package::( + &package, + initializers, + &args, + context.session(), + |trace| { + let expected = store_value.to_le_bytes(); + let mut expected_bytes = initial_bytes; + expected_bytes[offset as usize] = expected[0]; + expected_bytes[offset as usize + 1] = expected[1]; + + let word0 = + trace.read_from_rust_memory::(write_to).ok_or_else(|| { + TestCaseError::fail(format!( + "failed to read from byte address {write_to}" + )) + })?; + let word1 = + trace.read_from_rust_memory::(write_to + 4).ok_or_else(|| { + TestCaseError::fail(format!( + "failed to read from byte address {}", + write_to + 4 + )) + })?; + let observed_bytes = [ + (word0 & 0xff) as u8, + ((word0 >> 8) & 0xff) as u8, + ((word0 >> 16) & 0xff) as u8, + ((word0 >> 24) & 0xff) as u8, + (word1 & 0xff) as u8, + ((word1 >> 8) & 0xff) as u8, + ((word1 >> 16) & 0xff) as u8, + ((word1 >> 24) & 0xff) as u8, + ]; + + for (index, (stored, expected_byte)) in + observed_bytes.into_iter().zip(expected_bytes).enumerate() + { + prop_assert_eq!( + stored, + expected_byte, + "unexpected byte at address {}", + write_to + index as u32 + ); + } + + Ok(()) + }, + )?; + + prop_assert_eq!(output, 1u32); + Ok(()) + }); + + match res { + Err(TestError::Fail(reason, value)) => { + panic!("FAILURE: {}\nMinimal failing case: {value:?}", reason.message()); + } + Ok(_) => (), + _ => panic!("Unexpected test result: {res:?}"), + } + } + + #[doc = concat!( + "Tests that storing a `", + stringify!($rust_ty), + "` at byte offset 1 updates only the target bytes." + )] + #[test] + fn $offset_1_test() { + $run_fn(1); + } + + #[doc = concat!( + "Tests that storing a `", + stringify!($rust_ty), + "` at byte offset 2 updates only the target bytes." + )] + #[test] + fn $offset_2_test() { + $run_fn(2); + } + + #[doc = concat!( + "Tests that storing a `", + stringify!($rust_ty), + "` at byte offset 3 updates only the target bytes across the element boundary." + )] + #[test] + fn $offset_3_test() { + $run_fn(3); + } + }; +} + +define_unaligned_16bit_store_tests!( + run_store_unaligned_u16, + u16, + Type::U16, + |store_value: u16| Felt::new(store_value as u64), + store_unaligned_u16_offset_1, + store_unaligned_u16_offset_2, + store_unaligned_u16 +); +define_unaligned_16bit_store_tests!( + run_store_unaligned_i16, + i16, + Type::I16, + |store_value: i16| Felt::new(store_value as u16 as u64), + store_unaligned_i16_offset_1, + store_unaligned_i16_offset_2, + store_unaligned_i16 +); + /// Tests that u8 stores only affect the targeted byte and don't corrupt surrounding memory #[test] fn store_u8() { diff --git a/tests/integration/src/rust_masm_tests/instructions.rs b/tests/integration/src/rust_masm_tests/instructions.rs index 1d23401c4..653cdeebd 100644 --- a/tests/integration/src/rust_masm_tests/instructions.rs +++ b/tests/integration/src/rust_masm_tests/instructions.rs @@ -14,7 +14,7 @@ use proptest::{ use super::run_masm_vs_rust; use crate::{ CompilerTest, - testing::{Initializer, eval_package}, + testing::{Initializer, eval_package, setup}, }; macro_rules! test_bin_op { @@ -834,3 +834,545 @@ fn test_hmerge() { _ => panic!("Unexpected test result: {res:?}"), } } + +#[test] +fn test_memory_copy_aligned_bytes() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u32; 12], src: &[u32; 16]) { + unsafe { + let src_ptr = (src.as_ptr() as *const u8).add(4); + let dst_ptr = dst.as_mut_ptr() as *mut u8; + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 48); + } + } + + let mut src = [0u32; 16]; + let src_bytes = src.as_mut_ptr() as *mut u8; + let mut i = 0usize; + while i < 64 { + unsafe { *src_bytes.add(i) = i as u8; } + i += 1; + } + + let mut dst = [0u32; 12]; + do_copy(&mut dst, &src); + + let dst_bytes = dst.as_ptr() as *const u8; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 48 { + let observed = unsafe { *dst_bytes.add(i) }; + if observed != (i as u8).wrapping_add(4) { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_aligned_bytes_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_u128_fast_path() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u128; 2], src: &[u128; 3]) { + unsafe { + let src_ptr = src.as_ptr().add(1); + let dst_ptr = dst.as_mut_ptr(); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 2); + } + } + + let src = [ + 0x00112233445566778899aabbccddeeff_u128, + 0x102132435465768798a9bacbdcedfe0f_u128, + 0xfedcba98765432100123456789abcdef_u128, + ]; + let mut dst = [0u128; 2]; + do_copy(&mut dst, &src); + + let expected = [src[1], src[2]]; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 2 { + if dst[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_u128_fast_path", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_multiword_fast_path() { + let main_fn = r#"() -> Felt { + struct Chunk([u128; 2]); + + #[inline(never)] + fn do_copy(dst: &mut [Chunk; 1], src: &[Chunk; 2]) { + unsafe { + let src_ptr = src.as_ptr().add(1); + let dst_ptr = dst.as_mut_ptr(); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 1); + } + } + + let src = [ + Chunk([ + 0x00112233445566778899aabbccddeeff_u128, + 0x112233445566778899aabbccddeeff00_u128, + ]), + Chunk([ + 0xaabbccddeeff00112233445566778899_u128, + 0xffeeddccbbaa99887766554433221100_u128, + ]), + ]; + let mut dst = [Chunk([0u128; 2])]; + do_copy(&mut dst, &src); + + let expected = &src[1].0; + let observed = &dst[0].0; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 2 { + if observed[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_multiword_fast_path", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_aligned_addresses_misaligned_count() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u32; 12], src: &[u32; 16]) { + unsafe { + let src_ptr = (src.as_ptr() as *const u8).add(4); + let dst_ptr = dst.as_mut_ptr() as *mut u8; + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 47); + } + } + + let mut src = [0u32; 16]; + let src_bytes = src.as_mut_ptr() as *mut u8; + let mut i = 0usize; + while i < 64 { + unsafe { *src_bytes.add(i) = i as u8; } + i += 1; + } + + let mut dst = [0xffff_ffffu32; 12]; + do_copy(&mut dst, &src); + + let dst_bytes = dst.as_ptr() as *const u8; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 48 { + let observed = unsafe { *dst_bytes.add(i) }; + let expected = if i < 47 { + (i as u8).wrapping_add(4) + } else { + 0xff + }; + if observed != expected { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_aligned_addresses_misaligned_count_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_unaligned() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u8; 48], src: &[u8; 64]) { + unsafe { + let src_ptr = src.as_ptr().add(3); + let dst_ptr = dst.as_mut_ptr(); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 48); + } + } + + let mut src = [0u8; 64]; + let mut i = 0usize; + while i < 64 { + src[i] = i as u8; + i += 1; + } + + let mut dst = [0u8; 48]; + do_copy(&mut dst, &src); + + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 48 { + if dst[i] != (i as u8).wrapping_add(3) { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_unaligned_src_len_48_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_unaligned_dst() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u8; 53], src: &[u8; 64]) { + unsafe { + let src_ptr = src.as_ptr().add(3); + let dst_ptr = dst.as_mut_ptr().add(5); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 48); + } + } + + let mut src = [0u8; 64]; + let mut i = 0usize; + while i < 64 { + src[i] = i as u8; + i += 1; + } + + let mut dst = [0xffu8; 53]; + do_copy(&mut dst, &src); + + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 53 { + let expected = if i < 5 { 0xff } else { (i as u8).wrapping_sub(2) }; + if dst[i] != expected { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_unaligned_dst_len_48_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_unaligned_dst_short_count() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u8; 8], src: &[u8; 16]) { + unsafe { + let src_ptr = src.as_ptr().add(3); + let dst_ptr = dst.as_mut_ptr().add(2); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 3); + } + } + + let mut src = [0u8; 16]; + let mut i = 0usize; + while i < 16 { + src[i] = i as u8; + i += 1; + } + + let mut dst = [0xffu8; 8]; + do_copy(&mut dst, &src); + + let expected = [0xffu8, 0xff, 3, 4, 5, 0xff, 0xff, 0xff]; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 8 { + if dst[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_unaligned_dst_short_count_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_copy_unaligned_zero_count() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_copy(dst: &mut [u8; 8], src: &[u8; 16]) { + unsafe { + let src_ptr = src.as_ptr().add(1); + let dst_ptr = dst.as_mut_ptr().add(2); + core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, 0); + } + } + + let mut src = [0u8; 16]; + let mut i = 0usize; + while i < 16 { + src[i] = i as u8; + i += 1; + } + + let mut dst = [0xffu8; 8]; + do_copy(&mut dst, &src); + + let expected = [0xffu8; 8]; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 8 { + if dst[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_copy_unaligned_zero_count_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_set_unaligned() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_set(dst: &mut [u8; 11]) { + unsafe { + let dst_ptr = dst.as_mut_ptr().add(3); + core::ptr::write_bytes(dst_ptr, 0x5a, 5); + } + } + + let mut dst = [0xffu8; 11]; + do_set(&mut dst); + + let expected = [0xffu8, 0xff, 0xff, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0xff, 0xff, 0xff]; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 11 { + if dst[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = + CompilerTest::rust_fn_body_with_stdlib_sys("memory_set_unaligned_u8s", main_fn, config, []); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +} + +#[test] +fn test_memory_set_unaligned_zero_count() { + let main_fn = r#"() -> Felt { + #[inline(never)] + fn do_set(dst: &mut [u8; 11]) { + unsafe { + let dst_ptr = dst.as_mut_ptr().add(3); + core::ptr::write_bytes(dst_ptr, 0x5a, 0); + } + } + + let mut dst = [0xffu8; 11]; + do_set(&mut dst); + + let expected = [0xffu8; 11]; + let mut mismatches = 0u32; + let mut i = 0usize; + while i < 11 { + if dst[i] != expected[i] { + mismatches += 1; + } + i += 1; + } + + Felt::from_u32(mismatches) + }"#; + + setup::enable_compiler_instrumentation(); + let config = WasmTranslationConfig::default(); + let mut test = CompilerTest::rust_fn_body_with_stdlib_sys( + "memory_set_unaligned_zero_count_u8s", + main_fn, + config, + [], + ); + + let package = test.compile_package(); + let args: [Felt; 0] = []; + + eval_package::(&package, [], &args, &test.session, |trace| { + let res: Felt = trace.parse_result().unwrap(); + assert_eq!(res, Felt::ZERO); + Ok(()) + }) + .unwrap(); +}