Skip to content

Commit 5a66456

Browse files
Merge pull request #2801 from tanay-man/inheritance
Initial implementation of Inheritance and Polymorphic functions
2 parents a104c30 + f42a561 commit 5a66456

File tree

9 files changed

+207
-34
lines changed

9 files changed

+207
-34
lines changed

integration_tests/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,9 @@ RUN(NAME class_01 LABELS cpython llvm llvm_jit)
841841
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
842842
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
843843
RUN(NAME class_04 LABELS cpython llvm llvm_jit)
844+
RUN(NAME class_05 LABELS cpython llvm llvm_jit)
845+
RUN(NAME class_06 LABELS cpython llvm llvm_jit)
846+
844847

845848
# callback_04 is to test emulation. So just run with cpython
846849
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)

integration_tests/class_05.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from lpython import i32
2+
3+
class Animal:
4+
def __init__(self:"Animal"):
5+
self.species: str = "Generic Animal"
6+
self.age: i32 = 0
7+
self.is_domestic: bool = True
8+
9+
class Dog(Animal):
10+
def __init__(self:"Dog", name:str, age:i32):
11+
super().__init__()
12+
self.species: str = "Dog"
13+
self.name: str = name
14+
self.age: i32 = age
15+
16+
class Cat(Animal):
17+
def __init__(self:"Cat", name: str, age: i32):
18+
super().__init__()
19+
self.species: str = "Cat"
20+
self.name:str = name
21+
self.age: i32 = age
22+
23+
def main():
24+
dog: Dog = Dog("Buddy", 5)
25+
cat: Cat = Cat("Whiskers", 3)
26+
op1: str = str(dog.name+" is a "+str(dog.age)+"-year-old "+dog.species+".")
27+
print(op1)
28+
assert op1 == "Buddy is a 5-year-old Dog."
29+
print(dog.is_domestic)
30+
assert dog.is_domestic == True
31+
op2: str = str(cat.name+ " is a "+ str(cat.age)+ "-year-old "+ cat.species+ ".")
32+
print(op2)
33+
assert op2 == "Whiskers is a 3-year-old Cat."
34+
print(cat.is_domestic)
35+
assert cat.is_domestic == True
36+
37+
main()

integration_tests/class_06.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from lpython import i32
2+
3+
class Base():
4+
def __init__(self:"Base"):
5+
self.x : i32 = 10
6+
7+
def get_x(self:"Base")->i32:
8+
print(self.x)
9+
return self.x
10+
11+
#Testing polymorphic fn calls
12+
def get_x_static(d: Base)->i32:
13+
print(d.x)
14+
return d.x
15+
16+
class Derived(Base):
17+
def __init__(self: "Derived"):
18+
super().__init__()
19+
self.y : i32 = 20
20+
21+
def get_y(self:"Derived")->i32:
22+
print(self.y)
23+
return self.y
24+
25+
26+
def main():
27+
d : Derived = Derived()
28+
x : i32 = get_x_static(d)
29+
assert x == 10
30+
# Testing parent method call using der obj
31+
x = d.get_x()
32+
assert x == 10
33+
y: i32 = d.get_y()
34+
assert y == 20
35+
36+
main()

src/libasr/ASR.asdl

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ ttype
215215
| Array(ttype type, dimension* dims, array_physical_type physical_type)
216216
| FunctionType(ttype* arg_types, ttype? return_var_type, abi abi, deftype deftype, string? bindc_name, bool elemental, bool pure, bool module, bool inline, bool static, symbol* restrictions, bool is_restriction)
217217

218-
cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray
218+
cast_kind = RealToInteger | IntegerToReal | LogicalToReal | RealToReal | IntegerToInteger | RealToComplex | IntegerToComplex | IntegerToLogical | RealToLogical | CharacterToLogical | CharacterToInteger | CharacterToList | ComplexToLogical | ComplexToComplex | ComplexToReal | ComplexToInteger | LogicalToInteger | RealToCharacter | IntegerToCharacter | LogicalToCharacter | UnsignedIntegerToInteger | UnsignedIntegerToUnsignedInteger | UnsignedIntegerToReal | UnsignedIntegerToLogical | IntegerToUnsignedInteger | RealToUnsignedInteger | CPtrToUnsignedInteger | UnsignedIntegerToCPtr | IntegerToSymbolicExpression | ListToArray | DerivedToBase
219219
storage_type = Default | Save | Parameter
220220
access = Public | Private
221221
intent = Local | In | Out | InOut | ReturnVar | Unspecified

