Skip to content

Commit f2049d0

Browse files
class React
1 parent e354ff5 commit f2049d0

File tree

2 files changed

+74
-40
lines changed

2 files changed

+74
-40
lines changed

aimon/extensions/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .react import ReactConfig, react
1+
from .react import ReactConfig, React

aimon/extensions/react.py

+73-39
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,94 @@
1-
from aimon import Detect
1+
from aimon import Client, Detect
22
from dataclasses import dataclass
33

44
@dataclass
55
class ReactConfig:
66
publish: bool
7+
async_mode:bool
78
model_name: str
89
max_attempts: int
910
aimon_api_key: str
1011
application_name: str
11-
values_returned: list[str]
1212
hallucination_threshold: float
13-
aimon_config: dict[str, dict[str, str]]
13+
framework:str = None
1414

15-
## ReAct -> Reason and Act
15+
class React:
16+
17+
## Initailze the AIMon Client here
18+
def __init__(self, llm_app, react_configuration, context_extractor):
19+
20+
self.llm_app = llm_app
21+
self.context_extractor = context_extractor
22+
self.react_configuration = react_configuration
23+
self.client = Client(auth_header="Bearer {}".format(self.react_configuration.aimon_api_key))
24+
25+
def create_payload(self, context, user_query, user_instructions, generated_text):
26+
27+
aimon_payload = {
28+
'context':context,
29+
'user_query':user_query,
30+
'generated_text':generated_text,
31+
'instructions':user_instructions,
32+
}
33+
34+
aimon_payload['publish'] = self.react_configuration.publish
35+
aimon_payload['async_mode'] = self.react_configuration.async_mode
36+
aimon_payload['config'] = { 'hallucination': {'detector_name': 'default'},
37+
'instruction_adherence': {'detector_name': 'default'},}
38+
39+
if self.react_configuration.publish:
40+
aimon_payload['application_name'] = self.react_configuration.application_name
41+
aimon_payload['model_name'] = self.react_configuration.model_name
42+
43+
return aimon_payload
1644

17-
def react( llm_app,
18-
user_query,
19-
user_instructions,
20-
context_extractor,
21-
react_configuration,
22-
):
45+
46+
## ReAct -> Reason and Act
47+
def react(self, user_query, user_instructions,):
48+
49+
llm_response = self.llm_app(user_query, user_instructions, reprompted_flag=False)
50+
51+
context = self.context_extractor(user_query, user_instructions, llm_response)
52+
53+
## Generated text for LLM Response, if the user employs the LlamaIndex framework
54+
if llm_response.response or self.react_configuration.framework=="llamaindex":
55+
generated_text = llm_response.response
56+
else:
57+
generated_text = llm_response
2358

24-
detect = Detect(values_returned = react_configuration.values_returned,
25-
api_key = react_configuration.aimon_api_key,
26-
config = react_configuration.aimon_config,
27-
publish = react_configuration.publish,
28-
application_name = react_configuration.application_name,
29-
model_name = react_configuration.model_name,
30-
)
31-
32-
llm_response = llm_app(user_query, user_instructions, reprompted_flag=False)
33-
34-
## Decorating the context_extractor function with AIMon's "detect"
35-
context_extractor = detect(context_extractor)
59+
aimon_payload = self.create_payload(context, user_query, user_instructions, generated_text)
3660

37-
_, _, _, query_result, aimon_response = context_extractor(user_query, user_instructions, llm_response)
61+
detect_response = self.client.inference.detect(body=[aimon_payload])
62+
63+
for _ in range(self.react_configuration.max_attempts):
3864

39-
for _ in range(react_configuration.max_attempts):
65+
failed_instructions = []
66+
## Loop to check for failed instructions
67+
for x in detect_response.instruction_adherence['results']:
68+
if x['adherence'] == False:
69+
failed_instructions.append(x['instruction'])
4070

41-
failed_instructions = []
42-
## Loop to check for failed instructions
43-
for x in aimon_response.detect_response.instruction_adherence['results']:
44-
if x['adherence'] == False:
45-
failed_instructions.append(x['instruction'])
71+
hallucination_score = detect_response.hallucination['score']
4672

47-
hallucination_score = aimon_response.detect_response.hallucination['score']
73+
## Check whether the hallucination score is greater than the required threshold OR if any of the supplied instructions are not complied with
74+
if self.react_configuration.hallucination_threshold > 0 and \
75+
(hallucination_score > self.react_configuration.hallucination_threshold or len(failed_instructions)>0):
76+
77+
llm_response = self.llm_app(user_query, user_instructions, reprompted_flag=True, hallucination_score=hallucination_score)
78+
79+
context = self.context_extractor(user_query, user_instructions, llm_response)
80+
81+
## Generated text for LLM Response, if the user employs the LlamaIndex framework
82+
if llm_response.response or self.react_configuration.framework=="llamaindex":
83+
generated_text = llm_response.response
84+
else:
85+
generated_text = llm_response
4886

49-
## Check whether the hallucination score is greater than the required threshold OR if any of the supplied instructions are not complied with
50-
if react_configuration.hallucination_threshold > 0 and \
51-
(hallucination_score > react_configuration.hallucination_threshold or len(failed_instructions)>0):
52-
53-
llm_response = llm_app(user_query, user_instructions, reprompted_flag=True, hallucination_score=hallucination_score)
87+
new_aimon_payload = self.create_payload(context, user_query, user_instructions, generated_text)
5488

55-
_, _, _, query_result, aimon_response = context_extractor(user_query, user_instructions, llm_response)
89+
detect_response = self.client.inference.detect(body=[new_aimon_payload])
5690

57-
if hallucination_score > react_configuration.hallucination_threshold:
58-
return f"The generated LLM response, even after {react_configuration.max_attempts} attempts of ReAct is still hallucinated. The response: {query_result}"
91+
if hallucination_score > self.react_configuration.hallucination_threshold:
92+
return f"The generated LLM response, even after {self.react_configuration.max_attempts} attempts of ReAct is still hallucinated. The response: {generated_text}"
5993

60-
return query_result
94+
return generated_text

0 commit comments

Comments
 (0)