Skip to content

Commit 3f96a84

Browse files
authored
[UnitTest] add resume twice ut (#1216)
* [UnitTest] add resume twice ut
1 parent 6835e9b commit 3f96a84

File tree

1 file changed

+48
-10
lines changed

1 file changed

+48
-10
lines changed

tests/train/test_trainer.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,21 +277,24 @@ def test_resume(self):
277277
lr_cfg=lr_cfg,
278278
tokenizer_path=self.tokenizer_path,
279279
global_batch_size=2,
280-
total_step=10,
280+
total_step=6,
281281
work_dir=str(self.work_dir),
282282
hf_interval=3,
283283
hf_max_keep=2,
284284
seed=42,
285285
debug=False,
286-
checkpoint_interval=5,
286+
checkpoint_interval=2,
287+
checkpoint_maxkeep=2,
287288
)
288289

289290
trainer.fit()
290291
dist.barrier()
292+
# 0. Test checkpoint_maxkeep is consistent with meta file
293+
assert len(trainer.meta.latest_exp.checkpoint_list) == 2
291294

292295
# Test resume
293296
# TODO: It's hard to test the accuracy of resuming in unit test now, need to improve
294-
# 1. Test auto resume
297+
# 1. Test auto_resume
295298
resume_trainer1 = Trainer(
296299
load_from=str(self.fake_hf_model_dir),
297300
model_cfg=model_cfg,
@@ -308,14 +311,47 @@ def test_resume(self):
308311
hf_max_keep=2,
309312
seed=42,
310313
debug=False,
311-
checkpoint_interval=3,
314+
checkpoint_interval=2,
315+
checkpoint_maxkeep=2,
312316
resume_cfg=ResumeConfig(
313317
auto_resume=True,
314318
),
315319
)
316-
assert resume_trainer1.cur_step == 10
320+
assert resume_trainer1.cur_step == 6
317321
assert resume_trainer1.exp_dir == trainer.exp_dir
322+
resume_trainer1.fit()
323+
dist.barrier()
324+
325+
# 1.1 auto_resume twice
326+
resume_trainer1_2 = Trainer(
327+
load_from=str(self.fake_hf_model_dir),
328+
model_cfg=model_cfg,
329+
optim_cfg=optim_cfg,
330+
fsdp_cfg=fsdp_cfg,
331+
dataset_cfg=dataset_cfg,
332+
dataloader_cfg=dataloader_cfg,
333+
lr_cfg=lr_cfg,
334+
tokenizer_path=self.tokenizer_path,
335+
global_batch_size=2,
336+
total_step=16,
337+
work_dir=str(self.work_dir),
338+
hf_interval=3,
339+
hf_max_keep=2,
340+
seed=42,
341+
debug=False,
342+
checkpoint_interval=2,
343+
checkpoint_maxkeep=2,
344+
resume_cfg=ResumeConfig(
345+
auto_resume=True,
346+
),
347+
)
348+
assert resume_trainer1_2.cur_step == 10
349+
assert resume_trainer1_2.exp_dir == trainer.exp_dir
350+
resume_trainer1_2.fit()
351+
assert resume_trainer1_2.cur_step == 16
352+
dist.barrier()
318353

354+
# 2. Test resume_from
319355
resume_trainer2 = Trainer(
320356
load_from=str(self.fake_hf_model_dir),
321357
model_cfg=model_cfg,
@@ -326,20 +362,22 @@ def test_resume(self):
326362
lr_cfg=lr_cfg,
327363
tokenizer_path=self.tokenizer_path,
328364
global_batch_size=2,
329-
total_step=10,
365+
total_step=20,
330366
work_dir=str(self.work_dir),
331367
hf_interval=3,
332368
hf_max_keep=2,
333369
seed=42,
334370
debug=False,
335-
checkpoint_interval=3,
371+
checkpoint_interval=5,
372+
checkpoint_maxkeep=2,
336373
resume_cfg=ResumeConfig(
337-
resume_from=trainer.meta.latest_exp.checkpoint_list[-2],
374+
resume_from=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2],
338375
),
339376
)
340-
assert resume_trainer2.cur_step == 5
377+
assert resume_trainer2.cur_step == 14
341378
resume_trainer2.fit()
342-
assert resume_trainer2.cur_step == 10
379+
assert resume_trainer2.cur_step == 20
380+
assert resume_trainer2.exp_dir != resume_trainer1_2.exp_dir
343381

344382
@property
345383
def world_size(self) -> int:

0 commit comments

Comments
 (0)