diff --git a/src/serialize/numpy.rs b/src/serialize/numpy.rs index 54f6f13a86196e1e0b70a2757b60892da51d4e22_c3JjL3NlcmlhbGl6ZS9udW1weS5ycw==..003c12335774cec4816bdefc6d38b7c4ec126665_c3JjL3NlcmlhbGl6ZS9udW1weS5ycw== 100644 --- a/src/serialize/numpy.rs +++ b/src/serialize/numpy.rs @@ -242,23 +242,9 @@ S: Serializer, { let mut seq = serializer.serialize_seq(None).unwrap(); - if !self.children.is_empty() { - for child in &self.children { - seq.serialize_element(child).unwrap(); - } - } else { - let data_ptr = self.data(); - let num_items = self.num_items(); - match self.kind().unwrap() { - ItemType::F64 => { - let slice: &[f64] = slice!(data_ptr as *const f64, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeF64 { obj: each }).unwrap(); - } - } - ItemType::F32 => { - let slice: &[f32] = slice!(data_ptr as *const f32, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeF32 { obj: each }).unwrap(); - } + + if self.depth >= self.shape().len() || self.shape()[self.depth] != 0 { + if !self.children.is_empty() { + for child in &self.children { + seq.serialize_element(child).unwrap(); } @@ -264,6 +250,12 @@ } - ItemType::I64 => { - let slice: &[i64] = slice!(data_ptr as *const i64, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeI64 { obj: each }).unwrap(); + + } else { + let data_ptr = self.data(); + let num_items = self.num_items(); + match self.kind().unwrap() { + ItemType::F64 => { + let slice: &[f64] = slice!(data_ptr as *const f64, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeF64 { obj: each }).unwrap(); + } } @@ -269,7 +261,13 @@ } - } - ItemType::I32 => { - let slice: &[i32] = slice!(data_ptr as *const i32, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeI32 { obj: each }).unwrap(); + ItemType::F32 => { + let slice: &[f32] = slice!(data_ptr as *const f32, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeF32 { obj: each }).unwrap(); + } + } + ItemType::I64 => { + let slice: &[i64] = slice!(data_ptr as *const i64, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeI64 { obj: each }).unwrap(); + } } @@ -275,7 +273,13 @@ } - } - ItemType::U64 => { - let slice: &[u64] = slice!(data_ptr as *const u64, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeU64 { obj: each }).unwrap(); + ItemType::I32 => { + let slice: &[i32] = slice!(data_ptr as *const i32, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeI32 { obj: each }).unwrap(); + } + } + ItemType::U64 => { + let slice: &[u64] = slice!(data_ptr as *const u64, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeU64 { obj: each }).unwrap(); + } } @@ -281,7 +285,7 @@ } - } - ItemType::U32 => { - let slice: &[u32] = slice!(data_ptr as *const u32, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeU32 { obj: each }).unwrap(); + ItemType::U32 => { + let slice: &[u32] = slice!(data_ptr as *const u32, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeU32 { obj: each }).unwrap(); + } } @@ -287,9 +291,9 @@ } - } - ItemType::BOOL => { - let slice: &[u8] = slice!(data_ptr as *const u8, num_items); - for &each in slice.iter() { - seq.serialize_element(&DataTypeBOOL { obj: each }).unwrap(); + ItemType::BOOL => { + let slice: &[u8] = slice!(data_ptr as *const u8, num_items); + for &each in slice.iter() { + seq.serialize_element(&DataTypeBOOL { obj: each }).unwrap(); + } } } } diff --git a/test/test_numpy.py b/test/test_numpy.py index 54f6f13a86196e1e0b70a2757b60892da51d4e22_dGVzdC90ZXN0X251bXB5LnB5..003c12335774cec4816bdefc6d38b7c4ec126665_dGVzdC90ZXN0X251bXB5LnB5 100644 --- a/test/test_numpy.py +++ b/test/test_numpy.py @@ -235,6 +235,39 @@ with self.assertRaises(orjson.JSONEncodeError): orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY) + array = numpy.empty((0, 4, 2)) + self.assertEqual( + orjson.loads( + orjson.dumps( + array, + option=orjson.OPT_SERIALIZE_NUMPY, + ) + ), + array.tolist(), + ) + + array = numpy.empty((4, 0, 2)) + self.assertEqual( + orjson.loads( + orjson.dumps( + array, + option=orjson.OPT_SERIALIZE_NUMPY, + ) + ), + array.tolist(), + ) + + array = numpy.empty((2, 4, 0)) + self.assertEqual( + orjson.loads( + orjson.dumps( + array, + option=orjson.OPT_SERIALIZE_NUMPY, + ) + ), + array.tolist(), + ) + def test_numpy_array_dimension_max(self): array = numpy.random.rand( 1,