Skip to content

Commit 70cfe46

Browse files
authored
PYTHON-3290 Support nested pymongo.timeout() calls (#962)
1 parent 890cd26 commit 70cfe46

File tree

4 files changed

+71
-12
lines changed

4 files changed

+71
-12
lines changed

pymongo/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,18 @@ def timeout(seconds: Optional[float]) -> ContextManager:
127127
NetworkTimeout) as exc:
128128
print(f"block timed out: {exc!r}")
129129
130+
When nesting :func:`~pymongo.timeout`, the nested block overrides the
131+
timeout. When exiting the block, the previous deadline is restored::
132+
133+
with pymongo.timeout(5):
134+
coll.find_one() # Uses the 5 second deadline.
135+
with pymongo.timeout(3):
136+
coll.find_one() # Uses the 3 second deadline.
137+
coll.find_one() # Uses the original 5 second deadline.
138+
with pymongo.timeout(10):
139+
coll.find_one() # Uses the 10 second deadline.
140+
coll.find_one() # Uses the original 5 second deadline.
141+
130142
:Parameters:
131143
- `seconds`: A non-negative floating point number expressing seconds, or None.
132144

pymongo/_csot.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""Internal helpers for CSOT."""
1616

1717
import time
18-
from contextvars import ContextVar
19-
from typing import Optional
18+
from contextvars import ContextVar, Token
19+
from typing import Optional, Tuple
2020

2121
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
2222
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
@@ -39,11 +39,6 @@ def set_rtt(rtt: float) -> None:
3939
RTT.set(rtt)
4040

4141

42-
def set_timeout(timeout: Optional[float]) -> None:
43-
TIMEOUT.set(timeout)
44-
DEADLINE.set(time.monotonic() + timeout if timeout else float("inf"))
45-
46-
4742
def remaining() -> Optional[float]:
4843
if not get_timeout():
4944
return None
@@ -67,14 +62,24 @@ class _TimeoutContext(object):
6762
client.test.test.insert_one({})
6863
"""
6964

70-
__slots__ = ("_timeout",)
65+
__slots__ = ("_timeout", "_tokens")
7166

7267
def __init__(self, timeout: Optional[float]):
7368
self._timeout = timeout
69+
self._tokens: Optional[Tuple[Token, Token, Token]] = None
7470

7571
def __enter__(self):
76-
set_timeout(self._timeout)
72+
timeout_token = TIMEOUT.set(self._timeout)
73+
deadline_token = DEADLINE.set(
74+
time.monotonic() + self._timeout if self._timeout else float("inf")
75+
)
76+
rtt_token = RTT.set(0.0)
77+
self._tokens = (timeout_token, deadline_token, rtt_token)
7778
return self
7879

7980
def __exit__(self, exc_type, exc_val, exc_tb):
80-
set_timeout(None)
81+
if self._tokens:
82+
timeout_token, deadline_token, rtt_token = self._tokens
83+
TIMEOUT.reset(timeout_token)
84+
DEADLINE.reset(deadline_token)
85+
RTT.reset(rtt_token)

pymongo/topology.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ def _select_server(self, selector, server_selection_timeout=None, address=None):
270270
def select_server(self, selector, server_selection_timeout=None, address=None):
271271
"""Like select_servers, but choose a random server if several match."""
272272
server = self._select_server(selector, server_selection_timeout, address)
273-
_csot.set_rtt(server.description.round_trip_time)
273+
if _csot.get_timeout():
274+
_csot.set_rtt(server.description.round_trip_time)
274275
return server
275276

276277
def select_server_by_address(self, address, server_selection_timeout=None):

test/test_csot.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,55 @@
1919

2020
sys.path[0:0] = [""]
2121

22-
from test import unittest
22+
from test import IntegrationTest, unittest
2323
from test.unified_format import generate_test_classes
2424

25+
import pymongo
26+
from pymongo import _csot
27+
2528
# Location of JSON test specifications.
2629
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot")
2730

2831
# Generate unified tests.
2932
globals().update(generate_test_classes(TEST_PATH, module=__name__))
3033

34+
35+
class TestCSOT(IntegrationTest):
36+
def test_timeout_nested(self):
37+
coll = self.db.coll
38+
self.assertEqual(_csot.get_timeout(), None)
39+
self.assertEqual(_csot.get_deadline(), float("inf"))
40+
self.assertEqual(_csot.get_rtt(), 0.0)
41+
with pymongo.timeout(10):
42+
coll.find_one()
43+
self.assertEqual(_csot.get_timeout(), 10)
44+
deadline_10 = _csot.get_deadline()
45+
46+
with pymongo.timeout(15):
47+
coll.find_one()
48+
self.assertEqual(_csot.get_timeout(), 15)
49+
self.assertGreater(_csot.get_deadline(), deadline_10)
50+
51+
# Should be reset to previous values
52+
self.assertEqual(_csot.get_timeout(), 10)
53+
self.assertEqual(_csot.get_deadline(), deadline_10)
54+
coll.find_one()
55+
56+
with pymongo.timeout(5):
57+
coll.find_one()
58+
self.assertEqual(_csot.get_timeout(), 5)
59+
self.assertLess(_csot.get_deadline(), deadline_10)
60+
61+
# Should be reset to previous values
62+
self.assertEqual(_csot.get_timeout(), 10)
63+
self.assertEqual(_csot.get_deadline(), deadline_10)
64+
coll.find_one()
65+
66+
# Should be reset to previous values
67+
self.assertEqual(_csot.get_timeout(), None)
68+
self.assertEqual(_csot.get_deadline(), float("inf"))
69+
self.assertEqual(_csot.get_rtt(), 0.0)
70+
71+
3172
if __name__ == "__main__":
3273
unittest.main()

0 commit comments

Comments
 (0)