Skip to content

Commit

Permalink
Fix ScopePusher regression issue (#3637) (#3644)
Browse files Browse the repository at this point in the history
* Remove torch_xla::ScopePusher and related functions

* Update torch:xla:ScopePusher to torch::lazy::ScopePusher

* Add unit test for ScopePusher

* Run linter

* Fix unit tests

* Update unit tests
  • Loading branch information
wonjoolee95 authored Jun 10, 2022
1 parent eaba928 commit 60d39ee
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 78 deletions.
22 changes: 22 additions & 0 deletions test/cpp/test_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,27 @@ TEST(IrTest, TestSelectUnselect) {
});
}

TEST(IrTest, TestScopePusherWithoutDebugging) {
bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
FLAGS_torch_lazy_ir_debug = false;
torch::lazy::ScopePusher scope("TestScope");
torch::lazy::NodePtr nodeptr = ScalarOp(1.0, xla::F32);
auto metaWithScope = nodeptr->metadata();
EXPECT_EQ(metaWithScope.scope, "");
EXPECT_EQ(metaWithScope.frame_info.size(), 0);
FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
}

TEST(IrTest, TestScopePusherWithDebugging) {
bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
FLAGS_torch_lazy_ir_debug = true;
torch::lazy::ScopePusher scope("TestScope");
torch::lazy::NodePtr nodeptr = ScalarOp(1.0, xla::F32);
auto metaWithScope = nodeptr->metadata();
ASSERT_TRUE(metaWithScope.scope.find("TestScope") != std::string::npos);
EXPECT_EQ(metaWithScope.frame_info.size(), 1);
FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
}

} // namespace cpp_test
} // namespace torch_xla
14 changes: 8 additions & 6 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,14 @@ void BuildProfilerSubmodule(py::module* m) {
.def_static("is_enabled",
&tensorflow::profiler::TraceMeWrapper::IsEnabled);

py::class_<ScopePusher, std::unique_ptr<ScopePusher>> scope_pusher_class(
profiler, "ScopePusher");
profiler.def("scope_pusher",
[](const std::string& name) -> std::unique_ptr<ScopePusher> {
return absl::make_unique<ScopePusher>(name);
});
py::class_<torch::lazy::ScopePusher,
std::unique_ptr<torch::lazy::ScopePusher>>
scope_pusher_class(profiler, "ScopePusher");
profiler.def(
"scope_pusher",
[](const std::string& name) -> std::unique_ptr<torch::lazy::ScopePusher> {
return absl::make_unique<torch::lazy::ScopePusher>(name);
});
}

