Skip to content

Commit ede1437

Browse files
linpeizePeizeLin
andauthored
add check and update code format in exx (#6244)
* fix bug and update code format in exx * Fix bug in Exx_LRI_Interface. Change && to || * update exx in ESolver_KS_LCAO and FORCE_STRESS * update runtime check in Exx_LRI_Interface * move exx_lri_double from ESolver_KS_LCAO to Exx_LRI_Interface --------- Co-authored-by: linpz <[email protected]>
1 parent f7cb1d3 commit ede1437

File tree

12 files changed

+334
-260
lines changed

12 files changed

+334
-260
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,11 @@ ESolver_KS_LCAO<TK, TR>::ESolver_KS_LCAO()
8787
// because some members like two_level_step are used outside if(cal_exx)
8888
if (GlobalC::exx_info.info_ri.real_number)
8989
{
90-
this->exx_lri_double = std::make_shared<Exx_LRI<double>>(GlobalC::exx_info.info_ri);
91-
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(exx_lri_double);
90+
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(GlobalC::exx_info.info_ri);
9291
}
9392
else
9493
{
95-
this->exx_lri_complex = std::make_shared<Exx_LRI<std::complex<double>>>(GlobalC::exx_info.info_ri);
96-
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(exx_lri_complex);
94+
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(GlobalC::exx_info.info_ri);
9795
}
9896
#endif
9997
}
@@ -183,12 +181,12 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
183181
// initialize 2-center radial tables for EXX-LRI
184182
if (GlobalC::exx_info.info_ri.real_number)
185183
{
186-
this->exx_lri_double->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
184+
this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
187185
this->exd->exx_before_all_runners(this->kv, ucell, this->pv);
188186
}
189187
else
190188
{
191-
this->exx_lri_complex->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
189+
this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
192190
this->exc->exx_before_all_runners(this->kv, ucell, this->pv);
193191
}
194192
}
@@ -327,8 +325,8 @@ void ESolver_KS_LCAO<TK, TR>::cal_force(UnitCell& ucell, ModuleBase::matrix& for
327325
this->pw_rho,
328326
this->solvent,
329327
#ifdef __EXX
330-
*this->exx_lri_double,
331-
*this->exx_lri_complex,
328+
*this->exd,
329+
*this->exc,
332330
#endif
333331
&ucell.symm);
334332

@@ -479,8 +477,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
479477
this->gd
480478
#ifdef __EXX
481479
,
482-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
483-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
480+
this->exd ? &this->exd->get_Hexxs() : nullptr,
481+
this->exc ? &this->exc->get_Hexxs() : nullptr
484482
#endif
485483
);
486484
}
@@ -508,8 +506,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
508506
this->two_center_bundle_
509507
#ifdef __EXX
510508
,
511-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
512-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
509+
this->exd ? &this->exd->get_Hexxs() : nullptr,
510+
this->exc ? &this->exc->get_Hexxs() : nullptr
513511
#endif
514512
);
515513
}

source/module_esolver/esolver_ks_lcao.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
9595
#ifdef __EXX
9696
std::shared_ptr<Exx_LRI_Interface<TK, double>> exd = nullptr;
9797
std::shared_ptr<Exx_LRI_Interface<TK, std::complex<double>>> exc = nullptr;
98-
std::shared_ptr<Exx_LRI<double>> exx_lri_double = nullptr;
99-
std::shared_ptr<Exx_LRI<std::complex<double>>> exx_lri_complex = nullptr;
10098
#endif
10199

102100
friend class LR::ESolver_LR<double, double>;

source/module_esolver/lcao_before_scf.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
198198
,
199199
istep,
200200
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
201-
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
202-
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
201+
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
202+
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
203203
#endif
204204
);
205205
}

source/module_esolver/lcao_others.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
204204
,
205205
istep,
206206
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
207-
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
208-
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
207+
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
208+
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
209209
#endif
210210
);
211211
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
5151
ModulePW::PW_Basis* rhopw,
5252
surchem& solvent,
5353
#ifdef __EXX
54-
Exx_LRI<double>& exx_lri_double,
55-
Exx_LRI<std::complex<double>>& exx_lri_complex,
54+
Exx_LRI_Interface<T, double>& exd,
55+
Exx_LRI_Interface<T, std::complex<double>>& exc,
5656
#endif
5757
ModuleSymmetry::Symmetry* symm)
5858
{
@@ -372,26 +372,26 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
372372
{
373373
if (GlobalC::exx_info.info_ri.real_number)
374374
{
375-
exx_lri_double.cal_exx_force(ucell.nat);
376-
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.force_exx;
375+
exd.cal_exx_force(ucell.nat);
376+
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_force();
377377
}
378378
else
379379
{
380-
exx_lri_complex.cal_exx_force(ucell.nat);
381-
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.force_exx;
380+
exc.cal_exx_force(ucell.nat);
381+
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_force();
382382
}
383383
}
384384
if (isstress)
385385
{
386386
if (GlobalC::exx_info.info_ri.real_number)
387387
{
388-
exx_lri_double.cal_exx_stress(ucell.omega, ucell.lat0);
389-
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.stress_exx;
388+
exd.cal_exx_stress(ucell.omega, ucell.lat0);
389+
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_stress();
390390
}
391391
else
392392
{
393-
exx_lri_complex.cal_exx_stress(ucell.omega, ucell.lat0);
394-
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.stress_exx;
393+
exc.cal_exx_stress(ucell.omega, ucell.lat0);
394+
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_stress();
395395
}
396396
}
397397
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "module_io/input_conv.h"
1212
#include "module_psi/psi.h"
1313
#ifdef __EXX
14-
#include "module_ri/Exx_LRI.h"
14+
#include "module_ri/Exx_LRI_interface.h"
1515
#endif
1616
#include "force_stress_arrays.h"
1717
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
@@ -50,8 +50,8 @@ class Force_Stress_LCAO
5050
ModulePW::PW_Basis* rhopw,
5151
surchem& solvent,
5252
#ifdef __EXX
53-
Exx_LRI<double>& exx_lri_double,
54-
Exx_LRI<std::complex<double>>& exx_lri_complex,
53+
Exx_LRI_Interface<T, double>& exd,
54+
Exx_LRI_Interface<T, std::complex<double>>& exc,
5555
#endif
5656
ModuleSymmetry::Symmetry* symm);
5757

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
233233
{
234234
// if the same kernel is calculated in the esolver_ks, move it
235235
std::string dft_functional = LR_Util::tolower(input.dft_functional);
236-
if (ks_sol.exx_lri_double && std::is_same<T, double>::value && xc_kernel == dft_functional) {
237-
this->move_exx_lri(ks_sol.exx_lri_double);
238-
} else if (ks_sol.exx_lri_complex && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
239-
this->move_exx_lri(ks_sol.exx_lri_complex);
236+
if (ks_sol.exd && std::is_same<T, double>::value && xc_kernel == dft_functional) {
237+
this->move_exx_lri(ks_sol.exd->exx_ptr);
238+
} else if (ks_sol.exc && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
239+
this->move_exx_lri(ks_sol.exc->exx_ptr);
240240
} else // construct C, V from scratch
241241
{
242242
// set ccp_type according to the xc_kernel

source/module_ri/Exx_LRI.h

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@
2121
#include "module_exx_symmetry/symmetry_rotation.h"
2222

2323
class Parallel_Orbitals;
24-
24+
2525
template<typename T, typename Tdata>
2626
class RPA_LRI;
2727

2828
template<typename T, typename Tdata>
2929
class Exx_LRI_Interface;
3030

31-
namespace LR
32-
{
33-
template<typename T, typename TR>
34-
class ESolver_LR;
31+
namespace LR
32+
{
33+
template<typename T, typename TR>
34+
class ESolver_LR;
3535

36-
template<typename T>
37-
class OperatorLREXX;
38-
}
36+
template<typename T>
37+
class OperatorLREXX;
38+
}
3939

4040
template<typename Tdata>
4141
class Exx_LRI
@@ -49,37 +49,39 @@ class Exx_LRI
4949
using TatomR = std::array<double,Ndim>; // tmp
5050

5151
public:
52-
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
53-
Exx_LRI operator=(const Exx_LRI&) = delete;
54-
Exx_LRI operator=(Exx_LRI&&);
55-
56-
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
57-
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
58-
59-
void init(const MPI_Comm &mpi_comm_in,
60-
const UnitCell &ucell,
61-
const K_Vectors &kv_in,
62-
const LCAO_Orbitals& orb);
63-
void cal_exx_force(const int& nat);
64-
void cal_exx_stress(const double& omega, const double& lat0);
52+
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
53+
Exx_LRI operator=(const Exx_LRI&) = delete;
54+
Exx_LRI operator=(Exx_LRI&&);
55+
56+
void init(
57+
const MPI_Comm &mpi_comm_in,
58+
const UnitCell &ucell,
59+
const K_Vectors &kv_in,
60+
const LCAO_Orbitals& orb);
6561
void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false);
66-
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
62+
void cal_exx_elec(
63+
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
6764
const UnitCell& ucell,
68-
const Parallel_Orbitals& pv,
69-
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
70-
std::vector<std::vector<int>> get_abfs_nchis() const;
65+
const Parallel_Orbitals& pv,
66+
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
67+
void cal_exx_force(const int& nat);
68+
void cal_exx_stress(const double& omega, const double& lat0);
69+
70+
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
71+
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
72+
//std::vector<std::vector<int>> get_abfs_nchis() const;
7173

7274
std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;
73-
double Eexx;
75+
double Eexx;
7476
ModuleBase::matrix force_exx;
7577
ModuleBase::matrix stress_exx;
76-
78+
7779

7880
private:
7981
const Exx_Info::Exx_Info_RI &info;
8082
MPI_Comm mpi_comm;
8183
const K_Vectors *p_kv = nullptr;
82-
std::vector<double> orb_cutoff_;
84+
std::vector<double> orb_cutoff_;
8385

8486
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> lcaos;
8587
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> abfs;
@@ -89,16 +91,16 @@ class Exx_LRI
8991
RI::Exx<TA,Tcell,Ndim,Tdata> exx_lri;
9092

9193
void post_process_Hexx( std::map<TA, std::map<TAC, RI::Tensor<Tdata>>> &Hexxs_io ) const;
92-
double post_process_Eexx(const double& Eexx_in) const;
94+
double post_process_Eexx(const double& Eexx_in) const;
9395

9496
friend class RPA_LRI<double, Tdata>;
9597
friend class RPA_LRI<std::complex<double>, Tdata>;
9698
friend class Exx_LRI_Interface<double, Tdata>;
9799
friend class Exx_LRI_Interface<std::complex<double>, Tdata>;
98-
friend class LR::ESolver_LR<double, double>;
99-
friend class LR::ESolver_LR<std::complex<double>, double>;
100-
friend class LR::OperatorLREXX<double>;
101-
friend class LR::OperatorLREXX<std::complex<double>>;
100+
friend class LR::ESolver_LR<double, double>;
101+
friend class LR::ESolver_LR<std::complex<double>, double>;
102+
friend class LR::OperatorLREXX<double>;
103+
friend class LR::OperatorLREXX<std::complex<double>>;
102104
};
103105

104106
#include "Exx_LRI.hpp"

source/module_ri/Exx_LRI.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
#include <string>
2727

2828
template<typename Tdata>
29-
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
29+
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
3030
const UnitCell &ucell,
31-
const K_Vectors &kv_in,
31+
const K_Vectors &kv_in,
3232
const LCAO_Orbitals& orb)
3333
{
3434
ModuleBase::TITLE("Exx_LRI","init");
@@ -130,7 +130,7 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
130130
this->exx_lri.set_parallel(this->mpi_comm, atoms_pos, latvec, period);
131131

132132
// std::max(3) for gamma_only, list_A2 should contain cell {-1,0,1}. In the future distribute will be neighbour.
133-
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
133+
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
134134
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA,std::array<Tcell,Ndim>>>>>
135135
list_As_Vs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Vs, 2, false);
136136

@@ -237,7 +237,7 @@ void Exx_LRI<Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, R
237237
}
238238
this->Eexx = post_process_Eexx(this->Eexx);
239239
this->exx_lri.set_symmetry(false, {});
240-
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
240+
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
241241
}
242242

243243
template<typename Tdata>
@@ -283,11 +283,6 @@ void Exx_LRI<Tdata>::cal_exx_force(const int& nat)
283283
ModuleBase::TITLE("Exx_LRI","cal_exx_force");
284284
ModuleBase::timer::tick("Exx_LRI", "cal_exx_force");
285285

286-
if (!this->exx_lri.flag_finish.D)
287-
{
288-
ModuleBase::WARNING_QUIT("Force_Stress_LCAO", "Cannot calculate EXX force when the first PBE loop is not converged.");
289-
}
290-
291286
this->force_exx.create(nat, Ndim);
292287
for(int is=0; is<PARAM.inp.nspin; ++is)
293288
{
@@ -328,6 +323,7 @@ void Exx_LRI<Tdata>::cal_exx_stress(const double& omega, const double& lat0)
328323
ModuleBase::timer::tick("Exx_LRI", "cal_exx_stress");
329324
}
330325

326+
/*
331327
template<typename Tdata>
332328
std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
333329
{
@@ -341,5 +337,6 @@ std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
341337
}
342338
return abfs_nchis;
343339
}
340+
*/
344341

345342
#endif

0 commit comments

Comments
 (0)