Skip to content

Commit 1b054fb

Browse files
authored
Merge pull request #28 from iryna-kondr/hotfix_ml_classifier
Improved multi-label classifier compatibility with gpt4all
2 parents b2bedb1 + 08e4b6f commit 1b054fb

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

skllm/models/gpt_zero_shot_clf.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,11 @@ def _get_prompt(self, x) -> str:
9898
def _predict_single(self, x):
9999
completion = self._get_chat_completion(x)
100100
try:
101-
if self.openai_model.startswith("gpt4all::"):
102-
label = str(
103-
extract_json_key(
104-
completion["choices"][0]["message"]["content"], "label"
105-
)
106-
)
107-
else:
108-
label = str(
109-
extract_json_key(completion.choices[0].message["content"], "label")
101+
label = str(
102+
extract_json_key(
103+
completion["choices"][0]["message"]["content"], "label"
110104
)
105+
)
111106
except Exception as e:
112107
print(completion)
113108
print(f"Could not extract the label from the completion: {str(e)}")
@@ -155,10 +150,13 @@ def _get_prompt(self, x) -> str:
155150
def _predict_single(self, x):
156151
completion = self._get_chat_completion(x)
157152
try:
158-
labels = extract_json_key(completion.choices[0].message["content"], "label")
153+
labels = extract_json_key(completion["choices"][0]["message"]["content"], "label")
159154
if not isinstance(labels, list):
160-
raise RuntimeError("Invalid labels type, expected list")
161-
except Exception:
155+
labels = labels.split(",")
156+
labels = [l.strip() for l in labels]
157+
except Exception as e:
158+
print(completion)
159+
print(f"Could not extract the label from the completion: {str(e)}")
162160
labels = []
163161

164162
labels = list(filter(lambda l: l in self.classes_, labels))

skllm/openai/chatgpt.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,18 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3
3434

3535

3636
def extract_json_key(json_, key):
37-
try:
38-
json_ = json_.replace("\n", "")
39-
json_ = find_json_in_string(json_)
40-
as_json = json.loads(json_)
41-
if key not in as_json.keys():
42-
raise KeyError("The required key was not found")
43-
return as_json[key]
44-
except Exception:
45-
return None
37+
original_json = json_
38+
for i in range(2):
39+
try:
40+
json_ = original_json.replace("\n", "")
41+
if i == 1:
42+
json_ = json_.replace("'", "\"")
43+
json_ = find_json_in_string(json_)
44+
as_json = json.loads(json_)
45+
if key not in as_json.keys():
46+
raise KeyError("The required key was not found")
47+
return as_json[key]
48+
except Exception:
49+
if i == 0:
50+
continue
51+
return None

skllm/prompts/templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
Perform the following tasks:
4545
1. Identify to which categories the provided text belongs to with the highest probability.
4646
2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.
47-
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON.
47+
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON.
4848
4949
List of categories: {labels}
5050

0 commit comments

Comments
 (0)