Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
const MODULE: Option<&'static str> = Some("numpy");

fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
unsafe { npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
Expand Down Expand Up @@ -233,7 +233,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let mut dims = dims.into_dimension();
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
Expand All @@ -259,7 +259,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let mut dims = dims.into_dimension();
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
npyffi::get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
Expand Down
4 changes: 2 additions & 2 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use pyo3::{
};

use crate::npyffi::{
NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
self, NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
};
Expand Down Expand Up @@ -58,7 +58,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {

#[inline]
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
unsafe { npyffi::get_type_object(py, NpyTypes::PyArrayDescr_Type) }
}
}

Expand Down
256 changes: 112 additions & 144 deletions src/npyffi/array.rs

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions src/npyffi/flags.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::npy_uint32;
use super::{npy_uint32, npy_uint64};
use std::os::raw::c_int;

pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
Expand All @@ -11,8 +11,8 @@ pub const NPY_ARRAY_ELEMENTSTRIDES: c_int = 0x0080;
pub const NPY_ARRAY_ALIGNED: c_int = 0x0100;
pub const NPY_ARRAY_NOTSWAPPED: c_int = 0x0200;
pub const NPY_ARRAY_WRITEABLE: c_int = 0x0400;
pub const NPY_ARRAY_UPDATEIFCOPY: c_int = 0x1000;
pub const NPY_ARRAY_WRITEBACKIFCOPY: c_int = 0x2000;
pub const NPY_ARRAY_ENSURENOCOPY: c_int = 0x4000;
pub const NPY_ARRAY_BEHAVED: c_int = NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE;
pub const NPY_ARRAY_BEHAVED_NS: c_int = NPY_ARRAY_BEHAVED | NPY_ARRAY_NOTSWAPPED;
pub const NPY_ARRAY_CARRAY: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED;
Expand All @@ -22,13 +22,14 @@ pub const NPY_ARRAY_FARRAY_RO: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNE
pub const NPY_ARRAY_DEFAULT: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_IN_ARRAY: c_int = NPY_ARRAY_CARRAY_RO;
pub const NPY_ARRAY_OUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY;
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
pub const NPY_ARRAY_INOUT_ARRAY2: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
pub const NPY_ARRAY_IN_FARRAY: c_int = NPY_ARRAY_FARRAY_RO;
pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;
pub const NPY_ARRAY_UPDATE_ALL: c_int =
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNED;

pub const NPY_ITER_C_INDEX: npy_uint32 = 0x00000001;
pub const NPY_ITER_F_INDEX: npy_uint32 = 0x00000002;
Expand Down Expand Up @@ -63,19 +64,18 @@ pub const NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE: npy_uint32 = 0x40000000;
pub const NPY_ITER_GLOBAL_FLAGS: npy_uint32 = 0x0000ffff;
pub const NPY_ITER_PER_OP_FLAGS: npy_uint32 = 0xffff0000;

pub const NPY_ITEM_REFCOUNT: u64 = 0x01;
pub const NPY_ITEM_HASOBJECT: u64 = 0x01;
pub const NPY_LIST_PICKLE: u64 = 0x02;
pub const NPY_ITEM_IS_POINTER: u64 = 0x04;
pub const NPY_NEEDS_INIT: u64 = 0x08;
pub const NPY_NEEDS_PYAPI: u64 = 0x10;
pub const NPY_USE_GETITEM: u64 = 0x20;
pub const NPY_USE_SETITEM: u64 = 0x40;
#[allow(overflowing_literals)]
pub const NPY_ALIGNED_STRUCT: u64 = 0x80;
pub const NPY_FROM_FIELDS: u64 =
pub const NPY_ITEM_REFCOUNT: npy_uint64 = 0x01;
pub const NPY_ITEM_HASOBJECT: npy_uint64 = 0x01;
pub const NPY_LIST_PICKLE: npy_uint64 = 0x02;
pub const NPY_ITEM_IS_POINTER: npy_uint64 = 0x04;
pub const NPY_NEEDS_INIT: npy_uint64 = 0x08;
pub const NPY_NEEDS_PYAPI: npy_uint64 = 0x10;
pub const NPY_USE_GETITEM: npy_uint64 = 0x20;
pub const NPY_USE_SETITEM: npy_uint64 = 0x40;
pub const NPY_ALIGNED_STRUCT: npy_uint64 = 0x80;
pub const NPY_FROM_FIELDS: npy_uint64 =
NPY_NEEDS_INIT | NPY_LIST_PICKLE | NPY_ITEM_REFCOUNT | NPY_NEEDS_PYAPI;
pub const NPY_OBJECT_DTYPE_FLAGS: u64 = NPY_LIST_PICKLE
pub const NPY_OBJECT_DTYPE_FLAGS: npy_uint64 = NPY_LIST_PICKLE
| NPY_USE_GETITEM
| NPY_ITEM_IS_POINTER
| NPY_ITEM_REFCOUNT
Expand Down
114 changes: 74 additions & 40 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,40 @@

