Skip to content

Commit e5308f1

Browse files
authoredJun 14, 2024
Merge pull request #2711 from advikkabra/membership
Add membership checks in dictionaries and sets
2 parents ec2dfb5 + 3f9731c commit e5308f1

22 files changed

+388
-124
lines changed
 

‎grammar/Python.asdl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ module LPython
7373
-- need sequences for compare to distinguish between
7474
-- x < 4 < 3 and (x < 4) < 3
7575
| Compare(expr left, cmpop ops, expr* comparators)
76+
| Membership(expr left, membershipop op, expr right)
7677
| Call(expr func, expr* args, keyword* keywords)
7778
| FormattedValue(expr value, int conversion, expr? format_spec)
7879
| JoinedStr(expr* values)
@@ -110,7 +111,9 @@ module LPython
110111

111112
unaryop = Invert | Not | UAdd | USub
112113

113-
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
114+
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot
115+
116+
membershipop = In | NotIn
114117

115118
comprehension = (expr target, expr iter, expr* ifs, int is_async)
116119

‎integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x
600600
RUN(NAME test_import_06 LABELS cpython llvm llvm_jit)
601601
RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c)
602602
RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST)
603+
RUN(NAME test_membership_01 LABELS cpython llvm)
603604
RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c)
604605
RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c)
605606
RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def test_int_dict():
2+
a: dict[i32, i32] = {1:2, 2:3, 3:4, 4:5}
3+
i: i32
4+
assert (1 in a)
5+
assert (6 not in a)
6+
i = 4
7+
assert (i in a)
8+
9+
def test_str_dict():
10+
a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'}
11+
i: str
12+
assert ('a' in a)
13+
assert ('d' not in a)
14+
i = 'c'
15+
assert (i in a)
16+
17+
def test_int_set():
18+
a: set[i32] = {1, 2, 3, 4}
19+
i: i32
20+
assert (1 in a)
21+
assert (6 not in a)
22+
i = 4
23+
assert (i in a)
24+
25+
def test_str_set():
26+
a: set[str] = {'a', 'b', 'c', 'e', 'f'}
27+
i: str
28+
assert ('a' in a)
29+
# assert ('d' not in a)
30+
i = 'c'
31+
assert (i in a)
32+
33+
test_int_dict()
34+
test_str_dict()
35+
test_int_set()
36+
test_str_set()

‎src/libasr/ASR.asdl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,22 @@ expr
118118
| ListConcat(expr left, expr right, ttype type, expr? value)
119119
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
120120
| ListCount(expr arg, expr ele, ttype type, expr? value)
121+
| ListContains(expr left, expr right, ttype type, expr? value)
121122
| SetConstant(expr* elements, ttype type)
122123
| SetLen(expr arg, ttype type, expr? value)
123124
| TupleConstant(expr* elements, ttype type)
124125
| TupleLen(expr arg, ttype type, expr value)
125126
| TupleCompare(expr left, cmpop op, expr right, ttype type, expr? value)
126127
| TupleConcat(expr left, expr right, ttype type, expr? value)
128+
| TupleContains(expr left, expr right, ttype type, expr? value)
127129
| StringConstant(string s, ttype type)
128130
| StringConcat(expr left, expr right, ttype type, expr? value)
129131
| StringRepeat(expr left, expr right, ttype type, expr? value)
130132
| StringLen(expr arg, ttype type, expr? value)
131133
| StringItem(expr arg, expr idx, ttype type, expr? value)
132134
| StringSection(expr arg, expr? start, expr? end, expr? step, ttype type, expr? value)
133135
| StringCompare(expr left, cmpop op, expr right, ttype type, expr? value)
136+
| StringContains(expr left, expr right, ttype type, expr? value)
134137
| StringOrd(expr arg, ttype type, expr? value)
135138
| StringChr(expr arg, ttype type, expr? value)
136139
| StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value)
@@ -176,6 +179,8 @@ expr
176179
| ListRepeat(expr left, expr right, ttype type, expr? value)
177180
| DictPop(expr a, expr key, ttype type, expr? value)
178181
| SetPop(expr a, ttype type, expr? value)
182+
| SetContains(expr left, expr right, ttype type, expr? value)
183+
| DictContains(expr left, expr right, ttype type, expr? value)
179184
| IntegerBitLen(expr a, ttype type, expr? value)
180185
| Ichar(expr arg, ttype type, expr? value)
181186
| Iachar(expr arg, ttype type, expr? value)

‎src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,51 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16371637
}
16381638
}
16391639

