You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided.
#693
Open
2 of 4 tasks
industrialeaf opened this issue
Sep 6, 2024
· 1 comment
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
Error message:
Traceback (most recent call last):
File "/home/ubuntu/projects/seq2seq/train_t5_small.py", line 48, in <module>
trainer = Seq2SeqNeuronTrainer(
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/miniconda3/envs/py311/lib/python3.11/site-packages/optimum/neuron/trainers.py", line 144, in __init__
raise ValueError(
ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided.
Minimal example to reproduce:
Run the following script with torchrun train.py.
fromtransformersimportT5Tokenizer, AutoModelForSeq2SeqLMfromdatasetsimportload_datasetfromoptimum.neuronimportSeq2SeqNeuronTrainer, Seq2SeqNeuronTrainingArgumentsfromoptimum.neuron.distributedimportlazy_load_for_parallelism# Load datasetdataset=load_dataset("samsum")
# Load tokenizertokenizer=T5Tokenizer.from_pretrained("t5-small")
# Preprocess the datadefpreprocess_function(examples):
inputs= ["summarize: "+docfordocinexamples["dialogue"]]
model_inputs=tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
withtokenizer.as_target_tokenizer():
labels=tokenizer(examples["summary"], max_length=150, truncation=True, padding='max_length')
model_inputs["labels"] =labels["input_ids"]
print("keys", model_inputs.keys())
print("len labels", len(model_inputs['labels']))
print("len inpids", len(model_inputs['input_ids']))
print("len attmsk", len(model_inputs['attention_mask']))
returnmodel_inputstokenized_dataset=dataset.map(preprocess_function, batched=True)
# Define training argumentstraining_args=Seq2SeqNeuronTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=3,
predict_with_generate=False, # should be false since we don't provide a generation_config
)
# Load modelwithlazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model=AutoModelForSeq2SeqLM.from_pretrained("t5-small")
# Initialize the trainertrainer=Seq2SeqNeuronTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
)
# Train the modeltrainer.train()
Expected behavior
The NeuronTrainer accepts Seq2SeqNeuronTrainingArguments.
I have a workaround going where I have patched these lines to accept Seq2SeqNeuronTrainingArguments:
ifnotisinstance(self.args, NeuronTrainingArguments) andnotisinstance(self.args, Seq2SeqNeuronTrainingArguments):
raiseValueError(
f"The NeuronTrainer only accepts NeuronTrainingArguments and Seq2SeqNeuronTrainingArguments, but {type(self.args)} was provided."
)
The text was updated successfully, but these errors were encountered:
System Info
Who can help?
@michaelbenayoun
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction (minimal, reproducible, runnable)
Error message:
Minimal example to reproduce:
Run the following script with
torchrun train.py
.Expected behavior
The NeuronTrainer accepts Seq2SeqNeuronTrainingArguments.
I have a workaround going where I have patched these lines to accept
Seq2SeqNeuronTrainingArguments
:The text was updated successfully, but these errors were encountered: