Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ std::unordered_set<const Expr*> getBatchableTmaLoads(
// We have some tests where TMA load is used in an untraditional way.
// e.g. parallelized with threads, serial load, which requires multiple
// mbarriers or reuse of the same mbarrier.
if (std::any_of(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
auto non_trivial_ids =
tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
return !id->extent()->isConstScalar() ||
id->extent()->evaluate().as<int64_t>() > 1;
});
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
return {};
}
non_cb_tma_load_exprs.push_back(expr);
Expand Down
3 changes: 2 additions & 1 deletion csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ auto parseEnvOptions(
available_options.end(),
std::back_inserter(option_values),
[](const auto& kv) { return kv.first; });
std::sort(option_values.begin(), option_values.end());
std::ranges::sort(option_values);
NVF_CHECK(
false,
"Parsing ",
Expand Down Expand Up @@ -174,6 +174,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"tma_pointwise", EnableOption::TmaPointwise},
{"tma_inner_persistent", EnableOption::TmaInnerPersistent},
{"tma_reduction", EnableOption::TmaReduction},
{"tma_transpose", EnableOption::TmaTranspose},
{"ws_normalization", EnableOption::WarpSpecializedNormalization},
{"host_ir_lowering", EnableOption::HostIrLowering},
{"host_ir_jit", EnableOption::HostIrJit},
Expand Down
19 changes: 11 additions & 8 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <visibility.h>

