Skip to content

Commit e61d4d3

Browse files
committed
Add CUDA device_guard to 2nd part of fused attn
1 parent 62fa472 commit e61d4d3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

exllama_ext/exllama_ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ void q4_attn_2
518518
{
519519
TORCH_CHECK_DTYPE(x, kHalf);
520520
TORCH_CHECK_DTYPE(attn_output, kHalf);
521+
const at::cuda::OptionalCUDAGuard device_guard(x.device());
521522

522523
int height = x.size(0);
523524

0 commit comments

Comments
 (0)