Skip to content

Commit bac0986

Browse files
authored
Merge pull request #978 from Smit-create/i-928
Support multiple assignments
2 parents 318a6fe + 834c3a2 commit bac0986

File tree

5 files changed

+107
-48
lines changed

5 files changed

+107
-48
lines changed

integration_tests/expr_09.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,41 @@ def main0():
88
print(-i1 ^ -i2)
99
assert -i1 ^ -i2 == 6
1010

11+
12+
def test_multiple_assign_1():
13+
a: i32; b: i32; c: i32
14+
d: f64; e: f32; g: i32
15+
g = 5
16+
d = e = g + 1.0
17+
a = b = c = 10
18+
assert a == b
19+
assert b == c
20+
assert a == 10
21+
x: f32; y: f64
22+
x = y = 23.0
23+
assert abs(x - 23.0) < 1e-6
24+
assert abs(y - 23.0) < 1e-12
25+
assert abs(e - 6.0) < 1e-6
26+
assert abs(d - 6.0) < 1e-12
27+
i: list[f64]; j: list[f64]; k: list[f64] = []
28+
g = 0
29+
for g in range(10):
30+
k.append(g*2.0 + 5.0)
31+
i = j = k
32+
for g in range(10):
33+
assert abs(i[g] - j[g]) < 1e-12
34+
assert abs(i[g] - k[g]) < 1e-12
35+
assert abs(g*2.0 + 5.0 - k[g]) < 1e-12
36+
37+
38+
def test_issue_928():
39+
a: i32; b: i32; c: tuple[i32, i32]
40+
a, b = c = 2, 1
41+
assert a == 2
42+
assert b == 1
43+
assert c[0] == a and c[1] == b
44+
45+
46+
test_multiple_assign_1()
47+
test_issue_928()
1148
main0()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,14 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
366366

367367

368368
ASR::asr_t *tmp;
369+
370+
/*
371+
If `tmp` is not null, then `tmp_vec` is ignored and `tmp` is used as the only result (statement or
372+
expression). If `tmp` is null, then `tmp_vec` is used to return any number of statements:
373+
0 (no statement returned), 1 (redundant, one should use `tmp` for that), 2, 3, ... etc.
374+
*/
375+
std::vector<ASR::asr_t *> tmp_vec;
376+
369377
Allocator &al;
370378
SymbolTable *current_scope;
371379
// The current_module contains the current module that is being visited;
@@ -732,7 +740,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
732740
if (ASR::is_a<ASR::Function_t>(*t)) {
733741
new_call_name = (ASR::down_cast<ASR::Function_t>(t))->m_name;
734742
}
735-
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
743+
return make_call_helper(al, t, current_scope, args, new_call_name, loc);
736744
}
737745
if (ASR::down_cast<ASR::Function_t>(s)->m_return_var != nullptr) {
738746
ASR::ttype_t *a_type = nullptr;
@@ -2421,7 +2429,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
24212429
/* a_name */ s2c(al, sym_name),
24222430
/* a_args */ args.p,
24232431
/* n_args */ args.size(),
2424-
/* a_type_params */ tps.p,
2432+
/* a_type_params */ tps.p,
24252433
/* n_type_params */ tps.size(),
24262434
/* a_body */ nullptr,
24272435
/* n_body */ 0,
@@ -2648,6 +2656,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
26482656
// The `body` Vec must already be reserved
26492657
void transform_stmts(Vec<ASR::stmt_t*> &body, size_t n_body, AST::stmt_t **m_body) {
26502658
tmp = nullptr;
2659+
tmp_vec.clear();
26512660
Vec<ASR::stmt_t*>* current_body_copy = current_body;
26522661
current_body = &body;
26532662
for (size_t i=0; i<n_body; i++) {
@@ -2656,9 +2665,18 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
26562665
if (tmp != nullptr) {
26572666
ASR::stmt_t* tmp_stmt = ASRUtils::STMT(tmp);
26582667
body.push_back(al, tmp_stmt);
2668+
} else if (!tmp_vec.empty()) {
2669+
for (auto t: tmp_vec) {
2670+
if (t != nullptr) {
2671+
ASR::stmt_t* tmp_stmt = ASRUtils::STMT(t);
2672+
body.push_back(al, tmp_stmt);
2673+
}
2674+
}
2675+
tmp_vec.clear();
26592676
}
26602677
// To avoid last statement to be entered twice once we exit this node
26612678
tmp = nullptr;
2679+
tmp_vec.clear();
26622680
}
26632681
current_body = current_body_copy;
26642682
}
@@ -2677,9 +2695,16 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
26772695
items.reserve(al, 4);
26782696
for (size_t i=0; i<x.n_body; i++) {
26792697
tmp = nullptr;
2698+
tmp_vec.clear();
26802699
visit_stmt(*x.m_body[i]);
26812700
if (tmp) {
26822701
items.push_back(al, tmp);
2702+
} else if (!tmp_vec.empty()) {
2703+
for (auto t: tmp_vec) {
2704+
if (t) items.push_back(al, t);
2705+
}
2706+
// Ensure that statements in tmp_vec are used only once.
2707+
tmp_vec.clear();
26832708
}
26842709
}
26852710
// These global statements are added to the translation unit for now,
@@ -2782,10 +2807,21 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
27822807
}
27832808

