@@ -66,68 +66,38 @@ def _build_input_queue(
66
66
global_batch_size : int ,
67
67
num_batches : Optional [int ] = None ,
68
68
repeat_final_dataset : bool = False ) -> Iterator [Dict [str , spec .Tensor ]]:
69
- not_train = split != 'train'
70
- per_device_batch_size = int (global_batch_size / N_GPUS )
71
-
72
- seq_len = self ._seq_len # TODO: define it somewehere else?
73
- dtype = torch .int32 # TODO: decide between int32 and int64.
74
-
75
- # Only create and iterate over tf input pipeline in one Python process to
76
- # avoid creating too many threads.
77
- if RANK == 0 :
78
- np_iter = super ()._build_input_queue (
79
- data_rng = data_rng ,
80
- split = split ,
81
- data_dir = data_dir ,
82
- global_batch_size = global_batch_size ,
83
- num_batches = num_batches ,
84
- repeat_final_dataset = repeat_final_dataset )
69
+ """Build an input queue for the given split."""
70
+ from algoperf .workloads .lm .input_pipeline import get_hf_dataloader
71
+
72
+ loader = get_hf_dataloader (
73
+ cache_dir = data_dir ,
74
+ data_rng = data_rng ,
75
+ batch_size = global_batch_size ,
76
+ seq_len = self ._seq_len ,
77
+ framework = "torch" ,
78
+ split = split )
79
+ seq_len = self ._seq_len
85
80
weights = None
86
-
87
- while True :
88
- # Only iterate over tf input pipeline in one Python process to
89
- # avoid creating too many threads.
90
- if RANK == 0 :
91
- batch = next (np_iter ) # pylint: disable=stop-iteration-return
92
- inputs = torch .as_tensor (
93
- batch ['inputs' ], dtype = dtype ,
94
- device = DEVICE ) # (N_GPUS, global_batch_size, seq_len)
95
- targets = torch .as_tensor (
96
- batch ['targets' ], dtype = dtype ,
97
- device = DEVICE ) # (N_GPUS, global_batch_size, seq_len)
98
-
99
- # Send batch to other devices when using DDP.
100
- if USE_PYTORCH_DDP :
101
- if not_train :
102
- # During eval, the batch size of the remainder might be different.
103
- per_device_batch_size = torch .tensor (
104
- len (targets [0 ]), dtype = dtype , device = DEVICE )
105
- dist .broadcast (per_device_batch_size , src = 0 )
106
- # We don't broadcast the shard for RANK 0.
107
- dist .broadcast (inputs [1 :], src = 0 )
108
- dist .broadcast (targets [1 :], src = 0 )
109
-
110
- # RANK 0 extracts his shard. If not DDP, this just flattens.
111
- inputs , targets = inputs [0 ], targets [0 ]
112
-
113
- else :
114
- # Receive batch from rank 0.
115
- if not_train :
116
- # During eval, the batch size of the remainder might be different.
117
- per_device_batch_size = torch .empty ((1 ,), dtype = dtype , device = DEVICE )
81
+
82
+ dtype = torch .long
83
+ is_train = split == 'train'
84
+
85
+ for batch in loader :
86
+ inputs , targets = batch
87
+
88
+ if USE_PYTORCH_DDP :
89
+ if not is_train :
90
+ # During eval, the batch size of the remainder might be different
91
+ per_device_batch_size = torch .tensor (
92
+ len (targets [0 ]), dtype = dtype , device = DEVICE )
118
93
dist .broadcast (per_device_batch_size , src = 0 )
119
-
120
- # N_GPUS - 1 since we don't broadcast the shard for RANK 0.
121
- inputs = torch .empty ((N_GPUS - 1 , per_device_batch_size , seq_len ),
122
- dtype = dtype ,
123
- device = DEVICE )
124
- targets = torch .empty ((N_GPUS - 1 , per_device_batch_size , seq_len ),
125
- dtype = dtype ,
126
- device = DEVICE )
94
+
95
+ # Broadcast to all devices
127
96
dist .broadcast (inputs , src = 0 )
128
97
dist .broadcast (targets , src = 0 )
129
- # RANK - 1 since we don't broadcast the shard for RANK 0.
130
- inputs , targets = inputs [RANK - 1 ], targets [RANK - 1 ]
98
+
99
+ if weights is None :
100
+ weights = torch .ones (inputs .shape [0 ], device = DEVICE )
131
101
132
102
if weights is None :
133
103
weights = torch .ones (per_device_batch_size , device = DEVICE )
@@ -138,10 +108,51 @@ def _build_input_queue(
138
108
}
139
109
yield batch
140
110
111
+ def is_output_params (self , param_name : str ) -> bool :
112
+ """Return whether the given parameter is an output parameter."""
113
+ return 'output.weight' in param_name or 'output.bias' in param_name
114
+
141
115
def _eval_batch (self ,
142
116
params : spec .ParameterContainer ,
143
117
batch : Dict [str , spec .Tensor ],
144
118
model_state : spec .ModelAuxiliaryState ,
145
119
rng : spec .RandomState ) -> spec .Tensor :
146
120
"""Evaluate the model on a single batch."""
147
- pass
121
+ model = params
122
+ logits , _ = self .model_fn (
123
+ model , batch , model_state , spec .ForwardPassMode .EVAL , rng , False )
124
+ targets = batch ['targets' ]
125
+
126
+ # Calculate cross-entropy loss
127
+ log_probs = torch .nn .functional .log_softmax (logits , dim = - 1 )
128
+ loss = - torch .sum (targets * log_probs )
129
+ return loss
130
+ def loss_fn (
131
+ self ,
132
+ label_batch : spec .Tensor ,
133
+ logits_batch : spec .Tensor ,
134
+ mask_batch : Optional [spec .Tensor ] = None ,
135
+ label_smoothing : float = 0.0 ) -> Dict [str , spec .Tensor ]:
136
+ """Compute cross-entropy loss for language modeling in PyTorch."""
137
+ vocab_size = logits_batch .shape [- 1 ]
138
+
139
+ if len (label_batch .shape ) == len (logits_batch .shape ):
140
+ # One-hot labels
141
+ log_probs = torch .nn .functional .log_softmax (logits_batch , dim = - 1 )
142
+ loss = - torch .sum (label_batch * log_probs , dim = - 1 )
143
+ else :
144
+ # Dense labels
145
+ loss = torch .nn .functional .cross_entropy (
146
+ logits_batch ,
147
+ label_batch ,
148
+ reduction = 'none' )
149
+
150
+ if mask_batch is not None :
151
+ loss = loss * mask_batch
152
+
153
+ n_valid = mask_batch .sum () if mask_batch is not None else label_batch .shape [0 ]
154
+ return {
155
+ 'summed' : loss .sum (),
156
+ 'n_valid_examples' : n_valid ,
157
+ 'per_example' : loss
158
+ }
0 commit comments