1
- from aimon import Detect
1
+ from aimon import Client , Detect
2
2
from dataclasses import dataclass
3
3
4
4
@dataclass
5
5
class ReactConfig :
6
6
publish : bool
7
+ async_mode :bool
7
8
model_name : str
8
9
max_attempts : int
9
10
aimon_api_key : str
10
11
application_name : str
11
- values_returned : list [str ]
12
12
hallucination_threshold : float
13
- aimon_config : dict [ str , dict [ str , str ]]
13
+ framework : str = None
14
14
15
- ## ReAct -> Reason and Act
15
+ class React :
16
+
17
+ ## Initailze the AIMon Client here
18
+ def __init__ (self , llm_app , react_configuration , context_extractor ):
19
+
20
+ self .llm_app = llm_app
21
+ self .context_extractor = context_extractor
22
+ self .react_configuration = react_configuration
23
+ self .client = Client (auth_header = "Bearer {}" .format (self .react_configuration .aimon_api_key ))
24
+
25
+ def create_payload (self , context , user_query , user_instructions , generated_text ):
26
+
27
+ aimon_payload = {
28
+ 'context' :context ,
29
+ 'user_query' :user_query ,
30
+ 'generated_text' :generated_text ,
31
+ 'instructions' :user_instructions ,
32
+ }
33
+
34
+ aimon_payload ['publish' ] = self .react_configuration .publish
35
+ aimon_payload ['async_mode' ] = self .react_configuration .async_mode
36
+ aimon_payload ['config' ] = { 'hallucination' : {'detector_name' : 'default' },
37
+ 'instruction_adherence' : {'detector_name' : 'default' },}
38
+
39
+ if self .react_configuration .publish :
40
+ aimon_payload ['application_name' ] = self .react_configuration .application_name
41
+ aimon_payload ['model_name' ] = self .react_configuration .model_name
42
+
43
+ return aimon_payload
16
44
17
- def react ( llm_app ,
18
- user_query ,
19
- user_instructions ,
20
- context_extractor ,
21
- react_configuration ,
22
- ):
45
+
46
+ ## ReAct -> Reason and Act
47
+ def react (self , user_query , user_instructions ,):
48
+
49
+ llm_response = self .llm_app (user_query , user_instructions , reprompted_flag = False )
50
+
51
+ context = self .context_extractor (user_query , user_instructions , llm_response )
52
+
53
+ ## Generated text for LLM Response, if the user employs the LlamaIndex framework
54
+ if llm_response .response or self .react_configuration .framework == "llamaindex" :
55
+ generated_text = llm_response .response
56
+ else :
57
+ generated_text = llm_response
23
58
24
- detect = Detect (values_returned = react_configuration .values_returned ,
25
- api_key = react_configuration .aimon_api_key ,
26
- config = react_configuration .aimon_config ,
27
- publish = react_configuration .publish ,
28
- application_name = react_configuration .application_name ,
29
- model_name = react_configuration .model_name ,
30
- )
31
-
32
- llm_response = llm_app (user_query , user_instructions , reprompted_flag = False )
33
-
34
- ## Decorating the context_extractor function with AIMon's "detect"
35
- context_extractor = detect (context_extractor )
59
+ aimon_payload = self .create_payload (context , user_query , user_instructions , generated_text )
36
60
37
- _ , _ , _ , query_result , aimon_response = context_extractor (user_query , user_instructions , llm_response )
61
+ detect_response = self .client .inference .detect (body = [aimon_payload ])
62
+
63
+ for _ in range (self .react_configuration .max_attempts ):
38
64
39
- for _ in range (react_configuration .max_attempts ):
65
+ failed_instructions = []
66
+ ## Loop to check for failed instructions
67
+ for x in detect_response .instruction_adherence ['results' ]:
68
+ if x ['adherence' ] == False :
69
+ failed_instructions .append (x ['instruction' ])
40
70
41
- failed_instructions = []
42
- ## Loop to check for failed instructions
43
- for x in aimon_response .detect_response .instruction_adherence ['results' ]:
44
- if x ['adherence' ] == False :
45
- failed_instructions .append (x ['instruction' ])
71
+ hallucination_score = detect_response .hallucination ['score' ]
46
72
47
- hallucination_score = aimon_response .detect_response .hallucination ['score' ]
73
+ ## Check whether the hallucination score is greater than the required threshold OR if any of the supplied instructions are not complied with
74
+ if self .react_configuration .hallucination_threshold > 0 and \
75
+ (hallucination_score > self .react_configuration .hallucination_threshold or len (failed_instructions )> 0 ):
76
+
77
+ llm_response = self .llm_app (user_query , user_instructions , reprompted_flag = True , hallucination_score = hallucination_score )
78
+
79
+ context = self .context_extractor (user_query , user_instructions , llm_response )
80
+
81
+ ## Generated text for LLM Response, if the user employs the LlamaIndex framework
82
+ if llm_response .response or self .react_configuration .framework == "llamaindex" :
83
+ generated_text = llm_response .response
84
+ else :
85
+ generated_text = llm_response
48
86
49
- ## Check whether the hallucination score is greater than the required threshold OR if any of the supplied instructions are not complied with
50
- if react_configuration .hallucination_threshold > 0 and \
51
- (hallucination_score > react_configuration .hallucination_threshold or len (failed_instructions )> 0 ):
52
-
53
- llm_response = llm_app (user_query , user_instructions , reprompted_flag = True , hallucination_score = hallucination_score )
87
+ new_aimon_payload = self .create_payload (context , user_query , user_instructions , generated_text )
54
88
55
- _ , _ , _ , query_result , aimon_response = context_extractor ( user_query , user_instructions , llm_response )
89
+ detect_response = self . client . inference . detect ( body = [ new_aimon_payload ] )
56
90
57
- if hallucination_score > react_configuration .hallucination_threshold :
58
- return f"The generated LLM response, even after { react_configuration .max_attempts } attempts of ReAct is still hallucinated. The response: { query_result } "
91
+ if hallucination_score > self . react_configuration .hallucination_threshold :
92
+ return f"The generated LLM response, even after { self . react_configuration .max_attempts } attempts of ReAct is still hallucinated. The response: { generated_text } "
59
93
60
- return query_result
94
+ return generated_text
0 commit comments