Skip to content

Commit 864c22d

Browse files
authored
Merge pull request #4819 from camelto2/change_VariableSet_to_RealType
Convert VariableSet to RealType only
2 parents d27cab8 + cf0c220 commit 864c22d

10 files changed

+43
-67
lines changed

src/Optimize/OptimizeBase.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class CostFunctionBase
4242

4343
virtual int getNumParams() const = 0;
4444

45-
virtual Return_t& Params(int i) = 0;
45+
virtual Return_rt& Params(int i) = 0;
4646

4747
virtual Return_t Params(int i) const = 0;
4848

src/QMCDrivers/WFOpt/QMCCostFunctionBase.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ void QMCCostFunctionBase::reportParametersH5()
251251
if (!myComm->rank())
252252
{
253253
int ci_size = 0;
254-
std::vector<opt_variables_type::value_type> CIcoeff;
254+
std::vector<opt_variables_type::real_type> CIcoeff;
255255
for (int i = 0; i < OptVariables.size(); i++)
256256
{
257257
std::array<char, 128> Coeff;
@@ -271,7 +271,7 @@ void QMCCostFunctionBase::reportParametersH5()
271271
if (ci_size > 0)
272272
{
273273
CI_Opt = true;
274-
newh5 = RootName + ".opt.h5";
274+
newh5 = RootName + ".opt.h5";
275275
*msg_stream << " <Ci Coeffs saved in opt_coeffs=\"" << newh5 << "\">" << std::endl;
276276
hdf_archive hout;
277277
hout.create(newh5, H5F_ACC_TRUNC);

src/QMCDrivers/WFOpt/QMCCostFunctionBase.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class QMCCostFunctionBase : public CostFunctionBase<QMCTraits::RealType>, public
6868
SUM_INDEX_SIZE
6969
};
7070

71-
using EffectiveWeight = QMCTraits::QTFull::RealType;
71+
using EffectiveWeight = QMCTraits::QTFull::RealType;
7272
using FullPrecRealType = QMCTraits::FullPrecRealType;
7373
///Constructor.
7474
QMCCostFunctionBase(ParticleSet& w, TrialWaveFunction& psi, QMCHamiltonian& h, Communicate* comm);
@@ -85,7 +85,7 @@ class QMCCostFunctionBase : public CostFunctionBase<QMCTraits::RealType>, public
8585
///Path and name of the HDF5 prefix where CI coeffs are saved
8686
std::string newh5;
8787
///assign optimization parameter i
88-
Return_t& Params(int i) override { return OptVariables[i]; }
88+
Return_rt& Params(int i) override { return OptVariables[i]; }
8989
///return optimization parameter i
9090
Return_t Params(int i) const override { return OptVariables[i]; }
9191
int getType(int i) const { return OptVariables.getType(i); }

src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimize.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
10541054
<< std::endl;
10551055

10561056
// for each set of shifts, solve the linear method equations for the parameter update direction
1057-
std::vector<std::vector<ValueType>> parameterDirections;
1057+
std::vector<std::vector<RealType>> parameterDirections;
10581058
#ifdef HAVE_LMY_ENGINE
10591059
// call the engine to perform update
10601060
EngineObj->wfn_update_compute();
@@ -1070,7 +1070,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
10701070
if (true)
10711071
{
10721072
for (int j = 0; j < N; j++)
1073-
parameterDirections.at(i).at(j) = EngineObj->wfn_update().at(i * N + j);
1073+
parameterDirections.at(i).at(j) = std::real(EngineObj->wfn_update().at(i * N + j));
10741074
}
10751075
else
10761076
parameterDirections.at(i).at(0) = 1.0;
@@ -1080,7 +1080,7 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
10801080
optTarget->setneedGrads(false);
10811081

10821082
// prepare vectors to hold the initial and current parameters
1083-
std::vector<ValueType> currParams(numParams, 0.0);
1083+
std::vector<RealType> currParams(numParams, 0.0);
10841084

10851085
// initialize the initial and current parameter vectors
10861086
for (int i = 0; i < numParams; i++)
@@ -1168,8 +1168,8 @@ bool QMCFixedSampleLinearOptimize::adaptive_three_shift_run()
11681168
}
11691169

11701170
// find the best shift and the corresponding update direction
1171-
const std::vector<ValueType>* bestDirection = 0;
1172-
int best_shift = -1;
1171+
const std::vector<RealType>* bestDirection = 0;
1172+
int best_shift = -1;
11731173
for (int k = 0; k < costValues.size() && std::abs((initCost - initCost) / initCost) < max_relative_cost_change; k++)
11741174
if (is_best_cost(k, costValues, shifts_i, initCost) && good_update.at(k))
11751175
{
@@ -1440,7 +1440,7 @@ bool QMCFixedSampleLinearOptimize::descent_run()
14401440

14411441
for (int i = 0; i < results.size(); i++)
14421442
{
1443-
optTarget->Params(i) = results[i];
1443+
optTarget->Params(i) = std::real(results[i]);
14441444
}
14451445

14461446
//If descent is being run as part of a hybrid optimization, need to check if a vector of
@@ -1488,7 +1488,7 @@ bool QMCFixedSampleLinearOptimize::hybrid_run()
14881488
app_log() << "Update descent engine parameter values after Blocked LM step" << std::endl;
14891489
for (int i = 0; i < optTarget->getNumParams(); i++)
14901490
{
1491-
ValueType val = optTarget->Params(i);
1491+
RealType val = optTarget->Params(i);
14921492
descentEngineObj->setParamVal(i, val);
14931493
}
14941494
}

src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimizeBatched.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
13441344
<< std::endl;
13451345

13461346
// for each set of shifts, solve the linear method equations for the parameter update direction
1347-
std::vector<std::vector<ValueType>> parameterDirections;
1347+
std::vector<std::vector<RealType>> parameterDirections;
13481348
#ifdef HAVE_LMY_ENGINE
13491349
// call the engine to perform update
13501350
EngineObj->wfn_update_compute();
@@ -1360,7 +1360,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
13601360
if (true)
13611361
{
13621362
for (int j = 0; j < N; j++)
1363-
parameterDirections.at(i).at(j) = EngineObj->wfn_update().at(i * N + j);
1363+
parameterDirections.at(i).at(j) = std::real(EngineObj->wfn_update().at(i * N + j));
13641364
}
13651365
else
13661366
parameterDirections.at(i).at(0) = 1.0;
@@ -1370,7 +1370,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
13701370
//There will be updates of 0 for parameters that were filtered out before derivative ratios were used by the engine.
13711371
if (options_LMY_.filter_param)
13721372
{
1373-
std::vector<std::vector<ValueType>> tmpParameterDirections;
1373+
std::vector<std::vector<RealType>> tmpParameterDirections;
13741374
tmpParameterDirections.resize(shifts_i.size());
13751375

13761376
for (int i = 0; i < shifts_i.size(); i++)
@@ -1400,7 +1400,7 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
14001400
optTarget->setneedGrads(false);
14011401

14021402
// prepare vectors to hold the initial and current parameters
1403-
std::vector<ValueType> currParams(numParams, 0.0);
1403+
std::vector<RealType> currParams(numParams, 0.0);
14041404

14051405
// initialize the initial and current parameter vectors
14061406
for (int i = 0; i < numParams; i++)
@@ -1493,8 +1493,8 @@ bool QMCFixedSampleLinearOptimizeBatched::adaptive_three_shift_run()
14931493
}
14941494

14951495
// find the best shift and the corresponding update direction
1496-
const std::vector<ValueType>* bestDirection = 0;
1497-
int best_shift = -1;
1496+
const std::vector<RealType>* bestDirection = 0;
1497+
int best_shift = -1;
14981498
for (int k = 0;
14991499
k < costValues.size() && std::abs((initCost - initCost) / initCost) < options_LMY_.max_relative_cost_change; k++)
15001500
if (is_best_cost(k, costValues, shifts_i, initCost) && good_update.at(k))
@@ -1778,7 +1778,7 @@ bool QMCFixedSampleLinearOptimizeBatched::descent_run()
17781778

17791779
for (int i = 0; i < results.size(); i++)
17801780
{
1781-
optTarget->Params(i) = results[i];
1781+
optTarget->Params(i) = std::real(results[i]);
17821782
}
17831783

17841784
//If descent is being run as part of a hybrid optimization, need to check if a vector of

src/QMCDrivers/tests/test_DescentEngine.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ TEST_CASE("DescentEngine RMSprop update", "[drivers][descent]")
4444
optimize::VariableSet myVars;
4545

4646
//Two fake parameters are specified
47-
optimize::VariableSet::value_type first_param(1.0);
48-
optimize::VariableSet::value_type second_param(-2.0);
47+
optimize::VariableSet::real_type first_param(1.0);
48+
optimize::VariableSet::real_type second_param(-2.0);
4949

5050
myVars.insert("first", first_param);
5151
myVars.insert("second", second_param);

src/QMCWaveFunctions/Fermion/SlaterDetBuilder.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,10 @@ std::unique_ptr<MultiSlaterDetTableMethod> SlaterDetBuilder::createMSDFast(
561561
Optimizable = CI_Optimizable = true;
562562
if (csf_data_ptr)
563563
for (int i = 1; i < csf_data_ptr->coeffs.size(); i++)
564-
myVars.insert(CItags[i], csf_data_ptr->coeffs[i], true, optimize::LINEAR_P);
564+
myVars.insert(CItags[i], std::real(csf_data_ptr->coeffs[i]), true, optimize::LINEAR_P);
565565
else
566566
for (int i = 1; i < C.size(); i++)
567-
myVars.insert(CItags[i], C[i], true, optimize::LINEAR_P);
567+
myVars.insert(CItags[i], std::real(C[i]), true, optimize::LINEAR_P);
568568
}
569569
else
570570
{

src/QMCWaveFunctions/VariableSet.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void VariableSet::insertFrom(const VariableSet& input)
5454

5555
void VariableSet::insertFromSum(const VariableSet& input_1, const VariableSet& input_2)
5656
{
57-
value_type sum_val;
57+
real_type sum_val;
5858
std::string vname;
5959

6060
// Check that objects to be summed together have the same number of active
@@ -94,7 +94,7 @@ void VariableSet::insertFromSum(const VariableSet& input_1, const VariableSet& i
9494

9595
void VariableSet::insertFromDiff(const VariableSet& input_1, const VariableSet& input_2)
9696
{
97-
value_type diff_val;
97+
real_type diff_val;
9898
std::string vname;
9999

100100
// Check that objects to be subtracted have the same number of active
@@ -259,7 +259,7 @@ void VariableSet::writeToHDF(const std::string& filename, qmcplusplus::hdf_archi
259259

260260
hout.push("name_value_lists");
261261

262-
std::vector<qmcplusplus::QMCTraits::ValueType> param_values;
262+
std::vector<qmcplusplus::QMCTraits::RealType> param_values;
263263
std::vector<std::string> param_names;
264264
for (auto& pair_it : NameAndValue)
265265
{
@@ -292,7 +292,7 @@ void VariableSet::readFromHDF(const std::string& filename, qmcplusplus::hdf_arch
292292
throw std::runtime_error(err_msg.str());
293293
}
294294

295-
std::vector<qmcplusplus::QMCTraits::ValueType> param_values;
295+
std::vector<qmcplusplus::QMCTraits::RealType> param_values;
296296
hin.read(param_values, "parameter_values");
297297

298298
std::vector<std::string> param_names;

src/QMCWaveFunctions/VariableSet.h

+6-7
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ enum
4848
*/
4949
struct VariableSet
5050
{
51-
using value_type = qmcplusplus::QMCTraits::ValueType;
52-
using real_type = qmcplusplus::QMCTraits::RealType;
51+
using real_type = qmcplusplus::QMCTraits::RealType;
5352

54-
using pair_type = std::pair<std::string, value_type>;
53+
using pair_type = std::pair<std::string, real_type>;
5554
using index_pair_type = std::pair<std::string, int>;
5655
using iterator = std::vector<pair_type>::iterator;
5756
using const_iterator = std::vector<pair_type>::const_iterator;
@@ -131,7 +130,7 @@ struct VariableSet
131130
return -1;
132131
}
133132

134-
inline void insert(const std::string& vname, value_type v, bool enable = true, int type = OTHER_P)
133+
inline void insert(const std::string& vname, real_type v, bool enable = true, int type = OTHER_P)
135134
{
136135
iterator loc = find(vname);
137136
int ind_loc = loc - NameAndValue.begin();
@@ -169,7 +168,7 @@ struct VariableSet
169168

170169
/** equivalent to std::map<std::string,T>[string] operator
171170
*/
172-
inline value_type& operator[](const std::string& vname)
171+
inline real_type& operator[](const std::string& vname)
173172
{
174173
iterator loc = find(vname);
175174
if (loc == NameAndValue.end())
@@ -192,12 +191,12 @@ struct VariableSet
192191
/** return the i-th value
193192
* @param i index
194193
*/
195-
inline value_type operator[](int i) const { return NameAndValue[i].second; }
194+
inline real_type operator[](int i) const { return NameAndValue[i].second; }
196195

197196
/** assign the i-th value
198197
* @param i index
199198
*/
200-
inline value_type& operator[](int i) { return NameAndValue[i].second; }
199+
inline real_type& operator[](int i) { return NameAndValue[i].second; }
201200

202201
/** get the i-th parameter's type
203202
* @param i index

src/QMCWaveFunctions/tests/test_variable_set.cpp

+10-33
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212

1313
#include "catch.hpp"
14-
#include "complex_approx.hpp"
1514

1615
#include "VariableSet.h"
1716
#include "io/hdf/hdf_archive.h"
@@ -20,7 +19,6 @@
2019
#include <string>
2120

2221
using std::string;
23-
using qmcplusplus::ValueApprox;
2422

2523
namespace optimize
2624
{
@@ -37,7 +35,7 @@ TEST_CASE("VariableSet empty", "[optimize]")
3735
TEST_CASE("VariableSet one", "[optimize]")
3836
{
3937
VariableSet vs;
40-
VariableSet::value_type first_val(1.123456789);
38+
VariableSet::real_type first_val(1.123456789);
4139
vs.insert("first", first_val);
4240
std::vector<std::string> names{"first"};
4341
vs.activate(names.begin(), names.end(), true);
@@ -47,44 +45,31 @@ TEST_CASE("VariableSet one", "[optimize]")
4745
REQUIRE(vs.getIndex("first") == 0);
4846
REQUIRE(vs.name(0) == "first");
4947
double first_val_real = 1.123456789;
50-
CHECK(std::real(vs[0] ) == Approx(first_val_real));
48+
CHECK(vs[0] == Approx(first_val_real));
5149

5250
std::ostringstream o;
5351
vs.print(o, 0, false);
5452
//std::cout << o.str() << std::endl;
55-
#ifdef QMC_COMPLEX
56-
REQUIRE(o.str() == "first (1.123457e+00,0.000000e+00) 0 1 ON 0\n");
57-
#else
5853
REQUIRE(o.str() == "first 1.123457e+00 0 1 ON 0\n");
59-
#endif
6054

6155
std::ostringstream o2;
6256
vs.print(o2, 1, true);
6357
//std::cout << o2.str() << std::endl;
6458

65-
#ifdef QMC_COMPLEX
66-
char formatted_output[] = " Name Value Type Recompute Use Index\n"
67-
" ----- ---------------------------- ---- --------- --- -----\n"
68-
" first (1.123457e+00,0.000000e+00) 0 1 ON 0\n";
69-
70-
71-
REQUIRE(o2.str() == formatted_output);
72-
#else
7359
char formatted_output[] = " Name Value Type Recompute Use Index\n"
7460
" ----- ---------------------------- ---- --------- --- -----\n"
7561
" first 1.123457e+00 0 1 ON 0\n";
7662

7763

7864
REQUIRE(o2.str() == formatted_output);
79-
#endif
8065
}
8166

8267
TEST_CASE("VariableSet output", "[optimize]")
8368
{
8469
VariableSet vs;
85-
VariableSet::value_type first_val(11234.56789);
86-
VariableSet::value_type second_val(0.000256789);
87-
VariableSet::value_type third_val(-1.2);
70+
VariableSet::real_type first_val(11234.56789);
71+
VariableSet::real_type second_val(0.000256789);
72+
VariableSet::real_type third_val(-1.2);
8873
vs.insert("s", first_val);
8974
vs.insert("second", second_val);
9075
vs.insert("really_long_name", third_val);
@@ -95,29 +80,21 @@ TEST_CASE("VariableSet output", "[optimize]")
9580
vs.print(o, 0, true);
9681
//std::cout << o.str() << std::endl;
9782

98-
#ifdef QMC_COMPLEX
99-
char formatted_output[] = " Name Value Type Recompute Use Index\n"
100-
"---------------- ---------------------------- ---- --------- --- -----\n"
101-
" s (1.123457e+04,0.000000e+00) 0 1 ON 0\n"
102-
" second (2.567890e-04,0.000000e+00) 0 1 ON 1\n"
103-
"really_long_name (-1.200000e+00,0.000000e+00) 0 1 ON 2\n";
104-
#else
10583
char formatted_output[] = " Name Value Type Recompute Use Index\n"
10684
"---------------- ---------------------------- ---- --------- --- -----\n"
10785
" s 1.123457e+04 0 1 ON 0\n"
10886
" second 2.567890e-04 0 1 ON 1\n"
10987
"really_long_name -1.200000e+00 0 1 ON 2\n";
110-
#endif
11188

11289
REQUIRE(o.str() == formatted_output);
11390
}
11491

11592
TEST_CASE("VariableSet HDF output and input", "[optimize]")
11693
{
11794
VariableSet vs;
118-
VariableSet::value_type first_val(11234.56789);
119-
VariableSet::value_type second_val(0.000256789);
120-
VariableSet::value_type third_val(-1.2);
95+
VariableSet::real_type first_val(11234.56789);
96+
VariableSet::real_type second_val(0.000256789);
97+
VariableSet::real_type third_val(-1.2);
12198
vs.insert("s", first_val);
12299
vs.insert("second", second_val);
123100
vs.insert("really_really_really_long_name", third_val);
@@ -129,8 +106,8 @@ TEST_CASE("VariableSet HDF output and input", "[optimize]")
129106
vs2.insert("second", 0.0);
130107
qmcplusplus::hdf_archive hin;
131108
vs2.readFromHDF("vp.h5", hin);
132-
CHECK(vs2.find("s")->second == ValueApprox(first_val));
133-
CHECK(vs2.find("second")->second == ValueApprox(second_val));
109+
CHECK(vs2.find("s")->second == Approx(first_val));
110+
CHECK(vs2.find("second")->second == Approx(second_val));
134111
// This value as in the file, but not in the VariableSet that loaded the file,
135112
// so the value does not get added.
136113
CHECK(vs2.find("really_really_really_long_name") == vs2.end());

0 commit comments

Comments
 (0)