@@ -95,6 +95,94 @@ def __init__(self,
9595# pylint: enable=too-few-public-methods
9696
9797
98+ def _bucket_boundaries (max_length , min_length = 8 , length_bucket_step = 1.1 ):
99+ """A default set of length-bucket boundaries."""
100+ assert length_bucket_step > 1.0
101+ x = min_length
102+ boundaries = []
103+ while x < max_length :
104+ boundaries .append (x )
105+ x = max (x + 1 , int (x * length_bucket_step ))
106+ return boundaries
107+
108+
109+ def get_batching_scheme (batch_size : int ,
110+ max_length : int = None ,
111+ min_length_bucket : int = 8 ,
112+ length_bucket_step : float = 1.1 ,
113+ drop_long_sequences : bool = False ,
114+ shard_multiplier : int = 1 ,
115+ length_multiplier : int = 1 ,
116+ min_length : int = 0 ) -> BatchingScheme :
117+ """A batching scheme based on model hyperparameters.
118+ Every batch contains a number of sequences divisible by `shard_multiplier`.
119+ Args:
120+ batch_size: int, total number of tokens in a batch.
121+ max_length: int, sequences longer than this will be skipped. Defaults to
122+ batch_size.
123+ min_length_bucket: int
124+ length_bucket_step: float greater than 1.0
125+ drop_long_sequences: bool, if True, then sequences longer than
126+ `max_length` are dropped. This prevents generating batches with
127+ more than the usual number of tokens, which can cause out-of-memory
128+ errors.
129+ shard_multiplier: an integer increasing the batch_size to suit splitting
130+ across datashards.
131+ length_multiplier: an integer multiplier that is used to increase the
132+ batch sizes and sequence length tolerance.
133+ min_length: int, sequences shorter than this will be skipped.
134+ Returns:
135+ A dictionary with parameters that can be passed to input_pipeline:
136+ * boundaries: list of bucket boundaries
137+ * batch_sizes: list of batch sizes for each length bucket
138+ * max_length: int, maximum length of an example
139+ Raises:
140+ ValueError: If min_length > max_length
141+ """
142+ max_length = max_length or batch_size
143+ if max_length < min_length :
144+ raise ValueError ("max_length must be greater or equal to min_length" )
145+
146+ boundaries = _bucket_boundaries (max_length , min_length_bucket ,
147+ length_bucket_step )
148+ boundaries = [boundary * length_multiplier for boundary in boundaries ]
149+ max_length *= length_multiplier
150+
151+ batch_sizes = [
152+ max (1 , batch_size // length ) for length in boundaries + [max_length ]
153+ ]
154+ max_batch_size = max (batch_sizes )
155+ # Since the Datasets API only allows a single constant for window_size,
156+ # and it needs divide all bucket_batch_sizes, we pick a highly-composite
157+ # window size and then round down all batch sizes to divisors of that window
158+ # size, so that a window can always be divided evenly into batches.
159+ # TODO(noam): remove this when Dataset API improves.
160+ highly_composite_numbers = [
161+ 1 , 2 , 4 , 6 , 12 , 24 , 36 , 48 , 60 , 120 , 180 , 240 , 360 , 720 , 840 , 1260 , 1680 ,
162+ 2520 , 5040 , 7560 , 10080 , 15120 , 20160 , 25200 , 27720 , 45360 , 50400 , 55440 ,
163+ 83160 , 110880 , 166320 , 221760 , 277200 , 332640 , 498960 , 554400 , 665280 ,
164+ 720720 , 1081080 , 1441440 , 2162160 , 2882880 , 3603600 , 4324320 , 6486480 ,
165+ 7207200 , 8648640 , 10810800 , 14414400 , 17297280 , 21621600 , 32432400 ,
166+ 36756720 , 43243200 , 61261200 , 73513440 , 110270160
167+ ]
168+ window_size = max (
169+ [i for i in highly_composite_numbers if i <= 3 * max_batch_size ])
170+ divisors = [i for i in range (1 , window_size + 1 ) if window_size % i == 0 ]
171+ batch_sizes = [max ([d for d in divisors if d <= bs ]) for bs in batch_sizes ]
172+ window_size *= shard_multiplier
173+ batch_sizes = [bs * shard_multiplier for bs in batch_sizes ]
174+ # The Datasets API splits one window into multiple batches, which
175+ # produces runs of many consecutive batches of the same size. This
176+ # is bad for training. To solve this, we will shuffle the batches
177+ # using a queue which must be several times as large as the maximum
178+ # number of batches per window.
179+ max_batches_per_window = window_size // min (batch_sizes )
180+ shuffle_queue_size = max_batches_per_window * 3
181+
182+ ret = BatchingScheme (bucket_boundaries = boundaries ,
183+ bucket_batch_sizes = batch_sizes )
184+ return ret
185+
98186# The protected functions below are designed to convert the ambiguous spec
99187# structures to a normalized form.
100188
0 commit comments