Skip to content

Commit 020cf8e

Browse files
advikkabraczgdp1807
authored andcommitted
Add membership checks in dictionaries and sets
1 parent ec2dfb5 commit 020cf8e

22 files changed

+559
-93
lines changed

grammar/Python.asdl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ module LPython
7373
-- need sequences for compare to distinguish between
7474
-- x < 4 < 3 and (x < 4) < 3
7575
| Compare(expr left, cmpop ops, expr* comparators)
76+
| Membership(expr left, membershipop op, expr right)
7677
| Call(expr func, expr* args, keyword* keywords)
7778
| FormattedValue(expr value, int conversion, expr? format_spec)
7879
| JoinedStr(expr* values)
@@ -110,7 +111,9 @@ module LPython
110111

111112
unaryop = Invert | Not | UAdd | USub
112113

113-
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
114+
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot
115+
116+
membershipop = In | NotIn
114117

115118
comprehension = (expr target, expr iter, expr* ifs, int is_async)
116119

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x
600600
RUN(NAME test_import_06 LABELS cpython llvm llvm_jit)
601601
RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c)
602602
RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST)
603+
RUN(NAME test_membership_01 LABELS cpython llvm llvm_jit c)
603604
RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c)
604605
RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c)
605606
RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def test_int_dict():
2+
a: dict[i32, i32] = {1:2, 2:3, 3:4, 4:5}
3+
i: i32
4+
assert (1 in a)
5+
assert (6 not in a)
6+
i = 4
7+
assert (i in a)
8+
9+
def test_str_dict():
10+
a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'}
11+
i: str
12+
assert ('a' in a)
13+
assert ('d' not in a)
14+
i = 'c'
15+
assert (i in a)
16+
17+
def test_int_set():
18+
a: set[i32] = {1, 2, 3, 4}
19+
i: i32
20+
assert (1 in a)
21+
assert (6 not in a)
22+
i = 4
23+
# assert (i in a)
24+
25+
def test_str_set():
26+
a: set[str] = {'a', 'b', 'c'}
27+
i: str
28+
assert ('a' in a)
29+
assert ('d' not in a)
30+
i = 'c'
31+
assert (i in a)
32+
33+
# test_int_dict()
34+
# test_str_dict()
35+
test_int_set()
36+
# test_str_set()

src/libasr/ASR.asdl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,22 @@ expr
118118
| ListConcat(expr left, expr right, ttype type, expr? value)
119119
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
120120
| ListCount(expr arg, expr ele, ttype type, expr? value)
121+
| ListContains(expr left, expr right, ttype type, expr? value)
121122
| SetConstant(expr* elements, ttype type)
122123
| SetLen(expr arg, ttype type, expr? value)
123124
| TupleConstant(expr* elements, ttype type)
124125
| TupleLen(expr arg, ttype type, expr value)
125126
| TupleCompare(expr left, cmpop op, expr right, ttype type, expr? value)
126127
| TupleConcat(expr left, expr right, ttype type, expr? value)
128+
| TupleContains(expr left, expr right, ttype type, expr? value)
127129
| StringConstant(string s, ttype type)
128130
| StringConcat(expr left, expr right, ttype type, expr? value)
129131
| StringRepeat(expr left, expr right, ttype type, expr? value)
130132
| StringLen(expr arg, ttype type, expr? value)
131133
| StringItem(expr arg, expr idx, ttype type, expr? value)
132134
| StringSection(expr arg, expr? start, expr? end, expr? step, ttype type, expr? value)
133135
| StringCompare(expr left, cmpop op, expr right, ttype type, expr? value)
136+
| StringContains(expr left, expr right, ttype type, expr? value)
134137
| StringOrd(expr arg, ttype type, expr? value)
135138
| StringChr(expr arg, ttype type, expr? value)
136139
| StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value)
@@ -176,6 +179,8 @@ expr
176179
| ListRepeat(expr left, expr right, ttype type, expr? value)
177180
| DictPop(expr a, expr key, ttype type, expr? value)
178181
| SetPop(expr a, ttype type, expr? value)
182+
| SetContains(expr left, expr right, ttype type, expr? value)
183+
| DictContains(expr left, expr right, ttype type, expr? value)
179184
| IntegerBitLen(expr a, ttype type, expr? value)
180185
| Ichar(expr arg, ttype type, expr? value)
181186
| Iachar(expr arg, ttype type, expr? value)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,45 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
16371637
}
16381638
}
16391639

