# HG changeset patch
# User Aviram Hassan <aviram.hassan@biocatch.com>
# Date 1594207330 -10800
#      Wed Jul 08 14:22:10 2020 +0300
# Node ID 9817391863459df8020b4cc14c1f47dfe8ef8676
# Parent  83db8324a4cb41c42fc1d0096bca1ffe6078f677
add numpy primitives encoding

diff --git a/src/serialize/dict.rs b/src/serialize/dict.rs
--- a/src/serialize/dict.rs
+++ b/src/serialize/dict.rs
@@ -215,7 +215,8 @@
                 }
             }
             ObType::Tuple
-            | ObType::Array
+            | ObType::NumpyScalar
+            | ObType::NumpyArray
             | ObType::Dict
             | ObType::List
             | ObType::Dataclass
diff --git a/src/serialize/encode.rs b/src/serialize/encode.rs
--- a/src/serialize/encode.rs
+++ b/src/serialize/encode.rs
@@ -1,7 +1,6 @@
 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
 
 use crate::exc::*;
-use crate::ffi::PyDict_GET_SIZE;
 use crate::ffi::*;
 use crate::opt::*;
 use crate::serialize::dataclass::*;
@@ -33,7 +32,7 @@
     let mut buf = BytesWriter::new();
     let obtype = pyobject_to_obtype(ptr, opts);
     match obtype {
-        ObType::List | ObType::Dict | ObType::Dataclass | ObType::Array => {
+        ObType::List | ObType::Dict | ObType::Dataclass | ObType::NumpyArray => {
             buf.resize(1024);
         }
         _ => {}
@@ -75,7 +74,8 @@
     Tuple,
     Uuid,
     Dataclass,
-    Array,
+    NumpyScalar,
+    NumpyArray,
     Enum,
     StrSubclass,
     Unknown,
@@ -145,11 +145,10 @@
             ObType::Dict
         } else if ffi!(PyDict_Contains((*ob_type).tp_dict, DATACLASS_FIELDS_STR)) == 1 {
             ObType::Dataclass
-        } else if opts & SERIALIZE_NUMPY != 0
-            && ARRAY_TYPE.is_some()
-            && ob_type == ARRAY_TYPE.unwrap().as_ptr()
-        {
-            ObType::Array
+        } else if opts & SERIALIZE_NUMPY != 0 && is_numpy_scalar(ob_type) {
+            ObType::NumpyScalar
+        } else if opts & SERIALIZE_NUMPY != 0 && is_numpy_array(ob_type) {
+            ObType::NumpyArray
         } else {
             ObType::Unknown
         }
@@ -433,7 +432,7 @@
                 )
                 .serialize(serializer)
             }
