Skip to content

Commit ebbf0a1

Browse files
authored
Merge pull request #1295 from czgdp1807/struct_ret
Handle structs as return type in LLVM backend
2 parents a4ef3d2 + 356ff38 commit ebbf0a1

13 files changed

+529
-143
lines changed

integration_tests/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ RUN(NAME structs_07 LABELS llvm c
311311
EXTRAFILES structs_07b.c)
312312
RUN(NAME structs_08 LABELS cpython llvm c)
313313
RUN(NAME structs_09 LABELS cpython llvm c)
314-
RUN(NAME structs_10 LABELS cpython llvm c)
314+
# TODO: Re-enable c in structs_10
315+
RUN(NAME structs_10 LABELS cpython llvm)
316+
RUN(NAME structs_11 LABELS cpython llvm)
315317
RUN(NAME structs_12 LABELS cpython llvm c)
316318
RUN(NAME structs_13 LABELS llvm c
317319
EXTRAFILES structs_13b.c)

integration_tests/structs_10.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,25 @@
22
from numpy import empty, float64
33

44
@dataclass
5-
class MatVec:
5+
class Mat:
66
mat: f64[2, 2]
7+
8+
@dataclass
9+
class Vec:
710
vec: f64[2]
811

12+
@dataclass
13+
class MatVec:
14+
mat: Mat
15+
vec: Vec
16+
917
def rotate(mat_vec: MatVec) -> f64[2]:
1018
rotated_vec: f64[2] = empty(2, dtype=float64)
11-
rotated_vec[0] = mat_vec.mat[0, 0] * mat_vec.vec[0] + mat_vec.mat[0, 1] * mat_vec.vec[1]
12-
rotated_vec[1] = mat_vec.mat[1, 0] * mat_vec.vec[0] + mat_vec.mat[1, 1] * mat_vec.vec[1]
19+
rotated_vec[0] = mat_vec.mat.mat[0, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[0, 1] * mat_vec.vec.vec[1]
20+
rotated_vec[1] = mat_vec.mat.mat[1, 0] * mat_vec.vec.vec[0] + mat_vec.mat.mat[1, 1] * mat_vec.vec.vec[1]
1321
return rotated_vec
1422

15-
def test_rotate_by_90():
23+
def create_MatVec_obj() -> MatVec:
1624
mat: f64[2, 2] = empty((2, 2), dtype=float64)
1725
vec: f64[2] = empty(2, dtype=float64)
1826
mat[0, 0] = 0.0
@@ -21,9 +29,15 @@ def test_rotate_by_90():
2129
mat[1, 1] = 0.0
2230
vec[0] = 1.0
2331
vec[1] = 0.0
24-
mat_vec: MatVec = MatVec(mat, vec)
25-
print(mat_vec.mat[0, 0], mat_vec.mat[0, 1], mat_vec.mat[1, 0], mat_vec.mat[1, 1])
26-
print(mat_vec.vec[0], mat_vec.vec[1])
32+
mat_s: Mat = Mat(mat)
33+
vec_s: Vec = Vec(vec)
34+
mat_vec: MatVec = MatVec(mat_s, vec_s)
35+
return mat_vec
36+
37+
def test_rotate_by_90():
38+
mat_vec: MatVec = create_MatVec_obj()
39+
print(mat_vec.mat.mat[0, 0], mat_vec.mat.mat[0, 1], mat_vec.mat.mat[1, 0], mat_vec.mat.mat[1, 1])
40+
print(mat_vec.vec.vec[0], mat_vec.vec.vec[1])
2741
rotated_vec: f64[2] = rotate(mat_vec)
2842
print(rotated_vec[0], rotated_vec[1])
2943
assert abs(rotated_vec[0] - 0.0) <= 1e-12

integration_tests/structs_11.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from ltypes import i32, f64, dataclass
2+
3+
@dataclass
4+
class A:
5+
x: i32
6+
y: f64
7+
8+
def f(x_: i32, y_: f64) -> A:
9+
a_struct: A = A(x_, y_)
10+
return a_struct
11+
12+
def test_struct_return():
13+
b: A = f(0, 1.0)
14+
print(b.x, b.y)
15+
assert b.x == 0
16+
assert b.y == 1.0
17+
18+
test_struct_return()

src/libasr/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(SRC
3434
pass/select_case.cpp
3535
pass/implied_do_loops.cpp
3636
pass/array_op.cpp
37+
pass/subroutine_from_function.cpp
3738
pass/class_constructor.cpp
3839
pass/arr_slice.cpp
3940
pass/print_arr.cpp

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
270270
llvm_utils->tuple_api = tuple_api.get();
271271
llvm_utils->list_api = list_api.get();
272272
llvm_utils->dict_api = nullptr;
273+
llvm_utils->arr_api = arr_descr.get();
273274
}
274275

275276
llvm::Value* CreateLoad(llvm::Value *x) {
@@ -1383,7 +1384,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13831384
this->visit_expr(*x.m_args[i]);
13841385
llvm::Value* item = tmp;
13851386
llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i));
1386-
list_api->write_item(const_list, pos, item, list_type->m_type, false, *module);
1387+
list_api->write_item(const_list, pos, item, list_type->m_type,
1388+
false, module.get(), name2memidx);
13871389
}
13881390
ptr_loads = ptr_loads_copy;
13891391
tmp = const_list;
@@ -1416,7 +1418,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14161418
visit_expr(*x.m_values[i]);
14171419
llvm::Value* value = tmp;
14181420
llvm_utils->dict_api->write_item(const_dict, key, value, module.get(),
1419-
x_dict->m_key_type, x_dict->m_value_type);
1421+
x_dict->m_key_type, x_dict->m_value_type, name2memidx);
14201422
}
14211423
ptr_loads = ptr_loads_copy;
14221424
tmp = const_dict;
@@ -1484,7 +1486,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
14841486
llvm::Value *item = tmp;
14851487
ptr_loads = ptr_loads_copy;
14861488

