Skip to content

Commit 561c4e8

Browse files
committed
update train.py
1 parent 2c6d894 commit 561c4e8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@
296296
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
297297
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
298298
if total_step <= wanted_step:
299+
if num_train // Unfreeze_batch_size == 0:
300+
raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
299301
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
300302
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
301303
print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))

0 commit comments

Comments
 (0)