Skip to content

Commit c0ac791

Browse files
authored
[src,build] support for cuda 11.1 and rtx ampere (#4295)
1 parent 36f2065 commit c0ac791

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/configure

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,16 @@ Either your CUDA is too new or too old."
368368
MIN_UNSUPPORTED_GCC_VER_NUM=80000;
369369
CUSOLVER=true
370370
;;
371-
10_1 | 10_* | 11_*)
371+
10_1 | 10_*)
372372
MIN_UNSUPPORTED_GCC_VER="9.0"
373373
MIN_UNSUPPORTED_GCC_VER_NUM=90000;
374374
CUSOLVER=true
375375
;;
376+
11_*)
377+
MIN_UNSUPPORTED_GCC_VER="10.0"
378+
MIN_UNSUPPORTED_GCC_VER_NUM=100000;
379+
CUSOLVER=true
380+
;;
376381
*)
377382
echo "Unsupported CUDA_VERSION (CUDA_VERSION=$CUDA_VERSION), please report it to Kaldi mailing list, together with 'nvcc -h' or 'ptxas -h' which lists allowed -gencode values..."; exit 1;
378383
;;
@@ -392,7 +397,8 @@ Either your CUDA is too new or too old."
392397
#8_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61" ;;
393398
9_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70" ;;
394399
10_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" ;;
395-
11_*) CUDA_ARCH="-gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80" ;;
400+
11_0) CUDA_ARCH="-gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80" ;;
401+
11_1 | 11_*) CUDA_ARCH="-gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86" ;;
396402
*) echo "Unsupported CUDA_VERSION (CUDA_VERSION=$CUDA_VERSION), please report it to Kaldi mailing list, together with 'nvcc -h' or 'ptxas -h' which lists allowed -gencode values..."; exit 1 ;;
397403
esac
398404
;;

src/cudamatrix/cublas-wrappers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,13 @@ inline cusparseStatus_t cusparse_csrmm2(cusparseHandle_t handle,
265265

266266
size_t buffer_size;
267267
status = cusparseSpMM_bufferSize(handle, transA, transB, alpha, matA, matB,
268-
beta, matC, valType, CUSPARSE_MM_ALG_DEFAULT,
268+
beta, matC, valType, CUSPARSE_SPMM_CSR_ALG2,
269269
&buffer_size);
270270
if (status != CUSPARSE_STATUS_SUCCESS) return status;
271271

272272
void *buffer = (buffer_size > 0) ? CuDevice::Instantiate().Malloc(buffer_size) : NULL;
273273
status = cusparseSpMM(handle, transA, transB, alpha, matA, matB, beta, matC,
274-
valType, CUSPARSE_MM_ALG_DEFAULT, buffer);
274+
valType, CUSPARSE_SPMM_CSR_ALG2, buffer);
275275

276276
if (status != CUSPARSE_STATUS_SUCCESS) return status;
277277
if(buffer)

0 commit comments

Comments
 (0)