use std::mem::forget;
use std::os::raw::{c_uint, c_void};
use std::ptr::NonNull;

use pyo3::{
ffi::PyTypeObject,
sync::PyOnceLock,
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
PyResult, Python,
};

pub const API_VERSION_2_0: c_uint = 0x00000012;

static API_VERSION: PyOnceLock<c_uint> = PyOnceLock::new();

fn get_numpy_api<'py>(
py: Python<'py>,
module: &str,
capsule: &str,
) -> PyResult<*const *const c_void> {
) -> PyResult<NonNull<*const c_void>> {
let module = PyModule::import(py, module)?;
let capsule = module.getattr(capsule)?.cast_into::<PyCapsule>()?;

let api = capsule
.pointer_checked(None)?
.cast::<*const c_void>()
.as_ptr()
.cast_const();
let api = capsule.pointer_checked(None)?;

// Intentionally leak a reference to the capsule
// so we can safely cache a pointer into its interior.
forget(capsule);

Ok(api)
Ok(api.cast())
}

/// Returns whether the runtime `numpy` version is 2.0 or greater.
pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to look at the usages of this. These are used to conditionally cast structures, which we should not do anymore now that we are targeting a single ABI, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is used only in npyffi and in a test in src/dtype.rs. In npyffi it is used exactly like in the NumPy compatibility header. Functions PyDataType_SET_ELSIZE and PyDataType_FLAGS match. Macro define_descr_accessor matches the NumPy macro DESCR_ACCESSOR.

let api_version = *API_VERSION.get_or_init(py, || unsafe {
PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
});
api_version >= API_VERSION_2_0
api_version >= NPY_2_0_API_VERSION
}

// Implements wrappers for NumPy's Array and UFunc API
Expand All @@ -57,52 +53,90 @@ macro_rules! impl_api {
[$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
let f: extern "C" fn ($($arg : $t), *) $(-> $ret)* = self.get(py, $offset).cast().read();
f($($arg), *)
}
};
}

