Skip to content

Commit

Permalink
make fft sfinae friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
alfC committed Mar 8, 2025
1 parent 397383c commit 23bda93
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
28 changes: 22 additions & 6 deletions include/boost/multi/adaptors/cufft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ constexpr sign backward{CUFFT_INVERSE};

static_assert(forward != none && none != backward && backward != forward);

template<dimensionality_type DD = -1, class Alloc = void*>
template<dimensionality_type DD, class Alloc = void*>
class plan {
Alloc alloc_;
::size_t workSize_ = 0;
Expand Down Expand Up @@ -185,6 +185,8 @@ class plan {
}
}

static_assert(sizeof(ILayout*) == 0);

if(first_howmany_ == D) {
if constexpr(std::is_same_v<Alloc, void*>) {
cufftSafeCall(::cufftPlanMany(
Expand Down Expand Up @@ -286,14 +288,20 @@ class plan {

private:

template<typename = void>
void ExecZ2Z_(complex_type const* idata, complex_type* odata, int direction) const{
cufftSafeCall(cufftExecZ2Z(h_, const_cast<complex_type*>(idata), odata, direction)); // NOLINT(cppcoreguidelines-pro-type-const-cast) wrap legacy interface
// cudaDeviceSynchronize();
}

public:
template<class IPtr, class OPtr>
void execute(IPtr idata, OPtr odata, int direction) { // TODO(correaa) make const
auto execute(IPtr idata, OPtr odata, int direction)
-> decltype((void)(
reinterpret_cast<complex_type const*>(::thrust::raw_pointer_cast(idata)),
reinterpret_cast<complex_type*>(::thrust::raw_pointer_cast(odata))
))
{ // TODO(correaa) make const
if(first_howmany_ == DD) {
ExecZ2Z_(reinterpret_cast<complex_type const*>(::thrust::raw_pointer_cast(idata)), reinterpret_cast<complex_type*>(::thrust::raw_pointer_cast(odata)), direction); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) wrap a legacy interface
return;
Expand Down Expand Up @@ -381,24 +389,32 @@ class cached_plan {
if(it_ == LEAKY_cache.end()) {it_ = LEAKY_cache.insert(std::make_pair(std::make_tuple(which, in, out), plan<D, Alloc>(which, in, out, alloc))).first;}
}
template<class IPtr, class OPtr>
void execute(IPtr idata, OPtr odata, int direction) {
auto execute(IPtr idata, OPtr odata, int direction)
-> decltype(
(void)
(std::declval<
typename std::map<std::tuple<std::array<bool, D>, multi::layout_t<D>, multi::layout_t<D>>, plan<D, Alloc> >::iterator&
>()->second.execute(idata, odata, direction)
)
)
{
// assert(it_ != LEAKY_cache.end());
it_->second.execute(idata, odata, direction);
}
};

template<typename In, class Out, dimensionality_type D = In::rank::value, std::enable_if_t<!multi::has_get_allocator<In>::value, int> =0>
template<typename In, class Out, dimensionality_type D = In::rank::value, std::enable_if_t<!multi::has_get_allocator<In>::value, int> =0, typename = decltype(raw_pointer_cast(std::declval<In const&>().base()))>
auto dft(std::array<bool, +D> which, In const& in, Out&& out, int sgn)
->decltype(cufft::cached_plan<D>{which, in.layout(), out.layout()}.execute(in.base(), out.base(), sgn), std::forward<Out>(out)) {
return cufft::cached_plan<D>{which, in.layout(), out.layout()}.execute(in.base(), out.base(), sgn), std::forward<Out>(out); }

template<typename In, class Out, dimensionality_type D = In::rank::value, std::enable_if_t< multi::has_get_allocator<In>::value, int> =0>
template<typename In, class Out, dimensionality_type D = In::rank::value, std::enable_if_t< multi::has_get_allocator<In>::value, int> =0, typename = decltype(raw_pointer_cast(std::declval<In const&>().base()))>
auto dft(std::array<bool, +D> which, In const& in, Out&& out, int sgn)
->decltype(cufft::cached_plan<D /*, typename std::allocator_traits<typename In::allocator_type>::rebind_alloc<char>*/ >{which, in.layout(), out.layout()/*, i.get_allocator()*/}.execute(in.base(), out.base(), sgn), std::forward<Out>(out)) {
return cufft::cached_plan<D /*, typename std::allocator_traits<typename In::allocator_type>::rebind_alloc<char>*/ >{which, in.layout(), out.layout()/*, i.get_allocator()*/}.execute(in.base(), out.base(), sgn), std::forward<Out>(out); }

template<typename In, class Out, dimensionality_type D = In::rank::value>//, std::enable_if_t<not multi::has_get_allocator<In>::value, int> =0>
auto dft_forward(std::array<bool, +D> which, In const& in, Out&& out) -> Out&& {
auto dft_forward(std::array<bool, +D> which, In const& in, Out&& out) -> Out&& {
//->decltype(cufft::plan<D>{which, i.layout(), o.layout()}.execute(i.base(), o.base(), cufft::forward), std::forward<Out>(o)) {
return cufft::cached_plan<D>{which, in.layout(), out.layout()}.execute(in.base(), out.base(), cufft::forward), std::forward<Out>(out); }

Expand Down
4 changes: 2 additions & 2 deletions include/boost/multi/adaptors/cufft/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ include(CTest)
include_directories(${CMAKE_BINARY_DIR})

# file(GLOB TEST_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cpp)
set(TEST_SRCS cufft.cpp)
#set(TEST_SRCS cufft.cpp)

foreach(TEST_FILE ${TEST_SRCS})
set(TEST_EXE "${TEST_FILE}.x")
add_executable(${TEST_EXE} ${TEST_FILE})
if(ENABLE_CUDA OR DEFINED CXXCUDA)
set_source_files_properties(${TEST_FILE} PROPERTIES LANGUAGE CUDA)
target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
# target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
endif()
# target_compile_features (${TEST_EXE} PUBLIC cxx_std_17)
target_compile_definitions(${TEST_EXE} PRIVATE "BOOST_PP_VARIADICS") # needed by Boost.Test and NVCC
Expand Down
2 changes: 1 addition & 1 deletion include/boost/multi/adaptors/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ namespace boost::multi::fft{
#endif
template<class In>
auto dft_all(In&& in) {
auto const all_true = std::apply([](auto... es) { return std::array{(es, true)...}; }, std::array<bool, std::decay_t<In>::dimensionality>{});
auto const all_true = std::apply([](auto... es) { return std::array{((void)es, true)...}; }, std::array<bool, std::decay_t<In>::dimensionality>{});
return dft(all_true, std::forward<In>(in), fft::forward);
}

Expand Down
2 changes: 1 addition & 1 deletion include/boost/multi/adaptors/fftw/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ foreach(TEST_FILE ${TEST_SRCS})
add_executable(${TEST_EXE} ${TEST_FILE})
if(ENABLE_CUDA OR DEFINED CXXCUDA)
set_source_files_properties(${TEST_FILE} PROPERTIES LANGUAGE CUDA)
target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
# target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
endif()

target_include_directories(${TEST_EXE} PRIVATE ${PROJECT_SOURCE_DIR}/include)
Expand Down
2 changes: 1 addition & 1 deletion include/boost/multi/adaptors/tblis/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ foreach(TEST_FILE ${TEST_SRCS})
add_executable(${TEST_EXE} ${TEST_FILE})
if(ENABLE_CUDA OR DEFINED CXXCUDA)
set_source_files_properties(${TEST_FILE} PROPERTIES LANGUAGE CUDA)
target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
# target_compile_options(${TEST_EXE} PRIVATE -std=c++17)
endif()
# target_compile_features (${TEST_EXE} PUBLIC cxx_std_17)
target_compile_definitions(${TEST_EXE} PRIVATE "BOOST_PP_VARIADICS")
Expand Down

0 comments on commit 23bda93

Please sign in to comment.