Skip to content

Commit 1b518ef

Browse files
authored
Implement embed for inference (#44)
## Problem The `embed` endpoint was not implemented. ## Solution This PR implements the `embed` endpoint for inference. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan All new and existing test cases pass.
1 parent 05ed9ff commit 1b518ef

File tree

4 files changed

+215
-19
lines changed

4 files changed

+215
-19
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use openapi::apis::inference_api;
2+
use openapi::models::EmbedRequest;
3+
4+
use crate::pinecone::PineconeClient;
5+
use crate::utils::errors::PineconeError;
6+
7+
pub use openapi::models::{EmbedRequestParameters, EmbeddingsList};
8+
9+
impl PineconeClient {
10+
/// Generate embeddings for input data.
11+
///
12+
/// # Arguments
13+
/// * `model: &str` - The model to use for embedding.
14+
/// * `parameters: Option<EmbedRequestParameters>` - Model-specific parameters.
15+
/// * `inputs: &Vec<&str>` - The input data to embed.
16+
pub async fn embed(
17+
&self,
18+
model: &str,
19+
parameters: Option<EmbedRequestParameters>,
20+
inputs: &Vec<&str>,
21+
) -> Result<EmbeddingsList, PineconeError> {
22+
let request = EmbedRequest {
23+
model: model.to_string(),
24+
parameters: parameters.map(|x| Box::new(x)),
25+
inputs: inputs
26+
.iter()
27+
.map(|&x| openapi::models::EmbedRequestInputsInner {
28+
text: Some(x.to_string()),
29+
})
30+
.collect(),
31+
};
32+
33+
let res = inference_api::embed(&self.openapi_config, Some(request))
34+
.await
35+
.map_err(|e| PineconeError::from(e))?;
36+
37+
Ok(res)
38+
}
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
use httpmock::prelude::*;
45+
use tokio;
46+
47+
#[tokio::test]
48+
async fn test_embed() -> Result<(), PineconeError> {
49+
let server = MockServer::start();
50+
51+
let mock = server.mock(|when, then| {
52+
when.method(POST).path("/embed");
53+
then.status(200)
54+
.header("content-type", "application/json")
55+
.body(
56+
r#"
57+
{
58+
"model": "multilingual-e5-large",
59+
"data": [
60+
{"values": [0.01849365234375, -0.003767013549804688, -0.037261962890625, 0.0222930908203125]}
61+
],
62+
"usage": {"total_tokens": 1632}
63+
}
64+
"#,
65+
);
66+
});
67+
68+
let client = PineconeClient::new(None, Some(server.base_url().as_str()), None, None)?;
69+
let response = client
70+
.embed("multilingual-e5-large", None, &vec!["Hello, world!"])
71+
.await
72+
.expect("Failed to embed");
73+
74+
mock.assert();
75+
76+
assert_eq!(response.model.unwrap(), "multilingual-e5-large");
77+
assert_eq!(response.data.unwrap().len(), 1);
78+
assert_eq!(response.usage.unwrap().total_tokens, Some(1632));
79+
80+
Ok(())
81+
}
82+
83+
async fn test_embed_invalid_arguments() -> Result<(), PineconeError> {
84+
let server = MockServer::start();
85+
86+
let mock = server.mock(|when, then| {
87+
when.method(POST).path("/embed");
88+
then.status(400)
89+
.header("content-type", "application/json")
90+
.body(
91+
r#"
92+
{
93+
"error": {
94+
"code": "INVALID_ARGUMENT",
95+
"message": "Invalid parameter value input_type='bad-parameter' for model 'multilingual-e5-large', must be one of [query, passage]"
96+
},
97+
"status": 400
98+
}
99+
"#,
100+
);
101+
});
102+
103+
let client = PineconeClient::new(None, None, None, None).unwrap();
104+
105+
let parameters = EmbedRequestParameters {
106+
input_type: Some("bad-parameter".to_string()),
107+
truncate: Some("bad-parameter".to_string()),
108+
};
109+
110+
let _ = client
111+
.embed(
112+
"multilingual-e5-large",
113+
Some(parameters),
114+
&vec!["Hello, world!"],
115+
)
116+
.await
117+
.expect_err("Expected to fail embedding with invalid arguments");
118+
119+
mock.assert();
120+
121+
Ok(())
122+
}
123+
}

pinecone_sdk/src/pinecone/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ pub mod control;
1515
/// Data plane module.
1616
pub mod data;
1717

18+
/// Inference module.
19+
pub mod inference;
20+
1821
/// The `PineconeClient` struct is the main entry point for interacting with Pinecone via this Rust SDK.
1922
#[derive(Debug, Clone)]
2023
pub struct PineconeClient {

pinecone_sdk/src/utils/errors.rs

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ pub enum PineconeError {
131131
/// Error status
132132
status: tonic::Status,
133133
},
134+
135+
/// InferenceError: Failed to perform an inference operation.
136+
InferenceError {
137+
/// Error status
138+
status: tonic::Status,
139+
},
134140
}
135141

136142
// Implement the conversion from OpenApiError to PineconeError for CreateIndexError.
@@ -198,67 +204,70 @@ impl std::fmt::Display for PineconeError {
198204
"Unknown response error: status: {}, message: {}",
199205
status, message
200206
)
201-
},
207+
}
202208
PineconeError::ResourceAlreadyExistsError { source } => {
203209
write!(f, "Resource already exists error: {}", source)
204-
},
210+
}
205211
PineconeError::UnprocessableEntityError { source } => {
206212
write!(f, "Unprocessable entity error: {}", source)
207-
},
213+
}
208214
PineconeError::PendingCollectionError { source } => {
209215
write!(f, "Pending collection error: {}", source)
210-
},
216+
}
211217
PineconeError::InternalServerError { source } => {
212218
write!(f, "Internal server error: {}", source)
213-
},
219+
}
214220
PineconeError::ReqwestError { source } => {
215221
write!(f, "Reqwest error: {}", source.to_string())
216-
},
222+
}
217223
PineconeError::SerdeError { source } => {
218224
write!(f, "Serde error: {}", source.to_string())
219-
},
225+
}
220226
PineconeError::IoError { message } => {
221227
write!(f, "IO error: {}", message)
222-
},
228+
}
223229
PineconeError::BadRequestError { source } => {
224230
write!(f, "Bad request error: {}", source)
225-
},
231+
}
226232
PineconeError::UnauthorizedError { source } => {
227233
write!(f, "Unauthorized error: status: {}", source)
228-
},
234+
}
229235
PineconeError::PodQuotaExceededError { source } => {
230236
write!(f, "Pod quota exceeded error: {}", source)
231-
},
237+
}
232238
PineconeError::CollectionsQuotaExceededError { source } => {
233239
write!(f, "Collections quota exceeded error: {}", source)
234-
},
240+
}
235241
PineconeError::InvalidCloudError { source } => {
236242
write!(f, "Invalid cloud error: status: {}", source)
237-
},
243+
}
238244
PineconeError::InvalidRegionError { source } => {
239245
write!(f, "Invalid region error: {}", source)
240-
},
246+
}
241247
PineconeError::CollectionNotFoundError { source } => {
242248
write!(f, "Collection not found error: {}", source)
243-
},
249+
}
244250
PineconeError::IndexNotFoundError { source } => {
245251
write!(f, "Index not found error: status: {}", source)
246-
},
252+
}
247253
PineconeError::APIKeyMissingError { message } => {
248254
write!(f, "API key missing error: {}", message)
249-
},
255+
}
250256
PineconeError::InvalidHeadersError { message } => {
251257
write!(f, "Invalid headers error: {}", message)
252-
},
258+
}
253259
PineconeError::TimeoutError { message } => {
254260
write!(f, "Timeout error: {}", message)
255-
},
261+
}
256262
PineconeError::ConnectionError { source } => {
257263
write!(f, "Connection error: {}", source)
258264
}
259265
PineconeError::DataPlaneError { status } => {
260266
write!(f, "Data plane error: {}", status)
261267
}
268+
PineconeError::InferenceError { status } => {
269+
write!(f, "Inference error: {}", status)
270+
}
262271
}
263272
}
264273
}
@@ -290,6 +299,7 @@ impl std::error::Error for PineconeError {
290299
PineconeError::TimeoutError { message: _ } => None,
291300
PineconeError::ConnectionError { source } => Some(source.as_ref()),
292301
PineconeError::DataPlaneError { status } => Some(status),
302+
PineconeError::InferenceError { status } => Some(status),
293303
}
294304
}
295305
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use pinecone_sdk::pinecone::inference::EmbedRequestParameters;
2+
use pinecone_sdk::pinecone::PineconeClient;
3+
use pinecone_sdk::utils::errors::PineconeError;
4+
5+
#[tokio::test]
6+
async fn test_embed() -> Result<(), PineconeError> {
7+
let pinecone = PineconeClient::new(None, None, None, None).unwrap();
8+
9+
let parameters = EmbedRequestParameters {
10+
input_type: Some("query".to_string()),
11+
truncate: Some("END".to_string()),
12+
};
13+
14+
let response = pinecone
15+
.embed(
16+
"multilingual-e5-large",
17+
Some(parameters),
18+
&vec!["Hello, world!"],
19+
)
20+
.await
21+
.expect("Failed to embed");
22+
23+
assert_eq!(response.model.unwrap(), "multilingual-e5-large");
24+
assert_eq!(response.data.unwrap().len(), 1);
25+
26+
Ok(())
27+
}
28+
29+
#[tokio::test]
30+
async fn test_embed_invalid_model() -> Result<(), PineconeError> {
31+
let pinecone = PineconeClient::new(None, None, None, None).unwrap();
32+
33+
let _ = pinecone
34+
.embed("invalid-model", None, &vec!["Hello, world!"])
35+
.await
36+
.expect_err("Expected to fail embedding with invalid model");
37+
38+
Ok(())
39+
}
40+
41+
#[tokio::test]
42+
async fn test_embed_invalid_parameters() -> Result<(), PineconeError> {
43+
let pinecone = PineconeClient::new(None, None, None, None).unwrap();
44+
45+
let parameters = EmbedRequestParameters {
46+
input_type: Some("bad-parameter".to_string()),
47+
truncate: Some("bad-parameter".to_string()),
48+
};
49+
50+
let _ = pinecone
51+
.embed(
52+
"multilingual-e5-large",
53+
Some(parameters),
54+
&vec!["Hello, world!"],
55+
)
56+
.await
57+
.expect_err("Expected to fail embedding with invalid model parameters");
58+
59+
Ok(())
60+
}

0 commit comments

Comments
 (0)