Skip to content

Commit 4c18899

Browse files
authored
[Refactor] Extract GrpcChannelFactory from GRPCIndexBase (#394)
## Problem I'm preparing to implement asyncio for the data plane, and I had a need to extract some of this grpc channel configuration into a spot where it could be reused more easily across both sync and async implementations. ## Solution - Extract `GrpcChannelFactory` from `GRPCIndexBase` - Add some unit tests for this new class ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [x] None of the above: Refactoring only, should be no functional change
1 parent 1d0f046 commit 4c18899

File tree

3 files changed

+248
-64
lines changed

3 files changed

+248
-64
lines changed

pinecone/grpc/base.py

Lines changed: 7 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33
from functools import wraps
44
from typing import Dict, Optional
55

6-
import certifi
76
import grpc
87
from grpc._channel import _InactiveRpcError, Channel
9-
import json
108

119
from .retry import RetryConfig
10+
from .channel_factory import GrpcChannelFactory
1211

1312
from pinecone import Config
1413
from .utils import _generate_request_id
1514
from .config import GRPCClientConfig
16-
from pinecone.utils.constants import MAX_MSG_SIZE, REQUEST_ID, CLIENT_VERSION
17-
from pinecone.utils.user_agent import get_user_agent_grpc
15+
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
1816
from pinecone.exceptions.exceptions import PineconeException
1917

2018
_logger = logging.getLogger(__name__)
@@ -35,8 +33,6 @@ def __init__(
3533
grpc_config: Optional[GRPCClientConfig] = None,
3634
_endpoint_override: Optional[str] = None,
3735
):
38-
self.name = index_name
39-
4036
self.config = config
4137
self.grpc_client_config = grpc_config or GRPCClientConfig()
4238
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()
@@ -51,35 +47,10 @@ def __init__(
5147

5248
self._endpoint_override = _endpoint_override
5349

54-
self.method_config = json.dumps(
55-
{
56-
"methodConfig": [
57-
{
58-
"name": [{"service": "VectorService.Upsert"}],
59-
"retryPolicy": {
60-
"maxAttempts": 5,
61-
"initialBackoff": "0.1s",
62-
"maxBackoff": "1s",
63-
"backoffMultiplier": 2,
64-
"retryableStatusCodes": ["UNAVAILABLE"],
65-
},
66-
},
67-
{
68-
"name": [{"service": "VectorService"}],
69-
"retryPolicy": {
70-
"maxAttempts": 5,
71-
"initialBackoff": "0.1s",
72-
"maxBackoff": "1s",
73-
"backoffMultiplier": 2,
74-
"retryableStatusCodes": ["UNAVAILABLE"],
75-
},
76-
},
77-
]
78-
}
50+
self.channel_factory = GrpcChannelFactory(
51+
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
7952
)
80-
81-
options = {"grpc.primary_user_agent": get_user_agent_grpc(config)}
82-
self._channel = channel or self._gen_channel(options=options)
53+
self._channel = channel or self._gen_channel()
8354
self.stub = self.stub_class(self._channel)
8455

8556
@property
@@ -93,36 +64,8 @@ def _endpoint(self):
9364
grpc_host = f"{grpc_host}:443"
9465
return self._endpoint_override if self._endpoint_override else grpc_host
9566

96-
def _gen_channel(self, options=None):
97-
target = self._endpoint()
98-
default_options = {
99-
"grpc.max_send_message_length": MAX_MSG_SIZE,
100-
"grpc.max_receive_message_length": MAX_MSG_SIZE,
101-
"grpc.service_config": self.method_config,
102-
"grpc.enable_retries": True,
103-
"grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE,
104-
}
105-
if self.grpc_client_config.secure:
106-
default_options["grpc.ssl_target_name_override"] = target.split(":")[0]
107-
if self.config.proxy_url:
108-
default_options["grpc.http_proxy"] = self.config.proxy_url
109-
user_provided_options = options or {}
110-
_options = tuple((k, v) for k, v in {**default_options, **user_provided_options}.items())
111-
_logger.debug(
112-
"creating new channel with endpoint %s options %s and config %s",
113-
target,
114-
_options,
115-
self.grpc_client_config,
116-
)
117-
if not self.grpc_client_config.secure:
118-
channel = grpc.insecure_channel(target, options=_options)
119-
else:
120-
ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where()
121-
root_cas = open(ca_certs, "rb").read()
122-
tls = grpc.ssl_channel_credentials(root_certificates=root_cas)
123-
channel = grpc.secure_channel(target, tls, options=_options)
124-
125-
return channel
67+
def _gen_channel(self):
68+
return self.channel_factory.create_channel(self._endpoint())
12669

12770
@property
12871
def channel(self):

pinecone/grpc/channel_factory.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import logging
2+
from typing import Optional
3+
4+
import certifi
5+
import grpc
6+
import json
7+
8+
from pinecone import Config
9+
from .config import GRPCClientConfig
10+
from pinecone.utils.constants import MAX_MSG_SIZE
11+
from pinecone.utils.user_agent import get_user_agent_grpc
12+
13+
_logger = logging.getLogger(__name__)
14+
15+
16+
class GrpcChannelFactory:
17+
def __init__(
18+
self,
19+
config: Config,
20+
grpc_client_config: GRPCClientConfig,
21+
use_asyncio: Optional[bool] = False,
22+
):
23+
self.config = config
24+
self.grpc_client_config = grpc_client_config
25+
self.use_asyncio = use_asyncio
26+
27+
def _get_service_config(self):
28+
# https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto
29+
return json.dumps(
30+
{
31+
"methodConfig": [
32+
{
33+
"name": [{"service": "VectorService.Upsert"}],
34+
"retryPolicy": {
35+
"maxAttempts": 5,
36+
"initialBackoff": "0.1s",
37+
"maxBackoff": "1s",
38+
"backoffMultiplier": 2,
39+
"retryableStatusCodes": ["UNAVAILABLE"],
40+
},
41+
},
42+
{
43+
"name": [{"service": "VectorService"}],
44+
"retryPolicy": {
45+
"maxAttempts": 5,
46+
"initialBackoff": "0.1s",
47+
"maxBackoff": "1s",
48+
"backoffMultiplier": 2,
49+
"retryableStatusCodes": ["UNAVAILABLE"],
50+
},
51+
},
52+
]
53+
}
54+
)
55+
56+
def _build_options(self, target):
57+
# For property definitions, see https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
58+
options = {
59+
"grpc.max_send_message_length": MAX_MSG_SIZE,
60+
"grpc.max_receive_message_length": MAX_MSG_SIZE,
61+
"grpc.service_config": self._get_service_config(),
62+
"grpc.enable_retries": True,
63+
"grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE,
64+
"grpc.primary_user_agent": get_user_agent_grpc(self.config),
65+
}
66+
if self.grpc_client_config.secure:
67+
options["grpc.ssl_target_name_override"] = target.split(":")[0]
68+
if self.config.proxy_url:
69+
options["grpc.http_proxy"] = self.config.proxy_url
70+
71+
options_tuple = tuple((k, v) for k, v in options.items())
72+
return options_tuple
73+
74+
def _build_channel_credentials(self):
75+
ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where()
76+
root_cas = open(ca_certs, "rb").read()
77+
channel_creds = grpc.ssl_channel_credentials(root_certificates=root_cas)
78+
return channel_creds
79+
80+
def create_channel(self, endpoint):
81+
options_tuple = self._build_options(endpoint)
82+
83+
_logger.debug(
84+
"Creating new channel with endpoint %s options %s and config %s",
85+
endpoint,
86+
options_tuple,
87+
self.grpc_client_config,
88+
)
89+
90+
if not self.grpc_client_config.secure:
91+
create_channel_fn = (
92+
grpc.aio.insecure_channel if self.use_asyncio else grpc.insecure_channel
93+
)
94+
channel = create_channel_fn(endpoint, options=options_tuple)
95+
else:
96+
channel_creds = self._build_channel_credentials()
97+
create_channel_fn = grpc.aio.secure_channel if self.use_asyncio else grpc.secure_channel
98+
channel = create_channel_fn(endpoint, credentials=channel_creds, options=options_tuple)
99+
100+
return channel
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import grpc
2+
import re
3+
import pytest
4+
from unittest.mock import patch, MagicMock, ANY
5+
6+
from pinecone import Config
7+
from pinecone.grpc.channel_factory import GrpcChannelFactory, GRPCClientConfig
8+
from pinecone.utils.constants import MAX_MSG_SIZE
9+
10+
11+
@pytest.fixture
12+
def config():
13+
return Config(ssl_ca_certs=None, proxy_url=None)
14+
15+
16+
@pytest.fixture
17+
def grpc_client_config():
18+
return GRPCClientConfig(secure=True)
19+
20+
21+
class TestGrpcChannelFactory:
22+
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config):
23+
factory = GrpcChannelFactory(
24+
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
25+
)
26+
endpoint = "test.endpoint:443"
27+
28+
with patch("grpc.secure_channel") as mock_secure_channel, patch(
29+
"certifi.where", return_value="/path/to/certifi/cacert.pem"
30+
), patch("builtins.open", new_callable=MagicMock) as mock_open:
31+
# Mock the file object to return bytes when read() is called
32+
mock_file = MagicMock()
33+
mock_file.read.return_value = b"mocked_cert_data"
34+
mock_open.return_value = mock_file
35+
channel = factory.create_channel(endpoint)
36+
37+
mock_secure_channel.assert_called_once()
38+
assert mock_secure_channel.call_args[0][0] == endpoint
39+
assert isinstance(mock_secure_channel.call_args[1]["options"], tuple)
40+
41+
options = dict(mock_secure_channel.call_args[1]["options"])
42+
assert options["grpc.ssl_target_name_override"] == "test.endpoint"
43+
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE
44+
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE
45+
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE
46+
assert "grpc.service_config" in options
47+
assert options["grpc.enable_retries"] is True
48+
assert (
49+
re.search(
50+
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"]
51+
)
52+
is not None
53+
)
54+
55+
assert isinstance(channel, MagicMock)
56+
57+
def test_create_secure_channel_with_proxy(self):
58+
grpc_client_config = GRPCClientConfig(secure=True)
59+
config = Config(proxy_url="http://test.proxy:8080")
60+
factory = GrpcChannelFactory(
61+
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
62+
)
63+
endpoint = "test.endpoint:443"
64+
65+
with patch("grpc.secure_channel") as mock_secure_channel:
66+
channel = factory.create_channel(endpoint)
67+
68+
mock_secure_channel.assert_called_once()
69+
assert "grpc.http_proxy" in dict(mock_secure_channel.call_args[1]["options"])
70+
assert (
71+
"http://test.proxy:8080"
72+
== dict(mock_secure_channel.call_args[1]["options"])["grpc.http_proxy"]
73+
)
74+
assert isinstance(channel, MagicMock)
75+
76+
def test_create_insecure_channel(self, config):
77+
grpc_client_config = GRPCClientConfig(secure=False)
78+
factory = GrpcChannelFactory(
79+
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
80+
)
81+
endpoint = "test.endpoint:50051"
82+
83+
with patch("grpc.insecure_channel") as mock_insecure_channel:
84+
channel = factory.create_channel(endpoint)
85+
86+
mock_insecure_channel.assert_called_once_with(endpoint, options=ANY)
87+
assert isinstance(channel, MagicMock)
88+
89+
90+
class TestGrpcChannelFactoryAsyncio:
91+
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config):
92+
factory = GrpcChannelFactory(
93+
config=config, grpc_client_config=grpc_client_config, use_asyncio=True
94+
)
95+
endpoint = "test.endpoint:443"
96+
97+
with patch("grpc.aio.secure_channel") as mock_secure_aio_channel, patch(
98+
"certifi.where", return_value="/path/to/certifi/cacert.pem"
99+
), patch("builtins.open", new_callable=MagicMock) as mock_open:
100+
# Mock the file object to return bytes when read() is called
101+
mock_file = MagicMock()
102+
mock_file.read.return_value = b"mocked_cert_data"
103+
mock_open.return_value = mock_file
104+
channel = factory.create_channel(endpoint)
105+
106+
mock_secure_aio_channel.assert_called_once()
107+
assert mock_secure_aio_channel.call_args[0][0] == endpoint
108+
assert isinstance(mock_secure_aio_channel.call_args[1]["options"], tuple)
109+
110+
options = dict(mock_secure_aio_channel.call_args[1]["options"])
111+
assert options["grpc.ssl_target_name_override"] == "test.endpoint"
112+
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE
113+
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE
114+
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE
115+
assert "grpc.service_config" in options
116+
assert options["grpc.enable_retries"] is True
117+
assert (
118+
re.search(
119+
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"]
120+
)
121+
is not None
122+
)
123+
124+
security_credentials = mock_secure_aio_channel.call_args[1]["credentials"]
125+
assert security_credentials is not None
126+
assert isinstance(security_credentials, grpc.ChannelCredentials)
127+
128+
assert isinstance(channel, MagicMock)
129+
130+
def test_create_insecure_channel_asyncio(self, config):
131+
grpc_client_config = GRPCClientConfig(secure=False)
132+
factory = GrpcChannelFactory(
133+
config=config, grpc_client_config=grpc_client_config, use_asyncio=True
134+
)
135+
endpoint = "test.endpoint:50051"
136+
137+
with patch("grpc.aio.insecure_channel") as mock_aio_insecure_channel:
138+
channel = factory.create_channel(endpoint)
139+
140+
mock_aio_insecure_channel.assert_called_once_with(endpoint, options=ANY)
141+
assert isinstance(channel, MagicMock)

0 commit comments

Comments
 (0)