Skip to content

Commit 59b62c8

Browse files
committed
PYTHON-2389 Add session support to find_raw_batches and aggregate_raw_batches (#658)
(cherry picked from commit 0e0c4fd)
1 parent 498c673 commit 59b62c8

File tree

4 files changed

+163
-33
lines changed

4 files changed

+163
-33
lines changed

pymongo/collection.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,21 +1547,16 @@ def find_raw_batches(self, *args, **kwargs):
15471547
>>> for batch in cursor:
15481548
... print(bson.decode_all(batch))
15491549
1550-
.. note:: find_raw_batches does not support sessions or auto
1551-
encryption.
1550+
.. note:: find_raw_batches does not support auto encryption.
15521551
15531552
.. versionchanged:: 3.12
15541553
Instead of ignoring the user-specified read concern, this method
15551554
now sends it to the server when connected to MongoDB 3.6+.
15561555
1556+
Added session support.
1557+
15571558
.. versionadded:: 3.6
15581559
"""
1559-
# OP_MSG with document stream returns is required to support
1560-
# sessions.
1561-
if "session" in kwargs:
1562-
raise ConfigurationError(
1563-
"find_raw_batches does not support sessions")
1564-
15651560
# OP_MSG is required to support encryption.
15661561
if self.__database.client._encrypter:
15671562
raise InvalidOperation(
@@ -2505,7 +2500,7 @@ def aggregate(self, pipeline, session=None, **kwargs):
25052500
explicit_session=session is not None,
25062501
**kwargs)
25072502

2508-
def aggregate_raw_batches(self, pipeline, **kwargs):
2503+
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
25092504
"""Perform an aggregation and retrieve batches of raw BSON.
25102505
25112506
Similar to the :meth:`aggregate` method but returns a
@@ -2522,28 +2517,25 @@ def aggregate_raw_batches(self, pipeline, **kwargs):
25222517
>>> for batch in cursor:
25232518
... print(bson.decode_all(batch))
25242519
2525-
.. note:: aggregate_raw_batches does not support sessions or auto
2526-
encryption.
2520+
.. note:: aggregate_raw_batches does not support auto encryption.
2521+
2522+
.. versionchanged:: 3.12
2523+
Added session support.
25272524
25282525
.. versionadded:: 3.6
25292526
"""
2530-
# OP_MSG with document stream returns is required to support
2531-
# sessions.
2532-
if "session" in kwargs:
2533-
raise ConfigurationError(
2534-
"aggregate_raw_batches does not support sessions")
2535-
25362527
# OP_MSG is required to support encryption.
25372528
if self.__database.client._encrypter:
25382529
raise InvalidOperation(
25392530
"aggregate_raw_batches does not support auto encryption")
25402531

2541-
return self._aggregate(_CollectionRawAggregationCommand,
2542-
pipeline,
2543-
RawBatchCommandCursor,
2544-
session=None,
2545-
explicit_session=False,
2546-
**kwargs)
2532+
with self.__database.client._tmp_session(session, close=False) as s:
2533+
return self._aggregate(_CollectionRawAggregationCommand,
2534+
pipeline,
2535+
RawBatchCommandCursor,
2536+
session=s,
2537+
explicit_session=session is not None,
2538+
**kwargs)
25472539

25482540
def watch(self, pipeline=None, full_document=None, resume_after=None,
25492541
max_await_time_ms=None, batch_size=None, collation=None,

pymongo/message.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ def use_command(self, sock_info):
473473

474474
class _RawBatchGetMore(_GetMore):
475475
def use_command(self, sock_info):
476+
# Compatibility checks.
477+
super(_RawBatchGetMore, self).use_command(sock_info)
476478
if sock_info.max_wire_version >= 8:
477479
# MongoDB 4.2+ supports exhaust over OP_MSG
478480
return True

test/test_cursor.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
unittest,
4747
IntegrationTest)
4848
from test.utils import (EventListener,
49+
OvertCommandListener,
4950
ignore_deprecations,
5051
rs_or_single_client,
5152
WhiteListEventListener)
@@ -1478,6 +1479,76 @@ def test_manipulate(self):
14781479
with self.assertRaises(InvalidOperation):
14791480
c.find_raw_batches(manipulate=True)
14801481

1482+
@client_context.require_transactions
1483+
def test_find_raw_transaction(self):
1484+
c = self.db.test
1485+
c.drop()
1486+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1487+
c.insert_many(docs)
1488+
1489+
listener = OvertCommandListener()
1490+
client = rs_or_single_client(event_listeners=[listener])
1491+
with client.start_session() as session:
1492+
with session.start_transaction():
1493+
batches = list(client[self.db.name].test.find_raw_batches(
1494+
session=session).sort('_id'))
1495+
cmd = listener.results['started'][0]
1496+
self.assertEqual(cmd.command_name, 'find')
1497+
self.assertEqual(cmd.command['$clusterTime'],
1498+
decode_all(session.cluster_time.raw)[0])
1499+
self.assertEqual(cmd.command['startTransaction'], True)
1500+
self.assertEqual(cmd.command['txnNumber'], 1)
1501+
1502+
self.assertEqual(1, len(batches))
1503+
self.assertEqual(docs, decode_all(batches[0]))
1504+
1505+
@client_context.require_sessions
1506+
@client_context.require_failCommand_fail_point
1507+
def test_find_raw_retryable_reads(self):
1508+
c = self.db.test
1509+
c.drop()
1510+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1511+
c.insert_many(docs)
1512+
1513+
listener = OvertCommandListener()
1514+
client = rs_or_single_client(event_listeners=[listener],
1515+
retryReads=True)
1516+
with self.fail_point({
1517+
'mode': {'times': 1}, 'data': {'failCommands': ['find'],
1518+
'closeConnection': True}}):
1519+
batches = list(
1520+
client[self.db.name].test.find_raw_batches().sort('_id'))
1521+
1522+
self.assertEqual(1, len(batches))
1523+
self.assertEqual(docs, decode_all(batches[0]))
1524+
self.assertEqual(len(listener.results['started']), 2)
1525+
for cmd in listener.results['started']:
1526+
self.assertEqual(cmd.command_name, 'find')
1527+
1528+
@client_context.require_version_min(5, 0, 0)
1529+
@client_context.require_no_standalone
1530+
def test_find_raw_snapshot_reads(self):
1531+
c = self.db.get_collection(
1532+
"test", write_concern=WriteConcern(w="majority"))
1533+
c.drop()
1534+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1535+
c.insert_many(docs)
1536+
1537+
listener = OvertCommandListener()
1538+
client = rs_or_single_client(event_listeners=[listener],
1539+
retryReads=True)
1540+
db = client[self.db.name]
1541+
with client.start_session(snapshot=True) as session:
1542+
db.test.distinct('x', {}, session=session)
1543+
batches = list(db.test.find_raw_batches(
1544+
session=session).sort('_id'))
1545+
self.assertEqual(1, len(batches))
1546+
self.assertEqual(docs, decode_all(batches[0]))
1547+
1548+
find_cmd = listener.results['started'][1].command
1549+
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
1550+
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
1551+
14811552
def test_explain(self):
14821553
c = self.db.test
14831554
c.insert_one({})
@@ -1602,6 +1673,75 @@ def test_aggregate_raw(self):
16021673
self.assertEqual(1, len(batches))
16031674
self.assertEqual(docs, decode_all(batches[0]))
16041675

1676+
@client_context.require_transactions
1677+
def test_aggregate_raw_transaction(self):
1678+
c = self.db.test
1679+
c.drop()
1680+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1681+
c.insert_many(docs)
1682+
1683+
listener = OvertCommandListener()
1684+
client = rs_or_single_client(event_listeners=[listener])
1685+
with client.start_session() as session:
1686+
with session.start_transaction():
1687+
batches = list(client[self.db.name].test.aggregate_raw_batches(
1688+
[{'$sort': {'_id': 1}}], session=session))
1689+
cmd = listener.results['started'][0]
1690+
self.assertEqual(cmd.command_name, 'aggregate')
1691+
self.assertEqual(cmd.command['$clusterTime'], session.cluster_time)
1692+
self.assertEqual(cmd.command['startTransaction'], True)
1693+
self.assertEqual(cmd.command['txnNumber'], 1)
1694+
self.assertEqual(1, len(batches))
1695+
self.assertEqual(docs, decode_all(batches[0]))
1696+
1697+
@client_context.require_sessions
1698+
@client_context.require_failCommand_fail_point
1699+
def test_aggregate_raw_retryable_reads(self):
1700+
c = self.db.test
1701+
c.drop()
1702+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1703+
c.insert_many(docs)
1704+
1705+
listener = OvertCommandListener()
1706+
client = rs_or_single_client(event_listeners=[listener],
1707+
retryReads=True)
1708+
with self.fail_point({
1709+
'mode': {'times': 1}, 'data': {'failCommands': ['aggregate'],
1710+
'closeConnection': True}}):
1711+
batches = list(client[self.db.name].test.aggregate_raw_batches(
1712+
[{'$sort': {'_id': 1}}]))
1713+
1714+
self.assertEqual(1, len(batches))
1715+
self.assertEqual(docs, decode_all(batches[0]))
1716+
self.assertEqual(len(listener.results['started']), 3)
1717+
cmds = listener.results['started']
1718+
self.assertEqual(cmds[0].command_name, 'aggregate')
1719+
self.assertEqual(cmds[1].command_name, 'aggregate')
1720+
1721+
@client_context.require_version_min(5, 0, -1)
1722+
@client_context.require_no_standalone
1723+
def test_aggregate_raw_snapshot_reads(self):
1724+
c = self.db.get_collection(
1725+
"test", write_concern=WriteConcern(w="majority"))
1726+
c.drop()
1727+
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
1728+
c.insert_many(docs)
1729+
1730+
listener = OvertCommandListener()
1731+
client = rs_or_single_client(event_listeners=[listener],
1732+
retryReads=True)
1733+
db = client[self.db.name]
1734+
with client.start_session(snapshot=True) as session:
1735+
db.test.distinct('x', {}, session=session)
1736+
batches = list(db.test.aggregate_raw_batches(
1737+
[{'$sort': {'_id': 1}}], session=session))
1738+
self.assertEqual(1, len(batches))
1739+
self.assertEqual(docs, decode_all(batches[0]))
1740+
1741+
find_cmd = listener.results['started'][1].command
1742+
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
1743+
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
1744+
16051745
def test_server_error(self):
16061746
c = self.db.test
16071747
c.drop()

test/test_session.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,12 @@ def test_reads(self):
891891
lambda coll, session: coll.count_documents({}, session=session))
892892
self._test_reads(
893893
lambda coll, session: coll.distinct('foo', session=session))
894+
self._test_reads(
895+
lambda coll, session: list(coll.aggregate_raw_batches(
896+
[], session=session)))
897+
self._test_reads(
898+
lambda coll, session: list(coll.find_raw_batches(
899+
{}, session=session)))
894900

895901
# SERVER-40938 removed support for casually consistent mapReduce.
896902
map_reduce_exc = None
@@ -916,16 +922,6 @@ def scan(coll, session):
916922
self._test_reads(
917923
lambda coll, session: scan(coll, session=session))
918924

919-
self.assertRaises(
920-
ConfigurationError,
921-
self._test_reads,
922-
lambda coll, session: list(
923-
coll.aggregate_raw_batches([], session=session)))
924-
self.assertRaises(
925-
ConfigurationError,
926-
self._test_reads,
927-
lambda coll, session: list(
928-
coll.find_raw_batches({}, session=session)))
929925
self.assertRaises(
930926
ConfigurationError,
931927
self._test_reads,

0 commit comments

Comments
 (0)