Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/align.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ macro_rules! impl_ArrayDefault {
// We want this to be implemented for all `T: Default` where `T` is not `[_; _]`,
// but we can't do that, so we can just add individual
// `impl`s here for types we need it for.
impl_ArrayDefault!(bool);
impl_ArrayDefault!(u8);
impl_ArrayDefault!(i8);
impl_ArrayDefault!(i16);
Expand Down
6 changes: 3 additions & 3 deletions src/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::align::{Align16, Align32, Align4, Align8};
use crate::error::Rav1dResult;
use crate::include::dav1d::headers::{Rav1dFilterMode, Rav1dFrameHeader};
use crate::levels::{
BlockLevel, BlockPartition, BlockSize, MVJoint, SegmentId, TxfmSize, N_COMP_INTER_PRED_MODES,
BlockLevel, BlockPartition, BlockSize, CompInterPredMode, MVJoint, SegmentId, TxfmSize,
N_INTRA_PRED_MODES, N_UV_INTRA_PRED_MODES,
};
use crate::tables::DAV1D_PARTITION_TYPE_COUNT;
Expand Down Expand Up @@ -106,7 +106,7 @@ pub struct CdfModeInterContext {
// inter/switch
pub y_mode: Align32<[[u16; N_INTRA_PRED_MODES + 3]; 4]>,
pub wedge_idx: Align32<[[u16; 16]; 9]>,
pub comp_inter_mode: Align16<[[u16; N_COMP_INTER_PRED_MODES]; 8]>,
pub comp_inter_mode: Align16<[[u16; CompInterPredMode::COUNT]; 8]>,
pub filter: Align8<[[[u16; Rav1dFilterMode::N_FILTERS]; 8]; 2]>,
pub interintra_mode: Align8<[[u16; 4]; 4]>,
pub motion_mode: Align8<[[u16; 3 + 1]; BlockSize::COUNT]>,
Expand Down Expand Up @@ -5039,7 +5039,7 @@ pub(crate) fn rav1d_cdf_thread_update(

update_cdf_2d!(4, N_INTRA_PRED_MODES - 1, mi.y_mode);
update_cdf_2d!(9, 15, mi.wedge_idx);
update_cdf_2d!(8, N_COMP_INTER_PRED_MODES - 1, mi.comp_inter_mode);
update_cdf_2d!(8, CompInterPredMode::COUNT - 1, mi.comp_inter_mode);
update_cdf_3d!(
2,
8,
Expand Down
343 changes: 187 additions & 156 deletions src/decode.rs

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/enum_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ impl<T> DefaultValue for Option<T> {
const DEFAULT: Self = None;
}

impl<T: DefaultValue, const N: usize> DefaultValue for [T; N] {
const DEFAULT: Self = [T::DEFAULT; N];
}

/// A map from an `enum` key `K` to `V`s.
/// `N` is the number of possible `enum` values.
pub struct EnumMap<K, V, const N: usize>
Expand Down
4 changes: 2 additions & 2 deletions src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub struct BlockContext {
pub lcoef: DisjointMut<Align8<[u8; 32]>>,
pub ccoef: [DisjointMut<Align8<[u8; 32]>>; 2],
pub seg_pred: DisjointMut<Align8<[u8; 32]>>,
pub skip: DisjointMut<Align8<[u8; 32]>>,
pub skip_mode: DisjointMut<Align8<[u8; 32]>>,
pub skip: DisjointMut<Align8<[bool; 32]>>,
pub skip_mode: DisjointMut<Align8<[bool; 32]>>,
pub intra: DisjointMut<Align8<[u8; 32]>>,
pub comp_type: DisjointMut<Align8<[Option<CompInterType>; 32]>>,
pub r#ref: [DisjointMut<Align8<[i8; 32]>>; 2],
Expand Down
31 changes: 15 additions & 16 deletions src/in_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hint::assert_unchecked;

use zerocopy::{AsBytes, FromBytes, FromZeroes};

use crate::const_fn::const_for;
use crate::enum_map::DefaultValue;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct InRange<T, const MIN: u128, const MAX: u128>(T);
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, FromZeroes, FromBytes, AsBytes)]
#[repr(transparent)]
pub struct InRange<T, const MIN: i128, const MAX: i128>(T);

impl<T, const MIN: u128, const MAX: u128> InRange<T, MIN, MAX>
impl<T, const MIN: i128, const MAX: i128> InRange<T, MIN, MAX>
where
T: TryFrom<u128, Error: Debug>,
T: TryFrom<i128, Error: Debug>,
{
pub fn min() -> Self {
Self(MIN.try_into().unwrap())
Expand All @@ -21,9 +24,9 @@ where
}
}

impl<T, const MIN: u128, const MAX: u128> InRange<T, MIN, MAX>
impl<T, const MIN: i128, const MAX: i128> InRange<T, MIN, MAX>
where
T: TryFrom<u128, Error: Debug> + PartialEq + Eq + PartialOrd + Ord,
T: TryFrom<i128, Error: Debug> + PartialEq + Eq + PartialOrd + Ord,
{
fn in_bounds(&self) -> bool {
*self >= Self::min() && *self <= Self::max()
Expand All @@ -45,16 +48,16 @@ where
}
}

impl<T, const MIN: u128, const MAX: u128> Default for InRange<T, MIN, MAX>
impl<T, const MIN: i128, const MAX: i128> Default for InRange<T, MIN, MAX>
where
T: TryFrom<u128, Error: Debug>,
T: TryFrom<i128, Error: Debug>,
{
fn default() -> Self {
Self::min()
}
}

impl<T, const MIN: u128, const MAX: u128> Display for InRange<T, MIN, MAX>
impl<T, const MIN: i128, const MAX: i128> Display for InRange<T, MIN, MAX>
where
T: Display,
{
Expand All @@ -65,14 +68,14 @@ where

macro_rules! impl_const_new {
($T:ty) => {
impl<const MIN: u128, const MAX: u128> DefaultValue for InRange<$T, MIN, MAX> {
impl<const MIN: i128, const MAX: i128> DefaultValue for InRange<$T, MIN, MAX> {
const DEFAULT: Self = Self(0);
}

impl<const MIN: u128, const MAX: u128> InRange<$T, MIN, MAX> {
impl<const MIN: i128, const MAX: i128> InRange<$T, MIN, MAX> {
#[allow(unused)]
pub const fn const_new(value: $T) -> Self {
assert!(value as u128 >= MIN && value as u128 <= MAX);
assert!(value as i128 >= MIN && value as i128 <= MAX);
Self(value)
}

Expand All @@ -85,10 +88,6 @@ macro_rules! impl_const_new {
b
}
}

impl<const N: usize, const MIN: u128, const MAX: u128> DefaultValue for [InRange<$T, MIN, MAX>; N] {
const DEFAULT: Self = [DefaultValue::DEFAULT; N];
}
};
}

Expand Down
108 changes: 81 additions & 27 deletions src/levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use strum::{EnumCount, FromRepr};
use zerocopy::{AsBytes, FromBytes, FromZeroes};

use crate::align::ArrayDefault;
use crate::enum_map::{DefaultValue, EnumKey};
use crate::enum_map::{enum_map, DefaultValue, EnumKey};
use crate::in_range::InRange;
use crate::include::dav1d::headers::Rav1dFilterMode;

Expand Down Expand Up @@ -206,9 +206,10 @@ impl BlockPartition {
pub const N_SUB8X8_PARTITIONS: usize = 4;
}

#[derive(Clone, Copy, PartialEq, Eq, FromRepr, EnumCount, FromZeroes, Default)]
#[repr(u8)]
#[derive(Clone, Copy, PartialEq, Eq, FromRepr, EnumCount, FromZeroes)]
pub enum BlockSize {
#[default]
Bs128x128 = 0,
Bs128x64 = 1,
Bs64x128 = 2,
Expand Down Expand Up @@ -302,12 +303,18 @@ bitflags! {
}
}

pub type InterPredMode = u8;
pub const _N_INTER_PRED_MODES: usize = 4;
pub const NEWMV: InterPredMode = 3;
pub const GLOBALMV: InterPredMode = 2;
pub const NEARMV: InterPredMode = 1;
pub const NEARESTMV: InterPredMode = 0;
#[expect(clippy::enum_variant_names, reason = "match dav1d naming")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterPredMode {
NearestMv = 0,
NearMv = 1,
GlobalMv = 2,
NewMv = 3,
}

impl DefaultValue for InterPredMode {
const DEFAULT: Self = Self::NearestMv;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum DrlProximity {
Expand All @@ -318,16 +325,59 @@ pub enum DrlProximity {
Nearish,
}

pub type CompInterPredMode = u8;
pub const N_COMP_INTER_PRED_MODES: usize = 8;
pub const NEWMV_NEWMV: CompInterPredMode = 7;
pub const GLOBALMV_GLOBALMV: CompInterPredMode = 6;
pub const NEWMV_NEARMV: CompInterPredMode = 5;
pub const NEARMV_NEWMV: CompInterPredMode = 4;
pub const NEWMV_NEARESTMV: CompInterPredMode = 3;
pub const NEARESTMV_NEWMV: CompInterPredMode = 2;
pub const NEARMV_NEARMV: CompInterPredMode = 1;
pub const NEARESTMV_NEARESTMV: CompInterPredMode = 0;
/// Sometimes this can store a [`InterPredMode`] instead, which is smaller.
#[expect(clippy::enum_variant_names, reason = "match dav1d naming")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromRepr, EnumCount, Default)]
pub enum CompInterPredMode {
#[default]
NearestMvNearestMv = 0,
NearMvNearMv = 1,
NearestMvNewMv = 2,
NewMvNearestMv = 3,
NearMvNewMv = 4,
NewMvNearMv = 5,
GlobalMvGlobalMv = 6,
NewMvNewMv = 7,
}

impl From<InterPredMode> for CompInterPredMode {
fn from(value: InterPredMode) -> Self {
CompInterPredMode::from_repr(value as usize).unwrap()
}
}

impl EnumKey<{ Self::COUNT }> for CompInterPredMode {
const VALUES: [Self; Self::COUNT] = [
Self::NearestMvNearestMv,
Self::NearMvNearMv,
Self::NearestMvNewMv,
Self::NewMvNearestMv,
Self::NearMvNewMv,
Self::NewMvNearMv,
Self::GlobalMvGlobalMv,
Self::NewMvNewMv,
];

fn as_usize(self) -> usize {
self as usize
}
}

impl CompInterPredMode {
pub fn split(self) -> [InterPredMode; 2] {
use InterPredMode::*;
enum_map!(CompInterPredMode => [InterPredMode; 2]; match key {
NearestMvNearestMv => [NearestMv, NearestMv],
NearMvNearMv => [NearMv, NearMv],
NearestMvNewMv => [NearestMv, NewMv],
NewMvNearestMv => [NewMv, NearestMv],
NearMvNewMv => [NearMv, NewMv],
NewMvNearMv => [NewMv, NearMv],
GlobalMvGlobalMv => [GlobalMv, GlobalMv],
NewMvNewMv => [NewMv, NewMv],
})[self]
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CompInterType {
Expand Down Expand Up @@ -484,11 +534,13 @@ impl From<MaskedInterIntraPredMode> for InterIntraPredMode {
}
}

pub type WedgeIdx = InRange<u8, 0, 15>;

#[derive(Clone, Default, FromZeroes, FromBytes, AsBytes)]
#[repr(C)]
pub struct Av1BlockInter1d {
pub mv: [Mv; 2],
pub wedge_idx: u8,
pub wedge_idx: WedgeIdx,

/// Stored as a [`u8`] since [`bool`] is not [`FromBytes`].
pub mask_sign: u8,
Expand All @@ -515,7 +567,7 @@ pub struct Av1BlockInter2d {
pub matrix: [i16; 4],
}

#[derive(Clone)]
#[derive(Clone, Default)]
#[repr(C)]
pub struct Av1BlockInterNd {
/// Make [`Av1BlockInter1d`] the field instead of [`Av1BlockInter2d`]
Expand Down Expand Up @@ -552,15 +604,17 @@ impl From<Av1BlockInter2d> for Av1BlockInterNd {
}
}

#[derive(Clone)]
pub type Av1BlockInterRefIndex = InRange<i8, -1, 6>;

#[derive(Clone, Default)]
#[repr(C)]
pub struct Av1BlockInter {
pub nd: Av1BlockInterNd,
pub comp_type: Option<CompInterType>,
pub inter_mode: u8,
pub inter_mode: CompInterPredMode,
pub motion_mode: MotionMode,
pub drl_idx: DrlProximity,
pub r#ref: [i8; 2],
pub r#ref: [Av1BlockInterRefIndex; 2],
pub max_ytx: TxfmSize,
pub filter2d: Filter2d,
pub interintra_type: Option<InterIntraType>,
Expand Down Expand Up @@ -594,7 +648,7 @@ impl Default for Av1BlockIntraInter {
/// Within range `0..`[`SegmentId::COUNT`].
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct SegmentId {
id: InRange<u8, 0, { Self::COUNT as u128 - 1 }>,
id: InRange<u8, 0, { Self::COUNT as i128 - 1 }>,
}

impl SegmentId {
Expand Down Expand Up @@ -629,11 +683,11 @@ impl Display for SegmentId {
#[repr(C)]
pub struct Av1Block {
pub bl: BlockLevel,
pub bs: u8,
pub bs: BlockSize,
pub bp: BlockPartition,
pub seg_id: SegmentId,
pub skip_mode: u8,
pub skip: u8,
pub skip_mode: bool,
pub skip: bool,
pub uvtx: TxfmSize,
pub ii: Av1BlockIntraInter,
}
Loading
Loading