Skip to content

Commit 2831a5e

Browse files
authored
[Refactor] Extract GrpcRunner from GRPCIndexBase class (#395)
## Problem This is another extractive refactoring in preparation for grpc with asyncio. ## Solution The generated stub class, `VectorServiceStub`, is what knows how to call the Pinecone grpc service, but our wrapper code needs to do some work to make sure we have a consistent approach to "metadata" (grpc-speak for request headers) and handling other request params like `timeout`. Previously this work was accomplished in a private method of the `GRPCIndexBase` base class called `_wrap_grpc_call()`. Since we will need to perform almost identical marshaling of metadata for requests with asyncio, I pulled this logic out into a separate class `GrpcRunner` and renamed `_wrap_grpc_call` to `run`. You can see there is also a parallel method implementation called `run_asyncio`; currently this is unused and untested, but kind of illustrates why this refactor is useful. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [x] None of the above: Mechanical refactor, should have no net impact to functionality. ## Test Plan Tests should still be green
1 parent 4c18899 commit 2831a5e

10 files changed

+287
-144
lines changed

pinecone/grpc/base.py

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
1-
import logging
21
from abc import ABC, abstractmethod
3-
from functools import wraps
4-
from typing import Dict, Optional
2+
from typing import Optional
53

64
import grpc
7-
from grpc._channel import _InactiveRpcError, Channel
5+
from grpc._channel import Channel
86

9-
from .retry import RetryConfig
107
from .channel_factory import GrpcChannelFactory
118

129
from pinecone import Config
13-
from .utils import _generate_request_id
1410
from .config import GRPCClientConfig
15-
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
16-
from pinecone.exceptions.exceptions import PineconeException
17-
18-
_logger = logging.getLogger(__name__)
11+
from .grpc_runner import GrpcRunner
1912

2013

2114
class GRPCIndexBase(ABC):
@@ -35,18 +28,12 @@ def __init__(
3528
):
3629
self.config = config
3730
self.grpc_client_config = grpc_config or GRPCClientConfig()
38-
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()
39-
40-
self.fixed_metadata = {
41-
"api-key": config.api_key,
42-
"service-name": index_name,
43-
"client-version": CLIENT_VERSION,
44-
}
45-
if self.grpc_client_config.additional_metadata:
46-
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)
4731

4832
self._endpoint_override = _endpoint_override
4933

34+
self.runner = GrpcRunner(
35+
index_name=index_name, config=config, grpc_config=self.grpc_client_config
36+
)
5037
self.channel_factory = GrpcChannelFactory(
5138
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
5239
)
@@ -91,44 +78,6 @@ def close(self):
9178
except TypeError:
9279
pass
9380

94-
def _wrap_grpc_call(
95-
self,
96-
func,
97-
request,
98-
timeout=None,
99-
metadata=None,
100-
credentials=None,
101-
wait_for_ready=None,
102-
compression=None,
103-
):
104-
@wraps(func)
105-
def wrapped():
106-
user_provided_metadata = metadata or {}
107-
_metadata = tuple(
108-
(k, v)
109-
for k, v in {
110-
**self.fixed_metadata,
111-
**self._request_metadata(),
112-
**user_provided_metadata,
113-
}.items()
114-
)
115-
try:
116-
return func(
117-
request,
118-
timeout=timeout,
119-
metadata=_metadata,
120-
credentials=credentials,
121-
wait_for_ready=wait_for_ready,
122-
compression=compression,
123-
)
124-
except _InactiveRpcError as e:
125-
raise PineconeException(e._state.debug_error_string) from e
126-
127-
return wrapped()
128-
129-
def _request_metadata(self) -> Dict[str, str]:
130-
return {REQUEST_ID: _generate_request_id()}
131-
13281
def __enter__(self):
13382
return self
13483

