Skip to content

Commit 6d8ab22

Browse files
Add Index bounds checking option to check out of range in lists.
1 parent d4a6d53 commit 6d8ab22

File tree

5 files changed

+88
-49
lines changed

5 files changed

+88
-49
lines changed

src/bin/lpython.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ int main(int argc, char *argv[])
12821282
app.add_flag("--no-warnings", compiler_options.no_warnings, "Turn off all warnings");
12831283
app.add_flag("--no-error-banner", compiler_options.no_error_banner, "Turn off error banner");
12841284
app.add_option("--backend", arg_backend, "Select a backend (llvm, cpp, x86, wasm, wasm_x86)")->capture_default_str();
1285+
app.add_flag("--enable-bounds-checking", compiler_options.enable_bounds_checking, "Turn on index bounds checking");
12851286
app.add_flag("--openmp", compiler_options.openmp, "Enable openmp");
12861287
app.add_flag("--fast", compiler_options.fast, "Best performance (disable strict standard compliance)");
12871288
app.add_option("--target", compiler_options.target, "Generate code for the given target")->capture_default_str();

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
164164
bool emit_debug_info;
165165
std::string infile;
166166
bool emit_debug_line_column;
167+
bool enable_bounds_checking;
167168
Allocator &al;
168169

169170
llvm::Value *tmp;
@@ -242,14 +243,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
242243

243244
ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, Platform platform,
244245
bool emit_debug_info, std::string infile, bool emit_debug_line_column,
245-
diag::Diagnostics &diagnostics) :
246+
bool enable_bounds_checking, diag::Diagnostics &diagnostics) :
246247
diag{diagnostics},
247248
context(context),
248249
builder(std::make_unique<llvm::IRBuilder<>>(context)),
249250
platform{platform},
250251
emit_debug_info{emit_debug_info},
251252
infile{infile},
252253
emit_debug_line_column{emit_debug_line_column},
254+
enable_bounds_checking{enable_bounds_checking},
253255
al{al},
254256
prototype_only(false),
255257
llvm_utils(std::make_unique<LLVMUtils>(context, builder.get())),
@@ -1380,7 +1382,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13801382
this->visit_expr(*x.m_args[i]);
13811383
llvm::Value* item = tmp;
13821384
llvm::Value* pos = llvm::ConstantInt::get(context, llvm::APInt(32, i));
1383-
list_api->write_item(const_list, pos, item, list_type->m_type, *module);
1385+
list_api->write_item(const_list, pos, item, list_type->m_type, false, *module);
13841386
}
13851387
ptr_loads = ptr_loads_copy;
13861388
tmp = const_list;
@@ -1520,9 +1522,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15201522
ptr_loads = ptr_loads_copy;
15211523
llvm::Value *pos = tmp;
15221524

1523-
tmp = list_api->read_item(plist, pos,
1524-
(LLVM::is_llvm_struct(el_type) ||
1525-
ptr_loads == 0));
1525+
tmp = list_api->read_item(plist, pos, enable_bounds_checking, *module,
1526+
(LLVM::is_llvm_struct(el_type) || ptr_loads == 0));
15261527
}
15271528

