Skip to content

Commit 8aa878f

Browse files
authored
[jit] fix tuple alias analysis (pytorch#41992)
Previously when analyzing a TupleConstruct, we ignored the aliasing information of the inputs and simply marked all elements of the returned tuple as wildcards. But since we can fully reason about the contents of a tuple statically, we should be able to assign them aliasing information. This analysis was not only incomplete but produced incorrect results, since if `a` is not a wildcard, `a noalias wilcard`. So if we looked at `tuple(a)` and reported the aliasing info as `tuple(wildcard)`, then `tuple[0] noalias a`, which is...wrong.
1 parent 7c7c9c3 commit 8aa878f

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

test/jit/test_freezing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,6 @@ def forward(self, x):
794794
expected = m_s.forward(inp)
795795
self.assertEqual(out, expected)
796796

797-
# Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
798-
# In this example, 'a' is not mutated. However, we do not track which sub
799-
# values of a composite ivalue is mutated.
800797
def test_freeze_module_with_aliased_attr2(self):
801798
class FreezeMe(nn.Module):
802799
def __init__(self):
@@ -815,7 +812,6 @@ def forward(self, x):
815812
m_s = torch.jit.script(m)
816813
m_s.eval()
817814
m_f = torch._C._freeze_module(m_s._c)
818-
self.assertTrue(m_f.hasattr('a'))
819815
inp = torch.tensor([5])
820816
out = m_f.forward(inp)
821817
expected = m.forward(inp)

torch/csrc/jit/ir/alias_analysis.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) {
498498
case prim::tolist:
499499
return analyzeCreator(node);
500500
case prim::TupleConstruct:
501+
return analyzeTupleConstruct(node);
501502
case prim::DictConstruct:
502503
case prim::ListConstruct:
503504
return analyzeContainerConstruct(node);
@@ -864,20 +865,30 @@ void AliasDb::analyzeConservative(Node* node) {
864865
}
865866
}
866867

868+
void AliasDb::analyzeTupleConstruct(Node* node) {
869+
TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct);
870+
// tuples which contain immutable types are immutable
871+
if (!isMutableTypeInternal(node->output())) {
872+
return;
873+
}
874+
875+
giveFreshAlias(node->output());
876+
877+
for (const auto& input : node->inputs()) {
878+
if (isMutableTypeInternal(input)) {
879+
addToContainedElements(input, node->output());
880+
}
881+
}
882+
}
883+
867884
// List or dict or tuple: construct: create an aliasing element for the actual
868885
// container, then mark all inputs as wildcards, since they've gone inside the
869886
// container. Then, add the wildcard sets of appropriate type to the contained
870887
// elements of the container.
871888
void AliasDb::analyzeContainerConstruct(Node* node) {
872889
TORCH_INTERNAL_ASSERT(
873890
node->kind() == prim::ListConstruct ||
874-
node->kind() == prim::DictConstruct ||
875-
node->kind() == prim::TupleConstruct);
876-
877-
// tuples which contain immutable types are immutable
878-
if (!isMutableTypeInternal(node->output())) {
879-
return;
880-
}
891+
node->kind() == prim::DictConstruct);
881892

882893
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
883894
auto container = node->output();

torch/csrc/jit/ir/alias_analysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class AliasDb {
194194
void analyzeSetAttr(Node* node);
195195
void analyzeConservative(Node* node);
196196
void analyzeContainerConstruct(Node* node);
197+
void analyzeTupleConstruct(Node* node);
197198
bool tryRegisteredAnalysis(Node* node);
198199

199200
/**

0 commit comments

Comments
 (0)