Skip to content

Latest commit

 

History

History
23 lines (15 loc) · 950 Bytes

sync-bn.md

File metadata and controls

23 lines (15 loc) · 950 Bytes

pytorch 1.8 里 syncbn 的实现,forward 和 backward 里都有nccl 通信的过程:

Forward

  1. all gather (mean, invstd, count) https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/modules/_functions.py#L28
dist.all_gather(combined_list, combined, async_op=False) # 这里是阻塞/同步的( work = xxx. work.wait() if not async_op)

Backward

  1. 为了计算 input grad 而 all_reduce sum_dy, sum_dy_xmu https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/modules/_functions.py#L80

参考资料:

limu 团队:跨卡同步 Batch Normalization

CVPR 2017 tutorial Hekaiming 和 RG 解释了 BN 原理

Variants of BN

SenseTime 的实现