Skip to content

Commit 7cbb460

Browse files
committed
Add dataset_output_field_values_to_texts option
1 parent 62598bb commit 7cbb460

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

recipes/A5000_24GB_x8/fake-news-detector-en-1.5T.yaml

+8-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ output_base_dir: /data/output
55
dataset_id: mrm8488/fake-news
66
dataset_input_field_name: text
77
dataset_output_field_name: label
8+
dataset_output_field_values_to_texts:
9+
0: "Real"
10+
1: "Fake"
811
dataset_train_split_seed: 42
912
dataset_train_split_test_size: 0.2
1013
lora_r: 8
@@ -20,16 +23,16 @@ inference_max_new_tokens: 2
2023
evaluations:
2124
-
2225
prompt: "Donald Trump has never been President of the United States."
23-
expected_output: "1"
26+
expected_output: "Fake"
2427
-
2528
prompt: "The Earth is flat."
26-
expected_output: "1"
29+
expected_output: "Fake"
2730
-
2831
prompt: "Martians visited Japan in 2011."
29-
expected_output: "1"
32+
expected_output: "Fake"
3033
-
3134
prompt: "The World Trade Center collapsed when the plane hit it."
32-
expected_output: "0"
35+
expected_output: "Real"
3336
-
34-
expected_output: "0"
3537
prompt: "The United States is a country in North America."
38+
expected_output: "Real"

recipes/A5000_24GB_x8/fake-news-detector-en.yaml

+8-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ output_base_dir: /data/output
55
dataset_id: mrm8488/fake-news
66
dataset_input_field_name: text
77
dataset_output_field_name: label
8+
dataset_output_field_values_to_texts:
9+
0: "Real"
10+
1: "Fake"
811
dataset_train_split_seed: 42
912
dataset_train_split_test_size: 0.2
1013
lora_r: 8
@@ -20,16 +23,16 @@ inference_max_new_tokens: 2
2023
evaluations:
2124
-
2225
prompt: "Donald Trump has never been President of the United States."
23-
expected_output: "1"
26+
expected_output: "Fake"
2427
-
2528
prompt: "The Earth is flat."
26-
expected_output: "1"
29+
expected_output: "Fake"
2730
-
2831
prompt: "Martians visited Japan in 2011."
29-
expected_output: "1"
32+
expected_output: "Fake"
3033
-
3134
prompt: "The World Trade Center collapsed when the plane hit it."
32-
expected_output: "0"
35+
expected_output: "Real"
3336
-
34-
expected_output: "0"
3537
prompt: "The United States is a country in North America."
38+
expected_output: "Real"

src/train.py

+5
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def prepare_train_data(dataset_id):
125125
data_df["text"] = data_df[input_field_name].apply(lambda x: simple_template_for_pretrain(x))
126126
else:
127127
output_field_name = train_config["dataset_output_field_name"]
128+
if "dataset_output_field_values_to_texts" in train_config:
129+
output_field_values_to_texts = train_config["dataset_output_field_values_to_texts"]
130+
data_df[output_field_name] = data_df[output_field_name].apply(
131+
lambda x: output_field_values_to_texts.get(x, x)
132+
)
128133
if "dataset_context_field_name" in train_config:
129134
context_field_name = train_config["dataset_context_field_name"]
130135
if "dataset_context_hint" not in train_config:

0 commit comments

Comments
 (0)