@@ -101,12 +101,16 @@ class NanosetDatasetsArgs:
101
101
dataset_folder : Union [str , dict , List [str ]]
102
102
103
103
def __post_init__ (self ):
104
- if isinstance (self .dataset_folder , str ): # Case 1: 1 Dataset file
104
+ if isinstance (self .dataset_folder , str ): # Case 1: 1 Dataset folder
105
105
self .dataset_folder = [self .dataset_folder ]
106
106
self .dataset_weights = [1 ]
107
- elif isinstance (self .dataset_folder , List ): # Case 2: > 1 Dataset file
108
- self .dataset_weights = None # Set to None so we consume all the samples randomly
109
- elif isinstance (self .dataset_folder , dict ): # Case 3: dict with > 1 dataset_folder and weights
107
+ elif isinstance (self .dataset_folder , List ): # Case 2: > 1 Dataset folder
108
+ self .dataset_weights = (
109
+ None # Set to None so we consume all the samples randomly
110
+ )
111
+ elif isinstance (
112
+ self .dataset_folder , dict
113
+ ): # Case 3: dict with > 1 dataset_folder and weights
110
114
tmp_dataset_folder = self .dataset_folder .copy ()
111
115
self .dataset_folder = list (tmp_dataset_folder .keys ())
112
116
self .dataset_weights = list (tmp_dataset_folder .values ())
@@ -116,16 +120,55 @@ def __post_init__(self):
116
120
class MultilingualNanosetDatasetsArgs :
117
121
training_folder : Union [str , dict , List [str ]]
118
122
validation_folder : Union [str , List [str ]]
119
- languages : List [str ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
123
+ languages : List [
124
+ str
125
+ ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
120
126
121
127
def __post_init__ (self ):
122
128
if isinstance (self .training_folder , str ): # Case 1: 1 Dataset folder
123
129
self .training_folder = [self .training_folder ]
124
130
self .validation_folder = [self .validation_folder ]
125
131
self .dataset_weights = [1 ]
126
132
elif isinstance (self .training_folder , List ): # Case 2: > 1 Dataset folder
127
- self .dataset_weights = None # Set to None so we consume all the samples randomly
128
- elif isinstance (self .training_folder , dict ): # Case 3: dict with > 1 training_folder and weights
133
+ self .dataset_weights = (
134
+ None # Set to None so we consume all the samples randomly
135
+ )
136
+ elif isinstance (
137
+ self .training_folder , dict
138
+ ): # Case 3: dict with > 1 training_folder and weights
139
+ tmp_training_folder = self .training_folder .copy ()
140
+ self .training_folder = list (tmp_training_folder .keys ())
141
+ self .dataset_weights = list (tmp_training_folder .values ())
142
+
143
+ assert len (self .training_folder ) == len (
144
+ self .languages
145
+ ), f"The sizes of training_folder and languages mismatch ({ len (self .training_folder )} vs { len (self .languages )} )"
146
+
147
+ assert len (self .training_folder ) == len (
148
+ self .validation_folder
149
+ ), f"The sizes of training_folder and validation_folder mismatch ({ len (self .training_folder )} vs { len (self .validation_folder )} )"
150
+
151
+
152
+ @dataclass
153
+ class MultilingualNanosetDatasetsArgs :
154
+ training_folder : Union [str , dict , List [str ]]
155
+ validation_folder : Union [str , List [str ]]
156
+ languages : List [
157
+ str
158
+ ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
159
+
160
+ def __post_init__ (self ):
161
+ if isinstance (self .training_folder , str ): # Case 1: 1 Dataset folder
162
+ self .training_folder = [self .training_folder ]
163
+ self .validation_folder = [self .validation_folder ]
164
+ self .dataset_weights = [1 ]
165
+ elif isinstance (self .training_folder , List ): # Case 2: > 1 Dataset folder
166
+ self .dataset_weights = (
167
+ None # Set to None so we consume all the samples randomly
168
+ )
169
+ elif isinstance (
170
+ self .training_folder , dict
171
+ ): # Case 3: dict with > 1 training_folder and weights
129
172
tmp_training_folder = self .training_folder .copy ()
130
173
self .training_folder = list (tmp_training_folder .keys ())
131
174
self .dataset_weights = list (tmp_training_folder .values ())
@@ -167,7 +210,9 @@ class DatasetStageArgs:
167
210
168
211
def __post_init__ (self ):
169
212
if self .start_training_step < 0 :
170
- raise ValueError (f"training_steps should be a positive integer and not { self .start_training_step } " )
213
+ raise ValueError (
214
+ f"training_steps should be a positive integer and not { self .start_training_step } "
215
+ )
171
216
172
217
173
218
@dataclass
@@ -182,6 +227,7 @@ class CheckpointsArgs:
182
227
checkpoints_path : Path
183
228
checkpoint_interval : int
184
229
save_initial_state : Optional [bool ] = False
230
+ save_final_state : Optional [bool ] = False
185
231
resume_checkpoint_path : Optional [Path ] = None
186
232
checkpoints_path_is_shared_file_system : Optional [bool ] = False
187
233
@@ -387,13 +433,19 @@ def __post_init__(self):
387
433
if self .profiler is not None and self .profiler .profiler_export_path is not None :
388
434
assert self .tokens .train_steps < 10
389
435
390
- if self .optimizer is not None and self .optimizer .learning_rate_scheduler .lr_decay_steps is None :
436
+ if (
437
+ self .optimizer is not None
438
+ and self .optimizer .learning_rate_scheduler .lr_decay_steps is None
439
+ ):
391
440
self .optimizer .learning_rate_scheduler .lr_decay_steps = (
392
- self .tokens .train_steps - self .optimizer .learning_rate_scheduler .lr_warmup_steps
441
+ self .tokens .train_steps
442
+ - self .optimizer .learning_rate_scheduler .lr_warmup_steps
393
443
)
394
444
395
445
if self .data_stages is not None :
396
- self .data_stages = sorted (self .data_stages , key = lambda stage : stage .start_training_step )
446
+ self .data_stages = sorted (
447
+ self .data_stages , key = lambda stage : stage .start_training_step
448
+ )
397
449
names = [stage .name for stage in self .data_stages ]
398
450
training_steps = [stage .start_training_step for stage in self .data_stages ]
399
451
assert any (
@@ -402,7 +454,9 @@ def __post_init__(self):
402
454
403
455
for stage in self .data_stages :
404
456
if names .count (stage .name ) > 1 :
405
- raise ValueError (f"Each stage should have unique names and not { names } " )
457
+ raise ValueError (
458
+ f"Each stage should have unique names and not { names } "
459
+ )
406
460
407
461
if training_steps .count (stage .start_training_step ) > 1 :
408
462
raise ValueError (
@@ -411,13 +465,29 @@ def __post_init__(self):
411
465
412
466
# NOTE: must order the stages by start_training_step from lowest to highest
413
467
assert all (
414
- self .data_stages [i ].start_training_step < self .data_stages [i + 1 ].start_training_step
468
+ self .data_stages [i ].start_training_step
469
+ < self .data_stages [i + 1 ].start_training_step
415
470
for i in range (len (self .data_stages ) - 1 )
416
471
), "The stages are not sorted by start_training_step in increasing order"
417
472
418
473
# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
419
474
# must comply with val_check_interval % iteration_step_info_interval = 0
420
- if not self .tokens .val_check_interval % self .logging .iteration_step_info_interval == 0 :
475
+ if (
476
+ not self .tokens .val_check_interval
477
+ % self .logging .iteration_step_info_interval
478
+ == 0
479
+ ):
480
+ raise ValueError (
481
+ f"It is necessary to run the validation stage during a logging step. Validation interval: { self .tokens .val_check_interval } , Logging interval: { self .logging .iteration_step_info_interval } "
482
+ )
483
+
484
+ # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
485
+ # must comply with val_check_interval % iteration_step_info_interval = 0
486
+ if (
487
+ not self .tokens .val_check_interval
488
+ % self .logging .iteration_step_info_interval
489
+ == 0
490
+ ):
421
491
raise ValueError (
422
492
f"It is necessary to run the validation stage during a logging step. Validation interval: { self .tokens .val_check_interval } , Logging interval: { self .logging .iteration_step_info_interval } "
423
493
)
@@ -428,7 +498,11 @@ def __post_init__(self):
428
498
429
499
@property
430
500
def global_batch_size (self ):
431
- return self .tokens .micro_batch_size * self .tokens .batch_accumulation_per_replica * self .parallelism .dp
501
+ return (
502
+ self .tokens .micro_batch_size
503
+ * self .tokens .batch_accumulation_per_replica
504
+ * self .parallelism .dp
505
+ )
432
506
433
507
def save_as_yaml (self , file_path : str ):
434
508
config_dict = serialize (self )
@@ -460,12 +534,18 @@ def get_config_from_dict(
460
534
if skip_unused_config_keys :
461
535
logger .warning ("skip_unused_config_keys set" )
462
536
config_dict = {
463
- field .name : config_dict [field .name ] for field in fields (config_class ) if field .name in config_dict
537
+ field .name : config_dict [field .name ]
538
+ for field in fields (config_class )
539
+ if field .name in config_dict
464
540
}
465
541
if skip_null_keys :
466
542
logger .warning ("Skip_null_keys set" )
467
543
config_dict = {
468
- k : ({kk : vv for kk , vv in v .items () if vv is not None } if isinstance (v , dict ) else v )
544
+ k : (
545
+ {kk : vv for kk , vv in v .items () if vv is not None }
546
+ if isinstance (v , dict )
547
+ else v
548
+ )
469
549
for k , v in config_dict .items ()
470
550
if v is not None
471
551
}
0 commit comments