Skip to content

Commit

Permalink
add dataset disposal of b-mc2/sql-create-context for codegen
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Nov 22, 2023
1 parent a45f740 commit 0fb0359
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
29 changes: 29 additions & 0 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,35 @@ LOWER_LIST=ops_bf16.txt python3 ../gaudi_spawn.py \
--low_cpu_mem_usage True
```

- Multi-card finetuning of codegen-16B-mono:
```bash
python ../gaudi_spawn.py \
--world_size 8 --use_mpi run_lora_clm.py \
--model_name_or_path Salesforce/codegen-16B-mono \
--dataset_name b-mc2/sql-create-context \
--bf16 True \
--output_dir ./finetuned-models/codegen-finetune-on-sql-create-context-hpu8-lora8-bs4 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-4 \
--logging_steps 1 \
--dataset_concatenation \
--do_train \
--use_habana \
--use_lazy_mode \
--throughput_warmup_steps 3 \
--use_hpu_graphs_for_inference \
--lora_target_modules "qkv_proj" \
--lora_rank 8 \
--do_eval \
--validation_split_percentage 10 \
--use_cache False \

```

## Streaming

To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` with `--max_steps` specified in the command line. This is currently supported by `run_mlm.py` and `run_clm.py`.
Expand Down
31 changes: 28 additions & 3 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ class FinetuneArguments:
),
}

SQL_PROMPT= (
"You are a text-to-SQL model. Your job is to answer questions about a database. "
"You are given a question and a context regarding one or more tables in the database.\n\n"
"You must output the SQL query that answers the question. The SQL query must be between [SQL] and [/SQL] tags.\n\n"
"### Question: \n{question}\n\n### Context: \n{context}\n\n### Response:"
)


def create_prompts(examples):
prompts = {}
Expand All @@ -283,6 +290,16 @@ def create_prompts(examples):
prompts["target"].append(example["output"])
return prompts

def create_sql_prompts(examples):
prompts = {}
prompts["source"] = []
prompts["target"] = []
for example in examples:
source = SQL_PROMPT.format_map(example)
prompts["source"].append(source)
prompts["target"].append(example["answer"])
return prompts


def main():
# See all possible arguments in src/transformers/training_args.py
Expand Down Expand Up @@ -438,10 +455,18 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
)
if data_args.dataset_name == "tatsu-lab/alpaca":
for key in raw_datasets:
# if alpaca dataset and sql dataset pass as json file, make sure they could work
if sorted(list(raw_datasets[key].features.keys())) == sorted(['input', 'output', 'instruction']):
data_args.dataset_name = "tatsu-lab/alpaca"
if sorted(list(raw_datasets[key].features.keys())) == sorted(['question', 'context', 'answer']):
data_args.dataset_name = "b-mc2/sql-create-context"

if data_args.dataset_name in ["tatsu-lab/alpaca","b-mc2/sql-create-context"]:
# Preprocessing the datasets.
is_alpaca = data_args.dataset_name == "tatsu-lab/alpaca"
for key in raw_datasets:
prompts = create_prompts(raw_datasets[key])
prompts = create_prompts(raw_datasets[key]) if is_alpaca else create_sql_prompts(raw_datasets[key])
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
Expand Down Expand Up @@ -558,7 +583,7 @@ def concatenate_data(dataset, max_seq_length):
concatenated_dataset[column] = reshaped_data
return datasets.Dataset.from_dict(concatenated_dataset)

if data_args.dataset_name == "tatsu-lab/alpaca":
if data_args.dataset_name in ["tatsu-lab/alpaca","b-mc2/sql-create-context"]:
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(
Expand Down

0 comments on commit 0fb0359

Please sign in to comment.