Skip to content

Commit a104c30

Browse files
authored
Merge pull request #2802 from advikkabra/empty-membership
Prevent the hashing function being called when capacity is zero
2 parents ba2dff6 + a010b04 commit a104c30

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

integration_tests/test_membership_01.py

+12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ def test_int_dict():
66
i = 4
77
assert (i in a)
88

9+
a = {}
10+
assert (1 not in a)
11+
912
def test_str_dict():
1013
a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'}
1114
i: str
@@ -14,6 +17,9 @@ def test_str_dict():
1417
i = 'c'
1518
assert (i in a)
1619

20+
a = {}
21+
assert ('a' not in a)
22+
1723
def test_int_set():
1824
a: set[i32] = {1, 2, 3, 4}
1925
i: i32
@@ -22,6 +28,9 @@ def test_int_set():
2228
i = 4
2329
assert (i in a)
2430

31+
a = set()
32+
assert (1 not in a)
33+
2534
def test_str_set():
2635
a: set[str] = {'a', 'b', 'c', 'e', 'f'}
2736
i: str
@@ -30,6 +39,9 @@ def test_str_set():
3039
i = 'c'
3140
assert (i in a)
3241

42+
a = set()
43+
assert ('a' not in a)
44+
3345
test_int_dict()
3446
test_str_dict()
3547
test_int_set()

src/libasr/codegen/asr_to_llvm.cpp

+22-6
Original file line numberDiff line numberDiff line change
@@ -1726,9 +1726,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17261726
ptr_loads = ptr_loads_copy;
17271727
llvm::Value *capacity = LLVM::CreateLoad(*builder,
17281728
llvm_utils->dict_api->get_pointer_to_capacity(right));
1729-
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);
1730-
1731-
tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true);
1729+
get_builder0();
1730+
llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
1731+
llvm_utils->create_if_else(builder->CreateICmpEQ(
1732+
capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))),
1733+
[&]() {
1734+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res);
1735+
}, [&]() {
1736+
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);
1737+
LLVM::CreateStore(*builder, llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true), res);
1738+
});
1739+
tmp = LLVM::CreateLoad(*builder, res);
17321740
}
17331741

17341742
void visit_SetContains(const ASR::SetContains_t &x) {
@@ -1748,9 +1756,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17481756
ptr_loads = ptr_loads_copy;
17491757
llvm::Value *capacity = LLVM::CreateLoad(*builder,
17501758
llvm_utils->set_api->get_pointer_to_capacity(right));
1751-
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);
1752-
1753-
tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true);
1759+
get_builder0();
1760+
llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
1761+
llvm_utils->create_if_else(builder->CreateICmpEQ(
1762+
capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))),
1763+
[&]() {
1764+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res);
1765+
}, [&]() {
1766+
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);
1767+
LLVM::CreateStore(*builder, llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true), res);
1768+
});
1769+
tmp = LLVM::CreateLoad(*builder, res);
17541770
}
17551771

17561772
void visit_DictLen(const ASR::DictLen_t& x) {

0 commit comments

Comments
 (0)