The primary aim of this project was to assess the performance bottlenecks introduced by high-level machine learning frameworks like PyTorch, and to explore how much speedup could be achieved by implementing a low-level CUDA solution. Additionally, it served as a learning opportunity to deepen my understanding of CUDA and its direct application in neural network training.
As someone with extensive experience using Python-based frameworks such as PyTorch and JAX, this marks my first larger-scale project written primarily in CUDA. The focus was to maintain code readability while ensuring that it remains accessible for others who are starting their journey with CUDA programming.
Surprisingly, PyTorch exhibits significant overhead, particularly with smaller networks. Even when utilizing PyTorch 2.0's torch.compile
feature (tuned with mode="max-autotune"
and fullgraph=True
to reduce Python overhead), my results indicate that it can be up to 6 times slower compared to the CUDA implementation.
As the network scales, the gap narrows, yet PyTorch still lags behind by around 20%, even in larger models.
Several factors contribute to PyTorch's comparative slowness:
- My CUDA implementation leverages fp16 precision during matrix multiplications, whereas PyTorch defaults to fp32, prioritizing stability. Notably, NVIDIA highlights that using fp16 can theoretically double throughput. My benchmarks didn't reveal any instability using fp16.
- The CUDA version was carefully fine-tuned for my specific hardware, while I’m unsure if PyTorch’s
max-autotune
feature fully optimizes for hardware-specific parameters.
Note I've applied optimization techniques to both the PyTorch and CUDA versions:
- To minimize host-to-device data transfer delays, all data is preloaded into memory.
- I allowed PyTorch to perform several warm-up iterations to give its JIT compiler time to optimize the computation graph before timing the runs.
Although the CUDA implementation delivers superior speed, it's not without potential improvements. For instance, I didn't utilize vectorized loading in some element-wise operations like ReLU
, which could yield further performance gains.
To verify correctness, I compared the loss curves generated by both implementations. As expected, the training behavior is consistent between PyTorch and CUDA, indicating that the lower-level optimizations did not impact the convergence quality.