Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ RUN(NAME global_syms_04 LABELS cpython llvm c wasm wasm_x64)
RUN(NAME global_syms_05 LABELS cpython llvm c)
RUN(NAME global_syms_06 LABELS cpython llvm c)

RUN(NAME callback_01 LABELS cpython llvm)

# Intrinsic Functions
RUN(NAME intrinsics_01 LABELS cpython llvm) # any

Expand Down
26 changes: 26 additions & 0 deletions integration_tests/callback_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from lpython import i32, Callable


def f(x: i32) -> i32:
return x + 1

def f2(x: i32) -> i32:
return x + 10

def f3(x: i32) -> i32:
return f(x) + f2(x)


def g(func: Callable[[i32], i32], arg: i32) -> i32:
ret: i32
ret = func(arg)
return ret


def check():
assert g(f, 10) == 11
assert g(f2, 20) == 30
assert g(f3, 5) == 21


check()
26 changes: 26 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,11 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
m_dims = nullptr;
break;
}
case ASR::ttypeType::FunctionType: {
n_dims = 0;
m_dims = nullptr;
break;
}
case ASR::ttypeType::Dict: {
n_dims = 0;
m_dims = nullptr;
Expand Down Expand Up @@ -2340,6 +2345,27 @@ inline bool check_equal_type(ASR::ttype_t* x, ASR::ttype_t* y) {
std::string left_param = left_tp->m_param;
std::string right_param = right_tp->m_param;
return left_param.compare(right_param) == 0;
} else if (ASR::is_a<ASR::FunctionType_t>(*x) && ASR::is_a<ASR::FunctionType_t>(*y)) {
ASR::FunctionType_t* left_ft = ASR::down_cast<ASR::FunctionType_t>(x);
ASR::FunctionType_t* right_ft = ASR::down_cast<ASR::FunctionType_t>(y);
if (left_ft->n_arg_types != right_ft->n_arg_types) {
return false;
}
bool result;
for (size_t i=0; i<left_ft->n_arg_types; i++) {
result = check_equal_type(left_ft->m_arg_types[i],
right_ft->m_arg_types[i]);
if (!result) return false;
}
if (left_ft->m_return_var_type == nullptr &&
right_ft->m_return_var_type == nullptr) {
return true;
} else if (left_ft->m_return_var_type != nullptr &&
right_ft->m_return_var_type != nullptr) {
return check_equal_type(left_ft->m_return_var_type,
right_ft->m_return_var_type);
}
return false;
}

return types_equal(x, y);
Expand Down
120 changes: 107 additions & 13 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,31 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Tuple_t(al, loc,
types.p, types.size()));
return type;
} else if (var_annotation == "Callable") {
LCOMPILERS_ASSERT(AST::is_a<AST::Tuple_t>(*s->m_slice));
AST::Tuple_t *t = AST::down_cast<AST::Tuple_t>(s->m_slice);
LCOMPILERS_ASSERT(t->n_elts <= 2 && t->n_elts >= 1);
Vec<ASR::ttype_t*> arg_types;
LCOMPILERS_ASSERT(AST::is_a<AST::List_t>(*t->m_elts[0]));

AST::List_t *arg_list = AST::down_cast<AST::List_t>(t->m_elts[0]);
if (arg_list->n_elts > 0) {
arg_types.reserve(al, arg_list->n_elts);
for (size_t i=0; i<arg_list->n_elts; i++) {
arg_types.push_back(al, ast_expr_to_asr_type(loc, *arg_list->m_elts[i]));
}
} else {
arg_types.reserve(al, 1);
}
ASR::ttype_t* ret_type = nullptr;
if (t->n_elts == 2) {
ret_type = ast_expr_to_asr_type(loc, *t->m_elts[1]);
}
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_FunctionType_t(al, loc, arg_types.p,
arg_types.size(), ret_type, ASR::abiType::Source,
ASR::deftypeType::Interface, nullptr, false, false,
false, false, false, nullptr, 0, nullptr, 0, false));
return type;
} else if (var_annotation == "set") {
if (AST::is_a<AST::Name_t>(*s->m_slice)) {
ASR::ttype_t *type = ast_expr_to_asr_type(loc, *s->m_slice);
Expand Down Expand Up @@ -3380,6 +3405,69 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
tmp = tmp0;
}

ASR::symbol_t* create_implicit_interface_function(Location &loc, ASR::FunctionType_t *func, std::string func_name) {
SymbolTable *parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, func->n_arg_types);
std::string sym_name = to_lower(func_name);
for (size_t i=0; i<func->n_arg_types; i++) {
std::string arg_name = sym_name + "_arg_" + std::to_string(i);
arg_name = to_lower(arg_name);
ASR::symbol_t *v;
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
func->m_arg_types[i]);
v = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, loc,
current_scope, s2c(al, arg_name), variable_dependencies_vec.p,
variable_dependencies_vec.size(), ASRUtils::intent_unspecified,
nullptr, nullptr, ASR::storage_typeType::Default, func->m_arg_types[i],
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
false));
current_scope->add_symbol(arg_name, v);
LCOMPILERS_ASSERT(v != nullptr)
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc,
v)));
}