15281529
void visit_DictItem(const ASR::DictItem_t& x) {
@@ -3914,7 +3915,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
39143915
llvm::Value* list = tmp;
39153916
this->visit_expr_wrapper(asr_target0->m_pos, true);
39163917
llvm::Value* pos = tmp;
3917-
target = list_api->read_item(list, pos, true);
3918+
3919+
target = list_api->read_item(list, pos, enable_bounds_checking,
3920+
*module, true);
39183921
}
39193922
} else {
39203923
ASR::Variable_t *asr_target = EXPR2VAR(x.m_target);
@@ -6397,7 +6400,7 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
63976400
context.setOpaquePointers(false);
63986401
#endif
63996402
ASRToLLVMVisitor v(al, context, co.platform, co.emit_debug_info, infile,
6400-
co.emit_debug_line_column, diagnostics);
6403+
co.emit_debug_line_column, co.enable_bounds_checking, diagnostics);
64016404
LCompilers::PassOptions pass_options;
64026405
pass_options.run_fun = run_fn;
64036406
pass_options.always_run = false;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ namespace LFortran {
670670
llvm_utils->start_new_block(loopbody);
671671
{
672672
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
673-
llvm::Value* srci = read_item(src, pos, true);
674-
llvm::Value* desti = read_item(dest, pos, true);
673+
llvm::Value* srci = read_item(src, pos, false, module, true);
674+
llvm::Value* desti = read_item(dest, pos, false, module, true);
675675
llvm_utils->deepcopy(srci, desti, element_type, module);
676676
llvm::Value* tmp = builder->CreateAdd(
677677
pos,
@@ -951,25 +951,55 @@ namespace LFortran {
951951
LLVM::CreateStore(*builder, dest_key_value_pairs, get_pointer_to_key_value_pairs(dest));
952952
}
953953

954-
void LLVMList::check_index_within_bounds(llvm::Value* /*list*/, llvm::Value* /*pos*/) {
954+
void LLVMList::check_index_within_bounds(llvm::Value* list,
955+
llvm::Value* pos, llvm::Module& module) {
956+
llvm::Value* end_point = LLVM::CreateLoad(*builder,
957+
get_pointer_to_current_end_point(list));
958+
llvm::Value* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
959+
llvm::APInt(32, 0));
955960

961+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
962+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
963+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
964+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
965+
966+
llvm::Value* cond = builder->CreateOr(
967+
builder->CreateICmpSGE(pos, end_point),
968+
builder->CreateICmpSLT(pos, zero));
969+
builder->CreateCondBr(cond, thenBB, elseBB);
970+
builder->SetInsertPoint(thenBB);
971+
{
972+
std::string message = "list index out of range";
973+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("IndexError: %s\n");
974+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
975+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
976+
int exit_code_int = 1;
977+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
978+
llvm::APInt(32, exit_code_int));
979+
exit(context, module, *builder, exit_code);
980+
}
981+
builder->CreateBr(mergeBB);
982+
983+
llvm_utils->start_new_block(elseBB);
984+
llvm_utils->start_new_block(mergeBB);
956985
}
957986

958987
void LLVMList::write_item(llvm::Value* list, llvm::Value* pos,
959988
llvm::Value* item, ASR::ttype_t* asr_type,
960-
llvm::Module& module, bool check_index_bound) {
961-
if( check_index_bound ) {
962-
check_index_within_bounds(list, pos);
989+
bool enable_bounds_checking, llvm::Module& module) {
990+
if( enable_bounds_checking ) {
991+
check_index_within_bounds(list, pos, module);
963992
}
964993
llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list));
965994
llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos);
966995
llvm_utils->deepcopy(item, element_ptr, asr_type, module);
967996
}
968997

969998
void LLVMList::write_item(llvm::Value* list, llvm::Value* pos,
970-
llvm::Value* item, bool check_index_bound) {
971-
if( check_index_bound ) {
972-
check_index_within_bounds(list, pos);
999+
llvm::Value* item, bool enable_bounds_checking,
1000+
llvm::Module& module) {
1001+
if( enable_bounds_checking ) {
1002+
check_index_within_bounds(list, pos, module);
9731003
}
9741004
llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list));
9751005
llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos);
@@ -1108,7 +1138,7 @@ namespace LFortran {
11081138
builder->SetInsertPoint(thenBB);
11091139
{
11101140
llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos,
1111-
LLVM::is_llvm_struct(key_asr_type), false);
1141+
false, module, LLVM::is_llvm_struct(key_asr_type));
11121142
is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module,
11131143
key_asr_type);
11141144
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
@@ -1196,7 +1226,7 @@ namespace LFortran {
11961226
builder->SetInsertPoint(thenBB);
11971227
{
11981228
llvm::Value* original_key = llvm_utils->list_api->read_item(key_list, pos,
1199-
LLVM::is_llvm_struct(key_asr_type), false);
1229+
false, module, LLVM::is_llvm_struct(key_asr_type));
12001230
is_key_matching = llvm_utils->is_equal_by_value(key, original_key, module,
12011231
key_asr_type);
12021232
LLVM::CreateStore(*builder, is_key_matching, is_key_matching_var);
@@ -1323,9 +1353,9 @@ namespace LFortran {
13231353
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type);
13241354
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
13251355
llvm_utils->list_api->write_item(key_list, pos, key,
1326-
key_asr_type, *module, false);
1356+
key_asr_type, false, *module);
13271357
llvm_utils->list_api->write_item(value_list, pos, value,
1328-
value_asr_type, *module, false);
1358+
value_asr_type, false, *module);
13291359
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
13301360
llvm_utils->create_ptr_gep(key_mask, pos));
13311361
llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value,
@@ -1352,9 +1382,9 @@ namespace LFortran {
13521382
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, *module, key_asr_type);
13531383
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
13541384
llvm_utils->list_api->write_item(key_list, pos, key,
1355-
key_asr_type, *module, false);
1385+
key_asr_type, false, *module);
13561386
llvm_utils->list_api->write_item(value_list, pos, value,
1357-
value_asr_type, *module, false);
1387+
value_asr_type, false, *module);
13581388

