Skip to content

Commit 2117e4b

Browse files
committed
Fix mypy issues
1 parent 1380065 commit 2117e4b

File tree

4 files changed

+53
-27
lines changed

4 files changed

+53
-27
lines changed

pinecone/grpc/index_grpc_asyncio.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union, List, Dict, Awaitable
1+
from typing import Optional, Union, List, Dict, Awaitable, Any
22

33
from tqdm.asyncio import tqdm
44
import asyncio
@@ -91,14 +91,14 @@ async def upsert(
9191
max_concurrent_requests: Optional[int] = None,
9292
semaphore: Optional[asyncio.Semaphore] = None,
9393
**kwargs,
94-
) -> Awaitable[UpsertResponse]:
94+
) -> UpsertResponse:
9595
timeout = kwargs.pop("timeout", None)
9696
vectors = list(map(VectorFactoryGRPC.build, vectors))
9797
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)
9898

9999
if batch_size is None:
100100
return await self._upsert_batch(
101-
vectors, namespace, timeout=timeout, semaphore=semaphore, **kwargs
101+
vectors=vectors, namespace=namespace, timeout=timeout, semaphore=semaphore, **kwargs
102102
)
103103

104104
if not isinstance(batch_size, int) or batch_size <= 0:
@@ -132,7 +132,7 @@ async def _upsert_batch(
132132
namespace: Optional[str],
133133
timeout: Optional[int] = None,
134134
**kwargs,
135-
) -> Awaitable[UpsertResponse]:
135+
) -> UpsertResponse:
136136
args_dict = parse_non_empty_args([("namespace", namespace)])
137137
request = UpsertRequest(vectors=vectors, **args_dict)
138138
return await self.runner.run_asyncio(
@@ -151,7 +151,7 @@ async def _query(
151151
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
152152
semaphore: Optional[asyncio.Semaphore] = None,
153153
**kwargs,
154-
) -> Awaitable[Dict]:
154+
) -> dict[str, Any]:
155155
if vector is not None and id is not None:
156156
raise ValueError("Cannot specify both `id` and `vector`")
157157

