diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 44a0ad611df..12a49def29b 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -309,8 +309,15 @@ impl FnType { #[derive(Clone, Debug)] pub enum SelfType { - Receiver { mutable: bool, span: Span }, - TryFromBoundRef(Span), + Receiver { + mutable: bool, + non_null: bool, + span: Span, + }, + TryFromBoundRef { + span: Span, + non_null: bool, + }, } #[derive(Clone, Copy)] @@ -348,8 +355,18 @@ impl SelfType { let slf = syn::Ident::new("_slf", Span::call_site()); let Ctx { pyo3_path, .. } = ctx; match self { - SelfType::Receiver { span, mutable } => { - let arg = quote! { unsafe { #pyo3_path::impl_::extract_argument::cast_function_argument(#py, #slf) } }; + SelfType::Receiver { + span, + mutable, + non_null, + } => { + let cast_fn = if *non_null { + quote!(cast_non_null_function_argument) + } else { + quote!(cast_function_argument) + }; + let arg = + quote! { unsafe { #pyo3_path::impl_::extract_argument::#cast_fn(#py, #slf) } }; let method = if *mutable { syn::Ident::new("extract_pyclass_ref_mut", *span) } else { @@ -367,8 +384,12 @@ impl SelfType { ctx, ) } - SelfType::TryFromBoundRef(span) => { - let bound_ref = quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf) } }; + SelfType::TryFromBoundRef { span, non_null } => { + let bound_ref = if *non_null { + quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_non_null(#py, &#slf) } } + } else { + quote! { unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf) } } + }; let pyo3_path = pyo3_path.to_tokens_spanned(*span); error_mode.handle_error( quote_spanned! { *span => @@ -427,7 +448,7 @@ pub struct FnSpec<'a> { pub output: syn::ReturnType, } -pub fn parse_method_receiver(arg: &syn::FnArg) -> Result { +pub fn parse_method_receiver(arg: &syn::FnArg, non_null: bool) -> Result { match arg { syn::FnArg::Receiver( recv @ syn::Receiver { @@ -439,12 +460,16 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> Result { syn::FnArg::Receiver(recv @ syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver { mutable: mutability.is_some(), span: recv.span(), + non_null, }), syn::FnArg::Typed(syn::PatType { ty, .. }) => { if let syn::Type::ImplTrait(_) = &**ty { bail_spanned!(ty.span() => IMPL_TRAIT_ERR); } - Ok(SelfType::TryFromBoundRef(ty.span())) + Ok(SelfType::TryFromBoundRef { + span: ty.span(), + non_null, + }) } } } @@ -515,6 +540,14 @@ impl<'a> FnSpec<'a> { python_name: &mut Option, ) -> Result { let mut method_attributes = parse_method_attributes(meth_attrs)?; + let receiver_non_null = method_attributes.iter().any(|attr| { + matches!( + attr, + MethodTypeAttribute::Getter(_, _) + | MethodTypeAttribute::Setter(_, _) + | MethodTypeAttribute::Deleter(_, _) + ) + }); let name = &sig.ident; let parse_receiver = |msg: &'static str| { @@ -522,7 +555,7 @@ impl<'a> FnSpec<'a> { .inputs .first() .ok_or_else(|| err_spanned!(sig.span() => msg))?; - parse_method_receiver(first_arg) + parse_method_receiver(first_arg, receiver_non_null) }; // strip get_ or set_ diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 841b0ee0316..0e8e4d5194b 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -2133,7 +2133,10 @@ fn complex_enum_variant_field_getter( let py = FnArg::parse(&mut arg)?; let signature = FunctionSignature::from_arguments(vec![py]); - let self_type = crate::method::SelfType::TryFromBoundRef(field_span); + let self_type = crate::method::SelfType::TryFromBoundRef { + span: field_span, + non_null: true, + }; let spec = FnSpec { tp: crate::method::FnType::Getter(self_type.clone()), diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 6e64903e698..977ddf420eb 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -422,10 +422,11 @@ fn impl_traverse_slot( } // check that the receiver does not try to smuggle an (implicit) `Python` token into here - if let FnType::Fn(SelfType::TryFromBoundRef(span)) + if let FnType::Fn(SelfType::TryFromBoundRef { span, .. }) | FnType::Fn(SelfType::Receiver { mutable: true, span, + .. }) = spec.tp { bail_spanned! { span => @@ -608,6 +609,7 @@ pub fn impl_py_setter_def( let slf = SelfType::Receiver { mutable: true, span: Span::call_site(), + non_null: true, } .receiver(cls, ExtractErrorMode::Raise, &mut holders, ctx); if let Some(ident) = &field.ident { @@ -716,11 +718,11 @@ pub fn impl_py_setter_def( #cfg_attrs unsafe fn #wrapper_ident( py: #pyo3_path::Python<'_>, - _slf: *mut #pyo3_path::ffi::PyObject, - _value: *mut #pyo3_path::ffi::PyObject, + _slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>, + _value: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>, ) -> #pyo3_path::PyResult<::std::ffi::c_int> { use ::std::convert::Into; - let _value = #pyo3_path::impl_::extract_argument::cast_function_argument(py, _value); + let _value = #pyo3_path::impl_::extract_argument::cast_non_null_function_argument(py, _value); #init_holders #extract #warnings @@ -850,7 +852,7 @@ pub fn impl_py_getter_def( #cfg_attrs unsafe fn #wrapper_ident( py: #pyo3_path::Python<'_>, - _slf: *mut #pyo3_path::ffi::PyObject + _slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject> ) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> { #init_holders #warnings @@ -897,7 +899,7 @@ pub fn impl_py_deleter_def( let associated_method = quote! { unsafe fn #wrapper_ident( py: #pyo3_path::Python<'_>, - _slf: *mut #pyo3_path::ffi::PyObject, + _slf: ::std::ptr::NonNull<#pyo3_path::ffi::PyObject>, ) -> #pyo3_path::PyResult<::std::ffi::c_int> { #init_holders #warnings diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 396c079ac0f..4fcc3bad5f9 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -1340,7 +1340,7 @@ impl( _py: Python<'_>, - obj: &'a *mut ffi::PyObject, + obj: &'a NonNull, ) -> Result, PyBorrowError> { unsafe { PyClassGuard::try_borrow(NonNull::from(obj).cast::>().as_ref()) } } @@ -1352,7 +1352,7 @@ unsafe fn ensure_no_mutable_alias<'a, ClassT: PyClass>( /// - there must be a value of type `FieldT` at the calculated offset within `ClassT` unsafe fn pyo3_get_value_into_pyobject_ref( py: Python<'_>, - obj: *mut ffi::PyObject, + obj: NonNull, ) -> PyResult<*mut ffi::PyObject> where ClassT: PyClass, @@ -1365,25 +1365,24 @@ where /// - value of type `FieldT` must exist at the given offset within obj unsafe fn inner( py: Python<'_>, - obj: *const (), + obj: NonNull<()>, offset: usize, ) -> PyResult<*mut ffi::PyObject> where for<'a, 'py> &'a FieldT: IntoPyObject<'py>, { // SAFETY: caller upholds safety invariants - let value = unsafe { &*obj.byte_add(offset).cast::() }; + let value = unsafe { obj.byte_add(offset).cast::().as_ref() }; value.into_py_any(py).map(Py::into_ptr) } // SAFETY: `obj` is a valid pointer to `ClassT` let _holder = unsafe { ensure_no_mutable_alias::(py, &obj)? }; let class_ptr = obj.cast::<::Layout>(); - let class_obj = unsafe { &*class_ptr }; - let contents_ptr = ptr::from_ref(class_obj.contents()); + let class_obj = unsafe { class_ptr.as_ref() }; // SAFETY: _holder prevents mutable aliasing, caller upholds other safety invariants - unsafe { inner::(py, contents_ptr.cast(), OFFSET) } + unsafe { inner::(py, NonNull::from(class_obj.contents()).cast(), OFFSET) } } /// Gets a field value from a pyclass and produces a python value using `IntoPyObject` for `FieldT`, @@ -1394,7 +1393,7 @@ where /// - there must be a value of type `FieldT` at the calculated offset within `ClassT` unsafe fn pyo3_get_value_into_pyobject( py: Python<'_>, - obj: *mut ffi::PyObject, + obj: NonNull, ) -> PyResult<*mut ffi::PyObject> where ClassT: PyClass, @@ -1407,25 +1406,24 @@ where /// - value of type `FieldT` must exist at the given offset within obj unsafe fn inner( py: Python<'_>, - obj: *const (), + obj: NonNull<()>, offset: usize, ) -> PyResult<*mut ffi::PyObject> where for<'py> FieldT: IntoPyObject<'py> + Clone, { // SAFETY: caller upholds safety invariants - let value = unsafe { &*obj.byte_add(offset).cast::() }; + let value = unsafe { obj.byte_add(offset).cast::().as_ref() }; value.clone().into_py_any(py).map(Py::into_ptr) } // SAFETY: `obj` is a valid pointer to `ClassT` let _holder = unsafe { ensure_no_mutable_alias::(py, &obj)? }; let class_ptr = obj.cast::<::Layout>(); - let class_obj = unsafe { &*class_ptr }; - let contents_ptr = ptr::from_ref(class_obj.contents()); + let class_obj = unsafe { class_ptr.as_ref() }; // SAFETY: _holder prevents mutable aliasing, caller upholds other safety invariants - unsafe { inner::(py, contents_ptr.cast(), OFFSET) } + unsafe { inner::(py, NonNull::from(class_obj.contents()).cast(), OFFSET) } } pub struct ConvertField< diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index ae391637dd7..7dc00b3cd7a 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -300,10 +300,14 @@ impl fmt::Debug for PyClassAttributeDef { /// Class getter / setters pub(crate) type Getter = - for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult<*mut ffi::PyObject>; -pub(crate) type Setter = - for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject, *mut ffi::PyObject) -> PyResult; -pub(crate) type Deleter = for<'py> unsafe fn(Python<'py>, *mut ffi::PyObject) -> PyResult; + for<'py> unsafe fn(Python<'py>, NonNull) -> PyResult<*mut ffi::PyObject>; +pub(crate) type Setter = for<'py> unsafe fn( + Python<'py>, + NonNull, + NonNull, +) -> PyResult; +pub(crate) type Deleter = + for<'py> unsafe fn(Python<'py>, NonNull) -> PyResult; impl PyGetterDef { /// Define a getter. diff --git a/src/pyclass/create_type_object.rs b/src/pyclass/create_type_object.rs index 86429a2b61e..9c72a9b4d1a 100644 --- a/src/pyclass/create_type_object.rs +++ b/src/pyclass/create_type_object.rs @@ -687,6 +687,7 @@ impl GetSetDefType { slf: *mut ffi::PyObject, closure: *mut c_void, ) -> *mut ffi::PyObject { + let slf = unsafe { NonNull::new_unchecked(slf) }; // Safety: PyO3 sets the closure when constructing the ffi getter so this cast should always be valid let getter: Getter = unsafe { std::mem::transmute(closure) }; unsafe { trampoline(|py| getter(py, slf)) } @@ -699,14 +700,15 @@ impl GetSetDefType { value: *mut ffi::PyObject, closure: *mut c_void, ) -> c_int { + let slf = unsafe { NonNull::new_unchecked(slf) }; // Safety: PyO3 sets the closure when constructing the ffi setter so this cast should always be valid let setter: Setter = unsafe { std::mem::transmute(closure) }; unsafe { trampoline(|py| { - if value.is_null() { - Err(PyAttributeError::new_err("property has no deleter")) - } else { + if let Some(value) = NonNull::new(value) { setter(py, slf, value) + } else { + Err(PyAttributeError::new_err("property has no deleter")) } }) } @@ -718,6 +720,7 @@ impl GetSetDefType { slf: *mut ffi::PyObject, closure: *mut c_void, ) -> *mut ffi::PyObject { + let slf = unsafe { NonNull::new_unchecked(slf) }; let getset: &GetSetDeleteCombination = unsafe { &*closure.cast() }; // we only call this method if getter is set unsafe { trampoline(|py| getset.getter.unwrap_unchecked()(py, slf)) } @@ -728,17 +731,18 @@ impl GetSetDefType { value: *mut ffi::PyObject, closure: *mut c_void, ) -> c_int { + let slf = unsafe { NonNull::new_unchecked(slf) }; let getset: &GetSetDeleteCombination = unsafe { &*closure.cast() }; unsafe { trampoline(|py| { - if value.is_null() { - getset.deleter.ok_or_else(|| { - PyAttributeError::new_err("property has no deleter") - })?(py, slf) - } else { + if let Some(value) = NonNull::new(value) { getset.setter.ok_or_else(|| { PyAttributeError::new_err("property has no setter") })?(py, slf, value) + } else { + getset.deleter.ok_or_else(|| { + PyAttributeError::new_err("property has no deleter") + })?(py, slf) } }) }