Skip to content

Commit f320101

Browse files
committed
Fix handling of example CSV in LLM annotator
1 parent 2122510 commit f320101

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

processors/machine_learning/annotate_text.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from common.lib.dmi_service_manager import DmiServiceManager, DmiServiceManagerException, DsmOutOfMemory
1212
from common.lib.exceptions import QueryParametersException
1313
from common.lib.user_input import UserInput
14-
from common.lib.helpers import sniff_encoding
14+
from common.lib.helpers import sniff_encoding, sniff_csv_dialect
1515
from common.config_manager import config
1616

1717
__author__ = "Stijn Peeters"
@@ -141,7 +141,6 @@ def process(self):
141141

142142
model = self.parameters.get("model")
143143
textfield = self.parameters.get("text-column")
144-
labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()}
145144

146145
# Make output dir
147146
staging_area = self.dataset.get_staging_area()
@@ -172,6 +171,28 @@ def process(self):
172171
return self.dataset.finish_with_error(
173172
"Cannot connect to DMI Service Manager. Verify that this 4CAT server has access to it.")
174173

174+
if self.parameters["shotstyle"] == "fewshot":
175+
# do we have examples?
176+
example_path = self.dataset.get_results_path().with_suffix(".importing")
177+
if not example_path.exists():
178+
return self.dataset.finish_with_error("Cannot open example file")
179+
180+
labels = {}
181+
with example_path.open() as infile:
182+
dialect, has_header = sniff_csv_dialect(infile)
183+
reader = csv.reader(infile, dialect=dialect)
184+
for row in reader:
185+
if row[0] not in labels:
186+
labels[row[0]] = []
187+
labels[row[0]].append(row[1])
188+
189+
example_path.unlink()
190+
191+
else:
192+
# if we have no examples, just include an empty list
193+
labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()}
194+
195+
175196
# store labels in a file (since we don't know how much data this is)
176197
labels_path = staging_area.joinpath("labels.temp.json")
177198
with labels_path.open("w") as outfile:
@@ -288,14 +309,10 @@ def validate_query(query, request, user):
288309

289310
# we want a very specific type of CSV file!
290311
encoding = sniff_encoding(file)
291-
292312
wrapped_file = io.TextIOWrapper(file, encoding=encoding)
293313
try:
294-
sample = wrapped_file.read(1024 * 1024)
295314
wrapped_file.seek(0)
296-
has_header = csv.Sniffer().has_header(sample)
297-
dialect = csv.Sniffer().sniff(sample, delimiters=(",", ";", "\t"))
298-
315+
dialect, has_header = sniff_csv_dialect(file)
299316
reader = csv.reader(wrapped_file, dialect=dialect) if not has_header else csv.DictReader(wrapped_file)
300317
row = next(reader)
301318
if len(list(row)) != 2:
@@ -326,7 +343,7 @@ def after_create(query, dataset, request):
326343
if query.get("shotstyle") != "fewshot":
327344
return
328345

329-
file = request.files["option-category_file"]
346+
file = request.files["option-category-file"]
330347
file.seek(0)
331348
with dataset.get_results_path().with_suffix(".importing").open("wb") as outfile:
332349
while True:

0 commit comments

Comments
 (0)