Skip to content

Commit f68b5f8

Browse files
committed
Change timeout behaviour to better mimic usage as timeout of Weaviate op, not of HTTP op
1 parent 3833ceb commit f68b5f8

File tree

8 files changed

+210
-14
lines changed

8 files changed

+210
-14
lines changed

mock_tests/conftest.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import time
23
from concurrent import futures
34
from typing import Generator, Mapping
45

@@ -8,11 +9,19 @@
89
from grpc_health.v1.health_pb2 import HealthCheckResponse, HealthCheckRequest
910
from grpc_health.v1.health_pb2_grpc import HealthServicer, add_HealthServicer_to_server
1011
from pytest_httpserver import HTTPServer, HeaderValueMatcher
11-
from werkzeug.wrappers import Response
12+
from werkzeug.wrappers import Request, Response
1213

1314
import weaviate
1415
from weaviate.connect.base import ConnectionParams, ProtocolParams
15-
from weaviate.proto.v1 import properties_pb2, tenants_pb2, search_get_pb2, weaviate_pb2_grpc
16+
from weaviate.proto.v1 import (
17+
batch_pb2,
18+
properties_pb2,
19+
tenants_pb2,
20+
search_get_pb2,
21+
weaviate_pb2_grpc,
22+
)
23+
24+
from mock_tests.mock_data import mock_class
1625

1726
MOCK_IP = "127.0.0.1"
1827
MOCK_PORT = 23536
@@ -76,6 +85,26 @@ def weaviate_auth_mock(weaviate_mock: HTTPServer):
7685
yield weaviate_mock
7786

7887

88+
@pytest.fixture(scope="function")
89+
def weaviate_timeouts_mock(weaviate_no_auth_mock: HTTPServer):
90+
def slow_get(request: Request) -> Response:
91+
time.sleep(1)
92+
return Response(json.dumps({"doesn't": "matter"}), content_type="application/json")
93+
94+
def slow_post(request: Request) -> Response:
95+
time.sleep(2)
96+
return Response(json.dumps({"doesn't": "matter"}), content_type="application/json")
97+
98+
weaviate_no_auth_mock.expect_request(
99+
f"/v1/schema/{mock_class['class']}", method="GET"
100+
).respond_with_handler(slow_get)
101+
weaviate_no_auth_mock.expect_request("/v1/objects", method="POST").respond_with_handler(
102+
slow_post
103+
)
104+
105+
yield weaviate_no_auth_mock
106+
107+
79108
@pytest.fixture(scope="function")
80109
def start_grpc_server() -> Generator[grpc.Server, None, None]:
81110
# Create a gRPC server
@@ -110,6 +139,22 @@ def weaviate_client(
110139
client.close()
111140

112141

142+
@pytest.fixture(scope="function")
143+
def weaviate_timeouts_client(
144+
weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server
145+
) -> Generator[weaviate.WeaviateClient, None, None]:
146+
client = weaviate.connect_to_local(
147+
host=MOCK_IP,
148+
port=MOCK_PORT,
149+
grpc_port=MOCK_PORT_GRPC,
150+
additional_config=weaviate.classes.init.AdditionalConfig(
151+
timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5)
152+
),
153+
)
154+
yield client
155+
client.close()
156+
157+
113158
@pytest.fixture(scope="function")
114159
def tenants_collection(
115160
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
@@ -184,3 +229,24 @@ def Search(
184229

185230
weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
186231
return weaviate_client.collections.get("YearZeroCollection")
232+
233+
234+
@pytest.fixture(scope="function")
235+
def timeouts_collection(
236+
weaviate_timeouts_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
237+
) -> weaviate.collections.Collection:
238+
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
239+
def Search(
240+
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
241+
) -> search_get_pb2.SearchReply:
242+
time.sleep(1)
243+
return search_get_pb2.SearchReply()
244+
245+
def BatchObjects(
246+
self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext
247+
) -> batch_pb2.BatchObjectsReply:
248+
time.sleep(2)
249+
return batch_pb2.BatchObjectsReply()
250+
251+
weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
252+
return weaviate_timeouts_client.collections.get(mock_class["class"])

mock_tests/mock_data.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
mock_class = {
2+
"class": "Something",
3+
"description": "It's something!",
4+
"invertedIndexConfig": {
5+
"bm25": {"b": 0.8, "k1": 1.3},
6+
"cleanupIntervalSeconds": 61,
7+
"indexPropertyLength": True,
8+
"indexTimestamps": True,
9+
"stopwords": {"additions": None, "preset": "en", "removals": ["the"]},
10+
},
11+
"moduleConfig": {
12+
"generative-openai": {},
13+
"text2vec-contextionary": {"vectorizeClassName": True},
14+
},
15+
"multiTenancyConfig": {
16+
"autoTenantActivation": False,
17+
"autoTenantCreation": False,
18+
"enabled": False,
19+
},
20+
"properties": [
21+
{
22+
"dataType": ["text[]"],
23+
"indexFilterable": True,
24+
"indexRangeFilters": False,
25+
"indexSearchable": True,
26+
"moduleConfig": {
27+
"text2vec-contextionary": {"skip": False, "vectorizePropertyName": False}
28+
},
29+
"name": "names",
30+
"tokenization": "word",
31+
}
32+
],
33+
"replicationConfig": {"asyncEnabled": False, "factor": 1},
34+
"shardingConfig": {
35+
"virtualPerPhysical": 128,
36+
"desiredCount": 1,
37+
"actualCount": 1,
38+
"desiredVirtualCount": 128,
39+
"actualVirtualCount": 128,
40+
"key": "_id",
41+
"strategy": "hash",
42+
"function": "murmur3",
43+
},
44+
"vectorIndexConfig": {
45+
"skip": True,
46+
"cleanupIntervalSeconds": 300,
47+
"maxConnections": 64,
48+
"efConstruction": 128,
49+
"ef": -2,
50+
"dynamicEfMin": 101,
51+
"dynamicEfMax": 501,
52+
"dynamicEfFactor": 9,
53+
"vectorCacheMaxObjects": 1000000000001,
54+
"flatSearchCutoff": 40001,
55+
"distance": "cosine",
56+
"pq": {
57+
"enabled": True,
58+
"bitCompression": True,
59+
"segments": 1,
60+
"centroids": 257,
61+
"trainingLimit": 100001,
62+
"encoder": {"type": "tile", "distribution": "normal"},
63+
},
64+
"bq": {"enabled": False},
65+
"sq": {"enabled": False, "trainingLimit": 100000, "rescoreLimit": 20},
66+
},
67+
"vectorIndexType": "hnsw",
68+
"vectorizer": "text2vec-contextionary",
69+
}

mock_tests/test_timeouts.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
import weaviate
3+
from weaviate.exceptions import WeaviateTimeoutError, WeaviateQueryError
4+
5+
6+
def test_timeout_rest_query(timeouts_collection: weaviate.collections.Collection):
7+
with pytest.raises(WeaviateTimeoutError):
8+
timeouts_collection.config.get()
9+
10+
11+
def test_timeout_rest_insert(timeouts_collection: weaviate.collections.Collection):
12+
with pytest.raises(WeaviateTimeoutError):
13+
timeouts_collection.data.insert(properties={"what": "ever"})
14+
15+
16+
def test_timeout_grpc_query(timeouts_collection: weaviate.collections.Collection):
17+
with pytest.raises(WeaviateQueryError) as recwarn:
18+
timeouts_collection.query.fetch_objects()
19+
assert "DEADLINE_EXCEEDED" in str(recwarn)
20+
21+
22+
def test_timeout_grpc_insert(timeouts_collection: weaviate.collections.Collection):
23+
with pytest.raises(WeaviateQueryError) as recwarn:
24+
timeouts_collection.data.insert_many([{"what": "ever"}])
25+
assert "DEADLINE_EXCEEDED" in str(recwarn)

weaviate/client_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ async def graphql_raw_query(self, gql_query: str) -> _RawGQLReturn:
209209
weaviate_object=json_query,
210210
error_msg="Raw GQL query failed",
211211
status_codes=_ExpectedStatusCodes(ok_in=[200], error="GQL query"),
212+
is_gql_query=True,
212213
)
213214

214215
res = _decode_json_response_dict(response, "GQL query")

weaviate/collections/batch/grpc_batch_objects.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def pack_vector(vector: Any) -> bytes:
7979
for obj in objects
8080
]
8181

