# HG changeset patch
# User Bob Ippolito <bob@redivi.com>
# Date 1336680348 25200
#      Thu May 10 13:05:48 2012 -0700
# Node ID 42e60cf4e8ce0e6516e8bd916ad0ec4c591b7498
# Parent  bc00ddd724e1e0dc7410d0eed6c0c52d355b2739
Support for use_decimal=True in sub-interpreters

diff --git a/CHANGES.txt b/CHANGES.txt
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,3 +1,9 @@
+Version 2.5.1 released 2012-05-10
+
+* Support for use_decimal=True in environments that use Python
+  sub-interpreters such as uWSGI
+  https://github.com/simplejson/simplejson/issues/34
+
 Version 2.5.0 released 2012-03-29
 
 * New item_sort_key option for encoder to allow fine grained control of sorted
diff --git a/conf.py b/conf.py
--- a/conf.py
+++ b/conf.py
@@ -44,7 +44,7 @@
 # The short X.Y version.
 version = '2.5'
 # The full version, including alpha/beta/rc tags.
-release = '2.5.0'
+release = '2.5.1'
 
 # There are two options for replacing |today|: either, you set today to some
 # non-false value, then it is used:
diff --git a/setup.py b/setup.py
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,7 @@
     DistutilsPlatformError
 
 IS_PYPY = hasattr(sys, 'pypy_translation_info')
-VERSION = '2.5.0'
+VERSION = '2.5.1'
 DESCRIPTION = "Simple, fast, extensible JSON encoder/decoder for Python"
 LONG_DESCRIPTION = open('README.rst', 'r').read()
 