13591389
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
13601390
llvm_utils->create_ptr_gep(key_mask, pos));
@@ -1455,7 +1485,7 @@ namespace LFortran {
14551485
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
14561486
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, key_asr_type, true);
14571487
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1458-
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false);
1488+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, false, module, true);
14591489
return item;
14601490
}
14611491

@@ -1493,9 +1523,8 @@ namespace LFortran {
14931523
llvm::BasicBlock *elseBB_single_match = llvm::BasicBlock::Create(context, "else");
14941524
llvm::BasicBlock *mergeBB_single_match = llvm::BasicBlock::Create(context, "ifcont");
14951525
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
1496-
llvm_utils->list_api->read_item(key_list, key_hash,
1497-
LLVM::is_llvm_struct(key_asr_type), false),
1498-
module, key_asr_type);
1526+
llvm_utils->list_api->read_item(key_list, key_hash, false, module,
1527+
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
14991528
builder->CreateCondBr(is_key_matching, thenBB_single_match, elseBB_single_match);
15001529
builder->SetInsertPoint(thenBB_single_match);
15011530
LLVM::CreateStore(*builder, key_hash, pos_ptr);
@@ -1521,7 +1550,8 @@ namespace LFortran {
15211550
}
15221551
llvm_utils->start_new_block(mergeBB);
15231552
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1524-
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos, true, false);
1553+
llvm::Value* item = llvm_utils->list_api->read_item(value_list, pos,
1554+
false, module, true);
15251555
return item;
15261556
}
15271557

@@ -1757,18 +1787,18 @@ namespace LFortran {
17571787
builder->SetInsertPoint(thenBB);
17581788
{
17591789
llvm::Value* key = llvm_utils->list_api->read_item(key_list, idx,
1760-
LLVM::is_llvm_struct(key_asr_type), false);
1761-
llvm::Value* value = llvm_utils->list_api->read_item(value_list, idx,
1762-
LLVM::is_llvm_struct(value_asr_type), false);
1790+
false, *module, LLVM::is_llvm_struct(key_asr_type));
1791+
llvm::Value* value = llvm_utils->list_api->read_item(value_list,
1792+
idx, false, *module, LLVM::is_llvm_struct(value_asr_type));
17631793
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module);
17641794
this->resolve_collision(current_capacity, key_hash, key, new_key_list,
17651795
new_key_mask, *module, key_asr_type);
17661796
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
1767-
llvm::Value* key_dest = llvm_utils->list_api->read_item(new_key_list, pos,
1768-
true, false);
1797+
llvm::Value* key_dest = llvm_utils->list_api->read_item(
1798+
new_key_list, pos, false, *module, true);
17691799
llvm_utils->deepcopy(key, key_dest, key_asr_type, *module);
1770-
llvm::Value* value_dest = llvm_utils->list_api->read_item(new_value_list, pos,
1771-
true, false);
1800+
llvm::Value* value_dest = llvm_utils->list_api->read_item(
1801+
new_value_list, pos, false, *module, true);
17721802
llvm_utils->deepcopy(value, value_dest, value_asr_type, *module);
17731803

17741804
llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos);
@@ -2143,10 +2173,11 @@ namespace LFortran {
21432173
return LLVM::CreateLoad(*builder, value_ptr);
21442174
}
21452175

2146-
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool get_pointer,
2147-
bool check_index_bound) {
2148-
if( check_index_bound ) {
2149-
check_index_within_bounds(list, pos);
2176+
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
2177+
bool enable_bounds_checking,
2178+
llvm::Module& module, bool get_pointer) {
2179+
if( enable_bounds_checking ) {
2180+
check_index_within_bounds(list, pos, module);
21502181
}
21512182
llvm::Value* list_data = LLVM::CreateLoad(*builder, get_pointer_to_list_data(list));
21522183
llvm::Value* element_ptr = llvm_utils->create_ptr_gep(list_data, pos);
@@ -2233,7 +2264,7 @@ namespace LFortran {
22332264
llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]);
22342265
resize_if_needed(list, current_end_point, current_capacity,
22352266
type_size, el_type, module);
2236-
write_item(list, current_end_point, item, asr_type, module);
2267+
write_item(list, current_end_point, item, asr_type, false, module);
22372268
shift_end_point_by_one(list);
22382269
}
22392270

