Skip to content

Commit 0507c7f

Browse files
committed
Merge remote-tracking branch 'origin/main' into moe
2 parents 3bce1f4 + 4d6c1d3 commit 0507c7f

16 files changed

+380
-96
lines changed

docs/nanoset.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument:
7979

8080
Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
8181
```shell
82-
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
82+
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
8383
```
8484

8585
## Under the hood

src/nanotron/config/config.py

+97-17
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,16 @@ class NanosetDatasetsArgs:
101101
dataset_folder: Union[str, dict, List[str]]
102102

103103
def __post_init__(self):
104-
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
104+
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
105105
self.dataset_folder = [self.dataset_folder]
106106
self.dataset_weights = [1]
107-
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
108-
self.dataset_weights = None # Set to None so we consume all the samples randomly
109-
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
107+
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder
108+
self.dataset_weights = (
109+
None # Set to None so we consume all the samples randomly
110+
)
111+
elif isinstance(
112+
self.dataset_folder, dict
113+
): # Case 3: dict with > 1 dataset_folder and weights
110114
tmp_dataset_folder = self.dataset_folder.copy()
111115
self.dataset_folder = list(tmp_dataset_folder.keys())
112116
self.dataset_weights = list(tmp_dataset_folder.values())
@@ -116,16 +120,55 @@ def __post_init__(self):
116120
class MultilingualNanosetDatasetsArgs:
117121
training_folder: Union[str, dict, List[str]]
118122
validation_folder: Union[str, List[str]]
119-
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
123+
languages: List[
124+
str
125+
] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
120126

