Skip to content

Commit f9df694

Browse files
authored
Make add_relu an internal function (pytorch#46676) (pytorch#46765)
Summary: Cleanup for 1.7 Pull Request resolved: pytorch#46676 Reviewed By: gchanan Differential Revision: D24458565 Pulled By: albanD fbshipit-source-id: b1e4b4630233d3f1a4bac20e3077411d1ae17f7b # Conflicts: # test/backward_compatibility/check_backward_compatibility.py
1 parent 6394982 commit f9df694

File tree

9 files changed

+21
-20
lines changed

9 files changed

+21
-20
lines changed

aten/src/ATen/core/NamedRegistrations.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
3737
m.impl("add.out", CppFunction::makeFallthrough());
3838
m.impl("add_.Scalar", CppFunction::makeFallthrough());
3939
m.impl("add_.Tensor", CppFunction::makeFallthrough());
40-
m.impl("add_relu.Tensor", CppFunction::makeFallthrough());
41-
m.impl("add_relu.out", CppFunction::makeFallthrough());
42-
m.impl("add_relu_.Tensor", CppFunction::makeFallthrough());
40+
m.impl("_add_relu.Tensor", CppFunction::makeFallthrough());
41+
m.impl("_add_relu.out", CppFunction::makeFallthrough());
42+
m.impl("_add_relu_.Tensor", CppFunction::makeFallthrough());
4343
m.impl("addcdiv", CppFunction::makeFallthrough());
4444
m.impl("addcdiv.out", CppFunction::makeFallthrough());
4545
m.impl("addcdiv_", CppFunction::makeFallthrough());

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,19 +368,19 @@
368368
SparseCUDA: add_out_sparse_cuda
369369
MkldnnCPU: mkldnn_add_out
370370

371-
- func: add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
371+
- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
372372
use_c10_dispatcher: full
373373
variants: function
374374
dispatch:
375375
CPU: add_relu
376376

377-
- func: add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
377+
- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
378378
use_c10_dispatcher: full
379379
variants: function
380380
dispatch:
381381
CPU: add_relu_
382382

383-
- func: add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
383+
- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
384384
variants: function
385385
dispatch:
386386
CPU: add_relu_out

test/backward_compatibility/check_backward_compatibility.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@
118118
('aten::contiguous', datetime.date(2020, 12, 1)),
119119
('aten::to', datetime.date(2020, 12, 1)),
120120
("tensorexpr::Group", datetime.date(2020, 12, 1)),
121+
("aten::add_relu", datetime.date(2020, 12, 1)),
122+
("aten::add_relu_", datetime.date(2020, 12, 1)),
121123
]
122124

123125

test/test_jit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def forward(self, a, b, c):
581581
m = torch.jit.load(buffer)
582582
new_res = m(a, b, c)
583583
FileCheck().check_not("aten::relu(") \
584-
.check("aten::add_relu(") \
584+
.check("aten::_add_relu(") \
585585
.run(m.graph)
586586
torch.testing.assert_allclose(orig_res, new_res)
587587

@@ -600,7 +600,7 @@ def forward(self, a, b, c):
600600
m = torch.jit.load(buffer)
601601
new_res = m(a, b, c)
602602
FileCheck().check_not("aten::relu_(") \
603-
.check("aten::add_relu(") \
603+
.check("aten::_add_relu(") \
604604
.run(m.graph)
605605
torch.testing.assert_allclose(orig_res, new_res)
606606

@@ -631,10 +631,10 @@ def forward(self, a, b):
631631
new_res = m(a_copy, b)
632632
FileCheck().check_not("aten::add_(") \
633633
.check_not("aten::relu_(") \
634-
.check("aten::add_relu_(") \
634+
.check("aten::_add_relu_(") \
635635
.run(m.graph)
636636
torch.testing.assert_allclose(orig_res, new_res)
637-
# Since add_relu_ does inplace mutation ensure
637+
# Since _add_relu_ does inplace mutation ensure
638638
# a_copy is modified
639639
torch.testing.assert_allclose(orig_res, a_copy)
640640

@@ -669,10 +669,10 @@ def forward(self, a, b):
669669
new_res = m(a_copy, b)
670670
FileCheck().check_not("aten::add(") \
671671
.check_not("aten::relu_(") \
672-
.check("aten::add_relu(") \
672+
.check("aten::_add_relu(") \
673673
.run(m.graph)
674674
torch.testing.assert_allclose(orig_res, new_res)
675-
# Since add_relu_ with out=a does inplace mutation ensure
675+
# Since _add_relu_ with out=a does inplace mutation ensure
676676
# a_copy is modified
677677
torch.testing.assert_allclose(orig_res, a_copy)
678678

test/test_mobile_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def forward(self, x):
9595
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
9696
.check_not("aten::add(") \
9797
.check_not("aten::relu(") \
98-
.check_count("aten::add_relu(", 1, exactly=True) \
98+
.check_count("aten::_add_relu(", 1, exactly=True) \
9999
.run(optimized_scripted_model.graph)
100100
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
101101

test/test_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9152,7 +9152,7 @@ def test_add_relu(self):
91529152
a = a + 5
91539153
add_res = a + b
91549154
relu_res = torch.relu(add_res)
9155-
add_relu_res = torch.add_relu(a, b)
9155+
add_relu_res = torch._VF._add_relu(a, b)
91569156

91579157
self.assertTrue(torch.allclose(add_relu_res, relu_res))
91589158

tools/code_analyzer/default_op_deps.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,7 +1903,7 @@
19031903
- name: aten::resize_as_
19041904
- name: aten::scalar_tensor
19051905
- name: aten::to
1906-
- name: aten::add_relu
1906+
- name: aten::_add_relu
19071907
depends:
19081908
- name: aten::as_strided_
19091909
- name: aten::copy_
@@ -1915,7 +1915,7 @@
19151915
- name: aten::resize_
19161916
- name: aten::resize_as_
19171917
- name: aten::to
1918-
- name: aten::add_relu_
1918+
- name: aten::_add_relu_
19191919
depends:
19201920
- name: aten::as_strided_
19211921
- name: aten::copy_

torch/csrc/jit/passes/fuse_relu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
1717
return (%res))";
1818
std::string add_relu_fused = R"(
1919
graph(%a, %b, %alpha):
20-
%res = aten::add_relu(%a, %b, %alpha)
20+
%res = aten::_add_relu(%a, %b, %alpha)
2121
return (%res))";
2222
rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
2323

@@ -35,7 +35,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
3535
return (%res))";
3636
std::string add_inplace_relu_fused = R"(
3737
graph(%a, %b, %alpha):
38-
%res = aten::add_relu_(%a, %b, %alpha)
38+
%res = aten::_add_relu_(%a, %b, %alpha)
3939
return (%res))";
4040
rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
4141

@@ -46,7 +46,7 @@ void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
4646
return (%res))";
4747
std::string add_out_relu_fused = R"(
4848
graph(%a, %b, %alpha, %out):
49-
%res = aten::add_relu(%a, %b, %alpha, %out)
49+
%res = aten::_add_relu(%a, %b, %alpha, %out)
5050
return (%res))";
5151

5252
rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);

torch/overrides.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
211211
torch.arccos: lambda input, out=None: -1,
212212
torch.acosh: lambda input, out=None: -1,
213213
torch.arccosh: lambda input, out=None: -1,
214-
torch.add_relu: lambda input, other, out=None: -1,
215214
torch.add: lambda input, other, out=None: -1,
216215
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
217216
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,

0 commit comments

Comments
 (0)