# HG changeset patch # User ijl <ijl@mailbox.org> # Date 1595553764 0 # Fri Jul 24 01:22:44 2020 +0000 # Node ID b0c889bccb37e811fbb6b3d3b42f9be5d956f38d # Parent 9817391863459df8020b4cc14c1f47dfe8ef8676 Refactor numpy diff --git a/README.md b/README.md --- a/README.md +++ b/README.md @@ -750,7 +750,9 @@ ### numpy -orjson natively serializes `numpy.ndarray` instances. Arrays may have a +orjson natively serializes `numpy.ndarray` and individual `numpy.float64`, +`numpy.float32`, `numpy.int64`, `numpy.int32`, `numpy.uint64`, and +`numpy.uint32` instances. Arrays may have a `dtype` of `numpy.bool`, `numpy.float32`, `numpy.float64`, `numpy.int32`, `numpy.int64`, `numpy.uint32`, `numpy.uint64`, `numpy.uintp`, or `numpy.intp`. orjson is faster than all compared libraries at serializing diff --git a/src/serialize/encode.rs b/src/serialize/encode.rs --- a/src/serialize/encode.rs +++ b/src/serialize/encode.rs @@ -432,7 +432,7 @@ ) .serialize(serializer) } - ObType::NumpyArray => match PyArray::new(self.ptr) { + ObType::NumpyArray => match NumpyArray::new(self.ptr) { Ok(val) => val.serialize(serializer), Err(PyArrayError::Malformed) => err!("numpy array is malformed"), Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => { @@ -446,18 +446,7 @@ .serialize(serializer) } }, - ObType::NumpyScalar => match pyobj_to_numpy_obj(self.ptr) { - Ok(numpy_obj) => match numpy_obj { - NumpyObjects::Float32(obj) => obj.serialize(serializer), - NumpyObjects::Float64(obj) => obj.serialize(serializer), - NumpyObjects::Int32(obj) => obj.serialize(serializer), - NumpyObjects::Int64(obj) => obj.serialize(serializer), - NumpyObjects::Uint32(obj) => obj.serialize(serializer), - NumpyObjects::Uint64(obj) => obj.serialize(serializer), - }, - Err(NumpyError::InvalidType) => err!("invalid numpy type"), - Err(NumpyError::NotAvailable) => err!("numpy not available"), - }, + ObType::NumpyScalar => NumpyScalar::new(self.ptr).serialize(serializer), ObType::Unknown => DefaultSerializer::new( self.ptr, self.opts, diff --git a/src/serialize/numpy.rs b/src/serialize/numpy.rs --- a/src/serialize/numpy.rs +++ b/src/serialize/numpy.rs @@ -10,6 +10,28 @@ }; } +pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool { + if unsafe { NUMPY_TYPES.is_none() } { + false + } else { + let scalar_types = unsafe { NUMPY_TYPES.as_ref().unwrap() }; + ob_type == scalar_types.float64 + || ob_type == scalar_types.float32 + || ob_type == scalar_types.int64 + || ob_type == scalar_types.int32 + || ob_type == scalar_types.uint64 + || ob_type == scalar_types.uint32 + } +} + +pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool { + if unsafe { NUMPY_TYPES.is_none() } { + false + } else { + unsafe { ob_type == NUMPY_TYPES.as_ref().unwrap().array } + } +} + #[repr(C)] pub struct PyCapsule { pub ob_refcnt: Py_ssize_t, @@ -20,6 +42,8 @@ pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *); } +// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__ + #[repr(C)] pub struct PyArrayInterface { pub two: c_int, @@ -33,7 +57,6 @@ pub descr: *mut PyObject, } -#[derive(Copy, Clone, PartialEq)] pub enum ItemType { BOOL, F32, @@ -50,9 +73,42 @@ UnsupportedDataType, } -pub enum NumpyError { - NotAvailable, - InvalidType, +#[repr(transparent)] +pub struct NumpyScalar { + pub ptr: *mut pyo3::ffi::PyObject, +} + +impl NumpyScalar { + pub fn new(ptr: *mut PyObject) -> Self { + NumpyScalar { ptr } + } +} + +impl<'p> Serialize for NumpyScalar { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + unsafe { + let ob_type = ob_type!(self.ptr); + let scalar_types = NUMPY_TYPES.deref_mut().as_ref().unwrap(); + if ob_type == scalar_types.float64 { + (*(self.ptr as *mut NumpyFloat64)).serialize(serializer) + } else if ob_type == scalar_types.float32 { + (*(self.ptr as *mut NumpyFloat32)).serialize(serializer) + } else if ob_type == scalar_types.int64 { + (*(self.ptr as *mut NumpyInt64)).serialize(serializer) + } else if ob_type == scalar_types.int32 { + (*(self.ptr as *mut NumpyInt32)).serialize(serializer) + } else if ob_type == scalar_types.uint64 { + (*(self.ptr as *mut NumpyUint64)).serialize(serializer) + } else if ob_type == scalar_types.uint32 { + (*(self.ptr as *mut NumpyUint32)).serialize(serializer) + } else { + unreachable!() + } + } + } } // >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32) @@ -62,16 +118,15 @@ // (2, 2, 2) // >>> arr.strides // (16, 8, 4) -pub struct PyArray { +pub struct NumpyArray { array: *mut PyArrayInterface, position: Vec<isize>, - children: Vec<PyArray>, + children: Vec<NumpyArray>, depth: usize, capsule: *mut PyCapsule, } -impl<'a> PyArray { - #[cold] +impl<'a> NumpyArray { pub fn new(ptr: *mut PyObject) -> Result<Self, PyArrayError> { let capsule = ffi!(PyObject_GetAttr(ptr, ARRAY_STRUCT_STR)); let array = unsafe { (*(capsule as *mut PyCapsule)).pointer as *mut PyArrayInterface }; @@ -86,7 +141,7 @@ if num_dimensions == 0 { return Err(PyArrayError::UnsupportedDataType); } - let mut pyarray = PyArray { + let mut pyarray = NumpyArray { array: array, position: vec![0; num_dimensions], children: Vec::with_capacity(num_dimensions), @@ -105,7 +160,7 @@ } fn from_parent(&self, position: Vec<isize>, num_children: usize) -> Self { - let mut arr = PyArray { + let mut arr = NumpyArray { array: self.array, position: position, children: Vec::with_capacity(num_children), @@ -145,7 +200,7 @@ } } - fn data(&self) -> *mut c_void { + fn data(&self) -> *const c_void { let offset = self .strides() .iter() @@ -173,7 +228,7 @@ } } -impl Drop for PyArray { +impl Drop for NumpyArray { fn drop(&mut self) { if self.depth == 0 { ffi!(Py_XDECREF(self.capsule as *mut pyo3::ffi::PyObject)) @@ -181,7 +236,7 @@ } } -impl<'p> Serialize for PyArray { +impl<'p> Serialize for NumpyArray { fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, @@ -342,7 +397,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyInt32 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -359,7 +413,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyInt64 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -376,7 +429,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyUint32 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -393,7 +445,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyUint64 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -410,7 +461,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyFloat32 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -427,7 +477,6 @@ } #[repr(C)] -#[derive(Copy, Clone)] pub struct NumpyFloat64 { pub ob_refcnt: Py_ssize_t, pub ob_type: *mut PyTypeObject, @@ -442,83 +491,3 @@ serializer.serialize_f64(self.value) } } - -pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool { - let available_types; - unsafe { - match NUMPY_TYPES.deref_mut() { - Some(v) => available_types = v, - _ => return false, - } - } - - let numpy_scalars = [ - available_types.float32, - available_types.float64, - available_types.int32, - available_types.int64, - available_types.uint32, - available_types.uint64, - ]; - numpy_scalars.contains(&ob_type) -} - -pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool { - let available_types; - unsafe { - match NUMPY_TYPES.deref_mut() { - Some(v) => available_types = v, - _ => return false, - } - } - available_types.array == ob_type -} - -// pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error> -// where -// S: Serializer, -// { -// let ob_type = ob_type!(obj); -// let numpy = match ob_type { - -// } -// } - -pub enum NumpyObjects { - Float32(NumpyFloat32), - Float64(NumpyFloat64), - Int32(NumpyInt32), - Int64(NumpyInt64), - Uint32(NumpyUint32), - Uint64(NumpyUint64), -} - -pub fn pyobj_to_numpy_obj(obj: *mut pyo3::ffi::PyObject) -> Result<NumpyObjects, NumpyError> { - let available_types; - unsafe { - match NUMPY_TYPES.deref_mut() { - Some(v) => available_types = v, - _ => return Err(NumpyError::NotAvailable), - } - } - - let ob_type = ob_type!(obj); - - unsafe { - if ob_type == available_types.float32 { - return Ok(NumpyObjects::Float32(*(obj as *mut NumpyFloat32))); - } else if ob_type == available_types.float64 { - return Ok(NumpyObjects::Float64(*(obj as *mut NumpyFloat64))); - } else if ob_type == available_types.int32 { - return Ok(NumpyObjects::Int32(*(obj as *mut NumpyInt32))); - } else if ob_type == available_types.int64 { - return Ok(NumpyObjects::Int64(*(obj as *mut NumpyInt64))); - } else if ob_type == available_types.uint32 { - return Ok(NumpyObjects::Uint32(*(obj as *mut NumpyUint32))); - } else if ob_type == available_types.uint64 { - return Ok(NumpyObjects::Uint64(*(obj as *mut NumpyUint64))); - } else { - return Err(NumpyError::InvalidType); - } - } -} diff --git a/src/typeref.rs b/src/typeref.rs --- a/src/typeref.rs +++ b/src/typeref.rs @@ -7,13 +7,13 @@ use std::sync::Once; pub struct NumpyTypes { + pub float64: *mut PyTypeObject, pub float32: *mut PyTypeObject, - pub array: *mut PyTypeObject, - pub float64: *mut PyTypeObject, + pub int64: *mut PyTypeObject, pub int32: *mut PyTypeObject, - pub int64: *mut PyTypeObject, + pub uint64: *mut PyTypeObject, pub uint32: *mut PyTypeObject, - pub uint64: *mut PyTypeObject, + pub array: *mut PyTypeObject, } pub static mut HASH_SEED: u64 = 0;