Skip to content

Commit 077b7ff

Browse files
committed
query: Add Query class and begin implementing decode
1 parent 5bb7bdc commit 077b7ff

File tree

7 files changed

+314
-52
lines changed

7 files changed

+314
-52
lines changed

src/skytable_py/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from .connection import Config, Connection
16+
from .connection import Connection
17+
from .query import Query, UInt, SInt
18+
from .config import Config

src/skytable_py/config.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024, Sayan Nandan <[email protected]>
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
from .connection import Connection
18+
from .exception import ClientException
19+
20+
21+
class Config:
22+
def __init__(self, username: str, password: str, host: str = "127.0.0.1", port: int = 2003) -> None:
23+
self._username = username
24+
self._password = password
25+
self._host = host
26+
self._port = port
27+
28+
def get_username(self) -> str:
29+
return self._username
30+
31+
def get_password(self) -> str:
32+
return self._password
33+
34+
def get_host(self) -> str:
35+
return self._host
36+
37+
def get_port(self) -> int:
38+
return self._port
39+
40+
def __hs(self) -> bytes:
41+
return f"H\0\0\0\0\0{len(self.get_username())}\n{len(self.get_password())}\n{self.get_username()}{self.get_password()}".encode()
42+
43+
async def connect(self) -> Connection:
44+
"""
45+
Establish a connection to the database instance using the set configuration.
46+
47+
## Exceptions
48+
Exceptions are raised in the following scenarios:
49+
- If the server responds with a handshake error
50+
- If the server sends an unknown handshake (usually caused by version incompatibility)
51+
"""
52+
reader, writer = await asyncio.open_connection(self.get_host(), self.get_port())
53+
con = Connection(reader, writer)
54+
await con._write_all(self.__hs())
55+
resp = await con._read_exact(4)
56+
a, b, c, d = resp[0], resp[1], resp[2], resp[3]
57+
if resp == b"H\0\0\0":
58+
return con
59+
elif a == ord(b'H') and b == 0 and c == 1:
60+
raise ClientException(f"handshake error {d}")
61+
else:
62+
raise ClientException("unknown handshake")

src/skytable_py/connection.py

+63-51
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import asyncio
1716
from asyncio import StreamReader, StreamWriter
18-
19-
20-
class ClientException(Exception):
21-
"""
22-
An exception thrown by this client library
23-
"""
24-
pass
17+
from .query import Query
18+
from .exception import ProtocolException
2519

2620

2721
class Connection:
@@ -32,6 +26,8 @@ class Connection:
3226
def __init__(self, reader: StreamReader, writer: StreamWriter) -> None:
3327
self._reader = reader
3428
self._writer = writer
29+
self._cursor = 0
30+
self.buffer = bytes()
3531

3632
async def _write_all(self, bytes: bytes):
3733
self._write(bytes)
@@ -40,6 +36,9 @@ async def _write_all(self, bytes: bytes):
4036
def _write(self, bytes: bytes) -> None:
4137
self._writer.write(bytes)
4238

39+
def __buffer(self) -> bytes:
40+
return self.buffer[:self._cursor]
41+
4342
async def _flush(self):
4443
await self._writer.drain()
4544

@@ -53,46 +52,59 @@ async def close(self):
5352
self._writer.close()
5453
await self._writer.wait_closed()
5554

56-
57-
class Config:
58-
def __init__(self, username: str, password: str, host: str = "127.0.0.1", port: int = 2003) -> None:
59-
self._username = username
60-
self._password = password
61-
self._host = host
62-
self._port = port
63-
64-
def get_username(self) -> str:
65-
return self._username
66-
67-
def get_password(self) -> str:
68-
return self._password
69-
70-
def get_host(self) -> str:
71-
return self._host
72-
73-
def get_port(self) -> int:
74-
return self._port
75-
76-
def __hs(self) -> bytes:
77-
return f"H\0\0\0\0\0{len(self.get_username())}\n{len(self.get_password())}\n{self.get_username()}{self.get_password()}".encode()
78-
79-
async def connect(self) -> Connection:
80-
"""
81-
Establish a connection to the database instance using the set configuration.
82-
83-
## Exceptions
84-
Exceptions are raised in the following scenarios:
85-
- If the server responds with a handshake error
86-
- If the server sends an unknown handshake (usually caused by version incompatibility)
87-
"""
88-
reader, writer = await asyncio.open_connection(self.get_host(), self.get_port())
89-
con = Connection(reader, writer)
90-
await con._write_all(self.__hs())
91-
resp = await con._read_exact(4)
92-
a, b, c, d = resp[0], resp[1], resp[2], resp[3]
93-
if resp == b"H\0\0\0":
94-
return con
95-
elif a == ord(b'H') and b == 0 and c == 1:
96-
raise ClientException(f"handshake error {d}")
97-
else:
98-
raise ClientException("unknown handshake")
55+
def __parse_string(self) -> None | str:
56+
strlen = self.__parse_int()
57+
if strlen:
58+
if len(self.__buffer()) >= strlen:
59+
strlen = self.__buffer()[:strlen].decode()
60+
self._cursor += strlen
61+
return strlen
62+
63+
def __parse_binary(self) -> None | bytes:
64+
binlen = self.__parse_int()
65+
if binlen:
66+
if len(self.__buffer()) >= binlen:
67+
binlen = self.__buffer()[:binlen].decode()
68+
self._cursor += binlen
69+
return binlen
70+
71+
def __parse_int(self) -> None | int:
72+
i = 0
73+
strlen = 0
74+
stop = False
75+
buffer = self.__buffer()
76+
77+
while i < len(buffer) and not stop:
78+
digit = None
79+
if 48 <= buffer[i] <= 57:
80+
digit = buffer[i] - 48
81+
82+
if digit is not None:
83+
strlen = (10 * strlen) + digit
84+
i += 1
85+
else:
86+
raise ProtocolException("invalid response from server")
87+
88+
if i < len(buffer) and buffer[i] == ord(b'\n'):
89+
stop = True
90+
i += 1
91+
92+
if stop:
93+
self._cursor += i
94+
self._cursor += 1 # for LF
95+
return strlen
96+
97+
async def run_simple_query(self, query: Query):
98+
query_window_str = str(len(query._q_window))
99+
total_packet_size = len(query_window_str) + 1 + len(query._buffer)
100+
# write metaframe
101+
metaframe = f"S{str(total_packet_size)}\n{query_window_str}\n"
102+
await self._write_all(metaframe.encode())
103+
# write dataframe
104+
await self._write_all(query._buffer)
105+
# now enter read loop
106+
while True:
107+
read = await self._reader.read(1024)
108+
if len(read) == 0:
109+
raise ConnectionResetError
110+
self.buffer = self.buffer + read

