Skip to content

Commit 528086b

Browse files
declark1mdevino
andauthored
Refactor detector clients to single DetectorClient (#462)
* Refactor to single DetectorClient Signed-off-by: declark1 <[email protected]> * Drop unneeded cfg_attr Co-authored-by: Mateus Devino <[email protected]> Signed-off-by: Dan Clark <[email protected]> * Drop unneeded client method Signed-off-by: declark1 <[email protected]> --------- Signed-off-by: declark1 <[email protected]> Signed-off-by: Dan Clark <[email protected]> Co-authored-by: Mateus Devino <[email protected]>
1 parent 96de3d4 commit 528086b

File tree

9 files changed

+279
-664
lines changed

9 files changed

+279
-664
lines changed

src/clients.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ pub mod http;
4949
pub use http::{HttpClient, http_trace_layer};
5050

5151
pub mod chunker;
52+
pub use chunker::ChunkerClient;
5253

5354
pub mod detector;
54-
pub use detector::TextContentsDetectorClient;
55+
pub use detector::DetectorClient;
5556

5657
pub mod tgis;
5758
pub use tgis::TgisClient;

src/clients/detector.rs

Lines changed: 258 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,70 +15,62 @@
1515
1616
*/
1717

18-
use std::fmt::Debug;
18+
use std::{collections::BTreeMap, fmt::Debug};
1919

20+
use async_trait::async_trait;
2021
use axum::http::HeaderMap;
2122
use http::header::CONTENT_TYPE;
2223
use hyper::StatusCode;
23-
use serde::Deserialize;
24+
use serde::{Deserialize, Serialize};
25+
use tracing::info;
2426
use url::Url;
2527

2628
use super::{
2729
Error,
28-
http::{HttpClientExt, JSON_CONTENT_TYPE, RequestBody, ResponseBody},
30+
http::{JSON_CONTENT_TYPE, RequestBody, ResponseBody},
31+
};
32+
use crate::{
33+
clients::{
34+
Client, HttpClient, create_http_client,
35+
openai::{Message, Tool},
36+
},
37+
config::ServiceConfig,
38+
health::HealthCheckResult,
39+
models::{DetectionResult, DetectorParams, EvidenceObj, Metadata},
2940
};
3041

31-
pub mod text_contents;
32-
pub use text_contents::*;
33-
pub mod text_chat;
34-
pub use text_chat::*;
35-
pub mod text_context_doc;
36-
pub use text_context_doc::*;
37-
pub mod text_generation;
38-
pub use text_generation::*;
39-
40-
const DEFAULT_PORT: u16 = 8080;
42+
pub const DEFAULT_PORT: u16 = 8080;
43+
pub const MODEL_HEADER_NAME: &str = "x-model-name";
4144
pub const DETECTOR_ID_HEADER_NAME: &str = "detector-id";
42-
const MODEL_HEADER_NAME: &str = "x-model-name";
45+
pub const CONTENTS_DETECTOR_ENDPOINT: &str = "/api/v1/text/contents";
46+
pub const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat";
47+
pub const CONTEXT_DOC_DETECTOR_ENDPOINT: &str = "/api/v1/text/context/doc";
48+
pub const GENERATION_DETECTOR_ENDPOINT: &str = "/api/v1/text/generation";
4349

44-
#[derive(Debug, Clone, Deserialize)]
45-
pub struct DetectorError {
46-
pub code: u16,
47-
pub message: String,
50+
#[derive(Clone)]
51+
pub struct DetectorClient {
52+
client: HttpClient,
53+
health_client: Option<HttpClient>,
4854
}
4955

50-
impl From<DetectorError> for Error {
51-
fn from(error: DetectorError) -> Self {
52-
Error::Http {
53-
code: StatusCode::from_u16(error.code).unwrap(),
54-
message: error.message,
55-
}
56+
impl DetectorClient {
57+
pub async fn new(
58+
config: &ServiceConfig,
59+
health_config: Option<&ServiceConfig>,
60+
) -> Result<Self, Error> {
61+
let client = create_http_client(DEFAULT_PORT, config).await?;
62+
let health_client = if let Some(health_config) = health_config {
63+
Some(create_http_client(DEFAULT_PORT, health_config).await?)
64+
} else {
65+
None
66+
};
67+
Ok(Self {
68+
client,
69+
health_client,
70+
})
5671
}
57-
}
58-
59-
/// This trait should be implemented by all detectors.
60-
/// If the detector has an HTTP client (currently all detector clients are HTTP) this trait will
61-
/// implicitly extend the client with an HTTP detector specific post function.
62-
pub trait DetectorClient {}
63-
64-
/// Provides a helper extension for HTTP detector clients.
65-
pub trait DetectorClientExt: HttpClientExt {
66-
/// Wraps the post function with extra detector functionality
67-
/// (detector id header injection & error handling)
68-
async fn post_to_detector<U: ResponseBody>(
69-
&self,
70-
model_id: &str,
71-
url: Url,
72-
headers: HeaderMap,
73-
request: impl RequestBody,
74-
) -> Result<U, Error>;
75-
76-
/// Wraps call to inner HTTP client endpoint function.
77-
fn endpoint(&self, path: &str) -> Url;
78-
}
7972

80-
impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
81-
async fn post_to_detector<U: ResponseBody>(
73+
async fn post<U: ResponseBody>(
8274
&self,
8375
model_id: &str,
8476
url: Url,
@@ -90,7 +82,7 @@ impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
9082
// Header used by a router component, if available
9183
headers.append(MODEL_HEADER_NAME, model_id.parse().unwrap());
9284

93-
let response = self.inner().post(url, headers, request).await?;
85+
let response = self.client.post(url, headers, request).await?;
9486

9587
let status = response.status();
9688
match status {
@@ -106,7 +98,222 @@ impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
10698
}
10799
}
108100

109-
fn endpoint(&self, path: &str) -> Url {
110-
self.inner().endpoint(path)
101+
pub async fn text_contents(
102+
&self,
103+
model_id: &str,
104+
request: ContentAnalysisRequest,
105+
headers: HeaderMap,
106+
) -> Result<Vec<Vec<ContentAnalysisResponse>>, Error> {
107+
let url = self.client.endpoint(CONTENTS_DETECTOR_ENDPOINT);
108+
info!("sending text content detector request to {}", url);
109+
self.post(model_id, url, headers, request).await
110+
}
111+
112+
pub async fn text_chat(
113+
&self,
114+
model_id: &str,
115+
request: ChatDetectionRequest,
116+
headers: HeaderMap,
117+
) -> Result<Vec<DetectionResult>, Error> {
118+
let url = self.client.endpoint(CHAT_DETECTOR_ENDPOINT);
119+
info!("sending text chat detector request to {}", url);
120+
self.post(model_id, url, headers, request).await
121+
}
122+
123+
pub async fn text_context_doc(
124+
&self,
125+
model_id: &str,
126+
request: ContextDocsDetectionRequest,
127+
headers: HeaderMap,
128+
) -> Result<Vec<DetectionResult>, Error> {
129+
let url = self.client.endpoint(CONTEXT_DOC_DETECTOR_ENDPOINT);
130+
info!("sending text context doc detector request to {}", url);
131+
self.post(model_id, url, headers, request).await
132+
}
133+
134+
pub async fn text_generation(
135+
&self,
136+
model_id: &str,
137+
request: GenerationDetectionRequest,
138+
headers: HeaderMap,
139+
) -> Result<Vec<DetectionResult>, Error> {
140+
let url = self.client.endpoint(GENERATION_DETECTOR_ENDPOINT);
141+
info!("sending text generation detector request to {}", url);
142+
self.post(model_id, url, headers, request).await
143+
}
144+
}
145+
146+
#[async_trait]
147+
impl Client for DetectorClient {
148+
fn name(&self) -> &str {
149+
"detector"
150+
}
151+
152+
async fn health(&self) -> HealthCheckResult {
153+
if let Some(health_client) = &self.health_client {
154+
health_client.health().await
155+
} else {
156+
self.client.health().await
157+
}
158+
}
159+
}
160+
161+
#[derive(Debug, Clone, Deserialize)]
162+
pub struct DetectorError {
163+
pub code: u16,
164+
pub message: String,
165+
}
166+
167+
impl From<DetectorError> for Error {
168+
fn from(error: DetectorError) -> Self {
169+
Error::Http {
170+
code: StatusCode::from_u16(error.code).unwrap(),
171+
message: error.message,
172+
}
173+
}
174+
}
175+
176+
/// Request for text content analysis
177+
/// Results of this request will contain analysis / detection of each of the provided documents
178+
/// in the order they are present in the `contents` object.
179+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
180+
pub struct ContentAnalysisRequest {
181+
/// Field allowing users to provide list of documents for analysis
182+
pub contents: Vec<String>,
183+
/// Detector parameters (available parameters depend on the detector)
184+
pub detector_params: DetectorParams,
185+
}
186+
187+
impl ContentAnalysisRequest {
188+
pub fn new(contents: Vec<String>, detector_params: DetectorParams) -> ContentAnalysisRequest {
189+
ContentAnalysisRequest {
190+
contents,
191+
detector_params,
192+
}
193+
}
194+
}
195+
196+
/// Response of text content analysis endpoint
197+
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
198+
pub struct ContentAnalysisResponse {
199+
/// Start index of detection
200+
pub start: usize,
201+
/// End index of detection
202+
pub end: usize,
203+
/// Text corresponding to detection
204+
pub text: String,
205+
/// Relevant detection class
206+
pub detection: String,
207+
/// Detection type or aggregate detection label
208+
pub detection_type: String,
209+
/// Optional, ID of Detector
210+
pub detector_id: Option<String>,
211+
/// Score of detection
212+
pub score: f64,
213+
/// Optional, any applicable evidence for detection
214+
#[serde(skip_serializing_if = "Option::is_none")]
215+
pub evidence: Option<Vec<EvidenceObj>>,
216+
// Optional metadata block
217+
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
218+
pub metadata: Metadata,
219+
}
220+
221+
impl From<ContentAnalysisResponse> for crate::models::TokenClassificationResult {
222+
fn from(value: ContentAnalysisResponse) -> Self {
223+
Self {
224+
start: value.start as u32,
225+
end: value.end as u32,
226+
word: value.text,
227+
entity: value.detection,
228+
entity_group: value.detection_type,
229+
detector_id: value.detector_id,
230+
score: value.score,
231+
token_count: None,
232+
}
233+
}
234+
}
235+
236+
/// A struct representing a request to a detector compatible with the
237+
/// /api/v1/text/chat endpoint.
238+
#[derive(Debug, Clone, Serialize)]
239+
pub struct ChatDetectionRequest {
240+
/// Chat messages to run detection on
241+
pub messages: Vec<Message>,
242+
/// Optional list of tool definitions
243+
pub tools: Vec<Tool>,
244+
/// Detector parameters (available parameters depend on the detector)
245+
pub detector_params: DetectorParams,
246+
}
247+
248+
impl ChatDetectionRequest {
249+
pub fn new(messages: Vec<Message>, tools: Vec<Tool>, detector_params: DetectorParams) -> Self {
250+
Self {
251+
messages,
252+
tools,
253+
detector_params,
254+
}
255+
}
256+
}
257+
258+
/// A struct representing a request to a detector compatible with the
259+
/// /api/v1/text/context/doc endpoint.
260+
#[cfg_attr(test, derive(PartialEq))]
261+
#[derive(Debug, Clone, Serialize)]
262+
pub struct ContextDocsDetectionRequest {
263+
/// Content to run detection on
264+
pub content: String,
265+
/// Type of context being sent
266+
pub context_type: ContextType,
267+
/// Context to run detection on
268+
pub context: Vec<String>,
269+
/// Detector parameters (available parameters depend on the detector)
270+
pub detector_params: DetectorParams,
271+
}
272+
273+
/// Enum representing the context type of a detection
274+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
275+
pub enum ContextType {
276+
#[serde(rename = "docs")]
277+
Document,
278+
#[serde(rename = "url")]
279+
Url,
280+
}
281+
282+
impl ContextDocsDetectionRequest {
283+
pub fn new(
284+
content: String,
285+
context_type: ContextType,
286+
context: Vec<String>,
287+
detector_params: DetectorParams,
288+
) -> Self {
289+
Self {
290+
content,
291+
context_type,
292+
context,
293+
detector_params,
294+
}
295+
}
296+
}
297+
298+
/// A struct representing a request to a detector compatible with the
299+
/// /api/v1/text/generation endpoint.
300+
#[cfg_attr(test, derive(PartialEq))]
301+
#[derive(Debug, Clone, Serialize)]
302+
pub struct GenerationDetectionRequest {
303+
/// User prompt sent to LLM
304+
pub prompt: String,
305+
/// Text generated from an LLM
306+
pub generated_text: String,
307+
/// Detector parameters (available parameters depend on the detector)
308+
pub detector_params: DetectorParams,
309+
}
310+
311+
impl GenerationDetectionRequest {
312+
pub fn new(prompt: String, generated_text: String, detector_params: DetectorParams) -> Self {
313+
Self {
314+
prompt,
315+
generated_text,
316+
detector_params,
317+
}
111318
}
112319
}

0 commit comments

Comments
 (0)