Skip to content

Commit 808af53

Browse files
authored
Feature: Implement cal_force_op for sincos parallel (#6265)
* implement gpu op for sincos loops * add cpu kernel for cal_force_loc & cal_force_ew * fix sincos op for gpu&cpu * fix vloc computation in cal_force_loc_sincos_op * fix cal_force_ew * fix malloc error
1 parent ede1437 commit 808af53

File tree

5 files changed

+773
-83
lines changed

5 files changed

+773
-83
lines changed

source/module_hamilt_pw/hamilt_pwdft/forces.cpp

Lines changed: 215 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "module_hamilt_general/module_ewald/H_Ewald_pw.h"
1616
#include "module_hamilt_general/module_surchem/surchem.h"
1717
#include "module_hamilt_general/module_vdw/vdw.h"
18+
#include "kernels/force_op.h"
1819

1920
#ifdef _OPENMP
2021
#include <omp.h>
@@ -531,31 +532,110 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
531532
// to G space. maybe need fftw with OpenMP
532533
rho_basis->real2recip(aux, aux);
533534

534-
#ifdef _OPENMP
535-
#pragma omp parallel for
536-
#endif
537-
for (int iat = 0; iat < this->nat; ++iat)
538-
{
539-
// read `it` `ia` from the table
535+
// sincos op for G space
536+
537+
538+
// data preparation
539+
std::vector<FPTYPE> tau_flat(this->nat * 3);
540+
std::vector<FPTYPE> gcar_flat(rho_basis->npw * 3);
541+
542+
543+
for (int iat = 0; iat < this->nat; iat++) {
544+
int it = ucell.iat2it[iat];
545+
int ia = ucell.iat2ia[iat];
546+
547+
tau_flat[iat * 3 + 0] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][0]);
548+
tau_flat[iat * 3 + 1] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][1]);
549+
tau_flat[iat * 3 + 2] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][2]);
550+
}
551+
552+
for (int ig = 0; ig < rho_basis->npw; ig++) {
553+
gcar_flat[ig * 3 + 0] = static_cast<FPTYPE>(rho_basis->gcar[ig][0]);
554+
gcar_flat[ig * 3 + 1] = static_cast<FPTYPE>(rho_basis->gcar[ig][1]);
555+
gcar_flat[ig * 3 + 2] = static_cast<FPTYPE>(rho_basis->gcar[ig][2]);
556+
}
557+
558+
// calculate vloc_factors for all atom types
559+
std::vector<FPTYPE> vloc_per_type_host(this->nat * rho_basis->npw);
560+
for (int iat = 0; iat < this->nat; iat++) {
540561
int it = ucell.iat2it[iat];
541-
int ia = ucell.iat2ia[iat];
542-
for (int ig = 0; ig < rho_basis->npw; ig++)
543-
{
544-
const double phase = ModuleBase::TWO_PI * (rho_basis->gcar[ig] * ucell.atoms[it].tau[ia]);
545-
double sinp, cosp;
546-
ModuleBase::libm::sincos(phase, &sinp, &cosp);
547-
const double factor
548-
= vloc(it, rho_basis->ig2igg[ig]) * (cosp * aux[ig].imag() + sinp * aux[ig].real());
549-
forcelc(iat, 0) += rho_basis->gcar[ig][0] * factor;
550-
forcelc(iat, 1) += rho_basis->gcar[ig][1] * factor;
551-
forcelc(iat, 2) += rho_basis->gcar[ig][2] * factor;
562+
for (int ig = 0; ig < rho_basis->npw; ig++) {
563+
vloc_per_type_host[iat * rho_basis->npw + ig] = static_cast<FPTYPE>(vloc(it, rho_basis->ig2igg[ig]));
552564
}
553-
forcelc(iat, 0) *= (ucell.tpiba * ucell.omega);
554-
forcelc(iat, 1) *= (ucell.tpiba * ucell.omega);
555-
forcelc(iat, 2) *= (ucell.tpiba * ucell.omega);
565+
}
566+
567+
std::vector<std::complex<FPTYPE>> aux_fptype(rho_basis->npw);
568+
for (int ig = 0; ig < rho_basis->npw; ig++) {
569+
aux_fptype[ig] = static_cast<std::complex<FPTYPE>>(aux[ig]);
570+
}
571+
572+
FPTYPE* d_gcar = gcar_flat.data();
573+
FPTYPE* d_tau = tau_flat.data();
574+
FPTYPE* d_vloc_per_type = vloc_per_type_host.data();
575+
std::complex<FPTYPE>* d_aux = aux_fptype.data();
576+
FPTYPE* d_force = nullptr;
577+
std::vector<FPTYPE> force_host(this->nat * 3);
578+
579+
if (this->device == base_device::GpuDevice)
580+
{
581+
d_gcar = nullptr;
582+
d_tau = nullptr;
583+
d_vloc_per_type = nullptr;
584+
d_aux = nullptr;
585+
586+
resmem_var_op()(this->ctx, d_gcar, rho_basis->npw * 3);
587+
resmem_var_op()(this->ctx, d_tau, this->nat * 3);
588+
resmem_var_op()(this->ctx, d_vloc_per_type, this->nat * rho_basis->npw);
589+
resmem_complex_op()(this->ctx, d_aux, rho_basis->npw);
590+
resmem_var_op()(this->ctx, d_force, this->nat * 3);
591+
592+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_gcar, gcar_flat.data(), rho_basis->npw * 3);
593+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_tau, tau_flat.data(), this->nat * 3);
594+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_vloc_per_type, vloc_per_type_host.data(), this->nat * rho_basis->npw);
595+
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, d_aux, aux_fptype.data(), rho_basis->npw);
596+
597+
base_device::memory::set_memory_op<FPTYPE, Device>()(this->ctx, d_force, 0.0, this->nat * 3);
598+
}
599+
else
600+
{
601+
d_force = force_host.data();
602+
std::fill(force_host.begin(), force_host.end(), static_cast<FPTYPE>(0.0));
603+
}
604+
605+
const FPTYPE scale_factor = static_cast<FPTYPE>(ucell.tpiba * ucell.omega);
606+
607+
// call op for sincos calculation
608+
hamilt::cal_force_loc_sincos_op<FPTYPE, Device>()(
609+
this->ctx,
610+
this->nat,
611+
rho_basis->npw,
612+
this->nat,
613+
d_gcar,
614+
d_tau,
615+
d_vloc_per_type,
616+
d_aux,
617+
scale_factor,
618+
d_force
619+
);
620+
621+
if (this->device == base_device::GpuDevice)
622+
{
623+
syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, force_host.data(), d_force, this->nat * 3);
624+
625+
delmem_var_op()(this->ctx, d_gcar);
626+
delmem_var_op()(this->ctx, d_tau);
627+
delmem_var_op()(this->ctx, d_vloc_per_type);
628+
delmem_complex_op()(this->ctx, d_aux);
629+
delmem_var_op()(this->ctx, d_force);
630+
}
631+
632+
for (int iat = 0; iat < this->nat; iat++) {
633+
forcelc(iat, 0) = static_cast<double>(force_host[iat * 3 + 0]);
634+
forcelc(iat, 1) = static_cast<double>(force_host[iat * 3 + 1]);
635+
forcelc(iat, 2) = static_cast<double>(force_host[iat * 3 + 2]);
556636
}
557637