ASR::expr_t *to_return = nullptr;
if (func->m_return_var_type) {
std::string return_var_name = sym_name + "_return_var_name";
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
func->m_return_var_type);
ASR::asr_t *return_var = ASR::make_Variable_t(al, loc,
current_scope, s2c(al, return_var_name), variable_dependencies_vec.p,
variable_dependencies_vec.size(), ASRUtils::intent_return_var,
nullptr, nullptr, ASR::storage_typeType::Default, func->m_return_var_type,
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
false);
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(return_var));
to_return = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
ASR::down_cast<ASR::symbol_t>(return_var)));
}

tmp = ASRUtils::make_Function_t_util(
al, loc,
/* a_symtab */ current_scope,
/* a_name */ s2c(al, sym_name),
nullptr, 0,
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_body */ nullptr,
/* n_body */ 0,
/* a_return_var */ to_return,
ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Interface,
nullptr, false, false, false, false, false, /* a_type_parameters */ nullptr,
/* n_type_parameters */ 0, nullptr, 0, false, false, false);
current_scope = parent_scope;
return ASR::down_cast<ASR::symbol_t>(tmp);
}

void visit_FunctionDef(const AST::FunctionDef_t &x) {
dependencies.clear(al);
SymbolTable *parent_scope = current_scope;
Expand Down Expand Up @@ -3497,20 +3585,26 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
if (current_procedure_abi_type == ASR::abiType::BindC) {
value_attr = true;
}
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
ASR::asr_t *v = ASR::make_Variable_t(al, loc, current_scope,
s2c(al, arg_s), variable_dependencies_vec.p,
variable_dependencies_vec.size(),
s_intent, init_expr, value, storage_type, arg_type,
current_procedure_abi_type, s_access, s_presence,
value_attr);
current_scope->add_symbol(arg_s, ASR::down_cast<ASR::symbol_t>(v));

ASR::symbol_t *var = current_scope->get_symbol(arg_s);
ASR::symbol_t *v;
if (ASR::is_a<ASR::FunctionType_t>(*arg_type)) {
ASR::FunctionType_t *func = ASR::down_cast<ASR::FunctionType_t>(arg_type);
v = create_implicit_interface_function(loc, func, arg_s);
} else {
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
ASR::asr_t *_tmp = ASR::make_Variable_t(al, loc, current_scope,
s2c(al, arg_s), variable_dependencies_vec.p,
variable_dependencies_vec.size(),
s_intent, init_expr, value, storage_type, arg_type,
current_procedure_abi_type, s_access, s_presence,
value_attr);
v = ASR::down_cast<ASR::symbol_t>(_tmp);

}
current_scope->add_symbol(arg_s, v);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
var)));
v)));
}
ASR::accessType s_access = ASR::accessType::Public;
ASR::deftypeType deftype = ASR::deftypeType::Implementation;
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__slots__ = ["i8", "i16", "i32", "i64", "f32", "f64", "c32", "c64", "CPtr",
"overload", "ccall", "TypeVar", "pointer", "c_p_pointer", "Pointer",
"p_c_pointer", "vectorize", "inline", "Union", "static", "with_goto",
"packed", "Const", "sizeof", "ccallable", "ccallback"]
"packed", "Const", "sizeof", "ccallable", "ccallback", "Callable"]

# data-types

Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(self, type, dims):
c64 = Type("c64")
CPtr = Type("c_ptr")
Const = ConstType("Const")
Callable = Type("Callable")
Union = ctypes.Union
Pointer = PointerType("Pointer")

Expand Down
13 changes: 13 additions & 0 deletions tests/reference/asr-callback_01-64f7a94.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-callback_01-64f7a94",
"cmd": "lpython --show-asr --indent --no-color {infile} -o {outfile}",
"infile": "tests/../integration_tests/callback_01.py",
"infile_hash": "c3ab71a93f40edda000ae863149c38c388bb43a8329ebae9320a7ab4",
"outfile": null,
"outfile_hash": null,
"stdout": "asr-callback_01-64f7a94.stdout",
"stdout_hash": "0b2b8730f07fc9aad59a2c4f1dc9060bcd022d05fb5bf2f34e5f8b4b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
Loading