Skip to content

Commit 06e69d0

Browse files
quantumstevewilliamfgc
authored andcommitted
Remove test differences and fix build
Signed-off-by: Steven Hahn <[email protected]>
1 parent cb3d07c commit 06e69d0

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ typename LCAOrbitalBuilderT<T>::BasisSet_t* LCAOrbitalBuilderT<T>::createBasisSe
468468
return mBasisSet;
469469
}
470470
#ifndef QMC_COMPLEX
471+
#ifndef MIXED_PRECISION
471472
template<>
472473
std::unique_ptr<SPOSetT<double>> LCAOrbitalBuilderT<double>::createWithCuspCorrection(
473474
xmlNodePtr cur,
@@ -484,7 +485,6 @@ std::unique_ptr<SPOSetT<double>> LCAOrbitalBuilderT<double>::createWithCuspCorre
484485
lcwc->setOrbitalSetSize(lcwc->lcao.getOrbitalSetSize());
485486
sposet = std::move(lcwc);
486487
}
487-
#ifndef MIXED_PRECISION
488488
// Create a temporary particle set to use for cusp initialization.
489489
// The particle coordinates left at the end are unsuitable for further
490490
// computations. The coordinates get set to nuclear positions, which

src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp

+48-8
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,41 @@ void LCAOrbitalSetT<T>::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSetT
519519
}
520520
}
521521

522+
template<typename T>
523+
void LCAOrbitalSetT<T>::mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
524+
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
525+
OffloadMWVArray& vp_phi_v) const
526+
{
527+
assert(this == &spo_list.getLeader());
528+
auto& spo_leader = spo_list.template getCastedLeader<LCAOrbitalSetT<T>>();
529+
//const size_t nw = spo_list.size();
530+
auto& vp_basis_v_mw = spo_leader.mw_mem_handle_.getResource().vp_basis_v_mw;
531+
//Splatter basis_v
532+
const size_t nVPs = vp_phi_v.size(0);
533+
vp_basis_v_mw.resize(nVPs, BasisSetSize);
534+
535+
auto basis_list = spo_leader.extractBasisRefList(spo_list);
536+
myBasisSet->mw_evaluateValueVPs(basis_list, vp_list, vp_basis_v_mw);
537+
vp_basis_v_mw.updateFrom(); // TODO: remove this when gemm is implemented
538+
539+
if (Identity)
540+
{
541+
std::copy_n(vp_basis_v_mw.data_at(0, 0), this->OrbitalSetSize * nVPs, vp_phi_v.data_at(0, 0));
542+
}
543+
else
544+
{
545+
const size_t requested_orb_size = vp_phi_v.size(1);
546+
assert(requested_orb_size <= this->OrbitalSetSize);
547+
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
548+
BLAS::gemm('T', 'N',
549+
requested_orb_size, // MOs
550+
nVPs, // walkers * Virtual Particles
551+
BasisSetSize, // AOs
552+
1, C_partial_view.data(), BasisSetSize, vp_basis_v_mw.data(), BasisSetSize, 0, vp_phi_v.data(),
553+
requested_orb_size);
554+
}
555+
}
556+
522557
template<class T>
523558
void LCAOrbitalSetT<T>::mw_evaluateValue(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
524559
const RefVectorWithLeader<ParticleSetT<T>>& P_list,
@@ -579,15 +614,20 @@ void LCAOrbitalSetT<T>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSetT<T
579614
const std::vector<const T*>& invRow_ptr_list,
580615
std::vector<std::vector<T>>& ratios_list) const
581616
{
582-
const size_t nw = spo_list.size();
583-
for (size_t iw = 0; iw < nw; iw++)
584-
{
617+
assert(this == &spo_list.getLeader());
618+
auto& spo_leader = spo_list.template getCastedLeader<LCAOrbitalSetT<T>>();
619+
auto& vp_phi_v = spo_leader.mw_mem_handle_.getResource().vp_phi_v;
620+
621+
const size_t nVPs = VirtualParticleSetT<T>::countVPs(vp_list);
622+
const size_t requested_orb_size = psi_list[0].get().size();
623+
vp_phi_v.resize(nVPs, requested_orb_size);
624+
625+
mw_evaluateValueVPsImplGEMM(spo_list, vp_list, vp_phi_v);
626+
627+
size_t index = 0;
628+
for (size_t iw = 0; iw < vp_list.size(); iw++)
585629
for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
586-
{
587-
spo_list[iw].evaluateValue(vp_list[iw], iat, psi_list[iw]);
588-
ratios_list[iw][iat] = simd::dot(psi_list[iw].get().data(), invRow_ptr_list[iw], psi_list[iw].get().size());
589-
}
590-
}
630+
ratios_list[iw][iat] = simd::dot(vp_phi_v.data_at(index++, 0), invRow_ptr_list[iw], requested_orb_size);
591631
}
592632

593633
template<class T>

src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h

+5
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ class LCAOrbitalSetT : public SPOSetT<T>
358358
int iat,
359359
OffloadMWVArray& phi_v) const;
360360

361+
/// packed walker GEMM implementation with multi virtual particle sets
362+
void mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
363+
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
364+
OffloadMWVArray& phi_v) const;
365+
361366
/// helper function for extracting a list of basis sets from a list of LCAOrbitalSet
362367
RefVectorWithLeader<basis_type> extractBasisRefList(const RefVectorWithLeader<SPOSetT<T>>& spo_list) const;
363368

src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
#include "DistanceTable.h"
2424
#include "QMCWaveFunctions/SPOSet.h"
2525
#include "QMCWaveFunctions/LCAO/LCAOrbitalSet.h"
26-
#include <stdio.h>
26+
27+
#include <cstdio>
2728
#include <string>
2829
#include <limits>
2930

@@ -338,10 +339,9 @@ void test_LCAO_DiamondC_2x1x1_real()
338339
ratios_list[iw].resize(nvp_list[iw]);
339340

340341
// just need dummy refvec with correct size
341-
SPOSet::ValueVector tmp_psi_list(norb), tmp_psi_list_2(norb);
342+
SPOSet::ValueVector tmp_psi_list(norb);
342343
spo->mw_evaluateDetRatios(spo_list, RefVectorWithLeader<const VirtualParticleSet>(VP_, {VP_, VP_2}),
343-
RefVector<SPOSet::ValueVector>{tmp_psi_list, tmp_psi_list_2}, invRow_ptr_list,
344-
ratios_list);
344+
RefVector<SPOSet::ValueVector>{tmp_psi_list}, invRow_ptr_list, ratios_list);
345345

346346
std::vector<SPOSet::ValueType> ratios_ref_0(nvp_);
347347
std::vector<SPOSet::ValueType> ratios_ref_1(nvp_2);

0 commit comments

Comments
 (0)