Skip to content

Commit d92a038

Browse files
committed
add exx nscf file check
1 parent 808af53 commit d92a038

File tree

4 files changed

+97
-36
lines changed

4 files changed

+97
-36
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ template <typename TK, typename TR>
8282
OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
8383
HContainer<TR>*hR_in,
8484
const UnitCell& ucell_in,
85-
const K_Vectors& kv_in,
86-
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
87-
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
85+
const K_Vectors& kv_in,
86+
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
87+
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
8888
Add_Hexx_Type add_hexx_type_in,
8989
const int istep,
9090
int* two_level_step_in,
91-
const bool restart_in)
91+
const bool restart_in)
9292
: OperatorLCAO<TK, TR>(hsk_in, kv_in.kvec_d, hR_in),
9393
ucell(ucell_in),
9494
kv(kv_in),
@@ -105,42 +105,75 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
105105

106106
if (PARAM.inp.calculation == "nscf" && GlobalC::exx_info.info_global.cal_exx)
107107
{ // if nscf, read HexxR first and reallocate hR according to the read-in HexxR
108-
const std::string file_name_exx = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
109-
bool all_exist = true;
110-
for (int is=0;is<PARAM.inp.nspin;++is)
108+
auto file_name_list_csr = []() -> std::vector<std::string>
111109
{
112-
std::ifstream ifs(file_name_exx + "_" + std::to_string(is) + ".csr");
113-
if (!ifs) { all_exist = false; break; }
114-
}
115-
if (all_exist)
110+
std::vector<std::string> file_name_list;
111+
for (int irank=0; irank<PARAM.globalv.nproc; ++irank) {
112+
for (int is=0;is<PARAM.inp.nspin;++is) {
113+
file_name_list.push_back( PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(irank) + "_" + std::to_string(is) + ".csr" );
114+
} }
115+
return file_name_list;
116+
};
117+
auto file_name_list_cereal = []() -> std::vector<std::string>
118+
{
119+
std::vector<std::string> file_name_list;
120+
for (int irank=0; irank<PARAM.globalv.nproc; ++irank)
121+
{ file_name_list.push_back( "HexxR_" + std::to_string(irank) ); }
122+
return file_name_list;
123+
};
124+
auto check_exist = [](const std::vector<std::string> &file_name_list) -> bool
125+
{
126+
for (const std::string &file_name : file_name_list)
127+
{
128+
std::ifstream ifs(file_name);
129+
if (!ifs.is_open())
130+
{ return false; }
131+
}
132+
return true;
133+
};
134+
135+
std::cout<<" Attention: The number of MPI processes must be strictly identical between SCF and NSCF when computing exact-exchange."<<std::endl;
136+
if (check_exist(file_name_list_csr()))
116137
{
138+
const std::string file_name_exx_csr = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(PARAM.globalv.myrank);
117139
// Read HexxR in CSR format
118140
if (GlobalC::exx_info.info_ri.real_number)
119141
{
120-
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
121-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
142+
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
143+
if (this->add_hexx_type == Add_Hexx_Type::R)
144+
{ reallocate_hcontainer(*Hexxd, this->hR); }
122145
}
123146
else
124147
{
125-
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
126-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
148+
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
149+
if (this->add_hexx_type == Add_Hexx_Type::R)
150+
{ reallocate_hcontainer(*Hexxc, this->hR); }
127151
}
128152
}
129-
else
153+
else if (check_exist(file_name_list_cereal()))
130154
{
131155
// Read HexxR in binary format (old version)
132-
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
156+
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(PARAM.globalv.myrank);
157+
std::ifstream ifs(file_name_exx_cereal, std::ios::binary);
158+
if (!ifs)
159+
{ ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file < " + file_name_exx_cereal + " >."); }
133160
if (GlobalC::exx_info.info_ri.real_number)
134161
{
135162
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxd);
136-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
163+
if (this->add_hexx_type == Add_Hexx_Type::R)
164+
{ reallocate_hcontainer(*Hexxd, this->hR); }
137165
}
138166
else
139167
{
140168
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxc);
141-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
169+
if (this->add_hexx_type == Add_Hexx_Type::R)
170+
{ reallocate_hcontainer(*Hexxc, this->hR); }
142171
}
143172
}
173+
else
174+
{
175+
ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file in " + PARAM.globalv.global_readin_dir);
176+
}
144177
this->use_cell_nearest = false;
145178
}
146179
else
@@ -207,7 +240,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
207240
else if (this->add_hexx_type == Add_Hexx_Type::R)
208241
{
209242
// read in Hexx(R)
210-
const std::string restart_HR_path = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
243+
const std::string restart_HR_path = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(PARAM.globalv.myrank);
211244
bool all_exist = true;
212245
for (int is = 0; is < PARAM.inp.nspin; ++is)
213246
{
@@ -227,7 +260,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
227260
else
228261
{
229262
// Read HexxR in binary format (old version)
230-
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(GlobalV::MY_RANK);
263+
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(PARAM.globalv.myrank);
231264
if (GlobalC::exx_info.info_ri.real_number) {
232265
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxd);
233266
}

source/module_io/restart_exx_csr.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ namespace ModuleIO
2828
{
2929
const std::vector<int>& R = csr.getRCoordinate(iR);
3030
TC dR({ R[0], R[1], R[2] });
31-
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>({ static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw), static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw) });
31+
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>(
32+
{
33+
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw),
34+
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw)
35+
}
36+
);
3237
}
3338
}
3439
}
@@ -44,7 +49,12 @@ namespace ModuleIO
4449
const int& npol = ucell.get_npol();
4550
const int& i = ijv.first.first * npol;
4651
const int& j = ijv.first.second * npol;
47-
Hexxs.at(is).at(ucell.iwt2iat[i]).at({ ucell.iwt2iat[j], { R[0], R[1], R[2] } })(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
52+
Hexxs.at(is).at(ucell.iwt2iat[i]).at(
53+
{
54+
ucell.iwt2iat[j],
55+
{ R[0], R[1], R[2] }
56+
}
57+
)(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
4858
}
4959
}
5060
}
@@ -57,6 +67,8 @@ namespace ModuleIO
5767
ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal");
5868
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
5969
std::ifstream ifs(file_name, std::ios::binary);
70+
if(!ifs.is_open())
71+
{ ModuleBase::WARNING_QUIT("read_Hexxs_cereal", file_name+" not found."); }
6072
cereal::BinaryInputArchive iar(ifs);
6173
iar(Hexxs);
6274
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");