-            ObType::Array => match PyArray::new(self.ptr) {
+            ObType::NumpyArray => match PyArray::new(self.ptr) {
                 Ok(val) => val.serialize(serializer),
                 Err(PyArrayError::Malformed) => err!("numpy array is malformed"),
                 Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => {
@@ -447,6 +446,18 @@
                     .serialize(serializer)
                 }
             },
+            ObType::NumpyScalar => match pyobj_to_numpy_obj(self.ptr) {
+                Ok(numpy_obj) => match numpy_obj {
+                    NumpyObjects::Float32(obj) => obj.serialize(serializer),
+                    NumpyObjects::Float64(obj) => obj.serialize(serializer),
+                    NumpyObjects::Int32(obj) => obj.serialize(serializer),
+                    NumpyObjects::Int64(obj) => obj.serialize(serializer),
+                    NumpyObjects::Uint32(obj) => obj.serialize(serializer),
+                    NumpyObjects::Uint64(obj) => obj.serialize(serializer),
+                },
+                Err(NumpyError::InvalidType) => err!("invalid numpy type"),
+                Err(NumpyError::NotAvailable) => err!("numpy not available"),
+            },
             ObType::Unknown => DefaultSerializer::new(
                 self.ptr,
                 self.opts,
diff --git a/src/serialize/numpy.rs b/src/serialize/numpy.rs
--- a/src/serialize/numpy.rs
+++ b/src/serialize/numpy.rs
@@ -1,8 +1,7 @@
-// SPDX-License-Identifier: (Apache-2.0 OR MIT)
-
-use crate::typeref::ARRAY_STRUCT_STR;
+use crate::typeref::{ARRAY_STRUCT_STR, NUMPY_TYPES};
 use pyo3::ffi::*;
 use serde::ser::{Serialize, SerializeSeq, Serializer};
+use std::ops::DerefMut;
 use std::os::raw::{c_char, c_int, c_void};
 
 macro_rules! slice {
@@ -21,8 +20,6 @@
     pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *);
 }
 
-// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__
-
 #[repr(C)]
 pub struct PyArrayInterface {
     pub two: c_int,
@@ -53,6 +50,11 @@
     UnsupportedDataType,
 }
 
+pub enum NumpyError {
+    NotAvailable,
+    InvalidType,
+}
+
 // >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32)
 // >>> arr.ndim
 // 3
@@ -256,7 +258,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeF64 {
+pub struct DataTypeF64 {
     pub obj: f64,
 }
 
@@ -270,7 +272,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeI32 {
+pub struct DataTypeI32 {
     pub obj: i32,
 }
 
@@ -284,7 +286,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeI64 {
+pub struct DataTypeI64 {
     pub obj: i64,
 }
 
@@ -298,7 +300,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeU32 {
+pub struct DataTypeU32 {
     pub obj: u32,
 }
 
@@ -312,7 +314,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeU64 {
+pub struct DataTypeU64 {
     pub obj: u64,
 }
 
@@ -326,7 +328,7 @@
 }
 
 #[repr(transparent)]
-struct DataTypeBOOL {
+pub struct DataTypeBOOL {
     pub obj: u8,
 }
 
@@ -338,3 +340,185 @@
         serializer.serialize_bool(self.obj == 1)
     }
 }
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyInt32 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: i32,
+}
+
+impl<'p> Serialize for NumpyInt32 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_i32(self.value)
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyInt64 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: i64,
+}
+
+impl<'p> Serialize for NumpyInt64 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_i64(self.value)
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyUint32 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: u32,
+}
+
+impl<'p> Serialize for NumpyUint32 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_u32(self.value)
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyUint64 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: u64,
+}
+
+impl<'p> Serialize for NumpyUint64 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_u64(self.value)
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyFloat32 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: f32,
+}
+
+impl<'p> Serialize for NumpyFloat32 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_f32(self.value)
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct NumpyFloat64 {
+    pub ob_refcnt: Py_ssize_t,
+    pub ob_type: *mut PyTypeObject,
+    pub value: f64,
+}
+
+impl<'p> Serialize for NumpyFloat64 {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_f64(self.value)
+    }
+}
+
+pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
+    let available_types;
+    unsafe {
+        match NUMPY_TYPES.deref_mut() {
+            Some(v) => available_types = v,
+            _ => return false,
+        }
+    }
+
+    let numpy_scalars = [
+        available_types.float32,
+        available_types.float64,
+        available_types.int32,
+        available_types.int64,
+        available_types.uint32,
+        available_types.uint64,
+    ];
+    numpy_scalars.contains(&ob_type)
+}
+
+pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool {
+    let available_types;
+    unsafe {
+        match NUMPY_TYPES.deref_mut() {
+            Some(v) => available_types = v,
+            _ => return false,
+        }
+    }
+    available_types.array == ob_type
+}
+
+// pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error>
+// where
+//     S: Serializer,
+//     {
+//         let ob_type = ob_type!(obj);
+//         let numpy = match ob_type {
+
+//         }
+//     }
+
+pub enum NumpyObjects {
+    Float32(NumpyFloat32),
+    Float64(NumpyFloat64),
+    Int32(NumpyInt32),
+    Int64(NumpyInt64),
+    Uint32(NumpyUint32),
+    Uint64(NumpyUint64),
+}
+
+pub fn pyobj_to_numpy_obj(obj: *mut pyo3::ffi::PyObject) -> Result<NumpyObjects, NumpyError> {
+    let available_types;
+    unsafe {
+        match NUMPY_TYPES.deref_mut() {
+            Some(v) => available_types = v,
+            _ => return Err(NumpyError::NotAvailable),
+        }
+    }
+
+    let ob_type = ob_type!(obj);
+
+    unsafe {
+        if ob_type == available_types.float32 {
+            return Ok(NumpyObjects::Float32(*(obj as *mut NumpyFloat32)));
+        } else if ob_type == available_types.float64 {
+            return Ok(NumpyObjects::Float64(*(obj as *mut NumpyFloat64)));
+        } else if ob_type == available_types.int32 {
+            return Ok(NumpyObjects::Int32(*(obj as *mut NumpyInt32)));
+        } else if ob_type == available_types.int64 {
+            return Ok(NumpyObjects::Int64(*(obj as *mut NumpyInt64)));
+        } else if ob_type == available_types.uint32 {
+            return Ok(NumpyObjects::Uint32(*(obj as *mut NumpyUint32)));
+        } else if ob_type == available_types.uint64 {
+            return Ok(NumpyObjects::Uint64(*(obj as *mut NumpyUint64)));
+        } else {
+            return Err(NumpyError::InvalidType);
+        }
+    }
+}
diff --git a/src/typeref.rs b/src/typeref.rs
--- a/src/typeref.rs
+++ b/src/typeref.rs
@@ -6,6 +6,15 @@
 use std::ptr::NonNull;
 use std::sync::Once;
 