1487-
list_api->append(plist, item, asr_list->m_type, *module);
1489+
list_api->append(plist, item, asr_list->m_type, module.get(), name2memidx);
14881490
}
14891491

14901492
void visit_UnionRef(const ASR::UnionRef_t& x) {
@@ -1609,7 +1611,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16091611
llvm::Value *item = tmp;
16101612
ptr_loads = ptr_loads_copy;
16111613

1612-
list_api->insert_item(plist, pos, item, asr_list->m_type, *module);
1614+
list_api->insert_item(plist, pos, item, asr_list->m_type, module.get(), name2memidx);
16131615
}
16141616

16151617
void visit_DictInsert(const ASR::DictInsert_t& x) {
@@ -1631,7 +1633,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16311633
set_dict_api(dict_type);
16321634
llvm_utils->dict_api->write_item(pdict, key, value, module.get(),
16331635
dict_type->m_key_type,
1634-
dict_type->m_value_type);
1636+
dict_type->m_value_type, name2memidx);
16351637
}
16361638

16371639
void visit_ListRemove(const ASR::ListRemove_t& x) {
@@ -2571,6 +2573,49 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
25712573
m_dims_local, n_dims_local, a_kind_local);
25722574
}
25732575

2576+
void fill_array_details_(llvm::Value* ptr, ASR::dimension_t* m_dims,
2577+
size_t n_dims, bool is_malloc_array_type, bool is_array_type,
2578+
bool is_list, ASR::ttype_t* m_type) {
2579+
if( is_malloc_array_type &&
2580+
m_type->type != ASR::ttypeType::Pointer &&
2581+
!is_list ) {
2582+
arr_descr->fill_dimension_descriptor(ptr, n_dims);
2583+
}
2584+
if( is_array_type && !is_malloc_array_type &&
2585+
m_type->type != ASR::ttypeType::Pointer &&
2586+
!is_list ) {
2587+
ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, m_type, m_type->base.loc);
2588+
llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type);
2589+
fill_array_details(ptr, llvm_data_type, m_dims, n_dims);
2590+
}
2591+
if( is_array_type && is_malloc_array_type &&
2592+
m_type->type != ASR::ttypeType::Pointer &&
2593+
!is_list ) {
2594+
// Set allocatable arrays as unallocated
2595+
arr_descr->set_is_allocated_flag(ptr, 0);
2596+
}
2597+
}
2598+
2599+
void allocate_array_members_of_struct(llvm::Value* ptr, ASR::ttype_t* asr_type) {
2600+
LFORTRAN_ASSERT(ASR::is_a<ASR::Struct_t>(*asr_type));
2601+
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
2602+
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(struct_t->m_derived_type);
2603+
std::string struct_type_name = struct_type_t->m_name;
2604+
for( auto item: struct_type_t->m_symtab->get_scope() ) {
2605+
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second);
2606+
int idx = name2memidx[struct_type_name][item.first];
2607+
llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx);
2608+
if( ASRUtils::is_array(symbol_type) ) {
2609+
// Assume that struct member array is not allocatable
2610+
ASR::dimension_t* m_dims = nullptr;
2611+
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(symbol_type, m_dims);
2612+
fill_array_details_(ptr_member, m_dims, n_dims, false, true, false, symbol_type);
2613+
} else if( ASR::is_a<ASR::Struct_t>(*symbol_type) ) {
2614+
allocate_array_members_of_struct(ptr_member, symbol_type);
2615+
}
2616+
}
2617+
}
2618+
25742619
template<typename T>
25752620
void declare_vars(const T &x) {
25762621
llvm::Value *target_var;
@@ -2626,6 +2671,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
26262671
}
26272672
}
26282673
llvm::AllocaInst *ptr = builder->CreateAlloca(type, nullptr, v->m_name);
2674+
if( ASR::is_a<ASR::Struct_t>(*v->m_type) ) {
2675+
allocate_array_members_of_struct(ptr, v->m_type);
2676+
}
26292677
if (emit_debug_info) {
26302678
// Reset the debug location
26312679
builder->SetCurrentDebugLocation(nullptr);
@@ -2659,24 +2707,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
26592707
}
26602708
}
26612709
llvm_symtab[h] = ptr;
2662-
if( is_malloc_array_type &&
2663-
v->m_type->type != ASR::ttypeType::Pointer &&
2664-
!is_list ) {
2665-
arr_descr->fill_dimension_descriptor(ptr, n_dims);
2666-
}
2667-
if( is_array_type && !is_malloc_array_type &&
2668-
v->m_type->type != ASR::ttypeType::Pointer &&
2669-
!is_list ) {
2670-
ASR::ttype_t* asr_data_type = ASRUtils::duplicate_type_without_dims(al, v->m_type, v->m_type->base.loc);
2671-
llvm::Type* llvm_data_type = get_type_from_ttype_t_util(asr_data_type);
2672-
fill_array_details(ptr, llvm_data_type, m_dims, n_dims);
2673-
}
2674-
if( is_array_type && is_malloc_array_type &&
2675-
v->m_type->type != ASR::ttypeType::Pointer &&
2676-
!is_list ) {
2677-
// Set allocatable arrays as unallocated
2678-
arr_descr->set_is_allocated_flag(ptr, 0);
2679-
}
2710+
fill_array_details_(ptr, m_dims, n_dims,
2711+
is_malloc_array_type,
2712+
is_array_type, is_list, v->m_type);
26802713
if( v->m_symbolic_value != nullptr &&
26812714
!ASR::is_a<ASR::List_t>(*v->m_type)) {
26822715
target_var = ptr;
@@ -3776,7 +3809,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
37763809
ASRUtils::expr_type(x.m_value));
37773810
std::string value_type_code = ASRUtils::get_type_code(value_asr_list->m_type);
37783811
list_api->list_deepcopy(value_list, target_list,
3779-
value_asr_list, *module);
3812+
value_asr_list, module.get(),
3813+
name2memidx);
37803814
return ;
37813815
} else if( is_target_tuple && is_value_tuple ) {
37823816
uint64_t ptr_loads_copy = ptr_loads;
@@ -3805,7 +3839,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38053839
llvm::Value* llvm_tuple_i = builder->CreateAlloca(llvm_tuple_i_type, nullptr);
38063840
ptr_loads = !LLVM::is_llvm_struct(asr_tuple_i_type);
38073841
visit_expr(*asr_value_tuple->m_elements[i]);
3808-
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, *module);
3842+
llvm_utils->deepcopy(tmp, llvm_tuple_i, asr_tuple_i_type, module.get(), name2memidx);
38093843
src_deepcopies.push_back(al, llvm_tuple_i);
38103844
}
38113845
ASR::TupleConstant_t* asr_target_tuple = ASR::down_cast<ASR::TupleConstant_t>(x.m_target);
@@ -3829,7 +3863,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38293863
std::string type_code = ASRUtils::get_type_code(value_tuple_type->m_type,
38303864
value_tuple_type->n_type);
38313865
tuple_api->tuple_deepcopy(value_tuple, target_tuple,
3832-
value_tuple_type, *module);
3866+
value_tuple_type, module.get(),
3867+
name2memidx);
38333868
}
38343869
return ;
38353870
} else if( is_target_dict && is_value_dict ) {
@@ -3843,7 +3878,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38433878
ASR::Dict_t* value_dict_type = ASR::down_cast<ASR::Dict_t>(asr_value_type);
38443879
set_dict_api(value_dict_type);
38453880
llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict,
3846-
value_dict_type, module.get());
3881+
value_dict_type, module.get(), name2memidx);
38473882
return ;
38483883
} else if( is_target_struct && is_value_struct ) {
38493884
uint64_t ptr_loads_copy = ptr_loads;
@@ -3856,10 +3891,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
38563891
is_assignment_target = is_assignment_target_copy;
38573892
llvm::Value* target_struct = tmp;
38583893
ptr_loads = ptr_loads_copy;
3859-
LLVM::CreateStore(*builder,
3860-
LLVM::CreateLoad(*builder, value_struct),
3861-
target_struct
3862-
);
3894+
llvm_utils->deepcopy(value_struct, target_struct,
3895+
asr_target_type, module.get(), name2memidx);
38633896
return ;
38643897
}
38653898

