@@ -4359,6 +4359,128 @@ namespace LCompilers {
43594359 llvm_utils->start_new_block (loopend);
43604360 }
43614361
4362+ llvm::Value *LLVMDict::is_key_present (llvm::Value *dict, llvm::Value *key,
4363+ ASR::Dict_t *dict_type, llvm::Module &module ) {
4364+ llvm::Value *capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4365+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module );
4366+ llvm::Value *key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4367+ llvm::Value *key_list = get_key_list (dict);
4368+
4369+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module , dict_type->m_key_type , true );
4370+ llvm::Value *pos = LLVM::CreateLoad (*builder, pos_ptr);
4371+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4372+ llvm_utils->list_api ->read_item (key_list, pos, false , module ,
4373+ LLVM::is_llvm_struct (dict_type->m_key_type )), module , dict_type->m_key_type );
4374+
4375+ return is_key_matching;
4376+ }
4377+
4378+ llvm::Value *LLVMDictOptimizedLinearProbing::is_key_present (llvm::Value *dict, llvm::Value *key,
4379+ ASR::Dict_t *dict_type, llvm::Module &module ) {
4380+ /* *
4381+ * C++ equivalent:
4382+ *
4383+ * key_mask_value = key_mask[key_hash];
4384+ * is_prob_not_needed = key_mask_value == 1;
4385+ * if( is_prob_not_needed ) {
4386+ * is_key_matching = key == key_list[key_hash];
4387+ * if( is_key_matching ) {
4388+ * pos = key_hash;
4389+ * }
4390+ * else {
4391+ * return is_key_matching;
4392+ * }
4393+ * }
4394+ * else {
4395+ * resolve_collision(key, for_read=true); // modifies pos
4396+ * }
4397+ *
4398+ * is_key_matching = key == key_list[pos];
4399+ * return is_key_matching;
4400+ */
4401+
4402+ llvm::Value* key_list = get_key_list (dict);
4403+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4404+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module );
4405+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4406+ get_builder0 ()
4407+ pos_ptr = builder0.CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
4408+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
4409+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
4410+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
4411+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
4412+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
4413+ llvm_utils->create_ptr_gep (key_mask, key_hash));
4414+ llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ (key_mask_value,
4415+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
4416+ bool to_return = false ;
4417+ builder->CreateCondBr (is_prob_not_neeeded, thenBB, elseBB);
4418+ builder->SetInsertPoint (thenBB);
4419+ {
4420+ // A single by value comparison is needed even though
4421+ // we don't need to do linear probing. This is because
4422+ // the user can provide a key which is absent in the dict
4423+ // but is giving the same hash value as one of the keys present in the dict.
4424+ // In the above case we will end up returning value for a key
4425+ // which is not present in the dict. Instead we should return an error
4426+ // which is done in the below code.
4427+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4428+ llvm_utils->list_api ->read_item (key_list, key_hash, false , module ,
4429+ LLVM::is_llvm_struct (dict_type->m_key_type )), module , dict_type->m_key_type );
4430+
4431+ llvm_utils->create_if_else (is_key_matching, [=]() {
4432+ LLVM::CreateStore (*builder, key_hash, pos_ptr);
4433+ }, [&]() {
4434+ // to_return = true;
4435+ });
4436+ }
4437+ builder->CreateBr (mergeBB);
4438+ llvm_utils->start_new_block (elseBB);
4439+ {
4440+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask,
4441+ module , dict_type->m_key_type , true );
4442+ }
4443+ llvm_utils->start_new_block (mergeBB);
4444+ if (to_return) {
4445+ return llvm::ConstantInt::get (llvm::Type::getInt1Ty (context), 0 );
4446+ }
4447+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
4448+ // Check if the actual key is present or not
4449+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4450+ llvm_utils->list_api ->read_item (key_list, pos, false , module ,
4451+ LLVM::is_llvm_struct (dict_type->m_key_type )), module , dict_type->m_key_type );
4452+
4453+ return is_key_matching;
4454+ }
4455+
4456+ llvm::Value *LLVMDictSeparateChaining::is_key_present (llvm::Value *dict, llvm::Value *key,
4457+ ASR::Dict_t *dict_type, llvm::Module &module ) {
4458+ llvm::Value *capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4459+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module );
4460+ llvm::Value *key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4461+ llvm::Value* key_value_pairs = LLVM::CreateLoad (*builder, get_pointer_to_key_value_pairs (dict));
4462+ llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep (key_value_pairs, key_hash);
4463+ llvm::Type* kv_struct_type = get_key_value_pair_type (dict_type->m_key_type , dict_type->m_value_type );
4464+ this ->resolve_collision (capacity, key_hash, key, key_value_pair_linked_list,
4465+ kv_struct_type, key_mask, module , dict_type->m_key_type );
4466+ std::pair<std::string, std::string> llvm_key = std::make_pair (
4467+ ASRUtils::get_type_code (dict_type->m_key_type ),
4468+ ASRUtils::get_type_code (dict_type->m_value_type )
4469+ );
4470+ llvm::Type* value_type = std::get<2 >(typecode2dicttype[llvm_key]).second ;
4471+ get_builder0 ()
4472+ tmp_value_ptr = builder0.CreateAlloca (value_type, nullptr );
4473+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
4474+ llvm_utils->create_ptr_gep (key_mask, key_hash));
4475+ llvm::Value* does_kv_exists = builder->CreateICmpEQ (key_mask_value,
4476+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
4477+ does_kv_exists = builder->CreateAnd (does_kv_exists,
4478+ builder->CreateICmpNE (LLVM::CreateLoad (*builder, chain_itr),
4479+ llvm::ConstantPointerNull::get (llvm::Type::getInt8PtrTy (context)))
4480+ );
4481+ return does_kv_exists;
4482+ }
4483+
43624484 llvm::Value* LLVMList::read_item (llvm::Value* list, llvm::Value* pos,
43634485 bool enable_bounds_checking,
43644486 llvm::Module& module , bool get_pointer) {
@@ -6825,6 +6947,113 @@ namespace LCompilers {
68256947 llvm_utils->start_new_block (loopend);
68266948 }
68276949
6950+ llvm::Value *LLVMSetLinearProbing::is_el_present (
6951+ llvm::Value *set, llvm::Value *el,
6952+ llvm::Module &module , ASR::ttype_t *el_asr_type) {
6953+ /* *
6954+ * C++ equivalent:
6955+ *
6956+ * el_mask_value = el_mask[el_hash];
6957+ * is_prob_needed = el_mask_value == 1;
6958+ * if( is_prob_needed ) {
6959+ * is_el_matching = el == el_list[el_hash];
6960+ * if( is_el_matching ) {
6961+ * pos = el_hash;
6962+ * }
6963+ * else {
6964+ * return is_el_matching;
6965+ * }
6966+ * }
6967+ * else {
6968+ * resolve_collision(el, for_read=true); // modifies pos
6969+ * }
6970+ *
6971+ * is_el_matching = el == el_list[pos];
6972+ * return is_el_matching
6973+ */
6974+
6975+ get_builder0 ()
6976+ llvm::Value* el_list = get_el_list (set);
6977+ llvm::Value* el_mask = LLVM::CreateLoad (*builder, get_pointer_to_mask (set));
6978+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (set));
6979+ llvm::Value *el_hash = get_el_hash (capacity, el, el_asr_type, module );
6980+ pos_ptr = builder0.CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
6981+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
6982+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
6983+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
6984+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
6985+ llvm::Value* el_mask_value = LLVM::CreateLoad (*builder,
6986+ llvm_utils->create_ptr_gep (el_mask, el_hash));
6987+ llvm::Value* is_prob_not_needed = builder->CreateICmpEQ (el_mask_value,
6988+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
6989+ bool to_return = false ;
6990+ builder->CreateCondBr (is_prob_not_needed, thenBB, elseBB);
6991+ builder->SetInsertPoint (thenBB);
6992+ {
6993+ // reasoning for this check explained in
6994+ // LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check
6995+ llvm::Value* is_el_matching = llvm_utils->is_equal_by_value (el,
6996+ llvm_utils->list_api ->read_item (el_list, el_hash, false , module ,
6997+ LLVM::is_llvm_struct (el_asr_type)), module , el_asr_type);
6998+
6999+ llvm_utils->create_if_else (is_el_matching, [=]() {
7000+ LLVM::CreateStore (*builder, el_hash, pos_ptr);
7001+ }, [&]() {
7002+ // to_return = true; // Need to check why this is not working
7003+ });
7004+ }
7005+ builder->CreateBr (mergeBB);
7006+ llvm_utils->start_new_block (elseBB);
7007+ {
7008+ this ->resolve_collision (capacity, el_hash, el, el_list, el_mask,
7009+ module , el_asr_type, true );
7010+ }
7011+ llvm_utils->start_new_block (mergeBB);
7012+ if (to_return) {
7013+ return llvm::ConstantInt::get (llvm::Type::getInt1Ty (context), 0 );
7014+ }
7015+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
7016+ // Check if the actual element is present or not
7017+ llvm::Value* is_el_matching = llvm_utils->is_equal_by_value (el,
7018+ llvm_utils->list_api ->read_item (el_list, pos, false , module ,
7019+ LLVM::is_llvm_struct (el_asr_type)), module , el_asr_type);
7020+
7021+
7022+ return is_el_matching;
7023+ }
7024+
7025+ llvm::Value *LLVMSetSeparateChaining::is_el_present (
7026+ llvm::Value *set, llvm::Value *el,
7027+ llvm::Module &module , ASR::ttype_t *el_asr_type) {
7028+ /* *
7029+ * C++ equivalent:
7030+ *
7031+ * resolve_collision(el); // modified chain_itr
7032+ * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr;
7033+ * return does_el_exist;
7034+ *
7035+ */
7036+ llvm::Value* elems = LLVM::CreateLoad (*builder, get_pointer_to_elems (set));
7037+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (set));
7038+ llvm::Value* el_hash = get_el_hash (capacity, el, el_asr_type, module );
7039+ llvm::Value* el_linked_list = llvm_utils->create_ptr_gep (elems, el_hash);
7040+ llvm::Value* el_mask = LLVM::CreateLoad (*builder, get_pointer_to_mask (set));
7041+ std::string el_type_code = ASRUtils::get_type_code (el_asr_type);
7042+ llvm::Type* el_struct_type = typecode2elstruct[el_type_code];
7043+ this ->resolve_collision (el_hash, el, el_linked_list,
7044+ el_struct_type, el_mask, module , el_asr_type);
7045+ llvm::Value* el_mask_value = LLVM::CreateLoad (*builder,
7046+ llvm_utils->create_ptr_gep (el_mask, el_hash));
7047+ llvm::Value* does_el_exist = builder->CreateICmpEQ (el_mask_value,
7048+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
7049+ does_el_exist = builder->CreateAnd (does_el_exist,
7050+ builder->CreateICmpNE (LLVM::CreateLoad (*builder, chain_itr),
7051+ llvm::ConstantPointerNull::get (llvm::Type::getInt8PtrTy (context)))
7052+ );
7053+
7054+ return does_el_exist;
7055+ }
7056+
68287057 llvm::Value* LLVMSetInterface::len (llvm::Value* set) {
68297058 return LLVM::CreateLoad (*builder, get_pointer_to_occupancy (set));
68307059 }
0 commit comments