Skip to content

Commit b7989c2

Browse files
authored
[Fix] Return type from embed() should be iterator (#454)
## Problem When migrating the `embed` and `rerank` over from the plugin, I forgot to include these custom return objects. ## Solution Add custom return types, adjust tests to ensure result is iterable. ## Type of Change - [x] Bug fix (non-breaking change which fixes an issue)
1 parent 4dced5e commit b7989c2

File tree

7 files changed

+102
-32
lines changed

7 files changed

+102
-32
lines changed

pinecone/data/features/inference/inference.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pinecone.openapi_support import ApiClient
55
from pinecone.core.openapi.inference.apis import InferenceApi
6-
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
6+
from .models import EmbeddingsList, RerankResult
77
from pinecone.core.openapi.inference import API_VERSION
88
from pinecone.utils import setup_openapi_client, PluginAware
99

@@ -84,7 +84,8 @@ def embed(
8484
request_body = InferenceRequestBuilder.embed_request(
8585
model=model, inputs=inputs, parameters=parameters
8686
)
87-
return self.__inference_api.embed(embed_request=request_body)
87+
resp = self.__inference_api.embed(embed_request=request_body)
88+
return EmbeddingsList(resp)
8889

8990
def rerank(
9091
self,
@@ -162,4 +163,5 @@ def rerank(
162163
top_n=top_n,
163164
parameters=parameters,
164165
)
165-
return self.__inference_api.rerank(rerank_request=rerank_request)
166+
resp = self.__inference_api.rerank(rerank_request=rerank_request)
167+
return RerankResult(resp)

pinecone/data/features/inference/inference_asyncio.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional, Dict, List, Union, Any
22

33
from pinecone.core.openapi.inference.api.inference_api import AsyncioInferenceApi
4-
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
4+
from .models import EmbeddingsList, RerankResult
55

66
from .inference_request_builder import (
77
InferenceRequestBuilder,
@@ -64,7 +64,8 @@ async def embed(
6464
request_body = InferenceRequestBuilder.embed_request(
6565
model=model, inputs=inputs, parameters=parameters
6666
)
67-
return await self.__inference_api.embed(embed_request=request_body)
67+
resp = await self.__inference_api.embed(embed_request=request_body)
68+
return EmbeddingsList(resp)
6869

6970
async def rerank(
7071
self,
@@ -142,4 +143,5 @@ async def rerank(
142143
top_n=top_n,
143144
parameters=parameters,
144145
)
145-
return await self.__inference_api.rerank(rerank_request=rerank_request)
146+
resp = await self.__inference_api.rerank(rerank_request=rerank_request)
147+
return RerankResult(resp)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .embedding_list import EmbeddingsList
2+
from .rerank_result import RerankResult
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from pinecone.core.openapi.inference.models import EmbeddingsList as OpenAPIEmbeddingsList
2+
3+
4+
class EmbeddingsList:
5+
"""
6+
A list of embeddings.
7+
"""
8+
9+
def __init__(self, embeddings_list: OpenAPIEmbeddingsList):
10+
self.embeddings_list = embeddings_list
11+
self.current = 0
12+
13+
def __getitem__(self, index):
14+
return self.embeddings_list.get("data")[index]
15+
16+
def __len__(self):
17+
return len(self.embeddings_list.get("data"))
18+
19+
def __iter__(self):
20+
return iter(self.embeddings_list.get("data"))
21+
22+
def __str__(self):
23+
return str(self.embeddings_list)
24+
25+
def __repr__(self):
26+
return repr(self.embeddings_list)
27+
28+
def __getattr__(self, attr):
29+
return getattr(self.embeddings_list, attr)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pinecone.core.openapi.inference.models import RerankResult as OpenAPIRerankResult
2+
3+
4+
class RerankResult:
5+
"""
6+
A wrapper around OpenAPIRerankResult.
7+
"""
8+
9+
def __init__(self, rerank_result: OpenAPIRerankResult):
10+
self.rerank_result = rerank_result
11+
12+
def __str__(self):
13+
return str(self.rerank_result)
14+
15+
def __repr__(self):
16+
return repr(self.rerank_result)
17+
18+
def __getattr__(self, attr):
19+
return getattr(self.rerank_result, attr)

tests/integration/inference/test_asyncio_inference.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,32 @@ async def test_create_embeddings(self, model_input, model_output):
1919
parameters={"input_type": "query", "truncate": "END"},
2020
)
2121
assert embeddings.vector_type == "dense"
22+
assert embeddings.get("vector_type") == "dense"
2223
assert embeddings.model == model_output
24+
assert embeddings.get("model") == model_output
2325
assert len(embeddings.data) == 2
24-
assert len(embeddings.data[0].values) == 1024
25-
assert len(embeddings.data[1].values) == 1024
26+
assert len(embeddings.get("data")) == 2
27+
assert embeddings.usage is not None
2628

27-
# Dict-style bracket accessors
28-
assert embeddings["vector_type"] == "dense"
29-
assert embeddings["model"] == model_output
30-
assert len(embeddings["data"]) == 2
29+
individual_embedding = embeddings[0]
30+
assert len(individual_embedding.values) == 1024
31+
assert individual_embedding.vector_type.value == "dense"
32+
assert len(individual_embedding["values"]) == 1024
3133

32-
# Dict-style get accessors for embeddings
33-
assert embeddings.get("vector_type") == "dense"
34-
assert embeddings.get("model") == model_output
35-
assert len(embeddings.get("data")) == 2
36-
assert len(embeddings.get("data")[0]["values"]) == 1024
37-
assert len(embeddings.get("data")[1]["values"]) == 1024
38-
assert embeddings.get("model") == model_output
34+
await pc.close()
35+
36+
async def test_embedding_result_is_iterable(self):
37+
pc = PineconeAsyncio()
38+
embeddings = await pc.inference.embed(
39+
model=EmbedModel.Multilingual_E5_Large,
40+
inputs=["The quick brown fox jumps over the lazy dog.", "lorem ipsum"],
41+
parameters={"input_type": "query", "truncate": "END"},
42+
)
43+
iter_count = 0
44+
for embedding in embeddings:
45+
iter_count += 1
46+
assert len(embedding.values) == 1024
47+
assert iter_count == 2
3948
await pc.close()
4049

4150
@pytest.mark.parametrize(

tests/integration/inference/test_inference.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,30 @@ def test_create_embeddings(self, model_input, model_output):
1818
parameters={"input_type": "query", "truncate": "END"},
1919
)
2020
assert embeddings.vector_type == "dense"
21+
assert embeddings.get("vector_type") == "dense"
2122
assert embeddings.model == model_output
23+
assert embeddings.get("model") == model_output
2224
assert len(embeddings.data) == 2
23-
assert len(embeddings.data[0].values) == 1024
24-
assert len(embeddings.data[1].values) == 1024
25+
assert len(embeddings.get("data")) == 2
26+
assert embeddings.usage is not None
2527

26-
# Dict-style bracket accessors
27-
assert embeddings["vector_type"] == "dense"
28-
assert embeddings["model"] == model_output
29-
assert len(embeddings["data"]) == 2
28+
individual_embedding = embeddings[0]
29+
assert len(individual_embedding.values) == 1024
30+
assert individual_embedding.vector_type.value == "dense"
31+
assert len(individual_embedding["values"]) == 1024
3032

31-
# Dict-style get accessors for embeddings
32-
assert embeddings.get("vector_type") == "dense"
33-
assert embeddings.get("model") == model_output
34-
assert len(embeddings.get("data")) == 2
35-
assert len(embeddings.get("data")[0]["values"]) == 1024
36-
assert len(embeddings.get("data")[1]["values"]) == 1024
37-
assert embeddings.get("model") == model_output
33+
def test_embedding_result_is_iterable(self):
34+
pc = Pinecone()
35+
embeddings = pc.inference.embed(
36+
model=EmbedModel.Multilingual_E5_Large,
37+
inputs=["The quick brown fox jumps over the lazy dog.", "lorem ipsum"],
38+
parameters={"input_type": "query", "truncate": "END"},
39+
)
40+
iter_count = 0
41+
for embedding in embeddings:
42+
iter_count += 1
43+
assert len(embedding.values) == 1024
44+
assert iter_count == 2
3845

3946
@pytest.mark.parametrize(
4047
"model_input,model_output",

0 commit comments

Comments
 (0)