source/module_ri/Exx_LRI_interface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class Exx_LRI_Interface
4242
}
4343
Exx_LRI_Interface() = delete;
4444

45-
/// read and write Hexxs using cereal
46-
void write_Hexxs_cereal(const std::string& file_name) const;
47-
void read_Hexxs_cereal(const std::string& file_name);
45+
///// read and write Hexxs using cereal
46+
//void write_Hexxs_cereal(const std::string& file_name) const;
47+
//void read_Hexxs_cereal(const std::string& file_name);
4848

4949
std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; }
5050
double &get_Eexx() const { return this->exx_ptr->Eexx; }

source/module_ri/Exx_LRI_interface.hpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <stdexcept>
1919
#include <string>
2020

21+
/*
2122
template<typename T, typename Tdata>
2223
void Exx_LRI_Interface<T, Tdata>::write_Hexxs_cereal(const std::string& file_name) const
2324
{
@@ -32,13 +33,17 @@ void Exx_LRI_Interface<T, Tdata>::write_Hexxs_cereal(const std::string& file_nam
3233
template<typename T, typename Tdata>
3334
void Exx_LRI_Interface<T, Tdata>::read_Hexxs_cereal(const std::string& file_name)
3435
{
35-
ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal");
36-
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
37-
std::ifstream ifs(file_name + "_" + std::to_string(GlobalV::MY_RANK), std::ofstream::binary);
36+
ModuleBase::TITLE("Exx_LRI_Interface", "read_Hexxs_cereal");
37+
ModuleBase::timer::tick("Exx_LRI_Interface", "read_Hexxs_cereal");
38+
const std::string file_name_rank = file_name + "_" + std::to_string(GlobalV::MY_RANK);
39+
std::ifstream ifs(file_name_rank, std::ofstream::binary);
40+
if(!ifs.is_open())
41+
{ ModuleBase::WARNING_QUIT("Exx_LRI_Interface", file_name_rank+" not found."); }
3842
cereal::BinaryInputArchive iar(ifs);
3943
iar(this->exx_ptr->Hexxs);
4044
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
4145
}
46+
*/
4247

4348
template<typename T, typename Tdata>
4449
void Exx_LRI_Interface<T, Tdata>::init(const MPI_Comm &mpi_comm,
@@ -115,7 +120,8 @@ void Exx_LRI_Interface<T, Tdata>::exx_before_all_runners(const K_Vectors& kv, co
115120
if (this->exx_spacegroup_symmetry)
116121
{
117122
const std::array<int, 3>& period = RI_Util::get_Born_vonKarmen_period(kv);
118-
this->symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st,
123+
this->symrot_.find_irreducible_sector(
124+
ucell.symm, ucell.atoms, ucell.st,
119125
RI_Util::get_Born_von_Karmen_cells(period), period, ucell.lat);
120126
// this->symrot_.set_Cs_rotation(this->exx_ptr->get_abfs_nchis());
121127
this->symrot_.cal_Ms(kv, ucell, pv);
@@ -209,8 +215,19 @@ void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const int istep,
209215
{ this->mix_DMk_2D.mix(dm_in.get_DMK_vector(), flag_restart); }
210216
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>
211217
Ds = PARAM.globalv.gamma_only_local
212-
? RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_gamma_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin)
213-
: RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin, this->exx_spacegroup_symmetry);
218+
? RI_2D_Comm::split_m2D_ktoR<Tdata>(
219+
ucell,
220+
*this->exx_ptr->p_kv,
221+
this->mix_DMk_2D.get_DMk_gamma_out(),
222+
*dm_in.get_paraV_pointer(),
223+
PARAM.inp.nspin)
224+
: RI_2D_Comm::split_m2D_ktoR<Tdata>(
225+
ucell,
226+
*this->exx_ptr->p_kv,
227+
this->mix_DMk_2D.get_DMk_k_out(),
228+
*dm_in.get_paraV_pointer(),
229+
PARAM.inp.nspin,
230+
this->exx_spacegroup_symmetry);
214231

215232
if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace)
216233
{ this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); }
@@ -240,11 +257,10 @@ void Exx_LRI_Interface<T, Tdata>::exx_hamilt2density(elecstate::ElecState& elec,
240257
{
241258
if (GlobalV::MY_RANK == 0)
242259
{
243-
try { GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
260+
try
261+
{ GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
244262
catch (const std::exception& e)
245-
{
246-
std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl;
247-
}
263+
{ std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl; }
248264
}
249265
Parallel_Common::bcast_double(this->exx_ptr->Eexx);
250266
this->exx_ptr->Eexx /= GlobalC::exx_info.info_global.hybrid_alpha;

0 commit comments

Comments
 (0)