Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <utility>

#include "device_lower/analysis/circular_buffer.h"
#include "device_lower/analysis/trivial_broadcast.h"
#include "device_lower/lower2device.h"
#include "device_lower/utils.h"
#include "disjoint_set.h"
Expand All @@ -25,7 +24,6 @@
#include "iter_visitor.h"
#include "logical_domain_map.h"
#include "transform_iter.h"
#include "val_graph_visitor.h"

namespace nvfuser {

Expand Down Expand Up @@ -495,6 +493,61 @@ std::vector<std::vector<Val*>> getTriviallyMappedIds(Expr* expr) {
return mapped_ids;
}

void mapSplitOfSplit(ValGraph& graph) {
// The following is a subpattern of
// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
//
// outer, _ = split(root)
// outermost_grand, _ = split(outer)
// outer', _ = split(root)
//
// If outermost_grand and outer' have the same extent, map them.
std::vector<std::pair<Val*, Val*>> ids_to_map;
for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
const ExprGroups& uses_of_root = graph.getUses(root);
std::vector<ValGroup> outermost_grands;
for (const ExprGroup& use_of_root : uses_of_root) {
auto* split0 = dynamic_cast<Split*>(use_of_root->front());
if (split0 == nullptr) {
continue;
}
// Only follow the outer output of the first split; outer and inner
// must not be conflated.
const ValGroup& outer = graph.toGroup(split0->outer());
for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
if (split1 == nullptr) {
continue;
}
const ValGroup& outermost_grand = graph.toGroup(split1->outer());
outermost_grands.push_back(outermost_grand);
}
}

for (const ValGroup& outermost_grand : outermost_grands) {
Val* extent_of_grand =
outermost_grand->front()->as<IterDomain>()->extent();

for (const ExprGroup& use_of_root : uses_of_root) {
auto* split = dynamic_cast<Split*>(use_of_root->front());
if (split == nullptr) {
continue;
}

const ValGroup& outer = graph.toGroup(split->outer());
if (outer->front()->as<IterDomain>()->extent()->sameAs(
extent_of_grand)) {
ids_to_map.emplace_back(outermost_grand->front(), outer->front());
}
}
}
}

for (const auto& [id1, id2] : ids_to_map) {
graph.mapVals(id1, id2);
}
}

} // namespace