27842809
void visit_Assign(const AST::Assign_t &x) {
2785-
ASR::expr_t *target;
2786-
if (x.n_targets == 1) {
2787-
if (AST::is_a<AST::Subscript_t>(*x.m_targets[0])) {
2788-
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_targets[0]);
2810+
ASR::expr_t *target, *assign_value = nullptr, *tmp_value;
2811+
this->visit_expr(*x.m_value);
2812+
if (tmp) {
2813+
// This happens if `m.m_value` is `empty`, such as in:
2814+
// a = empty(16)
2815+
// We skip this statement for now, the array is declared
2816+
// by the annotation.
2817+
// TODO: enforce that empty(), ones(), zeros() is called
2818+
// for every declaration.
2819+
assign_value = ASRUtils::EXPR(tmp);
2820+
}
2821+
for (size_t i=0; i<x.n_targets; i++) {
2822+
tmp_value = assign_value;
2823+
if (AST::is_a<AST::Subscript_t>(*x.m_targets[i])) {
2824+
AST::Subscript_t *sb = AST::down_cast<AST::Subscript_t>(x.m_targets[i]);
27892825
if (AST::is_a<AST::Name_t>(*sb->m_value)) {
27902826
std::string name = AST::down_cast<AST::Name_t>(sb->m_value)->m_id;
27912827
ASR::symbol_t *s = current_scope->get_symbol(name);
@@ -2799,8 +2835,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
27992835
// dict insert case;
28002836
this->visit_expr(*sb->m_slice);
28012837
ASR::expr_t *key = ASRUtils::EXPR(tmp);
2802-
this->visit_expr(*x.m_value);
2803-
ASR::expr_t *val = ASRUtils::EXPR(tmp);
28042838
ASR::ttype_t *key_type = ASR::down_cast<ASR::Dict_t>(type)->m_key_type;
28052839
ASR::ttype_t *value_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;
28062840
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) {
@@ -2815,29 +2849,30 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
28152849
);
28162850
throw SemanticAbort();
28172851
}
2818-
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(val), value_type)) {
2819-
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(val));
2852+
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) {
2853+
std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value));
28202854
std::string totype = ASRUtils::type_to_str_python(value_type);
28212855
diag.add(diag::Diagnostic(
28222856
"Type mismatch in dictionary value, the types must be compatible",
28232857
diag::Level::Error, diag::Stage::Semantic, {
28242858
diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')",
2825-
{val->base.loc})
2859+
{tmp_value->base.loc})
28262860
})
28272861
);
28282862
throw SemanticAbort();
28292863
}
28302864
ASR::expr_t* se = ASR::down_cast<ASR::expr_t>(
28312865
ASR::make_Var_t(al, x.base.base.loc, s));
2832-
tmp = make_DictInsert_t(al, x.base.base.loc, se, key, val);
2833-
return;
2866+
tmp = nullptr;
2867+
tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value));
2868+
continue;
28342869
} else if (ASRUtils::is_immutable(type)) {
28352870
throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support"
28362871
" item assignment", x.base.base.loc);
28372872
}
28382873
}
2839-
} else if (AST::is_a<AST::Attribute_t>(*x.m_targets[0])) {
2840-
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_targets[0]);
2874+
} else if (AST::is_a<AST::Attribute_t>(*x.m_targets[i])) {
2875+
AST::Attribute_t *attr = AST::down_cast<AST::Attribute_t>(x.m_targets[i]);
28412876
if (AST::is_a<AST::Name_t>(*attr->m_value)) {
28422877
std::string name = AST::down_cast<AST::Name_t>(attr->m_value)->m_id;
28432878
ASR::symbol_t *s = current_scope->get_symbol(name);
@@ -2852,62 +2887,51 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
28522887
}
28532888
}
28542889
}
2855-
this->visit_expr(*x.m_targets[0]);
2890+
if (!tmp_value) continue;
2891+
this->visit_expr(*x.m_targets[i]);
28562892
target = ASRUtils::EXPR(tmp);
2857-
} else {
2858-
throw SemanticError("Assignment to multiple targets not supported",
2859-
x.base.base.loc);
2860-
}
2861-
2862-
this->visit_expr(*x.m_value);
2863-
if (tmp == nullptr) {
2864-
// This happens if `m.m_value` is `empty`, such as in:
2865-
// a = empty(16)
2866-
// We skip this statement for now, the array is declared
2867-
// by the annotation.
2868-
// TODO: enforce that empty(), ones(), zeros() is called
2869-
// for every declaration.
2870-
tmp = nullptr;
2871-
} else {
2872-
ASR::expr_t *value = ASRUtils::EXPR(tmp);
28732893
ASR::ttype_t *target_type = ASRUtils::expr_type(target);
2874-
ASR::ttype_t *value_type = ASRUtils::expr_type(value);
2894+
ASR::ttype_t *value_type = ASRUtils::expr_type(tmp_value);
28752895
if( ASR::is_a<ASR::Pointer_t>(*target_type) &&
28762896
ASR::is_a<ASR::Var_t>(*target) ) {
2877-
if( !ASR::is_a<ASR::GetPointer_t>(*value) ) {
2897+
if( !ASR::is_a<ASR::GetPointer_t>(*tmp_value) ) {
28782898
throw SemanticError("A pointer variable can only "
28792899
"be associated with the output "
28802900
"of pointer() call.",
2881-
value->base.loc);
2901+
tmp_value->base.loc);
28822902
}
28832903
if( !ASRUtils::check_equal_type(target_type, value_type) ) {
28842904
throw SemanticError("Casting not supported for different pointer types. Received "
28852905
"target pointer type, " + ASRUtils::type_to_str_python(target_type) +
28862906
" and value pointer type, " + ASRUtils::type_to_str_python(value_type),
28872907
x.base.base.loc);
28882908
}
2889-
tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, value, nullptr);
2890-
return ;
2909+
tmp = nullptr;
2910+
tmp_vec.push_back(ASR::make_Assignment_t(al, x.base.base.loc, target,
2911+
tmp_value, nullptr));
2912+
continue;
28912913
}
2892-
2893-
cast_helper(target, value, true);
2894-
value_type = ASRUtils::expr_type(value);
2914+
cast_helper(target, tmp_value, true);
2915+
value_type = ASRUtils::expr_type(tmp_value);
28952916
if (!ASRUtils::check_equal_type(target_type, value_type)) {
28962917
std::string ltype = ASRUtils::type_to_str_python(target_type);
28972918
std::string rtype = ASRUtils::type_to_str_python(value_type);
28982919
diag.add(diag::Diagnostic(
28992920
"Type mismatch in assignment, the types must be compatible",
29002921
diag::Level::Error, diag::Stage::Semantic, {
29012922
diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')",
2902-
{target->base.loc, value->base.loc})
2923+
{target->base.loc, tmp_value->base.loc})
29032924
})
29042925
);
29052926
throw SemanticAbort();
29062927
}
29072928
ASR::stmt_t *overloaded=nullptr;
2908-
tmp = ASR::make_Assignment_t(al, x.base.base.loc, target, value,
2909-
overloaded);
2929+
tmp = nullptr;
2930+
tmp_vec.push_back(ASR::make_Assignment_t(al, x.base.base.loc, target, tmp_value,
2931+
overloaded));
29102932
}
2933+
// to make sure that we add only those statements in tmp_vec
2934+
tmp = nullptr;
29112935
}
29122936

