|
11 | 11 | from common.lib.dmi_service_manager import DmiServiceManager, DmiServiceManagerException, DsmOutOfMemory
|
12 | 12 | from common.lib.exceptions import QueryParametersException
|
13 | 13 | from common.lib.user_input import UserInput
|
14 |
| -from common.lib.helpers import sniff_encoding |
| 14 | +from common.lib.helpers import sniff_encoding, sniff_csv_dialect |
15 | 15 | from common.config_manager import config
|
16 | 16 |
|
17 | 17 | __author__ = "Stijn Peeters"
|
@@ -141,7 +141,6 @@ def process(self):
|
141 | 141 |
|
142 | 142 | model = self.parameters.get("model")
|
143 | 143 | textfield = self.parameters.get("text-column")
|
144 |
| - labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()} |
145 | 144 |
|
146 | 145 | # Make output dir
|
147 | 146 | staging_area = self.dataset.get_staging_area()
|
@@ -172,6 +171,28 @@ def process(self):
|
172 | 171 | return self.dataset.finish_with_error(
|
173 | 172 | "Cannot connect to DMI Service Manager. Verify that this 4CAT server has access to it.")
|
174 | 173 |
|
| 174 | + if self.parameters["shotstyle"] == "fewshot": |
| 175 | + # do we have examples? |
| 176 | + example_path = self.dataset.get_results_path().with_suffix(".importing") |
| 177 | + if not example_path.exists(): |
| 178 | + return self.dataset.finish_with_error("Cannot open example file") |
| 179 | + |
| 180 | + labels = {} |
| 181 | + with example_path.open() as infile: |
| 182 | + dialect, has_header = sniff_csv_dialect(infile) |
| 183 | + reader = csv.reader(infile, dialect=dialect) |
| 184 | + for row in reader: |
| 185 | + if row[0] not in labels: |
| 186 | + labels[row[0]] = [] |
| 187 | + labels[row[0]].append(row[1]) |
| 188 | + |
| 189 | + example_path.unlink() |
| 190 | + |
| 191 | + else: |
| 192 | + # if we have no examples, just include an empty list |
| 193 | + labels = {l.strip(): [] for l in self.parameters.get("categories").split(",") if l.strip()} |
| 194 | + |
| 195 | + |
175 | 196 | # store labels in a file (since we don't know how much data this is)
|
176 | 197 | labels_path = staging_area.joinpath("labels.temp.json")
|
177 | 198 | with labels_path.open("w") as outfile:
|
@@ -288,14 +309,10 @@ def validate_query(query, request, user):
|
288 | 309 |
|
289 | 310 | # we want a very specific type of CSV file!
|
290 | 311 | encoding = sniff_encoding(file)
|
291 |
| - |
292 | 312 | wrapped_file = io.TextIOWrapper(file, encoding=encoding)
|
293 | 313 | try:
|
294 |
| - sample = wrapped_file.read(1024 * 1024) |
295 | 314 | wrapped_file.seek(0)
|
296 |
| - has_header = csv.Sniffer().has_header(sample) |
297 |
| - dialect = csv.Sniffer().sniff(sample, delimiters=(",", ";", "\t")) |
298 |
| - |
| 315 | + dialect, has_header = sniff_csv_dialect(file) |
299 | 316 | reader = csv.reader(wrapped_file, dialect=dialect) if not has_header else csv.DictReader(wrapped_file)
|
300 | 317 | row = next(reader)
|
301 | 318 | if len(list(row)) != 2:
|
@@ -326,7 +343,7 @@ def after_create(query, dataset, request):
|
326 | 343 | if query.get("shotstyle") != "fewshot":
|
327 | 344 | return
|
328 | 345 |
|
329 |
| - file = request.files["option-category_file"] |
| 346 | + file = request.files["option-category-file"] |
330 | 347 | file.seek(0)
|
331 | 348 | with dataset.get_results_path().with_suffix(".importing").open("wb") as outfile:
|
332 | 349 | while True:
|
|
0 commit comments