Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ pub trait FromPyObject<'a, 'py>: Sized {
impl<T> FromPyObjectSequence for NeverASequence<T> {
type Target = T;

fn to_vec(&self) -> Vec<Self::Target> {
fn to_vec(&self) -> PyResult<Vec<Self::Target>> {
unreachable!()
}

Expand Down Expand Up @@ -480,7 +480,7 @@ mod from_py_object_sequence {
pub trait FromPyObjectSequence {
type Target;

fn to_vec(&self) -> Vec<Self::Target>;
fn to_vec(&self) -> PyResult<Vec<Self::Target>>;

fn to_array<const N: usize>(&self) -> PyResult<[Self::Target; N]>;
}
Expand Down
90 changes: 83 additions & 7 deletions src/conversions/std/num.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
use crate::buffer::PyBuffer;
use crate::conversion::private::Reference;
use crate::conversion::{FromPyObjectSequence, IntoPyObject};
use crate::ffi_ptr_ext::FfiPtrExt;
Expand All @@ -6,10 +8,14 @@ use crate::inspect::types::TypeInfo;
#[cfg(feature = "experimental-inspect")]
use crate::inspect::PyStaticExpr;
use crate::py_result_ext::PyResultExt;
#[cfg(feature = "experimental-inspect")]
use crate::type_object::PyTypeInfo;
use crate::types::{PyByteArray, PyByteArrayMethods, PyBytes, PyInt};
use crate::{exceptions, ffi, Borrowed, Bound, FromPyObject, PyAny, PyErr, PyResult, Python};
use crate::types::sequence::PySequenceMethods;
use crate::types::{
any::PyAnyMethods, PyByteArray, PyByteArrayMethods, PyBytes, PyInt, PySequence,
};
use crate::{
exceptions, ffi, Borrowed, Bound, CastError, FromPyObject, PyAny, PyErr, PyResult, PyTypeInfo,
Python,
};
use std::convert::Infallible;
use std::ffi::c_long;
use std::mem::MaybeUninit;
Expand Down Expand Up @@ -317,6 +323,10 @@ impl<'py> FromPyObject<'_, 'py> for u8 {
} else if let Ok(byte_array) = obj.cast::<PyByteArray>() {
Some(BytesSequenceExtractor::ByteArray(byte_array))
} else {
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
if unsafe { ffi::PyObject_CheckBuffer(obj.as_ptr()) != 0 } {
return Some(BytesSequenceExtractor::Buffer(obj.to_any()));
}
None
}
}
Expand All @@ -325,6 +335,8 @@ impl<'py> FromPyObject<'_, 'py> for u8 {
pub(crate) enum BytesSequenceExtractor<'a, 'py> {
Bytes(Borrowed<'a, 'py, PyBytes>),
ByteArray(Borrowed<'a, 'py, PyByteArray>),
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
Buffer(Borrowed<'a, 'py, PyAny>),
}

impl BytesSequenceExtractor<'_, '_> {
Expand All @@ -348,18 +360,43 @@ impl BytesSequenceExtractor<'_, '_> {
copy_slice(unsafe { b.as_bytes() })
})
}
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
BytesSequenceExtractor::Buffer(any) => {
// Fall back to sequence semantics if the buffer is incompatible with u8
// (e.g., array('I')).
if let Ok(buf) = PyBuffer::<u8>::get(any) {
// Safety: we're about to write the entire `out` slice.
let target = unsafe {
std::slice::from_raw_parts_mut(out.as_mut_ptr().cast::<u8>(), out.len())
};
buf.copy_to_slice(any.py(), target)?;
Ok(())
} else {
fill_u8_slice_from_sequence(*any, out)
}
}
}
}
}

