Skip to content

Commit 93cf438

Browse files
committed
Adding support of domains to config
1 parent 4d6c1d3 commit 93cf438

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

src/nanotron/config/config.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ def __post_init__(self):
111111
class MultilingualNanosetDatasetsArgs:
112112
training_folder: Union[str, dict, List[str]]
113113
validation_folder: Union[str, List[str]]
114-
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
114+
domains: Optional[List[str]] = None # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
115+
languages: Optional[List[str]] = None # NOTE(@paultltc): For back-compatibility
115116

116117
def __post_init__(self):
118+
if self.languages is not None and self.domains is None:
119+
self.domains = self.languages
117120
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
118121
self.training_folder = [self.training_folder]
119122
self.validation_folder = [self.validation_folder]
@@ -125,13 +128,13 @@ def __post_init__(self):
125128
self.training_folder = list(tmp_training_folder.keys())
126129
self.dataset_weights = list(tmp_training_folder.values())
127130

128-
assert len(self.training_folder) == len(
129-
self.languages
130-
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"
131+
# assert len(self.training_folder) == len(
132+
# self.domains
133+
# ), f"The sizes of training_folder and domains mismatch ({len(self.training_folder)} vs {len(self.domains)})"
131134

132-
assert len(self.training_folder) == len(
133-
self.validation_folder
134-
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
135+
# assert len(self.training_folder) == len(
136+
# self.validation_folder
137+
# ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
135138

136139

137140
@dataclass
@@ -189,7 +192,6 @@ class GeneralArgs:
189192
190193
Args:
191194
project: Name of the project (a project gather several runs in common tensorboard/hub-folders)
192-
entity: Weights and bias entity name (optional)
193195
run: Name of the run
194196
step: Global step (updated when we save the checkpoint)
195197
consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size)

src/nanotron/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
679679
)
680680

681681
lang_losses = {
682-
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages
682+
lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.domains
683683
}
684684
lang_losses_list = list(lang_losses.keys())
685685

0 commit comments

Comments
 (0)