Skip to content

Commit 85e95e2

Browse files
committed
wikipediaを間引く
1 parent d017728 commit 85e95e2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/train.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def prepare_train_data(dataset_id):
113113
if "dataset_load_config" in train_config:
114114
dataset_load_config = train_config["dataset_load_config"]
115115
data = load_dataset(dataset_id, dataset_load_config, split="train", num_proc=32)
116-
if dataset_load_config == "20231101.ja" or dataset_load_config == "20231101.vi" or dataset_load_config == "20231101.es":
116+
if dataset_load_config == "20231101.ja" or dataset_load_config == "20231101.vi" or dataset_load_config == "20231101.es" or dataset_load_config == "20231101.it":
117117
data = data.filter(lambda item, idx: idx % 3 == 0, with_indices=True)
118-
if dataset_load_config == "20231101.de":
118+
if dataset_load_config == "20231101.de" or dataset_load_config == "20231101.fr":
119119
data = data.filter(lambda item, idx: idx % 5 == 0, with_indices=True)
120120
else:
121121
data = load_dataset(dataset_id, split="train", num_proc=32)
@@ -162,7 +162,8 @@ def prepare_train_data(dataset_id):
162162
lambda x: simple_template_for_train(x[input_field_name], x[output_field_name]),
163163
axis=1,
164164
)
165-
165+
# keep only text field
166+
data = data_df[["text"]]
166167
data = Dataset.from_pandas(data_df)
167168
data = data.train_test_split(seed=42, test_size=0.2)
168169
print(len(data["train"]))
@@ -281,7 +282,7 @@ def load_model_and_tokenizer(model_id):
281282
args=training_arguments,
282283
tokenizer=tokenizer,
283284
packing=False,
284-
max_seq_length=512,
285+
max_seq_length=1024,
285286
)
286287

287288
#

0 commit comments

Comments
 (0)