13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- import asyncio
17
16
from asyncio import StreamReader , StreamWriter
18
-
19
-
20
- class ClientException (Exception ):
21
- """
22
- An exception thrown by this client library
23
- """
24
- pass
17
+ from .query import Query
18
+ from .exception import ProtocolException
25
19
26
20
27
21
class Connection :
@@ -32,6 +26,8 @@ class Connection:
32
26
def __init__ (self , reader : StreamReader , writer : StreamWriter ) -> None :
33
27
self ._reader = reader
34
28
self ._writer = writer
29
+ self ._cursor = 0
30
+ self .buffer = bytes ()
35
31
36
32
async def _write_all (self , bytes : bytes ):
37
33
self ._write (bytes )
@@ -40,6 +36,9 @@ async def _write_all(self, bytes: bytes):
40
36
def _write (self , bytes : bytes ) -> None :
41
37
self ._writer .write (bytes )
42
38
39
+ def __buffer (self ) -> bytes :
40
+ return self .buffer [:self ._cursor ]
41
+
43
42
async def _flush (self ):
44
43
await self ._writer .drain ()
45
44
@@ -53,46 +52,59 @@ async def close(self):
53
52
self ._writer .close ()
54
53
await self ._writer .wait_closed ()
55
54
56
-
57
- class Config :
58
- def __init__ (self , username : str , password : str , host : str = "127.0.0.1" , port : int = 2003 ) -> None :
59
- self ._username = username
60
- self ._password = password
61
- self ._host = host
62
- self ._port = port
63
-
64
- def get_username (self ) -> str :
65
- return self ._username
66
-
67
- def get_password (self ) -> str :
68
- return self ._password
69
-
70
- def get_host (self ) -> str :
71
- return self ._host
72
-
73
- def get_port (self ) -> int :
74
- return self ._port
75
-
76
- def __hs (self ) -> bytes :
77
- return f"H\0 \0 \0 \0 \0 { len (self .get_username ())} \n { len (self .get_password ())} \n { self .get_username ()} { self .get_password ()} " .encode ()
78
-
79
- async def connect (self ) -> Connection :
80
- """
81
- Establish a connection to the database instance using the set configuration.
82
-
83
- ## Exceptions
84
- Exceptions are raised in the following scenarios:
85
- - If the server responds with a handshake error
86
- - If the server sends an unknown handshake (usually caused by version incompatibility)
87
- """
88
- reader , writer = await asyncio .open_connection (self .get_host (), self .get_port ())
89
- con = Connection (reader , writer )
90
- await con ._write_all (self .__hs ())
91
- resp = await con ._read_exact (4 )
92
- a , b , c , d = resp [0 ], resp [1 ], resp [2 ], resp [3 ]
93
- if resp == b"H\0 \0 \0 " :
94
- return con
95
- elif a == ord (b'H' ) and b == 0 and c == 1 :
96
- raise ClientException (f"handshake error { d } " )
97
- else :
98
- raise ClientException ("unknown handshake" )
55
+ def __parse_string (self ) -> None | str :
56
+ strlen = self .__parse_int ()
57
+ if strlen :
58
+ if len (self .__buffer ()) >= strlen :
59
+ strlen = self .__buffer ()[:strlen ].decode ()
60
+ self ._cursor += strlen
61
+ return strlen
62
+
63
+ def __parse_binary (self ) -> None | bytes :
64
+ binlen = self .__parse_int ()
65
+ if binlen :
66
+ if len (self .__buffer ()) >= binlen :
67
+ binlen = self .__buffer ()[:binlen ].decode ()
68
+ self ._cursor += binlen
69
+ return binlen
70
+
71
+ def __parse_int (self ) -> None | int :
72
+ i = 0
73
+ strlen = 0
74
+ stop = False
75
+ buffer = self .__buffer ()
76
+
77
+ while i < len (buffer ) and not stop :
78
+ digit = None
79
+ if 48 <= buffer [i ] <= 57 :
80
+ digit = buffer [i ] - 48
81
+
82
+ if digit is not None :
83
+ strlen = (10 * strlen ) + digit
84
+ i += 1
85
+ else :
86
+ raise ProtocolException ("invalid response from server" )
87
+
88
+ if i < len (buffer ) and buffer [i ] == ord (b'\n ' ):
89
+ stop = True
90
+ i += 1
91
+
92
+ if stop :
93
+ self ._cursor += i
94
+ self ._cursor += 1 # for LF
95
+ return strlen
96
+
97
+ async def run_simple_query (self , query : Query ):
98
+ query_window_str = str (len (query ._q_window ))
99
+ total_packet_size = len (query_window_str ) + 1 + len (query ._buffer )
100
+ # write metaframe
101
+ metaframe = f"S{ str (total_packet_size )} \n { query_window_str } \n "
102
+ await self ._write_all (metaframe .encode ())
103
+ # write dataframe
104
+ await self ._write_all (query ._buffer )
105
+ # now enter read loop
106
+ while True :
107
+ read = await self ._reader .read (1024 )
108
+ if len (read ) == 0 :
109
+ raise ConnectionResetError
110
+ self .buffer = self .buffer + read
0 commit comments