diff --git a/simplejson/__init__.py b/simplejson/__init__.py
--- a/simplejson/__init__.py
+++ b/simplejson/__init__.py
@@ -97,7 +97,7 @@
     $ echo '{ 1.2:3.4}' | python -m simplejson.tool
     Expecting property name: line 1 column 2 (char 2)
 """
-__version__ = '2.5.0'
+__version__ = '2.5.1'
 __all__ = [
     'dump', 'dumps', 'load', 'loads',
     'JSONDecoder', 'JSONDecodeError', 'JSONEncoder',
diff --git a/simplejson/_speedups.c b/simplejson/_speedups.c
--- a/simplejson/_speedups.c
+++ b/simplejson/_speedups.c
@@ -44,11 +44,9 @@
 #define PyScanner_CheckExact(op) (Py_TYPE(op) == &PyScannerType)
 #define PyEncoder_Check(op) PyObject_TypeCheck(op, &PyEncoderType)
 #define PyEncoder_CheckExact(op) (Py_TYPE(op) == &PyEncoderType)
-#define Decimal_Check(op) (PyObject_TypeCheck(op, DecimalTypePtr))
 
 static PyTypeObject PyScannerType;
 static PyTypeObject PyEncoderType;
-static PyTypeObject *DecimalTypePtr;
 
 typedef struct _PyScannerObject {
     PyObject_HEAD
@@ -84,6 +82,7 @@
     PyObject *sort_keys;
     PyObject *skipkeys;
     PyObject *key_memo;
+    PyObject *Decimal;
     int fast_encode;
     int allow_nan;
     int use_decimal;
@@ -2053,6 +2052,7 @@
         s->skipkeys = NULL;
         s->key_memo = NULL;
         s->item_sort_key = NULL;
+        s->Decimal = NULL;
     }
     return (PyObject *)s;
 }
@@ -2065,15 +2065,18 @@
 
     PyEncoderObject *s;
     PyObject *markers, *defaultfn, *encoder, *indent, *key_separator;
-    PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo, *use_decimal, *namedtuple_as_object, *tuple_as_array, *bigint_as_string, *item_sort_key;
+    PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo;
+    PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array;
+    PyObject *bigint_as_string, *item_sort_key, *Decimal;
 
     assert(PyEncoder_Check(self));
     s = (PyEncoderObject *)self;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOO:make_encoder", kwlist,
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOO:make_encoder", kwlist,
         &markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator,
         &sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal,
-        &namedtuple_as_object, &tuple_as_array, &bigint_as_string, &item_sort_key))
+        &namedtuple_as_object, &tuple_as_array, &bigint_as_string,
+        &item_sort_key, &Decimal))
         return -1;
 
     s->markers = markers;
@@ -2092,6 +2095,7 @@
     s->tuple_as_array = PyObject_IsTrue(tuple_as_array);
     s->bigint_as_string = PyObject_IsTrue(bigint_as_string);
     s->item_sort_key = item_sort_key;
+    s->Decimal = Decimal;
 
     Py_INCREF(s->markers);
     Py_INCREF(s->defaultfn);
@@ -2103,6 +2107,7 @@
     Py_INCREF(s->skipkeys);
     Py_INCREF(s->key_memo);
     Py_INCREF(s->item_sort_key);
+    Py_INCREF(s->Decimal);
     return 0;
 }
 
@@ -2255,7 +2260,7 @@
         else if (PyDict_Check(obj)) {
             rv = encoder_listencode_dict(s, rval, obj, indent_level);
         }
-        else if (s->use_decimal && Decimal_Check(obj)) {
+        else if (s->use_decimal && PyObject_TypeCheck(obj, s->Decimal)) {
             PyObject *encoded = PyObject_Str(obj);
             if (encoded != NULL)
                 rv = _steal_list_append(rval, encoded);
@@ -2649,6 +2654,7 @@
     Py_CLEAR(s->skipkeys);
     Py_CLEAR(s->key_memo);
     Py_CLEAR(s->item_sort_key);
+    Py_CLEAR(s->Decimal);
     return 0;
 }
 
@@ -2716,7 +2722,7 @@
 void
 init_speedups(void)
 {
-    PyObject *m, *decimal;
+    PyObject *m;
     PyScannerType.tp_new = PyType_GenericNew;
     if (PyType_Ready(&PyScannerType) < 0)
         return;
@@ -2724,13 +2730,6 @@
     if (PyType_Ready(&PyEncoderType) < 0)
         return;
 
-    decimal = PyImport_ImportModule("decimal");
-    if (decimal == NULL)
-        return;
-    DecimalTypePtr = (PyTypeObject*)PyObject_GetAttrString(decimal, "Decimal");
-    Py_DECREF(decimal);
-    if (DecimalTypePtr == NULL)
-        return;
 
     m = Py_InitModule3("_speedups", speedups_methods, module_doc);
     Py_INCREF((PyObject*)&PyScannerType);
diff --git a/simplejson/encoder.py b/simplejson/encoder.py
--- a/simplejson/encoder.py
+++ b/simplejson/encoder.py
@@ -297,14 +297,16 @@
                 self.key_separator, self.item_separator, self.sort_keys,
                 self.skipkeys, self.allow_nan, key_memo, self.use_decimal,
                 self.namedtuple_as_object, self.tuple_as_array,
-                self.bigint_as_string, self.item_sort_key)
+                self.bigint_as_string, self.item_sort_key,
+                Decimal)
         else:
             _iterencode = _make_iterencode(
                 markers, self.default, _encoder, self.indent, floatstr,
                 self.key_separator, self.item_separator, self.sort_keys,
                 self.skipkeys, _one_shot, self.use_decimal,
                 self.namedtuple_as_object, self.tuple_as_array,
-                self.bigint_as_string, self.item_sort_key)
+                self.bigint_as_string, self.item_sort_key,
+                Decimal=Decimal)
         try:
             return _iterencode(o, 0)
         finally:
diff --git a/simplejson/tests/test_decimal.py b/simplejson/tests/test_decimal.py
--- a/simplejson/tests/test_decimal.py
+++ b/simplejson/tests/test_decimal.py
@@ -1,3 +1,4 @@
+import decimal
 from decimal import Decimal
 from unittest import TestCase
 from StringIO import StringIO
@@ -22,11 +23,11 @@
     def test_decimal_encode(self):
         for d in map(Decimal, self.NUMS):
             self.assertEquals(self.dumps(d, use_decimal=True), str(d))
-    
+
     def test_decimal_decode(self):
         for s in self.NUMS:
             self.assertEquals(self.loads(s, parse_float=Decimal), Decimal(s))
-    
+
     def test_decimal_roundtrip(self):
         for d in map(Decimal, self.NUMS):
             # The type might not be the same (int and Decimal) but they
@@ -46,10 +47,20 @@
         self.assertRaises(TypeError, json.dumps, d, use_decimal=False)
         self.assertEqual('1.1', json.dumps(d))
         self.assertEqual('1.1', json.dumps(d, use_decimal=True))
-        self.assertRaises(TypeError, json.dump, d, StringIO(), use_decimal=False)
+        self.assertRaises(TypeError, json.dump, d, StringIO(),
+                          use_decimal=False)
         sio = StringIO()
         json.dump(d, sio)
         self.assertEqual('1.1', sio.getvalue())
         sio = StringIO()
         json.dump(d, sio, use_decimal=True)
         self.assertEqual('1.1', sio.getvalue())
+
+    def test_decimal_reload(self):
+        # Simulate a subinterpreter that reloads the Python modules but not
+        # the C code https://github.com/simplejson/simplejson/issues/34
+        global Decimal
+        Decimal = reload(decimal).Decimal
+        import simplejson.encoder
+        simplejson.encoder.Decimal = Decimal
+        self.test_decimal_roundtrip()