11import asyncio
22import copy
33import enum
4- import errno
54import inspect
65import io
76import os
5554if HIREDIS_AVAILABLE :
5655 import hiredis
5756
58- NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
59- BlockingIOError : errno .EWOULDBLOCK ,
60- ssl .SSLWantReadError : 2 ,
61- ssl .SSLWantWriteError : 2 ,
62- ssl .SSLError : 2 ,
63- }
64-
65- NONBLOCKING_EXCEPTIONS = tuple (NONBLOCKING_EXCEPTION_ERROR_NUMBERS .keys ())
66-
67-
6857SYM_STAR = b"*"
6958SYM_DOLLAR = b"$"
7059SYM_CRLF = b"\r \n "
@@ -229,11 +218,9 @@ def __init__(
229218 self ,
230219 stream_reader : asyncio .StreamReader ,
231220 socket_read_size : int ,
232- socket_timeout : Optional [float ],
233221 ):
234222 self ._stream : Optional [asyncio .StreamReader ] = stream_reader
235223 self .socket_read_size = socket_read_size
236- self .socket_timeout = socket_timeout
237224 self ._buffer : Optional [io .BytesIO ] = io .BytesIO ()
238225 # number of bytes written to the buffer from the socket
239226 self .bytes_written = 0
@@ -244,52 +231,35 @@ def __init__(
244231 def length (self ):
245232 return self .bytes_written - self .bytes_read
246233
247- async def _read_from_socket (
248- self ,
249- length : Optional [int ] = None ,
250- timeout : Union [float , None , _Sentinel ] = SENTINEL ,
251- raise_on_timeout : bool = True ,
252- ) -> bool :
234+ async def _read_from_socket (self , length : Optional [int ] = None ) -> bool :
253235 buf = self ._buffer
254236 if buf is None or self ._stream is None :
255237 raise RedisError ("Buffer is closed." )
256238 buf .seek (self .bytes_written )
257239 marker = 0
258- timeout = timeout if timeout is not SENTINEL else self .socket_timeout
259240
260- try :
261- while True :
262- async with async_timeout .timeout (timeout ):
263- data = await self ._stream .read (self .socket_read_size )
264- # an empty string indicates the server shutdown the socket
265- if isinstance (data , bytes ) and len (data ) == 0 :
266- raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
267- buf .write (data )
268- data_length = len (data )
269- self .bytes_written += data_length
270- marker += data_length
271-
272- if length is not None and length > marker :
273- continue
274- return True
275- except (socket .timeout , asyncio .TimeoutError ):
276- if raise_on_timeout :
277- raise TimeoutError ("Timeout reading from socket" )
278- return False
279- except NONBLOCKING_EXCEPTIONS as ex :
280- # if we're in nonblocking mode and the recv raises a
281- # blocking error, simply return False indicating that
282- # there's no data to be read. otherwise raise the
283- # original exception.
284- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS .get (ex .__class__ , - 1 )
285- if not raise_on_timeout and ex .errno == allowed :
286- return False
287- raise ConnectionError (f"Error while reading from socket: { ex .args } " )
241+ while True :
242+ data = await self ._stream .read (self .socket_read_size )
243+ # an empty string indicates the server shutdown the socket
244+ if isinstance (data , bytes ) and len (data ) == 0 :
245+ raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
246+ buf .write (data )
247+ data_length = len (data )
248+ self .bytes_written += data_length
249+ marker += data_length
250+
251+ if length is not None and length > marker :
252+ continue
253+ return True
288254
289255 async def can_read_destructive (self ) -> bool :
290- return bool (self .length ) or await self ._read_from_socket (
291- timeout = 0 , raise_on_timeout = False
292- )
256+ if self .length :
257+ return True
258+ try :
259+ async with async_timeout .timeout (0 ):
260+ return await self ._read_from_socket ()
261+ except asyncio .TimeoutError :
262+ return False
293263
294264 async def read (self , length : int ) -> bytes :
295265 length = length + 2 # make sure to read the \r\n terminator
@@ -372,9 +342,7 @@ def on_connect(self, connection: "Connection"):
372342 if self ._stream is None :
373343 raise RedisError ("Buffer is closed." )
374344
375- self ._buffer = SocketBuffer (
376- self ._stream , self ._read_size , connection .socket_timeout
377- )
345+ self ._buffer = SocketBuffer (self ._stream , self ._read_size )
378346 self .encoder = connection .encoder
379347
380348 def on_disconnect (self ):
@@ -444,14 +412,13 @@ async def read_response(
444412class HiredisParser (BaseParser ):
445413 """Parser class for connections using Hiredis"""
446414
447- __slots__ = BaseParser .__slots__ + ("_reader" , "_socket_timeout" )
415+ __slots__ = BaseParser .__slots__ + ("_reader" ,)
448416
449417 def __init__ (self , socket_read_size : int ):
450418 if not HIREDIS_AVAILABLE :
451419 raise RedisError ("Hiredis is not available." )
452420 super ().__init__ (socket_read_size = socket_read_size )
453421 self ._reader : Optional [hiredis .Reader ] = None
454- self ._socket_timeout : Optional [float ] = None
455422
456423 def on_connect (self , connection : "Connection" ):
457424 self ._stream = connection ._reader
@@ -464,7 +431,6 @@ def on_connect(self, connection: "Connection"):
464431 kwargs ["errors" ] = connection .encoder .encoding_errors
465432
466433 self ._reader = hiredis .Reader (** kwargs )
467- self ._socket_timeout = connection .socket_timeout
468434
469435 def on_disconnect (self ):
470436 self ._stream = None
@@ -475,39 +441,20 @@ async def can_read_destructive(self):
475441 raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
476442 if self ._reader .gets ():
477443 return True
478- return await self .read_from_socket (timeout = 0 , raise_on_timeout = False )
479-
480- async def read_from_socket (
481- self ,
482- timeout : Union [float , None , _Sentinel ] = SENTINEL ,
483- raise_on_timeout : bool = True ,
484- ):
485- timeout = self ._socket_timeout if timeout is SENTINEL else timeout
486444 try :
487- if timeout is None :
488- buffer = await self ._stream .read (self ._read_size )
489- else :
490- async with async_timeout .timeout (timeout ):
491- buffer = await self ._stream .read (self ._read_size )
492- if not buffer or not isinstance (buffer , bytes ):
493- raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR ) from None
494- self ._reader .feed (buffer )
495- # data was read from the socket and added to the buffer.
496- # return True to indicate that data was read.
497- return True
498- except (socket .timeout , asyncio .TimeoutError ):
499- if raise_on_timeout :
500- raise TimeoutError ("Timeout reading from socket" ) from None
445+ async with async_timeout .timeout (0 ):
446+ return await self .read_from_socket ()
447+ except asyncio .TimeoutError :
501448 return False
502- except NONBLOCKING_EXCEPTIONS as ex :
503- # if we're in nonblocking mode and the recv raises a
504- # blocking error, simply return False indicating that
505- # there's no data to be read. otherwise raise the
506- # original exception.
507- allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS . get ( ex . __class__ , - 1 )
508- if not raise_on_timeout and ex . errno == allowed :
509- return False
510- raise ConnectionError ( f"Error while reading from socket: { ex . args } " )
449+
450+ async def read_from_socket ( self ):
451+ buffer = await self . _stream . read ( self . _read_size )
452+ if not buffer or not isinstance ( buffer , bytes ):
453+ raise ConnectionError ( SERVER_CLOSED_CONNECTION_ERROR ) from None
454+ self . _reader . feed ( buffer )
455+ # data was read from the socket and added to the buffer.
456+ # return True to indicate that data was read.
457+ return True
511458
512459 async def read_response (
513460 self , disable_decoding : bool = False
@@ -922,11 +869,16 @@ async def can_read_destructive(self):
922869 f"Error while reading from { self .host } :{ self .port } : { e .args } "
923870 )
924871
925- async def read_response (self , disable_decoding : bool = False ):
872+ async def read_response (
873+ self ,
874+ disable_decoding : bool = False ,
875+ timeout : Optional [float ] = None ,
876+ ):
926877 """Read the response from a previously sent command"""
878+ read_timeout = timeout if timeout is not None else self .socket_timeout
927879 try :
928- if self . socket_timeout :
929- async with async_timeout .timeout (self . socket_timeout ):
880+ if read_timeout is not None :
881+ async with async_timeout .timeout (read_timeout ):
930882 response = await self ._parser .read_response (
931883 disable_decoding = disable_decoding
932884 )
@@ -935,6 +887,10 @@ async def read_response(self, disable_decoding: bool = False):
935887 disable_decoding = disable_decoding
936888 )
937889 except asyncio .TimeoutError :
890+ if timeout is not None :
891+ # user requested timeout, return None
892+ return None
893+ # it was a self.socket_timeout error.
938894 await self .disconnect (nowait = True )
939895 raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
940896 except OSError as e :
0 commit comments