38
38
print (f"Warning: { e } . Moving ahead without these qaic modules." )
39
39
40
40
41
- from transformers import AutoModelForCausalLM , AutoTokenizer
41
+ from transformers import AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
42
42
43
43
# Suppress all warnings
44
44
warnings .filterwarnings ("ignore" )
@@ -56,6 +56,7 @@ def main(**kwargs):
56
56
# update the configuration for the training process
57
57
train_config = TRAIN_CONFIG ()
58
58
update_config (train_config , ** kwargs )
59
+ dataset_config = generate_dataset_config (train_config , kwargs )
59
60
device = train_config .device
60
61
61
62
# dist init
@@ -78,12 +79,30 @@ def main(**kwargs):
78
79
# Load the pre-trained model and setup its configuration
79
80
# config = AutoConfig.from_pretrained(train_config.model_name)
80
81
pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
81
- model = AutoModelForCausalLM .from_pretrained (
82
- pretrained_model_path ,
83
- use_cache = False ,
84
- attn_implementation = "sdpa" ,
85
- torch_dtype = torch .float16 ,
86
- )
82
+ if train_config .task_type == "seq_classification" :
83
+ model = AutoModelForSequenceClassification .from_pretrained (
84
+ pretrained_model_path ,
85
+ num_labels = dataset_config .num_labels ,
86
+ attn_implementation = "sdpa" ,
87
+ torch_dtype = torch .float16 ,
88
+ )
89
+
90
+ if not hasattr (model , "base_model_prefix" ):
91
+ raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
92
+
93
+ for param in getattr (model , model .base_model_prefix ).parameters ():
94
+ param .requires_grad = False
95
+
96
+ for param in model .parameters ():
97
+ if param .requires_grad :
98
+ param .data = param .data .to (torch .float32 )
99
+ else :
100
+ model = AutoModelForCausalLM .from_pretrained (
101
+ pretrained_model_path ,
102
+ use_cache = False ,
103
+ attn_implementation = "sdpa" ,
104
+ torch_dtype = torch .float16 ,
105
+ )
87
106
88
107
# Load the tokenizer and add special tokens
89
108
tokenizer = AutoTokenizer .from_pretrained (
@@ -127,7 +146,6 @@ def main(**kwargs):
127
146
model .print_trainable_parameters ()
128
147
129
148
# Get the dataset utils
130
- dataset_config = generate_dataset_config (train_config , kwargs )
131
149
dataset_processer = tokenizer
132
150
133
151
# Load and preprocess the dataset for training and validation
0 commit comments