Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,15 @@ impl FnType {

#[derive(Clone, Debug)]
pub enum SelfType {
Receiver { mutable: bool, span: Span },
TryFromBoundRef(Span),
Receiver {
mutable: bool,
non_null: bool,
span: Span,
},
TryFromBoundRef {
span: Span,
non_null: bool,
},
}

#[derive(Clone, Copy)]
Expand Down Expand Up @@ -348,8 +355,18 @@ impl SelfType {
let slf = syn::Ident::new("_slf", Span::call_site());
let Ctx { pyo3_path, .. } = ctx;
match self {
SelfType::Receiver { span, mutable } => {
let arg = quote! { unsafe { #pyo3_path::impl_::extract_argument::cast_function_argument(#py, #slf) } };
SelfType::Receiver {
span,
mutable,
non_null,
} => {
let cast_fn = if *non_null {
quote!(cast_non_null_function_argument)
} else {
quote!(cast_function_argument)
};
let arg =
quote! { unsafe { #pyo3_path::impl_::extract_argument::#cast_fn(#py, #slf) } };
let method = if *mutable {
syn::Ident::new("extract_pyclass_ref_mut", *span)
} else {
Expand All @@ -367,8 +384,12 @@ impl SelfType {
ctx,
)
}
SelfType::TryFromBoundRef(span) => {
let bound_ref = quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf) } };
SelfType::TryFromBoundRef { span, non_null } => {
let bound_ref = if *non_null {
quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_non_null(#py, &#slf) } }
} else {
quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf) } }
};
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
error_mode.handle_error(
quote_spanned! { *span =>
Expand Down Expand Up @@ -427,7 +448,7 @@ pub struct FnSpec<'a> {
pub output: syn::ReturnType,
}

pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
pub fn parse_method_receiver(arg: &syn::FnArg, non_null: bool) -> Result<SelfType> {
match arg {
syn::FnArg::Receiver(
recv @ syn::Receiver {
Expand All @@ -439,12 +460,16 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
syn::FnArg::Receiver(recv @ syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver {
mutable: mutability.is_some(),
span: recv.span(),
non_null,
}),
syn::FnArg::Typed(syn::PatType { ty, .. }) => {
if let syn::Type::ImplTrait(_) = &**ty {
bail_spanned!(ty.span() => IMPL_TRAIT_ERR);
}
Ok(SelfType::TryFromBoundRef(ty.span()))
Ok(SelfType::TryFromBoundRef {
span: ty.span(),
non_null,
})
}
}
}
Expand Down Expand Up @@ -515,14 +540,22 @@ impl<'a> FnSpec<'a> {
python_name: &mut Option<syn::Ident>,
) -> Result<FnType> {
let mut method_attributes = parse_method_attributes(meth_attrs)?;
let receiver_non_null = method_attributes.iter().any(|attr| {
matches!(
attr,
MethodTypeAttribute::Getter(_, _)
| MethodTypeAttribute::Setter(_, _)
| MethodTypeAttribute::Deleter(_, _)
)
});

let name = &sig.ident;
let parse_receiver = |msg: &'static str| {
let first_arg = sig
.inputs
.first()
.ok_or_else(|| err_spanned!(sig.span() => msg))?;
parse_method_receiver(first_arg)
parse_method_receiver(first_arg, receiver_non_null)
};

// strip get_ or set_
Expand Down
5 changes: 4 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,10 @@ fn complex_enum_variant_field_getter(
let py = FnArg::parse(&mut arg)?;
let signature = FunctionSignature::from_arguments(vec![py]);

let self_type = crate::method::SelfType::TryFromBoundRef(field_span);
let self_type = crate::method::SelfType::TryFromBoundRef {
span: field_span,
non_null: true,
};

let spec = FnSpec {
tp: crate::method::FnType::Getter(self_type.clone()),
Expand Down
14 changes: 8 additions & 6 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,11 @@ fn impl_traverse_slot(
}

// check that the receiver does not try to smuggle an (implicit) `Python` token into here
if let FnType::Fn(SelfType::TryFromBoundRef(span))
if let FnType::Fn(SelfType::TryFromBoundRef { span, .. })
| FnType::Fn(SelfType::Receiver {
mutable: true,
span,
..
}) = spec.tp
{
bail_spanned! { span =>
Expand Down Expand Up @@ -608,6 +609,7 @@ pub fn impl_py_setter_def(
let slf = SelfType::Receiver {
mutable: true,
span: Span::call_site(),
non_null: true,
}
.receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx);
if let Some(ident) = &field.ident {
Expand Down Expand Up @@ -716,11 +718,11 @@ pub fn impl_py_setter_def(
#cfg_attrs
unsafe fn #wrapper_ident(
py: #pyo3_path::Python<'_>,
_slf: *mut #pyo3_path::ffi::PyObject,
_value: *mut #pyo3_path::ffi::PyObject,
_slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>,
_value: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>,
) -> #pyo3_path::PyResult<::std::ffi::c_int> {
use ::std::convert::Into;
let _value = #pyo3_path::impl_::extract_argument::cast_function_argument(py, _value);
let _value = #pyo3_path::impl_::extract_argument::cast_non_null_function_argument(py, _value);
#init_holders
#extract
#warnings
Expand Down Expand Up @@ -850,7 +852,7 @@ pub fn impl_py_getter_def(
#cfg_attrs
unsafe fn #wrapper_ident(
py: #pyo3_path::Python<'_>,
_slf: *mut #pyo3_path::ffi::PyObject
_slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
#init_holders
#warnings
Expand Down Expand Up @@ -897,7 +899,7 @@ pub fn impl_py_deleter_def(
let associated_method = quote! {
unsafe fn #wrapper_ident(
py: #pyo3_path::Python<'_>,
_slf: *mut #pyo3_path::ffi::PyObject,
_slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>,
) -> #pyo3_path::PyResult<::std::ffi::c_int> {
#init_holders
#warnings
Expand Down
24 changes: 11 additions & 13 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ impl<ClassT: PyClass, FieldT, const OFFSET: usize, const IMPLEMENTS_INTOPYOBJECT
#[inline]
unsafe fn ensure_no_mutable_alias<'a, ClassT: PyClass>(
_py: Python<'_>,
obj: &'a *mut ffi::PyObject,
obj: &'a NonNull<ffi::PyObject>,
) -> Result<PyClassGuard<'a, ClassT>, PyBorrowError> {
unsafe { PyClassGuard::try_borrow(NonNull::from(obj).cast::<Py<ClassT>>().as_ref()) }
}
Expand All @@ -1352,7 +1352,7 @@ unsafe fn ensure_no_mutable_alias<'a, ClassT: PyClass>(
/// - there must be a value of type `FieldT` at the calculated offset within `ClassT`
unsafe fn pyo3_get_value_into_pyobject_ref<ClassT, FieldT, const OFFSET: usize>(
py: Python<'_>,
obj: *mut ffi::PyObject,
obj: NonNull<ffi::PyObject>,
) -> PyResult<*mut ffi::PyObject>
where
ClassT: PyClass,
Expand All @@ -1365,25 +1365,24 @@ where
/// - value of type `FieldT` must exist at the given offset within obj
unsafe fn inner<FieldT>(
py: Python<'_>,
obj: *const (),
obj: NonNull<()>,
offset: usize,
) -> PyResult<*mut ffi::PyObject>
where
for<'a, 'py> &'a FieldT: IntoPyObject<'py>,
{
// SAFETY: caller upholds safety invariants
let value = unsafe { &*obj.byte_add(offset).cast::<FieldT>() };
let value = unsafe { obj.byte_add(offset).cast::<FieldT>().as_ref() };
value.into_py_any(py).map(Py::into_ptr)
}

// SAFETY: `obj` is a valid pointer to `ClassT`
let _holder = unsafe { ensure_no_mutable_alias::<ClassT>(py, &obj)? };
let class_ptr = obj.cast::<<ClassT as PyClassImpl>::Layout>();
let class_obj = unsafe { &*class_ptr };
let contents_ptr = ptr::from_ref(class_obj.contents());
let class_obj = unsafe { class_ptr.as_ref() };

// SAFETY: _holder prevents mutable aliasing, caller upholds other safety invariants
unsafe { inner::<FieldT>(py, contents_ptr.cast(), OFFSET) }
unsafe { inner::<FieldT>(py, NonNull::from(class_obj.contents()).cast(), OFFSET) }
}

/// Gets a field value from a pyclass and produces a python value using `IntoPyObject` for `FieldT`,
Expand All @@ -1394,7 +1393,7 @@ where
/// - there must be a value of type `FieldT` at the calculated offset within `ClassT`
unsafe fn pyo3_get_value_into_pyobject<ClassT, FieldT, const OFFSET: usize>(
py: Python<'_>,
obj: *mut ffi::PyObject,
obj: NonNull<ffi::PyObject>,
) -> PyResult<*mut ffi::PyObject>
where
ClassT: PyClass,
Expand All @@ -1407,25 +1406,24 @@ where
/// - value of type `FieldT` must exist at the given offset within obj
unsafe fn inner<FieldT>(
py: Python<'_>,
obj: *const (),
obj: NonNull<()>,
offset: usize,
) -> PyResult<*mut ffi::PyObject>
where
for<'py> FieldT: IntoPyObject<'py> + Clone,
{
// SAFETY: caller upholds safety invariants
let value = unsafe { &*obj.byte_add(offset).cast::<FieldT>() };
let value = unsafe { obj.byte_add(offset).cast::<FieldT>().as_ref() };
value.clone().into_py_any(py).map(Py::into_ptr)
}

// SAFETY: `obj` is a valid pointer to `ClassT`
let _holder = unsafe { ensure_no_mutable_alias::<ClassT>(py, &obj)? };
let class_ptr = obj.cast::<<ClassT as PyClassImpl>::Layout>();
let class_obj = unsafe { &*class_ptr };
let contents_ptr = ptr::from_ref(class_obj.contents());
let class_obj = unsafe { class_ptr.as_ref() };

// SAFETY: _holder prevents mutable aliasing, caller upholds other safety invariants
unsafe { inner::<FieldT>(py, contents_ptr.cast(), OFFSET) }
unsafe { inner::<FieldT>(py, NonNull::from(class_obj.contents()).cast(), OFFSET) }
}

pub struct ConvertField<
Expand Down
12 changes: 8 additions & 4 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,14 @@ impl fmt::Debug for PyClassAttributeDef {

/// Class getter / setters
pub(crate) type Getter =
for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<*mut ffi::PyObject>;
pub(crate) type Setter =
for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject, *mut ffi::PyObject) -> PyResult<c_int>;
pub(crate) type Deleter = for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<c_int>;
for<'py> unsafe fn(Python<'py>, NonNull<ffi::PyObject>) -> PyResult<*mut ffi::PyObject>;
pub(crate) type Setter = for<'py> unsafe fn(
Python<'py>,
NonNull<ffi::PyObject>,
NonNull<ffi::PyObject>,
) -> PyResult<c_int>;
pub(crate) type Deleter =
for<'py> unsafe fn(Python<'py>, NonNull<ffi::PyObject>) -> PyResult<c_int>;

impl PyGetterDef {
/// Define a getter.
Expand Down
20 changes: 12 additions & 8 deletions src/pyclass/create_type_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ impl GetSetDefType {
slf: *mut ffi::PyObject,
closure: *mut c_void,
) -> *mut ffi::PyObject {
let slf = unsafe { NonNull::new_unchecked(slf) };
// Safety: PyO3 sets the closure when constructing the ffi getter so this cast should always be valid
let getter: Getter = unsafe { std::mem::transmute(closure) };
unsafe { trampoline(|py| getter(py, slf)) }
Expand All @@ -699,14 +700,15 @@ impl GetSetDefType {
value: *mut ffi::PyObject,
closure: *mut c_void,
) -> c_int {
let slf = unsafe { NonNull::new_unchecked(slf) };
// Safety: PyO3 sets the closure when constructing the ffi setter so this cast should always be valid
let setter: Setter = unsafe { std::mem::transmute(closure) };
unsafe {
trampoline(|py| {
if value.is_null() {
Err(PyAttributeError::new_err("property has no deleter"))
} else {
if let Some(value) = NonNull::new(value) {
setter(py, slf, value)
} else {
Err(PyAttributeError::new_err("property has no deleter"))
}
})
}
Expand All @@ -718,6 +720,7 @@ impl GetSetDefType {
slf: *mut ffi::PyObject,
closure: *mut c_void,
) -> *mut ffi::PyObject {
let slf = unsafe { NonNull::new_unchecked(slf) };
let getset: &GetSetDeleteCombination = unsafe { &*closure.cast() };
// we only call this method if getter is set
unsafe { trampoline(|py| getset.getter.unwrap_unchecked()(py, slf)) }
Expand All @@ -728,17 +731,18 @@ impl GetSetDefType {
value: *mut ffi::PyObject,
closure: *mut c_void,
) -> c_int {
let slf = unsafe { NonNull::new_unchecked(slf) };
let getset: &GetSetDeleteCombination = unsafe { &*closure.cast() };
unsafe {
trampoline(|py| {
if value.is_null() {
getset.deleter.ok_or_else(|| {
PyAttributeError::new_err("property has no deleter")
})?(py, slf)
} else {
if let Some(value) = NonNull::new(value) {
getset.setter.ok_or_else(|| {
PyAttributeError::new_err("property has no setter")
})?(py, slf, value)
} else {
getset.deleter.ok_or_else(|| {
PyAttributeError::new_err("property has no deleter")
})?(py, slf)
}
})
}
Expand Down
Loading