Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,117 @@
Intel(R) Tensor Processing Primitives extension for PyTorch\*
=============================================================
*Copyright (c) Intel corp.*
# Intel® Tensor Processing Primitives (TPP) Extension for PyTorch

# Pre-requisite
gcc v8.3.0 or higher
© Intel Corporation

# Installation
Setup conda environment using `utils/setup_conda.sh`
[![BSD 3-Clause License](https://img.shields.io/badge/license-BSD3-blue.svg "BSD 3-Clause License")](LICENSE.md)

The **Intel® Tensor Processing Primitives (TPP)** extension for PyTorch brings highly optimized deep learning kernels to PyTorch, delivering **accelerated performance on Intel architectures**. It is designed to efficiently execute compute-intensive operations using Intel® AVX-512 and other architectural features through JIT-compiled kernels.

## What is TPP?

[**TPP (Tensor Processing Primitives)**](https://libxsmm.readthedocs.io/en/latest/libxsmm_tpp/) is a collection of low-level building blocks designed for performance-critical deep learning workloads. It is part of the [LIBXSMM](https://github.com/libxsmm/libxsmm) ecosystem, which focuses on:

- Just-In-Time (JIT) code generation for tensor operations.
- Optimized support for small and medium-sized GEMMs (General Matrix-Matrix Multiplications), especially BRGEMMs (Batch Reduce GEMMs).
- Cache-aware loop transformations and blocking strategies.
- Vectorization and multi-threading support.

TPP supports multiple operation types using a unified JIT dispatch mechanism: **Unary Operations** (e.g., ReLU, copy, negation), **Binary Operations** (e.g., add, multiply), **Ternary Operations** (e.g., fused multiply-add), and **GEMM** and **BRGEMM** Kernels.
---

## About This Extension

This PyTorch extension integrates TPP kernels to accelerate various deep learning operators. Notably:

- It does **not** use `torch.compile` or other PyTorch dynamic compiler paths.
- It allows **direct invocation** of optimized C++/x86 kernels in PyTorch workflows.
- Ideal for research and experimentation on performance-aware PyTorch models.

## Prerequisites

- **Operating System:** Linux-based system (e.g., Ubuntu)
- **Compiler:** GCC 8.3.0 or higher
- **Environment:** Anaconda/Miniconda (recommended for Python and package management)
- **PyTorch:** 1.4.0 or higher
- **Python:** 3.6 or higher

---

## Installation Guide

### 1. Set Up Conda Environment

Use the provided script to create and configure a Conda environment.

```bash
# Create new conda env
# It creates an env.sh script for activating conda env
$bash utils/setup_conda.sh [-p <conda_install_path>]
# Optionally specify the conda installation path
bash utils/setup_conda.sh [-p <conda_install_path>]
```

Install the extension:
This generates an `env.sh` script for activating the environment.

## 2. Install the TPP Extension

Activate the environment and install the extension:

```bash
# Activate Conda environment
source env.sh

# Initialize Git submodules
git submodule update --init

# Install the extension
python setup.py install
```
# Source the env.sh and install the extension
$source env.sh
$git submodule update --init
$python setup.py install

setup.py install is deprecated. For modern packaging, consider using:

```bash
pip install .
```

# For multi-node runs:
(Optional) install torch_ccl module:
## Multi-Node Support (Optional)

To enable distributed training across multiple nodes, install the `torch_ccl` communication library:

```bash
$bash utils/install_torch_ccl.sh
bash utils/install_torch_ccl.sh
```
## Related Work & Repositories

### Core Libraries
- [libxsmm/libxsmm-dnn](https://github.com/libxsmm/libxsmm-dnn): LIBXSMM-based DNN kernels.
- [libxsmm/parlooper](https://github.com/libxsmm/parlooper): Loop parallelization library.

# Examples
## BERT
📄 [Tensor Processing Primitives(TPP)](https://arxiv.org/pdf/2104.05755)

📄 [Parlooper](https://arxiv.org/pdf/2304.12576)


### Kernel Testing
- [BRGEMM sample kernel](https://github.com/libxsmm/libxsmm/blob/main/samples/xgemm/gemm_kernel.c)
- [GEMM test script](https://github.com/libxsmm/libxsmm/blob/main/samples/xgemm/kernel_test/generate_gemm_test_scripts.sh#L8)

### MLIR Compiler Integration
- [tpp-mlir](https://github.com/libxsmm/tpp-mlir): MLIR-based TPP backend.

📄 [TPP-MLIR Compiler Paper](https://arxiv.org/abs/2404.15204v1)

### Triton CPU Upstreaming Efforts
- [triton-cpu (xsmm-main)](https://github.com/libxsmm/triton-cpu/tree/xsmm-main): Extending Triton with TPP support.

## Examples
### [Intel-alphafold](https://github.com/IntelLabs/open-omics-alphafold)
- [TPP optimization of AlphaFold2](examples/alphafold/README.md)
### BERT
- [BERT SQuAD Fine-tuning](examples/bert/squad/README.txt)
- [BERT MLPerf pre-training](examples/bert/pretrain_mlperf/README.txt)

## [GNN](examples/gnn/README.md)
### [GNN](examples/gnn/README.md)
- [GraphSage](examples/gnn/graphsage/README.md)
- [Graph Attention Network (GAT)](examples/gnn/gat/README.md)

### LLM
- [GPT-J](examples/llm/README.txt)


119 changes: 119 additions & 0 deletions examples/cnn/Efficiency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import time
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F

from tpp_pytorch_extension.cnn.Conv1dOpti_ext import Conv1dOpti, ReLU_bf16 # Import Layer from the extension


"""
Set parameters here for testing the convolutional layer. By default layer run in single-precsion (FP32) format

To run code in BFloat16 set enable_BF16 flag to True. BFloat16 code runs only when parameters of
Input width, number of filters and input channels to the layer are even number.
Ex. - Input_width = 60000, Filters = 16, Channels = 16, enable_BF16 = True ------ BF16 run

If any of the previous parameters is an odd number than code runs in FP32 format.


Keep batch size as multiple of CPU (Ex. - 28, 56, 84, 128 .... on a 28 core cascade lake) for optimal
performance with the Conv1dOpti layer. Each batch will run on a seperate thread thus performance
may go down if some core are not free, or batch size is not equal to the number of free cores.
Keep the batch size as power of 2 with the MKLDNN backend (Conv1d) for optimal performance.

"""

Input_width = 60400 # Width of the input signal track (60400)
Channels = 16 # Number of channels in the input (15)
Filters = 16 # Number of filter in the layer (15)
Dilation = 8 # Amount of dilation (8)
Kernel_size = 51 # Size of each filter (51)
enable_BF16 = False # Enable layer compute in BFloat16 (Only works when Filters and channels are both even numbers)


class Net1(nn.Module): # First network containing inbuilt PyTorch layer
def __init__(self):
super(Net1, self).__init__()
self.conv1 = nn.Conv1d(in_channels=Filters, out_channels=Channels, kernel_size=Kernel_size, \
stride=1, padding=0, dilation=Dilation, bias=False) # PyTorch Convolutional layer

def forward(self, x):
x = self.conv1(x)
# x = F.relu(x) # If applying relu
return x


class Net2(nn.Module): # Second network containing our optimized layer
def __init__(self):
super(Net2, self).__init__()
self.conv2 = Conv1dOpti(in_channels=Filters, out_channels=Channels, kernel_size=Kernel_size, \
stride=1, padding=0, dilation=Dilation, bias=False, enable_BF16=enable_BF16) # Optimized convolutional layer

def forward(self, x):
x = self.conv2(x)
# x = ReLU_bf16.apply(x) # If applying BF16 relu
return x


net1 = Net1() # Initilize neural networks
net2 = Net2()

torch.manual_seed(11) # Fixed Random Seed for comparison

random_weights = torch.randn(Filters, Channels, Kernel_size) # Random weights
net1.conv1.weight.data = random_weights # Assign random weights to the layer
net2.conv2.weight.data = random_weights

###------------------------------------- Timing check part -----------------------------------###

forward1 = 0 # variables to store time values
forward2 = 0
backward1 = 0
backward2 = 0


Batch_size_1 = 64 # Batch size for oneDNN (64)
X = torch.randn(Batch_size_1, Channels, Input_width, requires_grad=True) # Random Input (Batch_size, channel, width)

N = 20 # Number of iterations
for _ in range(N): # MKLDNN PyTorch layer Forward and Backward pass timing
start = time.time()
Y1 = net1.forward(X)
forward1 += time.time() - start

start = time.time()
Y1.sum().backward()
backward1 += time.time() - start


Batch_size_2 = 56 # Multiple of core count for optimized layer
X = torch.randn(Batch_size_2, Channels, Input_width, requires_grad=True) # Random Input (Batch_size, channel, width)

if enable_BF16 == True: # if BFloat16 computation is enabled
X = X.to(torch.bfloat16)

for _ in range(N): # Optimized PyTorch layer Forward and Backward pass timing
start = time.time()
Y2 = net2.forward(X)
forward2 += time.time() - start

start = time.time()
Y2.sum().backward()
backward2 += time.time() - start

print('Forward pass time (PyTorch layer): {:.3f} ms | Forward pass time (Optimized layer): {:.3f} ms'.format(forward1 * 1e3/N, forward2 * 1e3/N))
print('Backward pass time (PyTorch layer): {:.3f} ms | Backward pass time (Optimized layer): {:.3f} ms'.format(backward1 * 1e3/N, backward2 * 1e3/N))

forward1_flops = 2*Batch_size_1*Channels*Filters*Kernel_size*(Input_width - (Kernel_size - 1)*Dilation)/(forward1 / N)
backward1_flops = 2*2*Batch_size_1*Channels*Filters*Kernel_size*(Input_width - (Kernel_size - 1)*Dilation)/(backward1 / N)

forward2_flops = 2*Batch_size_2*Channels*Filters*Kernel_size*(Input_width - (Kernel_size - 1)*Dilation)/(forward2 / N)
backward2_flops = 2*2*Batch_size_2*Channels*Filters*Kernel_size*(Input_width - (Kernel_size - 1)*Dilation)/(backward2 / N)



print("\n")
print('Forward pass flops (PyTorch layer): {:e} | Forward pass flops (Optimized layer): {:e} '.format(forward1_flops, forward2_flops))
print('Backward pass flops (PyTorch layer): {:e} | Backward pass flops (Optimized layer): {:e} '.format(backward1_flops, backward2_flops))
76 changes: 76 additions & 0 deletions examples/cnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Optimized 1D U-Net with TPP Convolution Layer
---
An optimized 1D convolutional layer (`Conv1dOpti`) is developed using LIBXSMM Tensor Processing Primitives (TPPs) to enable high-performance execution. The implementation is integrated into the `cnn/` directory of the `tpp_pytorch_extension` package, allowing seamless access upon compilation. The optimized layer is subsequently integrated into a full 1D U-Net architecture to serve as a complete and practical example.

## 🚀 Features

- ✅ Drop-in replacement of `nn.Conv1d` with `Conv1dOpti`

---

## 📁 File Overview

- `examples/cnn/unet_example.py`: U-Net implementation
- `tpp_pytorch_extension/cnn/Conv1dOpti_ext.py`: Optimized convolution pytorch extension
- `examples/cnn/Efficiency_test.py`: Original Conv1d vs Conv1dOpti comparison

---

## ⚙️ Environment Setup

Set up your environment using conda:

```bash
conda create -n tpp-unet python=3.10 -y (python: 3.6 or higher)
conda activate tpp-unet

# Install PyTorch (1.4.0 or higher)
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 cpuonly -c pytorch

# Install required Python packages
pip install numpy==1.24
```

Install build tools and compile the TPP extension:

```bash
sudo apt-get update
sudo apt-get install -y build-essential git

# Clone and initialize submodules
git submodule update --init

# Build the TPP PyTorch extension
python setup.py install
```

---

## 🛠️ Development Guide

To use the optimized convolution in your own model, import:

```python
from tpp_pytorch_extension.cnn.Conv1dOpti_ext import Conv1dOpti, ReLU_bf16
```

### Example Usage:
```python
self.conv = Conv1dOpti(in_ch, out_ch, kernel_size=kernel_size, stride=1,
padding=0, dilation=dilation, bias=False)

x = torch.randn(8, 16, 1024)
out = conv(x)
```

---

## 📈 Run the U-Net


```bash
cd examples/cnn
python unet_example.py
```

---
Loading