#include <algorithm>
#include <cstdint>
#include <mutex>
#include <string>
#include <unordered_map>
Expand All @@ -22,7 +23,7 @@ namespace nvfuser {
//!
//! These can be set through the `NVFUSER_DUMP` environment variable
//!
enum class DebugDumpOption {
enum class DebugDumpOption : std::uint8_t {
CutlassCompile, //!< Dump compile commands and compile times for
//!< CutlassExecutor
FunctionTrace, //!< Dump the function trace of selected internal function. The
Expand Down Expand Up @@ -97,7 +98,7 @@ enum class DebugDumpOption {
//!
//! These can be set through the `NVFUSER_ENABLE` environment variable
//!
enum class EnableOption {
enum class EnableOption : std::uint8_t {
CutlassScheduler, //! Enable the CUTLASS scheduler and executor
FuseMatmul, //! Enable automatic fusion of matmul and linear ops
FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel
Expand All @@ -118,6 +119,7 @@ enum class EnableOption {
TmaPointwise, //! Enable TMA pointwise kernel
TmaInnerPersistent, //! Enable TMA inner persistent kernel
TmaReduction, //! Enable TMA reduction kernel
TmaTranspose, //! Enable TMA transpose kernel
WarpSpecializedNormalization, //! Enable warp specialized persistent kernel
HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR
HostIrJit, //! Enable Host IR JIT compilation with LLVM
Expand All @@ -134,7 +136,7 @@ enum class EnableOption {
//!
//! These can be set through the `NVFUSER_DISABLE` environment variable
//!
enum class DisableOption {
enum class DisableOption : std::uint8_t {
CompileToSass, //! Disable direct compilation to sass so the ptx can be
//! examined
ContigIndexing, //! Disable contiguous indexing
Expand Down Expand Up @@ -176,7 +178,7 @@ enum class DisableOption {
//!
//! These can be set through the `NVFUSER_PROF` environment variable
//!
enum class ProfilerOption {
enum class ProfilerOption : std::uint8_t {
Enable, //! Enables the profiler.
EnableNocupti, //! Enables the profiler, but disables CUPTI specific
//! profiling inorder to measure true host time without
Expand All @@ -197,10 +199,11 @@ class Options {
public:
Options() : options_(getOptionsFromEnv()) {}

Options(const Options& other) {
std::lock_guard<std::mutex> lock_other(other.mutex_);
options_ = other.options_;
}
Options(const Options& other)
: options_([&other]() {
std::lock_guard<std::mutex> lock_other(other.mutex_);
return other.options_;
}()) {}

Options& operator=(const Options& other) {
std::lock_guard<std::mutex> lock_other(other.mutex_);
Expand Down
10 changes: 6 additions & 4 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,11 @@ std::unique_ptr<HeuristicParams> TransposeScheduler::computeHeuristics(

std::unique_ptr<TransposeParams> tparams = nullptr;

// Try TMA path first
tparams =
transpose::tma::getTransposeHeuristics(fusion, runtime_info, data_cache);
// Try TMA path first if enabled
if (isOptionEnabled(EnableOption::TmaTranspose)) {
tparams = transpose::tma::getTransposeHeuristics(
fusion, runtime_info, data_cache);
}

// Fallback to non-TMA scheduler if TMA is not applicable
if (tparams == nullptr) {
Expand All @@ -431,7 +433,7 @@ void TransposeScheduler::schedule(
"Incorrect parameters sent to TransposeScheduler::schedule",
params);

if (tparams->use_tma_load) {
if (tparams->use_tma_load || tparams->use_tma_store) {
transpose::tma::scheduleTranspose(fusion, tparams);
} else {
transpose::non_tma::scheduleTranspose(fusion, tparams);
Expand Down
32 changes: 31 additions & 1 deletion csrc/scheduler/transpose_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace nvfuser {
// are equivelent!
class TransposeParams : public HeuristicParams {
public:
TransposeParams() : HeuristicParams(SchedulerType::Transpose) {};
TransposeParams() : HeuristicParams(SchedulerType::Transpose){};
static constexpr int64_t getMaxThreadsPerBlock() {
return 128;
}
Expand All @@ -39,6 +39,20 @@ class TransposeParams : public HeuristicParams {

// Whether to use TMA for loading inputs
bool use_tma_load = false;
bool use_tma_store = false;

// Which side of shared memory holds the transposed (swizzled) layout.
// false: input smem is swizzled, transpose happens on smem->register read.
// true: output smem is swizzled, transpose happens on register->smem write.
// This is independent of use_tma_load/use_tma_store — TMA can be used for
// either side regardless of where the transpose swizzle lives.
bool is_output_smem_transpose = false;

// In 128-bytes swizzled tma load, inner most dim is split into 8 chunks each
// with 16 bytes. Each thread may handle multiple chunks along the inner most
// dim.
int64_t chunks_per_thread = 1;
int64_t elements_per_chunk = 1;

// Vectorization factor for tensors in the first group
int64_t vectorize_factor1 = 1;
Expand All @@ -65,6 +79,10 @@ class TransposeParams : public HeuristicParams {
}
bool attr_equal = other->cparams == cparams &&
other->use_tma_load == use_tma_load &&
other->use_tma_store == use_tma_store &&
other->is_output_smem_transpose == is_output_smem_transpose &&
other->chunks_per_thread == chunks_per_thread &&
other->elements_per_chunk == elements_per_chunk &&
other->split_before_tiling == split_before_tiling &&
other->dims_merged_with_1 == dims_merged_with_1 &&
other->dims_merged_with_2 == dims_merged_with_2 &&
Expand Down Expand Up @@ -99,6 +117,14 @@ class TransposeParams : public HeuristicParams {
if (unroll_factor2 > 1) {
ss << "Unroll group 2, Factor: " << unroll_factor2 << "\n";
}
if (use_tma_load || use_tma_store) {
ss << "TMA: load=" << (use_tma_load ? "true" : "false")
<< " store=" << (use_tma_store ? "true" : "false")
<< " is_output_smem_transpose="
<< (is_output_smem_transpose ? "true" : "false")
<< " chunks_per_thread=" << chunks_per_thread
<< " elements_per_chunk=" << elements_per_chunk << "\n";
}
if (!split_before_tiling.empty() || !dims_merged_with_1.empty() ||
!dims_merged_with_2.empty()) {
ss << "Virtual inner-most dim:\n";
Expand Down Expand Up @@ -146,6 +172,10 @@ class TransposeParams : public HeuristicParams {
size_t hash() const override {
return c10::get_hash(
use_tma_load,
use_tma_store,
is_output_smem_transpose,
chunks_per_thread,
elements_per_chunk,
split_before_tiling,
dims_merged_with_1,
dims_merged_with_2,
Expand Down
Loading
Loading