Skip to content

Commit c005b4f

Browse files
64bitifsheldon
authored andcommitted
feat: configurable per request path (64bit#479)
* refactor to avoid repetition * configurable per request path (cherry picked from commit 6013669)
1 parent fc17c9c commit c005b4f

File tree

3 files changed

+84
-150
lines changed

3 files changed

+84
-150
lines changed

async-openai/src/client.rs

Lines changed: 62 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,35 @@ impl<C: Config> Client<C> {
196196
&self.config
197197
}
198198

199+
/// Helper function to build a request builder with common configuration
200+
fn build_request_builder(
201+
&self,
202+
method: reqwest::Method,
203+
path: &str,
204+
request_options: &RequestOptions,
205+
) -> reqwest::RequestBuilder {
206+
let mut request_builder = if let Some(path) = request_options.path() {
207+
self.http_client
208+
.request(method, self.config.url(path.as_str()))
209+
} else {
210+
self.http_client.request(method, self.config.url(path))
211+
};
212+
213+
request_builder = request_builder
214+
.query(&self.config.query())
215+
.headers(self.config.headers());
216+
217+
if let Some(headers) = request_options.headers() {
218+
request_builder = request_builder.headers(headers.clone());
219+
}
220+
221+
if !request_options.query().is_empty() {
222+
request_builder = request_builder.query(request_options.query());
223+
}
224+
225+
request_builder
226+
}
227+
199228
/// Make a GET request to {path} and deserialize the response body
200229
pub(crate) async fn get<O>(
201230
&self,
@@ -205,22 +234,10 @@ impl<C: Config> Client<C> {
205234
where
206235
O: DeserializeOwned,
207236
{
208-
self.execute(async {
209-
let mut request_builder = self
210-
.http_client
211-
.get(self.config.url(path))
212-
.query(&self.config.query())
213-
.headers(self.config.headers());
214-
215-
if let Some(headers) = request_options.headers() {
216-
request_builder = request_builder.headers(headers.clone());
217-
}
218-
219-
if !request_options.query().is_empty() {
220-
request_builder = request_builder.query(request_options.query());
221-
}
222-
223-
Ok(request_builder.build()?)
237+
let request_maker = || async {
238+
Ok(self
239+
.build_request_builder(reqwest::Method::GET, path, request_options)
240+
.build()?)
224241
})
225242
.await
226243
}
@@ -235,21 +252,9 @@ impl<C: Config> Client<C> {
235252
O: DeserializeOwned,
236253
{
237254
self.execute(async {
238-
let mut request_builder = self
239-
.http_client
240-
.delete(self.config.url(path))
241-
.query(&self.config.query())
242-
.headers(self.config.headers());
243-
244-
if let Some(headers) = request_options.headers() {
245-
request_builder = request_builder.headers(headers.clone());
246-
}
247-
248-
if !request_options.query().is_empty() {
249-
request_builder = request_builder.query(request_options.query());
250-
}
251-
252-
Ok(request_builder.build()?)
255+
Ok(self
256+
.build_request_builder(reqwest::Method::DELETE, path, request_options)
257+
.build()?)
253258
})
254259
.await
255260
}
@@ -261,21 +266,9 @@ impl<C: Config> Client<C> {
261266
request_options: &RequestOptions,
262267
) -> Result<(Bytes, HeaderMap), OpenAIError> {
263268
self.execute_raw(async {
264-
let mut request_builder = self
265-
.http_client
266-
.get(self.config.url(path))
267-
.query(&self.config.query())
268-
.headers(self.config.headers());
269-
270-
if let Some(headers) = request_options.headers() {
271-
request_builder = request_builder.headers(headers.clone());
272-
}
273-
274-
if !request_options.query().is_empty() {
275-
request_builder = request_builder.query(request_options.query());
276-
}
277-
278-
Ok(request_builder.build()?)
269+
Ok(self
270+
.build_request_builder(reqwest::Method::GET, path, request_options)
271+
.build()?)
279272
})
280273
.await
281274
}
@@ -291,22 +284,10 @@ impl<C: Config> Client<C> {
291284
I: Serialize,
292285
{
293286
self.execute_raw(async {
294-
let mut request_builder = self
295-
.http_client
296-
.post(self.config.url(path))
297-
.query(&self.config.query())
298-
.headers(self.config.headers())
299-
.json(&request);
300-
301-
if let Some(headers) = request_options.headers() {
302-
request_builder = request_builder.headers(headers.clone());
303-
}
304-
305-
if !request_options.query().is_empty() {
306-
request_builder = request_builder.query(request_options.query());
307-
}
308-
309-
Ok(request_builder.build()?)
287+
Ok(self
288+
.build_request_builder(reqwest::Method::POST, path, request_options)
289+
.json(&request)
290+
.build()?)
310291
})
311292
.await
312293
}
@@ -323,22 +304,10 @@ impl<C: Config> Client<C> {
323304
O: DeserializeOwned,
324305
{
325306
self.execute(async {
326-
let mut request_builder = self
327-
.http_client
328-
.post(self.config.url(path))
329-
.query(&self.config.query())
330-
.headers(self.config.headers())
331-
.json(&request);
332-
333-
if let Some(headers) = request_options.headers() {
334-
request_builder = request_builder.headers(headers.clone());
335-
}
336-
337-
if !request_options.query().is_empty() {
338-
request_builder = request_builder.query(request_options.query());
339-
}
340-
341-
Ok(request_builder.build()?)
307+
Ok(self
308+
.build_request_builder(reqwest::Method::POST, path, request_options)
309+
.json(&request)
310+
.build()?)
342311
})
343312
.await
344313
}
@@ -355,22 +324,10 @@ impl<C: Config> Client<C> {
355324
{
356325
self.execute_raw(async {
357326
let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
358-
let mut request_builder = self
359-
.http_client
360-
.post(self.config.url(path))
361-
.query(&self.config.query())
362-
.headers(self.config.headers())
363-
.multipart(form);
364-
365-
if let Some(headers) = request_options.headers() {
366-
request_builder = request_builder.headers(headers.clone());
367-
}
368-
369-
if !request_options.query().is_empty() {
370-
request_builder = request_builder.query(request_options.query());
371-
}
372-
373-
Ok(request_builder.build()?)
327+
Ok(self
328+
.build_request_builder(reqwest::Method::POST, path, request_options)
329+
.multipart(form)
330+
.build()?)
374331
})
375332
.await
376333
}
@@ -388,22 +345,10 @@ impl<C: Config> Client<C> {
388345
{
389346
self.execute(async {
390347
let form = <Form as AsyncTryFrom<F>>::try_from(form).await?;
391-
let mut request_builder = self
392-
.http_client
393-
.post(self.config.url(path))
394-
.query(&self.config.query())
395-
.headers(self.config.headers())
396-
.multipart(form);
397-
398-
if let Some(headers) = request_options.headers() {
399-
request_builder = request_builder.headers(headers.clone());
400-
}
401-
402-
if !request_options.query().is_empty() {
403-
request_builder = request_builder.query(request_options.query());
404-
}
405-
406-
Ok(request_builder.build()?)
348+
Ok(self
349+
.build_request_builder(reqwest::Method::POST, path, request_options)
350+
.multipart(form)
351+
.build()?)
407352
})
408353
.await
409354
}
@@ -421,20 +366,9 @@ impl<C: Config> Client<C> {
421366
{
422367
// Build and execute request manually since multipart::Form is not Clone
423368
// and .eventsource() requires cloneability
424-
let mut request_builder = self
425-
.http_client
426-
.post(self.config.url(path))
427-
.query(&self.config.query())
428-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
429-
.headers(self.config.headers());
430-
431-
if let Some(headers) = request_options.headers() {
432-
request_builder = request_builder.headers(headers.clone());
433-
}
434-
435-
if !request_options.query().is_empty() {
436-
request_builder = request_builder.query(request_options.query());
437-
}
369+
let request_builder = self
370+
.build_request_builder(reqwest::Method::POST, path, request_options)
371+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
438372

439373
let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
440374

@@ -509,21 +443,10 @@ impl<C: Config> Client<C> {
509443
I: Serialize,
510444
O: DeserializeOwned + Send + 'static,
511445
{
512-
let mut request_builder = self
513-
.http_client
514-
.post(self.config.url(path))
515-
.query(&self.config.query())
516-
.headers(self.config.headers())
446+
let request_builder = self
447+
.build_request_builder(reqwest::Method::POST, path, request_options)
517448
.json(&request);
518449

519-
if let Some(headers) = request_options.headers() {
520-
request_builder = request_builder.headers(headers.clone());
521-
}
522-
523-
if !request_options.query().is_empty() {
524-
request_builder = request_builder.query(request_options.query());
525-
}
526-
527450
let event_source = request_builder.eventsource().unwrap();
528451

529452
OpenAIEventStream::new(event_source)
@@ -540,21 +463,10 @@ impl<C: Config> Client<C> {
540463
I: Serialize,
541464
O: DeserializeOwned + Send + 'static,
542465
{
543-
let mut request_builder = self
544-
.http_client
545-
.post(self.config.url(path))
546-
.query(&self.config.query())
547-
.headers(self.config.headers())
466+
let request_builder = self
467+
.build_request_builder(reqwest::Method::POST, path, request_options)
548468
.json(&request);
549469

550-
if let Some(headers) = request_options.headers() {
551-
request_builder = request_builder.headers(headers.clone());
552-
}
553-
554-
if !request_options.query().is_empty() {
555-
request_builder = request_builder.query(request_options.query());
556-
}
557-
558470
let event_source = request_builder.eventsource().unwrap();
559471

560472
OpenAIEventStream::with_event_mapping(event_source, event_mapper)

async-openai/src/request_options.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,28 @@ use crate::{config::OPENAI_API_BASE, error::OpenAIError};
88
pub struct RequestOptions {
99
query: Option<Vec<(String, String)>>,
1010
headers: Option<HeaderMap>,
11+
path: Option<String>,
1112
}
1213

1314
impl RequestOptions {
1415
pub(crate) fn new() -> Self {
1516
Self {
1617
query: None,
1718
headers: None,
19+
path: None,
1820
}
1921
}
2022

23+
pub(crate) fn with_path(&mut self, path: &str) -> Result<(), OpenAIError> {
24+
if path.is_empty() {
25+
return Err(OpenAIError::InvalidArgument(
26+
"Path cannot be empty".to_string(),
27+
));
28+
}
29+
self.path = Some(path.to_string());
30+
Ok(())
31+
}
32+
2133
pub(crate) fn with_headers(&mut self, headers: HeaderMap) {
2234
// merge with existing headers or update with new headers
2335
if let Some(existing_headers) = &mut self.headers {
@@ -81,4 +93,8 @@ impl RequestOptions {
8193
pub(crate) fn headers(&self) -> Option<&HeaderMap> {
8294
self.headers.as_ref()
8395
}
96+
97+
pub(crate) fn path(&self) -> Option<&String> {
98+
self.path.as_ref()
99+
}
84100
}

async-openai/src/traits.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,10 @@ pub trait RequestOptionsBuilder: Sized {
5353
self.options_mut().with_query(query)?;
5454
Ok(self)
5555
}
56+
57+
/// Add a path to RequestOptions
58+
fn path<P: Into<String>>(mut self, path: P) -> Result<Self, OpenAIError> {
59+
self.options_mut().with_path(path.into().as_str())?;
60+
Ok(self)
61+
}
5662
}

0 commit comments

Comments
 (0)