15
15
16
16
*/
17
17
18
- use std:: fmt:: Debug ;
18
+ use std:: { collections :: BTreeMap , fmt:: Debug } ;
19
19
20
+ use async_trait:: async_trait;
20
21
use axum:: http:: HeaderMap ;
21
22
use http:: header:: CONTENT_TYPE ;
22
23
use hyper:: StatusCode ;
23
- use serde:: Deserialize ;
24
+ use serde:: { Deserialize , Serialize } ;
25
+ use tracing:: info;
24
26
use url:: Url ;
25
27
26
28
use super :: {
27
29
Error ,
28
- http:: { HttpClientExt , JSON_CONTENT_TYPE , RequestBody , ResponseBody } ,
30
+ http:: { JSON_CONTENT_TYPE , RequestBody , ResponseBody } ,
31
+ } ;
32
+ use crate :: {
33
+ clients:: {
34
+ Client , HttpClient , create_http_client,
35
+ openai:: { Message , Tool } ,
36
+ } ,
37
+ config:: ServiceConfig ,
38
+ health:: HealthCheckResult ,
39
+ models:: { DetectionResult , DetectorParams , EvidenceObj , Metadata } ,
29
40
} ;
30
41
31
- pub mod text_contents;
32
- pub use text_contents:: * ;
33
- pub mod text_chat;
34
- pub use text_chat:: * ;
35
- pub mod text_context_doc;
36
- pub use text_context_doc:: * ;
37
- pub mod text_generation;
38
- pub use text_generation:: * ;
39
-
40
- const DEFAULT_PORT : u16 = 8080 ;
42
+ pub const DEFAULT_PORT : u16 = 8080 ;
43
+ pub const MODEL_HEADER_NAME : & str = "x-model-name" ;
41
44
pub const DETECTOR_ID_HEADER_NAME : & str = "detector-id" ;
42
- const MODEL_HEADER_NAME : & str = "x-model-name" ;
45
+ pub const CONTENTS_DETECTOR_ENDPOINT : & str = "/api/v1/text/contents" ;
46
+ pub const CHAT_DETECTOR_ENDPOINT : & str = "/api/v1/text/chat" ;
47
+ pub const CONTEXT_DOC_DETECTOR_ENDPOINT : & str = "/api/v1/text/context/doc" ;
48
+ pub const GENERATION_DETECTOR_ENDPOINT : & str = "/api/v1/text/generation" ;
43
49
44
- #[ derive( Debug , Clone , Deserialize ) ]
45
- pub struct DetectorError {
46
- pub code : u16 ,
47
- pub message : String ,
50
+ #[ derive( Clone ) ]
51
+ pub struct DetectorClient {
52
+ client : HttpClient ,
53
+ health_client : Option < HttpClient > ,
48
54
}
49
55
50
- impl From < DetectorError > for Error {
51
- fn from ( error : DetectorError ) -> Self {
52
- Error :: Http {
53
- code : StatusCode :: from_u16 ( error. code ) . unwrap ( ) ,
54
- message : error. message ,
55
- }
56
+ impl DetectorClient {
57
+ pub async fn new (
58
+ config : & ServiceConfig ,
59
+ health_config : Option < & ServiceConfig > ,
60
+ ) -> Result < Self , Error > {
61
+ let client = create_http_client ( DEFAULT_PORT , config) . await ?;
62
+ let health_client = if let Some ( health_config) = health_config {
63
+ Some ( create_http_client ( DEFAULT_PORT , health_config) . await ?)
64
+ } else {
65
+ None
66
+ } ;
67
+ Ok ( Self {
68
+ client,
69
+ health_client,
70
+ } )
56
71
}
57
- }
58
-
59
- /// This trait should be implemented by all detectors.
60
- /// If the detector has an HTTP client (currently all detector clients are HTTP) this trait will
61
- /// implicitly extend the client with an HTTP detector specific post function.
62
- pub trait DetectorClient { }
63
-
64
- /// Provides a helper extension for HTTP detector clients.
65
- pub trait DetectorClientExt : HttpClientExt {
66
- /// Wraps the post function with extra detector functionality
67
- /// (detector id header injection & error handling)
68
- async fn post_to_detector < U : ResponseBody > (
69
- & self ,
70
- model_id : & str ,
71
- url : Url ,
72
- headers : HeaderMap ,
73
- request : impl RequestBody ,
74
- ) -> Result < U , Error > ;
75
-
76
- /// Wraps call to inner HTTP client endpoint function.
77
- fn endpoint ( & self , path : & str ) -> Url ;
78
- }
79
72
80
- impl < C : DetectorClient + HttpClientExt > DetectorClientExt for C {
81
- async fn post_to_detector < U : ResponseBody > (
73
+ async fn post < U : ResponseBody > (
82
74
& self ,
83
75
model_id : & str ,
84
76
url : Url ,
@@ -90,7 +82,7 @@ impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
90
82
// Header used by a router component, if available
91
83
headers. append ( MODEL_HEADER_NAME , model_id. parse ( ) . unwrap ( ) ) ;
92
84
93
- let response = self . inner ( ) . post ( url, headers, request) . await ?;
85
+ let response = self . client . post ( url, headers, request) . await ?;
94
86
95
87
let status = response. status ( ) ;
96
88
match status {
@@ -106,7 +98,222 @@ impl<C: DetectorClient + HttpClientExt> DetectorClientExt for C {
106
98
}
107
99
}
108
100
109
- fn endpoint ( & self , path : & str ) -> Url {
110
- self . inner ( ) . endpoint ( path)
101
+ pub async fn text_contents (
102
+ & self ,
103
+ model_id : & str ,
104
+ request : ContentAnalysisRequest ,
105
+ headers : HeaderMap ,
106
+ ) -> Result < Vec < Vec < ContentAnalysisResponse > > , Error > {
107
+ let url = self . client . endpoint ( CONTENTS_DETECTOR_ENDPOINT ) ;
108
+ info ! ( "sending text content detector request to {}" , url) ;
109
+ self . post ( model_id, url, headers, request) . await
110
+ }
111
+
112
+ pub async fn text_chat (
113
+ & self ,
114
+ model_id : & str ,
115
+ request : ChatDetectionRequest ,
116
+ headers : HeaderMap ,
117
+ ) -> Result < Vec < DetectionResult > , Error > {
118
+ let url = self . client . endpoint ( CHAT_DETECTOR_ENDPOINT ) ;
119
+ info ! ( "sending text chat detector request to {}" , url) ;
120
+ self . post ( model_id, url, headers, request) . await
121
+ }
122
+
123
+ pub async fn text_context_doc (
124
+ & self ,
125
+ model_id : & str ,
126
+ request : ContextDocsDetectionRequest ,
127
+ headers : HeaderMap ,
128
+ ) -> Result < Vec < DetectionResult > , Error > {
129
+ let url = self . client . endpoint ( CONTEXT_DOC_DETECTOR_ENDPOINT ) ;
130
+ info ! ( "sending text context doc detector request to {}" , url) ;
131
+ self . post ( model_id, url, headers, request) . await
132
+ }
133
+
134
+ pub async fn text_generation (
135
+ & self ,
136
+ model_id : & str ,
137
+ request : GenerationDetectionRequest ,
138
+ headers : HeaderMap ,
139
+ ) -> Result < Vec < DetectionResult > , Error > {
140
+ let url = self . client . endpoint ( GENERATION_DETECTOR_ENDPOINT ) ;
141
+ info ! ( "sending text generation detector request to {}" , url) ;
142
+ self . post ( model_id, url, headers, request) . await
143
+ }
144
+ }
145
+
146
+ #[ async_trait]
147
+ impl Client for DetectorClient {
148
+ fn name ( & self ) -> & str {
149
+ "detector"
150
+ }
151
+
152
+ async fn health ( & self ) -> HealthCheckResult {
153
+ if let Some ( health_client) = & self . health_client {
154
+ health_client. health ( ) . await
155
+ } else {
156
+ self . client . health ( ) . await
157
+ }
158
+ }
159
+ }
160
+
161
+ #[ derive( Debug , Clone , Deserialize ) ]
162
+ pub struct DetectorError {
163
+ pub code : u16 ,
164
+ pub message : String ,
165
+ }
166
+
167
+ impl From < DetectorError > for Error {
168
+ fn from ( error : DetectorError ) -> Self {
169
+ Error :: Http {
170
+ code : StatusCode :: from_u16 ( error. code ) . unwrap ( ) ,
171
+ message : error. message ,
172
+ }
173
+ }
174
+ }
175
+
176
+ /// Request for text content analysis
177
+ /// Results of this request will contain analysis / detection of each of the provided documents
178
+ /// in the order they are present in the `contents` object.
179
+ #[ derive( Clone , Debug , Serialize , Deserialize , PartialEq ) ]
180
+ pub struct ContentAnalysisRequest {
181
+ /// Field allowing users to provide list of documents for analysis
182
+ pub contents : Vec < String > ,
183
+ /// Detector parameters (available parameters depend on the detector)
184
+ pub detector_params : DetectorParams ,
185
+ }
186
+
187
+ impl ContentAnalysisRequest {
188
+ pub fn new ( contents : Vec < String > , detector_params : DetectorParams ) -> ContentAnalysisRequest {
189
+ ContentAnalysisRequest {
190
+ contents,
191
+ detector_params,
192
+ }
193
+ }
194
+ }
195
+
196
+ /// Response of text content analysis endpoint
197
+ #[ derive( Clone , Debug , Default , Serialize , Deserialize , PartialEq ) ]
198
+ pub struct ContentAnalysisResponse {
199
+ /// Start index of detection
200
+ pub start : usize ,
201
+ /// End index of detection
202
+ pub end : usize ,
203
+ /// Text corresponding to detection
204
+ pub text : String ,
205
+ /// Relevant detection class
206
+ pub detection : String ,
207
+ /// Detection type or aggregate detection label
208
+ pub detection_type : String ,
209
+ /// Optional, ID of Detector
210
+ pub detector_id : Option < String > ,
211
+ /// Score of detection
212
+ pub score : f64 ,
213
+ /// Optional, any applicable evidence for detection
214
+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
215
+ pub evidence : Option < Vec < EvidenceObj > > ,
216
+ // Optional metadata block
217
+ #[ serde( default , skip_serializing_if = "BTreeMap::is_empty" ) ]
218
+ pub metadata : Metadata ,
219
+ }
220
+
221
+ impl From < ContentAnalysisResponse > for crate :: models:: TokenClassificationResult {
222
+ fn from ( value : ContentAnalysisResponse ) -> Self {
223
+ Self {
224
+ start : value. start as u32 ,
225
+ end : value. end as u32 ,
226
+ word : value. text ,
227
+ entity : value. detection ,
228
+ entity_group : value. detection_type ,
229
+ detector_id : value. detector_id ,
230
+ score : value. score ,
231
+ token_count : None ,
232
+ }
233
+ }
234
+ }
235
+
236
+ /// A struct representing a request to a detector compatible with the
237
+ /// /api/v1/text/chat endpoint.
238
+ #[ derive( Debug , Clone , Serialize ) ]
239
+ pub struct ChatDetectionRequest {
240
+ /// Chat messages to run detection on
241
+ pub messages : Vec < Message > ,
242
+ /// Optional list of tool definitions
243
+ pub tools : Vec < Tool > ,
244
+ /// Detector parameters (available parameters depend on the detector)
245
+ pub detector_params : DetectorParams ,
246
+ }
247
+
248
+ impl ChatDetectionRequest {
249
+ pub fn new ( messages : Vec < Message > , tools : Vec < Tool > , detector_params : DetectorParams ) -> Self {
250
+ Self {
251
+ messages,
252
+ tools,
253
+ detector_params,
254
+ }
255
+ }
256
+ }
257
+
258
+ /// A struct representing a request to a detector compatible with the
259
+ /// /api/v1/text/context/doc endpoint.
260
+ #[ cfg_attr( test, derive( PartialEq ) ) ]
261
+ #[ derive( Debug , Clone , Serialize ) ]
262
+ pub struct ContextDocsDetectionRequest {
263
+ /// Content to run detection on
264
+ pub content : String ,
265
+ /// Type of context being sent
266
+ pub context_type : ContextType ,
267
+ /// Context to run detection on
268
+ pub context : Vec < String > ,
269
+ /// Detector parameters (available parameters depend on the detector)
270
+ pub detector_params : DetectorParams ,
271
+ }
272
+
273
+ /// Enum representing the context type of a detection
274
+ #[ derive( Clone , Debug , PartialEq , Serialize , Deserialize ) ]
275
+ pub enum ContextType {
276
+ #[ serde( rename = "docs" ) ]
277
+ Document ,
278
+ #[ serde( rename = "url" ) ]
279
+ Url ,
280
+ }
281
+
282
+ impl ContextDocsDetectionRequest {
283
+ pub fn new (
284
+ content : String ,
285
+ context_type : ContextType ,
286
+ context : Vec < String > ,
287
+ detector_params : DetectorParams ,
288
+ ) -> Self {
289
+ Self {
290
+ content,
291
+ context_type,
292
+ context,
293
+ detector_params,
294
+ }
295
+ }
296
+ }
297
+
298
+ /// A struct representing a request to a detector compatible with the
299
+ /// /api/v1/text/generation endpoint.
300
+ #[ cfg_attr( test, derive( PartialEq ) ) ]
301
+ #[ derive( Debug , Clone , Serialize ) ]
302
+ pub struct GenerationDetectionRequest {
303
+ /// User prompt sent to LLM
304
+ pub prompt : String ,
305
+ /// Text generated from an LLM
306
+ pub generated_text : String ,
307
+ /// Detector parameters (available parameters depend on the detector)
308
+ pub detector_params : DetectorParams ,
309
+ }
310
+
311
+ impl GenerationDetectionRequest {
312
+ pub fn new ( prompt : String , generated_text : String , detector_params : DetectorParams ) -> Self {
313
+ Self {
314
+ prompt,
315
+ generated_text,
316
+ detector_params,
317
+ }
111
318
}
112
319
}
0 commit comments