ValGraph& IdModel::buildAlmostExactGraph() {
Expand Down Expand Up @@ -554,6 +607,8 @@ ValGraph& IdModel::buildAlmostExactGraph() {
almost_exact_graph.mapVals(id1, id2);
}

mapSplitOfSplit(almost_exact_graph);

almost_exact_graph.validateConsistency();

if (!allow_self_mapping_) {
Expand Down
9 changes: 5 additions & 4 deletions csrc/val_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,12 @@ const ExprGroups& ValGraph::getDefinitions(const ValGroup& val_group) const {

const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
NVF_ERROR(val_group, "Nullptr not allowed");

static ExprGroups empty_expr_groups;
const auto it = unique_uses_.find(val_group);
NVF_ERROR(
it != unique_uses_.end(),
"Use group not found for ",
nvfuser::toString(val_group));
if (it == unique_uses_.end()) {
return empty_expr_groups;
}
return it->second;
}

Expand Down
81 changes: 55 additions & 26 deletions tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*/
// clang-format on

#include <fstream>

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>

Expand All @@ -17,7 +15,6 @@
#include "id_model/loop_promotion.h"
#include "id_model/schedule.h"
#include "id_model/to_string.h"
#include "ir/graphviz.h"
#include "ops/all_ops.h"
#include "scheduler/tools/inlining.h"
#include "scheduler/tools/resize_utils.h"
Expand Down Expand Up @@ -235,8 +232,7 @@ void validateIELResolution(
auto promotion_id = iel_promotion_map_it->second;
ASSERT_TRUE(
exact_graph.disjointValSets().strictAreMapped(promotion_id, ref_id))
<< "Unexpected promotion. "
<< "Expected: " << ref_id->toString()
<< "Unexpected promotion. Expected: " << ref_id->toString()
<< ". Actual: " << promotion_id->toString();
ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(id, promotion_id))
<< "Promotion of " << id->toString()
Expand Down Expand Up @@ -376,9 +372,9 @@ void checkStep4Results(
const auto& iel_promotion_map = tester.s4_iel_promotion_map;

EXPECT_EQ(iel_promotion_map.size(), ref_promotion_map.size())
<< "Mismatched Step-4 result map. "
<< "Expected to have " << ref_promotion_map.size()
<< " mappings but found " << iel_promotion_map.size();
<< "Mismatched Step-4 result map. Expected to have "
<< ref_promotion_map.size() << " mappings but found "
<< iel_promotion_map.size();

for (const auto& ref_promotion_pair : ref_promotion_map) {
const auto& ref_promotion_group = ref_promotion_pair.first;
Expand Down Expand Up @@ -2937,9 +2933,8 @@ TEST_F(IdModelTest, LoopPromotionCyclicGraphWar) {
// Test to verify the split-aware covered group analysis. See
// also https://github.com/NVIDIA/Fuser/pull/3877.
TEST_F(IdModelTest, CoveredGroups) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({-1, 1});
fusion.addInput(tv0);
Expand Down Expand Up @@ -3000,7 +2995,7 @@ TEST_F(IdModelTest, CoveredGroups) {
TEST_F(IdModelTest, InvalidLoopPromotion) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
FusionGuard fg(&fusion);

auto T0 = makeContigConcreteTensor({1, 32, 6});
fusion.addInput(T0);
Expand Down Expand Up @@ -3086,9 +3081,8 @@ TEST_F(IdModelTest, InvalidLoopPromotion) {
// When a loop group only includes broadcast IDs, the group should not
// need to be promoted
TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({-1, 1});
fusion.addInput(tv0);
Expand Down Expand Up @@ -3130,9 +3124,8 @@ TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {

// Scatter output uses unique mapping schemes
TEST_F(IdModelTest, ScatterLoopMapping) {
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(1);
fusion.addInput(tv0);
Expand Down Expand Up @@ -3185,8 +3178,7 @@ TEST_F(IdModelTest, ScatterLoopMapping) {
// required but is a WAR for special ops like
// PreprocessGroupedMatmulInputSf. See also issue #5391.
TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
Expand Down Expand Up @@ -3219,8 +3211,7 @@ TEST_F(IdModelTest, LoopPromotionIncludeOnlyLoopIds) {
}

TEST_F(IdModelTest, PermissiveResizeGraph) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({36});
Expand Down Expand Up @@ -3270,8 +3261,7 @@ TEST_F(IdModelTest, PermissiveResizeGraph) {
// This is the failing segment of the reproducer of
// https://github.com/NVIDIA/Fuser/issues/5803.
TEST_F(IdModelTest, ReproIssue5803) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv2 = makeContigConcreteTensor({4}, DataType::Int);
Expand Down Expand Up @@ -3312,8 +3302,7 @@ TEST_F(IdModelTest, ReproIssue5803) {
// This is a minimal fusion pattern to trigger the loop promotion
// issue as reported in https://github.com/NVIDIA/Fuser/issues/5803
TEST_F(IdModelTest, ReproIssue5803Minimal) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({4, 8});
Expand All @@ -3337,4 +3326,44 @@ TEST_F(IdModelTest, ReproIssue5803Minimal) {
IdModel id_model(&fusion, true);
}

TEST_F(IdModelTest, SplitingReshape) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({2 * 2 * 2});
fusion.addInput(in);
TensorView* out = reshape(in, {2 * 2 * 2}, {2 * 2, 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 2);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_TRUE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(1)));
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(1), out->axis(2)));
}

TEST_F(IdModelTest, SplitingReshape_DifferentExtents) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* in = makeContigConcreteTensor({12});
fusion.addInput(in);
TensorView* out = reshape(in, {12}, {6, 2});
fusion.addOutput(out);

in->outer_split(0, 2);
out->outer_split(0, 3);

IdModel id_model(&fusion);
const ValGraph& almost_exact_graph = id_model.buildAlmostExactGraph();
EXPECT_FALSE(almost_exact_graph.disjointValSets().strictAreMapped(
in->axis(0), out->axis(0)));
}

} // namespace nvfuser