src/skytable_py/exception.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2024, Sayan Nandan <[email protected]>
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
class ClientException(Exception):
18+
"""
19+
An exception thrown by this client library
20+
"""
21+
pass
22+
23+
24+
class ProtocolException(ClientException):
25+
"""
26+
An exception thrown by the protocol
27+
"""
28+
pass

src/skytable_py/query.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024, Sayan Nandan <[email protected]>
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from abc import ABC
17+
# internal
18+
from .exception import ClientException
19+
20+
21+
class Query:
22+
def __init__(self, query: str, *argv) -> None:
23+
self._buffer = query.encode()
24+
self._param_cnt = 0
25+
self._q_window = len(self._buffer)
26+
for param in argv:
27+
self.add_param(param)
28+
29+
def add_param(self, param: any) -> None:
30+
payload, param_cnt = encode_parameter(param)
31+
self._param_cnt += param_cnt
32+
self._buffer = self._buffer + payload
33+
34+
def get_param_count(self) -> int:
35+
return self._param_cnt
36+
37+
38+
class SkyhashParameter(ABC):
39+
def encode_self(self) -> tuple[bytes, int]: pass
40+
41+
42+
class UInt(SkyhashParameter):
43+
def __init__(self, v: int) -> None:
44+
if v < 0:
45+
raise ClientException("unsigned int can't be negative")
46+
self.v = v
47+
48+
def encode_self(self) -> tuple[bytes, int]:
49+
return (f"\x02{self.v}\n".encode(), 1)
50+
51+
52+
class SInt(SkyhashParameter):
53+
def __init__(self, v: int) -> None:
54+
self.v = v
55+
56+
def encode_self(self) -> tuple[bytes, int]:
57+
return (f"\x03{self.v}\n".encode(), 1)
58+
59+
60+
def encode_parameter(parameter: any) -> tuple[bytes, int]:
61+
encoded = None
62+
if isinstance(parameter, SkyhashParameter):
63+
return parameter.encode_self()
64+
elif parameter is None:
65+
encoded = "\0".encode()
66+
elif isinstance(parameter, bool):
67+
encoded = f"\1{1 if parameter else 0}".encode()
68+
elif isinstance(parameter, float):
69+
encoded = f"\x04{parameter}\n".encode()
70+
elif isinstance(parameter, bytes):
71+
encoded = f"\x05{len(parameter)}\n".encode() + parameter
72+
elif isinstance(parameter, str):
73+
encoded = f"\x06{len(parameter)}\n{parameter}".encode()
74+
else:
75+
raise ClientException("unsupported type")
76+
return (encoded, 1)

tests/test_encoding.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024, Sayan Nandan <[email protected]>
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
from src.skytable_py.query import encode_parameter, UInt, SInt
18+
from src.skytable_py.exception import ClientException
19+
20+
21+
class TestConfig(unittest.TestCase):
22+
def test_encode_null(self):
23+
self.assertEqual(encode_parameter(None), (b"\0", 1))
24+
25+
def test_encode_bool(self):
26+
self.assertEqual(encode_parameter(False), (b"\x010", 1))
27+
self.assertEqual(encode_parameter(True), (b"\x011", 1))
28+
29+
def test_encode_uint(self):
30+
self.assertEqual(encode_parameter(UInt(1234)), (b"\x021234\n", 1))
31+
32+
def test_encode_sint(self):
33+
self.assertEqual(encode_parameter(SInt(-1234)), (b"\x03-1234\n", 1))
34+
35+
def test_encode_float(self):
36+
self.assertEqual(encode_parameter(3.141592654),
37+
(b"\x043.141592654\n", 1))
38+
39+
def test_encode_bin(self):
40+
self.assertEqual(encode_parameter(b"binary"), (b"\x056\nbinary", 1))
41+
42+
def test_encode_str(self):
43+
self.assertEqual(encode_parameter("string"), (b"\x066\nstring", 1))
44+
45+
def test_int_causes_exception(self):
46+
try:
47+
encode_parameter(1234)
48+
except ClientException as e:
49+
if str(e) == "unsupported type":
50+
pass
51+
else:
52+
self.fail(f"expected 'unsupported type' but got '{e}'")
53+
else:
54+
self.fail("expected exception but no exception was raised")

0 commit comments

Comments
 (0)