Skip to content

Commit

Permalink
add sql prompt supprt, fix deepspeed zero3 save issue
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Dec 8, 2023
1 parent 3c8678d commit 773842b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
33 changes: 32 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,40 @@ python ../gaudi_spawn.py \
--dataset_concatenation \
--max_seq_length 512 \
--ddp_bucket_cap_mb 50 \
--adam_epsilon 1e-08
--adam_epsilon 1e-08 \
--low_cpu_mem_usage True
```

- Multi-card finetuning of codegen-16B-mono:
```bash
python ../gaudi_spawn.py \
--world_size 8 --use_mpi run_lora_clm.py \
--model_name_or_path Salesforce/codegen-16B-mono \
--dataset_name b-mc2/sql-create-context \
--sql_prompt \
--bf16 True \
--output_dir ./finetuned-models/codegen-finetune-on-sql-create-context-hpu8-lora8-bs4 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-4 \
--logging_steps 1 \
--dataset_concatenation \
--do_train \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 3 \
--use_hpu_graphs_for_inference \
--lora_target_modules "qkv_proj" \
--lora_rank 8 \
--do_eval \
--validation_split_percentage 10 \
--use_cache False \

```

- Multi-card finetuning of Falcon-40B:
```bash
LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
Expand Down
3 changes: 2 additions & 1 deletion examples/language-modeling/llama2_ds_zero3_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
"contiguous_gradients": false
"contiguous_gradients": false,
"stage3_gather_16bit_weights_on_model_save": true
}
}
37 changes: 27 additions & 10 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
DataCollatorForLanguageModeling,
HfArgumentParser,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
Expand Down Expand Up @@ -230,7 +229,10 @@ class DataArguments:
default=False,
metadata={"help": "Whether to concatenate the sentence for more efficient training."},
)

sql_prompt: bool = field(
default=False,
metadata={"help": "indicate it's sql style prompt"},
)

@dataclass
class FinetuneArguments:
Expand Down Expand Up @@ -273,6 +275,13 @@ class FinetuneArguments:
),
}

SQL_PROMPT = (
"You are a text-to-SQL model. Your job is to answer questions about a database. "
"You are given a question and a context regarding one or more tables in the database.\n\n"
"You must output the SQL query that answers the question. The SQL query must be between [SQL] and [/SQL] tags.\n\n"
"### Question: \n{question}\n\n### Context: \n{context}\n\n### Response:"
)


def create_prompts(examples):
prompts = {}
Expand All @@ -288,6 +297,17 @@ def create_prompts(examples):
return prompts


def create_sql_prompts(examples):
prompts = {}
prompts["source"] = []
prompts["target"] = []
for example in examples:
source = SQL_PROMPT.format_map(example)
prompts["source"].append(source)
prompts["target"].append(example["answer"])
return prompts


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -444,10 +464,11 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
if data_args.dataset_name == "tatsu-lab/alpaca":

if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
# Preprocessing the datasets.
for key in raw_datasets:
prompts = create_prompts(raw_datasets[key])
prompts = create_prompts(raw_datasets[key]) if not data_args.sql_prompt else create_sql_prompts(raw_datasets[key])
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
Expand Down Expand Up @@ -565,7 +586,7 @@ def concatenate_data(dataset, max_seq_length):
concatenated_dataset[column] = reshaped_data
return datasets.Dataset.from_dict(concatenated_dataset)

if data_args.dataset_name == "tatsu-lab/alpaca":
if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(
Expand Down Expand Up @@ -651,11 +672,7 @@ def compute_metrics(eval_preds):

if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

with training_args.main_process_first(desc="save model"):
if is_main_process(training_args.local_rank):
unwrapped_model = unwrap_model(lora_model)
unwrapped_model.save_pretrained(training_args.output_dir, state_dict=unwrapped_model.state_dict())
trainer.save_model()

metrics = train_result.metrics
trainer.log_metrics("train", metrics)
Expand Down

0 comments on commit 773842b

Please sign in to comment.