// API with version constraints, checked at runtime
[$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
assert!(
!is_numpy_2(py),
"{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
}
// Define type objects associated with the NumPy API
macro_rules! impl_array_type {
($(($api:ident [ $offset:expr ] , $tname:ident)),* $(,)?) => {
/// All type objects exported by the NumPy API.
#[allow(non_camel_case_types)]
pub enum NpyTypes { $($tname),* }

};
[$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
#[allow(non_snake_case)]
pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
assert!(
is_numpy_2(py),
"{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
stringify!($fname),
API_VERSION_2_0,
*API_VERSION.get(py).expect("API_VERSION is initialized"),
);
let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
(*fptr)($($arg), *)
/// Get a pointer of the type object associated with `ty`.
pub unsafe fn get_type_object<'py>(py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
match ty {
$( NpyTypes::$tname => $api.get(py, $offset).read() as _ ),*
}
}
}
}

};
impl_array_type! {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart, I like the syntax and that this can now support returning type objects from different APIs.

// Multiarray API
// Slot 1 was never meaningfully used by NumPy
(PY_ARRAY_API[2], PyArray_Type),
(PY_ARRAY_API[3], PyArrayDescr_Type),
// Unused slot 4, was `PyArrayFlags_Type`
(PY_ARRAY_API[5], PyArrayIter_Type),
(PY_ARRAY_API[6], PyArrayMultiIter_Type),
// (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
(PY_ARRAY_API[8], PyBoolArrType_Type),
// (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
(PY_ARRAY_API[10], PyGenericArrType_Type),
(PY_ARRAY_API[11], PyNumberArrType_Type),
(PY_ARRAY_API[12], PyIntegerArrType_Type),
(PY_ARRAY_API[13], PySignedIntegerArrType_Type),
(PY_ARRAY_API[14], PyUnsignedIntegerArrType_Type),
(PY_ARRAY_API[15], PyInexactArrType_Type),
(PY_ARRAY_API[16], PyFloatingArrType_Type),
(PY_ARRAY_API[17], PyComplexFloatingArrType_Type),
(PY_ARRAY_API[18], PyFlexibleArrType_Type),
(PY_ARRAY_API[19], PyCharacterArrType_Type),
(PY_ARRAY_API[20], PyByteArrType_Type),
(PY_ARRAY_API[21], PyShortArrType_Type),
(PY_ARRAY_API[22], PyIntArrType_Type),
(PY_ARRAY_API[23], PyLongArrType_Type),
(PY_ARRAY_API[24], PyLongLongArrType_Type),
(PY_ARRAY_API[25], PyUByteArrType_Type),
(PY_ARRAY_API[26], PyUShortArrType_Type),
(PY_ARRAY_API[27], PyUIntArrType_Type),
(PY_ARRAY_API[28], PyULongArrType_Type),
(PY_ARRAY_API[29], PyULongLongArrType_Type),
(PY_ARRAY_API[30], PyFloatArrType_Type),
(PY_ARRAY_API[31], PyDoubleArrType_Type),
(PY_ARRAY_API[32], PyLongDoubleArrType_Type),
(PY_ARRAY_API[33], PyCFloatArrType_Type),
(PY_ARRAY_API[34], PyCDoubleArrType_Type),
(PY_ARRAY_API[35], PyCLongDoubleArrType_Type),
(PY_ARRAY_API[36], PyObjectArrType_Type),
(PY_ARRAY_API[37], PyStringArrType_Type),
(PY_ARRAY_API[38], PyUnicodeArrType_Type),
(PY_ARRAY_API[39], PyVoidArrType_Type),
(PY_ARRAY_API[214], PyTimeIntegerArrType_Type),
(PY_ARRAY_API[215], PyDatetimeArrType_Type),
(PY_ARRAY_API[216], PyTimedeltaArrType_Type),
(PY_ARRAY_API[217], PyHalfArrType_Type),
(PY_ARRAY_API[218], NpyIter_Type),
// UFunc API
(PY_UFUNC_API[0], PyUFunc_Type),
}

pub mod array;
pub mod flags;
mod npy_common;
mod numpyconfig;
pub mod objects;
pub mod types;
pub mod ufunc;

pub use self::array::*;
pub use self::flags::*;
pub use self::npy_common::*;
pub use self::numpyconfig::*;
pub use self::objects::*;
pub use self::types::*;
pub use self::ufunc::*;
8 changes: 8 additions & 0 deletions src/npyffi/npy_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use std::ffi::c_int;

/// Unknown CPU endianness.
pub const NPY_CPU_UNKNOWN_ENDIAN: c_int = 0;
/// CPU is little-endian.
pub const NPY_CPU_LITTLE: c_int = 1;
/// CPU is big-endian.
pub const NPY_CPU_BIG: c_int = 2;
18 changes: 18 additions & 0 deletions src/npyffi/numpyconfig.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// This file matches the numpyconfig.h header.

use std::ffi::c_uint;

/// The current target ABI version
const NPY_ABI_VERSION: c_uint = 0x02000000;

/// The current target API version (v1.15)
const NPY_API_VERSION: c_uint = 0x0000000c;

pub(super) const NPY_2_0_API_VERSION: c_uint = 0x00000012;

/// The current version of the `ndarray` object (ABI version).
pub const NPY_VERSION: c_uint = NPY_ABI_VERSION;
/// The current version of C API.
pub const NPY_FEATURE_VERSION: c_uint = NPY_API_VERSION;
/// The string representation of current version C API.
pub const NPY_FEATURE_VERSION_STRING: &str = "1.15";
Loading
Loading