diff --git a/rust/map.rs b/rust/map.rs index 4f114962e5a6b..84c94007960cd 100644 --- a/rust/map.rs +++ b/rust/map.rs @@ -6,10 +6,10 @@ // https://developers.google.com/open-source/licenses/bsd use crate::{ - AsMut, AsView, IntoMut, IntoProxied, IntoView, Message, Mut, MutProxied, Proxied, Singular, - View, __internal::runtime::{InnerMap, InnerMapMut, RawMap, RawMapIter}, __internal::{Private, SealedInternal}, + AsMut, AsView, IntoMut, IntoProxied, IntoView, Message, Mut, MutProxied, Proxied, Singular, + View, }; use std::marker::PhantomData; @@ -225,6 +225,13 @@ impl<'msg, K: MapKey, V: MapValue> MapView<'msg, K, V> { self.len() == 0 } + pub fn contains_key<'a>(self, key: impl Into>) -> bool + where + K: 'a, + { + self.get(key).is_some() + } + /// Returns an iterator visiting all key-value pairs in arbitrary order. /// /// The iterator element type is `(View, View)`. @@ -362,6 +369,13 @@ impl<'msg, K: MapKey, V: MapValue> MapMut<'msg, K, V> { self.len() == 0 } + pub fn contains_key<'a>(&self, key: impl Into>) -> bool + where + K: 'a, + { + self.as_view().contains_key(key) + } + /// Adds a key-value pair to the map. /// /// Returns `true` if the entry was newly inserted. @@ -730,4 +744,15 @@ mod tests { assert_that!(format!("{:?}", map.as_view()), eq("MapView(\"i32\", \"f64\")")); assert_that!(format!("{:?}", map.as_mut()), eq("MapMut(\"i32\", \"f64\")")); } + #[gtest] + fn test_contains_key() { + let mut map = Map::::new(); + let mut map_mut = map.as_mut(); + assert!(!map_mut.contains_key(42)); + assert!(!map_mut.as_view().contains_key(42)); + + map_mut.insert(42, 1.0); + assert!(map_mut.contains_key(42)); + assert!(map_mut.as_view().contains_key(42)); + } } diff --git a/rust/repeated.rs b/rust/repeated.rs index f71493130b3fe..91beb1410c8fa 100644 --- a/rust/repeated.rs +++ b/rust/repeated.rs @@ -204,6 +204,24 @@ impl<'msg, T: Singular> RepeatedView<'msg, T> { pub fn iter(self) -> RepeatedIter<'msg, T> { self.into_iter() } + + /// Returns the first element of the repeated field, or `None` if it is empty. + #[inline] + pub fn first(self) -> Option> { + self.get(0) + } + + /// Returns the last element of the repeated field, or `None` if it is empty. + #[inline] + pub fn last(self) -> Option> { + let len = self.len(); + if len == 0 { + None + } else { + // SAFETY: len is not 0, so len - 1 is a valid index + unsafe { Some(self.get_unchecked(len - 1)) } + } + } } impl<'msg, T: Singular> AsView for RepeatedView<'msg, T> { @@ -380,6 +398,18 @@ impl<'msg, T: Singular> RepeatedMut<'msg, T> { pub fn clear(&mut self) { T::repeated_clear(Private, self.as_mut()) } + + /// Returns the first element of the repeated field, or `None` if it is empty. + #[inline] + pub fn first(&self) -> Option> { + self.as_view().first() + } + + /// Returns the last element of the repeated field, or `None` if it is empty. + #[inline] + pub fn last(&self) -> Option> { + self.as_view().last() + } } impl<'msg, T: Singular> IntoProxied> for RepeatedMut<'msg, T> { @@ -441,11 +471,23 @@ where } } +impl std::iter::FromIterator for Repeated +where + U: IntoProxied, +{ + fn from_iter>(iter: I) -> Self { + let mut repeated = Repeated::new(); + repeated.as_mut().extend(iter); + repeated + } +} + /// An iterator over the values inside of a [`View>`](RepeatedView). #[derive(Clone)] pub struct RepeatedIter<'msg, T> { view: RepeatedView<'msg, T>, current_index: usize, + end_index: usize, } impl<'msg, T> Debug for RepeatedIter<'msg, T> { @@ -462,11 +504,13 @@ impl<'msg, T: Singular> iter::Iterator for RepeatedIter<'msg, T> { #[inline] fn next(&mut self) -> Option { - let val = self.view.get(self.current_index); - if val.is_some() { - self.current_index += 1; + if self.current_index >= self.end_index { + return None; } - val + // SAFETY: current_index < end_index <= view.len() + let val = unsafe { self.view.get_unchecked(self.current_index) }; + self.current_index += 1; + Some(val) } fn size_hint(&self) -> (usize, Option) { @@ -477,11 +521,22 @@ impl<'msg, T: Singular> iter::Iterator for RepeatedIter<'msg, T> { impl<'msg, T: Singular> ExactSizeIterator for RepeatedIter<'msg, T> { fn len(&self) -> usize { - self.view.len() - self.current_index + self.end_index - self.current_index + } +} + +impl<'msg, T: Singular> iter::DoubleEndedIterator for RepeatedIter<'msg, T> { + #[inline] + fn next_back(&mut self) -> Option { + if self.current_index >= self.end_index { + return None; + } + self.end_index -= 1; + // SAFETY: current_index <= end_index < view.len() + Some(unsafe { self.view.get_unchecked(self.end_index) }) } } -// TODO: impl DoubleEndedIterator impl<'msg, T: Singular> FusedIterator for RepeatedIter<'msg, T> {} impl<'msg, T: Singular> iter::IntoIterator for RepeatedView<'msg, T> { @@ -489,7 +544,7 @@ impl<'msg, T: Singular> iter::IntoIterator for RepeatedView<'msg, T> { type IntoIter = RepeatedIter<'msg, T>; fn into_iter(self) -> Self::IntoIter { - RepeatedIter { view: self, current_index: 0 } + RepeatedIter { view: self, current_index: 0, end_index: self.len() } } } @@ -498,7 +553,7 @@ impl<'msg, T: Singular> iter::IntoIterator for &'_ RepeatedView<'msg, T> { type IntoIter = RepeatedIter<'msg, T>; fn into_iter(self) -> Self::IntoIter { - RepeatedIter { view: *self, current_index: 0 } + RepeatedIter { view: *self, current_index: 0, end_index: self.len() } } } @@ -507,7 +562,8 @@ impl<'borrow, T: Singular> iter::IntoIterator for &'borrow RepeatedMut<'_, T> { type IntoIter = RepeatedIter<'borrow, T>; fn into_iter(self) -> Self::IntoIter { - RepeatedIter { view: self.as_view(), current_index: 0 } + let view = self.as_view(); + RepeatedIter { view, current_index: 0, end_index: view.len() } } } @@ -515,7 +571,7 @@ impl<'borrow, T: Singular> iter::IntoIterator for &'borrow RepeatedMut<'_, T> { pub struct RepeatedMutIter<'msg, T> { inner: InnerRepeatedMut<'msg>, current_index: usize, - len: usize, + end_index: usize, _phantom: PhantomData<&'msg mut T>, } @@ -524,7 +580,7 @@ impl<'msg, T: Message> iter::Iterator for RepeatedMutIter<'msg, T> { #[inline] fn next(&mut self) -> Option { - if self.current_index >= self.len { + if self.current_index >= self.end_index { return None; } let index = self.current_index; @@ -547,7 +603,26 @@ impl<'msg, T: Message> iter::Iterator for RepeatedMutIter<'msg, T> { impl<'msg, T: Message> ExactSizeIterator for RepeatedMutIter<'msg, T> { fn len(&self) -> usize { - self.len - self.current_index + self.end_index - self.current_index + } +} + +impl<'msg, T: Message> iter::DoubleEndedIterator for RepeatedMutIter<'msg, T> { + #[inline] + fn next_back(&mut self) -> Option { + if self.current_index >= self.end_index { + return None; + } + self.end_index -= 1; + let index = self.end_index; + + // SAFETY: index is valid. + let val = unsafe { + let temp_repeated = RepeatedMut::from_inner(Private, self.inner); + T::repeated_get_mut_unchecked(Private, temp_repeated, index) + }; + + Some(val) } } @@ -557,12 +632,36 @@ impl<'msg, T: Message> RepeatedMut<'msg, T> { /// Returns an iterator that allows modifying each value. pub fn iter_mut(self) -> RepeatedMutIter<'msg, T> { RepeatedMutIter { - len: self.len(), + end_index: self.len(), inner: self.inner, current_index: 0, _phantom: PhantomData, } } + + /// Returns a mutable reference to the first element of the repeated field, or `None` if it is empty. + #[inline] + pub fn first_mut<'r>(&'r mut self) -> Option> + where + 'r: 'msg, + { + self.get_mut(0) + } + + /// Returns a mutable reference to the last element of the repeated field, or `None` if it is empty. + #[inline] + pub fn last_mut<'r>(&'r mut self) -> Option> + where + 'r: 'msg, + { + let len = self.len(); + if len == 0 { + None + } else { + // SAFETY: len is not 0, so len - 1 is valid + Some(unsafe { self.get_mut_unchecked(len - 1) }) + } + } } #[cfg(test)] @@ -639,4 +738,40 @@ mod tests { assert_that!(clone.next(), eq(Some(1))); assert_that!(iter.next(), eq(Some(1))); } + #[gtest] + fn test_first_last() { + let mut r = Repeated::::new(); + assert!(r.as_view().first().is_none()); + assert!(r.as_view().last().is_none()); + assert!(r.as_mut().first().is_none()); + assert!(r.as_mut().last().is_none()); + + let mut m = r.as_mut(); + m.push(10); + assert_that!(m.first(), eq(Some(10))); + assert_that!(m.last(), eq(Some(10))); + assert_that!(m.as_view().first(), eq(Some(10))); + assert_that!(m.as_view().last(), eq(Some(10))); + + m.push(20); + assert_that!(m.first(), eq(Some(10))); + assert_that!(m.last(), eq(Some(20))); + } + + #[gtest] + fn test_double_ended_iter() { + let r: Repeated = [1, 2, 3].into_iter().collect(); + let mut iter = r.as_view().into_iter(); + assert_that!(iter.next(), eq(Some(1))); + assert_that!(iter.next_back(), eq(Some(3))); + assert_that!(iter.next(), eq(Some(2))); + assert_that!(iter.next(), eq(None)); + assert_that!(iter.next_back(), eq(None)); + } + + #[gtest] + fn test_from_iter() { + let r: Repeated = [10, 20, 30].into_iter().collect(); + assert_that!(r.as_view(), elements_are![eq(10), eq(20), eq(30)]); + } } diff --git a/rust/test/shared/accessors_map_test.rs b/rust/test/shared/accessors_map_test.rs index 3fc847f1c4bf5..e25692b4d1885 100644 --- a/rust/test/shared/accessors_map_test.rs +++ b/rust/test/shared/accessors_map_test.rs @@ -427,3 +427,13 @@ generate_map_with_msg_values_tests!( (bool, true, false), (string, "foo", "bar"), ); + +#[gtest] +fn test_map_contains_key() { + let mut msg = TestMap::new(); + assert_that!(msg.map_int32_int32().contains_key(42), eq(false)); + + msg.map_int32_int32_mut().insert(42, 100); + assert_that!(msg.map_int32_int32().contains_key(42), eq(true)); + assert_that!(msg.map_int32_int32_mut().contains_key(42), eq(true)); +} diff --git a/rust/test/shared/accessors_repeated_test.rs b/rust/test/shared/accessors_repeated_test.rs index 50326cadbc669..2386ed9d5e5c1 100644 --- a/rust/test/shared/accessors_repeated_test.rs +++ b/rust/test/shared/accessors_repeated_test.rs @@ -378,3 +378,59 @@ fn test_repeated_bytes() { older_msg.repeated_bytes_mut().clear(); assert_that!(older_msg.repeated_bytes(), is_empty()); } + +#[gtest] +fn test_repeated_numeric_first_last() { + let mut msg = TestAllTypes::new(); + let mut mutator = msg.repeated_int32_mut(); + assert_that!(mutator.first(), none()); + assert_that!(mutator.last(), none()); + + mutator.push(42); + assert_that!(mutator.first(), some(eq(42))); + assert_that!(mutator.last(), some(eq(42))); + + mutator.push(100); + assert_that!(mutator.first(), some(eq(42))); + assert_that!(mutator.last(), some(eq(100))); +} + +#[gtest] +fn test_repeated_message_first_last_mut() { + let mut msg = TestAllTypes::new(); + assert!(msg.repeated_nested_message_mut().first_mut().is_none()); + assert!(msg.repeated_nested_message_mut().last_mut().is_none()); + + msg.repeated_nested_message_mut().push_default().set_bb(10); + assert_that!(msg.repeated_nested_message_mut().first_mut().unwrap().bb(), eq(10)); + assert_that!(msg.repeated_nested_message_mut().last_mut().unwrap().bb(), eq(10)); + + msg.repeated_nested_message_mut().push_default().set_bb(20); + assert_that!(msg.repeated_nested_message_mut().first_mut().unwrap().bb(), eq(10)); + assert_that!(msg.repeated_nested_message_mut().last_mut().unwrap().bb(), eq(20)); + + msg.repeated_nested_message_mut().last_mut().unwrap().set_bb(30); + assert_that!(msg.repeated_nested_message().last().unwrap().bb(), eq(30)); +} + +#[gtest] +fn test_repeated_double_ended_iter() { + let mut msg = TestAllTypes::new(); + let mut mutator = msg.repeated_int32_mut(); + mutator.push(1); + mutator.push(2); + mutator.push(3); + + let mut iter = mutator.as_view().into_iter(); + assert_that!(iter.next(), some(eq(1))); + assert_that!(iter.next_back(), some(eq(3))); + assert_that!(iter.next(), some(eq(2))); + assert_that!(iter.next(), none()); + assert_that!(iter.next_back(), none()); +} + +#[gtest] +fn test_repeated_from_iter() { + let r: Repeated = [10, 20, 30].into_iter().collect(); + assert_that!(r.as_view(), elements_are![eq(10), eq(20), eq(30)]); +}