Skip to content

Commit f78c909

Browse files
feat(gpu): stream pools
1 parent 27e2fbd commit f78c909

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

backends/tfhe-cuda-backend/cuda/include/device.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
#include <cstdlib>
77
#include <cuda_runtime.h>
88

9+
#define CUDA_STREAM_POOL
10+
11+
enum CudaStreamType
12+
{
13+
KEY = 0,
14+
ALLOC = 1,
15+
TEMP_HELPER = 2,
16+
};
17+
918
extern "C" {
1019

1120
#define check_cuda_error(ans) \

backends/tfhe-cuda-backend/cuda/src/device.cu

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
#include <cuda_runtime.h>
44
#include <mutex>
55

6+
#ifdef CUDA_STREAM_POOL
7+
#include <deque>
8+
#include <mutex>
9+
#include <unordered_map>
10+
#endif
11+
612
uint32_t cuda_get_device() {
713
int device;
814
check_cuda_error(cudaGetDevice(&device));
@@ -109,18 +115,89 @@ void cuda_event_destroy(cudaEvent_t event, uint32_t gpu_index) {
109115
check_cuda_error(cudaEventDestroy(event));
110116
}
111117

118+
#ifdef CUDA_STREAM_POOL
119+
struct CudaBoundStream
120+
{
121+
cudaStream_t stream;
122+
uint32_t gpu_index;
123+
};
124+
125+
class CudaStreamPool
126+
{
127+
std::vector<CudaBoundStream> poolCompute;
128+
std::vector<CudaBoundStream> poolTransfer;
129+
130+
std::mutex mutex_pools;
131+
132+
size_t nextStream = 0;
133+
134+
const size_t MAX_STREAMS = 16;
135+
136+
public:
137+
cudaStream_t create_stream(uint32_t gpu_index)
138+
{
139+
std::lock_guard<std::mutex> lock(mutex_pools);
140+
if (poolCompute.empty())
141+
{
142+
poolCompute.reserve(MAX_STREAMS);
143+
144+
cuda_set_device(gpu_index);
145+
for (size_t i = 0; i < MAX_STREAMS; i++)
146+
{
147+
cudaStream_t stream;
148+
check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
149+
poolCompute.push_back(CudaBoundStream{stream, gpu_index});
150+
}
151+
}
152+
153+
PANIC_IF_FALSE(gpu_index == poolCompute[nextStream].gpu_index, "Bad gpu in stream pool");
154+
cudaStream_t res = poolCompute[nextStream].stream;
155+
nextStream = (nextStream + 1) % poolCompute.size();
156+
return res;
157+
}
158+
159+
void destroy_stream(cudaStream_t stream, uint32_t gpu_index)
160+
{
161+
//do nothing
162+
}
163+
};
164+
165+
166+
class CudaMultiStreamPool {
167+
std::unordered_map<uint32_t, CudaStreamPool> per_gpu_pools;
168+
std::mutex pools_mutex; // for creation of the mem managers
169+
170+
public:
171+
CudaStreamPool &get(uint32_t gpu_index) {
172+
std::lock_guard<std::mutex> guard(pools_mutex);
173+
return per_gpu_pools[gpu_index]; // creates it if it does not exist
174+
}
175+
};
176+
177+
CudaMultiStreamPool gCudaStreamPool;
178+
#endif
179+
180+
112181
/// Unsafe function to create a CUDA stream, must check first that GPU exists
113182
cudaStream_t cuda_create_stream(uint32_t gpu_index) {
183+
#ifdef CUDA_STREAM_POOL
184+
return gCudaStreamPool.get(gpu_index).create_stream(gpu_index);
185+
#else
114186
cuda_set_device(gpu_index);
115187
cudaStream_t stream;
116188
check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
117189
return stream;
190+
#endif
118191
}
119192

120193
/// Unsafe function to destroy CUDA stream, must check first the GPU exists
121194
void cuda_destroy_stream(cudaStream_t stream, uint32_t gpu_index) {
195+
#ifdef CUDA_STREAM_POOL
196+
gCudaStreamPool.get(gpu_index).destroy_stream(stream, gpu_index);
197+
#else
122198
cuda_set_device(gpu_index);
123199
check_cuda_error(cudaStreamDestroy(stream));
200+
#endif
124201
}
125202

126203
void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index) {

0 commit comments

Comments
 (0)