@@ -4359,6 +4359,128 @@ namespace LCompilers {
4359
4359
llvm_utils->start_new_block (loopend);
4360
4360
}
4361
4361
4362
+ llvm::Value *LLVMDict::is_key_present (llvm::Value *dict, llvm::Value *key,
4363
+ ASR::Dict_t *dict_type, llvm::Module &module) {
4364
+ llvm::Value *capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4365
+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module);
4366
+ llvm::Value *key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4367
+ llvm::Value *key_list = get_key_list (dict);
4368
+
4369
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask, module, dict_type->m_key_type , true );
4370
+ llvm::Value *pos = LLVM::CreateLoad (*builder, pos_ptr);
4371
+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4372
+ llvm_utils->list_api ->read_item (key_list, pos, false , module,
4373
+ LLVM::is_llvm_struct (dict_type->m_key_type )), module, dict_type->m_key_type );
4374
+
4375
+ return is_key_matching;
4376
+ }
4377
+
4378
+ llvm::Value *LLVMDictOptimizedLinearProbing::is_key_present (llvm::Value *dict, llvm::Value *key,
4379
+ ASR::Dict_t *dict_type, llvm::Module &module) {
4380
+ /* *
4381
+ * C++ equivalent:
4382
+ *
4383
+ * key_mask_value = key_mask[key_hash];
4384
+ * is_prob_not_needed = key_mask_value == 1;
4385
+ * if( is_prob_not_needed ) {
4386
+ * is_key_matching = key == key_list[key_hash];
4387
+ * if( is_key_matching ) {
4388
+ * pos = key_hash;
4389
+ * }
4390
+ * else {
4391
+ * return is_key_matching;
4392
+ * }
4393
+ * }
4394
+ * else {
4395
+ * resolve_collision(key, for_read=true); // modifies pos
4396
+ * }
4397
+ *
4398
+ * is_key_matching = key == key_list[pos];
4399
+ * return is_key_matching;
4400
+ */
4401
+
4402
+ llvm::Value* key_list = get_key_list (dict);
4403
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4404
+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module);
4405
+ llvm::Value* key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4406
+ get_builder0 ()
4407
+ pos_ptr = builder0.CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
4408
+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
4409
+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
4410
+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
4411
+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
4412
+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
4413
+ llvm_utils->create_ptr_gep (key_mask, key_hash));
4414
+ llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ (key_mask_value,
4415
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
4416
+ bool to_return = false ;
4417
+ builder->CreateCondBr (is_prob_not_neeeded, thenBB, elseBB);
4418
+ builder->SetInsertPoint (thenBB);
4419
+ {
4420
+ // A single by value comparison is needed even though
4421
+ // we don't need to do linear probing. This is because
4422
+ // the user can provide a key which is absent in the dict
4423
+ // but is giving the same hash value as one of the keys present in the dict.
4424
+ // In the above case we will end up returning value for a key
4425
+ // which is not present in the dict. Instead we should return an error
4426
+ // which is done in the below code.
4427
+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4428
+ llvm_utils->list_api ->read_item (key_list, key_hash, false , module,
4429
+ LLVM::is_llvm_struct (dict_type->m_key_type )), module, dict_type->m_key_type );
4430
+
4431
+ llvm_utils->create_if_else (is_key_matching, [=]() {
4432
+ LLVM::CreateStore (*builder, key_hash, pos_ptr);
4433
+ }, [&]() {
4434
+ // to_return = true;
4435
+ });
4436
+ }
4437
+ builder->CreateBr (mergeBB);
4438
+ llvm_utils->start_new_block (elseBB);
4439
+ {
4440
+ this ->resolve_collision (capacity, key_hash, key, key_list, key_mask,
4441
+ module, dict_type->m_key_type , true );
4442
+ }
4443
+ llvm_utils->start_new_block (mergeBB);
4444
+ if (to_return) {
4445
+ return llvm::ConstantInt::get (llvm::Type::getInt1Ty (context), 0 );
4446
+ }
4447
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
4448
+ // Check if the actual key is present or not
4449
+ llvm::Value* is_key_matching = llvm_utils->is_equal_by_value (key,
4450
+ llvm_utils->list_api ->read_item (key_list, pos, false , module,
4451
+ LLVM::is_llvm_struct (dict_type->m_key_type )), module, dict_type->m_key_type );
4452
+
4453
+ return is_key_matching;
4454
+ }
4455
+
4456
+ llvm::Value *LLVMDictSeparateChaining::is_key_present (llvm::Value *dict, llvm::Value *key,
4457
+ ASR::Dict_t *dict_type, llvm::Module &module) {
4458
+ llvm::Value *capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (dict));
4459
+ llvm::Value *key_hash = get_key_hash (capacity, key, dict_type->m_key_type , module);
4460
+ llvm::Value *key_mask = LLVM::CreateLoad (*builder, get_pointer_to_keymask (dict));
4461
+ llvm::Value* key_value_pairs = LLVM::CreateLoad (*builder, get_pointer_to_key_value_pairs (dict));
4462
+ llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep (key_value_pairs, key_hash);
4463
+ llvm::Type* kv_struct_type = get_key_value_pair_type (dict_type->m_key_type , dict_type->m_value_type );
4464
+ this ->resolve_collision (capacity, key_hash, key, key_value_pair_linked_list,
4465
+ kv_struct_type, key_mask, module, dict_type->m_key_type );
4466
+ std::pair<std::string, std::string> llvm_key = std::make_pair (
4467
+ ASRUtils::get_type_code (dict_type->m_key_type ),
4468
+ ASRUtils::get_type_code (dict_type->m_value_type )
4469
+ );
4470
+ llvm::Type* value_type = std::get<2 >(typecode2dicttype[llvm_key]).second ;
4471
+ get_builder0 ()
4472
+ tmp_value_ptr = builder0.CreateAlloca (value_type, nullptr );
4473
+ llvm::Value* key_mask_value = LLVM::CreateLoad (*builder,
4474
+ llvm_utils->create_ptr_gep (key_mask, key_hash));
4475
+ llvm::Value* does_kv_exists = builder->CreateICmpEQ (key_mask_value,
4476
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
4477
+ does_kv_exists = builder->CreateAnd (does_kv_exists,
4478
+ builder->CreateICmpNE (LLVM::CreateLoad (*builder, chain_itr),
4479
+ llvm::ConstantPointerNull::get (llvm::Type::getInt8PtrTy (context)))
4480
+ );
4481
+ return does_kv_exists;
4482
+ }
4483
+
4362
4484
llvm::Value* LLVMList::read_item (llvm::Value* list, llvm::Value* pos,
4363
4485
bool enable_bounds_checking,
4364
4486
llvm::Module& module, bool get_pointer) {
@@ -6825,6 +6947,113 @@ namespace LCompilers {
6825
6947
llvm_utils->start_new_block (loopend);
6826
6948
}
6827
6949
6950
+ llvm::Value *LLVMSetLinearProbing::is_el_present (
6951
+ llvm::Value *set, llvm::Value *el,
6952
+ llvm::Module &module, ASR::ttype_t *el_asr_type) {
6953
+ /* *
6954
+ * C++ equivalent:
6955
+ *
6956
+ * el_mask_value = el_mask[el_hash];
6957
+ * is_prob_needed = el_mask_value == 1;
6958
+ * if( is_prob_needed ) {
6959
+ * is_el_matching = el == el_list[el_hash];
6960
+ * if( is_el_matching ) {
6961
+ * pos = el_hash;
6962
+ * }
6963
+ * else {
6964
+ * return is_el_matching;
6965
+ * }
6966
+ * }
6967
+ * else {
6968
+ * resolve_collision(el, for_read=true); // modifies pos
6969
+ * }
6970
+ *
6971
+ * is_el_matching = el == el_list[pos];
6972
+ * return is_el_matching
6973
+ */
6974
+
6975
+ get_builder0 ()
6976
+ llvm::Value* el_list = get_el_list (set);
6977
+ llvm::Value* el_mask = LLVM::CreateLoad (*builder, get_pointer_to_mask (set));
6978
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (set));
6979
+ llvm::Value *el_hash = get_el_hash (capacity, el, el_asr_type, module);
6980
+ pos_ptr = builder0.CreateAlloca (llvm::Type::getInt32Ty (context), nullptr );
6981
+ llvm::Function *fn = builder->GetInsertBlock ()->getParent ();
6982
+ llvm::BasicBlock *thenBB = llvm::BasicBlock::Create (context, " then" , fn);
6983
+ llvm::BasicBlock *elseBB = llvm::BasicBlock::Create (context, " else" );
6984
+ llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create (context, " ifcont" );
6985
+ llvm::Value* el_mask_value = LLVM::CreateLoad (*builder,
6986
+ llvm_utils->create_ptr_gep (el_mask, el_hash));
6987
+ llvm::Value* is_prob_not_needed = builder->CreateICmpEQ (el_mask_value,
6988
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
6989
+ bool to_return = false ;
6990
+ builder->CreateCondBr (is_prob_not_needed, thenBB, elseBB);
6991
+ builder->SetInsertPoint (thenBB);
6992
+ {
6993
+ // reasoning for this check explained in
6994
+ // LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check
6995
+ llvm::Value* is_el_matching = llvm_utils->is_equal_by_value (el,
6996
+ llvm_utils->list_api ->read_item (el_list, el_hash, false , module,
6997
+ LLVM::is_llvm_struct (el_asr_type)), module, el_asr_type);
6998
+
6999
+ llvm_utils->create_if_else (is_el_matching, [=]() {
7000
+ LLVM::CreateStore (*builder, el_hash, pos_ptr);
7001
+ }, [&]() {
7002
+ // to_return = true; // Need to check why this is not working
7003
+ });
7004
+ }
7005
+ builder->CreateBr (mergeBB);
7006
+ llvm_utils->start_new_block (elseBB);
7007
+ {
7008
+ this ->resolve_collision (capacity, el_hash, el, el_list, el_mask,
7009
+ module, el_asr_type, true );
7010
+ }
7011
+ llvm_utils->start_new_block (mergeBB);
7012
+ if (to_return) {
7013
+ return llvm::ConstantInt::get (llvm::Type::getInt1Ty (context), 0 );
7014
+ }
7015
+ llvm::Value* pos = LLVM::CreateLoad (*builder, pos_ptr);
7016
+ // Check if the actual element is present or not
7017
+ llvm::Value* is_el_matching = llvm_utils->is_equal_by_value (el,
7018
+ llvm_utils->list_api ->read_item (el_list, pos, false , module,
7019
+ LLVM::is_llvm_struct (el_asr_type)), module, el_asr_type);
7020
+
7021
+
7022
+ return is_el_matching;
7023
+ }
7024
+
7025
+ llvm::Value *LLVMSetSeparateChaining::is_el_present (
7026
+ llvm::Value *set, llvm::Value *el,
7027
+ llvm::Module &module, ASR::ttype_t *el_asr_type) {
7028
+ /* *
7029
+ * C++ equivalent:
7030
+ *
7031
+ * resolve_collision(el); // modified chain_itr
7032
+ * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr;
7033
+ * return does_el_exist;
7034
+ *
7035
+ */
7036
+ llvm::Value* elems = LLVM::CreateLoad (*builder, get_pointer_to_elems (set));
7037
+ llvm::Value* capacity = LLVM::CreateLoad (*builder, get_pointer_to_capacity (set));
7038
+ llvm::Value* el_hash = get_el_hash (capacity, el, el_asr_type, module);
7039
+ llvm::Value* el_linked_list = llvm_utils->create_ptr_gep (elems, el_hash);
7040
+ llvm::Value* el_mask = LLVM::CreateLoad (*builder, get_pointer_to_mask (set));
7041
+ std::string el_type_code = ASRUtils::get_type_code (el_asr_type);
7042
+ llvm::Type* el_struct_type = typecode2elstruct[el_type_code];
7043
+ this ->resolve_collision (el_hash, el, el_linked_list,
7044
+ el_struct_type, el_mask, module, el_asr_type);
7045
+ llvm::Value* el_mask_value = LLVM::CreateLoad (*builder,
7046
+ llvm_utils->create_ptr_gep (el_mask, el_hash));
7047
+ llvm::Value* does_el_exist = builder->CreateICmpEQ (el_mask_value,
7048
+ llvm::ConstantInt::get (llvm::Type::getInt8Ty (context), llvm::APInt (8 , 1 )));
7049
+ does_el_exist = builder->CreateAnd (does_el_exist,
7050
+ builder->CreateICmpNE (LLVM::CreateLoad (*builder, chain_itr),
7051
+ llvm::ConstantPointerNull::get (llvm::Type::getInt8PtrTy (context)))
7052
+ );
7053
+
7054
+ return does_el_exist;
7055
+ }
7056
+
6828
7057
llvm::Value* LLVMSetInterface::len (llvm::Value* set) {
6829
7058
return LLVM::CreateLoad (*builder, get_pointer_to_occupancy (set));
6830
7059
}
0 commit comments