82-
async def objects(self, objects: List[_BatchObject], timeout: int) -> BatchObjectReturn:
82+
async def objects(
83+
self, objects: List[_BatchObject], timeout: Union[int, float]
84+
) -> BatchObjectReturn:
8385
"""Insert multiple objects into Weaviate through the gRPC API.
8486
8587
Parameters:
@@ -131,7 +133,7 @@ async def objects(self, objects: List[_BatchObject], timeout: int) -> BatchObjec
131133
)
132134

133135
async def __send_batch(
134-
self, batch: List[batch_pb2.BatchObject], timeout: int
136+
self, batch: List[batch_pb2.BatchObject], timeout: Union[int, float]
135137
) -> Dict[int, str]:
136138
metadata = self._get_metadata()
137139
try:

weaviate/config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def __post_init__(self) -> None:
4848
class Timeout(BaseModel):
4949
"""Timeouts for the different operations in the client."""
5050

51-
query: int = Field(default=30, ge=0)
52-
insert: int = Field(default=90, ge=0)
53-
init: int = Field(default=2, ge=0)
51+
query: Union[int, float] = Field(default=30, ge=0)
52+
insert: Union[int, float] = Field(default=90, ge=0)
53+
init: Union[int, float] = Field(default=2, ge=0)
5454

5555

5656
class Proxies(BaseModel):

weaviate/connect/v4.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
HTTPStatusError,
2626
Limits,
2727
ReadError,
28+
ReadTimeout,
2829
RemoteProtocolError,
2930
RequestError,
3031
Response,
@@ -55,6 +56,7 @@
5556
WeaviateConnectionError,
5657
WeaviateGRPCUnavailableError,
5758
WeaviateStartUpError,
59+
WeaviateTimeoutError,
5860
)
5961
from weaviate.proto.v1 import weaviate_pb2_grpc
6062
from weaviate.util import (
@@ -219,12 +221,6 @@ def __make_mounts(self) -> Dict[str, AsyncHTTPTransport]:
219221
def __make_async_client(self) -> AsyncClient:
220222
return AsyncClient(
221223
headers=self._headers,
222-
timeout=Timeout(
223-
None,
224-
connect=self.timeout_config.init,
225-
read=self.timeout_config.query,
226-
write=self.timeout_config.insert,
227-
),
228224
mounts=self.__make_mounts(),
229225
)
230226

@@ -409,12 +405,36 @@ def __get_latest_headers(self) -> Dict[str, str]:
409405
copied_headers.update({"authorization": self.get_current_bearer_token()})
410406
return copied_headers
411407

408+
def __get_timeout(
409+
self, method: Literal["DELETE", "GET", "HEAD", "PATCH", "POST", "PUT"], is_gql_query: bool
410+
) -> Timeout:
411+
"""
412+
In this way, the client waits the `httpx` default of 5s when connecting to a socket (connect), writing chunks (write), and
413+
acquiring a connection from the pool (pool), but a custom amount as specified for reading the response (read).
414+
415+
From the PoV of the user, a request is considered to be timed out if no response is received within the specified time.
416+
They specify the times depending on how they expect Weaviate to behave. For example, a query might take longer than an insert or vice versa.
417+
418+
https://www.python-httpx.org/advanced/timeouts/
419+
"""
420+
timeout = None
421+
if method == "DELETE" or method == "PATCH" or method == "PUT":
422+
timeout = self.timeout_config.insert
423+
elif method == "GET" or method == "HEAD":
424+
timeout = self.timeout_config.query
425+
elif method == "POST" and is_gql_query:
426+
timeout = self.timeout_config.query
427+
elif method == "POST" and not is_gql_query:
428+
timeout = self.timeout_config.insert
429+
return Timeout(timeout=5.0, read=timeout)
430+
412431
async def __send(
413432
self,
414433
method: Literal["DELETE", "GET", "HEAD", "PATCH", "POST", "PUT"],
415434
url: str,
416435
error_msg: str,
417436
status_codes: Optional[_ExpectedStatusCodes],
437+
is_gql_query: bool = False,
418438
weaviate_object: Optional[JSONPayload] = None,
419439
params: Optional[Dict[str, Any]] = None,
420440
) -> Response:
@@ -430,6 +450,7 @@ async def __send(
430450
json=weaviate_object,
431451
params=params,
432452
headers=self.__get_latest_headers(),
453+
timeout=self.__get_timeout(method, is_gql_query),
433454
)
434455
res = await self._client.send(req)
435456
if status_codes is not None and res.status_code not in status_codes.ok:
@@ -439,6 +460,8 @@ async def __send(
439460
raise WeaviateClosedClientError() from e
440461
except ConnectError as conn_err:
441462
raise WeaviateConnectionError(error_msg) from conn_err
463+
except ReadTimeout as read_err:
464+
raise WeaviateTimeoutError(error_msg) from read_err
442465
except Exception as e:
443466
raise e
444467

@@ -483,6 +506,7 @@ async def post(
483506
params: Optional[Dict[str, Any]] = None,
484507
error_msg: str = "",
485508
status_codes: Optional[_ExpectedStatusCodes] = None,
509+
is_gql_query: bool = False,
486510
) -> Response:
487511
return await self.__send(
488512
"POST",
@@ -491,6 +515,7 @@ async def post(
491515
params=params,
492516
error_msg=error_msg,
493517
status_codes=status_codes,
518+
is_gql_query=is_gql_query,
494519
)
495520

496521
async def put(

weaviate/exceptions.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class WeaviateConnectionError(WeaviateBaseError):
317317
"""Is raised when the connection to Weaviate fails."""
318318

319319
def __init__(self, message: str = "") -> None:
320-
msg = f"""Connection to Weaviate failed. {message}"""
320+
msg = f"""Connection to Weaviate failed. Details: {message}"""
321321
super().__init__(msg)
322322

323323

@@ -327,3 +327,11 @@ class WeaviateUnsupportedFeatureError(WeaviateBaseError):
327327
def __init__(self, feature: str, current: str, minimum: str) -> None:
328328
msg = f"""{feature} is not supported by your connected server's Weaviate version. The current version is {current}, but the feature requires at least version {minimum}."""
329329
super().__init__(msg)
330+
331+
332+
class WeaviateTimeoutError(WeaviateBaseError):
333+
"""Is raised when a request to Weaviate times out."""
334+
335+
def __init__(self, message: str = "") -> None:
336+
msg = f"""The request to Weaviate timed out while awaiting a response. Try adjusting the timeout config for your client. Details: {message}"""
337+
super().__init__(msg)

0 commit comments

Comments
 (0)