|
1 | 1 | from functools import wraps
|
| 2 | +import os |
2 | 3 |
|
3 |
| -from .common import AimonClientSingleton |
| 4 | +from aimon import Client |
| 5 | +from .evaluate import Application, Model |
4 | 6 |
|
| 7 | +class DetectResult: |
| 8 | + """ |
| 9 | + A class to represent the result of an AIMon detection operation. |
| 10 | +
|
| 11 | + This class encapsulates the status of the detection operation, the response from the detection service, |
| 12 | + and optionally, the response from publishing the result to the AIMon UI. |
| 13 | +
|
| 14 | + Attributes: |
| 15 | + ----------- |
| 16 | + status : int |
| 17 | + The HTTP status code of the detection operation. |
| 18 | + detect_response : object |
| 19 | + The response object from the AIMon synchronous detection. |
| 20 | + publish_response : list, optional |
| 21 | + The response from publishing the result to the AIMon UI, if applicable. This is also |
| 22 | + populated when the detect operation is run in async mode. |
| 23 | +
|
| 24 | + Methods: |
| 25 | + -------- |
| 26 | + __str__() |
| 27 | + Returns a string representation of the DetectResult object. |
| 28 | + __repr__() |
| 29 | + Returns a string representation of the DetectResult object (same as __str__). |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, status, detect_response, publish=None): |
| 33 | + self.status = status |
| 34 | + self.detect_response = detect_response |
| 35 | + self.publish_response = publish if publish is not None else [] |
| 36 | + |
| 37 | + def __str__(self): |
| 38 | + return f"DetectResult(status={self.status}, detect_response={self.detect_response}, publish_response={self.publish_response})" |
| 39 | + |
| 40 | + def __repr__(self): |
| 41 | + return str(self) |
5 | 42 |
|
6 | 43 | class Detect:
|
| 44 | + """ |
| 45 | + A decorator class for detecting various qualities in LLM-generated text using AIMon's detection services. |
| 46 | +
|
| 47 | + This decorator wraps a function that generates text using an LLM and sends the generated text |
| 48 | + along with context to AIMon for analysis. It can be used in both synchronous and asynchronous modes, |
| 49 | + and optionally publishes results to the AIMon UI. |
| 50 | +
|
| 51 | + Parameters: |
| 52 | + ----------- |
| 53 | + values_returned : list |
| 54 | + A list of values in the order returned by the decorated function. |
| 55 | + Acceptable values are 'generated_text', 'context', 'user_query', 'instructions'. |
| 56 | + api_key : str, optional |
| 57 | + The API key to use for the AIMon client. If not provided, it will attempt to use the AIMON_API_KEY environment variable. |
| 58 | + config : dict, optional |
| 59 | + A dictionary of configuration options for the detector. Defaults to {'hallucination': {'detector_name': 'default'}}. |
| 60 | + async_mode : bool, optional |
| 61 | + If True, the detect() function will return immediately with a DetectResult object. Default is False. |
| 62 | + publish : bool, optional |
| 63 | + If True, the payload will be published to AIMon and can be viewed on the AIMon UI. Default is False. |
| 64 | + application_name : str, optional |
| 65 | + The name of the application to use when publish is True. |
| 66 | + model_name : str, optional |
| 67 | + The name of the model to use when publish is True. |
| 68 | +
|
| 69 | + Example: |
| 70 | + -------- |
| 71 | + >>> from aimon.decorators import Detect |
| 72 | + >>> import os |
| 73 | + >>> |
| 74 | + >>> # Configure the detector |
| 75 | + >>> detect = Detect( |
| 76 | + ... values_returned=['context', 'generated_text', 'user_query'], |
| 77 | + ... api_key=os.getenv('AIMON_API_KEY'), |
| 78 | + ... config={ |
| 79 | + ... 'hallucination': {'detector_name': 'default'}, |
| 80 | + ... 'toxicity': {'detector_name': 'default'} |
| 81 | + ... }, |
| 82 | + ... publish=True, |
| 83 | + ... application_name='my_summarization_app', |
| 84 | + ... model_name='gpt-3.5-turbo' |
| 85 | + ... ) |
| 86 | + >>> |
| 87 | + >>> # Define a simple lambda function to simulate an LLM |
| 88 | + >>> your_llm_function = lambda context, query: f"Summary of '{context}' based on query: {query}" |
| 89 | + >>> |
| 90 | + >>> # Use the decorator on your LLM function |
| 91 | + >>> @detect |
| 92 | + ... def generate_summary(context, query): |
| 93 | + ... summary = your_llm_function(context, query) |
| 94 | + ... return context, summary, query |
| 95 | + >>> |
| 96 | + >>> # Use the decorated function |
| 97 | + >>> context = "The quick brown fox jumps over the lazy dog." |
| 98 | + >>> query = "Summarize the given text." |
| 99 | + >>> context, summary, query, aimon_result = generate_summary(context, query) |
| 100 | + >>> |
| 101 | + >>> # Print the generated summary |
| 102 | + >>> print(f"Generated summary: {summary}") |
| 103 | + >>> |
| 104 | + >>> # Check the AIMon detection results |
| 105 | + >>> print(f"Hallucination score: {aimon_result.detect_response.hallucination['score']}") |
| 106 | + >>> print(f"Toxicity score: {aimon_result.detect_response.toxicity['score']}") |
| 107 | + """ |
7 | 108 | DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}}
|
8 | 109 |
|
9 |
| - def __init__(self, values_returned, api_key=None, config=None): |
| 110 | + def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None): |
10 | 111 | """
|
11 | 112 | :param values_returned: A list of values in the order returned by the decorated function
|
12 | 113 | Acceptable values are 'generated_text', 'context', 'user_query', 'instructions'
|
| 114 | + :param api_key: The API key to use for the AIMon client |
| 115 | + :param config: A dictionary of configuration options for the detector |
| 116 | + :param async_mode: Boolean, if True, the detect() function will return immediately with a DetectResult object. Default is False. |
| 117 | + The payload will also be published to AIMon and can be viewed on the AIMon UI. |
| 118 | + :param publish: Boolean, if True, the payload will be published to AIMon and can be viewed on the AIMon UI. Default is False. |
| 119 | + :param application_name: The name of the application to use when publish is True |
| 120 | + :param model_name: The name of the model to use when publish is True |
13 | 121 | """
|
14 |
| - self.client = AimonClientSingleton.get_instance(api_key) |
| 122 | + api_key = os.getenv('AIMON_API_KEY') if not api_key else api_key |
| 123 | + if api_key is None: |
| 124 | + raise ValueError("API key is None") |
| 125 | + self.client = Client(auth_header="Bearer {}".format(api_key)) |
15 | 126 | self.config = config if config else self.DEFAULT_CONFIG
|
16 | 127 | self.values_returned = values_returned
|
17 | 128 | if self.values_returned is None or len(self.values_returned) == 0:
|
18 |
| - raise ValueError("Values returned by the decorated function must be specified") |
19 |
| - if "generated_text" not in self.values_returned: |
20 |
| - raise ValueError("values_returned must contain 'generated_text'") |
| 129 | + raise ValueError("values_returned by the decorated function must be specified") |
21 | 130 | if "context" not in self.values_returned:
|
22 | 131 | raise ValueError("values_returned must contain 'context'")
|
| 132 | + self.async_mode = async_mode |
| 133 | + self.publish = publish |
| 134 | + if self.async_mode: |
| 135 | + self.publish = True |
| 136 | + if self.publish: |
| 137 | + if application_name is None: |
| 138 | + raise ValueError("Application name must be provided if publish is True") |
| 139 | + if model_name is None: |
| 140 | + raise ValueError("Model name must be provided if publish is True") |
| 141 | + self.application = Application(application_name, stage="production") |
| 142 | + self.model = Model(model_name, "text") |
| 143 | + self._initialize_application_model() |
| 144 | + |
| 145 | + def _initialize_application_model(self): |
| 146 | + # Create or retrieve the model |
| 147 | + self._am_model = self.client.models.create( |
| 148 | + name=self.model.name, |
| 149 | + type=self.model.model_type, |
| 150 | + description="This model is named {} and is of type {}".format(self.model.name, self.model.model_type), |
| 151 | + metadata=self.model.metadata |
| 152 | + ) |
| 153 | + |
| 154 | + # Create or retrieve the application |
| 155 | + self._am_app = self.client.applications.create( |
| 156 | + name=self.application.name, |
| 157 | + model_name=self._am_model.name, |
| 158 | + stage=self.application.stage, |
| 159 | + type=self.application.type, |
| 160 | + metadata=self.application.metadata |
| 161 | + ) |
| 162 | + |
| 163 | + def _call_analyze(self, result_dict): |
| 164 | + if "generated_text" not in result_dict: |
| 165 | + raise ValueError("Result of the wrapped function must contain 'generated_text'") |
| 166 | + if "context" not in result_dict: |
| 167 | + raise ValueError("Result of the wrapped function must contain 'context'") |
| 168 | + _context = result_dict['context'] if isinstance(result_dict['context'], list) else [result_dict['context']] |
| 169 | + aimon_payload = { |
| 170 | + "application_id": self._am_app.id, |
| 171 | + "version": self._am_app.version, |
| 172 | + "output": result_dict['generated_text'], |
| 173 | + "context_docs": _context, |
| 174 | + "user_query": result_dict["user_query"] if 'user_query' in result_dict else "No User Query Specified", |
| 175 | + "prompt": result_dict['prompt'] if 'prompt' in result_dict else "No Prompt Specified", |
| 176 | + } |
| 177 | + if 'instructions' in result_dict: |
| 178 | + aimon_payload['instructions'] = result_dict['instructions'] |
| 179 | + if 'actual_request_timestamp' in result_dict: |
| 180 | + aimon_payload["actual_request_timestamp"] = result_dict['actual_request_timestamp'] |
| 181 | + |
| 182 | + aimon_payload['config'] = self.config |
| 183 | + analyze_response = self.client.analyze.create(body=[aimon_payload]) |
| 184 | + return analyze_response |
| 185 | + |
23 | 186 |
|
24 | 187 | def __call__(self, func):
|
25 | 188 | @wraps(func)
|
@@ -50,7 +213,14 @@ def wrapper(*args, **kwargs):
|
50 | 213 |
|
51 | 214 | data_to_send = [aimon_payload]
|
52 | 215 |
|
53 |
| - aimon_response = self.client.inference.detect(body=data_to_send)[0] |
54 |
| - return result + (aimon_response,) |
| 216 | + if self.async_mode: |
| 217 | + analyze_res = self._call_analyze(result_dict) |
| 218 | + return result + (DetectResult(analyze_res.status, analyze_res),) |
| 219 | + else: |
| 220 | + detect_response = self.client.inference.detect(body=data_to_send)[0] |
| 221 | + if self.publish: |
| 222 | + analyze_res = self._call_analyze(result_dict) |
| 223 | + return result + (DetectResult(max(200 if detect_response is not None else 500, analyze_res.status), detect_response, analyze_res),) |
| 224 | + return result + (DetectResult(200 if detect_response is not None else 500, detect_response),) |
55 | 225 |
|
56 | 226 | return wrapper
|
0 commit comments