Skip to content

Commit 91bc4aa

Browse files
authored
Merge pull request #1726 from Smit-create/i-1608
Initial support for callbacks
2 parents 3d3a53c + e535783 commit 91bc4aa

File tree

8 files changed

+739
-14
lines changed

8 files changed

+739
-14
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ RUN(NAME global_syms_04 LABELS cpython llvm c wasm wasm_x64)
525525
RUN(NAME global_syms_05 LABELS cpython llvm c)
526526
RUN(NAME global_syms_06 LABELS cpython llvm c)
527527

528+
RUN(NAME callback_01 LABELS cpython llvm)
529+
528530
# Intrinsic Functions
529531
RUN(NAME intrinsics_01 LABELS cpython llvm) # any
530532

integration_tests/callback_01.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from lpython import i32, Callable
2+
3+
4+
def f(x: i32) -> i32:
5+
return x + 1
6+
7+
def f2(x: i32) -> i32:
8+
return x + 10
9+
10+
def f3(x: i32) -> i32:
11+
return f(x) + f2(x)
12+
13+
14+
def g(func: Callable[[i32], i32], arg: i32) -> i32:
15+
ret: i32
16+
ret = func(arg)
17+
return ret
18+
19+
20+
def check():
21+
assert g(f, 10) == 11
22+
assert g(f2, 20) == 30
23+
assert g(f3, 5) == 21
24+
25+
26+
check()

