From 0e130670991aeac07306756061b6fb189edd127c Mon Sep 17 00:00:00 2001
From: Ori Avtalion <ori@avtalion.name>
Date: Fri, 25 Jan 2019 16:37:15 +0200
Subject: [PATCH] bson: Support encoding sets, dict view objects, bytearray,
 range

---
 bson/__init__.py    | 43 +++++++++++++++++-----
 bson/_cbsonmodule.c | 87 +++++++++++++++++++++++++++++++++++++++++++++
 test/test_bson.py   | 28 +++++++++++++++
 3 files changed, 149 insertions(+), 9 deletions(-)

diff --git a/bson/__init__.py b/bson/__init__.py
index 16573b3d..33fa1d5c 100644
--- a/bson/__init__.py
+++ b/bson/__init__.py
@@ -28,6 +28,13 @@ float                                    number (real)  both
 string                                   string         py -> bson
 unicode                                  string         both
 list                                     array          both
+set                                      array          py -> bson
+frozenset                                array          py -> bson
+dict_keys                                array          py -> bson
+dict_values                              array          py -> bson
+dict_items                               array          py -> bson
+xrange (Python 2)                        array          py -> bson
+range (Python 3)                         array          py -> bson
 dict / `SON`                             object         both
 datetime.datetime [#dt]_ [#dt2]_         date           both
 `bson.regex.Regex`                       regex          both
@@ -40,6 +47,7 @@ unicode                                  code           bson -> py
 `bson.code.Code`                         code           py -> bson
 unicode                                  symbol         bson -> py
 bytes (Python 3) [#bytes]_               binary         both
+bytearray [#bytes]_                      binary         py -> bson
 =======================================  =============  ===================
 
 Note that, when using Python 2.x, to save binary data it must be wrapped as
@@ -58,10 +66,10 @@ type.
    objects from ``re.compile()`` are both saved as BSON regular expressions.
    BSON regular expressions are decoded as :class:`~bson.regex.Regex`
    instances.
-.. [#bytes] The bytes type from Python 3.x is encoded as BSON binary with
-   subtype 0. In Python 3.x it will be decoded back to bytes. In Python 2.x
-   it will be decoded to an instance of :class:`~bson.binary.Binary` with
-   subtype 0.
+.. [#bytes] The bytes type from Python 3.x and the bytearray type are encoded
+   as BSON binary with subtype 0. In Python 3.x it will be decoded back to
+   bytes.  In Python 2.x it will be decoded to an instance of
+   :class:`~bson.binary.Binary` with subtype 0.
 """
 
 import calendar
@@ -524,6 +532,14 @@ else:
         return b"\x02" + name + _PACK_INT(len(value) + 1) + value + b"\x00"
 
 
+def _encode_bytearray(name, value, dummy0, dummy1):
+    """Encode a python bytearray."""
+    # Special case. Store 'bytes' as BSON binary subtype 0.
+    if not PY3:
+        value = str(value)
+    return b"\x05" + name + _PACK_INT(len(value)) + b"\x00" + value
+
+
 def _encode_mapping(name, value, check_keys, opts):
     """Encode a mapping type."""
     if _raw_document_class(value):
@@ -553,8 +569,8 @@ def _encode_dbref(name, value, check_keys, opts):
     return bytes(buf)
 
 
-def _encode_list(name, value, check_keys, opts):
-    """Encode a list/tuple."""
+def _encode_iterable(name, value, check_keys, opts):
+    """Encode an iterable."""
     lname = gen_list_name()
     data = b"".join([_name_value_to_bson(next(lname), item,
                                          check_keys, opts)
@@ -702,14 +718,17 @@ def _encode_maxkey(name, dummy0, dummy1, dummy2):
 _ENCODERS = {
     bool: _encode_bool,
     bytes: _encode_bytes,
+    bytearray: _encode_bytearray,
     datetime.datetime: _encode_datetime,
     dict: _encode_mapping,
     float: _encode_float,
     int: _encode_int,
-    list: _encode_list,
+    list: _encode_iterable,
+    set: _encode_iterable,
+    frozenset: _encode_iterable,
     # unicode in py2, str in py3
     text_type: _encode_text,
-    tuple: _encode_list,
+    tuple: _encode_iterable,
     type(None): _encode_none,
     uuid.UUID: _encode_uuid,
     Binary: _encode_binary,
@@ -725,7 +744,10 @@ _ENCODERS = {
     Timestamp: _encode_timestamp,
     UUIDLegacy: _encode_binary,
     Decimal128: _encode_decimal128,
-    # Special case. This will never be looked up directly.
+    # Special cases. This will never be looked up directly.
+    abc.KeysView: _encode_iterable,
+    abc.ValuesView: _encode_iterable,
+    abc.ItemsView: _encode_iterable,
     abc.Mapping: _encode_mapping,
 }
 
@@ -744,6 +766,9 @@ _MARKERS = {
 
 if not PY3:
     _ENCODERS[long] = _encode_long
+    _ENCODERS[xrange] = _encode_iterable
+else:
+    _ENCODERS[range] = _encode_iterable
 
 
 def _name_value_to_bson(name, value, check_keys, opts):
diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c
index afe05dcd..ab22eb14 100644
--- a/bson/_cbsonmodule.c
+++ b/bson/_cbsonmodule.c
@@ -1107,6 +1107,93 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
         buffer_write_int32_at_position(
             buffer, length_location, (int32_t)length);
         return 1;
+    } else if (PyAnySet_Check(value) || PyDictKeys_Check(value) ||
+               PyDictValues_Check(value) || PyDictItems_Check(value) ||
+               PyRange_Check(value)) {
+        Py_ssize_t items, i;
+        PyObject *iterator, *item_value;
+        int start_position,
+            length_location,
+            length;
+        char zero = 0;
+
+        *(buffer_get_buffer(buffer) + type_byte) = 0x04;
+        start_position = buffer_get_position(buffer);
+
+        /* save space for length */
+        length_location = buffer_save_space(buffer, 4);
+        if (length_location == -1) {
+            PyErr_NoMemory();
+            return 0;
+        }
+
+        if ((items = PySequence_Size(value)) > BSON_MAX_SIZE) {
+            PyObject* BSONError = _error("BSONError");
+            if (BSONError) {
+                PyErr_SetString(BSONError,
+                                "Too many items to serialize.");
+                Py_DECREF(BSONError);
+            }
+            return 0;
+        }
+
+        iterator = PyObject_GetIter(value);
+        for(i = 0; i < items; i++) {
+            int list_type_byte = buffer_save_space(buffer, 1);
+            char name[16];
+
+            if (list_type_byte == -1) {
+                PyErr_NoMemory();
+                return 0;
+            }
+            INT2STRING(name, (int)i);
+            if (!buffer_write_bytes(buffer, name, (int)strlen(name) + 1)) {
+                Py_DECREF(iterator);
+                return 0;
+            }
+
+            if (!(item_value = PyIter_Next(iterator))) {
+                Py_DECREF(iterator);
+                return 0;
+            }
+            if (!write_element_to_buffer(self, buffer, list_type_byte,
+                                         item_value, check_keys, options)) {
+                Py_DECREF(item_value);
+                Py_DECREF(iterator);
+                return 0;
+            }
+            Py_DECREF(item_value);
+        }
+        Py_DECREF(iterator);
+
+        /* write null byte and fill in length */
+        if (!buffer_write_bytes(buffer, &zero, 1)) {
+            return 0;
+        }
+        length = buffer_get_position(buffer) - start_position;
+        buffer_write_int32_at_position(
+            buffer, length_location, (int32_t)length);
+        return 1;
+    /* Special case. Store bytes as BSON binary subtype 0. */
+    } else if (PyByteArray_Check(value)) {
+        char subtype = 0;
+        int size;
+        const char* data = PyByteArray_AS_STRING(value);
+        if (!data)
+            return 0;
+        if ((size = _downcast_and_check(PyByteArray_GET_SIZE(value), 0)) == -1)
+            return 0;
+        *(buffer_get_buffer(buffer) + type_byte) = 0x05;
+        if (!buffer_write_int32(buffer, (int32_t)size)) {
+            return 0;
+        }
+        if (!buffer_write_bytes(buffer, &subtype, 1)) {
+            return 0;
+        }
+        if (!buffer_write_bytes(buffer, data, size)) {
+            return 0;
+        }
+        return 1;
 #if PY_MAJOR_VERSION >= 3
     /* Python3 special case. Store bytes as BSON binary subtype 0. */
     } else if (PyBytes_Check(value)) {
diff --git a/test/test_bson.py b/test/test_bson.py
index 86653c9c..da1e41a0 100644
--- a/test/test_bson.py
+++ b/test/test_bson.py
@@ -173,6 +173,34 @@ class TestBSON(unittest.TestCase):
         BSON.encode(dct)
         self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')]))
 
+    def test_encoding_set(self):
+        lst = ['bar', 'foo']
+        for typ in (set, frozenset):
+            encoded_set = BSON.encode({'test': typ(lst)})
+            decoded_set = sorted(BSON.decode(encoded_set)['test'])
+            self.assertEqual(decoded_set, lst)
+
+    def test_encoding_dict_views(self):
+        if not PY3:
+            return
+        dct = {'a': 'foo', 'b': 'bar'}
+        for meth in ('keys', 'values', 'items'):
+            encoded_view = BSON.encode({'test': getattr(dct, meth)()})
+            encoded_list = BSON.encode({'test': list(getattr(dct, meth)())})
+            self.assertEqual(encoded_view, encoded_list)
+
+    def test_encoding_bytearray(self):
+        binary_data = Binary(b'\x01\x02\x03')
+        encoded_bytearray = BSON.encode({'test': bytearray(binary_data)})
+        encoded_bytes = BSON.encode({'test': binary_data})
+        self.assertEqual(encoded_bytearray, encoded_bytes)
+
+    def test_encoding_range(self):
+        range_func = range if PY3 else xrange
+        encoded_range = BSON.encode({'test': range_func(0, 10)})
+        encoded_list = BSON.encode({'test': list(range_func(0, 10))})
+        self.assertEqual(encoded_range, encoded_list)
+
     def test_basic_validation(self):
         self.assertRaises(TypeError, is_valid, 100)
         self.assertRaises(TypeError, is_valid, u"test")
-- 
2.20.1