pinecone/grpc/grpc_runner.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from functools import wraps
2+
from typing import Dict, Tuple, Optional
3+
4+
from grpc._channel import _InactiveRpcError
5+
6+
from pinecone import Config
7+
from .utils import _generate_request_id
8+
from .config import GRPCClientConfig
9+
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
10+
from pinecone.exceptions.exceptions import PineconeException
11+
from grpc import CallCredentials, Compression
12+
from google.protobuf.message import Message
13+
14+
15+
class GrpcRunner:
16+
def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig):
17+
self.config = config
18+
self.grpc_client_config = grpc_config
19+
20+
self.fixed_metadata = {
21+
"api-key": config.api_key,
22+
"service-name": index_name,
23+
"client-version": CLIENT_VERSION,
24+
}
25+
if self.grpc_client_config.additional_metadata:
26+
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)
27+
28+
def run(
29+
self,
30+
func,
31+
request: Message,
32+
timeout: Optional[int] = None,
33+
metadata: Optional[Dict[str, str]] = None,
34+
credentials: Optional[CallCredentials] = None,
35+
wait_for_ready: Optional[bool] = None,
36+
compression: Optional[Compression] = None,
37+
):
38+
@wraps(func)
39+
def wrapped():
40+
user_provided_metadata = metadata or {}
41+
_metadata = self._prepare_metadata(user_provided_metadata)
42+
try:
43+
return func(
44+
request,
45+
timeout=timeout,
46+
metadata=_metadata,
47+
credentials=credentials,
48+
wait_for_ready=wait_for_ready,
49+
compression=compression,
50+
)
51+
except _InactiveRpcError as e:
52+
raise PineconeException(e._state.debug_error_string) from e
53+
54+
return wrapped()
55+
56+
async def run_asyncio(
57+
self,
58+
func,
59+
request: Message,
60+
timeout: Optional[int] = None,
61+
metadata: Optional[Dict[str, str]] = None,
62+
credentials: Optional[CallCredentials] = None,
63+
wait_for_ready: Optional[bool] = None,
64+
compression: Optional[Compression] = None,
65+
):
66+
@wraps(func)
67+
async def wrapped():
68+
user_provided_metadata = metadata or {}
69+
_metadata = self._prepare_metadata(user_provided_metadata)
70+
try:
71+
return await func(
72+
request,
73+
timeout=timeout,
74+
metadata=_metadata,
75+
credentials=credentials,
76+
wait_for_ready=wait_for_ready,
77+
compression=compression,
78+
)
79+
except _InactiveRpcError as e:
80+
raise PineconeException(e._state.debug_error_string) from e
81+
82+
return await wrapped()
83+
84+
def _prepare_metadata(
85+
self, user_provided_metadata: Dict[str, str]
86+
) -> Tuple[Tuple[str, str], ...]:
87+
return tuple(
88+
(k, v)
89+
for k, v in {
90+
**self.fixed_metadata,
91+
**self._request_metadata(),
92+
**user_provided_metadata,
93+
}.items()
94+
)
95+
96+
def _request_metadata(self) -> Dict[str, str]:
97+
return {REQUEST_ID: _generate_request_id()}