1640+
void visit_DictContains(const ASR::DictContains_t &x) {
1641+
if (x.m_value) {
1642+
this->visit_expr(*x.m_value);
1643+
return;
1644+
}
1645+
1646+
int64_t ptr_loads_copy = ptr_loads;
1647+
ptr_loads = 0;
1648+
this->visit_expr(*x.m_right);
1649+
llvm::Value *right = tmp;
1650+
ASR::Dict_t *dict_type = ASR::down_cast<ASR::Dict_t>(
1651+
ASRUtils::expr_type(x.m_right));
1652+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
1653+
this->visit_expr(*x.m_left);
1654+
llvm::Value *left = tmp;
1655+
ptr_loads = ptr_loads_copy;
1656+
llvm::Value *capacity = LLVM::CreateLoad(*builder,
1657+
llvm_utils->dict_api->get_pointer_to_capacity(right));
1658+
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);
1659+
1660+
tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true);
1661+
}
1662+
1663+
void visit_SetContains(const ASR::SetContains_t &x) {
1664+
if (x.m_value) {
1665+
this->visit_expr(*x.m_value);
1666+
return;
1667+
}
1668+
1669+
int64_t ptr_loads_copy = ptr_loads;
1670+
ptr_loads = 0;
1671+
this->visit_expr(*x.m_right);
1672+
llvm::Value *right = tmp;
1673+
ASR::ttype_t *el_type = ASRUtils::expr_type(x.m_left);
1674+
ptr_loads = !LLVM::is_llvm_struct(el_type);
1675+
this->visit_expr(*x.m_left);
1676+
llvm::Value *left = tmp;
1677+
ptr_loads = ptr_loads_copy;
1678+
llvm::Value *capacity = LLVM::CreateLoad(*builder,
1679+
llvm_utils->set_api->get_pointer_to_capacity(right));
1680+
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);
1681+
1682+
tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true);
1683+
}
1684+
16401685
void visit_DictLen(const ASR::DictLen_t& x) {
16411686
if (x.m_value) {
16421687
this->visit_expr(*x.m_value);

‎src/libasr/codegen/llvm_utils.cpp

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,7 +3177,7 @@ namespace LCompilers {
31773177
llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check(
31783178
llvm::Value* dict, llvm::Value* key_hash,
31793179
llvm::Value* key, llvm::Module& module,
3180-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
3180+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) {
31813181
llvm::Value* key_list = get_key_list(dict);
31823182
llvm::Value* value_list = get_value_list(dict);
31833183
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
@@ -3187,6 +3187,8 @@ namespace LCompilers {
31873187
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
31883188
llvm_utils->list_api->read_item(key_list, pos, false, module,
31893189
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
3190+
if (check_if_exists)
3191+
return is_key_matching;
31903192

31913193
llvm_utils->create_if_else(is_key_matching, [&]() {
31923194
}, [&]() {
@@ -3245,7 +3247,7 @@ namespace LCompilers {
32453247
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check(
32463248
llvm::Value* dict, llvm::Value* key_hash,
32473249
llvm::Value* key, llvm::Module& module,
3248-
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
3250+
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) {
32493251

32503252
/**
32513253
* C++ equivalent:
@@ -3287,6 +3289,9 @@ namespace LCompilers {
32873289
llvm_utils->create_ptr_gep(key_mask, key_hash));
32883290
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
32893291
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
3292+
llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
3293+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr);
3294+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
32903295
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
32913296
builder->SetInsertPoint(thenBB);
32923297
{
@@ -3304,6 +3309,9 @@ namespace LCompilers {
33043309
llvm_utils->create_if_else(is_key_matching, [=]() {
33053310
LLVM::CreateStore(*builder, key_hash, pos_ptr);
33063311
}, [&]() {
3312+
if (check_if_exists) {
3313+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr);
3314+
} else {
33073315
std::string message = "The dict does not contain the specified key";
33083316
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
33093317
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
@@ -3312,7 +3320,7 @@ namespace LCompilers {
33123320
llvm::Value *exit_code = llvm::ConstantInt::get(context,
33133321
llvm::APInt(32, exit_code_int));
33143322
exit(context, module, *builder, exit_code);
3315-
});
3323+
}});
33163324
}
33173325
builder->CreateBr(mergeBB);
33183326
llvm_utils->start_new_block(elseBB);
@@ -3321,11 +3329,24 @@ namespace LCompilers {
33213329
module, key_asr_type, true);
33223330
}
33233331
llvm_utils->start_new_block(mergeBB);
3324-
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
3325-
// Check if the actual key is present or not
3326-
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
3332+
llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr);
3333+
llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr);
3334+
llvm::AllocaInst *is_key_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
3335+
3336+
llvm_utils->create_if_else(flag, [&](){
3337+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_key_matching_ptr);
3338+
}, [&](){
3339+
// Check if the actual element is present or not
3340+
LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(key,
33273341
llvm_utils->list_api->read_item(key_list, pos, false, module,
3328-
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
3342+
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type), is_key_matching_ptr);
3343+
});
3344+
3345+
llvm::Value *is_key_matching = LLVM::CreateLoad(*builder, is_key_matching_ptr);
3346+
3347+
if (check_if_exists) {
3348+
return is_key_matching;
3349+
}
33293350

33303351
llvm_utils->create_if_else(is_key_matching, [&]() {
33313352
}, [&]() {
@@ -3471,7 +3492,7 @@ namespace LCompilers {
34713492
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check(
34723493
llvm::Value* dict, llvm::Value* key_hash,
34733494
llvm::Value* key, llvm::Module& module,
3474-
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) {
3495+
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists) {
34753496
/**
34763497
* C++ equivalent:
34773498
*
@@ -3506,6 +3527,10 @@ namespace LCompilers {
35063527
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
35073528
);
35083529

3530+
if (check_if_exists) {
3531+
return does_kv_exists;
3532+
}
3533+
35093534
llvm_utils->create_if_else(does_kv_exists, [&]() {
35103535
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
35113536
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo());
@@ -4358,6 +4383,7 @@ namespace LCompilers {
43584383
// end
43594384
llvm_utils->start_new_block(loopend);
43604385
}
4386+
43614387

43624388
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
43634389
bool enable_bounds_checking,
@@ -6393,9 +6419,9 @@ namespace LCompilers {
63936419
el_asr_type, name2memidx);
63946420
}
63956421

6396-
void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
6422+
llvm::Value* LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
63976423
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
6398-
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
6424+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) {
63996425

64006426
/**
64016427
* C++ equivalent:
@@ -6423,18 +6449,22 @@ namespace LCompilers {
64236449
*/
64246450

64256451
get_builder0()
6452+
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
64266453
llvm::Value* el_list = get_el_list(set);
64276454
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
64286455
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6429-
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
64306456
llvm::Function *fn = builder->GetInsertBlock()->getParent();
6431-
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
6432-
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
6433-
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
6457+
std::string s = check_if_exists ? "qq" : "pp";
6458+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then"+s, fn);
6459+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"+s);
6460+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"+s);
64346461
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
64356462
llvm_utils->create_ptr_gep(el_mask, el_hash));
64366463
llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value,
64376464
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
6465+
llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
6466+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
6467+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr);
64386468
builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB);
64396469
builder->SetInsertPoint(thenBB);
64406470
{
@@ -6447,6 +6477,9 @@ namespace LCompilers {
64476477
llvm_utils->create_if_else(is_el_matching, [=]() {
64486478
LLVM::CreateStore(*builder, el_hash, pos_ptr);
64496479
}, [&]() {
6480+
if (check_if_exists) {
6481+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr);
6482+
} else {
64506483
if (throw_key_error) {
64516484
std::string message = "The set does not contain the specified element";
64526485
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
@@ -6457,7 +6490,7 @@ namespace LCompilers {
64576490
llvm::APInt(32, exit_code_int));
64586491
exit(context, module, *builder, exit_code);
64596492
}
6460-
});
6493+
}});
64616494
}
64626495
builder->CreateBr(mergeBB);
64636496
llvm_utils->start_new_block(elseBB);
@@ -6466,11 +6499,25 @@ namespace LCompilers {
64666499
module, el_asr_type, true);
64676500
}
64686501
llvm_utils->start_new_block(mergeBB);
6469-
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6502+
llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr);
6503+
llvm::AllocaInst *is_el_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
6504+
6505+
llvm_utils->create_if_else(flag, [&](){
6506+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_el_matching_ptr);
6507+
}, [&](){
64706508
// Check if the actual element is present or not
6471-
llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el,
6472-
llvm_utils->list_api->read_item(el_list, pos, false, module,
6473-
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);
6509+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6510+
llvm::Value* item = llvm_utils->list_api->read_item(el_list, pos, false, module,
6511+
LLVM::is_llvm_struct(el_asr_type)) ;
6512+
llvm::Value *iseq =llvm_utils->is_equal_by_value(el,
6513+
item, module, el_asr_type) ;
6514+
LLVM::CreateStore(*builder, iseq, is_el_matching_ptr);
6515+
});
6516+
6517+
llvm::Value *is_el_matching = LLVM::CreateLoad(*builder, is_el_matching_ptr);
6518+
if (check_if_exists) {
6519+
return is_el_matching;
6520+
}
64746521

64756522
llvm_utils->create_if_else(is_el_matching, []() {}, [&]() {
64766523
if (throw_key_error) {
@@ -6484,11 +6531,13 @@ namespace LCompilers {
64846531
exit(context, module, *builder, exit_code);
64856532
}
64866533
});
6534+
6535+
return nullptr;
64876536
}
64886537

6489-
void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
6538+
llvm::Value* LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
64906539
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
6491-
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
6540+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) {
64926541
/**
64936542
* C++ equivalent:
64946543
*
@@ -6515,6 +6564,10 @@ namespace LCompilers {
65156564
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
65166565
);
65176566

6567+
if (check_if_exists) {
6568+
return does_el_exist;
6569+
}
6570+
65186571
llvm_utils->create_if_else(does_el_exist, []() {}, [&]() {
65196572
if (throw_key_error) {
65206573
std::string message = "The set does not contain the specified element";
@@ -6527,6 +6580,8 @@ namespace LCompilers {
65276580
exit(context, module, *builder, exit_code);
65286581
}
65296582
});
6583+
6584+
return nullptr;
65306585
}
65316586

65326587
void LLVMSetLinearProbing::remove_item(

0 commit comments

Comments
 (0)