@@ -2271,7 +2302,7 @@ namespace LFortran {
22712302
// LLVMList should treat them as data members and create them
22722303
// only if they are NULL
22732304
llvm::AllocaInst *tmp_ptr = builder->CreateAlloca(el_type, nullptr);
2274-
LLVM::CreateStore(*builder, read_item(list, pos, false), tmp_ptr);
2305+
LLVM::CreateStore(*builder, read_item(list, pos, false, module, false), tmp_ptr);
22752306
llvm::Value* tmp = nullptr;
22762307

22772308
// TODO: Should be created outside the user loop and not here.
@@ -2300,8 +2331,8 @@ namespace LFortran {
23002331
llvm::Value* next_index = builder->CreateAdd(
23012332
LLVM::CreateLoad(*builder, pos_ptr),
23022333
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
2303-
tmp = read_item(list, next_index, false);
2304-
write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr));
2334+
tmp = read_item(list, next_index, false, module, false);
2335+
write_item(list, next_index, LLVM::CreateLoad(*builder, tmp_ptr), false, module);
23052336
LLVM::CreateStore(*builder, tmp, tmp_ptr);
23062337

23072338
tmp = builder->CreateAdd(
@@ -2314,7 +2345,7 @@ namespace LFortran {
23142345
// end
23152346
llvm_utils->start_new_block(loopend);
23162347

2317-
write_item(list, pos, item, asr_type, module);
2348+
write_item(list, pos, item, asr_type, false, module);
23182349
shift_end_point_by_one(list);
23192350
}
23202351

@@ -2350,7 +2381,7 @@ namespace LFortran {
23502381
llvm_utils->start_new_block(loophead);
23512382
{
23522383
llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i),
2353-
LLVM::is_llvm_struct(item_type));
2384+
false, module, LLVM::is_llvm_struct(item_type));
23542385
llvm::Value* is_item_not_equal = builder->CreateNot(
23552386
llvm_utils->is_equal_by_value(
23562387
left_arg, item,
@@ -2445,7 +2476,7 @@ namespace LFortran {
24452476
LLVM::CreateLoad(*builder, item_pos),
24462477
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
24472478
write_item(list, LLVM::CreateLoad(*builder, item_pos),
2448-
read_item(list, tmp, false));
2479+
read_item(list, tmp, false, module, false), false, module);
24492480
LLVM::CreateStore(*builder, tmp, item_pos);
24502481
}
24512482
builder->CreateBr(loophead);

src/libasr/codegen/llvm_utils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,21 @@ namespace LFortran {
170170
llvm::Module& module);
171171

172172
llvm::Value* read_item(llvm::Value* list, llvm::Value* pos,
173-
bool get_pointer=false, bool check_index_bound=true);
173+
bool enable_bounds_checking,
174+
llvm::Module& module, bool get_pointer=false);
174175

175176
llvm::Value* len(llvm::Value* list);
176177

177-
void check_index_within_bounds(llvm::Value* list, llvm::Value* pos);
178+
void check_index_within_bounds(llvm::Value* list, llvm::Value* pos,
179+
llvm::Module& module);
178180

179181
void write_item(llvm::Value* list, llvm::Value* pos,
180182
llvm::Value* item, ASR::ttype_t* asr_type,
181-
llvm::Module& module, bool check_index_bound=true);
183+
bool enable_bounds_checking, llvm::Module& module);
182184

183185
void write_item(llvm::Value* list, llvm::Value* pos,
184-
llvm::Value* item, bool check_index_bound=true);
186+
llvm::Value* item, bool enable_bounds_checking,
187+
llvm::Module& module);
185188

186189
void append(llvm::Value* list, llvm::Value* item,
187190
ASR::ttype_t* asr_type, llvm::Module& module);

src/libasr/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct CompilerOptions {
3232
bool generate_object_code = false;
3333
bool no_warnings = false;
3434
bool no_error_banner = false;
35+
bool enable_bounds_checking = false;
3536
std::string error_format = "human";
3637
bool new_parser = false;
3738
bool implicit_typing = false;

0 commit comments

Comments
 (0)