Skip to content

Commit 515b1b0

Browse files
committed
Initial support for callbacks
1 parent dd2413b commit 515b1b0

File tree

2 files changed

+116
-13
lines changed

2 files changed

+116
-13
lines changed

src/libasr/asr_utils.h

+5
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,11 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
15511551
m_dims = nullptr;
15521552
break;
15531553
}
1554+
case ASR::ttypeType::FunctionType: {
1555+
n_dims = 0;
1556+
m_dims = nullptr;
1557+
break;
1558+
}
15541559
case ASR::ttypeType::Dict: {
15551560
n_dims = 0;
15561561
m_dims = nullptr;

src/lpython/semantics/python_ast_to_asr.cpp

+111-13
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
871871
cast_helper(m_args[i], c_arg.m_value, true);
872872
ASR::ttype_t* left_type = ASRUtils::expr_type(m_args[i]);
873873
ASR::ttype_t* right_type = ASRUtils::expr_type(c_arg.m_value);
874+
if (ASR::is_a<ASR::FunctionType_t>(*left_type) ) {
875+
// TODO: add FunctionType in check_equal_type
876+
continue;
877+
}
874878
if( check_type_equality && !ASRUtils::check_equal_type(left_type, right_type) ) {
875879
std::string ltype = ASRUtils::type_to_str_python(left_type);
876880
std::string rtype = ASRUtils::type_to_str_python(right_type);
@@ -1506,6 +1510,31 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
15061510
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Tuple_t(al, loc,
15071511
types.p, types.size()));
15081512
return type;
1513+
} else if (var_annotation == "Callable") {
1514+
LCOMPILERS_ASSERT(AST::is_a<AST::Tuple_t>(*s->m_slice));
1515+
AST::Tuple_t *t = AST::down_cast<AST::Tuple_t>(s->m_slice);
1516+
LCOMPILERS_ASSERT(t->n_elts <= 2 && t->n_elts >= 1);
1517+
Vec<ASR::ttype_t*> arg_types;
1518+
LCOMPILERS_ASSERT(AST::is_a<AST::List_t>(*t->m_elts[0]));
1519+
1520+
AST::List_t *arg_list = AST::down_cast<AST::List_t>(t->m_elts[0]);
1521+
if (arg_list->n_elts > 0) {
1522+
arg_types.reserve(al, arg_list->n_elts);
1523+
for (size_t i=0; i<arg_list->n_elts; i++) {
1524+
arg_types.push_back(al, ast_expr_to_asr_type(loc, *arg_list->m_elts[i]));
1525+
}
1526+
} else {
1527+
arg_types.reserve(al, 1);
1528+
}
1529+
ASR::ttype_t* ret_type = nullptr;
1530+
if (t->n_elts == 2) {
1531+
ret_type = ast_expr_to_asr_type(loc, *t->m_elts[1]);
1532+
}
1533+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_FunctionType_t(al, loc, arg_types.p,
1534+
arg_types.size(), ret_type, ASR::abiType::Source,
1535+
ASR::deftypeType::Interface, nullptr, false, false,
1536+
false, false, false, nullptr, 0, nullptr, 0, false));
1537+
return type;
15091538
} else if (var_annotation == "set") {
15101539
if (AST::is_a<AST::Name_t>(*s->m_slice)) {
15111540
ASR::ttype_t *type = ast_expr_to_asr_type(loc, *s->m_slice);
@@ -3365,6 +3394,69 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
33653394
tmp = tmp0;
33663395
}
33673396

3397+
ASR::symbol_t* create_implicit_interface_function(Location &loc, ASR::FunctionType_t *func, std::string func_name) {
3398+
SymbolTable *parent_scope = current_scope;
3399+
current_scope = al.make_new<SymbolTable>(parent_scope);
3400+
3401+
Vec<ASR::expr_t*> args;
3402+
args.reserve(al, func->n_arg_types);
3403+
std::string sym_name = to_lower(func_name);
3404+
for (size_t i=0; i<func->n_arg_types; i++) {
3405+
std::string arg_name = sym_name + "_arg_" + std::to_string(i);
3406+
arg_name = to_lower(arg_name);
3407+
ASR::symbol_t *v;
3408+
SetChar variable_dependencies_vec;
3409+
variable_dependencies_vec.reserve(al, 1);
3410+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
3411+
func->m_arg_types[i]);
3412+
v = ASR::down_cast<ASR::symbol_t>(
3413+
ASR::make_Variable_t(al, loc,
3414+
current_scope, s2c(al, arg_name), variable_dependencies_vec.p,
3415+
variable_dependencies_vec.size(), ASRUtils::intent_unspecified,
3416+
nullptr, nullptr, ASR::storage_typeType::Default, func->m_arg_types[i],
3417+
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
3418+
false));
3419+
current_scope->add_symbol(arg_name, v);
3420+
LCOMPILERS_ASSERT(v != nullptr)
3421+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc,
3422+
v)));
3423+
}
3424+
3425+
ASR::expr_t *to_return = nullptr;
3426+
if (func->m_return_var_type) {
3427+
std::string return_var_name = sym_name + "_return_var_name";
3428+
SetChar variable_dependencies_vec;
3429+
variable_dependencies_vec.reserve(al, 1);
3430+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
3431+
func->m_return_var_type);
3432+
ASR::asr_t *return_var = ASR::make_Variable_t(al, loc,
3433+
current_scope, s2c(al, return_var_name), variable_dependencies_vec.p,
3434+
variable_dependencies_vec.size(), ASRUtils::intent_return_var,
3435+
nullptr, nullptr, ASR::storage_typeType::Default, func->m_return_var_type,
3436+
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
3437+
false);
3438+
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(return_var));
3439+
to_return = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
3440+
ASR::down_cast<ASR::symbol_t>(return_var)));
3441+
}
3442+
3443+
tmp = ASRUtils::make_Function_t_util(
3444+
al, loc,
3445+
/* a_symtab */ current_scope,
3446+
/* a_name */ s2c(al, sym_name),
3447+
nullptr, 0,
3448+
/* a_args */ args.p,
3449+
/* n_args */ args.size(),
3450+
/* a_body */ nullptr,
3451+
/* n_body */ 0,
3452+
/* a_return_var */ to_return,
3453+
ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Interface,
3454+
nullptr, false, false, false, false, false, /* a_type_parameters */ nullptr,
3455+
/* n_type_parameters */ 0, nullptr, 0, false, false, false);
3456+
current_scope = parent_scope;
3457+
return ASR::down_cast<ASR::symbol_t>(tmp);
3458+
}
3459+
33683460
void visit_FunctionDef(const AST::FunctionDef_t &x) {
33693461
dependencies.clear(al);
33703462
SymbolTable *parent_scope = current_scope;
@@ -3482,20 +3574,26 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
34823574
if (current_procedure_abi_type == ASR::abiType::BindC) {
34833575
value_attr = true;
34843576
}
3485-
SetChar variable_dependencies_vec;
3486-
variable_dependencies_vec.reserve(al, 1);
3487-
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
3488-
ASR::asr_t *v = ASR::make_Variable_t(al, loc, current_scope,
3489-
s2c(al, arg_s), variable_dependencies_vec.p,
3490-
variable_dependencies_vec.size(),
3491-
s_intent, init_expr, value, storage_type, arg_type,
3492-
current_procedure_abi_type, s_access, s_presence,
3493-
value_attr);
3494-
current_scope->add_symbol(arg_s, ASR::down_cast<ASR::symbol_t>(v));
3495-
3496-
ASR::symbol_t *var = current_scope->get_symbol(arg_s);
3577+
ASR::symbol_t *v;
3578+
if (ASR::is_a<ASR::FunctionType_t>(*arg_type)) {
3579+
ASR::FunctionType_t *func = ASR::down_cast<ASR::FunctionType_t>(arg_type);
3580+
v = create_implicit_interface_function(loc, func, arg_s);
3581+
} else {
3582+
SetChar variable_dependencies_vec;
3583+
variable_dependencies_vec.reserve(al, 1);
3584+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
3585+
ASR::asr_t *_tmp = ASR::make_Variable_t(al, loc, current_scope,
3586+
s2c(al, arg_s), variable_dependencies_vec.p,
3587+
variable_dependencies_vec.size(),
3588+
s_intent, init_expr, value, storage_type, arg_type,
3589+
current_procedure_abi_type, s_access, s_presence,
3590+
value_attr);
3591+
v = ASR::down_cast<ASR::symbol_t>(_tmp);
3592+
3593+
}
3594+
current_scope->add_symbol(arg_s, v);
34973595
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
3498-
var)));
3596+
v)));
34993597
}
35003598
ASR::accessType s_access = ASR::accessType::Public;
35013599
ASR::deftypeType deftype = ASR::deftypeType::Implementation;

0 commit comments

Comments
 (0)