@@ -277,21 +277,24 @@ def test_resume(self):
277277 lr_cfg = lr_cfg ,
278278 tokenizer_path = self .tokenizer_path ,
279279 global_batch_size = 2 ,
280- total_step = 10 ,
280+ total_step = 6 ,
281281 work_dir = str (self .work_dir ),
282282 hf_interval = 3 ,
283283 hf_max_keep = 2 ,
284284 seed = 42 ,
285285 debug = False ,
286- checkpoint_interval = 5 ,
286+ checkpoint_interval = 2 ,
287+ checkpoint_maxkeep = 2 ,
287288 )
288289
289290 trainer .fit ()
290291 dist .barrier ()
292+ # 0. Test checkpoint_maxkeep is consistent with meta file
293+ assert len (trainer .meta .latest_exp .checkpoint_list ) == 2
291294
292295 # Test resume
293296 # TODO: It's hard to test the accuracy of resuming in unit test now, need to improve
294- # 1. Test auto resume
297+ # 1. Test auto_resume
295298 resume_trainer1 = Trainer (
296299 load_from = str (self .fake_hf_model_dir ),
297300 model_cfg = model_cfg ,
@@ -308,14 +311,47 @@ def test_resume(self):
308311 hf_max_keep = 2 ,
309312 seed = 42 ,
310313 debug = False ,
311- checkpoint_interval = 3 ,
314+ checkpoint_interval = 2 ,
315+ checkpoint_maxkeep = 2 ,
312316 resume_cfg = ResumeConfig (
313317 auto_resume = True ,
314318 ),
315319 )
316- assert resume_trainer1 .cur_step == 10
320+ assert resume_trainer1 .cur_step == 6
317321 assert resume_trainer1 .exp_dir == trainer .exp_dir
322+ resume_trainer1 .fit ()
323+ dist .barrier ()
324+
325+ # 1.1 auto_resume twice
326+ resume_trainer1_2 = Trainer (
327+ load_from = str (self .fake_hf_model_dir ),
328+ model_cfg = model_cfg ,
329+ optim_cfg = optim_cfg ,
330+ fsdp_cfg = fsdp_cfg ,
331+ dataset_cfg = dataset_cfg ,
332+ dataloader_cfg = dataloader_cfg ,
333+ lr_cfg = lr_cfg ,
334+ tokenizer_path = self .tokenizer_path ,
335+ global_batch_size = 2 ,
336+ total_step = 16 ,
337+ work_dir = str (self .work_dir ),
338+ hf_interval = 3 ,
339+ hf_max_keep = 2 ,
340+ seed = 42 ,
341+ debug = False ,
342+ checkpoint_interval = 2 ,
343+ checkpoint_maxkeep = 2 ,
344+ resume_cfg = ResumeConfig (
345+ auto_resume = True ,
346+ ),
347+ )
348+ assert resume_trainer1_2 .cur_step == 10
349+ assert resume_trainer1_2 .exp_dir == trainer .exp_dir
350+ resume_trainer1_2 .fit ()
351+ assert resume_trainer1_2 .cur_step == 16
352+ dist .barrier ()
318353
354+ # 2. Test resume_from
319355 resume_trainer2 = Trainer (
320356 load_from = str (self .fake_hf_model_dir ),
321357 model_cfg = model_cfg ,
@@ -326,20 +362,22 @@ def test_resume(self):
326362 lr_cfg = lr_cfg ,
327363 tokenizer_path = self .tokenizer_path ,
328364 global_batch_size = 2 ,
329- total_step = 10 ,
365+ total_step = 20 ,
330366 work_dir = str (self .work_dir ),
331367 hf_interval = 3 ,
332368 hf_max_keep = 2 ,
333369 seed = 42 ,
334370 debug = False ,
335- checkpoint_interval = 3 ,
371+ checkpoint_interval = 5 ,
372+ checkpoint_maxkeep = 2 ,
336373 resume_cfg = ResumeConfig (
337- resume_from = trainer .meta .latest_exp .checkpoint_list [- 2 ],
374+ resume_from = resume_trainer1_2 .meta .latest_exp .checkpoint_list [- 2 ],
338375 ),
339376 )
340- assert resume_trainer2 .cur_step == 5
377+ assert resume_trainer2 .cur_step == 14
341378 resume_trainer2 .fit ()
342- assert resume_trainer2 .cur_step == 10
379+ assert resume_trainer2 .cur_step == 20
380+ assert resume_trainer2 .exp_dir != resume_trainer1_2 .exp_dir
343381
344382 @property
345383 def world_size (self ) -> int :
0 commit comments