pinecone/grpc/index_grpc.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def upsert(
133133
if async_req:
134134
args_dict = self._parse_non_empty_args([("namespace", namespace)])
135135
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
136-
future = self._wrap_grpc_call(self.stub.Upsert.future, request, timeout=timeout)
136+
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
137137
return PineconeGrpcFuture(future)
138138

139139
if batch_size is None:
@@ -155,15 +155,11 @@ def upsert(
155155
return UpsertResponse(upserted_count=total_upserted)
156156

157157
def _upsert_batch(
158-
self,
159-
vectors: List[GRPCVector],
160-
namespace: Optional[str],
161-
timeout: Optional[float],
162-
**kwargs,
158+
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
163159
) -> UpsertResponse:
164160
args_dict = self._parse_non_empty_args([("namespace", namespace)])
165161
request = UpsertRequest(vectors=vectors, **args_dict)
166-
return self._wrap_grpc_call(self.stub.Upsert, request, timeout=timeout, **kwargs)
162+
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)
167163

168164
def upsert_from_dataframe(
169165
self,
@@ -280,10 +276,10 @@ def delete(
280276

281277
request = DeleteRequest(**args_dict, **kwargs)
282278
if async_req:
283-
future = self._wrap_grpc_call(self.stub.Delete.future, request, timeout=timeout)
279+
future = self.runner.run(self.stub.Delete.future, request, timeout=timeout)
284280
return PineconeGrpcFuture(future)
285281
else:
286-
return self._wrap_grpc_call(self.stub.Delete, request, timeout=timeout)
282+
return self.runner.run(self.stub.Delete, request, timeout=timeout)
287283

288284
def fetch(
289285
self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs
@@ -308,7 +304,7 @@ def fetch(
308304
args_dict = self._parse_non_empty_args([("namespace", namespace)])
309305

310306
request = FetchRequest(ids=ids, **args_dict, **kwargs)
311-
response = self._wrap_grpc_call(self.stub.Fetch, request, timeout=timeout)
307+
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
312308
json_response = json_format.MessageToDict(response)
313309
return parse_fetch_response(json_response)
314310

@@ -388,7 +384,7 @@ def query(
388384
request = QueryRequest(**args_dict)
389385

390386
timeout = kwargs.pop("timeout", None)
391-
response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout)
387+
response = self.runner.run(self.stub.Query, request, timeout=timeout)
392388
json_response = json_format.MessageToDict(response)
393389
return parse_query_response(json_response, _check_type=False)
394390

@@ -451,10 +447,10 @@ def update(
451447

452448
request = UpdateRequest(id=id, **args_dict)
453449
if async_req:
454-
future = self._wrap_grpc_call(self.stub.Update.future, request, timeout=timeout)
450+
future = self.runner.run(self.stub.Update.future, request, timeout=timeout)
455451
return PineconeGrpcFuture(future)
456452
else:
457-
return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout)
453+
return self.runner.run(self.stub.Update, request, timeout=timeout)
458454

459455
def list_paginated(
460456
self,
@@ -499,7 +495,7 @@ def list_paginated(
499495
)
500496
request = ListRequest(**args_dict, **kwargs)
501497
timeout = kwargs.pop("timeout", None)
502-
response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout)
498+
response = self.runner.run(self.stub.List, request, timeout=timeout)
503499

504500
if response.pagination and response.pagination.next != "":
505501
pagination = Pagination(next=response.pagination.next)
@@ -572,7 +568,7 @@ def describe_index_stats(
572568
timeout = kwargs.pop("timeout", None)
573569

574570
request = DescribeIndexStatsRequest(**args_dict)
575-
response = self._wrap_grpc_call(self.stub.DescribeIndexStats, request, timeout=timeout)
571+
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
576572
json_response = json_format.MessageToDict(response)
577573
return parse_stats_response(json_response)
578574

tests/unit_grpc/test_grpc_index_describe_index_stats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ def setup_method(self):
1212
)
1313

1414
def test_describeIndexStats_callWithoutFilter_CalledWithoutFilter(self, mocker):
15-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
15+
mocker.patch.object(self.index.runner, "run", autospec=True)
1616
self.index.describe_index_stats()
17-
self.index._wrap_grpc_call.assert_called_once_with(
17+
self.index.runner.run.assert_called_once_with(
1818
self.index.stub.DescribeIndexStats, DescribeIndexStatsRequest(), timeout=None
1919
)
2020

2121
def test_describeIndexStats_callWithFilter_CalledWithFilter(self, mocker, filter1):
22-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
22+
mocker.patch.object(self.index.runner, "run", autospec=True)
2323
self.index.describe_index_stats(filter=filter1)
24-
self.index._wrap_grpc_call.assert_called_once_with(
24+
self.index.runner.run.assert_called_once_with(
2525
self.index.stub.DescribeIndexStats,
2626
DescribeIndexStatsRequest(filter=dict_to_proto_struct(filter1)),
2727
timeout=None,

tests/unit_grpc/test_grpc_index_fetch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ def setup_method(self):
1111
)
1212

1313
def test_fetch_byIds_fetchByIds(self, mocker):
14-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
14+
mocker.patch.object(self.index.runner, "run", autospec=True)
1515
self.index.fetch(["vec1", "vec2"])
16-
self.index._wrap_grpc_call.assert_called_once_with(
16+
self.index.runner.run.assert_called_once_with(
1717
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"]), timeout=None
1818
)
1919

2020
def test_fetch_byIdsAndNS_fetchByIdsAndNS(self, mocker):
21-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
21+
mocker.patch.object(self.index.runner, "run", autospec=True)
2222
self.index.fetch(["vec1", "vec2"], namespace="ns", timeout=30)
23-
self.index._wrap_grpc_call.assert_called_once_with(
23+
self.index.runner.run.assert_called_once_with(
2424
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"], namespace="ns"), timeout=30
2525
)

tests/unit_grpc/test_grpc_index_initialization.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@ def test_init_with_default_config(self):
1515
assert index.grpc_client_config.grpc_channel_options is None
1616
assert index.grpc_client_config.additional_metadata is None
1717

18-
# Default metadata, grpc equivalent to http request headers
19-
assert len(index.fixed_metadata) == 3
20-
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
21-
assert index.fixed_metadata["service-name"] == "my-index"
22-
assert index.fixed_metadata["client-version"] is not None
23-
24-
def test_init_with_additional_metadata(self):
25-
pc = PineconeGRPC(api_key="YOUR_API_KEY")
26-
config = GRPCClientConfig(
27-
additional_metadata={"debug-header": "value123", "debug-header2": "value456"}
28-
)
29-
index = pc.Index(name="my-index", host="host", grpc_config=config)
30-
assert len(index.fixed_metadata) == 5
31-
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
32-
assert index.fixed_metadata["service-name"] == "my-index"
33-
assert index.fixed_metadata["client-version"] is not None
34-
assert index.fixed_metadata["debug-header"] == "value123"
35-
assert index.fixed_metadata["debug-header2"] == "value456"
36-
3718
def test_init_with_grpc_config_from_dict(self):
3819
pc = PineconeGRPC(api_key="YOUR_API_KEY")
3920
config = GRPCClientConfig._from_dict({"timeout": 10})

tests/unit_grpc/test_grpc_index_query.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ def setup_method(self):
1414
)
1515

1616
def test_query_byVectorNoFilter_queryVectorNoFilter(self, mocker, vals1):
17-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
17+
mocker.patch.object(self.index.runner, "run", autospec=True)
1818
self.index.query(top_k=10, vector=vals1)
19-
self.index._wrap_grpc_call.assert_called_once_with(
19+
self.index.runner.run.assert_called_once_with(
2020
self.index.stub.Query, QueryRequest(top_k=10, vector=vals1), timeout=None
2121
)
2222

2323
def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, filter1):
24-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
24+
mocker.patch.object(self.index.runner, "run", autospec=True)
2525
self.index.query(top_k=10, vector=vals1, filter=filter1, namespace="ns", timeout=10)
26-
self.index._wrap_grpc_call.assert_called_once_with(
26+
self.index.runner.run.assert_called_once_with(
2727
self.index.stub.Query,
2828
QueryRequest(
2929
top_k=10, vector=vals1, filter=dict_to_proto_struct(filter1), namespace="ns"
@@ -32,9 +32,9 @@ def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, fil
3232
)
3333

3434
def test_query_byVecId_queryByVecId(self, mocker):
35-
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
35+
mocker.patch.object(self.index.runner, "run", autospec=True)
3636
self.index.query(top_k=10, id="vec1", include_metadata=True, include_values=False)
37-
self.index._wrap_grpc_call.assert_called_once_with(
37+
self.index.runner.run.assert_called_once_with(
3838
self.index.stub.Query,
3939
QueryRequest(top_k=10, id="vec1", include_metadata=True, include_values=False),
4040
timeout=None,

0 commit comments

Comments
 (0)