121127
def __post_init__(self):
122128
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
123129
self.training_folder = [self.training_folder]
124130
self.validation_folder = [self.validation_folder]
125131
self.dataset_weights = [1]
126132
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
127-
self.dataset_weights = None # Set to None so we consume all the samples randomly
128-
elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights
133+
self.dataset_weights = (
134+
None # Set to None so we consume all the samples randomly
135+
)
136+
elif isinstance(
137+
self.training_folder, dict
138+
): # Case 3: dict with > 1 training_folder and weights
139+
tmp_training_folder = self.training_folder.copy()
140+
self.training_folder = list(tmp_training_folder.keys())
141+
self.dataset_weights = list(tmp_training_folder.values())
142+
143+
assert len(self.training_folder) == len(
144+
self.languages
145+
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"
146+
147+
assert len(self.training_folder) == len(
148+
self.validation_folder
149+
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
150+
151+
152+
@dataclass
153+
class MultilingualNanosetDatasetsArgs:
154+
training_folder: Union[str, dict, List[str]]
155+
validation_folder: Union[str, List[str]]
156+
languages: List[
157+
str
158+
] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
159+
160+
def __post_init__(self):
161+
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
162+
self.training_folder = [self.training_folder]
163+
self.validation_folder = [self.validation_folder]
164+
self.dataset_weights = [1]
165+
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
166+
self.dataset_weights = (
167+
None # Set to None so we consume all the samples randomly
168+
)
169+
elif isinstance(
170+
self.training_folder, dict
171+
): # Case 3: dict with > 1 training_folder and weights
129172
tmp_training_folder = self.training_folder.copy()
130173
self.training_folder = list(tmp_training_folder.keys())
131174
self.dataset_weights = list(tmp_training_folder.values())
@@ -167,7 +210,9 @@ class DatasetStageArgs:
167210

168211
def __post_init__(self):
169212
if self.start_training_step < 0:
170-
raise ValueError(f"training_steps should be a positive integer and not {self.start_training_step}")
213+
raise ValueError(
214+
f"training_steps should be a positive integer and not {self.start_training_step}"
215+
)
171216

172217

173218
@dataclass
@@ -182,6 +227,7 @@ class CheckpointsArgs:
182227
checkpoints_path: Path
183228
checkpoint_interval: int
184229
save_initial_state: Optional[bool] = False
230+
save_final_state: Optional[bool] = False
185231
resume_checkpoint_path: Optional[Path] = None
186232
checkpoints_path_is_shared_file_system: Optional[bool] = False
187233

@@ -387,13 +433,19 @@ def __post_init__(self):
387433
if self.profiler is not None and self.profiler.profiler_export_path is not None:
388434
assert self.tokens.train_steps < 10
389435

390-
if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
436+
if (
437+
self.optimizer is not None
438+
and self.optimizer.learning_rate_scheduler.lr_decay_steps is None
439+
):
391440
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
392-
self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps
441+
self.tokens.train_steps
442+
- self.optimizer.learning_rate_scheduler.lr_warmup_steps
393443
)
394444

395445
if self.data_stages is not None:
396-
self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step)
446+
self.data_stages = sorted(
447+
self.data_stages, key=lambda stage: stage.start_training_step
448+
)
397449
names = [stage.name for stage in self.data_stages]
398450
training_steps = [stage.start_training_step for stage in self.data_stages]
399451
assert any(
@@ -402,7 +454,9 @@ def __post_init__(self):
402454

403455
for stage in self.data_stages:
404456
if names.count(stage.name) > 1:
405-
raise ValueError(f"Each stage should have unique names and not {names}")
457+
raise ValueError(
458+
f"Each stage should have unique names and not {names}"
459+
)
406460

407461
if training_steps.count(stage.start_training_step) > 1:
408462
raise ValueError(
@@ -411,13 +465,29 @@ def __post_init__(self):
411465

412466
# NOTE: must order the stages by start_training_step from lowest to highest
413467
assert all(
414-
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
468+
self.data_stages[i].start_training_step
469+
< self.data_stages[i + 1].start_training_step
415470
for i in range(len(self.data_stages) - 1)
416471
), "The stages are not sorted by start_training_step in increasing order"
417472

418473
# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
419474
# must comply with val_check_interval % iteration_step_info_interval = 0
420-
if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0:
475+
if (
476+
not self.tokens.val_check_interval
477+
% self.logging.iteration_step_info_interval
478+
== 0
479+
):
480+
raise ValueError(
481+
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
482+
)
483+
484+
# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
485+
# must comply with val_check_interval % iteration_step_info_interval = 0
486+
if (
487+
not self.tokens.val_check_interval
488+
% self.logging.iteration_step_info_interval
489+
== 0
490+
):
421491
raise ValueError(
422492
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
423493
)
@@ -428,7 +498,11 @@ def __post_init__(self):
428498

429499
@property
430500
def global_batch_size(self):
431-
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp
501+
return (
502+
self.tokens.micro_batch_size
503+
* self.tokens.batch_accumulation_per_replica
504+
* self.parallelism.dp
505+
)
432506

433507
def save_as_yaml(self, file_path: str):
434508
config_dict = serialize(self)
@@ -460,12 +534,18 @@ def get_config_from_dict(
460534
if skip_unused_config_keys:
461535
logger.warning("skip_unused_config_keys set")
462536
config_dict = {
463-
field.name: config_dict[field.name] for field in fields(config_class) if field.name in config_dict
537+
field.name: config_dict[field.name]
538+
for field in fields(config_class)
539+
if field.name in config_dict
464540
}
465541
if skip_null_keys:
466542
logger.warning("Skip_null_keys set")
467543
config_dict = {
468-
k: ({kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v)
544+
k: (
545+
{kk: vv for kk, vv in v.items() if vv is not None}
546+
if isinstance(v, dict)
547+
else v
548+
)
469549
for k, v in config_dict.items()
470550
if v is not None
471551
}

src/nanotron/config/lighteval_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __post_init__(self):
5151
class LightEvalTasksArgs:
5252
"""Arguments related to tasks for LightEval"""
5353

54+
langs: Optional[str] = None
5455
tasks: Optional[str] = None
5556
custom_tasks: Optional[str] = None
5657
max_samples: Optional[int] = None

src/nanotron/config/parallelism_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class ParallelismArgs:
3434
tp_linear_async_communication: Optional[bool] = None
3535
recompute_layer: bool = False
3636

37+
tp_recompute_allgather: bool = True
38+
3739
expert_parallel_size: int = 1
3840

3941
def __post_init__(self):

src/nanotron/data/multilingual_nanoset.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def __init__(
3838

3939
# Checks
4040
if isinstance(dataset_folders, str):
41-
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
41+
warnings.warn(
42+
"dataset_folders should be of type List[str] but str was provided. Converting to List[str]"
43+
)
4244
dataset_folders = [dataset_folders]
4345

4446
# Init
@@ -63,7 +65,9 @@ def __init__(
6365

6466
# Build Nanoset Index
6567
## To build the index we need the length of each dataset
66-
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
68+
self.dataset_lengths = [
69+
len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets
70+
]
6771
## Set dataset weights
6872
if (
6973
dataset_weights is None
@@ -76,10 +80,14 @@ def __init__(
7680
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
7781
## Build dataset index and dataset sample index
7882
if is_valid: # Valid MultilingualNanoset
79-
self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths)
83+
self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(
84+
self.dataset_lengths
85+
)
8086

8187
else: # Train MultilingualNanoset
82-
self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index()
88+
self.dataset_index, self.dataset_sample_index = (
89+
self.build_train_nanoset_index()
90+
)
8391

8492
self.print_nanoset_info()
8593

@@ -129,7 +137,9 @@ def build_train_nanoset_index(self) -> np.ndarray:
129137
numpy_random_state.shuffle(dataset_sample_index)
130138
# Concatenate num_epochs the shuffled indexes
131139
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
132-
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
140+
dataset_sample_index = np.concatenate(
141+
[dataset_sample_index for _ in range(num_epochs)]
142+
)
133143
# Just keep the necessary samples
134144
dataset_index = dataset_index[: self.train_split_num_samples]
135145
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]
@@ -152,7 +162,9 @@ def print_nanoset_info(self):
152162
)
153163

154164
# Print samples from each dataset + weight
155-
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
165+
dataset_sample_count = count_dataset_indexes(
166+
self.dataset_index, len(self.dataset_folders)
167+
)
156168
for index, sample_count in enumerate(dataset_sample_count):
157169
log_rank(
158170
f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
@@ -174,7 +186,9 @@ def build_train_nanoset_index_helper(
174186
"""
175187
# Create empty arrays for dataset indices and dataset sample indices
176188
dataset_index = np.empty((n_samples,), dtype="uint")
177-
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples
189+
dataset_sample_index = np.empty(
190+
(n_samples,), dtype="long"
191+
) # Supports dataset with up to 2**64 samples
178192

179193
# Initialize buffer for number of samples used for each dataset
180194
current_samples = np.zeros((len(weights),), dtype="long")
@@ -191,7 +205,9 @@ def build_train_nanoset_index_helper(
191205

192206
# Assign the dataset index and update the sample index
193207
dataset_index[sample_idx] = max_error_index
194-
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]
208+
dataset_sample_index[sample_idx] = (
209+
current_samples[max_error_index] % dataset_sizes[max_error_index]
210+
)
195211

196212
# Update the total samples for the selected dataset
197213
current_samples[max_error_index] += 1
@@ -211,4 +227,6 @@ def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray:
211227
dataset_index.extend([i] * length)
212228
dataset_sample_index.extend(range(length))
213229

214-
return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long")
230+
return np.array(dataset_index, dtype="uint"), np.array(
231+
dataset_sample_index, dtype="long"
232+
)

0 commit comments

Comments
 (0)