+pub struct NumpyTypes {
+    pub float32: *mut PyTypeObject,
+    pub array: *mut PyTypeObject,
+    pub float64: *mut PyTypeObject,
+    pub int32: *mut PyTypeObject,
+    pub int64: *mut PyTypeObject,
+    pub uint32: *mut PyTypeObject,
+    pub uint64: *mut PyTypeObject,
+}
 pub static mut HASH_SEED: u64 = 0;
 
 pub static mut NONE: *mut PyObject = 0 as *mut PyObject;
@@ -25,8 +34,7 @@
 pub static mut TUPLE_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
 pub static mut UUID_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
 pub static mut ENUM_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
-pub static mut ARRAY_TYPE: Lazy<Option<NonNull<PyTypeObject>>> =
-    Lazy::new(|| unsafe { look_up_array_type() });
+pub static mut NUMPY_TYPES: Lazy<Option<NumpyTypes>> = Lazy::new(|| unsafe { load_numpy_types() });
 pub static mut FIELD_TYPE: Lazy<NonNull<PyObject>> = Lazy::new(|| unsafe { look_up_field_type() });
 
 pub static mut BYTES_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
@@ -116,19 +124,35 @@
     res
 }
 
-unsafe fn look_up_array_type() -> Option<NonNull<PyTypeObject>> {
+unsafe fn look_up_numpy_type(
+    numpy_module: *mut PyObject,
+    np_type: &str,
+) -> Option<NonNull<PyTypeObject>> {
+    let mod_dict = PyModule_GetDict(numpy_module);
+    let ptr = PyMapping_GetItemString(mod_dict, np_type.as_ptr() as *const c_char);
+    Py_XDECREF(ptr);
+    // Py_XDECREF(mod_dict) causes segfault when pytest exits
+    Some(NonNull::new_unchecked(ptr as *mut PyTypeObject))
+}
+
+unsafe fn load_numpy_types() -> Option<NumpyTypes> {
     let numpy = PyImport_ImportModule("numpy\0".as_ptr() as *const c_char);
     if numpy.is_null() {
         PyErr_Clear();
         return None;
-    } else {
-        let mod_dict = PyModule_GetDict(numpy);
-        let ptr = PyMapping_GetItemString(mod_dict, "ndarray\0".as_ptr() as *const c_char);
-        Py_XDECREF(ptr);
-        // Py_XDECREF(mod_dict) causes segfault when pytest exits
-        Py_XDECREF(numpy);
-        Some(NonNull::new_unchecked(ptr as *mut PyTypeObject))
     }
+
+    let types = Some(NumpyTypes {
+        array: look_up_numpy_type(numpy, "ndarray\0")?.as_ptr(),
+        float32: look_up_numpy_type(numpy, "float32\0")?.as_ptr(),
+        float64: look_up_numpy_type(numpy, "float64\0")?.as_ptr(),
+        int32: look_up_numpy_type(numpy, "int32\0")?.as_ptr(),
+        int64: look_up_numpy_type(numpy, "int64\0")?.as_ptr(),
+        uint32: look_up_numpy_type(numpy, "uint32\0")?.as_ptr(),
+        uint64: look_up_numpy_type(numpy, "uint64\0")?.as_ptr(),
+    });
+    Py_XDECREF(numpy);
+    types
 }
 
 unsafe fn look_up_field_type() -> NonNull<PyObject> {
diff --git a/test/test_numpy.py b/test/test_numpy.py
--- a/test/test_numpy.py
+++ b/test/test_numpy.py
@@ -2,10 +2,9 @@
 
 import unittest
 
+import orjson
 import pytest
 
-import orjson
-
 try:
     import numpy
 except ImportError:
@@ -146,10 +145,6 @@
             b"[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]",
         )
 
-    def test_numpy_array_d0(self):
-        with self.assertRaises(orjson.JSONEncodeError):
-            orjson.dumps(numpy.int32(1), option=orjson.OPT_SERIALIZE_NUMPY)
-
     def test_numpy_array_fotran(self):
         array = numpy.array([[1, 2], [3, 4]], order="F")
         assert array.flags["F_CONTIGUOUS"] == True
@@ -254,3 +249,59 @@
             orjson.loads(orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY,)),
             array.tolist(),
         )
