pytorch 1.8 里 syncbn 的实现,forward 和 backward 里都有nccl 通信的过程:
- all gather (mean, invstd, count) https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/modules/_functions.py#L28
dist.all_gather(combined_list, combined, async_op=False) # 这里是阻塞/同步的( work = xxx. work.wait() if not async_op)
- 为了计算 input grad 而 all_reduce sum_dy, sum_dy_xmu https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/modules/_functions.py#L80
limu 团队:跨卡同步 Batch Normalization