Skip to content

Commit 4d4029a

Browse files
committed
refactor: create class for coordination number calculation
1 parent ad12c42 commit 4d4029a

File tree

6 files changed

+145
-81
lines changed

6 files changed

+145
-81
lines changed

include/dftd_ncoord.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,79 @@
2828

2929
namespace dftd4 {
3030

31+
32+
class NCoordBase
33+
{
34+
public:
35+
TVector<double> cn;
36+
TMatrix<double> dcndr;
37+
static const double rad[];
38+
double kcn; // Steepness of counting function
39+
double norm_exp;
40+
double cutoff;
41+
// Get the coordination number
42+
int get_ncoord( // with ghost atoms
43+
const TMolecule&,
44+
const TMatrix<double>&,
45+
const double,
46+
TVector<double>&,
47+
TMatrix<double>&,
48+
bool);
49+
int get_ncoord( // without ghost atoms
50+
const TMolecule&,
51+
const TIVector&,
52+
const TMatrix<double>&,
53+
bool);
54+
// Calculate the coordination number using the virtual counting function
55+
int ncoord_base(
56+
const TMolecule&,
57+
const TIVector&,
58+
const TMatrix<double>&);
59+
// Calculate the derivative of the coordination number
60+
int dr_ncoord_base(
61+
const TMolecule&,
62+
const TIVector&,
63+
const TMatrix<double>&);
64+
// Get the DFT-D4 coordination number
65+
int get_ncoord_d4(
66+
const TMolecule&,
67+
const TMatrix<double>&,
68+
bool);
69+
int get_ncoord_d4(
70+
const TMolecule&,
71+
const TIVector&,
72+
const TMatrix<double>&,
73+
bool);
74+
int ncoord_d4(
75+
const TMolecule&,
76+
const TIVector&,
77+
const TMatrix<double>&);
78+
int dncoord_d4(
79+
const TMolecule&,
80+
const TIVector&,
81+
const TMatrix<double>&);
82+
83+
// Counting function
84+
virtual double count_fct(double) const = 0;
85+
// Derivative of the counting function
86+
virtual double dr_count_fct(double) const = 0;
87+
// Constructor
88+
NCoordBase(double, double, double);
89+
// Virtual destructor
90+
virtual ~NCoordBase() = default;
91+
};
92+
93+
class NCoordErf : public NCoordBase {
94+
public:
95+
// erf() based counting function
96+
double count_fct(double) const override;
97+
// derivative of the erf() based counting function
98+
double dr_count_fct(double) const override;
99+
// Constructor
100+
NCoordErf(double optional_kcn = 7.5, double optional_norm_exp = 1.0, double optional_cutoff = 25.0)
101+
: NCoordBase(optional_kcn, optional_norm_exp, optional_cutoff){}
102+
};
103+
31104
/**
32105
* Calculate all distance pairs and store in matrix.
33106
*

src/dftd_dispersion.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,16 @@ int get_dispersion(
102102
dist.NewMatrix(nat, nat);
103103
calc_distances(mol, realIdx, dist);
104104

105-
TVector<double> cn; // D4 coordination number
106105
TVector<double> q; // partial charges from EEQ model
107-
TMatrix<double> dcndr; // derivative of D4-CN
108106
TMatrix<double> dqdr; // derivative of partial charges
109107
TVector<double> gradient; // derivative of dispersion energy
108+
NCoordErf ncoord_erf;
110109
multicharge::EEQModel chrg_model; // Charge model
111110

112-
cn.NewVector(nat);
111+
ncoord_erf.cn.NewVector(nat);
113112
q.NewVector(nat);
114113
if (lgrad) {
115-
dcndr.NewMatrix(3 * nat, nat);
114+
ncoord_erf.dcndr.NewMatrix(3 * nat, nat);
116115
dqdr.NewMatrix(3 * nat, nat);
117116
gradient.NewVector(3 * nat);
118117
}
@@ -122,7 +121,7 @@ int get_dispersion(
122121
if (info != EXIT_SUCCESS) return info;
123122

124123
// get the D4 coordination number
125-
info = get_ncoord_d4(mol, realIdx, dist, cutoff.cn, cn, dcndr, lgrad);
124+
info = ncoord_erf.get_ncoord_d4(mol, realIdx, dist, lgrad);
126125
if (info != EXIT_SUCCESS) return info;
127126

128127
// maximum number of reference systems
@@ -145,7 +144,7 @@ int get_dispersion(
145144
dgwdq.NewMatrix(mref, nat);
146145
}
147146
info = d4.weight_references(
148-
mol, realIdx, cn, q, refq, gwvec, dgwdcn, dgwdq, lgrad
147+
mol, realIdx, ncoord_erf.cn, q, refq, gwvec, dgwdcn, dgwdq, lgrad
149148
);
150149
if (info != EXIT_SUCCESS) return info;
151150

@@ -212,11 +211,11 @@ int get_dispersion(
212211
dgwdq.NewMatrix(mref, nat);
213212
}
214213
info = d4.weight_references(
215-
mol, realIdx, cn, q, refq, gwvec, dgwdcn, dgwdq, lgrad
214+
mol, realIdx, ncoord_erf.cn, q, refq, gwvec, dgwdcn, dgwdq, lgrad
216215
);
217216
if (info != EXIT_SUCCESS) return info;
218217

219-
cn.Delete();
218+
ncoord_erf.cn.Delete();
220219
q.Delete();
221220
refq.Delete();
222221

@@ -253,7 +252,7 @@ int get_dispersion(
253252
);
254253
if (info != EXIT_SUCCESS) return info;
255254
} else {
256-
cn.Delete();
255+
ncoord_erf.cn.Delete();
257256
q.Delete();
258257
refq.Delete();
259258
gwvec.Delete();
@@ -266,9 +265,9 @@ int get_dispersion(
266265
dc6dcn.DelMat();
267266
dc6dq.DelMat();
268267

269-
if (lgrad) { BLAS_Add_Mat_x_Vec(gradient, dcndr, dEdcn, false, 1.0); }
268+
if (lgrad) { BLAS_Add_Mat_x_Vec(gradient, ncoord_erf.dcndr, dEdcn, false, 1.0); }
270269

271-
dcndr.DelMat();
270+
ncoord_erf.dcndr.DelMat();
272271
dEdcn.DelVec();
273272
dEdq.DelVec();
274273

src/dftd_eeq.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,23 +72,22 @@ int ChargeModel::get_charges(
7272
bool lverbose{false};
7373
int nat = realIdx.Max() + 1;
7474

75-
TVector<double> cn; // EEQ cordination number
76-
TMatrix<double> dcndr; // Derivative of EEQ-CN
75+
dftd4::NCoordErf ncoord_erf;
7776

78-
cn.NewVec(nat);
79-
if (lgrad) dcndr.NewMat(nat, 3 * nat);
77+
ncoord_erf.cn.NewVec(nat);
78+
if (lgrad) ncoord_erf.dcndr.NewMat(nat, 3 * nat);
8079

8180
// get the EEQ coordination number
82-
info = get_ncoord_erf(mol, realIdx, dist, cutoff, cn, dcndr, lgrad);
81+
info = ncoord_erf.get_ncoord(mol, realIdx, dist, lgrad);
8382
if (info != EXIT_SUCCESS) return info;
8483

8584
// corresponds to model%solve in Fortran
8685
info =
87-
eeq_chrgeq(mol, realIdx, dist, charge, cn, q, dcndr, dqdr, lgrad, lverbose);
86+
eeq_chrgeq(mol, realIdx, dist, charge, ncoord_erf.cn, q, ncoord_erf.dcndr, dqdr, lgrad, lverbose);
8887
if (info != EXIT_SUCCESS) return info;
8988

90-
dcndr.DelMat();
91-
cn.DelVec();
89+
ncoord_erf.dcndr.DelMat();
90+
ncoord_erf.cn.DelVec();
9291

9392
return EXIT_SUCCESS;
9493
};

0 commit comments

Comments
 (0)