-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathText Gen Webui in Langflow
214 lines (183 loc) · 12.4 KB
/
Text Gen Webui in Langflow
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import requests
import json
import logging
from langflow.custom import Component
from langflow.inputs import (
MessageTextInput,
StrInput,
SecretStrInput,
IntInput,
FloatInput,
BoolInput,
)
from langflow.template import Output
from langflow.schema.message import Message
from typing import Dict, Any, List, Optional
# Configure logging.
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class LocalModel(Component):
"""
Langflow component for interacting with the oobabooga text-generation-webui API.
oobabooga/text-generation-webui - A Gradio web UI for Large Language Models
https://github.com/oobabooga/text-generation-webui
This component sends a prompt to text-generation-webui and returns the generated text.
It provides control over the generation parameters, allowing you to customize the behavior of the LLM
directly within your Langflow flows. Secure API key handling is implemented using SecretStrInput.
"""
display_name = "text-generation-webui" # Name displayed in the Langflow UI
description = "Connects to oobabooga text-generation-webui"
documentation = "https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API" # oobabooga API documentation
icon = "prompts" # Icon for the component in Langflow
name = "text-generation-webui" # Internal name
inputs = [ # Input: User sets the parameters sent to text-generation-webui from Langflow.
MessageTextInput( # The prompt to send to the LLM.
name="prompt", display_name="Prompt", value="What is the meaning of life?",
info="The prompt text sent to the model."
),
StrInput( # The URL of the Oobabooga API endpoint.
name="api_url", display_name="API URL", value="http://127.0.0.1:5000/v1/completions",
info="API endpoint URL for text-generation-webui."
),
SecretStrInput( # The API key for authentication.
name="api_key", display_name="API Key",
info="Your text-generation-webui API key."
),
# --- Generation Input Parameters ---
IntInput( # Input: Maximum number of new tokens to generate
name="max_tokens", display_name="Max Tokens", value=250, required=True,
info="The maximum number of new tokens to generate.",
),
FloatInput( # Controls randomness (0.0 - deterministic, 1.0 - very random)
name="temperature", display_name="Temperature", value=0.7, required=False,
info="Control randomness (0.0-1.0 - 0=deterministic, 1=very random)."
),
# BoolInput(name="temperature_last", display_name="Temperature Last", value=False, info="Apply temperature only to last token."),
# BoolInput(name="dynamic_temperature", display_name="Dynamic Temperature", value=False, info="Enable dynamic temperature."),
# FloatInput(name="dynatemp_low", display_name="Dynatemp Low", value=1.0, required=False, info="Lower bound for dynamic temperature."),
# FloatInput(name="dynatemp_high", display_name="Dynatemp High", value=1.0, required=False, info="Upper bound for dynamic temperature."),
# FloatInput(name="dynatemp_exponent", display_name="Dynatemp Exponent", value=1.0, required=False, info="Exponent for dynamic temperature."),
# FloatInput(name="smoothing_factor", display_name="Smoothing Factor", value=0.0, required=False, info="Smoothing factor."),
# FloatInput(name="smoothing_curve", display_name="Smoothing Curve", value=1.0, required=False, info="Smoothing curve."),
FloatInput(name="top_p", display_name="Top P", value=0.37, required=False), # Nucleus sampling parameter
FloatInput(name="min_p", display_name="Min P", value=0.0, required=False, info="Minimum sampling probability."),
IntInput(name="top_k", display_name="Top K", value=100, required=False, info="Sample from top k tokens."),
FloatInput(name="repetition_penalty", display_name="Repetition Penalty", value=1.18, required=False, info="Penalty for repeating tokens."),
# FloatInput(name="presence_penalty", display_name="Presence Penalty", value=0.0, required=False, info="Penalty for new tokens."),
# FloatInput(name="frequency_penalty", display_name="Frequency Penalty", value=0.0, required=False, info="Penalty for frequent tokens."),
# IntInput(name="repetition_penalty_range", display_name="Repetition Penalty Range", value=1024, required=False, info="Range for repetition penalty."),
FloatInput(name="typical_p", display_name="Typical P", value=1.0, required=False, info="Typical sampling probability."),
# FloatInput(name="tfs", display_name="TFS", value=1.0, required=False, info="Tail free sampling."),
# FloatInput(name="top_a", display_name="Top A", value=0.0, required=False, info="Top-a sampling."),
# FloatInput(name="guidance_scale", display_name="Guidance Scale", value=1.0, required=False, info="Guidance scale."),
# FloatInput(name="penalty_alpha", display_name="Penalty Alpha", value=0.0, required=False, info="Penalty alpha."),
# IntInput(name="mirostat_mode", display_name="Mirostat Mode", value=0, required=False, info="Mirostat mode."),
# FloatInput(name="mirostat_tau", display_name="Mirostat Tau", value=5.0, required=False, info="Mirostat tau."),
# FloatInput(name="mirostat_eta", display_name="Mirostat Eta", value=0.1, required=False, info="Mirostat eta."),
# BoolInput(name="do_sample", display_name="Do Sample", value=True, info="Whether to sample."),
# FloatInput(name="encoder_repetition_penalty", display_name="Encoder Repetition Penalty", value=1.0, required=False, info="Encoder repetition penalty."),
# IntInput(name="no_repeat_ngram_size", display_name="No Repeat N-gram Size", value=0, required=False, info="Size of n-grams to avoid."),
# FloatInput(name="dry_multiplier", display_name="Dry Multiplier", value=0.0, required=False, info="Dry multiplier."), # Type corrected to float
# FloatInput(name="dry_base", display_name="Dry Base", value=1.75, required=False, info="Dry base value."), # Type corrected to float
# IntInput(name="dry_allowed_length", display_name="Dry Allowed Length", value=2, required=False, info="Dry allowed length."),
# StrInput( # Input: Comma-separated string of dry sequence breakers
# name="dry_sequence_breakers", display_name="Dry Sequence Breakers",
# value='"\n", ":", "\\"", "*"', # Set your defaults here. Comma-separated string.
# required=False, info='Comma-separated list of dry sequence breakers (e.g., "\\n", ":", "\\"", "*").',
# ),
StrInput( # Comma-separated string of stop sequences. Default to an empty string.
name="stop_sequences", display_name="Stop Sequences", value="", required=False,
info="Comma-separated list of stop sequences.",
),
# --- End of Generation Parameters ---
] # End of inputs list
outputs = [ # Output: The generated text from the model as a Langflow Message.
Output(
display_name="Message Output", name="message_output", method="generate_text", description="The generated text."
)
]
def build_headers(self) -> Dict[str, str]:
"""Builds the headers for the API request, including authentication."""
headers = {"Content-Type": "application/json"} # Set content type to JSON
if self.api_key: # If an API key is provided...
headers[
"Authorization"
] = f"Bearer {self.api_key}" # ...include it in the Authorization header (Bearer token format)
return headers
def mask_sensitive_data(self, data: Dict[str, Any], sensitive_keys: List[str]) -> Dict[str, Any]:
"""Masks sensitive data (like API keys) in a dictionary before logging."""
masked_data = data.copy() # Create a copy to avoid modifying the original data
for key in sensitive_keys: # Iterate through the keys to mask...
if key in masked_data: # ...if a key is present in the data...
masked_data[key] = "***MASKED***" # ...replace its value with "***MASKED***"
return masked_data
def build_payload(self) -> Dict[str, Any]:
"""Constructs the JSON payload for the API request to text-generation-webui."""
payload = {
"prompt": self.prompt,
# All parameters from GENERATE_PARAMS
# Use what user provides in Langflow
"max_tokens": self.max_tokens,
"temperature": self.temperature,
# "temperature_last": self.temperature_last,
# "dynamic_temperature": self.dynamic_temperature,
# "dynatemp_low": self.dynatemp_low,
# "dynatemp_high": self.dynatemp_high,
# "dynatemp_exponent": self.dynatemp_exponent,
# "smoothing_factor": self.smoothing_factor,
# "smoothing_curve": self.smoothing_curve,
"top_p": self.top_p,
"min_p": self.min_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
# "presence_penalty": self.presence_penalty,
# "frequency_penalty": self.frequency_penalty,
# "repetition_penalty_range": self.repetition_penalty_range,
"typical_p": self.typical_p,
# "tfs": self.tfs,
# "top_a": self.top_a,
# "guidance_scale": self.guidance_scale,
# "penalty_alpha": self.penalty_alpha,
# "mirostat_mode": self.mirostat_mode,
# "mirostat_tau": self.mirostat_tau,
# "mirostat_eta": self.mirostat_eta,
# "do_sample": self.do_sample,
# "encoder_repetition_penalty": self.encoder_repetition_penalty,
# "no_repeat_ngram_size": self.no_repeat_ngram_size,
# "dry_multiplier": self.dry_multiplier,
# "dry_base": self.dry_base,
# "dry_allowed_length": self.dry_allowed_length,
# "dry_sequence_breakers": self.dry_sequence_breakers,
}
if self.stop_sequences:
payload["stop"] = self.stop_sequences # Add stop sequences if provided
return payload # Return the constructed payload
def generate_text(self) -> Message:
"""Sends the API request to text-generation-webui and processes the response."""
headers = self.build_headers() # Prepare request headers
payload = self.build_payload() # Prepare request payload
logger.info(f"Payload being sent: {json.dumps(payload, indent=4)}") # Pretty-print JSON
try: # Logging statements in Langflow terminal
masked_headers = self.mask_sensitive_data(headers, ["Authorization"]) # Mask API key before logging
logger.info(f"Sending request to: {self.api_url}")
logger.info(f"Request headers: {masked_headers}")
logger.info(f"Request payload: {payload}") # Log the complete payload
response = requests.post(self.api_url, json=payload, headers=headers, timeout=60)
response.raise_for_status() # Check for HTTP errors (4xx or 5xx)
response_json = response.json() # Parse the JSON response
sensitive_response_keys = [] # Customize with keys to mask in the response if needed
masked_response = self.mask_sensitive_data(response_json, sensitive_response_keys)
logger.info(f"Model response: {masked_response}") # Log the (potentially masked) response
generated_text = response_json.get("choices", [{}])[0].get("text", "") # Extract the generated text from the response
# The 'if' statement should be INSIDE the 'try' block
if not generated_text: # Handle empty or incomplete responses
warning_msg = "Model returned an empty or incomplete response."
logger.warning(warning_msg)
return Message(text=warning_msg) # Return a message indicating the issue
return Message(text=generated_text) # Return the generated text as a Langflow Message
except requests.exceptions.RequestException as e: # Handle request errors
logger.error(f"Request Error: {e}")
return Message(text=f"Request Error: {e}") # Return an error message
except (ValueError, KeyError) as e: # Handle JSON/data processing errors
logger.error(f"Data Processing Error: {e}")
return Message(text=f"Data Processing Error: {e}") # Return an error message