@@ -182,7 +182,8 @@ async def _query(
182182
response = await self.runner.run_asyncio(
183183
self.stub.Query, request, timeout=timeout, semaphore=semaphore
184184
)
185-
return json_format.MessageToDict(response)
185+
parsed = json_format.MessageToDict(response)
186+
return parsed
186187

187188
async def query(
188189
self,
@@ -196,7 +197,7 @@ async def query(
196197
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
197198
semaphore: Optional[asyncio.Semaphore] = None,
198199
**kwargs,
199-
) -> Awaitable[QueryResponse]:
200+
) -> QueryResponse:
200201
"""
201202
The Query operation searches a namespace, using a query vector.
202203
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
@@ -257,9 +258,9 @@ async def query(
257258

258259
async def composite_query(
259260
self,
260-
vector: Optional[List[float]] = None,
261-
namespaces: Optional[List[str]] = None,
262-
top_k: Optional[int] = 10,
261+
vector: List[float],
262+
namespaces: List[str],
263+
top_k: Optional[int] = None,
263264
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
264265
include_values: Optional[bool] = None,
265266
include_metadata: Optional[bool] = None,
@@ -268,17 +269,23 @@ async def composite_query(
268269
max_concurrent_requests: Optional[int] = None,
269270
semaphore: Optional[asyncio.Semaphore] = None,
270271
**kwargs,
271-
) -> Awaitable[CompositeQueryResults]:
272+
) -> CompositeQueryResults:
272273
aggregator_lock = asyncio.Lock()
273274
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)
274275

275-
# The caller may only want the topK=1 result across all queries,
276+
if len(namespaces) == 0:
277+
raise ValueError("At least one namespace must be specified")
278+
if len(vector) == 0:
279+
raise ValueError("Query vector must not be empty")
280+
281+
# The caller may only want the top_k=1 result across all queries,
276282
# but we need to get at least 2 results from each query in order to
277283
# aggregate them correctly. So we'll temporarily set topK to 2 for the
278284
# subqueries, and then we'll take the topK=1 results from the aggregated
279285
# results.
280-
aggregator = QueryResultsAggregator(top_k=top_k)
281-
subquery_topk = top_k if top_k > 2 else 2
286+
overall_topk = top_k if top_k is not None else 10
287+
aggregator = QueryResultsAggregator(top_k=overall_topk)
288+
subquery_topk = overall_topk if overall_topk > 2 else 2
282289

283290
target_namespaces = set(namespaces) # dedup namespaces
284291
query_tasks = [

pinecone/grpc/query_results.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import TypedDict, List, Dict, Any
2+
3+
4+
class ScoredVectorTypedDict(TypedDict):
5+
id: str
6+
score: float
7+
values: List[float]
8+
metadata: dict
9+
10+
11+
class QueryResultsTypedDict(TypedDict):
12+
matches: List[ScoredVectorTypedDict]
13+
namespace: str
14+
usage: Dict[str, Any]

pinecone/grpc/query_results_aggregator.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Optional, Any, Dict
22
import json
33
import heapq
4-
from pinecone.core.openapi.data.models import QueryResponse, Usage
4+
from pinecone.core.openapi.data.models import Usage
55

66
from dataclasses import dataclass, asdict
77

@@ -15,14 +15,14 @@ class ScoredVectorWithNamespace:
1515
sparse_values: dict
1616
metadata: dict
1717

18-
def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, dict, str]):
18+
def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
1919
json_vector = aggregate_results_heap_tuple[2]
2020
self.namespace = aggregate_results_heap_tuple[3]
21-
self.id = json_vector.get("id")
22-
self.score = json_vector.get("score")
23-
self.values = json_vector.get("values")
24-
self.sparse_values = json_vector.get("sparse_values", None)
25-
self.metadata = json_vector.get("metadata", None)
21+
self.id = json_vector.get("id") # type: ignore
22+
self.score = json_vector.get("score") # type: ignore
23+
self.values = json_vector.get("values") # type: ignore
24+
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore
25+
self.metadata = json_vector.get("metadata", None) # type: ignore
2626

2727
def __getitem__(self, key):
2828
if hasattr(self, key):
@@ -106,10 +106,11 @@ def __init__(self, top_k: int):
106106
raise QueryResultsAggregatorInvalidTopKError(top_k)
107107
self.top_k = top_k
108108
self.usage_read_units = 0
109-
self.heap = []
109+
self.heap: List[Tuple[float, int, object, str]] = []
110110
self.insertion_counter = 0
111111
self.is_dotproduct = None
112112
self.read = False
113+
self.final_results: Optional[CompositeQueryResults] = None
113114

114115
def _is_dotproduct_index(self, matches):
115116
# The interpretation of the score depends on the similar metric used.
@@ -135,15 +136,15 @@ def _process_matches(self, matches, ns, heap_item_fn):
135136
else:
136137
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
137138

138-
def add_results(self, results: QueryResponse):
139+
def add_results(self, results: Dict[str, Any]):
139140
if self.read:
140141
# This is mainly just to sanity check in test cases which get quite confusing
141142
# if you read results twice due to the heap being emptied when constructing
142143
# the ordered results.
143144
raise ValueError("Results have already been read. Cannot add more results.")
144145

145146
matches = results.get("matches", [])
146-
ns = results.get("namespace")
147+
ns: str = results.get("namespace", "")
147148
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
148149

149150
if len(matches) == 0:
@@ -161,7 +162,11 @@ def add_results(self, results: QueryResponse):
161162

162163
def get_results(self) -> CompositeQueryResults:
163164
if self.read:
164-
return self.final_results
165+
if self.final_results is not None:
166+
return self.final_results
167+
else:
168+
# I don't think this branch can ever actually be reached, but the type checker disagrees
169+
raise ValueError("Results have already been read. Cannot get results again.")
165170
self.read = True
166171

167172
self.final_results = CompositeQueryResults(

pinecone/grpc/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
QueryResponse,
1111
DescribeIndexStatsResponse,
1212
NamespaceSummary,
13-
SparseValues as GRPCSparseValues,
1413
)
14+
from pinecone.core.grpc.protos.vector_service_pb2 import SparseValues as GRPCSparseValues
1515
from .sparse_vector import SparseVectorTypedDict
1616

1717
from google.protobuf.struct_pb2 import Struct

0 commit comments

Comments
 (0)