src/libasr/asr_utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,11 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
15511551
m_dims = nullptr;
15521552
break;
15531553
}
1554+
case ASR::ttypeType::FunctionType: {
1555+
n_dims = 0;
1556+
m_dims = nullptr;
1557+
break;
1558+
}
15541559
case ASR::ttypeType::Dict: {
15551560
n_dims = 0;
15561561
m_dims = nullptr;
@@ -2340,6 +2345,27 @@ inline bool check_equal_type(ASR::ttype_t* x, ASR::ttype_t* y) {
23402345
std::string left_param = left_tp->m_param;
23412346
std::string right_param = right_tp->m_param;
23422347
return left_param.compare(right_param) == 0;
2348+
} else if (ASR::is_a<ASR::FunctionType_t>(*x) && ASR::is_a<ASR::FunctionType_t>(*y)) {
2349+
ASR::FunctionType_t* left_ft = ASR::down_cast<ASR::FunctionType_t>(x);
2350+
ASR::FunctionType_t* right_ft = ASR::down_cast<ASR::FunctionType_t>(y);
2351+
if (left_ft->n_arg_types != right_ft->n_arg_types) {
2352+
return false;
2353+
}
2354+
bool result;
2355+
for (size_t i=0; i<left_ft->n_arg_types; i++) {
2356+
result = check_equal_type(left_ft->m_arg_types[i],
2357+
right_ft->m_arg_types[i]);
2358+
if (!result) return false;
2359+
}
2360+
if (left_ft->m_return_var_type == nullptr &&
2361+
right_ft->m_return_var_type == nullptr) {
2362+
return true;
2363+
} else if (left_ft->m_return_var_type != nullptr &&
2364+
right_ft->m_return_var_type != nullptr) {
2365+
return check_equal_type(left_ft->m_return_var_type,
2366+
right_ft->m_return_var_type);
2367+
}
2368+
return false;
23432369
}
23442370

23452371
return types_equal(x, y);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,31 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
15211521
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Tuple_t(al, loc,
15221522
types.p, types.size()));
15231523
return type;
1524+
} else if (var_annotation == "Callable") {
1525+
LCOMPILERS_ASSERT(AST::is_a<AST::Tuple_t>(*s->m_slice));
1526+
AST::Tuple_t *t = AST::down_cast<AST::Tuple_t>(s->m_slice);
1527+
LCOMPILERS_ASSERT(t->n_elts <= 2 && t->n_elts >= 1);
1528+
Vec<ASR::ttype_t*> arg_types;
1529+
LCOMPILERS_ASSERT(AST::is_a<AST::List_t>(*t->m_elts[0]));
1530+
1531+
AST::List_t *arg_list = AST::down_cast<AST::List_t>(t->m_elts[0]);
1532+
if (arg_list->n_elts > 0) {
1533+
arg_types.reserve(al, arg_list->n_elts);
1534+
for (size_t i=0; i<arg_list->n_elts; i++) {
1535+
arg_types.push_back(al, ast_expr_to_asr_type(loc, *arg_list->m_elts[i]));
1536+
}
1537+
} else {
1538+
arg_types.reserve(al, 1);
1539+
}
1540+
ASR::ttype_t* ret_type = nullptr;
1541+
if (t->n_elts == 2) {
1542+
ret_type = ast_expr_to_asr_type(loc, *t->m_elts[1]);
1543+
}
1544+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_FunctionType_t(al, loc, arg_types.p,
1545+
arg_types.size(), ret_type, ASR::abiType::Source,
1546+
ASR::deftypeType::Interface, nullptr, false, false,
1547+
false, false, false, nullptr, 0, nullptr, 0, false));
1548+
return type;
15241549
} else if (var_annotation == "set") {
15251550
if (AST::is_a<AST::Name_t>(*s->m_slice)) {
15261551
ASR::ttype_t *type = ast_expr_to_asr_type(loc, *s->m_slice);
@@ -3380,6 +3405,69 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
33803405
tmp = tmp0;
33813406
}
33823407

3408+
ASR::symbol_t* create_implicit_interface_function(Location &loc, ASR::FunctionType_t *func, std::string func_name) {
3409+
SymbolTable *parent_scope = current_scope;
3410+
current_scope = al.make_new<SymbolTable>(parent_scope);
3411+
3412+
Vec<ASR::expr_t*> args;
3413+
args.reserve(al, func->n_arg_types);
3414+
std::string sym_name = to_lower(func_name);
3415+
for (size_t i=0; i<func->n_arg_types; i++) {
3416+
std::string arg_name = sym_name + "_arg_" + std::to_string(i);
3417+
arg_name = to_lower(arg_name);
3418+
ASR::symbol_t *v;
3419+
SetChar variable_dependencies_vec;
3420+
variable_dependencies_vec.reserve(al, 1);
3421+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
3422+
func->m_arg_types[i]);
3423+
v = ASR::down_cast<ASR::symbol_t>(
3424+
ASR::make_Variable_t(al, loc,
3425+
current_scope, s2c(al, arg_name), variable_dependencies_vec.p,
3426+
variable_dependencies_vec.size(), ASRUtils::intent_unspecified,
3427+
nullptr, nullptr, ASR::storage_typeType::Default, func->m_arg_types[i],
3428+
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
3429+
false));
3430+
current_scope->add_symbol(arg_name, v);
3431+
LCOMPILERS_ASSERT(v != nullptr)
3432+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc,
3433+
v)));
3434+
}
3435+
3436+
ASR::expr_t *to_return = nullptr;
3437+
if (func->m_return_var_type) {
3438+
std::string return_var_name = sym_name + "_return_var_name";
3439+
SetChar variable_dependencies_vec;
3440+
variable_dependencies_vec.reserve(al, 1);
3441+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
3442+
func->m_return_var_type);
3443+
ASR::asr_t *return_var = ASR::make_Variable_t(al, loc,
3444+
current_scope, s2c(al, return_var_name), variable_dependencies_vec.p,
3445+
variable_dependencies_vec.size(), ASRUtils::intent_return_var,
3446+
nullptr, nullptr, ASR::storage_typeType::Default, func->m_return_var_type,
3447+
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
3448+
false);
3449+
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(return_var));
3450+
to_return = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
3451+
ASR::down_cast<ASR::symbol_t>(return_var)));
3452+
}
3453+
3454+
tmp = ASRUtils::make_Function_t_util(
3455+
al, loc,
3456+
/* a_symtab */ current_scope,
3457+
/* a_name */ s2c(al, sym_name),
3458+
nullptr, 0,
3459+
/* a_args */ args.p,
3460+
/* n_args */ args.size(),
3461+
/* a_body */ nullptr,
3462+
/* n_body */ 0,
3463+
/* a_return_var */ to_return,
3464+
ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Interface,
3465+
nullptr, false, false, false, false, false, /* a_type_parameters */ nullptr,
3466+
/* n_type_parameters */ 0, nullptr, 0, false, false, false);
3467+
current_scope = parent_scope;
3468+
return ASR::down_cast<ASR::symbol_t>(tmp);
3469+
}
3470+
33833471
void visit_FunctionDef(const AST::FunctionDef_t &x) {
33843472
dependencies.clear(al);
33853473
SymbolTable *parent_scope = current_scope;
@@ -3497,20 +3585,26 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
34973585
if (current_procedure_abi_type == ASR::abiType::BindC) {
34983586
value_attr = true;
34993587
}
3500-
SetChar variable_dependencies_vec;
3501-
variable_dependencies_vec.reserve(al, 1);
3502-
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
3503-
ASR::asr_t *v = ASR::make_Variable_t(al, loc, current_scope,
3504-
s2c(al, arg_s), variable_dependencies_vec.p,
3505-
variable_dependencies_vec.size(),
3506-
s_intent, init_expr, value, storage_type, arg_type,
3507-
current_procedure_abi_type, s_access, s_presence,
3508-
value_attr);
3509-
current_scope->add_symbol(arg_s, ASR::down_cast<ASR::symbol_t>(v));
3510-
3511-
ASR::symbol_t *var = current_scope->get_symbol(arg_s);
3588+
ASR::symbol_t *v;
3589+
if (ASR::is_a<ASR::FunctionType_t>(*arg_type)) {
3590+
ASR::FunctionType_t *func = ASR::down_cast<ASR::FunctionType_t>(arg_type);
3591+
v = create_implicit_interface_function(loc, func, arg_s);
3592+
} else {
3593+
SetChar variable_dependencies_vec;
3594+
variable_dependencies_vec.reserve(al, 1);
3595+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
3596+
ASR::asr_t *_tmp = ASR::make_Variable_t(al, loc, current_scope,
3597+
s2c(al, arg_s), variable_dependencies_vec.p,
3598+
variable_dependencies_vec.size(),
3599+
s_intent, init_expr, value, storage_type, arg_type,
3600+
current_procedure_abi_type, s_access, s_presence,
3601+
value_attr);
3602+
v = ASR::down_cast<ASR::symbol_t>(_tmp);
3603+
3604+
}
3605+
current_scope->add_symbol(arg_s, v);
35123606
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
3513-
var)));
3607+
v)));
35143608
}
35153609
ASR::accessType s_access = ASR::accessType::Public;
35163610
ASR::deftypeType deftype = ASR::deftypeType::Implementation;