impl FromPyObjectSequence for BytesSequenceExtractor<'_, '_> {
type Target = u8;

fn to_vec(&self) -> Vec<Self::Target> {
match self {
fn to_vec(&self) -> PyResult<Vec<Self::Target>> {
Ok(match self {
BytesSequenceExtractor::Bytes(b) => b.as_bytes().to_vec(),
BytesSequenceExtractor::ByteArray(b) => b.to_vec(),
}
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
BytesSequenceExtractor::Buffer(any) => {
// Fall back to sequence semantics if the buffer is incompatible with u8
// (e.g., array('I')).
if let Ok(buf) = PyBuffer::<u8>::get(any) {
return buf.to_vec(any.py());
} else {
return extract_u8_vec_from_sequence(*any);
}
}
})
}

fn to_array<const N: usize>(&self) -> PyResult<[u8; N]> {
Expand All @@ -377,6 +414,45 @@ impl FromPyObjectSequence for BytesSequenceExtractor<'_, '_> {
}
}

fn extract_u8_vec_from_sequence<'a, 'py>(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Vec<u8>> {
// Types that pass `PySequence_Check` usually implement enough of the sequence protocol
// to support this function and if not, we will only fail extraction safely.
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
obj.cast_unchecked::<PySequence>()
} else {
return Err(CastError::new(obj, PySequence::type_object(obj.py()).into_any()).into());
}
};

let mut v = Vec::with_capacity(seq.len().unwrap_or(0));
for item in seq.try_iter()? {
v.push(item?.extract::<u8>()?);
}
Ok(v)
}

fn fill_u8_slice_from_sequence<'a, 'py>(
obj: Borrowed<'a, 'py, PyAny>,
out: &mut [MaybeUninit<u8>],
) -> PyResult<()> {
let seq = unsafe {
if ffi::PySequence_Check(obj.as_ptr()) != 0 {
obj.cast_unchecked::<PySequence>()
} else {
return Err(CastError::new(obj, PySequence::type_object(obj.py()).into_any()).into());
}
};
let seq_len = seq.len()?;
if seq_len != out.len() {
return Err(invalid_sequence_length(out.len(), seq_len));
}
for (idx, item) in seq.try_iter()?.enumerate() {
out[idx].write(item?.extract::<u8>()?);
}
Ok(())
}

int_fits_c_long!(i8);
int_fits_c_long!(i16);
int_fits_c_long!(u16);
Expand Down
2 changes: 1 addition & 1 deletion src/conversions/std/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ where

fn extract(obj: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
if let Some(extractor) = T::sequence_extractor(obj, crate::conversion::private::Token) {
return Ok(extractor.to_vec());
return extractor.to_vec();
}

if obj.is_instance_of::<PyString>() {
Expand Down
45 changes: 44 additions & 1 deletion tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::buffer::PyBuffer;
use pyo3::exceptions::PyBufferError;
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::types::{IntoPyDict, PyBytes, PyDict};
use std::ffi::CString;
use std::ffi::{c_int, c_void};
use std::ptr;
Expand All @@ -15,6 +15,11 @@ use std::sync::Arc;

mod test_utils;

#[pyfunction]
fn vec_u8_to_pybytes(py: Python<'_>, bytes: Vec<u8>) -> Bound<'_, PyBytes> {
PyBytes::new(py, &bytes)
}

#[pyclass]
struct TestBufferClass {
vec: Vec<u8>,
Expand Down Expand Up @@ -94,6 +99,44 @@ fn test_buffer_referenced() {
assert!(drop_called.load(Ordering::Relaxed));
}

#[test]
fn test_extract_vec_u8_from_buffer_exporter() {
let drop_called = Arc::new(AtomicBool::new(false));

Python::attach(|py| {
let instance = Py::new(
py,
TestBufferClass {
vec: vec![b'A', b'B', b'C'],
drop_called: drop_called.clone(),
},
)
.unwrap();
let f = wrap_pyfunction!(vec_u8_to_pybytes)(py).unwrap();
let env = PyDict::new(py);
env.set_item("ob", instance).unwrap();
env.set_item("f", f).unwrap();
py_assert!(py, *env, "f(ob) == b'ABC'");
});

assert!(drop_called.load(Ordering::Relaxed));
}

#[test]
fn test_extract_vec_u8_falls_back_when_buffer_incompatible() {
Python::attach(|py| {
let array_mod = py.import("array").unwrap();
let ob = array_mod
.call_method1("array", ("I", vec![65u32, 66u32, 67u32]))
.unwrap();
let f = wrap_pyfunction!(vec_u8_to_pybytes)(py).unwrap();
let env = PyDict::new(py);
env.set_item("ob", ob).unwrap();
env.set_item("f", f).unwrap();
py_assert!(py, *env, "f(ob) == b'ABC'");
});
}

#[test]
#[cfg(Py_3_8)] // sys.unraisablehook not available until Python 3.8
fn test_releasebuffer_unraisable_error() {
Expand Down
Loading