Skip to content

Commit 2c12efb

Browse files
Support global list import with a FunctionCall as an initializer.
Co-authored-by: Ondřej Čertík <[email protected]>
1 parent a4f13a4 commit 2c12efb

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

src/libasr/asr_scopes.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,12 @@ std::string SymbolTable::get_unique_name(const std::string &name) {
137137

138138
void SymbolTable::move_symbols_from_global_scope(Allocator &al,
139139
SymbolTable *module_scope, Vec<char *> &syms,
140-
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init) {
140+
Vec<char *> &mod_dependencies, Vec<char *> &func_dependencies,
141+
Vec<ASR::stmt_t*> &var_init) {
141142
// TODO: This isn't scalable. We have write a visitor in asdl_cpp.py
142143
syms.reserve(al, 4);
143144
mod_dependencies.reserve(al, 4);
145+
func_dependencies.reserve(al, 4);
144146
var_init.reserve(al, 4);
145147
for (auto &a : scope) {
146148
switch (a.second->type) {
@@ -211,7 +213,35 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
211213
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(
212214
al, v->base.base.loc, (ASR::symbol_t *) es));
213215
ASR::expr_t *value = v->m_symbolic_value;
214-
216+
v->m_symbolic_value = nullptr;
217+
v->m_value = nullptr;
218+
if (ASR::is_a<ASR::FunctionCall_t>(*value)) {
219+
ASR::FunctionCall_t *call =
220+
ASR::down_cast<ASR::FunctionCall_t>(value);
221+
ASR::Module_t *m = ASRUtils::get_sym_module(s);
222+
ASR::symbol_t *func = m->m_symtab->get_symbol(
223+
ASRUtils::symbol_name(call->m_name));
224+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(func);
225+
std::string func_name = std::string(m->m_name) +
226+
"@" + f->m_name;
227+
ASR::symbol_t *es_func;
228+
if (!module_scope->get_symbol(func_name)) {
229+
es_func = ASR::down_cast<ASR::symbol_t>(
230+
ASR::make_ExternalSymbol_t(al, f->base.base.loc,
231+
module_scope, s2c(al, func_name), func, m->m_name,
232+
nullptr, 0, s2c(al, f->m_name), ASR::accessType::Public));
233+
module_scope->add_symbol(func_name, es_func);
234+
if (!present(func_dependencies, s2c(al, func_name))) {
235+
func_dependencies.push_back(al, s2c(al,func_name));
236+
}
237+
} else {
238+
es_func = module_scope->get_symbol(func_name);
239+
}
240+
value = ASRUtils::EXPR(ASR::make_FunctionCall_t(al,
241+
call->base.base.loc, es_func, call->m_original_name,
242+
call->m_args, call->n_args, call->m_type,
243+
call->m_value, call->m_dt));
244+
}
215245
ASR::asr_t* assign = ASR::make_Assignment_t(al,
216246
v->base.base.loc, target, value, nullptr);
217247
var_init.push_back(al, ASRUtils::STMT(assign));

src/libasr/asr_scopes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ struct SymbolTable {
8585

8686
void move_symbols_from_global_scope(Allocator &al,
8787
SymbolTable *module_scope, Vec<char *> &syms,
88-
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init);
88+
Vec<char *> &mod_dependencies, Vec<char *> &func_dependencies,
89+
Vec<ASR::stmt_t*> &var_init);
8990
};
9091

9192
} // namespace LCompilers

src/libasr/pass/global_symbols.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ void pass_wrap_global_syms_into_module(Allocator &al,
2222
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
2323
Vec<char *> moved_symbols;
2424
Vec<char *> mod_dependencies;
25+
Vec<char *> func_dependencies;
2526
Vec<ASR::stmt_t*> var_init;
2627

2728
// Move all the symbols from global into the module scope
2829
unit.m_global_scope->move_symbols_from_global_scope(al, module_scope,
29-
moved_symbols, mod_dependencies, var_init);
30+
moved_symbols, mod_dependencies, func_dependencies, var_init);
3031

3132
// Erase the symbols that are moved into the module
3233
for (auto &sym: moved_symbols) {
@@ -39,20 +40,19 @@ void pass_wrap_global_syms_into_module(Allocator &al,
3940
for (size_t i = 0; i < f->n_body; i++) {
4041
var_init.push_back(al, f->m_body[i]);
4142
}
43+
for (size_t i = 0; i < f->n_dependencies; i++) {
44+
func_dependencies.push_back(al, f->m_dependencies[i]);
45+
}
4246
f->m_body = var_init.p;
4347
f->n_body = var_init.n;
48+
f->m_dependencies = func_dependencies.p;
49+
f->n_dependencies = func_dependencies.n;
4450
// Overwrites the function: `_lpython_main_program`
4551
module_scope->add_symbol(f->m_name, (ASR::symbol_t *) f);
4652
}
4753

48-
Vec<char *> m_dependencies;
49-
m_dependencies.reserve(al, mod_dependencies.size());
50-
for( auto &dep: mod_dependencies) {
51-
m_dependencies.push_back(al, dep);
52-
}
53-
5454
ASR::symbol_t *module = (ASR::symbol_t *) ASR::make_Module_t(al, loc,
55-
module_scope, module_name, m_dependencies.p, m_dependencies.n,
55+
module_scope, module_name, mod_dependencies.p, mod_dependencies.n,
5656
false, false);
5757
unit.m_global_scope->add_symbol(module_name, module);
5858
}

0 commit comments

Comments
 (0)