Skip to content
Snippets Groups Projects
Commit b0c889bccb37 authored by ijl's avatar ijl
Browse files

Refactor numpy

parent 981739186345
No related branches found
No related tags found
No related merge requests found
...@@ -750,7 +750,9 @@ ...@@ -750,7 +750,9 @@
### numpy ### numpy
orjson natively serializes `numpy.ndarray` instances. Arrays may have a orjson natively serializes `numpy.ndarray` and individual `numpy.float64`,
`numpy.float32`, `numpy.int64`, `numpy.int32`, `numpy.uint64`, and
`numpy.uint32` instances. Arrays may have a
`dtype` of `numpy.bool`, `numpy.float32`, `numpy.float64`, `numpy.int32`, `dtype` of `numpy.bool`, `numpy.float32`, `numpy.float64`, `numpy.int32`,
`numpy.int64`, `numpy.uint32`, `numpy.uint64`, `numpy.uintp`, or `numpy.intp`. `numpy.int64`, `numpy.uint32`, `numpy.uint64`, `numpy.uintp`, or `numpy.intp`.
orjson is faster than all compared libraries at serializing orjson is faster than all compared libraries at serializing
......
...@@ -432,7 +432,7 @@ ...@@ -432,7 +432,7 @@
) )
.serialize(serializer) .serialize(serializer)
} }
ObType::NumpyArray => match PyArray::new(self.ptr) { ObType::NumpyArray => match NumpyArray::new(self.ptr) {
Ok(val) => val.serialize(serializer), Ok(val) => val.serialize(serializer),
Err(PyArrayError::Malformed) => err!("numpy array is malformed"), Err(PyArrayError::Malformed) => err!("numpy array is malformed"),
Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => { Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => {
...@@ -446,18 +446,7 @@ ...@@ -446,18 +446,7 @@
.serialize(serializer) .serialize(serializer)
} }
}, },
ObType::NumpyScalar => match pyobj_to_numpy_obj(self.ptr) { ObType::NumpyScalar => NumpyScalar::new(self.ptr).serialize(serializer),
Ok(numpy_obj) => match numpy_obj {
NumpyObjects::Float32(obj) => obj.serialize(serializer),
NumpyObjects::Float64(obj) => obj.serialize(serializer),
NumpyObjects::Int32(obj) => obj.serialize(serializer),
NumpyObjects::Int64(obj) => obj.serialize(serializer),
NumpyObjects::Uint32(obj) => obj.serialize(serializer),
NumpyObjects::Uint64(obj) => obj.serialize(serializer),
},
Err(NumpyError::InvalidType) => err!("invalid numpy type"),
Err(NumpyError::NotAvailable) => err!("numpy not available"),
},
ObType::Unknown => DefaultSerializer::new( ObType::Unknown => DefaultSerializer::new(
self.ptr, self.ptr,
self.opts, self.opts,
......
...@@ -10,6 +10,28 @@ ...@@ -10,6 +10,28 @@
}; };
} }
pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
if unsafe { NUMPY_TYPES.is_none() } {
false
} else {
let scalar_types = unsafe { NUMPY_TYPES.as_ref().unwrap() };
ob_type == scalar_types.float64
|| ob_type == scalar_types.float32
|| ob_type == scalar_types.int64
|| ob_type == scalar_types.int32
|| ob_type == scalar_types.uint64
|| ob_type == scalar_types.uint32
}
}
pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool {
if unsafe { NUMPY_TYPES.is_none() } {
false
} else {
unsafe { ob_type == NUMPY_TYPES.as_ref().unwrap().array }
}
}
#[repr(C)] #[repr(C)]
pub struct PyCapsule { pub struct PyCapsule {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
...@@ -20,6 +42,8 @@ ...@@ -20,6 +42,8 @@
pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *); pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *);
} }
// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__
#[repr(C)] #[repr(C)]
pub struct PyArrayInterface { pub struct PyArrayInterface {
pub two: c_int, pub two: c_int,
...@@ -33,7 +57,6 @@ ...@@ -33,7 +57,6 @@
pub descr: *mut PyObject, pub descr: *mut PyObject,
} }
#[derive(Copy, Clone, PartialEq)]
pub enum ItemType { pub enum ItemType {
BOOL, BOOL,
F32, F32,
...@@ -50,9 +73,42 @@ ...@@ -50,9 +73,42 @@
UnsupportedDataType, UnsupportedDataType,
} }
pub enum NumpyError { #[repr(transparent)]
NotAvailable, pub struct NumpyScalar {
InvalidType, pub ptr: *mut pyo3::ffi::PyObject,
}
impl NumpyScalar {
pub fn new(ptr: *mut PyObject) -> Self {
NumpyScalar { ptr }
}
}
impl<'p> Serialize for NumpyScalar {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
unsafe {
let ob_type = ob_type!(self.ptr);
let scalar_types = NUMPY_TYPES.deref_mut().as_ref().unwrap();
if ob_type == scalar_types.float64 {
(*(self.ptr as *mut NumpyFloat64)).serialize(serializer)
} else if ob_type == scalar_types.float32 {
(*(self.ptr as *mut NumpyFloat32)).serialize(serializer)
} else if ob_type == scalar_types.int64 {
(*(self.ptr as *mut NumpyInt64)).serialize(serializer)
} else if ob_type == scalar_types.int32 {
(*(self.ptr as *mut NumpyInt32)).serialize(serializer)
} else if ob_type == scalar_types.uint64 {
(*(self.ptr as *mut NumpyUint64)).serialize(serializer)
} else if ob_type == scalar_types.uint32 {
(*(self.ptr as *mut NumpyUint32)).serialize(serializer)
} else {
unreachable!()
}
}
}
} }
// >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32) // >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32)
...@@ -62,6 +118,6 @@ ...@@ -62,6 +118,6 @@
// (2, 2, 2) // (2, 2, 2)
// >>> arr.strides // >>> arr.strides
// (16, 8, 4) // (16, 8, 4)
pub struct PyArray { pub struct NumpyArray {
array: *mut PyArrayInterface, array: *mut PyArrayInterface,
position: Vec<isize>, position: Vec<isize>,
...@@ -66,7 +122,7 @@ ...@@ -66,7 +122,7 @@
array: *mut PyArrayInterface, array: *mut PyArrayInterface,
position: Vec<isize>, position: Vec<isize>,
children: Vec<PyArray>, children: Vec<NumpyArray>,
depth: usize, depth: usize,
capsule: *mut PyCapsule, capsule: *mut PyCapsule,
} }
...@@ -69,9 +125,8 @@ ...@@ -69,9 +125,8 @@
depth: usize, depth: usize,
capsule: *mut PyCapsule, capsule: *mut PyCapsule,
} }
impl<'a> PyArray { impl<'a> NumpyArray {
#[cold]
pub fn new(ptr: *mut PyObject) -> Result<Self, PyArrayError> { pub fn new(ptr: *mut PyObject) -> Result<Self, PyArrayError> {
let capsule = ffi!(PyObject_GetAttr(ptr, ARRAY_STRUCT_STR)); let capsule = ffi!(PyObject_GetAttr(ptr, ARRAY_STRUCT_STR));
let array = unsafe { (*(capsule as *mut PyCapsule)).pointer as *mut PyArrayInterface }; let array = unsafe { (*(capsule as *mut PyCapsule)).pointer as *mut PyArrayInterface };
...@@ -86,7 +141,7 @@ ...@@ -86,7 +141,7 @@
if num_dimensions == 0 { if num_dimensions == 0 {
return Err(PyArrayError::UnsupportedDataType); return Err(PyArrayError::UnsupportedDataType);
} }
let mut pyarray = PyArray { let mut pyarray = NumpyArray {
array: array, array: array,
position: vec![0; num_dimensions], position: vec![0; num_dimensions],
children: Vec::with_capacity(num_dimensions), children: Vec::with_capacity(num_dimensions),
...@@ -105,7 +160,7 @@ ...@@ -105,7 +160,7 @@
} }
fn from_parent(&self, position: Vec<isize>, num_children: usize) -> Self { fn from_parent(&self, position: Vec<isize>, num_children: usize) -> Self {
let mut arr = PyArray { let mut arr = NumpyArray {
array: self.array, array: self.array,
position: position, position: position,
children: Vec::with_capacity(num_children), children: Vec::with_capacity(num_children),
...@@ -145,7 +200,7 @@ ...@@ -145,7 +200,7 @@
} }
} }
fn data(&self) -> *mut c_void { fn data(&self) -> *const c_void {
let offset = self let offset = self
.strides() .strides()
.iter() .iter()
...@@ -173,7 +228,7 @@ ...@@ -173,7 +228,7 @@
} }
} }
impl Drop for PyArray { impl Drop for NumpyArray {
fn drop(&mut self) { fn drop(&mut self) {
if self.depth == 0 { if self.depth == 0 {
ffi!(Py_XDECREF(self.capsule as *mut pyo3::ffi::PyObject)) ffi!(Py_XDECREF(self.capsule as *mut pyo3::ffi::PyObject))
...@@ -181,7 +236,7 @@ ...@@ -181,7 +236,7 @@
} }
} }
impl<'p> Serialize for PyArray { impl<'p> Serialize for NumpyArray {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
...@@ -342,7 +397,6 @@ ...@@ -342,7 +397,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyInt32 { pub struct NumpyInt32 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -359,7 +413,6 @@ ...@@ -359,7 +413,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyInt64 { pub struct NumpyInt64 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -376,7 +429,6 @@ ...@@ -376,7 +429,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyUint32 { pub struct NumpyUint32 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -393,7 +445,6 @@ ...@@ -393,7 +445,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyUint64 { pub struct NumpyUint64 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -410,7 +461,6 @@ ...@@ -410,7 +461,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyFloat32 { pub struct NumpyFloat32 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -427,7 +477,6 @@ ...@@ -427,7 +477,6 @@
} }
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct NumpyFloat64 { pub struct NumpyFloat64 {
pub ob_refcnt: Py_ssize_t, pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject, pub ob_type: *mut PyTypeObject,
...@@ -442,83 +491,3 @@ ...@@ -442,83 +491,3 @@
serializer.serialize_f64(self.value) serializer.serialize_f64(self.value)
} }
} }
pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
let available_types;
unsafe {
match NUMPY_TYPES.deref_mut() {
Some(v) => available_types = v,
_ => return false,
}
}
let numpy_scalars = [
available_types.float32,
available_types.float64,
available_types.int32,
available_types.int64,
available_types.uint32,
available_types.uint64,
];
numpy_scalars.contains(&ob_type)
}
pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool {
let available_types;
unsafe {
match NUMPY_TYPES.deref_mut() {
Some(v) => available_types = v,
_ => return false,
}
}
available_types.array == ob_type
}
// pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error>
// where
// S: Serializer,
// {
// let ob_type = ob_type!(obj);
// let numpy = match ob_type {
// }
// }
pub enum NumpyObjects {
Float32(NumpyFloat32),
Float64(NumpyFloat64),
Int32(NumpyInt32),
Int64(NumpyInt64),
Uint32(NumpyUint32),
Uint64(NumpyUint64),
}
pub fn pyobj_to_numpy_obj(obj: *mut pyo3::ffi::PyObject) -> Result<NumpyObjects, NumpyError> {
let available_types;
unsafe {
match NUMPY_TYPES.deref_mut() {
Some(v) => available_types = v,
_ => return Err(NumpyError::NotAvailable),
}
}
let ob_type = ob_type!(obj);
unsafe {
if ob_type == available_types.float32 {
return Ok(NumpyObjects::Float32(*(obj as *mut NumpyFloat32)));
} else if ob_type == available_types.float64 {
return Ok(NumpyObjects::Float64(*(obj as *mut NumpyFloat64)));
} else if ob_type == available_types.int32 {
return Ok(NumpyObjects::Int32(*(obj as *mut NumpyInt32)));
} else if ob_type == available_types.int64 {
return Ok(NumpyObjects::Int64(*(obj as *mut NumpyInt64)));
} else if ob_type == available_types.uint32 {
return Ok(NumpyObjects::Uint32(*(obj as *mut NumpyUint32)));
} else if ob_type == available_types.uint64 {
return Ok(NumpyObjects::Uint64(*(obj as *mut NumpyUint64)));
} else {
return Err(NumpyError::InvalidType);
}
}
}
...@@ -7,4 +7,5 @@ ...@@ -7,4 +7,5 @@
use std::sync::Once; use std::sync::Once;
pub struct NumpyTypes { pub struct NumpyTypes {
pub float64: *mut PyTypeObject,
pub float32: *mut PyTypeObject, pub float32: *mut PyTypeObject,
...@@ -10,4 +11,3 @@ ...@@ -10,4 +11,3 @@
pub float32: *mut PyTypeObject, pub float32: *mut PyTypeObject,
pub array: *mut PyTypeObject, pub int64: *mut PyTypeObject,
pub float64: *mut PyTypeObject,
pub int32: *mut PyTypeObject, pub int32: *mut PyTypeObject,
...@@ -13,3 +13,3 @@ ...@@ -13,3 +13,3 @@
pub int32: *mut PyTypeObject, pub int32: *mut PyTypeObject,
pub int64: *mut PyTypeObject, pub uint64: *mut PyTypeObject,
pub uint32: *mut PyTypeObject, pub uint32: *mut PyTypeObject,
...@@ -15,5 +15,5 @@ ...@@ -15,5 +15,5 @@
pub uint32: *mut PyTypeObject, pub uint32: *mut PyTypeObject,
pub uint64: *mut PyTypeObject, pub array: *mut PyTypeObject,
} }
pub static mut HASH_SEED: u64 = 0; pub static mut HASH_SEED: u64 = 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment