Skip to content

add exx nscf file check #6289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ template <typename TK, typename TR>
OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
HContainer<TR>*hR_in,
const UnitCell& ucell_in,
const K_Vectors& kv_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
const K_Vectors& kv_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
Add_Hexx_Type add_hexx_type_in,
const int istep,
int* two_level_step_in,
const bool restart_in)
const bool restart_in)
: OperatorLCAO<TK, TR>(hsk_in, kv_in.kvec_d, hR_in),
ucell(ucell_in),
kv(kv_in),
Expand All @@ -105,42 +105,75 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,

if (PARAM.inp.calculation == "nscf" && GlobalC::exx_info.info_global.cal_exx)
{ // if nscf, read HexxR first and reallocate hR according to the read-in HexxR
const std::string file_name_exx = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
bool all_exist = true;
for (int is=0;is<PARAM.inp.nspin;++is)
auto file_name_list_csr = []() -> std::vector<std::string>
{
std::ifstream ifs(file_name_exx + "_" + std::to_string(is) + ".csr");
if (!ifs) { all_exist = false; break; }
}
if (all_exist)
std::vector<std::string> file_name_list;
for (int irank=0; irank<PARAM.globalv.nproc; ++irank) {
for (int is=0;is<PARAM.inp.nspin;++is) {
file_name_list.push_back( PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(irank) + "_" + std::to_string(is) + ".csr" );
} }
return file_name_list;
};
auto file_name_list_cereal = []() -> std::vector<std::string>
{
std::vector<std::string> file_name_list;
for (int irank=0; irank<PARAM.globalv.nproc; ++irank)
{ file_name_list.push_back( "HexxR_" + std::to_string(irank) ); }
return file_name_list;
};
auto check_exist = [](const std::vector<std::string> &file_name_list) -> bool
{
for (const std::string &file_name : file_name_list)
{
std::ifstream ifs(file_name);
if (!ifs.is_open())
{ return false; }
}
return true;
};

std::cout<<" Attention: The number of MPI processes must be strictly identical between SCF and NSCF when computing exact-exchange."<<std::endl;
if (check_exist(file_name_list_csr()))
{
const std::string file_name_exx_csr = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(PARAM.globalv.myrank);
// Read HexxR in CSR format
if (GlobalC::exx_info.info_ri.real_number)
{
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
if (this->add_hexx_type == Add_Hexx_Type::R)
{ reallocate_hcontainer(*Hexxd, this->hR); }
}
else
{
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
if (this->add_hexx_type == Add_Hexx_Type::R)
{ reallocate_hcontainer(*Hexxc, this->hR); }
}
}
else
else if (check_exist(file_name_list_cereal()))
{
// Read HexxR in binary format (old version)
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(PARAM.globalv.myrank);
std::ifstream ifs(file_name_exx_cereal, std::ios::binary);
if (!ifs)
{ ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file < " + file_name_exx_cereal + " >."); }
if (GlobalC::exx_info.info_ri.real_number)
{
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxd);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
if (this->add_hexx_type == Add_Hexx_Type::R)
{ reallocate_hcontainer(*Hexxd, this->hR); }
}
else
{
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxc);
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
if (this->add_hexx_type == Add_Hexx_Type::R)
{ reallocate_hcontainer(*Hexxc, this->hR); }
}
}
else
{
ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file in " + PARAM.globalv.global_readin_dir);
}
this->use_cell_nearest = false;
}
else
Expand Down Expand Up @@ -207,7 +240,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
else if (this->add_hexx_type == Add_Hexx_Type::R)
{
// read in Hexx(R)
const std::string restart_HR_path = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
const std::string restart_HR_path = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(PARAM.globalv.myrank);
bool all_exist = true;
for (int is = 0; is < PARAM.inp.nspin; ++is)
{
Expand All @@ -227,7 +260,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
else
{
// Read HexxR in binary format (old version)
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(GlobalV::MY_RANK);
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(PARAM.globalv.myrank);
if (GlobalC::exx_info.info_ri.real_number) {
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxd);
}
Expand Down
16 changes: 14 additions & 2 deletions source/module_io/restart_exx_csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ namespace ModuleIO
{
const std::vector<int>& R = csr.getRCoordinate(iR);
TC dR({ R[0], R[1], R[2] });
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>({ static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw), static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw) });
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>(
{
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw),
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw)
}
);
}
}
}
Expand All @@ -44,7 +49,12 @@ namespace ModuleIO
const int& npol = ucell.get_npol();
const int& i = ijv.first.first * npol;
const int& j = ijv.first.second * npol;
Hexxs.at(is).at(ucell.iwt2iat[i]).at({ ucell.iwt2iat[j], { R[0], R[1], R[2] } })(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
Hexxs.at(is).at(ucell.iwt2iat[i]).at(
{
ucell.iwt2iat[j],
{ R[0], R[1], R[2] }
}
)(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
}
}
}
Expand All @@ -57,6 +67,8 @@ namespace ModuleIO
ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal");
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
std::ifstream ifs(file_name, std::ios::binary);
if(!ifs.is_open())
{ ModuleBase::WARNING_QUIT("read_Hexxs_cereal", file_name+" not found."); }
cereal::BinaryInputArchive iar(ifs);
iar(Hexxs);
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
Expand Down
26 changes: 13 additions & 13 deletions source/module_ri/ABFs_Construct-PCA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ namespace PCA
dsyev_(&jobz, &uplo, &nr, a.ptr(), &nc, w, work.data(), &lwork, &info);
}

RI::Tensor<double> get_sub_matrix(
RI::Tensor<double> get_sub_matrix(
const RI::Tensor<double> & m, // size: (lcaos, lcaos, abfs)
const std::size_t & T,
const std::size_t & L,
const ModuleBase::Element_Basis_Index::Range & range,
const ModuleBase::Element_Basis_Index::IndexLNM & index )
{
ModuleBase::TITLE("ABFs_Construct::PCA::get_sub_matrix");
ModuleBase::TITLE("ABFs_Construct::PCA::get_sub_matrix");
assert(m.shape.size() == 3);
RI::Tensor<double> m_sub({ m.shape[0], m.shape[1], range[T][L].N });
for (std::size_t ir=0; ir!=m.shape[0]; ++ir) {
Expand Down Expand Up @@ -74,12 +74,12 @@ namespace PCA
std::vector<std::vector<std::pair<std::vector<double>, RI::Tensor<double>>>> cal_PCA(
const UnitCell &ucell,
const LCAO_Orbitals& orb,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs,
const double kmesh_times )
{
ModuleBase::TITLE("ABFs_Construct::PCA::cal_PCA");

const ModuleBase::Element_Basis_Index::Range
range_lcaos = Exx_Abfs::Abfs_Index::construct_range( lcaos );
const ModuleBase::Element_Basis_Index::IndexLNM
Expand Down Expand Up @@ -107,27 +107,27 @@ namespace PCA
m_abfslcaos_lcaos.init_radial_table(delta_R);

GlobalC::exx_info.info_ri.abfs_Lmax = Lmax_bak;

std::vector<std::vector<std::pair<std::vector<double>,RI::Tensor<double>>>> eig(abfs.size());
for( std::size_t T=0; T!=abfs.size(); ++T )
{
const RI::Tensor<double> A = m_abfslcaos_lcaos.cal_overlap_matrix<double>(
T,
T,
const RI::Tensor<double> A = m_abfslcaos_lcaos.cal_overlap_matrix<double>(
T,
T,
ModuleBase::Vector3<double>{0,0,0},
ModuleBase::Vector3<double>{0,0,0},
ModuleBase::Vector3<double>{0,0,0},
index_abfs,
index_abfs,
index_lcaos,
index_lcaos,
Matrix_Orbs21::Matrix_Order::A2BA1);

eig[T].resize(abfs[T].size());
for( std::size_t L=0; L!=abfs[T].size(); ++L )
{
const RI::Tensor<double> A_sub = get_sub_matrix( A, T, L, range_abfs, index_abfs );
RI::Tensor<double> mm = A_sub.transpose() * A_sub;
std::vector<double> eig_value(mm.shape[0]);

int info=1;

tensor_dsyev('V', 'L', mm, eig_value.data(), info);
Expand Down Expand Up @@ -158,7 +158,7 @@ namespace PCA
eig[T][L] = std::make_pair( eig_value, mm );
}
}

return eig;
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_ri/ABFs_Construct-PCA.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ namespace ABFs_Construct
{
namespace PCA
{
extern std::vector<std::vector<std::pair<std::vector<double>,RI::Tensor<double>>>> cal_PCA(
extern std::vector<std::vector<std::pair<std::vector<double>,RI::Tensor<double>>>> cal_PCA(
const UnitCell& ucell,
const LCAO_Orbitals &orb,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &lcaos,
const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &abfs, // abfs must be orthonormal
const double kmesh_times );
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_ri/Exx_LRI_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ class Exx_LRI_Interface
}
Exx_LRI_Interface() = delete;

/// read and write Hexxs using cereal
void write_Hexxs_cereal(const std::string& file_name) const;
void read_Hexxs_cereal(const std::string& file_name);
///// read and write Hexxs using cereal
//void write_Hexxs_cereal(const std::string& file_name) const;
//void read_Hexxs_cereal(const std::string& file_name);

std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; }
double &get_Eexx() const { return this->exx_ptr->Eexx; }
Expand Down
36 changes: 26 additions & 10 deletions source/module_ri/Exx_LRI_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <stdexcept>
#include <string>

/*
template<typename T, typename Tdata>
void Exx_LRI_Interface<T, Tdata>::write_Hexxs_cereal(const std::string& file_name) const
{
Expand All @@ -32,13 +33,17 @@ void Exx_LRI_Interface<T, Tdata>::write_Hexxs_cereal(const std::string& file_nam
template<typename T, typename Tdata>
void Exx_LRI_Interface<T, Tdata>::read_Hexxs_cereal(const std::string& file_name)
{
ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal");
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
std::ifstream ifs(file_name + "_" + std::to_string(GlobalV::MY_RANK), std::ofstream::binary);
ModuleBase::TITLE("Exx_LRI_Interface", "read_Hexxs_cereal");
ModuleBase::timer::tick("Exx_LRI_Interface", "read_Hexxs_cereal");
const std::string file_name_rank = file_name + "_" + std::to_string(GlobalV::MY_RANK);
std::ifstream ifs(file_name_rank, std::ofstream::binary);
if(!ifs.is_open())
{ ModuleBase::WARNING_QUIT("Exx_LRI_Interface", file_name_rank+" not found."); }
cereal::BinaryInputArchive iar(ifs);
iar(this->exx_ptr->Hexxs);
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
}
*/

template<typename T, typename Tdata>
void Exx_LRI_Interface<T, Tdata>::init(const MPI_Comm &mpi_comm,
Expand Down Expand Up @@ -115,7 +120,8 @@ void Exx_LRI_Interface<T, Tdata>::exx_before_all_runners(const K_Vectors& kv, co
if (this->exx_spacegroup_symmetry)
{
const std::array<int, 3>& period = RI_Util::get_Born_vonKarmen_period(kv);
this->symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st,
this->symrot_.find_irreducible_sector(
ucell.symm, ucell.atoms, ucell.st,
RI_Util::get_Born_von_Karmen_cells(period), period, ucell.lat);
// this->symrot_.set_Cs_rotation(this->exx_ptr->get_abfs_nchis());
this->symrot_.cal_Ms(kv, ucell, pv);
Expand Down Expand Up @@ -209,8 +215,19 @@ void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const int istep,
{ this->mix_DMk_2D.mix(dm_in.get_DMK_vector(), flag_restart); }
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>
Ds = PARAM.globalv.gamma_only_local
? RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_gamma_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin)
: RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin, this->exx_spacegroup_symmetry);
? RI_2D_Comm::split_m2D_ktoR<Tdata>(
ucell,
*this->exx_ptr->p_kv,
this->mix_DMk_2D.get_DMk_gamma_out(),
*dm_in.get_paraV_pointer(),
PARAM.inp.nspin)
: RI_2D_Comm::split_m2D_ktoR<Tdata>(
ucell,
*this->exx_ptr->p_kv,
this->mix_DMk_2D.get_DMk_k_out(),
*dm_in.get_paraV_pointer(),
PARAM.inp.nspin,
this->exx_spacegroup_symmetry);

if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace)
{ this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); }
Expand Down Expand Up @@ -240,11 +257,10 @@ void Exx_LRI_Interface<T, Tdata>::exx_hamilt2density(elecstate::ElecState& elec,
{
if (GlobalV::MY_RANK == 0)
{
try { GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
try
{ GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
catch (const std::exception& e)
{
std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl;
}
{ std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl; }
}
Parallel_Common::bcast_double(this->exx_ptr->Eexx);
this->exx_ptr->Eexx /= GlobalC::exx_info.info_global.hybrid_alpha;
Expand Down
6 changes: 3 additions & 3 deletions source/module_ri/exx_abfs-abfs_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ ModuleBase::Element_Basis_Index::Range
range[T].resize( orb[T].size() );
for( size_t L=0; L!=range[T].size(); ++L )
{
range[T][L].N = orb[T][L].size();
range[T][L].M = 2*L+1;
}
range[T][L].N = orb[T][L].size();
range[T][L].M = 2*L+1;
}
}
return range;
}
2 changes: 1 addition & 1 deletion source/module_ri/exx_abfs-abfs_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class LCAO_Orbitals;
class Exx_Abfs::Abfs_Index
{
public:
static ModuleBase::Element_Basis_Index::Range construct_range( const LCAO_Orbitals &orb );
static ModuleBase::Element_Basis_Index::Range construct_range( const LCAO_Orbitals &orb );
static ModuleBase::Element_Basis_Index::Range construct_range( const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &orb );
};

Expand Down
Loading
Loading