@@ -151,7 +151,7 @@ def run_epoch():
151
151
compute_loss_fct (XMB , YMB , MMB , clf_logits , lm_logits )
152
152
n_updates += 1
153
153
if n_updates in [1000 , 2000 , 4000 , 8000 , 16000 , 32000 ] and n_epochs == 0 :
154
- log ()
154
+ log (save_dir , desc )
155
155
156
156
157
157
argmax = lambda x : np .argmax (x , 1 )
@@ -194,7 +194,6 @@ def run_epoch():
194
194
parser .add_argument ('--clf_pdrop' , type = float , default = 0.1 )
195
195
parser .add_argument ('--l2' , type = float , default = 0.01 )
196
196
parser .add_argument ('--vector_l2' , action = 'store_true' )
197
- parser .add_argument ('--n_gpu' , type = int , default = 1 )
198
197
parser .add_argument ('--opt' , type = str , default = 'adam' )
199
198
parser .add_argument ('--afn' , type = str , default = 'gelu' )
200
199
parser .add_argument ('--lr_schedule' , type = str , default = 'warmup_linear' )
@@ -223,9 +222,11 @@ def run_epoch():
223
222
desc = args .desc
224
223
data_dir = args .data_dir
225
224
log_dir = args .log_dir
225
+ submission_dir = args .submission_dir
226
226
227
- # torch.device object used throughout this script TODO add gpu setting
228
227
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
228
+ n_gpu = torch .cuda .device_count ()
229
+ print ("device" , device , "n_gpu" , n_gpu )
229
230
230
231
logger = ResultLogger (path = os .path .join (log_dir , '{}.jsonl' .format (desc )), ** args .__dict__ )
231
232
text_encoder = TextEncoder (args .encoder_path , args .bpe_path )
@@ -259,10 +260,11 @@ def run_epoch():
259
260
260
261
n_train = len (trY )
261
262
n_valid = len (vaY )
262
- n_batch_train = args .n_batch * args . n_gpu
263
+ n_batch_train = args .n_batch * n_gpu
263
264
n_updates_total = (n_train // n_batch_train ) * args .n_iter
264
265
265
- dh_model = DoubleHeadModel (args , clf_token , vocab , n_ctx )
266
+ if n_gpu > 1 :
267
+ dh_model = DoubleHeadModel (args , clf_token , vocab , n_ctx )
266
268
267
269
criterion = nn .CrossEntropyLoss (reduce = False )
268
270
model_opt = OpenAIAdam (dh_model .parameters (),
0 commit comments