diff --git a/Cargo.lock b/Cargo.lock index 91f2173a2..f4981c252 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,6 +40,12 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "bytemuck" +version = "1.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9134a6ef01ce4b366b50689c94f82c14bc72bc5d0386829828a2e2752ef7958c" + [[package]] name = "byteorder" version = "1.5.0" @@ -152,6 +158,7 @@ dependencies = [ "assert_matches", "atomig", "bitflags", + "bytemuck", "cc", "cfg-if", "libc", diff --git a/Cargo.toml b/Cargo.toml index b6412f730..d63b5b7b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ raw-cpuid = "11.0.1" strum = { version = "0.26", features = ["derive"] } to_method = "1.1.0" zerocopy = { version = "0.7.32", features = ["derive"] } +bytemuck = { version = "1.23.0" } [build-dependencies] cc = "1.0.79" diff --git a/src/lf_mask.rs b/src/lf_mask.rs index 113e4e1c4..3ec6ebd70 100644 --- a/src/lf_mask.rs +++ b/src/lf_mask.rs @@ -14,6 +14,7 @@ use crate::src::levels::SegmentId; use crate::src::levels::TxfmSize; use crate::src::relaxed_atomic::RelaxedAtomic; use crate::src::tables::dav1d_txfm_dimensions; +use bytemuck; use libc::ptrdiff_t; use parking_lot::RwLock; use std::cmp; @@ -149,7 +150,7 @@ fn decomp_tx( }; } -#[inline] +#[inline(always)] fn mask_edges_inter( masks: &[[[[RelaxedAtomic; 2]; 3]; 32]; 2], by4: usize, @@ -261,7 +262,7 @@ fn mask_edges_inter( a[..w4].copy_from_slice(txa_slice); } -#[inline] +#[inline(always)] fn mask_edges_intra( masks: &[[[[RelaxedAtomic; 2]; 3]; 32]; 2], by4: usize, @@ -452,12 +453,13 @@ pub(crate) fn rav1d_create_lf_mask_intra( if bw4 != 0 && bh4 != 0 { let mut level_cache_off = by * b4_stride + bx; for _y in 0..bh4 { + let idx = 4 * level_cache_off; + let lvl = &mut *level_cache.index_mut((idx + 0.., ..4 * bw4)); + let lvl: &mut [[u8; 4]] = bytemuck::cast_slice_mut(lvl); for x in 0..bw4 { - let idx = 4 * (level_cache_off + x); // `0.., ..2` is for Y - let lvl = &mut *level_cache.index_mut((idx + 0.., ..2)); - lvl[0] = filter_level[0][0][0]; - lvl[1] = filter_level[1][0][0]; + lvl[x][0] = filter_level[0][0][0]; + lvl[x][1] = filter_level[1][0][0]; } level_cache_off += b4_stride; } @@ -490,12 +492,13 @@ pub(crate) fn rav1d_create_lf_mask_intra( let mut level_cache_off = (by >> ss_ver) * b4_stride + (bx >> ss_hor); for _y in 0..cbh4 { + let idx = 4 * level_cache_off; + let lvl = &mut *level_cache.index_mut((idx + 0.., ..4 * cbw4)); + let lvl: &mut [[u8; 4]] = bytemuck::cast_slice_mut(lvl); for x in 0..cbw4 { - let idx = 4 * (level_cache_off + x); // `2.., ..2` is for UV - let lvl = &mut *level_cache.index_mut((idx + 2.., ..2)); - lvl[0] = filter_level[2][0][0]; - lvl[1] = filter_level[3][0][0]; + lvl[x][2] = filter_level[2][0][0]; + lvl[x][3] = filter_level[3][0][0]; } level_cache_off += b4_stride; } @@ -550,12 +553,13 @@ pub(crate) fn rav1d_create_lf_mask_inter( if bw4 != 0 && bh4 != 0 { let mut level_cache_off = by * b4_stride + bx; for _y in 0..bh4 { + let idx = 4 * level_cache_off; + let lvl = &mut *level_cache.index_mut((idx + 0.., ..4 * bw4)); + let lvl: &mut [[u8; 4]] = bytemuck::cast_slice_mut(lvl); for x in 0..bw4 { - let idx = 4 * (level_cache_off + x); - // `0.., ..2` is for Y - let lvl = &mut *level_cache.index_mut((idx + 0.., ..2)); - lvl[0] = filter_level[0][r#ref][is_gmv]; - lvl[1] = filter_level[1][r#ref][is_gmv]; + // 0, 1 is for Y + lvl[x][0] = filter_level[0][r#ref][is_gmv]; + lvl[x][1] = filter_level[1][r#ref][is_gmv]; } level_cache_off += b4_stride; } @@ -599,12 +603,13 @@ pub(crate) fn rav1d_create_lf_mask_inter( let mut level_cache_off = (by >> ss_ver) * b4_stride + (bx >> ss_hor); for _y in 0..cbh4 { + let idx = 4 * level_cache_off; + let lvl = &mut *level_cache.index_mut((idx + 0.., ..4 * cbw4)); + let lvl: &mut [[u8; 4]] = bytemuck::cast_slice_mut(lvl); for x in 0..cbw4 { - let idx = 4 * (level_cache_off + x); - // `2.., ..2` is for UV - let lvl = &mut *level_cache.index_mut((idx + 2.., ..2)); - lvl[0] = filter_level[2][r#ref][is_gmv]; - lvl[1] = filter_level[3][r#ref][is_gmv]; + // 2, 3 is for UV + lvl[x][2] = filter_level[2][r#ref][is_gmv]; + lvl[x][3] = filter_level[3][r#ref][is_gmv]; } level_cache_off += b4_stride; }