-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathollama_talk_user.txt
333 lines (286 loc) · 12.9 KB
/
ollama_talk_user.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import requests
import json
import ollama
from ollama import Client
import logging
import hashlib
from typing import Dict, Any
from server import PromptServer
from pydub import AudioSegment
from pydub.playback import play
from aiohttp import web
import sys
import os
import time
import glob
class OllamaTalk:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"user_prompt": ("STRING", {"multiline": True}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"max_tokens": ("INT", {"default": 600, "min": 1, "max": 4096}),
"vram_retention_minutes": ("INT", {"default": 0, "min": 0, "max": 99}),
"answer_single_line": ("BOOLEAN", {"default": False}),
"waiting_for_prompt": ("BOOLEAN", {"default": False}),
"use_context_file": ("BOOLEAN", {"default": False}),
"use_context_file_as_user": ("BOOLEAN", {"default": False}),
# "context_size": ("INT", {"default": 0, "min": 0, "max": 1000}),
},
"optional": {
"OLLAMA_CONFIG": ("OLLAMA_CONFIG", {"forceInput": True}),
"context": ("STRING", {"multiline": True, "forceInput": True}),
"OLLAMA_JOB": ("OLLAMA_JOB", {
"forceInput": True
}),
}
}
RETURN_TYPES = ("STRING", "STRING", "STRING")
RETURN_NAMES = ("ollama_response", "updated_context", "system_prompt")
FUNCTION = "chat_response"
CATEGORY = "Bjornulf"
is_paused = True
is_interrupted = False
current_instance = None
def __init__(self):
self.last_content_hash = None
self.waiting = False
self.OLLAMA_CONFIG = None
self.OLLAMA_JOB = None
self.context = ""
self.answer_single_line = True
self.vram_retention_minutes = 1
self.ollama_response = ""
self.widgets = {}
self.use_context_file = False
self.use_context_file_as_user = False
OllamaTalk.current_instance = self
def play_audio(self):
try:
if sys.platform.startswith('win'):
try:
audio_file = os.path.join(os.path.dirname(__file__), 'bell.m4a')
sound = AudioSegment.from_file(audio_file, format="m4a")
wav_io = io.BytesIO()
sound.export(wav_io, format='wav')
wav_data = wav_io.getvalue()
import winsound
winsound.PlaySound(wav_data, winsound.SND_MEMORY)
except Exception as e:
print(f"An error occurred: {e}")
else:
audio_file = os.path.join(os.path.dirname(__file__), 'bell.m4a')
sound = AudioSegment.from_file(audio_file, format="m4a")
play(sound)
except Exception:
pass # Silently handle exceptions, no console output
@classmethod
def IS_CHANGED(cls, waiting_for_prompt, **kwargs):
if waiting_for_prompt:
return float("nan")
return float(0)
def save_context(self, context):
# Save original context
original_path = os.path.join("Bjornulf", "ollama", "ollama_context.txt")
os.makedirs(os.path.dirname(original_path), exist_ok=True)
with open(original_path, "a", encoding="utf-8") as f:
f.write(context + "\n")
# Save swapped context
swapped_path = os.path.join("Bjornulf", "ollama", "ollama_context_user.txt")
os.makedirs(os.path.dirname(swapped_path), exist_ok=True)
# Swap User/Assistant in the context
swapped_context = context
if not os.path.exists(swapped_path):
# Add initial line only if file doesn't exist
swapped_context = "User: Let's start a conversation.\n" + swapped_context
swapped_context = swapped_context.replace("User:", "_TEMP_")
swapped_context = swapped_context.replace("Assistant:", "User:")
swapped_context = swapped_context.replace("_TEMP_", "Assistant:")
with open(swapped_path, "a", encoding="utf-8") as f:
f.write(swapped_context + "\n")
def load_context(self):
os_path = os.path.join("Bjornulf", "ollama", "ollama_context.txt")
if os.path.exists(os_path):
with open(os_path, "r", encoding="utf-8") as f:
return f.read().strip()
return ""
def load_context_user(self):
os_path = os.path.join("Bjornulf", "ollama", "ollama_context_user.txt")
if os.path.exists(os_path):
with open(os_path, "r", encoding="utf-8") as f:
return f.read().strip()
return ""
def process_ollama_request(self, user_prompt, answer_single_line, max_tokens, use_context_file=False):
if self.OLLAMA_CONFIG is None:
self.OLLAMA_CONFIG = {
"model": "llama3.2:3b",
"url": "http://0.0.0.0:11434"
}
selected_model = self.OLLAMA_CONFIG["model"]
ollama_url = self.OLLAMA_CONFIG["url"]
if self.OLLAMA_JOB is None:
OLLAMA_JOB_text = "You are an helpful AI assistant."
else:
OLLAMA_JOB_text = self.OLLAMA_JOB["prompt"]
formatted_prompt = "User: " + user_prompt
if use_context_file:
if self.use_context_file_as_user:
file_context = self.load_context_user()
else:
file_context = self.load_context()
conversation = file_context + "\n" + formatted_prompt if file_context else formatted_prompt
else:
conversation = self.context + "\n" + formatted_prompt if self.context else formatted_prompt
keep_alive_minutes = self.vram_retention_minutes
try:
client = Client(host=ollama_url)
response = client.generate(
model=selected_model,
system=OLLAMA_JOB_text,
prompt=conversation,
options={
"num_ctx": max_tokens
},
keep_alive=f"{keep_alive_minutes}m"
)
result = response['response']
updated_context = conversation + "\nAssistant: " + result
self.context = updated_context
if use_context_file:
self.save_context(formatted_prompt + "\nAssistant: " + result)
if answer_single_line:
result = ' '.join(result.split())
self.ollama_response = result
return result, updated_context
except Exception as e:
logging.error(f"Connection to {ollama_url} failed: {e}")
return "Connection to Ollama failed.", self.context
def chat_response(self, user_prompt, seed, vram_retention_minutes, waiting_for_prompt=False,
context="", OLLAMA_CONFIG=None, OLLAMA_JOB=None, answer_single_line=False,
use_context_file=False, max_tokens=600, context_size=0, use_context_file_as_user=False):
# Store configurations
self.OLLAMA_CONFIG = OLLAMA_CONFIG
self.OLLAMA_JOB = OLLAMA_JOB
self.context = context
self.answer_single_line = answer_single_line
self.vram_retention_minutes = vram_retention_minutes
self.user_prompt = user_prompt
self.max_tokens = max_tokens
self.use_context_file = use_context_file
if waiting_for_prompt:
self.play_audio()
# Wait until either resumed or interrupted
while OllamaTalk.is_paused and not OllamaTalk.is_interrupted:
time.sleep(1)
# Check if we were interrupted
if OllamaTalk.is_interrupted:
OllamaTalk.is_paused = True
OllamaTalk.is_interrupted = False
return ("Interrupted", self.context, self.OLLAMA_JOB["prompt"] if self.OLLAMA_JOB else "")
OllamaTalk.is_paused = True
return (self.ollama_response, self.context, self.OLLAMA_JOB["prompt"] if self.OLLAMA_JOB else "")
# result, updated_context = self.process_ollama_request(user_prompt, answer_single_line, use_context_file)
# return (result, updated_context, OLLAMA_JOB["prompt"] if OLLAMA_JOB else "")
else:
# Direct execution without waiting
result, updated_context = self.process_ollama_request(user_prompt, answer_single_line, max_tokens, use_context_file)
return (result, updated_context, OLLAMA_JOB["prompt"] if OLLAMA_JOB else "")
@PromptServer.instance.routes.post("/bjornulf_ollama_send_prompt")
async def resume_node(request):
if OllamaTalk.current_instance:
instance = OllamaTalk.current_instance
# Get the data from the request
data = await request.json()
updated_prompt = data.get('user_prompt')
# Use the updated_prompt directly if it's not None
prompt_to_use = updated_prompt if updated_prompt is not None else instance.user_prompt
result, updated_context = instance.process_ollama_request(
prompt_to_use,
instance.answer_single_line,
instance.max_tokens,
use_context_file=instance.use_context_file # Ensure this is set to True
)
OllamaTalk.is_paused = False
return web.Response(text="Node resumed")
return web.Response(text="No active instance", status=400)
@PromptServer.instance.routes.post("/get_current_context_size")
async def get_current_context_size(request):
counter_file = os.path.join("Bjornulf", "ollama", "ollama_context.txt")
try:
if not os.path.exists(counter_file):
logging.info("Context file does not exist")
return web.json_response({"success": True, "value": 0}, status=200)
with open(counter_file, 'r', encoding='utf-8') as f:
# Count non-empty lines in the file
lines = [line.strip() for line in f.readlines() if line.strip()]
line_count = len(lines)
logging.info(f"Found {line_count} lines in context file")
return web.json_response({"success": True, "value": line_count}, status=200)
except Exception as e:
logging.error(f"Error reading context size: {str(e)}")
return web.json_response({
"success": False,
"error": str(e),
"value": 0
}, status=500)
def get_next_filename(base_path, base_name):
"""
Find the next available filename with format base_name.XXX.txt
where XXX is a 3-digit number starting from 001
"""
pattern = os.path.join(base_path, f"{base_name}.[0-9][0-9][0-9].txt")
existing_files = glob.glob(pattern)
if not existing_files:
return f"{base_name}.001.txt"
# Extract numbers from existing files and find the highest
numbers = []
for f in existing_files:
try:
num = int(f.split('.')[-2])
numbers.append(num)
except (ValueError, IndexError):
continue
next_number = max(numbers) + 1 if numbers else 1
return f"{base_name}.{next_number:03d}.txt"
@PromptServer.instance.routes.post("/reset_lines_context")
def reset_lines_context(request):
logging.info("Reset lines counter called")
base_dir = os.path.join("Bjornulf", "ollama")
base_file = "ollama_context"
counter_file = os.path.join(base_dir, f"{base_file}.txt")
try:
if os.path.exists(counter_file):
# Get new filename and rename
new_filename = os.path.join(base_dir, get_next_filename(base_dir, base_file))
os.rename(counter_file, new_filename)
logging.info(f"Renamed {counter_file} to {new_filename}")
# Send notification through ComfyUI
notification = {
"ui": {
"notification_text": [f"Context file renamed to: {os.path.basename(new_filename)}"]
}
}
return web.json_response({
"success": True,
**notification
}, status=200)
return web.json_response({
"success": True,
"ui": {
"notification_text": ["No context file to rename"]
}
}, status=200)
except Exception as e:
error_msg = str(e)
return web.json_response({
"success": False,
"error": error_msg,
"ui": {
"notification_text": [f"Error renaming file: {error_msg}"]
}
}, status=500)
@PromptServer.instance.routes.post("/bjornulf_ollama_interrupt")
async def interrupt_node(request):
OllamaTalk.is_interrupted = True
return web.Response(text="Node interrupted")