+
+    def test_numpy_primitives(self):
+        # int32
+        self.assertEqual(
+            orjson.dumps(numpy.int32(1), option=orjson.OPT_SERIALIZE_NUMPY), b"1"
+        )
+        self.assertEqual(
+            orjson.dumps(numpy.int32(2147483647), option=orjson.OPT_SERIALIZE_NUMPY),
+            b"2147483647",
+        )
+        self.assertEqual(
+            orjson.dumps(numpy.int32(-2147483648), option=orjson.OPT_SERIALIZE_NUMPY),
+            b"-2147483648",
+        )
+        # int 64
+        self.assertEqual(
+            orjson.dumps(
+                numpy.int64(-9223372036854775808), option=orjson.OPT_SERIALIZE_NUMPY
+            ),
+            b"-9223372036854775808",
+        )
+        self.assertEqual(
+            orjson.dumps(
+                numpy.int64(9223372036854775807), option=orjson.OPT_SERIALIZE_NUMPY
+            ),
+            b"9223372036854775807",
+        )
+        # uint32
+        self.assertEqual(
+            orjson.dumps(numpy.uint32(0), option=orjson.OPT_SERIALIZE_NUMPY), b"0"
+        )
+        self.assertEqual(
+            orjson.dumps(numpy.uint32(4294967295), option=orjson.OPT_SERIALIZE_NUMPY),
+            b"4294967295",
+        )
+        # uint64
+        self.assertEqual(
+            orjson.dumps(numpy.uint64(0), option=orjson.OPT_SERIALIZE_NUMPY), b"0"
+        )
+        self.assertEqual(
+            orjson.dumps(
+                numpy.uint64(18446744073709551615), option=orjson.OPT_SERIALIZE_NUMPY
+            ),
+            b"18446744073709551615",
+        )
+
+        # float32
+        self.assertEqual(
+            orjson.dumps(numpy.float32(1.0), option=orjson.OPT_SERIALIZE_NUMPY), b"1.0"
+        )
+
+        # float64
+        self.assertEqual(
+            orjson.dumps(numpy.float64(123.123), option=orjson.OPT_SERIALIZE_NUMPY),
+            b"123.123",
+        )