void InitXlaModuleBindings(py::module m) {
Expand Down
48 changes: 0 additions & 48 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,6 @@ namespace {
using ShapeCache =
xla::util::Cache<torch::lazy::hash_t, xla::Shape, torch::lazy::HashReducer>;

struct ScopeEntry {
std::string name;
size_t saved_next_id = 1;
};

struct ScopeContext {
std::vector<ScopeEntry> scopes;
size_t next_id = 1;
};

thread_local ScopeContext g_scope_context;

void PushScope(const std::string& name) {
size_t id = g_scope_context.next_id;
g_scope_context.scopes.push_back(
{absl::StrCat(name, ".", id), g_scope_context.next_id + 1});
g_scope_context.next_id = 1;
}

void PopScope() {
XLA_CHECK(!g_scope_context.scopes.empty());
g_scope_context.next_id = g_scope_context.scopes.back().saved_next_id;
g_scope_context.scopes.pop_back();
}

void ResetScopeContext() {
XLA_CHECK_EQ(g_scope_context.scopes.size(), 0);
g_scope_context.next_id = 1;
}

std::string GetCurrentScope() {
std::string scope;
for (auto& scope_entry : g_scope_context.scopes) {
if (scope.empty()) {
absl::StrAppend(&scope, scope_entry.name);
} else {
absl::StrAppend(&scope, "/", scope_entry.name);
}
}
return scope;
}

ShapeCache* GetShapeCache() {
static int64_t shape_cache_size =
xla::sys_util::GetEnvInt("XLA_IR_SHAPE_CACHE_SIZE", 4096);
Expand Down Expand Up @@ -202,12 +160,6 @@ xla::Shape XlaNode::GetOpShape(
return *shape;
}

ScopePusher::ScopePusher(const std::string& name) { PushScope(name); }

ScopePusher::~ScopePusher() { PopScope(); }

void ScopePusher::ResetScopes() { ResetScopeContext(); }

const xla::Shape& GetXlaShape(const torch::lazy::Value& value) {
XlaNode* casted = dynamic_cast<XlaNode*>(value.node.get());
return casted->xla_shape(value.index);
Expand Down
10 changes: 0 additions & 10 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,6 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t dag_hash_;
};

// RAII data structure to be used a stack variable to enter a new IR scope. IR
// scope names will appear in the IR and will help identifying the source of the
// single IR nodes.
struct ScopePusher {
explicit ScopePusher(const std::string& name);
~ScopePusher();

static void ResetScopes();
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
stream << node.ToString();
return stream;
Expand Down
18 changes: 9 additions & 9 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ torch::lazy::NodePtr Norm(const torch::lazy::Value& input,
const c10::optional<at::Scalar>& p,
c10::optional<at::ScalarType> dtype,
absl::Span<const int64_t> dims, bool keepdim) {
ScopePusher ir_scope(at::aten::norm.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::norm.toQualString());
auto dimensions = torch::lazy::ToVector<int64_t>(dims);
if (dimensions.empty()) {
dimensions = torch::lazy::Iota<int64_t>(GetXlaShape(input).rank());
Expand Down Expand Up @@ -769,31 +769,31 @@ torch::lazy::NodePtr GeluBackward(const torch::lazy::Value& grad_output,

torch::lazy::NodePtr Lshift(const torch::lazy::Value& input,
const at::Scalar& other) {
ScopePusher ir_scope(at::aten::__lshift__.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::__lshift__.toQualString());
return input * ScalarOp(pow(2, other.to<double>()), GetXlaShape(input));
}

torch::lazy::NodePtr Lshift(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
ScopePusher ir_scope(at::aten::__lshift__.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::__lshift__.toQualString());
return input * Pow(ScalarOp(2, GetXlaShape(input)), other);
}

torch::lazy::NodePtr Rshift(const torch::lazy::Value& input,
const at::Scalar& other) {
ScopePusher ir_scope(at::aten::__rshift__.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::__rshift__.toQualString());
return input / ScalarOp(pow(2, other.to<double>()), GetXlaShape(input));
}

torch::lazy::NodePtr Rshift(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
ScopePusher ir_scope(at::aten::__rshift__.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::__rshift__.toQualString());
return input / Pow(ScalarOp(2, GetXlaShape(input)), other);
}

torch::lazy::NodePtr Remainder(const torch::lazy::Value& input,
const torch::lazy::Value& divisor) {
ScopePusher ir_scope(at::aten::remainder.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::remainder.toQualString());
torch::lazy::NodePtr f = Fmod(
input,
torch::lazy::MakeNode<Abs>(divisor, std::vector<torch::lazy::Shape>()));
Expand Down Expand Up @@ -865,7 +865,7 @@ torch::lazy::NodePtr Take(const torch::lazy::Value& input,

torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
// TODO: add proper lowering function
ScopePusher ir_scope("aten::tanh_gelu");
torch::lazy::ScopePusher ir_scope("aten::tanh_gelu");
const xla::Shape& shape = GetXlaShape(input);
// inner = math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(input, 3))
// input * 0.5 * (1.0 + torch.tanh(inner))
Expand All @@ -882,7 +882,7 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
const torch::lazy::Value& input) {
// TODO: add proper lowering function
ScopePusher ir_scope("aten::tanh_gelu_backward");
torch::lazy::ScopePusher ir_scope("aten::tanh_gelu_backward");
const xla::Shape& shape = GetXlaShape(input);
constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
torch::lazy::NodePtr beta = ScalarOp(kBeta, shape);
Expand Down Expand Up @@ -975,7 +975,7 @@ torch::lazy::NodePtr BaddBmm(const torch::lazy::Value& lhs,
torch::lazy::NodePtr Lerp(const torch::lazy::Value& start,
const torch::lazy::Value& end,
const torch::lazy::Value& weight) {
ScopePusher ir_scope(at::aten::lerp.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::lerp.toQualString());
return start + weight * (end - start);
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,7 @@ void XLATensor::SyncLiveTensorsGraph(const torch::lazy::BackendDevice* device,
void XLATensor::MarkStep(const torch::lazy::BackendDevice& device) {
XLA_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
ScopePusher::ResetScopes();
torch::lazy::ScopePusher::ResetScopes();
g_tls_data.Reset();
}

Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1806,15 +1806,15 @@ XLATensor XLATensor::lt(const XLATensor& input, const XLATensor& other) {

void XLATensor::masked_fill_(XLATensor& input, const XLATensor& mask,
const at::Scalar& value) {
ScopePusher ir_scope(at::aten::masked_fill.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::masked_fill.toQualString());
input.SetIrValue(torch::lazy::MakeNode<MaskedFill>(
input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()),
value));
}

void XLATensor::masked_scatter_(XLATensor& input, const XLATensor& mask,
const XLATensor& source) {
ScopePusher ir_scope(at::aten::masked_scatter.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::masked_scatter.toQualString());
input.SetIrValue(torch::lazy::MakeNode<MaskedScatter>(
input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()),
source.GetIrValue()));
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ XLATensor MakeMatrixWithDiagonal(const XLATensor& input, int64_t diagonal) {

XLATensor SmoothL1Loss(const XLATensor& input, const XLATensor& target,
ReductionMode reduction, double beta) {
torch_xla::ScopePusher ir_scope(at::aten::smooth_l1_loss.toQualString());
torch::lazy::ScopePusher ir_scope(at::aten::smooth_l1_loss.toQualString());
auto broadcasted_inputs = XLATensor::broadcast_tensors({input, target});
XLA_CHECK_EQ(broadcasted_inputs.size(), 2);
const XLATensor& broadcasted_input = broadcasted_inputs[0];
Expand Down Expand Up @@ -134,7 +134,7 @@ XLATensor SmoothL1Loss(const XLATensor& input, const XLATensor& target,
XLATensor SmoothL1LossBackward(const XLATensor& grad_output,
const XLATensor& input, const XLATensor& target,
ReductionMode reduction, double beta) {
torch_xla::ScopePusher ir_scope(
torch::lazy::ScopePusher ir_scope(
at::aten::smooth_l1_loss_backward.toQualString());
auto broadcasted_inputs = XLATensor::broadcast_tensors({input, target});
XLA_CHECK_EQ(broadcasted_inputs.size(), 2);
Expand Down

0 comments on commit 60d39ee

Please sign in to comment.