@@ -519,6 +519,41 @@ void LCAOrbitalSetT<T>::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSetT
519
519
}
520
520
}
521
521
522
+ template <typename T>
523
+ void LCAOrbitalSetT<T>::mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
524
+ const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
525
+ OffloadMWVArray& vp_phi_v) const
526
+ {
527
+ assert (this == &spo_list.getLeader ());
528
+ auto & spo_leader = spo_list.template getCastedLeader <LCAOrbitalSetT<T>>();
529
+ // const size_t nw = spo_list.size();
530
+ auto & vp_basis_v_mw = spo_leader.mw_mem_handle_ .getResource ().vp_basis_v_mw ;
531
+ // Splatter basis_v
532
+ const size_t nVPs = vp_phi_v.size (0 );
533
+ vp_basis_v_mw.resize (nVPs, BasisSetSize);
534
+
535
+ auto basis_list = spo_leader.extractBasisRefList (spo_list);
536
+ myBasisSet->mw_evaluateValueVPs (basis_list, vp_list, vp_basis_v_mw);
537
+ vp_basis_v_mw.updateFrom (); // TODO: remove this when gemm is implemented
538
+
539
+ if (Identity)
540
+ {
541
+ std::copy_n (vp_basis_v_mw.data_at (0 , 0 ), this ->OrbitalSetSize * nVPs, vp_phi_v.data_at (0 , 0 ));
542
+ }
543
+ else
544
+ {
545
+ const size_t requested_orb_size = vp_phi_v.size (1 );
546
+ assert (requested_orb_size <= this ->OrbitalSetSize );
547
+ ValueMatrix C_partial_view (C->data (), requested_orb_size, BasisSetSize);
548
+ BLAS::gemm (' T' , ' N' ,
549
+ requested_orb_size, // MOs
550
+ nVPs, // walkers * Virtual Particles
551
+ BasisSetSize, // AOs
552
+ 1 , C_partial_view.data (), BasisSetSize, vp_basis_v_mw.data (), BasisSetSize, 0 , vp_phi_v.data (),
553
+ requested_orb_size);
554
+ }
555
+ }
556
+
522
557
template <class T >
523
558
void LCAOrbitalSetT<T>::mw_evaluateValue(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
524
559
const RefVectorWithLeader<ParticleSetT<T>>& P_list,
@@ -579,15 +614,20 @@ void LCAOrbitalSetT<T>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSetT<T
579
614
const std::vector<const T*>& invRow_ptr_list,
580
615
std::vector<std::vector<T>>& ratios_list) const
581
616
{
582
- const size_t nw = spo_list.size ();
583
- for (size_t iw = 0 ; iw < nw; iw++)
584
- {
617
+ assert (this == &spo_list.getLeader ());
618
+ auto & spo_leader = spo_list.template getCastedLeader <LCAOrbitalSetT<T>>();
619
+ auto & vp_phi_v = spo_leader.mw_mem_handle_ .getResource ().vp_phi_v ;
620
+
621
+ const size_t nVPs = VirtualParticleSetT<T>::countVPs (vp_list);
622
+ const size_t requested_orb_size = psi_list[0 ].get ().size ();
623
+ vp_phi_v.resize (nVPs, requested_orb_size);
624
+
625
+ mw_evaluateValueVPsImplGEMM (spo_list, vp_list, vp_phi_v);
626
+
627
+ size_t index = 0 ;
628
+ for (size_t iw = 0 ; iw < vp_list.size (); iw++)
585
629
for (size_t iat = 0 ; iat < vp_list[iw].getTotalNum (); iat++)
586
- {
587
- spo_list[iw].evaluateValue (vp_list[iw], iat, psi_list[iw]);
588
- ratios_list[iw][iat] = simd::dot (psi_list[iw].get ().data (), invRow_ptr_list[iw], psi_list[iw].get ().size ());
589
- }
590
- }
630
+ ratios_list[iw][iat] = simd::dot (vp_phi_v.data_at (index ++, 0 ), invRow_ptr_list[iw], requested_orb_size);
591
631
}
592
632
593
633
template <class T >
0 commit comments