Skip to content

Commit b538dcd

Browse files
authored
Fixed handling of member access of pointer to structs (#1308)
1 parent 26d828e commit b538dcd

File tree

4 files changed

+54
-36
lines changed

4 files changed

+54
-36
lines changed

integration_tests/structs_13.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def is_null(ptr: CPtr) -> i32:
1616
def add_A_members(Ax: i32, Ay: i16) -> i32:
1717
return Ax + i32(Ay)
1818

19+
@ccall
20+
def add_Aptr_members(Ax: i32, Ay: i16) -> i32:
21+
pass
22+
1923
def test_A_member_passing():
2024
array_cptr: CPtr = cmalloc(sizeof(A) * i64(10))
2125
assert not bool(is_null(array_cptr)), "Failed to allocate array on memory"
@@ -40,4 +44,13 @@ def test_A_member_passing():
4044
print(sum_A_members)
4145
assert sum_A_members == 2*i + 1
4246

47+
def test_Aptr_member_passing():
48+
a_cptr: CPtr = cmalloc(sizeof(A) * i64(1))
49+
assert not bool(is_null(a_cptr)), "Failed to allocate array on memory"
50+
a_ptr: Pointer[A]
51+
c_p_pointer(a_cptr, a_ptr)
52+
print(add_A_members(a_ptr.x, a_ptr.y), add_Aptr_members(a_ptr.x, a_ptr.y))
53+
assert add_A_members(a_ptr.x, a_ptr.y) == add_Aptr_members(a_ptr.x, a_ptr.y)
54+
4355
test_A_member_passing()
56+
test_Aptr_member_passing()

integration_tests/structs_13b.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ void* cmalloc(int64_t size) {
77
int32_t is_null(void* ptr) {
88
return ptr == NULL;
99
}
10+
11+
int32_t add_Aptr_members(int32_t Ax, int16_t Ay) {
12+
return Ax + Ay;
13+
}

integration_tests/structs_13b.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
void* cmalloc(int64_t size);
44
int32_t is_null(void* ptr);
5+
int32_t add_Aptr_members(int32_t Ax, int16_t Ay);

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
230230
std::unique_ptr<LLVMDictInterface> dict_api_sc;
231231
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;
232232

233-
uint64_t ptr_loads;
233+
int64_t ptr_loads;
234234
bool lookup_enum_value_for_nonints;
235235
bool is_assignment_target;
236236

@@ -1378,7 +1378,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13781378
llvm::Type* const_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size);
13791379
llvm::Value* const_list = builder->CreateAlloca(const_list_type, nullptr, "const_list");
13801380
list_api->list_init(type_code, const_list, *module, x.n_args, x.n_args);
1381-
uint64_t ptr_loads_copy = ptr_loads;
1381+
int64_t ptr_loads_copy = ptr_loads;
13821382
ptr_loads = 1;
13831383
for( size_t i = 0; i < x.n_args; i++ ) {
13841384
this->visit_expr(*x.m_args[i]);
@@ -1407,9 +1407,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14071407
std::string key_type_code = ASRUtils::get_type_code(x_dict->m_key_type);
14081408
std::string value_type_code = ASRUtils::get_type_code(x_dict->m_value_type);
14091409
llvm_utils->dict_api->dict_init(key_type_code, value_type_code, const_dict, module.get(), x.n_keys);
1410-
uint64_t ptr_loads_key = LLVM::is_llvm_struct(x_dict->m_key_type) ? 0 : 2;
1411-
uint64_t ptr_loads_value = LLVM::is_llvm_struct(x_dict->m_value_type) ? 0 : 2;
1412-
uint64_t ptr_loads_copy = ptr_loads;
1410+
int64_t ptr_loads_key = LLVM::is_llvm_struct(x_dict->m_key_type) ? 0 : 2;
1411+
int64_t ptr_loads_value = LLVM::is_llvm_struct(x_dict->m_value_type) ? 0 : 2;
1412+
int64_t ptr_loads_copy = ptr_loads;
14131413
for( size_t i = 0; i < x.n_keys; i++ ) {
14141414
ptr_loads = ptr_loads_key;
14151415
visit_expr(*x.m_keys[i]);
@@ -1442,7 +1442,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14421442
llvm::Type* const_tuple_type = tuple_api->get_tuple_type(type_code, llvm_el_types);
14431443
llvm::Value* const_tuple = builder->CreateAlloca(const_tuple_type, nullptr, "const_tuple");
14441444
std::vector<llvm::Value*> init_values;
1445-
uint64_t ptr_loads_copy = ptr_loads;
1445+
int64_t ptr_loads_copy = ptr_loads;
14461446
ptr_loads = 2;
14471447
for( size_t i = 0; i < x.n_elements; i++ ) {
14481448
this->visit_expr(*x.m_elements[i]);
@@ -1476,7 +1476,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14761476

14771477
void visit_ListAppend(const ASR::ListAppend_t& x) {
14781478
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(ASRUtils::expr_type(x.m_a));
1479-
uint64_t ptr_loads_copy = ptr_loads;
1479+
int64_t ptr_loads_copy = ptr_loads;
14801480
ptr_loads = 0;
14811481
this->visit_expr(*x.m_a);
14821482
llvm::Value* plist = tmp;
@@ -1490,7 +1490,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14901490
}
14911491

14921492
void visit_UnionRef(const ASR::UnionRef_t& x) {
1493-
uint64_t ptr_loads_copy = ptr_loads;
1493+
int64_t ptr_loads_copy = ptr_loads;
14941494
ptr_loads = 0;
14951495
this->visit_expr(*x.m_v);
14961496
ptr_loads = ptr_loads_copy;
@@ -1515,7 +1515,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15151515
void visit_ListItem(const ASR::ListItem_t& x) {
15161516
ASR::ttype_t* el_type = ASRUtils::get_contained_type(
15171517
ASRUtils::expr_type(x.m_a));
1518-
uint64_t ptr_loads_copy = ptr_loads;
1518+
int64_t ptr_loads_copy = ptr_loads;
15191519
ptr_loads = 0;
15201520
this->visit_expr(*x.m_a);
15211521
llvm::Value* plist = tmp;
@@ -1532,7 +1532,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15321532
void visit_DictItem(const ASR::DictItem_t& x) {
15331533
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
15341534
ASRUtils::expr_type(x.m_a));
1535-
uint64_t ptr_loads_copy = ptr_loads;
1535+
int64_t ptr_loads_copy = ptr_loads;
15361536
ptr_loads = 0;
15371537
this->visit_expr(*x.m_a);
15381538
llvm::Value* pdict = tmp;
@@ -1550,7 +1550,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15501550
void visit_DictPop(const ASR::DictPop_t& x) {
15511551
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
15521552
ASRUtils::expr_type(x.m_a));
1553-
uint64_t ptr_loads_copy = ptr_loads;
1553+
int64_t ptr_loads_copy = ptr_loads;
15541554
ptr_loads = 0;
15551555
this->visit_expr(*x.m_a);
15561556
llvm::Value* pdict = tmp;
@@ -1569,7 +1569,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15691569
if (x.m_value) {
15701570
this->visit_expr(*x.m_value);
15711571
} else {
1572-
uint64_t ptr_loads_copy = ptr_loads;
1572+
int64_t ptr_loads_copy = ptr_loads;
15731573
ptr_loads = 0;
15741574
this->visit_expr(*x.m_arg);
15751575
ptr_loads = ptr_loads_copy;
@@ -1584,7 +1584,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15841584
return ;
15851585
}
15861586

1587-
uint64_t ptr_loads_copy = ptr_loads;
1587+
int64_t ptr_loads_copy = ptr_loads;
15881588
ptr_loads = 0;
15891589
this->visit_expr(*x.m_arg);
15901590
ptr_loads = ptr_loads_copy;
@@ -1597,7 +1597,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15971597
void visit_ListInsert(const ASR::ListInsert_t& x) {
15981598
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(
15991599
ASRUtils::expr_type(x.m_a));
1600-
uint64_t ptr_loads_copy = ptr_loads;
1600+
int64_t ptr_loads_copy = ptr_loads;
16011601
ptr_loads = 0;
16021602
this->visit_expr(*x.m_a);
16031603
llvm::Value* plist = tmp;
@@ -1617,7 +1617,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16171617
void visit_DictInsert(const ASR::DictInsert_t& x) {
16181618
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
16191619
ASRUtils::expr_type(x.m_a));
1620-
uint64_t ptr_loads_copy = ptr_loads;
1620+
int64_t ptr_loads_copy = ptr_loads;
16211621
ptr_loads = 0;
16221622
this->visit_expr(*x.m_a);
16231623
llvm::Value* pdict = tmp;
@@ -1638,7 +1638,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16381638

16391639
void visit_ListRemove(const ASR::ListRemove_t& x) {
16401640
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_a));
1641-
uint64_t ptr_loads_copy = ptr_loads;
1641+
int64_t ptr_loads_copy = ptr_loads;
16421642
ptr_loads = 0;
16431643
this->visit_expr(*x.m_a);
16441644
llvm::Value* plist = tmp;
@@ -1651,7 +1651,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16511651
}
16521652

16531653
void visit_ListClear(const ASR::ListClear_t& x) {
1654-
uint64_t ptr_loads_copy = ptr_loads;
1654+
int64_t ptr_loads_copy = ptr_loads;
16551655
ptr_loads = 0;
16561656
this->visit_expr(*x.m_a);
16571657
llvm::Value* plist = tmp;
@@ -1666,7 +1666,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16661666
}
16671667

16681668
void visit_TupleItem(const ASR::TupleItem_t& x) {
1669-
uint64_t ptr_loads_copy = ptr_loads;
1669+
int64_t ptr_loads_copy = ptr_loads;
16701670
ptr_loads = 0;
16711671
this->visit_expr(*x.m_a);
16721672
ptr_loads = ptr_loads_copy;
@@ -1742,7 +1742,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17421742
std::vector<llvm::Value*> indices;
17431743
for( size_t r = 0; r < x.n_args; r++ ) {
17441744
ASR::array_index_t curr_idx = x.m_args[r];
1745-
uint64_t ptr_loads_copy = ptr_loads;
1745+
int64_t ptr_loads_copy = ptr_loads;
17461746
ptr_loads = 2;
17471747
this->visit_expr_wrapper(curr_idx.m_right, true);
17481748
ptr_loads = ptr_loads_copy;
@@ -1919,11 +1919,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19191919
}
19201920
der_type_name = "";
19211921
ASR::ttype_t* x_m_v_type = ASRUtils::expr_type(x.m_v);
1922-
uint64_t ptr_loads_copy = ptr_loads;
1922+
int64_t ptr_loads_copy = ptr_loads;
19231923
if( ASR::is_a<ASR::UnionRef_t>(*x.m_v) ) {
19241924
ptr_loads = 0;
19251925
} else {
1926-
ptr_loads = ptr_loads_copy - ASR::is_a<ASR::Pointer_t>(*x_m_v_type);
1926+
ptr_loads = 2 - ASR::is_a<ASR::Pointer_t>(*x_m_v_type);
19271927
}
19281928
this->visit_expr(*x.m_v);
19291929
ptr_loads = ptr_loads_copy;
@@ -3628,7 +3628,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
36283628
}
36293629

36303630
void visit_CLoc(const ASR::CLoc_t& x) {
3631-
uint64_t ptr_loads_copy = ptr_loads;
3631+
int64_t ptr_loads_copy = ptr_loads;
36323632
ptr_loads = 0;
36333633
this->visit_expr(*x.m_arg);
36343634
ptr_loads = ptr_loads_copy;
@@ -3677,7 +3677,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
36773677
}
36783678

36793679
void visit_GetPointer(const ASR::GetPointer_t& x) {
3680-
uint64_t ptr_loads_copy = ptr_loads;
3680+
int64_t ptr_loads_copy = ptr_loads;
36813681
ptr_loads = 0;
36823682
this->visit_expr(*x.m_arg);
36833683
ptr_loads = ptr_loads_copy;
@@ -3686,7 +3686,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
36863686
}
36873687

36883688
void visit_PointerToCPtr(const ASR::PointerToCPtr_t& x) {
3689-
uint64_t ptr_loads_copy = ptr_loads;
3689+
int64_t ptr_loads_copy = ptr_loads;
36903690
ptr_loads = 0;
36913691
this->visit_expr(*x.m_arg);
36923692
ptr_loads = ptr_loads_copy;
@@ -3708,7 +3708,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37083708
reduce_loads = cptr_var->m_intent == ASRUtils::intent_in;
37093709
}
37103710
if( ASRUtils::is_array(ASRUtils::expr_type(fptr)) ) {
3711-
uint64_t ptr_loads_copy = ptr_loads;
3711+
int64_t ptr_loads_copy = ptr_loads;
37123712
ptr_loads = 1 - reduce_loads;
37133713
this->visit_expr(*cptr);
37143714
llvm::Value* llvm_cptr = tmp;
@@ -3758,7 +3758,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37583758
builder->CreateStore(builder->CreateAdd(builder->CreateSub(new_ub, new_lb), i32_one), desi_size);
37593759
}
37603760
} else {
3761-
uint64_t ptr_loads_copy = ptr_loads;
3761+
int64_t ptr_loads_copy = ptr_loads;
37623762
ptr_loads = 1 - reduce_loads;
37633763
this->visit_expr(*cptr);
37643764
llvm::Value* llvm_cptr = tmp;
@@ -3798,7 +3798,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37983798
bool is_target_struct = ASR::is_a<ASR::Struct_t>(*asr_target_type);
37993799
bool is_value_struct = ASR::is_a<ASR::Struct_t>(*asr_value_type);
38003800
if( is_target_list && is_value_list ) {
3801-
uint64_t ptr_loads_copy = ptr_loads;
3801+
int64_t ptr_loads_copy = ptr_loads;
38023802
ptr_loads = 0;
38033803
this->visit_expr(*x.m_target);
38043804
llvm::Value* target_list = tmp;
@@ -3813,7 +3813,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38133813
name2memidx);
38143814
return ;
38153815
} else if( is_target_tuple && is_value_tuple ) {
3816-
uint64_t ptr_loads_copy = ptr_loads;
3816+
int64_t ptr_loads_copy = ptr_loads;
38173817
if( ASR::is_a<ASR::TupleConstant_t>(*x.m_target) &&
38183818
!ASR::is_a<ASR::TupleConstant_t>(*x.m_value) ) {
38193819
ptr_loads = 0;
@@ -3868,7 +3868,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38683868
}
38693869
return ;
38703870
} else if( is_target_dict && is_value_dict ) {
3871-
uint64_t ptr_loads_copy = ptr_loads;
3871+
int64_t ptr_loads_copy = ptr_loads;
38723872
ptr_loads = 0;
38733873
this->visit_expr(*x.m_value);
38743874
llvm::Value* value_dict = tmp;
@@ -3881,7 +3881,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38813881
value_dict_type, module.get(), name2memidx);
38823882
return ;
38833883
} else if( is_target_struct && is_value_struct ) {
3884-
uint64_t ptr_loads_copy = ptr_loads;
3884+
int64_t ptr_loads_copy = ptr_loads;
38853885
ptr_loads = 0;
38863886
this->visit_expr(*x.m_value);
38873887
llvm::Value* value_struct = tmp;
@@ -3944,7 +3944,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
39443944
}
39453945
} else if( ASR::is_a<ASR::ListItem_t>(*x.m_target) ) {
39463946
ASR::ListItem_t* asr_target0 = ASR::down_cast<ASR::ListItem_t>(x.m_target);
3947-
uint64_t ptr_loads_copy = ptr_loads;
3947+
int64_t ptr_loads_copy = ptr_loads;
39483948
ptr_loads = 0;
39493949
this->visit_expr(*asr_target0->m_a);
39503950
ptr_loads = ptr_loads_copy;
@@ -4961,7 +4961,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
49614961
uint32_t x_h = get_hash((ASR::asr_t*)x);
49624962
LFORTRAN_ASSERT(llvm_symtab.find(x_h) != llvm_symtab.end());
49634963
llvm::Value* x_v = llvm_symtab[x_h];
4964-
uint64_t ptr_loads_copy = ptr_loads;
4964+
int64_t ptr_loads_copy = ptr_loads;
49654965
tmp = x_v;
49664966
while( ptr_loads_copy-- ) {
49674967
tmp = CreateLoad(tmp);
@@ -5507,7 +5507,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
55075507
end = builder->CreateGlobalStringPtr("\n");
55085508
}
55095509
for (size_t i=0; i<x.n_values; i++) {
5510-
uint64_t ptr_loads_copy = ptr_loads;
5510+
int64_t ptr_loads_copy = ptr_loads;
55115511
int reduce_loads = 0;
55125512
ptr_loads = 2;
55135513
if( ASR::is_a<ASR::Var_t>(*x.m_values[i]) ) {
@@ -5882,7 +5882,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
58825882
}
58835883
} else {
58845884
ASR::ttype_t* arg_type = expr_type(x.m_args[i].m_value);
5885-
uint64_t ptr_loads_copy = ptr_loads;
5885+
int64_t ptr_loads_copy = ptr_loads;
58865886
ptr_loads = !LLVM::is_llvm_struct(arg_type);
58875887
this->visit_expr_wrapper(x.m_args[i].m_value);
58885888
if( x_abi == ASR::abiType::BindC ) {
@@ -6348,7 +6348,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
63486348
}
63496349
int output_kind = ASRUtils::extract_kind_from_ttype_t(x.m_type);
63506350
int dim_kind = 4;
6351-
uint64_t ptr_loads_copy = ptr_loads;
6351+
int64_t ptr_loads_copy = ptr_loads;
63526352
ptr_loads = 2 - // Sync: instead of 2 - , should this be ptr_loads_copy -
63536353
(ASRUtils::expr_type(x.m_v)->type ==
63546354
ASR::ttypeType::Pointer);
@@ -6380,7 +6380,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
63806380
tmp = llvm::ConstantInt::get(context, llvm::APInt(kind * 8, bound_value));
63816381
return ;
63826382
}
6383-
uint64_t ptr_loads_copy = ptr_loads;
6383+
int64_t ptr_loads_copy = ptr_loads;
63846384
ptr_loads = 2 - // Sync: instead of 2 - , should this be ptr_loads_copy -
63856385
(ASRUtils::expr_type(x.m_v)->type ==
63866386
ASR::ttypeType::Pointer);

0 commit comments

Comments
 (0)