src/runtime/lpython/lpython.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__slots__ = ["i8", "i16", "i32", "i64", "f32", "f64", "c32", "c64", "CPtr",
1010
"overload", "ccall", "TypeVar", "pointer", "c_p_pointer", "Pointer",
1111
"p_c_pointer", "vectorize", "inline", "Union", "static", "with_goto",
12-
"packed", "Const", "sizeof", "ccallable", "ccallback"]
12+
"packed", "Const", "sizeof", "ccallable", "ccallback", "Callable"]
1313

1414
# data-types
1515

@@ -55,6 +55,7 @@ def __init__(self, type, dims):
5555
c64 = Type("c64")
5656
CPtr = Type("c_ptr")
5757
Const = ConstType("Const")
58+
Callable = Type("Callable")
5859
Union = ctypes.Union
5960
Pointer = PointerType("Pointer")
6061

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-callback_01-64f7a94",
3+
"cmd": "lpython --show-asr --indent --no-color {infile} -o {outfile}",
4+
"infile": "tests/../integration_tests/callback_01.py",
5+
"infile_hash": "c3ab71a93f40edda000ae863149c38c388bb43a8329ebae9320a7ab4",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": "asr-callback_01-64f7a94.stdout",
9+
"stdout_hash": "0b2b8730f07fc9aad59a2c4f1dc9060bcd022d05fb5bf2f34e5f8b4b",
10+
"stderr": null,
11+
"stderr_hash": null,
12+
"returncode": 0
13+
}

0 commit comments

Comments
 (0)