1
- from typing import Optional , Union , List , Dict , Awaitable
1
+ from typing import Optional , Union , List , Dict , Awaitable , Any
2
2
3
3
from tqdm .asyncio import tqdm
4
4
import asyncio
@@ -91,14 +91,14 @@ async def upsert(
91
91
max_concurrent_requests : Optional [int ] = None ,
92
92
semaphore : Optional [asyncio .Semaphore ] = None ,
93
93
** kwargs ,
94
- ) -> Awaitable [ UpsertResponse ] :
94
+ ) -> UpsertResponse :
95
95
timeout = kwargs .pop ("timeout" , None )
96
96
vectors = list (map (VectorFactoryGRPC .build , vectors ))
97
97
semaphore = self ._get_semaphore (max_concurrent_requests , semaphore )
98
98
99
99
if batch_size is None :
100
100
return await self ._upsert_batch (
101
- vectors , namespace , timeout = timeout , semaphore = semaphore , ** kwargs
101
+ vectors = vectors , namespace = namespace , timeout = timeout , semaphore = semaphore , ** kwargs
102
102
)
103
103
104
104
if not isinstance (batch_size , int ) or batch_size <= 0 :
@@ -132,7 +132,7 @@ async def _upsert_batch(
132
132
namespace : Optional [str ],
133
133
timeout : Optional [int ] = None ,
134
134
** kwargs ,
135
- ) -> Awaitable [ UpsertResponse ] :
135
+ ) -> UpsertResponse :
136
136
args_dict = parse_non_empty_args ([("namespace" , namespace )])
137
137
request = UpsertRequest (vectors = vectors , ** args_dict )
138
138
return await self .runner .run_asyncio (
@@ -151,7 +151,7 @@ async def _query(
151
151
sparse_vector : Optional [Union [GRPCSparseValues , SparseVectorTypedDict ]] = None ,
152
152
semaphore : Optional [asyncio .Semaphore ] = None ,
153
153
** kwargs ,
154
- ) -> Awaitable [ Dict ]:
154
+ ) -> dict [ str , Any ]:
155
155
if vector is not None and id is not None :
156
156
raise ValueError ("Cannot specify both `id` and `vector`" )
157
157
@@ -182,7 +182,8 @@ async def _query(
182
182
response = await self .runner .run_asyncio (
183
183
self .stub .Query , request , timeout = timeout , semaphore = semaphore
184
184
)
185
- return json_format .MessageToDict (response )
185
+ parsed = json_format .MessageToDict (response )
186
+ return parsed
186
187
187
188
async def query (
188
189
self ,
@@ -196,7 +197,7 @@ async def query(
196
197
sparse_vector : Optional [Union [GRPCSparseValues , SparseVectorTypedDict ]] = None ,
197
198
semaphore : Optional [asyncio .Semaphore ] = None ,
198
199
** kwargs ,
199
- ) -> Awaitable [ QueryResponse ] :
200
+ ) -> QueryResponse :
200
201
"""
201
202
The Query operation searches a namespace, using a query vector.
202
203
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
@@ -257,9 +258,9 @@ async def query(
257
258
258
259
async def composite_query (
259
260
self ,
260
- vector : Optional [ List [float ]] = None ,
261
- namespaces : Optional [ List [str ]] = None ,
262
- top_k : Optional [int ] = 10 ,
261
+ vector : List [float ],
262
+ namespaces : List [str ],
263
+ top_k : Optional [int ] = None ,
263
264
filter : Optional [Dict [str , Union [str , float , int , bool , List , dict ]]] = None ,
264
265
include_values : Optional [bool ] = None ,
265
266
include_metadata : Optional [bool ] = None ,
@@ -268,17 +269,23 @@ async def composite_query(
268
269
max_concurrent_requests : Optional [int ] = None ,
269
270
semaphore : Optional [asyncio .Semaphore ] = None ,
270
271
** kwargs ,
271
- ) -> Awaitable [ CompositeQueryResults ] :
272
+ ) -> CompositeQueryResults :
272
273
aggregator_lock = asyncio .Lock ()
273
274
semaphore = self ._get_semaphore (max_concurrent_requests , semaphore )
274
275
275
- # The caller may only want the topK=1 result across all queries,
276
+ if len (namespaces ) == 0 :
277
+ raise ValueError ("At least one namespace must be specified" )
278
+ if len (vector ) == 0 :
279
+ raise ValueError ("Query vector must not be empty" )
280
+
281
+ # The caller may only want the top_k=1 result across all queries,
276
282
# but we need to get at least 2 results from each query in order to
277
283
# aggregate them correctly. So we'll temporarily set topK to 2 for the
278
284
# subqueries, and then we'll take the topK=1 results from the aggregated
279
285
# results.
280
- aggregator = QueryResultsAggregator (top_k = top_k )
281
- subquery_topk = top_k if top_k > 2 else 2
286
+ overall_topk = top_k if top_k is not None else 10
287
+ aggregator = QueryResultsAggregator (top_k = overall_topk )
288
+ subquery_topk = overall_topk if overall_topk > 2 else 2
282
289
283
290
target_namespaces = set (namespaces ) # dedup namespaces
284
291
query_tasks = [
0 commit comments