@@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
784
784
ASR::call_arg_t c_arg;
785
785
c_arg.loc = args[i].loc;
786
786
c_arg.m_value = args[i].m_value;
787
- cast_helper(m_args[i], c_arg.m_value, true);
788
787
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
789
788
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
789
+ if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
790
+ ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
791
+ ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
792
+ ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
793
+ ASRUtils::symbol_get_past_external(
794
+ l_type->m_derived_type));
795
+ ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
796
+ ASRUtils::symbol_get_past_external(
797
+ r_type->m_derived_type));
798
+ if ( ASRUtils::is_derived_type_similar(l2_type, r2_type) ) {
799
+ cast_helper(m_args[i], c_arg.m_value, true, true);
800
+ check_type_equality = false;
801
+ } else {
802
+ cast_helper(m_args[i], c_arg.m_value, true);
803
+ }
804
+ } else {
805
+ cast_helper(m_args[i], c_arg.m_value, true);
806
+ }
790
807
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
791
808
std::string ltype = ASRUtils::type_to_str_python(left_type);
792
809
std::string rtype = ASRUtils::type_to_str_python(right_type);
@@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
2962
2979
std::string obj_name = x.m_args.m_args->m_arg;
2963
2980
for(size_t i = 0; i < x.n_body; i++) {
2964
2981
std::string var_name;
2965
- if (! AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
2966
- throw SemanticError("Only AnnAssign implemented in __init__ ",
2967
- x.m_body[i]->base.loc);
2982
+ if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body[i]) ){
2983
+ continue;
2968
2984
}
2969
2985
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
2970
2986
if(AST::is_a<AST::Attribute_t>(*ann_assign.m_target)){
@@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3301
3317
current_scope->add_symbol(x_m_name, class_type);
3302
3318
}
3303
3319
} else {
3304
- if( x.n_bases > 0 ) {
3305
- throw SemanticError("Inheritance in classes isn't supported yet.",
3320
+ ASR::symbol_t* parent = nullptr;
3321
+ if( x.n_bases > 1 ) {
3322
+ throw SemanticError("Multiple inheritance in classes isn't supported yet.",
3306
3323
x.base.base.loc);
3307
3324
}
3325
+ else if (x.n_bases == 1) {
3326
+ std::string b_name = "";
3327
+ if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
3328
+ b_name = AST::down_cast<AST::Name_t>(x.m_bases[0])->m_id;
3329
+ } else {
3330
+ throw SemanticError("Expected a Name here", x.base.base.loc);
3331
+ }
3332
+ parent = current_scope->resolve_symbol(b_name);
3333
+ LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
3334
+ }
3308
3335
SymbolTable *parent_scope = current_scope;
3309
3336
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
3310
3337
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
@@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3316
3343
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
3317
3344
init_self_type(*f, sym, x.base.base.loc);
3318
3345
if ( std::string(f->m_name) == std::string("__init__") ) {
3319
- this->visit_init_body(*f);
3346
+ this->visit_init_body(*f, st->m_parent, x.m_body[i]->base.loc );
3320
3347
} else {
3321
3348
this->visit_stmt(*x.m_body[i]);
3322
3349
}
@@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3344
3371
member_names.p, member_names.size(), member_fn_names.p,
3345
3372
member_fn_names.size(), class_abi, ASR::accessType::Public,
3346
3373
false, false, member_init.p, member_init.size(),
3347
- nullptr, nullptr ));
3374
+ nullptr, parent ));
3348
3375
parent_scope->add_symbol(x.m_name, class_sym);
3349
3376
visit_ClassMembers(x, member_names, member_fn_names,
3350
3377
struct_dependencies, member_init, false, class_abi, true);
@@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3387
3414
current_scope = parent_scope;
3388
3415
}
3389
3416
3390
- virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
3417
+ virtual void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/ ) = 0;
3391
3418
3392
3419
void add_name(const Location &loc) {
3393
3420
std::string var_name = "__name__";
@@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
4421
4448
// Implement visit_Global for Symbol Table visitor.
4422
4449
void visit_Global(const AST::Global_t &/*x*/) {}
4423
4450
4424
- void visit_init_body (const AST::FunctionDef_t &/*x*/) {
4451
+ void visit_init_body (const AST::FunctionDef_t &/*x*/, ASR::symbol_t* /*parent_sym*/, const Location /*loc*/ ) {
4425
4452
//Implemented in BodyVisitor
4426
4453
}
4427
4454
@@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5153
5180
tmp = asr;
5154
5181
}
5155
5182
5156
- void visit_init_body (const AST::FunctionDef_t &x) {
5183
+ void visit_init_body (const AST::FunctionDef_t &x, ASR::symbol_t* parent_sym, const Location loc ) {
5157
5184
SymbolTable *old_scope = current_scope;
5158
5185
ASR::symbol_t *t = current_scope->get_symbol("__init__");
5159
5186
if ( t==nullptr ) {
@@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5163
5190
throw SemanticError("__init__ is not a function", x.base.base.loc);
5164
5191
}
5165
5192
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
5193
+ current_scope = f->m_symtab;
5166
5194
//Transform statements into correct format
5167
- Vec<AST::stmt_t*> new_body;
5168
- new_body.reserve(al, 1);
5195
+ Vec<AST::stmt_t*> body;
5196
+ body.reserve(al, 1);
5197
+ ASR::stmt_t* super_call_stmt = nullptr;
5169
5198
for (size_t i=0; i<x.n_body; i++) {
5170
- AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
5171
- if ( ann_assign.m_value != nullptr ) {
5172
- Vec<AST::expr_t*>target;
5173
- target.reserve(al, 1);
5174
- target.push_back(al, ann_assign.m_target);
5175
- AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
5176
- target.p, 1, ann_assign.m_value, nullptr);
5177
- AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
5178
- new_body.push_back(al, assgn);
5199
+ if (AST::is_a<AST::AnnAssign_t>(*x.m_body[i])) {
5200
+ AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body[i]);
5201
+ if ( ann_assign.m_value != nullptr ) {
5202
+ Vec<AST::expr_t*>target;
5203
+ target.reserve(al, 1);
5204
+ target.push_back(al, ann_assign.m_target);
5205
+ AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc,
5206
+ target.p, 1, ann_assign.m_value, nullptr);
5207
+ AST::stmt_t* assgn = AST::down_cast<AST::stmt_t>(assgn_ast);
5208
+ body.push_back(al, assgn);
5209
+ }
5210
+ } else if (AST::is_a<AST::Expr_t>(*x.m_body[i]) &&
5211
+ AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value))) {
5212
+ AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body[i])->m_value);
5213
+
5214
+ if ( !AST::is_a<AST::Attribute_t>(*(c->m_func))
5215
+ || !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value)) ) {
5216
+ body.push_back(al, x.m_body[i]);
5217
+ continue;
5218
+ }
5219
+ AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func)->m_value);
5220
+ std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func)->m_attr;
5221
+ if ( AST::is_a<AST::Name_t>(*(super_call->m_func)) &&
5222
+ std::string(AST::down_cast<AST::Name_t>(super_call->m_func)->m_id)=="super" &&
5223
+ attr == "__init__") {
5224
+ if (parent_sym == nullptr) {
5225
+ throw SemanticError("The class doesn't have a base class",loc);
5226
+ }
5227
+ Vec<ASR::call_arg_t> args;
5228
+ args.reserve(al, 1);
5229
+ parse_args(*super_call,args);
5230
+ ASR::call_arg_t first_arg;
5231
+ first_arg.loc = loc;
5232
+ ASR::symbol_t* self_sym = current_scope->get_symbol("self");
5233
+ first_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al,loc,self_sym));
5234
+ ASR::ttype_t* target_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc,parent_sym));
5235
+ cast_helper(target_type, first_arg.m_value, x.base.base.loc, true);
5236
+ Vec<ASR::call_arg_t> args_w_first; args_w_first.reserve(al,1);
5237
+ args_w_first.push_back(al, first_arg);
5238
+ for( size_t i = 0; i < args.size(); i++ ) {
5239
+ args_w_first.push_back(al,args[i]);
5240
+ }
5241
+ std::string call_name = "__init__";
5242
+ ASR::symbol_t* call_sym = get_struct_member(parent_sym,call_name,loc);
5243
+ super_call_stmt = ASRUtils::STMT(
5244
+ ASR::make_SubroutineCall_t(al, loc, call_sym, call_sym, args_w_first.p,
5245
+ args_w_first.size(), nullptr));
5246
+ }
5247
+ } else {
5248
+ body.push_back(al, x.m_body[i]);
5179
5249
}
5180
5250
}
5181
5251
current_scope = f->m_symtab;
5182
- Vec<ASR::stmt_t*> body;
5183
- body.reserve(al, x.n_body);
5252
+ Vec<ASR::stmt_t*> body_asr;
5253
+ body_asr.reserve(al, x.n_body);
5254
+ if ( super_call_stmt ) {
5255
+ body_asr.push_back(al, super_call_stmt);
5256
+ }
5184
5257
Vec<ASR::symbol_t*> rts;
5185
5258
rts.reserve(al, 4);
5186
5259
dependencies.clear(al);
5187
- transform_stmts(body, new_body .n, new_body .p);
5260
+ transform_stmts(body_asr, body .n, body .p);
5188
5261
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
5189
- f->m_body = body .p;
5190
- f->n_body = body .size();
5262
+ f->m_body = body_asr .p;
5263
+ f->n_body = body_asr .size();
5191
5264
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
5192
5265
f->m_function_signature);
5193
5266
func_type->m_restrictions = rts.p;
@@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
6239
6312
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
6240
6313
member_found = std::string(der_type->m_members[i]) == member_name;
6241
6314
}
6242
- if( !member_found ) {
6315
+ if( !member_found && !der_type->m_parent ) {
6243
6316
throw SemanticError("No member " + member_name +
6244
6317
" found in " + std::string(der_type->m_name),
6245
6318
loc);
6319
+ } else if ( !member_found && der_type->m_parent ) {
6320
+ ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
6321
+ visit_AttributeUtil(parent_type,attr_char,t,loc);
6322
+ return;
6246
6323
}
6247
6324
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
6248
6325
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
@@ -8064,7 +8141,8 @@ we will have to use something else.
8064
8141
//TODO: Correct Class and ClassType
8065
8142
// call to struct member function
8066
8143
// modifying args to pass the object as self
8067
- ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
8144
+ ASR::symbol_t* der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
8145
+ ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
8068
8146
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
8069
8147
ASR::call_arg_t self_arg;
8070
8148
self_arg.loc = args[0].loc;
@@ -8073,7 +8151,20 @@ we will have to use something else.
8073
8151
for (size_t i=0; i<args.n; i++) {
8074
8152
new_args.push_back(al, args[i]);
8075
8153
}
8076
- st = get_struct_member(der, call_name, loc);
8154
+ if ( der->m_symtab->get_symbol(call_name) ) {
8155
+ st = get_struct_member(der_sym, call_name, loc);
8156
+ } else if ( der->m_parent ) {
8157
+ ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent);
8158
+ if ( !parent->m_symtab->get_symbol(call_name) ) {
8159
+ throw SemanticError("Method not found in the class "+ std::string(der->m_name) +
8160
+ " or it's parents",loc);
8161
+ } else {
8162
+ st = get_struct_member(der->m_parent, call_name, loc);
8163
+ }
8164
+ } else {
8165
+ throw SemanticError("Method not found in the class "+std::string(der->m_name)+
8166
+ " or it's parents",loc);
8167
+ }
8077
8168
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
8078
8169
return;
8079
8170
} else {
0 commit comments