|
10 | 10 | //////////////////////////////////////////////////////////////////////////////////////
|
11 | 11 |
|
12 | 12 |
|
| 13 | +#include "OptimizableObject.h" |
13 | 14 | #include "catch.hpp"
|
14 | 15 |
|
15 | 16 | #include "OhmmsData/Libxml2Doc.h"
|
@@ -104,4 +105,184 @@ TEST_CASE("kspace jastrow", "[wavefunction]")
|
104 | 105 | double logpsi_real = std::real(jas->evaluateLog(elec_, elec_.G, elec_.L));
|
105 | 106 | CHECK(logpsi_real == Approx(-4.4088303951)); // !!!! value not checked
|
106 | 107 | }
|
| 108 | + |
| 109 | +TEST_CASE("kspace jastrow derivatives", "[wavefunction]") |
| 110 | +{ |
| 111 | + Communicate* c = OHMMS::Controller; |
| 112 | + |
| 113 | + // initialize simulationcell for kvectors |
| 114 | + const char* xmltext = R"(<tmp> |
| 115 | + <simulationcell> |
| 116 | + <parameter name="lattice" units="bohr"> |
| 117 | + 1.00000000 0.00000000 0.00000000 |
| 118 | + -0.27834620 1.98053614 0.00000000 |
| 119 | + 0.31358539 1.32438627 2.67351177 |
| 120 | + </parameter> |
| 121 | + <parameter name="bconds"> |
| 122 | + p p p |
| 123 | + </parameter> |
| 124 | + <parameter name="LR_dim_cutoff"> 15 </parameter> |
| 125 | + </simulationcell> |
| 126 | +</tmp>)"; |
| 127 | + Libxml2Document doc; |
| 128 | + bool okay = doc.parseFromString(xmltext); |
| 129 | + REQUIRE(okay); |
| 130 | + |
| 131 | + xmlNodePtr root = doc.getRoot(); |
| 132 | + xmlNodePtr part1 = xmlFirstElementChild(root); |
| 133 | + |
| 134 | + // read lattice |
| 135 | + Lattice lattice; |
| 136 | + LatticeParser lp(lattice); |
| 137 | + lp.put(part1); |
| 138 | + lattice.print(app_log(), 0); |
| 139 | + |
| 140 | + const SimulationCell simulation_cell(lattice); |
| 141 | + ParticleSet ions_(simulation_cell); |
| 142 | + ParticleSet elec_(simulation_cell); |
| 143 | + |
| 144 | + ions_.setName("ion"); |
| 145 | + ions_.create({1}); |
| 146 | + ions_.R[0] = {0.7, 0.8, 0.9}; |
| 147 | + elec_.setName("elec"); |
| 148 | + elec_.create({2, 0}); |
| 149 | + elec_.R[0] = {0.1, 0.2, 0.3}; |
| 150 | + elec_.R[1] = {0.4, 0.5, 0.6}; |
| 151 | + SpeciesSet& tspecies = elec_.getSpeciesSet(); |
| 152 | + int upIdx = tspecies.addSpecies("u"); |
| 153 | + int downIdx = tspecies.addSpecies("d"); |
| 154 | + int chargeIdx = tspecies.addAttribute("charge"); |
| 155 | + tspecies(chargeIdx, upIdx) = -1; |
| 156 | + tspecies(chargeIdx, downIdx) = -1; |
| 157 | + // initialize SK |
| 158 | + elec_.createSK(); |
| 159 | + |
| 160 | + |
| 161 | + const char* jk1input = R"(<tmp> |
| 162 | +<jastrow name="Jk" type="kSpace" source="ion0"> |
| 163 | + <correlation kc="5.0" type="One-Body" symmetry="crystal"> |
| 164 | + <coefficients id="cG1" type="Array"> |
| 165 | +0.8137632591137914 0.44470998182408406 0.5401857355121155 0.25752329923480577 0.8232477489081528 0.44767385249164526 0.6464929258779923 0.4929171993220475 0.7846576485476944 0.662660459182017 |
| 166 | + </coefficients> |
| 167 | + </correlation> |
| 168 | +</jastrow> |
| 169 | +</tmp> |
| 170 | +)"; |
| 171 | + okay = doc.parseFromString(jk1input); |
| 172 | + REQUIRE(okay); |
| 173 | + |
| 174 | + root = doc.getRoot(); |
| 175 | + xmlNodePtr jas1 = xmlFirstElementChild(root); |
| 176 | + |
| 177 | + kSpaceJastrowBuilder jastrow(c, elec_, ions_); |
| 178 | + std::unique_ptr<WaveFunctionComponent> wfc(jastrow.buildComponent(jas1)); |
| 179 | + kSpaceJastrow* jas = dynamic_cast<kSpaceJastrow*>(wfc.get()); |
| 180 | + |
| 181 | + // update all distance tables |
| 182 | + elec_.update(); |
| 183 | + |
| 184 | + double logpsi_real = std::real(jas->evaluateLog(elec_, elec_.G, elec_.L)); |
| 185 | + CHECK(logpsi_real == Approx(0.7137163755813973)); |
| 186 | + CHECK(std::real(elec_.G[0][0]) == Approx(0.0)); |
| 187 | + CHECK(std::real(elec_.G[0][1]) == Approx(-0.35067397)); |
| 188 | + CHECK(std::real(elec_.G[0][2]) == Approx(-1.18746358)); |
| 189 | + CHECK(std::real(elec_.G[1][0]) == Approx(0.0)); |
| 190 | + CHECK(std::real(elec_.G[1][1]) == Approx(-0.5971649)); |
| 191 | + CHECK(std::real(elec_.G[1][2]) == Approx(-1.30622405)); |
| 192 | + CHECK(std::real(elec_.L[0]) == Approx(-9.23735526)); |
| 193 | + CHECK(std::real(elec_.L[1]) == Approx(2.37396869)); |
| 194 | + |
| 195 | + opt_variables_type opt_vars; |
| 196 | + jas->checkInVariablesExclusive(opt_vars); |
| 197 | + opt_vars.resetIndex(); |
| 198 | + jas->checkOutVariables(opt_vars); |
| 199 | + jas->resetParametersExclusive(opt_vars); |
| 200 | + |
| 201 | + const int nopt = opt_vars.size(); |
| 202 | + Vector<ParticleSet::ValueType> dlogpsi(nopt); |
| 203 | + Vector<ParticleSet::ValueType> dhpsioverpsi(nopt); |
| 204 | + |
| 205 | + std::vector<double> refvals1 = {0.17404927, 0.30881786, 0.03441212, 0.31140348, 0.3374628, |
| 206 | + 0.14393708, -0.17959687, 0.12910817, -0.14896848, 0.2460836}; |
| 207 | + std::vector<double> refvals2 = {-0.433200292526692, 1.3469998192405797, -0.5970340626326731, 1.6549156762746264, |
| 208 | + 2.155972109951776, 0.7377064401570266, -2.33721875453341, 0.20791372534745278, |
| 209 | + -3.0523924081649056, 1.7868809500942717}; |
| 210 | + |
| 211 | + |
| 212 | + dlogpsi = 0; |
| 213 | + jas->evaluateDerivativesWF(elec_, opt_vars, dlogpsi); |
| 214 | + for (int i = 0; i < nopt; i++) |
| 215 | + CHECK(std::real(dlogpsi[i]) == Approx(refvals1[i])); |
| 216 | + |
| 217 | + dlogpsi = 0; |
| 218 | + jas->evaluateDerivatives(elec_, opt_vars, dlogpsi, dhpsioverpsi); |
| 219 | + for (int i = 0; i < nopt; i++) |
| 220 | + CHECK(std::real(dlogpsi[i]) == Approx(refvals1[i])); |
| 221 | + for (int i = 0; i < nopt; i++) |
| 222 | + CHECK(std::real(dhpsioverpsi[i]) == Approx(refvals2[i])); |
| 223 | + |
| 224 | + // Twobody check |
| 225 | + const char* jk2input = R"(<tmp> |
| 226 | +<jastrow name="Jk" type="kSpace" source="ion0"> |
| 227 | + <correlation kc="5.0" type="Two-Body" symmetry="crystal"> |
| 228 | + <coefficients id="cG2" type="Array"> |
| 229 | + 0.9531019536367156 0.18148850587408794 0.20772539137877666 0.9340655612889098 0.50544913283957 |
| 230 | + </coefficients> |
| 231 | + </correlation> |
| 232 | +</jastrow> |
| 233 | +</tmp> |
| 234 | +)"; |
| 235 | + okay = doc.parseFromString(jk2input); |
| 236 | + REQUIRE(okay); |
| 237 | + |
| 238 | + root = doc.getRoot(); |
| 239 | + xmlNodePtr jas2 = xmlFirstElementChild(root); |
| 240 | + |
| 241 | + kSpaceJastrowBuilder jastrow2(c, elec_, ions_); |
| 242 | + std::unique_ptr<WaveFunctionComponent> wfc2(jastrow2.buildComponent(jas2)); |
| 243 | + kSpaceJastrow* j2 = dynamic_cast<kSpaceJastrow*>(wfc2.get()); |
| 244 | + |
| 245 | + // update all distance tables |
| 246 | + elec_.update(); |
| 247 | + elec_.G = 0; |
| 248 | + elec_.L = 0; |
| 249 | + //reference values from python code in test directory |
| 250 | + logpsi_real = std::real(j2->evaluateLog(elec_, elec_.G, elec_.L)); |
| 251 | + CHECK(logpsi_real == Approx(1.3399793683)); |
| 252 | + CHECK(std::real(elec_.G[0][0]) == Approx(0.0)); |
| 253 | + CHECK(std::real(elec_.G[0][1]) == Approx(1.37913407)); |
| 254 | + CHECK(std::real(elec_.G[0][2]) == Approx(2.47457664)); |
| 255 | + CHECK(std::real(elec_.G[1][0]) == Approx(0.0)); |
| 256 | + CHECK(std::real(elec_.G[1][1]) == Approx(-1.37913407)); |
| 257 | + CHECK(std::real(elec_.G[1][2]) == Approx(-2.47457664)); |
| 258 | + CHECK(std::real(elec_.L[0]) == Approx(-1.13586493)); |
| 259 | + CHECK(std::real(elec_.L[1]) == Approx(-1.13586493)); |
| 260 | + |
| 261 | + opt_variables_type opt_vars2; |
| 262 | + j2->checkInVariablesExclusive(opt_vars2); |
| 263 | + opt_vars2.resetIndex(); |
| 264 | + j2->checkOutVariables(opt_vars2); |
| 265 | + j2->resetParametersExclusive(opt_vars2); |
| 266 | + |
| 267 | + const int nopt2 = opt_vars2.size(); |
| 268 | + Vector<ParticleSet::ValueType> dlogpsi2(nopt2); |
| 269 | + Vector<ParticleSet::ValueType> dhpsioverpsi2(nopt2); |
| 270 | + |
| 271 | + std::vector<double> refvals3 = {0.66537659, 0.51973641, 0.71270004, 0.25905169, 0.4381535}; |
| 272 | + std::vector<double> refvals4 = {-1.2583630528695267, -2.8959032092323866, 4.02906470009512, -11.046492939481567, |
| 273 | + -7.338195726624974}; |
| 274 | + |
| 275 | + dlogpsi2 = 0.0; |
| 276 | + j2->evaluateDerivativesWF(elec_, opt_vars2, dlogpsi2); |
| 277 | + for (int i = 0; i < nopt2; i++) |
| 278 | + CHECK(std::real(dlogpsi2[i]) == Approx(refvals3[i])); |
| 279 | + |
| 280 | + dlogpsi2 = 0.0; |
| 281 | + dhpsioverpsi2 = 0.0; |
| 282 | + j2->evaluateDerivatives(elec_, opt_vars2, dlogpsi2, dhpsioverpsi2); |
| 283 | + for (int i = 0; i < nopt2; i++) |
| 284 | + CHECK(std::real(dlogpsi2[i]) == Approx(refvals3[i])); |
| 285 | + for (int i = 0; i < nopt2; i++) |
| 286 | + CHECK(std::real(dhpsioverpsi2[i]) == Approx(refvals4[i])); |
| 287 | +} |
107 | 288 | } // namespace qmcplusplus
|
0 commit comments