1640+
void visit_DictContains(const ASR::DictContains_t &x) {
1641+
if (x.m_value) {
1642+
this->visit_expr(*x.m_value);
1643+
return;
1644+
}
1645+
1646+
int64_t ptr_loads_copy = ptr_loads;
1647+
ptr_loads = 0;
1648+
this->visit_expr(*x.m_right);
1649+
llvm::Value *right = tmp;
1650+
ASR::Dict_t *dict_type = ASR::down_cast<ASR::Dict_t>(
1651+
ASRUtils::expr_type(x.m_right));
1652+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
1653+
this->visit_expr(*x.m_left);
1654+
llvm::Value *left = tmp;
1655+
ptr_loads = ptr_loads_copy;
1656+
1657+
tmp = llvm_utils->dict_api->is_key_present(right, left, dict_type, *module);
1658+
}
1659+
1660+
void visit_SetContains(const ASR::SetContains_t &x) {
1661+
if (x.m_value) {
1662+
this->visit_expr(*x.m_value);
1663+
return;
1664+
}
1665+
1666+
int64_t ptr_loads_copy = ptr_loads;
1667+
ptr_loads = 0;
1668+
this->visit_expr(*x.m_right);
1669+
llvm::Value *right = tmp;
1670+
ASR::ttype_t *el_type = ASRUtils::expr_type(x.m_left);
1671+
ptr_loads = !LLVM::is_llvm_struct(el_type);
1672+
this->visit_expr(*x.m_left);
1673+
llvm::Value *left = tmp;
1674+
ptr_loads = ptr_loads_copy;
1675+
1676+
tmp = llvm_utils->set_api->is_el_present(right, left, *module, el_type);
1677+
}
1678+
16401679
void visit_DictLen(const ASR::DictLen_t& x) {
16411680
if (x.m_value) {
16421681
this->visit_expr(*x.m_value);

src/libasr/codegen/llvm_utils.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,6 +4359,128 @@ namespace LCompilers {
43594359
llvm_utils->start_new_block(loopend);
43604360
}
43614361

4362+
llvm::Value *LLVMDict::is_key_present(llvm::Value *dict, llvm::Value *key,
4363+
ASR::Dict_t *dict_type, llvm::Module &module) {
4364+
llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
4365+
llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module);
4366+
llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
4367+
llvm::Value *key_list = get_key_list(dict);
4368+
4369+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, dict_type->m_key_type, true);
4370+
llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr);
4371+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
4372+
llvm_utils->list_api->read_item(key_list, pos, false, module,
4373+
LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type);
4374+
4375+
return is_key_matching;
4376+
}
4377+
4378+
llvm::Value *LLVMDictOptimizedLinearProbing::is_key_present(llvm::Value *dict, llvm::Value *key,
4379+
ASR::Dict_t *dict_type, llvm::Module &module) {
4380+
/**
4381+
* C++ equivalent:
4382+
*
4383+
* key_mask_value = key_mask[key_hash];
4384+
* is_prob_not_needed = key_mask_value == 1;
4385+
* if( is_prob_not_needed ) {
4386+
* is_key_matching = key == key_list[key_hash];
4387+
* if( is_key_matching ) {
4388+
* pos = key_hash;
4389+
* }
4390+
* else {
4391+
* return is_key_matching;
4392+
* }
4393+
* }
4394+
* else {
4395+
* resolve_collision(key, for_read=true); // modifies pos
4396+
* }
4397+
*
4398+
* is_key_matching = key == key_list[pos];
4399+
* return is_key_matching;
4400+
*/
4401+
4402+
llvm::Value* key_list = get_key_list(dict);
4403+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
4404+
llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module);
4405+
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
4406+
get_builder0()
4407+
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
4408+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
4409+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
4410+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
4411+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
4412+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
4413+
llvm_utils->create_ptr_gep(key_mask, key_hash));
4414+
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
4415+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
4416+
bool to_return = false;
4417+
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
4418+
builder->SetInsertPoint(thenBB);
4419+
{
4420+
// A single by value comparison is needed even though
4421+
// we don't need to do linear probing. This is because
4422+
// the user can provide a key which is absent in the dict
4423+
// but is giving the same hash value as one of the keys present in the dict.
4424+
// In the above case we will end up returning value for a key
4425+
// which is not present in the dict. Instead we should return an error
4426+
// which is done in the below code.
4427+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
4428+
llvm_utils->list_api->read_item(key_list, key_hash, false, module,
4429+
LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type);
4430+
4431+
llvm_utils->create_if_else(is_key_matching, [=]() {
4432+
LLVM::CreateStore(*builder, key_hash, pos_ptr);
4433+
}, [&]() {
4434+
//to_return = true;
4435+
});
4436+
}
4437+
builder->CreateBr(mergeBB);
4438+
llvm_utils->start_new_block(elseBB);
4439+
{
4440+
this->resolve_collision(capacity, key_hash, key, key_list, key_mask,
4441+
module, dict_type->m_key_type, true);
4442+
}
4443+
llvm_utils->start_new_block(mergeBB);
4444+
if (to_return) {
4445+
return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0);
4446+
}
4447+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
4448+
// Check if the actual key is present or not
4449+
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
4450+
llvm_utils->list_api->read_item(key_list, pos, false, module,
4451+
LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type);
4452+
4453+
return is_key_matching;
4454+
}
4455+
4456+
llvm::Value *LLVMDictSeparateChaining::is_key_present(llvm::Value *dict, llvm::Value *key,
4457+
ASR::Dict_t *dict_type, llvm::Module &module) {
4458+
llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
4459+
llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module);
4460+
llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
4461+
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
4462+
llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash);
4463+
llvm::Type* kv_struct_type = get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type);
4464+
this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list,
4465+
kv_struct_type, key_mask, module, dict_type->m_key_type);
4466+
std::pair<std::string, std::string> llvm_key = std::make_pair(
4467+
ASRUtils::get_type_code(dict_type->m_key_type),
4468+
ASRUtils::get_type_code(dict_type->m_value_type)
4469+
);
4470+
llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second;
4471+
get_builder0()
4472+
tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr);
4473+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
4474+
llvm_utils->create_ptr_gep(key_mask, key_hash));
4475+
llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value,
4476+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
4477+
does_kv_exists = builder->CreateAnd(does_kv_exists,
4478+
builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr),
4479+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
4480+
);
4481+
return does_kv_exists;
4482+
}
4483+
43624484
llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
43634485
bool enable_bounds_checking,
43644486
llvm::Module& module, bool get_pointer) {
@@ -6825,6 +6947,113 @@ namespace LCompilers {
68256947
llvm_utils->start_new_block(loopend);
68266948
}
68276949

6950+
llvm::Value *LLVMSetLinearProbing::is_el_present(
6951+
llvm::Value *set, llvm::Value *el,
6952+
llvm::Module &module, ASR::ttype_t *el_asr_type) {
6953+
/**
6954+
* C++ equivalent:
6955+
*
6956+
* el_mask_value = el_mask[el_hash];
6957+
* is_prob_needed = el_mask_value == 1;
6958+
* if( is_prob_needed ) {
6959+
* is_el_matching = el == el_list[el_hash];
6960+
* if( is_el_matching ) {
6961+
* pos = el_hash;
6962+
* }
6963+
* else {
6964+
* return is_el_matching;
6965+
* }
6966+
* }
6967+
* else {
6968+
* resolve_collision(el, for_read=true); // modifies pos
6969+
* }
6970+
*
6971+
* is_el_matching = el == el_list[pos];
6972+
* return is_el_matching
6973+
*/
6974+
6975+
get_builder0()
6976+
llvm::Value* el_list = get_el_list(set);
6977+
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
6978+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6979+
llvm::Value *el_hash = get_el_hash(capacity, el, el_asr_type, module);
6980+
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
6981+
llvm::Function *fn = builder->GetInsertBlock()->getParent();
6982+
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
6983+
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
6984+
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
6985+
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
6986+
llvm_utils->create_ptr_gep(el_mask, el_hash));
6987+
llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value,
6988+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
6989+
bool to_return = false;
6990+
builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB);
6991+
builder->SetInsertPoint(thenBB);
6992+
{
6993+
// reasoning for this check explained in
6994+
// LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check
6995+
llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el,
6996+
llvm_utils->list_api->read_item(el_list, el_hash, false, module,
6997+
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);
6998+
6999+
llvm_utils->create_if_else(is_el_matching, [=]() {
7000+
LLVM::CreateStore(*builder, el_hash, pos_ptr);
7001+
}, [&]() {
7002+
//to_return = true; // Need to check why this is not working
7003+
});
7004+
}
7005+
builder->CreateBr(mergeBB);
7006+
llvm_utils->start_new_block(elseBB);
7007+
{
7008+
this->resolve_collision(capacity, el_hash, el, el_list, el_mask,
7009+
module, el_asr_type, true);
7010+
}
7011+
llvm_utils->start_new_block(mergeBB);
7012+
if (to_return) {
7013+
return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0);
7014+
}
7015+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
7016+
// Check if the actual element is present or not
7017+
llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el,
7018+
llvm_utils->list_api->read_item(el_list, pos, false, module,
7019+
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);
7020+
7021+
7022+
return is_el_matching;
7023+
}
7024+
7025+
llvm::Value *LLVMSetSeparateChaining::is_el_present(
7026+
llvm::Value *set, llvm::Value *el,
7027+
llvm::Module &module, ASR::ttype_t *el_asr_type) {
7028+
/**
7029+
* C++ equivalent:
7030+
*
7031+
* resolve_collision(el); // modified chain_itr
7032+
* does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr;
7033+
* return does_el_exist;
7034+
*
7035+
*/
7036+
llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set));
7037+
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
7038+
llvm::Value* el_hash = get_el_hash(capacity, el, el_asr_type, module);
7039+
llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash);
7040+
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
7041+
std::string el_type_code = ASRUtils::get_type_code(el_asr_type);
7042+
llvm::Type* el_struct_type = typecode2elstruct[el_type_code];
7043+
this->resolve_collision(el_hash, el, el_linked_list,
7044+
el_struct_type, el_mask, module, el_asr_type);
7045+
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
7046+
llvm_utils->create_ptr_gep(el_mask, el_hash));
7047+
llvm::Value* does_el_exist = builder->CreateICmpEQ(el_mask_value,
7048+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
7049+
does_el_exist = builder->CreateAnd(does_el_exist,
7050+
builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr),
7051+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
7052+
);
7053+
7054+
return does_el_exist;
7055+
}
7056+
68287057
llvm::Value* LLVMSetInterface::len(llvm::Value* set) {
68297058
return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set));
68307059
}

0 commit comments

Comments
 (0)