# HG changeset patch # User Aviram Hassan <aviram.hassan@biocatch.com> # Date 1594207330 -10800 # Wed Jul 08 14:22:10 2020 +0300 # Node ID 9817391863459df8020b4cc14c1f47dfe8ef8676 # Parent 83db8324a4cb41c42fc1d0096bca1ffe6078f677 add numpy primitives encoding diff --git a/src/serialize/dict.rs b/src/serialize/dict.rs --- a/src/serialize/dict.rs +++ b/src/serialize/dict.rs @@ -215,7 +215,8 @@ } } ObType::Tuple - | ObType::Array + | ObType::NumpyScalar + | ObType::NumpyArray | ObType::Dict | ObType::List | ObType::Dataclass diff --git a/src/serialize/encode.rs b/src/serialize/encode.rs --- a/src/serialize/encode.rs +++ b/src/serialize/encode.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::exc::*; -use crate::ffi::PyDict_GET_SIZE; use crate::ffi::*; use crate::opt::*; use crate::serialize::dataclass::*; @@ -33,7 +32,7 @@ let mut buf = BytesWriter::new(); let obtype = pyobject_to_obtype(ptr, opts); match obtype { - ObType::List | ObType::Dict | ObType::Dataclass | ObType::Array => { + ObType::List | ObType::Dict | ObType::Dataclass | ObType::NumpyArray => { buf.resize(1024); } _ => {} @@ -75,7 +74,8 @@ Tuple, Uuid, Dataclass, - Array, + NumpyScalar, + NumpyArray, Enum, StrSubclass, Unknown, @@ -145,11 +145,10 @@ ObType::Dict } else if ffi!(PyDict_Contains((*ob_type).tp_dict, DATACLASS_FIELDS_STR)) == 1 { ObType::Dataclass - } else if opts & SERIALIZE_NUMPY != 0 - && ARRAY_TYPE.is_some() - && ob_type == ARRAY_TYPE.unwrap().as_ptr() - { - ObType::Array + } else if opts & SERIALIZE_NUMPY != 0 && is_numpy_scalar(ob_type) { + ObType::NumpyScalar + } else if opts & SERIALIZE_NUMPY != 0 && is_numpy_array(ob_type) { + ObType::NumpyArray } else { ObType::Unknown } @@ -433,7 +432,7 @@ ) .serialize(serializer) } - ObType::Array => match PyArray::new(self.ptr) { + ObType::NumpyArray => match PyArray::new(self.ptr) { Ok(val) => val.serialize(serializer), Err(PyArrayError::Malformed) => err!("numpy array is malformed"), Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => { @@ -447,6 +446,18 @@ .serialize(serializer) } }, + ObType::NumpyScalar => match pyobj_to_numpy_obj(self.ptr) { + Ok(numpy_obj) => match numpy_obj { + NumpyObjects::Float32(obj) => obj.serialize(serializer), + NumpyObjects::Float64(obj) => obj.serialize(serializer), + NumpyObjects::Int32(obj) => obj.serialize(serializer), + NumpyObjects::Int64(obj) => obj.serialize(serializer), + NumpyObjects::Uint32(obj) => obj.serialize(serializer), + NumpyObjects::Uint64(obj) => obj.serialize(serializer), + }, + Err(NumpyError::InvalidType) => err!("invalid numpy type"), + Err(NumpyError::NotAvailable) => err!("numpy not available"), + }, ObType::Unknown => DefaultSerializer::new( self.ptr, self.opts, diff --git a/src/serialize/numpy.rs b/src/serialize/numpy.rs --- a/src/serialize/numpy.rs +++ b/src/serialize/numpy.rs @@ -1,8 +1,7 @@ -// SPDX-License-Identifier: (Apache-2.0 OR MIT) - -use crate::typeref::ARRAY_STRUCT_STR; +use crate::typeref::{ARRAY_STRUCT_STR, NUMPY_TYPES}; use pyo3::ffi::*; use serde::ser::{Serialize, SerializeSeq, Serializer}; +use std::ops::DerefMut; use std::os::raw::{c_char, c_int, c_void}; macro_rules! slice { @@ -21,8 +20,6 @@ pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *); } -// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__ - #[repr(C)] pub struct PyArrayInterface { pub two: c_int, @@ -53,6 +50,11 @@ UnsupportedDataType, } +pub enum NumpyError { + NotAvailable, + InvalidType, +} + // >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32) // >>> arr.ndim // 3 @@ -256,7 +258,7 @@ } #[repr(transparent)] -struct DataTypeF64 { +pub struct DataTypeF64 { pub obj: f64, } @@ -270,7 +272,7 @@ } #[repr(transparent)] -struct DataTypeI32 { +pub struct DataTypeI32 { pub obj: i32, } @@ -284,7 +286,7 @@ } #[repr(transparent)] -struct DataTypeI64 { +pub struct DataTypeI64 { pub obj: i64, } @@ -298,7 +300,7 @@ } #[repr(transparent)] -struct DataTypeU32 { +pub struct DataTypeU32 { pub obj: u32, } @@ -312,7 +314,7 @@ } #[repr(transparent)] -struct DataTypeU64 { +pub struct DataTypeU64 { pub obj: u64, } @@ -326,7 +328,7 @@ } #[repr(transparent)] -struct DataTypeBOOL { +pub struct DataTypeBOOL { pub obj: u8, } @@ -338,3 +340,185 @@ serializer.serialize_bool(self.obj == 1) } } + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyInt32 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: i32, +} + +impl<'p> Serialize for NumpyInt32 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_i32(self.value) + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyInt64 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: i64, +} + +impl<'p> Serialize for NumpyInt64 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_i64(self.value) + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyUint32 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: u32, +} + +impl<'p> Serialize for NumpyUint32 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_u32(self.value) + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyUint64 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: u64, +} + +impl<'p> Serialize for NumpyUint64 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_u64(self.value) + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyFloat32 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: f32, +} + +impl<'p> Serialize for NumpyFloat32 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_f32(self.value) + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct NumpyFloat64 { + pub ob_refcnt: Py_ssize_t, + pub ob_type: *mut PyTypeObject, + pub value: f64, +} + +impl<'p> Serialize for NumpyFloat64 { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + serializer.serialize_f64(self.value) + } +} + +pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool { + let available_types; + unsafe { + match NUMPY_TYPES.deref_mut() { + Some(v) => available_types = v, + _ => return false, + } + } + + let numpy_scalars = [ + available_types.float32, + available_types.float64, + available_types.int32, + available_types.int64, + available_types.uint32, + available_types.uint64, + ]; + numpy_scalars.contains(&ob_type) +} + +pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool { + let available_types; + unsafe { + match NUMPY_TYPES.deref_mut() { + Some(v) => available_types = v, + _ => return false, + } + } + available_types.array == ob_type +} + +// pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error> +// where +// S: Serializer, +// { +// let ob_type = ob_type!(obj); +// let numpy = match ob_type { + +// } +// } + +pub enum NumpyObjects { + Float32(NumpyFloat32), + Float64(NumpyFloat64), + Int32(NumpyInt32), + Int64(NumpyInt64), + Uint32(NumpyUint32), + Uint64(NumpyUint64), +} + +pub fn pyobj_to_numpy_obj(obj: *mut pyo3::ffi::PyObject) -> Result<NumpyObjects, NumpyError> { + let available_types; + unsafe { + match NUMPY_TYPES.deref_mut() { + Some(v) => available_types = v, + _ => return Err(NumpyError::NotAvailable), + } + } + + let ob_type = ob_type!(obj); + + unsafe { + if ob_type == available_types.float32 { + return Ok(NumpyObjects::Float32(*(obj as *mut NumpyFloat32))); + } else if ob_type == available_types.float64 { + return Ok(NumpyObjects::Float64(*(obj as *mut NumpyFloat64))); + } else if ob_type == available_types.int32 { + return Ok(NumpyObjects::Int32(*(obj as *mut NumpyInt32))); + } else if ob_type == available_types.int64 { + return Ok(NumpyObjects::Int64(*(obj as *mut NumpyInt64))); + } else if ob_type == available_types.uint32 { + return Ok(NumpyObjects::Uint32(*(obj as *mut NumpyUint32))); + } else if ob_type == available_types.uint64 { + return Ok(NumpyObjects::Uint64(*(obj as *mut NumpyUint64))); + } else { + return Err(NumpyError::InvalidType); + } + } +} diff --git a/src/typeref.rs b/src/typeref.rs --- a/src/typeref.rs +++ b/src/typeref.rs @@ -6,6 +6,15 @@ use std::ptr::NonNull; use std::sync::Once; +pub struct NumpyTypes { + pub float32: *mut PyTypeObject, + pub array: *mut PyTypeObject, + pub float64: *mut PyTypeObject, + pub int32: *mut PyTypeObject, + pub int64: *mut PyTypeObject, + pub uint32: *mut PyTypeObject, + pub uint64: *mut PyTypeObject, +} pub static mut HASH_SEED: u64 = 0; pub static mut NONE: *mut PyObject = 0 as *mut PyObject; @@ -25,8 +34,7 @@ pub static mut TUPLE_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject; pub static mut UUID_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject; pub static mut ENUM_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject; -pub static mut ARRAY_TYPE: Lazy<Option<NonNull<PyTypeObject>>> = - Lazy::new(|| unsafe { look_up_array_type() }); +pub static mut NUMPY_TYPES: Lazy<Option<NumpyTypes>> = Lazy::new(|| unsafe { load_numpy_types() }); pub static mut FIELD_TYPE: Lazy<NonNull<PyObject>> = Lazy::new(|| unsafe { look_up_field_type() }); pub static mut BYTES_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject; @@ -116,19 +124,35 @@ res } -unsafe fn look_up_array_type() -> Option<NonNull<PyTypeObject>> { +unsafe fn look_up_numpy_type( + numpy_module: *mut PyObject, + np_type: &str, +) -> Option<NonNull<PyTypeObject>> { + let mod_dict = PyModule_GetDict(numpy_module); + let ptr = PyMapping_GetItemString(mod_dict, np_type.as_ptr() as *const c_char); + Py_XDECREF(ptr); + // Py_XDECREF(mod_dict) causes segfault when pytest exits + Some(NonNull::new_unchecked(ptr as *mut PyTypeObject)) +} + +unsafe fn load_numpy_types() -> Option<NumpyTypes> { let numpy = PyImport_ImportModule("numpy\0".as_ptr() as *const c_char); if numpy.is_null() { PyErr_Clear(); return None; - } else { - let mod_dict = PyModule_GetDict(numpy); - let ptr = PyMapping_GetItemString(mod_dict, "ndarray\0".as_ptr() as *const c_char); - Py_XDECREF(ptr); - // Py_XDECREF(mod_dict) causes segfault when pytest exits - Py_XDECREF(numpy); - Some(NonNull::new_unchecked(ptr as *mut PyTypeObject)) } + + let types = Some(NumpyTypes { + array: look_up_numpy_type(numpy, "ndarray\0")?.as_ptr(), + float32: look_up_numpy_type(numpy, "float32\0")?.as_ptr(), + float64: look_up_numpy_type(numpy, "float64\0")?.as_ptr(), + int32: look_up_numpy_type(numpy, "int32\0")?.as_ptr(), + int64: look_up_numpy_type(numpy, "int64\0")?.as_ptr(), + uint32: look_up_numpy_type(numpy, "uint32\0")?.as_ptr(), + uint64: look_up_numpy_type(numpy, "uint64\0")?.as_ptr(), + }); + Py_XDECREF(numpy); + types } unsafe fn look_up_field_type() -> NonNull<PyObject> { diff --git a/test/test_numpy.py b/test/test_numpy.py --- a/test/test_numpy.py +++ b/test/test_numpy.py @@ -2,10 +2,9 @@ import unittest +import orjson import pytest -import orjson - try: import numpy except ImportError: @@ -146,10 +145,6 @@ b"[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]", ) - def test_numpy_array_d0(self): - with self.assertRaises(orjson.JSONEncodeError): - orjson.dumps(numpy.int32(1), option=orjson.OPT_SERIALIZE_NUMPY) - def test_numpy_array_fotran(self): array = numpy.array([[1, 2], [3, 4]], order="F") assert array.flags["F_CONTIGUOUS"] == True @@ -254,3 +249,59 @@ orjson.loads(orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY,)), array.tolist(), ) + + def test_numpy_primitives(self): + # int32 + self.assertEqual( + orjson.dumps(numpy.int32(1), option=orjson.OPT_SERIALIZE_NUMPY), b"1" + ) + self.assertEqual( + orjson.dumps(numpy.int32(2147483647), option=orjson.OPT_SERIALIZE_NUMPY), + b"2147483647", + ) + self.assertEqual( + orjson.dumps(numpy.int32(-2147483648), option=orjson.OPT_SERIALIZE_NUMPY), + b"-2147483648", + ) + # int 64 + self.assertEqual( + orjson.dumps( + numpy.int64(-9223372036854775808), option=orjson.OPT_SERIALIZE_NUMPY + ), + b"-9223372036854775808", + ) + self.assertEqual( + orjson.dumps( + numpy.int64(9223372036854775807), option=orjson.OPT_SERIALIZE_NUMPY + ), + b"9223372036854775807", + ) + # uint32 + self.assertEqual( + orjson.dumps(numpy.uint32(0), option=orjson.OPT_SERIALIZE_NUMPY), b"0" + ) + self.assertEqual( + orjson.dumps(numpy.uint32(4294967295), option=orjson.OPT_SERIALIZE_NUMPY), + b"4294967295", + ) + # uint64 + self.assertEqual( + orjson.dumps(numpy.uint64(0), option=orjson.OPT_SERIALIZE_NUMPY), b"0" + ) + self.assertEqual( + orjson.dumps( + numpy.uint64(18446744073709551615), option=orjson.OPT_SERIALIZE_NUMPY + ), + b"18446744073709551615", + ) + + # float32 + self.assertEqual( + orjson.dumps(numpy.float32(1.0), option=orjson.OPT_SERIALIZE_NUMPY), b"1.0" + ) + + # float64 + self.assertEqual( + orjson.dumps(numpy.float64(123.123), option=orjson.OPT_SERIALIZE_NUMPY), + b"123.123", + )