From 57072b3e094a2e5af9818b36aceaddb907d3ce77 Mon Sep 17 00:00:00 2001 From: Alexander Artemenko Date: Fri, 28 Aug 2009 00:29:26 -0400 Subject: [PATCH] Added __getitem__ for cursor, to support cursor[M:N] and cursor[K] notation. Added __len__ to support len(cursor), which returns real row count, not total object count. --- pymongo/cursor.py | 30 ++++++++++++++ test/test_cursor.py | 108 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 0 deletions(-) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 7a323a9..6429042 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -60,6 +60,8 @@ class Cursor(object): self.__connection_id = None self.__retrieved = 0 self.__killed = False + self.__left_bound = 0 + self.__right_bound = 100000 def __del__(self): if self.__id and not self.__killed: @@ -211,6 +213,30 @@ class Cursor(object): self.__skip = skip return self + def __getitem__(self, index): + """Applies skip and limit, when cursor[2:10] notation is used.""" + self.__check_okay_to_chain() + + if isinstance(index, slice): + if index.stop is not None: + self.__right_bound = min(self.__right_bound, self.__left_bound + index.stop) + + if index.start is not None: + self.__left_bound += index.start + + self.skip(self.__left_bound) + self.limit(self.__right_bound - self.__left_bound) + return self + + clone = self.clone() + clone.skip(index) + clone.limit(1) + obj = list(clone) + if len(obj) != 1: + raise IndexError + return obj[0] + + def sort(self, key_or_list, direction=None): """Sorts this cursors results. @@ -248,6 +274,10 @@ class Cursor(object): return 0 return int(response["n"]) + def __len__(self): + total = self.count() + return min(total, self.__limit or total) + def explain(self): """Returns an explain plan record for this cursor. """ diff --git a/test/test_cursor.py b/test/test_cursor.py index faa7e46..c3772ec 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -188,6 +188,114 @@ class TestCursor(unittest.TestCase): break self.assertRaises(InvalidOperation, a.skip, 5) + def test_slice_with_skip(self): + from itertools import izip, count + db = self.db + db.drop_collection("test") + + for i in range(100): + db.test.save({"x": i}) + + for i, v in izip(count(0), db.test.find()): + self.assertEqual(v["x"], i) + + for i, v in izip(count(20), db.test.find()[20:]): + self.assertEqual(v["x"], i) + + for i, v in izip(count(99), db.test.find()[99:]): + self.assertEqual(v["x"], i) + + for i, v in izip(count(1), db.test.find()[1:]): + self.assertEqual(v["x"], i) + + for i, v in izip(count(0), db.test.find()[0:]): + self.assertEqual(v["x"], i) + + for i, v in izip(count(60), db.test.find()[0:][50:][10:]): + self.assertEqual(v["x"], i) + + for i in db.test.find()[1000:]: + self.fail() + + def test_slice_with_limit(self): + from itertools import izip, count + db = self.db + db.drop_collection("test") + + for i in range(100): + db.test.save({"x": i}) + + for i, v in izip(count(0), db.test.find()): + self.assertEqual(v["x"], i) + + + result = db.test.find()[20:25] + self.assertEqual(len(result), 5) + + for i, v in izip(count(20), result): + self.assertEqual(v["x"], i) + + + result = db.test.find()[99:100] + self.assertEqual(len(result), 1) + + for i, v in izip(count(99), result): + self.assertEqual(v["x"], i) + + + result = db.test.find()[1:11] + self.assertEqual(len(result), 10) + + for i, v in izip(count(1), result): + self.assertEqual(v["x"], i) + + + result = db.test.find()[0:10] + self.assertEqual(len(result), 10) + + for i, v in izip(count(0), result): + self.assertEqual(v["x"], i) + + + result = db.test.find()[10:50][25:100] + self.assertEqual(len(result), 15) + + for i, v in izip(count(35), result): + self.assertEqual(v["x"], i) + + + result = db.test.find()[20:50][0:10] + self.assertEqual(len(result), 10) + + for i, v in izip(count(20), result): + self.assertEqual(v["x"], i) + + + for i in db.test.find()[1000:10]: + self.fail() + + def test_get_single_item(self): + db = self.db + db.drop_collection("test") + + for i in range(100): + db.test.save({"x": i}) + + result = db.test.find() + self.assertEqual(result[0]['x'], 0) + self.assertEqual(result[50]['x'], 50) + self.assertEqual(result[99]['x'], 99) + + def test_length(self): + from itertools import izip, count + db = self.db + db.drop_collection("test") + + for i in range(10): + db.test.save({"x": i}) + + self.assertEqual(len(db.test.find()), 10) + def test_sort(self): db = self.db -- 1.6.3.2