src/libasr/casting_utils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ namespace LCompilers::CastingUtil {
4141
{ASR::ttypeType::Complex, ASR::cast_kindType::ComplexToComplex},
4242
{ASR::ttypeType::Real, ASR::cast_kindType::RealToReal},
4343
{ASR::ttypeType::Integer, ASR::cast_kindType::IntegerToInteger},
44-
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger}
44+
{ASR::ttypeType::UnsignedInteger, ASR::cast_kindType::UnsignedIntegerToUnsignedInteger},
45+
{ASR::ttypeType::StructType, ASR::cast_kindType::DerivedToBase}
4546
};
4647

4748
int get_type_priority(ASR::ttypeType type) {

src/libasr/codegen/asr_to_llvm.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -7725,6 +7725,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
77257725
tmp = LLVM::CreateLoad(*builder, list_api->get_pointer_to_list_data(tmp));
77267726
break;
77277727
}
7728+
case (ASR::cast_kindType::DerivedToBase) : {
7729+
this->visit_expr(*x.m_arg);
7730+
tmp = llvm_utils->create_gep(tmp, 0);
7731+
break;
7732+
}
77287733
default : throw CodeGenError("Cast kind not implemented");
77297734
}
77307735
}

src/lpython/semantics/python_ast_to_asr.cpp

+121-30
Original file line numberDiff line numberDiff line change
@@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
784784
ASR::call_arg_t c_arg;
785785
c_arg.loc = args[i].loc;
786786
c_arg.m_value = args[i].m_value;
787-
cast_helper(m_args[i], c_arg.m_value, true);
788787
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
789788
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
789+
if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
790+
ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
791+
ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
792+
ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
793+
ASRUtils::symbol_get_past_external(
794+
l_type->m_derived_type));
795+
ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
796+
ASRUtils::symbol_get_past_external(
797+
r_type->m_derived_type));
798+
if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) {
799+
cast_helper(m_args[i], c_arg.m_value, true, true);
800+
check_type_equality = false;
801+
} else {
802+
cast_helper(m_args[i], c_arg.m_value, true);
803+
}
804+
} else {
805+
cast_helper(m_args[i], c_arg.m_value, true);
806+
}
790807
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
791808
std::string ltype = ASRUtils::type_to_str_python(left_type);
792809
std::string rtype = ASRUtils::type_to_str_python(right_type);
@@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
29622979
std::string obj_name = x.m_args.m_args->m_arg;
29632980
for(size_t i = 0; i < x.n_body; i++) {
29642981
std::string var_name;
2965-
if (! AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
2966-
throw SemanticError("Only AnnAssign implemented in __init__ ",
2967-
x.m_body[i]->base.loc);
2982+
if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
2983+
continue;
29682984
}
29692985
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
29702986
if(AST::is_a<AST::Attribute_t>(*ann_assign.m_target)){
@@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33013317
current_scope->add_symbol(x_m_name, class_type);
33023318
}
33033319
} else {
3304-
if( x.n_bases > 0 ) {
3305-
throw SemanticError("Inheritance in classes isn't supported yet.",
3320+
ASR::symbol_t* parent = nullptr;
3321+
if( x.n_bases > 1 ) {
3322+
throw SemanticError("Multiple inheritance in classes isn't supported yet.",
33063323
x.base.base.loc);
33073324
}
3325+
else if (x.n_bases == 1) {
3326+
std::string b_name = "";
3327+
if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
3328+
b_name = AST::down_cast<AST::Name_t>(x.m_bases[0])->m_id;
3329+
} else {
3330+
throw SemanticError("Expected a Name here", x.base.base.loc);
3331+
}
3332+
parent = current_scope->resolve_symbol(b_name);
3333+
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
3334+
}
33083335
SymbolTable *parent_scope = current_scope;
33093336
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
33103337
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
@@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33163343
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
33173344
init_self_type(*f, sym, x.base.base.loc);
33183345
if ( std::string(f->m_name) == std::string("__init__") ) {
3319-
this->visit_init_body(*f);
3346+
this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc);
33203347
} else {
33213348
this->visit_stmt(*x.m_body[i]);
33223349
}
@@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33443371
member_names.p, member_names.size(), member_fn_names.p,
33453372
member_fn_names.size(), class_abi, ASR::accessType::Public,
33463373
false, false, member_init.p, member_init.size(),
3347-
nullptr, nullptr));
3374+
nullptr, parent));
33483375
parent_scope->add_symbol(x.m_name, class_sym);
33493376
visit_ClassMembers(x, member_names, member_fn_names,
33503377
struct_dependencies, member_init, false, class_abi, true);
@@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33873414
current_scope = parent_scope;
33883415
}
33893416