558-
// this->print(GlobalV::ofs_running, "local forces", forcelc);
638+
// this->print(GlobalV: :ofs_running, "local forces", forcelc);
559639
Parallel_Reduce::reduce_pool(forcelc.c, forcelc.nr * forcelc.nc);
560640
delete[] aux;
561641
ModuleBase::timer::tick("Forces", "cal_force_loc");
@@ -665,6 +745,119 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
665745
aux[rho_basis->ig_gge0] = std::complex<double>(0.0, 0.0);
666746
}
667747

748+
// sincos op for cal_force_ew
749+
750+
std::vector<FPTYPE> it_facts_host(this->nat);
751+
std::vector<FPTYPE> tau_flat(this->nat * 3);
752+
753+
// iterate over by lookup table
754+
for (int iat = 0; iat < this->nat; iat++) {
755+
int it = ucell.iat2it[iat];
756+
int ia = ucell.iat2ia[iat];
757+
758+
double zv;
759+
if (PARAM.inp.use_paw)
760+
{
761+
#ifdef USE_PAW
762+
zv = GlobalC::paw_cell.get_val(it);
763+
#endif
764+
}
765+
else
766+
{
767+
zv = ucell.atoms[it].ncpp.zv;
768+
}
769+
770+
it_facts_host[iat] = static_cast<FPTYPE>(zv * ModuleBase::e2 * ucell.tpiba *
771+
ModuleBase::TWO_PI / ucell.omega * fact);
772+
773+
tau_flat[iat * 3 + 0] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][0]);
774+
tau_flat[iat * 3 + 1] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][1]);
775+
tau_flat[iat * 3 + 2] = static_cast<FPTYPE>(ucell.atoms[it].tau[ia][2]);
776+
}
777+
778+
std::vector<FPTYPE> gcar_flat(rho_basis->npw * 3);
779+
for (int ig = 0; ig < rho_basis->npw; ig++) {
780+
gcar_flat[ig * 3 + 0] = static_cast<FPTYPE>(rho_basis->gcar[ig][0]);
781+
gcar_flat[ig * 3 + 1] = static_cast<FPTYPE>(rho_basis->gcar[ig][1]);
782+
gcar_flat[ig * 3 + 2] = static_cast<FPTYPE>(rho_basis->gcar[ig][2]);
783+
}
784+
785+
std::vector<std::complex<FPTYPE>> aux_fptype(rho_basis->npw);
786+
for (int ig = 0; ig < rho_basis->npw; ig++) {
787+
aux_fptype[ig] = static_cast<std::complex<FPTYPE>>(aux[ig]);
788+
}
789+
790+
FPTYPE* d_gcar = gcar_flat.data();
791+
FPTYPE* d_tau = tau_flat.data();
792+
FPTYPE* d_it_facts = it_facts_host.data();
793+
std::complex<FPTYPE>* d_aux = aux_fptype.data();
794+
FPTYPE* d_force_g = nullptr;
795+
std::vector<FPTYPE> force_g_host(this->nat * 3);
796+
797+
if (this->device == base_device::GpuDevice)
798+
{
799+
d_gcar = nullptr;
800+
d_tau = nullptr;
801+
d_it_facts = nullptr;
802+
d_aux = nullptr;
803+
804+
resmem_var_op()(this->ctx, d_gcar, rho_basis->npw * 3);
805+
resmem_var_op()(this->ctx, d_tau, this->nat * 3);
806+
resmem_var_op()(this->ctx, d_it_facts, this->nat);
807+
resmem_complex_op()(this->ctx, d_aux, rho_basis->npw);
808+
resmem_var_op()(this->ctx, d_force_g, this->nat * 3);
809+
810+
811+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_gcar, gcar_flat.data(), rho_basis->npw * 3);
812+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_tau, tau_flat.data(), this->nat * 3);
813+
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, d_it_facts, it_facts_host.data(), this->nat);
814+
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, d_aux, aux_fptype.data(), rho_basis->npw);
815+
816+
817+
base_device::memory::set_memory_op<FPTYPE, Device>()(this->ctx, d_force_g, 0.0, this->nat * 3);
818+
}
819+
else
820+
{
821+
d_force_g = force_g_host.data();
822+
std::fill(force_g_host.begin(), force_g_host.end(), static_cast<FPTYPE>(0.0));
823+
}
824+
825+
// call op for sincos calculation
826+
hamilt::cal_force_ew_sincos_op<FPTYPE, Device>()(
827+
this->ctx,
828+
this->nat,
829+
rho_basis->npw,
830+
rho_basis->ig_gge0,
831+
d_gcar,
832+
d_tau,
833+
d_it_facts,
834+
d_aux,
835+
d_force_g
836+
);
837+
838+
839+
if (this->device == base_device::GpuDevice)
840+
{
841+
842+
syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, force_g_host.data(), d_force_g, this->nat * 3);
843+
844+
845+
delmem_var_op()(this->ctx, d_gcar);
846+
delmem_var_op()(this->ctx, d_tau);
847+
delmem_var_op()(this->ctx, d_it_facts);
848+
delmem_complex_op()(this->ctx, d_aux);
849+
delmem_var_op()(this->ctx, d_force_g);
850+
}
851+
852+
853+
for (int iat = 0; iat < this->nat; iat++) {
854+
forceion(iat, 0) += static_cast<double>(force_g_host[iat * 3 + 0]);
855+
forceion(iat, 1) += static_cast<double>(force_g_host[iat * 3 + 1]);
856+
forceion(iat, 2) += static_cast<double>(force_g_host[iat * 3 + 2]);
857+
}
858+
859+
860+
// calculate real space force
668861
#ifdef _OPENMP
669862
#pragma omp parallel
670863
{
@@ -688,66 +881,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
688881
iat_end = iat_beg + iat_end;
689882
ucell.iat2iait(iat_beg, &ia_beg, &it_beg);
690883

691-
int iat = iat_beg;
692-
int it = it_beg;
693-
int ia = ia_beg;
694-
695-
// preprocess ig_gap for skipping the ig point
696-
int ig_gap = (rho_basis->ig_gge0 >= 0 && rho_basis->ig_gge0 < rho_basis->npw) ? rho_basis->ig_gge0 : -1;
697-
698-
double it_fact = 0.;
699-
int last_it = -1;
700-
701-
// iterating atoms
702-
while (iat < iat_end)
703-
{
704-
if (it != last_it)
705-
{ // calculate it_tact when it is changed
706-
double zv;
707-
if (PARAM.inp.use_paw)
708-
{
709-
#ifdef USE_PAW
710-
zv = GlobalC::paw_cell.get_val(it);
711-
#endif
712-
}
713-
else
714-
{
715-
zv = ucell.atoms[it].ncpp.zv;
716-
}
717-
it_fact = zv * ModuleBase::e2 * ucell.tpiba * ModuleBase::TWO_PI / ucell.omega * fact;
718-
last_it = it;
719-
}
720-
721-
if (ucell.atoms[it].na != 0)
722-
{
723-
const auto ig_loop = [&](int ig_beg, int ig_end) {
724-
for (int ig = ig_beg; ig < ig_end; ig++)
725-
{
726-
const ModuleBase::Vector3<double> gcar = rho_basis->gcar[ig];
727-
const double arg = ModuleBase::TWO_PI * (gcar * ucell.atoms[it].tau[ia]);
728-
double sinp, cosp;
729-
ModuleBase::libm::sincos(arg, &sinp, &cosp);
730-
double sumnb = -cosp * aux[ig].imag() + sinp * aux[ig].real();
731-
forceion(iat, 0) += gcar[0] * sumnb;
732-
forceion(iat, 1) += gcar[1] * sumnb;
733-
forceion(iat, 2) += gcar[2] * sumnb;
734-
}
735-
};
736-
737-
// skip ig_gge0 point by separating ig loop into two part
738-
ig_loop(0, ig_gap);
739-
ig_loop(ig_gap + 1, rho_basis->npw);
740-
741-
forceion(iat, 0) *= it_fact;
742-
forceion(iat, 1) *= it_fact;
743-
forceion(iat, 2) *= it_fact;
744-
745-
++iat;
746-
ucell.step_iait(&ia, &it);
747-
}
748-
}
749-
750-
// means that the processor contains G=0 term.
884+
751885
if (rho_basis->ig_gge0 >= 0)
752886
{
753887
double rmax = 5.0 / (sqrt(alpha) * ucell.lat0);

0 commit comments

Comments
 (0)