29132937
void visit_Assert(const AST::Assert_t &x) {

src/lpython/semantics/python_comptime_eval.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,6 @@ struct PythonIntrinsicProcedures {
659659
throw SemanticError("str.capitalize() takes no arguments", loc);
660660
}
661661
ASR::expr_t *arg = args[0];
662-
ASR::ttype_t *arg_type = ASRUtils::expr_type(arg);
663662
std::string val = ASR::down_cast<ASR::StringConstant_t>(arg)->m_s;
664663
if (val.size()) {
665664
val[0] = std::toupper(val[0]);
@@ -676,7 +675,6 @@ struct PythonIntrinsicProcedures {
676675
throw SemanticError("str.lower() takes no arguments", loc);
677676
}
678677
ASR::expr_t *arg = args[0];
679-
ASR::ttype_t *arg_type = ASRUtils::expr_type(arg);
680678
std::string val = ASR::down_cast<ASR::StringConstant_t>(arg)->m_s;
681679
for (auto &i: val) {
682680
if (i >= 'A' && i <= 'Z') {

tests/reference/asr-expr_09-f3e89c8.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"basename": "asr-expr_09-f3e89c8",
33
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
44
"infile": "tests/../integration_tests/expr_09.py",
5-
"infile_hash": "7a3cdb6538c8d2d8e4555683aeac4f9b074be2fcaa6fe4532c01bf1a",
5+
"infile_hash": "51dfe55e01443840104d583e5e21ba3dd48fa33a95f1f943aac1d5d0",
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-expr_09-f3e89c8.stdout",
9-
"stdout_hash": "167f5176a21663f13aff75078c19a6bd7e07d7cab2605ef6302d9b8a",
9+
"stdout_hash": "66cf441a7ed60ad292ae9933ae51d8aac76f803201b809eac2689a33",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)