Skip to content

Commit fa0be76

Browse files
WisdomPilldvora-h
andauthored
Made sync lock consistent and added types to it (#2137)
* Made sync lock consistent and added types to it * Made linters happy * Fixed cluster client lock signature Co-authored-by: dvora-h <[email protected]>
1 parent 05fc203 commit fa0be76

File tree

5 files changed

+80
-25
lines changed

5 files changed

+80
-25
lines changed

redis/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ def lock(
10981098
name,
10991099
timeout=None,
11001100
sleep=0.1,
1101+
blocking=True,
11011102
blocking_timeout=None,
11021103
lock_class=None,
11031104
thread_local=True,
@@ -1113,6 +1114,12 @@ def lock(
11131114
when the lock is in blocking mode and another client is currently
11141115
holding the lock.
11151116
1117+
``blocking`` indicates whether calling ``acquire`` should block until
1118+
the lock has been acquired or to fail immediately, causing ``acquire``
1119+
to return False and the lock not being acquired. Defaults to True.
1120+
Note this value can be overridden by passing a ``blocking``
1121+
argument to ``acquire``.
1122+
11161123
``blocking_timeout`` indicates the maximum amount of time in seconds to
11171124
spend trying to acquire the lock. A value of ``None`` indicates
11181125
continue trying forever. ``blocking_timeout`` can be specified as a
@@ -1155,6 +1162,7 @@ def lock(
11551162
name,
11561163
timeout=timeout,
11571164
sleep=sleep,
1165+
blocking=blocking,
11581166
blocking_timeout=blocking_timeout,
11591167
thread_local=thread_local,
11601168
)

redis/cluster.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ def lock(
766766
name,
767767
timeout=None,
768768
sleep=0.1,
769+
blocking=True,
769770
blocking_timeout=None,
770771
lock_class=None,
771772
thread_local=True,
@@ -781,6 +782,12 @@ def lock(
781782
when the lock is in blocking mode and another client is currently
782783
holding the lock.
783784
785+
``blocking`` indicates whether calling ``acquire`` should block until
786+
the lock has been acquired or to fail immediately, causing ``acquire``
787+
to return False and the lock not being acquired. Defaults to True.
788+
Note this value can be overridden by passing a ``blocking``
789+
argument to ``acquire``.
790+
784791
``blocking_timeout`` indicates the maximum amount of time in seconds to
785792
spend trying to acquire the lock. A value of ``None`` indicates
786793
continue trying forever. ``blocking_timeout`` can be specified as a
@@ -823,6 +830,7 @@ def lock(
823830
name,
824831
timeout=timeout,
825832
sleep=sleep,
833+
blocking=blocking,
826834
blocking_timeout=blocking_timeout,
827835
thread_local=thread_local,
828836
)

redis/lock.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import threading
22
import time as mod_time
33
import uuid
4-
from types import SimpleNamespace
4+
from types import SimpleNamespace, TracebackType
5+
from typing import Optional, Type
56

67
from redis.exceptions import LockError, LockNotOwnedError
8+
from redis.typing import Number
79

810

911
class Lock:
@@ -74,12 +76,13 @@ class Lock:
7476
def __init__(
7577
self,
7678
redis,
77-
name,
78-
timeout=None,
79-
sleep=0.1,
80-
blocking=True,
81-
blocking_timeout=None,
82-
thread_local=True,
79+
name: str,
80+
*,
81+
timeout: Optional[Number] = None,
82+
sleep: Number = 0.1,
83+
blocking: bool = True,
84+
blocking_timeout: Optional[Number] = None,
85+
thread_local: bool = True,
8386
):
8487
"""
8588
Create a new Lock instance named ``name`` using the Redis client
@@ -142,7 +145,7 @@ def __init__(
142145
self.local.token = None
143146
self.register_scripts()
144147

145-
def register_scripts(self):
148+
def register_scripts(self) -> None:
146149
cls = self.__class__
147150
client = self.redis
148151
if cls.lua_release is None:
@@ -152,15 +155,27 @@ def register_scripts(self):
152155
if cls.lua_reacquire is None:
153156
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
154157

155-
def __enter__(self):
158+
def __enter__(self) -> "Lock":
156159
if self.acquire():
157160
return self
158161
raise LockError("Unable to acquire lock within the time specified")
159162

160-
def __exit__(self, exc_type, exc_value, traceback):
163+
def __exit__(
164+
self,
165+
exc_type: Optional[Type[BaseException]],
166+
exc_value: Optional[BaseException],
167+
traceback: Optional[TracebackType],
168+
) -> None:
161169
self.release()
162170

163-
def acquire(self, blocking=None, blocking_timeout=None, token=None):
171+
def acquire(
172+
self,
173+
*,
174+
sleep: Optional[Number] = None,
175+
blocking: Optional[bool] = None,
176+
blocking_timeout: Optional[Number] = None,
177+
token: Optional[str] = None,
178+
):
164179
"""
165180
Use Redis to hold a shared, distributed lock named ``name``.
166181
Returns True once the lock is acquired.
@@ -176,7 +191,8 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
176191
object with the default encoding. If a token isn't specified, a UUID
177192
will be generated.
178193
"""
179-
sleep = self.sleep
194+
if sleep is None:
195+
sleep = self.sleep
180196
if token is None:
181197
token = uuid.uuid1().hex.encode()
182198
else:
@@ -200,7 +216,7 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
200216
return False
201217
mod_time.sleep(sleep)
202218

203-
def do_acquire(self, token):
219+
def do_acquire(self, token: str) -> bool:
204220
if self.timeout:
205221
# convert to milliseconds
206222
timeout = int(self.timeout * 1000)
@@ -210,13 +226,13 @@ def do_acquire(self, token):
210226
return True
211227
return False
212228

213-
def locked(self):
229+
def locked(self) -> bool:
214230
"""
215231
Returns True if this key is locked by any process, otherwise False.
216232
"""
217233
return self.redis.get(self.name) is not None
218234

219-
def owned(self):
235+
def owned(self) -> bool:
220236
"""
221237
Returns True if this key is locked by this lock, otherwise False.
222238
"""
@@ -228,21 +244,23 @@ def owned(self):
228244
stored_token = encoder.encode(stored_token)
229245
return self.local.token is not None and stored_token == self.local.token
230246

231-
def release(self):
232-
"Releases the already acquired lock"
247+
def release(self) -> None:
248+
"""
249+
Releases the already acquired lock
250+
"""
233251
expected_token = self.local.token
234252
if expected_token is None:
235253
raise LockError("Cannot release an unlocked lock")
236254
self.local.token = None
237255
self.do_release(expected_token)
238256

239-
def do_release(self, expected_token):
257+
def do_release(self, expected_token: str) -> None:
240258
if not bool(
241259
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
242260
):
243261
raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")
244262

245-
def extend(self, additional_time, replace_ttl=False):
263+
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
246264
"""
247265
Adds more time to an already acquired lock.
248266
@@ -259,19 +277,19 @@ def extend(self, additional_time, replace_ttl=False):
259277
raise LockError("Cannot extend a lock with no timeout")
260278
return self.do_extend(additional_time, replace_ttl)
261279

262-
def do_extend(self, additional_time, replace_ttl):
280+
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
263281
additional_time = int(additional_time * 1000)
264282
if not bool(
265283
self.lua_extend(
266284
keys=[self.name],
267-
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
285+
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
268286
client=self.redis,
269287
)
270288
):
271-
raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
289+
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
272290
return True
273291

274-
def reacquire(self):
292+
def reacquire(self) -> bool:
275293
"""
276294
Resets a TTL of an already acquired lock back to a timeout value.
277295
"""
@@ -281,12 +299,12 @@ def reacquire(self):
281299
raise LockError("Cannot reacquire a lock with no timeout")
282300
return self.do_reacquire()
283301

284-
def do_reacquire(self):
302+
def do_reacquire(self) -> bool:
285303
timeout = int(self.timeout * 1000)
286304
if not bool(
287305
self.lua_reacquire(
288306
keys=[self.name], args=[self.local.token, timeout], client=self.redis
289307
)
290308
):
291-
raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
309+
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
292310
return True

redis/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from redis.connection import ConnectionPool, Encoder
1212

1313

14+
Number = Union[int, float]
1415
EncodedT = Union[bytes, memoryview]
1516
DecodedT = Union[str, int, float]
1617
EncodableT = Union[EncodedT, DecodedT]

tests/test_lock.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ def test_context_manager(self, r):
116116
assert r.get("foo") == lock.local.token
117117
assert r.get("foo") is None
118118

119+
def test_context_manager_blocking_timeout(self, r):
120+
with self.get_lock(r, "foo", blocking=False):
121+
bt = 0.4
122+
sleep = 0.05
123+
lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
124+
start = time.monotonic()
125+
assert not lock2.acquire()
126+
# The elapsed duration should be less than the total blocking_timeout
127+
assert bt > (time.monotonic() - start) > bt - sleep
128+
119129
def test_context_manager_raises_when_locked_not_acquired(self, r):
120130
r.set("foo", "bar")
121131
with pytest.raises(LockError):
@@ -221,6 +231,16 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
221231
with pytest.raises(LockNotOwnedError):
222232
lock.reacquire()
223233

234+
def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r):
235+
with self.get_lock(r, "foo", timeout=None, blocking=False) as lock:
236+
with pytest.raises(LockError):
237+
lock.reacquire()
238+
239+
def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r):
240+
with pytest.raises(LockNotOwnedError):
241+
with self.get_lock(r, "foo", timeout=10, blocking=False):
242+
r.set("foo", "a")
243+
224244

225245
class TestLockClassSelection:
226246
def test_lock_class_argument(self, r):

0 commit comments

Comments
 (0)