@@ -163,24 +163,51 @@ NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::AbstractA
163
163
# but allows us to write a generalized kernel supporting partial reductions.
164
164
R′ = reshape (R, (size (R)... , 1 ))
165
165
166
- # determine how many threads we can launch
166
+ # how many threads do we want?
167
+ #
168
+ # threads in a block work together to reduce values across the reduction dimensions;
169
+ # we want as many as possible to improve algorithm efficiency and execution occupancy.
170
+ dev = device ()
171
+ wanted_threads = shuffle ? nextwarp (dev, length (Rreduce)) : nextpow (2 , length (Rreduce))
172
+ function compute_threads (max_threads)
173
+ if wanted_threads > max_threads
174
+ shuffle ? prevwarp (dev, max_threads) : prevpow (2 , max_threads)
175
+ else
176
+ wanted_threads
177
+ end
178
+ end
179
+
180
+ # how many threads can we launch?
181
+ #
182
+ # we might not be able to launch all those threads to reduce each slice in one go.
183
+ # that's why each threads also loops across their inputs, processing multiple values
184
+ # so that we can span the entire reduction dimension using a single thread block.
167
185
args = (f, op, init, Rreduce, Rother, Val (shuffle), R′, As... )
168
186
kernel_args = cudaconvert .(args)
169
187
kernel_tt = Tuple{Core. Typeof .(kernel_args)... }
170
188
kernel = cufunction (partial_mapreduce_grid, kernel_tt)
171
- kernel_config =
172
- launch_configuration (kernel. fun; shmem = shuffle ? 0 : threads-> 2 * threads* sizeof (T))
189
+ compute_shmem (threads) = shuffle ? 0 : 2 * threads* sizeof (T)
190
+ kernel_config = launch_configuration (kernel. fun; shmem= compute_shmem∘ compute_threads)
191
+ reduce_threads = compute_threads (kernel_config. threads)
192
+ reduce_shmem = compute_shmem (reduce_threads)
193
+
194
+ # how many blocks should we launch?
195
+ #
196
+ # even though we can always reduce each slice in a single thread block, that may not be
197
+ # optimal as it might not saturate the GPU. we already launch some blocks to process
198
+ # independent dimensions in parallel; pad that number to ensure full occupancy.
199
+ other_blocks = length (Rother)
200
+ reduce_blocks = if other_blocks >= kernel_config. blocks
201
+ 1
202
+ else
203
+ min (cld (length (Rreduce), reduce_threads), # how many we need at most
204
+ cld (kernel_config. blocks, other_blocks)) # maximize occupancy
205
+ end
173
206
174
207
# determine the launch configuration
175
- dev = device ()
176
- reduce_threads = shuffle ? nextwarp (dev, length (Rreduce)) : nextpow (2 , length (Rreduce))
177
- if reduce_threads > kernel_config. threads
178
- reduce_threads = shuffle ? prevwarp (dev, kernel_config. threads) : prevpow (2 , kernel_config. threads)
179
- end
180
- reduce_blocks = min (reduce_threads, cld (length (Rreduce), reduce_threads))
181
- other_blocks = length (Rother)
182
- threads, blocks = reduce_threads, reduce_blocks* other_blocks
183
- shmem = shuffle ? 0 : 2 * threads* sizeof (T)
208
+ threads = reduce_threads
209
+ shmem = reduce_shmem
210
+ blocks = reduce_blocks* other_blocks
184
211
185
212
# perform the actual reduction
186
213
if reduce_blocks == 1
0 commit comments