@@ -3973,9 +4006,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
39734006
if( ASRUtils::is_array(target_type) &&
39744007
ASRUtils::is_array(value_type) &&
39754008
ASRUtils::check_equal_type(target_type, value_type) ) {
3976-
bool create_dim_des_array = !ASR::is_a<ASR::Var_t>(*x.m_target);
39774009
arr_descr->copy_array(value, target, module.get(),
3978-
target_type, create_dim_des_array);
4010+
target_type, false, false);
39794011
} else {
39804012
builder->CreateStore(value, target);
39814013
}
@@ -5975,7 +6007,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
59756007
}
59766008
if( ASR::is_a<ASR::Tuple_t>(*arg_type) ||
59776009
ASR::is_a<ASR::List_t>(*arg_type) ) {
5978-
llvm_utils->deepcopy(value, target, arg_type, *module);
6010+
llvm_utils->deepcopy(value, target, arg_type, module.get(), name2memidx);
59796011
} else {
59806012
builder->CreateStore(value, target);
59816013
}

src/libasr/codegen/llvm_array_utils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,16 @@ namespace LFortran {
571571

572572
// Shallow copies source array descriptor to destination descriptor
573573
void SimpleCMODescriptor::copy_array(llvm::Value* src, llvm::Value* dest,
574-
llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array) {
574+
llvm::Module* module, ASR::ttype_t* asr_data_type, bool create_dim_des_array,
575+
bool reserve_memory) {
575576
llvm::Value* num_elements = this->get_array_size(src, nullptr, 4);
576577

577578
llvm::Value* first_ptr = this->get_pointer_to_data(dest);
578579
llvm::Type* llvm_data_type = tkr2array[ASRUtils::get_type_code(asr_data_type, false, false)].second;
579-
llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements);
580-
builder->CreateStore(arr_first, first_ptr);
580+
if( reserve_memory ) {
581+
llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type, num_elements);
582+
builder->CreateStore(arr_first, first_ptr);
583+
}
581584

582585
llvm::Value* ptr2firstptr = this->get_pointer_to_data(src);
583586
llvm::DataLayout data_layout(module);

src/libasr/codegen/llvm_array_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ namespace LFortran {
260260
virtual
261261
void copy_array(llvm::Value* src, llvm::Value* dest,
262262
llvm::Module* module, ASR::ttype_t* asr_data_type,
263-
bool create_dim_des_array) = 0;
263+
bool create_dim_des_array, bool reserve_memory) = 0;
264264

265265
virtual
266266
llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim,
@@ -394,7 +394,7 @@ namespace LFortran {
394394
virtual
395395
void copy_array(llvm::Value* src, llvm::Value* dest,
396396
llvm::Module* module, ASR::ttype_t* asr_data_type,
397-
bool create_dim_des_array);
397+
bool create_dim_des_array, bool reserve_memory);
398398

399399
virtual
400400
llvm::Value* get_array_size(llvm::Value* array, llvm::Value* dim,

0 commit comments

Comments
 (0)