Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
247cdd2
Add blas_copy
Cstandardlib Oct 21, 2025
43c49ad
Add blas_nrm2
Cstandardlib Oct 21, 2025
b41f00c
Fix blas_nrm2 unittest
Cstandardlib Oct 21, 2025
b639675
Fix scnrm2_ return type
Cstandardlib Oct 21, 2025
7da4779
Sort and classify lapack routines
Cstandardlib Oct 22, 2025
281d49e
Merge branch 'develop' into feature/container-blas-と-lapack
Cstandardlib Oct 22, 2025
dba1900
Merge branch 'develop' into feature/container-blas-と-lapack
Cstandardlib Oct 27, 2025
0d3f2d0
Add geqrf
Cstandardlib Oct 27, 2025
76ed3b3
Merge branch 'develop' into feature/container-blas-と-lapack
Cstandardlib Oct 27, 2025
ec2582c
add geqrf lapack C interface
Cstandardlib Oct 27, 2025
e2fe179
geqrf_inplace with tests
Cstandardlib Oct 28, 2025
850fcfa
Comment test auxiliary code to be used later
Cstandardlib Oct 28, 2025
e82c17c
Merge branch 'develop' into feature/container-blas-と-lapack
Cstandardlib Oct 28, 2025
ab382f7
Merge branch 'develop' into feature/container-blas-and-lapack
Cstandardlib Oct 29, 2025
d1ad8fb
Merge branch 'develop' into feature/container-blas-and-lapack
Cstandardlib Oct 30, 2025
10664f6
Add the description of cusolver_utils.h, temporarily disabled
Cstandardlib Oct 30, 2025
7fa2e21
Merge branch 'develop' into feature/container-blas-and-lapack
Cstandardlib Nov 3, 2025
831625e
Update heevd interface to add lda
Cstandardlib Nov 3, 2025
62d695a
Update lapack_test to new interface
Cstandardlib Nov 4, 2025
999c5ed
Merge branch 'develop' into feature/container-blas-and-lapack
Cstandardlib Nov 4, 2025
01c30c6
Merge branch 'develop' into feature/container-blas-and-lapack
mohanchen Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion source/source_base/module_container/ATen/kernels/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
namespace container {
namespace kernels {


template <typename T>
struct blas_copy<T, DEVICE_CPU> {
void operator()(
const int n,
const T *x,
const int incx,
T *y,
const int incy)
{
BlasConnector::copy(n, x, incx, y, incy);
}
};

template <typename T>
struct blas_nrm2<T, DEVICE_CPU> {
using Real = typename GetTypeReal<T>::type;
Real operator()(
const int n,
const T *x,
const int incx)
{
return BlasConnector::nrm2(n, x, incx);
}
};

template <typename T>
struct blas_dot<T, DEVICE_CPU> {
void operator()(
Expand Down Expand Up @@ -175,6 +201,17 @@ struct blas_gemm_batched_strided<T, DEVICE_CPU> {
};

// Explicitly instantiate functors for the types of functor registered.

template struct blas_copy<float , DEVICE_CPU>;
template struct blas_copy<double, DEVICE_CPU>;
template struct blas_copy<std::complex<float >, DEVICE_CPU>;
template struct blas_copy<std::complex<double>, DEVICE_CPU>;

template struct blas_nrm2<float , DEVICE_CPU>;
template struct blas_nrm2<double, DEVICE_CPU>;
template struct blas_nrm2<std::complex<float >, DEVICE_CPU>;
template struct blas_nrm2<std::complex<double>, DEVICE_CPU>;

template struct blas_dot<float , DEVICE_CPU>;
template struct blas_dot<double, DEVICE_CPU>;
template struct blas_dot<std::complex<float >, DEVICE_CPU>;
Expand Down Expand Up @@ -221,4 +258,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_CPU>;
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_CPU>;

} // namespace kernels
} // namespace container
} // namespace container
22 changes: 21 additions & 1 deletion source/source_base/module_container/ATen/kernels/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
namespace container {
namespace kernels {

template <typename T, typename Device>
struct blas_copy {
// DCOPY copies a vector, x, to a vector, y.
void operator()(
const int n,
const T *x,
const int incx,
T *y,
const int incy);
};

template <typename T, typename Device>
struct blas_nrm2 {
using Real = typename GetTypeReal<T>::type;
Real operator()(
const int n,
const T *x,
const int incx);
};

template <typename T, typename Device>
struct blas_dot {
void operator()(
Expand Down Expand Up @@ -168,4 +188,4 @@ void destroyGpuBlasHandle(); // destory blas handle
} // namespace kernels
} // namespace container

#endif // ATEN_KERNELS_BLAS_H_
#endif // ATEN_KERNELS_BLAS_H_
43 changes: 41 additions & 2 deletions source/source_base/module_container/ATen/kernels/cuda/blas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,32 @@ void destroyGpuBlasHandle() {
}
}

template <typename T>
struct blas_nrm2<T, DEVICE_GPU> {
using Real = typename GetTypeReal<T>::type;
Real operator()(
const int n,
const T *x,
const int incx)
{
Real result;
cuBlasConnector::nrm2(cublas_handle, n, x, incx, &result);
return result;
}
};

template <typename T>
struct blas_copy<T, DEVICE_GPU> {
void operator()(
const int n,
const T * x,
const int incx,
T *y,
const int incy)
{
cuBlasConnector::copy(cublas_handle, n, x, incx, y, incy);
}
};

template <typename T>
struct blas_dot<T, DEVICE_GPU> {
Expand Down Expand Up @@ -76,7 +102,7 @@ struct blas_gemv<T, DEVICE_GPU> {
const int& incx,
const T* beta,
T* y,
const int& incy)
const int& incy)
{
cuBlasConnector::gemv(cublas_handle, trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy);
}
Expand Down Expand Up @@ -196,6 +222,19 @@ struct blas_gemm_batched_strided<T, DEVICE_GPU> {
};

// Explicitly instantiate functors for the types of functor registered.



template struct blas_copy<float , DEVICE_GPU>;
template struct blas_copy<double, DEVICE_GPU>;
template struct blas_copy<std::complex<float> , DEVICE_GPU>;
template struct blas_copy<std::complex<double>, DEVICE_GPU>;

template struct blas_nrm2<float , DEVICE_GPU>;
template struct blas_nrm2<double, DEVICE_GPU>;
template struct blas_nrm2<std::complex<float> , DEVICE_GPU>;
template struct blas_nrm2<std::complex<double>, DEVICE_GPU>;

template struct blas_dot<float , DEVICE_GPU>;
template struct blas_dot<double, DEVICE_GPU>;
template struct blas_dot<std::complex<float> , DEVICE_GPU>;
Expand Down Expand Up @@ -242,4 +281,4 @@ template struct blas_gemm_batched_strided<std::complex<float >, DEVICE_GPU>;
template struct blas_gemm_batched_strided<std::complex<double>, DEVICE_GPU>;

} // namespace kernels
} // namespace container
} // namespace container
Loading
Loading