3390-
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
3417+
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) = 0;
33913418

33923419
void add_name(const Location &loc) {
33933420
std::string var_name = "__name__";
@@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
44214448
// Implement visit_Global for Symbol Table visitor.
44224449
void visit_Global(const AST::Global_t &/*x*/) {}
44234450

4424-
void visit_init_body (const AST::FunctionDef_t &/*x*/) {
4451+
void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/) {
44254452
//Implemented in BodyVisitor
44264453
}
44274454

@@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51535180
tmp = asr;
51545181
}
51555182

5156-
void visit_init_body (const AST::FunctionDef_t &x) {
5183+
void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc) {
51575184
SymbolTable *old_scope = current_scope;
51585185
ASR::symbol_t *t = current_scope->get_symbol("__init__");
51595186
if ( t==nullptr ) {
@@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51635190
throw SemanticError("__init__ is not a function", x.base.base.loc);
51645191
}
51655192
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
5193+
current_scope = f->m_symtab;
51665194
//Transform statements into correct format
5167-
Vec<AST::stmt_t*> new_body;
5168-
new_body.reserve(al, 1);
5195+
Vec<AST::stmt_t*> body;
5196+
body.reserve(al, 1);
5197+
ASR::stmt_t* super_call_stmt = nullptr;
51695198
for (size_t i=0; i<x.n_body; i++) {
5170-
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
5171-
if ( ann_assign.m_value != nullptr ) {
5172-
Vec<AST::expr_t*>target;
5173-
target.reserve(al, 1);
5174-
target.push_back(al, ann_assign.m_target);
5175-
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
5176-
target.p, 1, ann_assign.m_value, nullptr);
5177-
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
5178-
new_body.push_back(al, assgn);
5199+
if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
5200+
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
5201+
if ( ann_assign.m_value != nullptr ) {
5202+
Vec<AST::expr_t*>target;
5203+
target.reserve(al, 1);
5204+
target.push_back(al, ann_assign.m_target);
5205+
AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
5206+
target.p, 1, ann_assign.m_value, nullptr);
5207+
AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
5208+
body.push_back(al, assgn);
5209+
}
5210+
} else if (AST::is_a<AST::Expr_t>(*x.m_body[i]) &&
5211+
AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value))) {
5212+
AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value);
5213+
5214+
if ( !AST::is_a<AST::Attribute_t>(*(c->m_func))
5215+
|| !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value)) ) {
5216+
body.push_back(al, x.m_body[i]);
5217+
continue;
5218+
}
5219+
AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value);
5220+
std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func)->m_attr;
5221+
if ( AST::is_a<AST::Name_t>(*(super_call->m_func)) &&
5222+
std::string(AST::down_cast<AST::Name_t>(super_call->m_func)->m_id)=="super" &&
5223+
attr == "__init__") {
5224+
if (parent_sym == nullptr) {
5225+
throw SemanticError("The class doesn't have a base class",loc);
5226+
}
5227+
Vec<ASR::call_arg_t> args;
5228+
args.reserve(al, 1);
5229+
parse_args(*super_call,args);
5230+
ASR::call_arg_t first_arg;
5231+
first_arg.loc = loc;
5232+
ASR::symbol_t* self_sym = current_scope->get_symbol("self");
5233+
first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym));
5234+
ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym));
5235+
cast_helper(target_type, first_arg.m_value, x.base.base.loc, true);
5236+
Vec<ASR::call_arg_t> args_w_first; args_w_first.reserve(al,1);
5237+
args_w_first.push_back(al, first_arg);
5238+
for( size_t i = 0; i < args.size(); i++ ) {
5239+
args_w_first.push_back(al,args[i]);
5240+
}
5241+
std::string call_name = "__init__";
5242+
ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc);
5243+
super_call_stmt = ASRUtils::STMT(
5244+
ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p,
5245+
args_w_first.size(), nullptr));
5246+
}
5247+
} else {
5248+
body.push_back(al, x.m_body[i]);
51795249
}
51805250
}
51815251
current_scope = f->m_symtab;
5182-
Vec<ASR::stmt_t*> body;
5183-
body.reserve(al, x.n_body);
5252+
Vec<ASR::stmt_t*> body_asr;
5253+
body_asr.reserve(al, x.n_body);
5254+
if ( super_call_stmt ) {
5255+
body_asr.push_back(al, super_call_stmt);
5256+
}
51845257
Vec<ASR::symbol_t*> rts;
51855258
rts.reserve(al, 4);
51865259
dependencies.clear(al);
5187-
transform_stmts(body, new_body.n, new_body.p);
5260+
transform_stmts(body_asr, body.n, body.p);
51885261
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
5189-
f->m_body = body.p;
5190-
f->n_body = body.size();
5262+
f->m_body = body_asr.p;
5263+
f->n_body = body_asr.size();
51915264
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
51925265
f->m_function_signature);
51935266
func_type->m_restrictions = rts.p;
@@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
62396312
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
62406313
member_found = std::string(der_type->m_members[i]) == member_name;
62416314
}
6242-
if( !member_found ) {
6315+
if( !member_found && !der_type->m_parent ) {
62436316
throw SemanticError("No member " + member_name +
62446317
" found in " + std::string(der_type->m_name),
62456318
loc);
6319+
} else if ( !member_found && der_type->m_parent ) {
6320+
ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
6321+
visit_AttributeUtil(parent_type,attr_char,t,loc);
6322+
return;
62466323
}
62476324
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
62486325
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
@@ -8064,7 +8141,8 @@ we will have to use something else.
80648141
//TODO: Correct Class and ClassType
80658142
// call to struct member function
80668143
// modifying args to pass the object as self
8067-
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
8144+
ASR::symbol_t* der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
8145+
ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
80688146
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
80698147
ASR::call_arg_t self_arg;
80708148
self_arg.loc = args[0].loc;
@@ -8073,7 +8151,20 @@ we will have to use something else.
80738151
for (size_t i=0; i<args.n; i++) {
80748152
new_args.push_back(al, args[i]);
80758153
}
8076-
st = get_struct_member(der, call_name, loc);
8154+
if ( der->m_symtab->get_symbol(call_name) ) {
8155+
st = get_struct_member(der_sym, call_name, loc);
8156+
} else if ( der->m_parent ) {
8157+
ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent);
8158+
if ( !parent->m_symtab->get_symbol(call_name) ) {
8159+
throw SemanticError("Method not found in the class "+ std::string(der->m_name) +
8160+
" or it's parents",loc);
8161+
} else {
8162+
st = get_struct_member(der->m_parent, call_name, loc);
8163+
}
8164+
} else {
8165+
throw SemanticError("Method not found in the class "+std::string(der->m_name)+
8166+
" or it's parents",loc);
8167+
}
80778168
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
80788169
return;
80798170
} else {

tests/reference/asr-structs_09-f3ffe08.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88
"stdout": null,
99
"stdout_hash": null,
1010
"stderr": "asr-structs_09-f3ffe08.stderr",
11-
"stderr_hash": "f59ab2d213f6423e0a891e43d5a19e83d4405391b1c7bf481b4b939e",
11+
"stderr_hash": "14119a0bc6420ad242b99395d457f2092014d96d2a1ac81d376c649d",
1212
"returncode": 2
1313
}

0 commit comments

Comments
 (0)