Skip to content

Commit aefa730

Browse files
committed
simplify code
1 parent 991a8c2 commit aefa730

File tree

1 file changed

+38
-58
lines changed

1 file changed

+38
-58
lines changed

src/utilities/nep_utilities.cuh

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,25 @@ static __device__ __forceinline__ void get_f12_5body(
498498
f12[2] += tmp1 * r12[2];
499499
}
500500

501+
template <int L>
502+
static __device__ __forceinline__ void calculate_s_one(
503+
const int n,
504+
const int n_max_angular_plus_1,
505+
const float* Fp,
506+
const float* sum_fxyz,
507+
float* s)
508+
{
509+
const int L_minus_1 = L - 1;
510+
const int L_twice_plus_1 = 2 * L + 1;
511+
const int L_square_minus_1 = L * L - 1;
512+
float Fp_factor = 2.0f * Fp[L_minus_1 * n_max_angular_plus_1 + n];
513+
s[0] = sum_fxyz[n * NUM_OF_ABC + L_square_minus_1] * C3B[L_square_minus_1] * Fp_factor;
514+
Fp_factor *= 2.0f;
515+
for (int k = 1; k < L_twice_plus_1; ++k) {
516+
s[k] = sum_fxyz[n * NUM_OF_ABC + L_square_minus_1 + k] * C3B[L_square_minus_1 + k] * Fp_factor;
517+
}
518+
}
519+
501520
template <int L>
502521
static __device__ __forceinline__ void accumulate_f12_one(
503522
const float d12inv,
@@ -611,111 +630,72 @@ static __device__ __forceinline__ void accumulate_f12(
611630
const float fnp_original = fnp;
612631
const float d12inv = 1.0f / d12;
613632
const float r12unit[3] = {r12[0]*d12inv, r12[1]*d12inv, r12[2]*d12inv};
614-
// l = 1
633+
615634
fnp = fnp * d12inv - fn * d12inv * d12inv;
616635
fn = fn * d12inv;
617-
float s1[3] = {
618-
sum_fxyz[n * NUM_OF_ABC + 0], sum_fxyz[n * NUM_OF_ABC + 1], sum_fxyz[n * NUM_OF_ABC + 2]};
619636
if (num_L >= L_max + 2) {
637+
float s1[3] = {
638+
sum_fxyz[n * NUM_OF_ABC + 0], sum_fxyz[n * NUM_OF_ABC + 1], sum_fxyz[n * NUM_OF_ABC + 2]};
620639
get_f12_5body(d12, d12inv, fn, fnp, Fp[(L_max + 1) * n_max_angular_plus_1 + n], s1, r12, f12);
621640
}
622641

623642
if (L_max >= 1) {
624-
float Fp_factor = 2.0f * Fp[n];
625-
s1[0] *= C3B[0] * Fp_factor;
626-
Fp_factor *= 2.0f;
627-
s1[1] *= C3B[1] * Fp_factor;
628-
s1[2] *= C3B[2] * Fp_factor;
643+
float s1[3];
644+
calculate_s_one<1>(n, n_max_angular_plus_1, Fp, sum_fxyz, s1);
629645
accumulate_f12_one<1>(d12inv, fn_original, fnp_original, s1, r12unit, f12);
630646
}
631647

632-
// l = 2
633648
fnp = fnp * d12inv - fn * d12inv * d12inv;
634649
fn = fn * d12inv;
635-
float s2[5] = {
636-
sum_fxyz[n * NUM_OF_ABC + 3],
637-
sum_fxyz[n * NUM_OF_ABC + 4],
638-
sum_fxyz[n * NUM_OF_ABC + 5],
639-
sum_fxyz[n * NUM_OF_ABC + 6],
640-
sum_fxyz[n * NUM_OF_ABC + 7]};
641650
if (num_L >= L_max + 1) {
651+
float s2[5] = {
652+
sum_fxyz[n * NUM_OF_ABC + 3],
653+
sum_fxyz[n * NUM_OF_ABC + 4],
654+
sum_fxyz[n * NUM_OF_ABC + 5],
655+
sum_fxyz[n * NUM_OF_ABC + 6],
656+
sum_fxyz[n * NUM_OF_ABC + 7]};
642657
get_f12_4body(d12, d12inv, fn, fnp, Fp[L_max * n_max_angular_plus_1 + n], s2, r12, f12);
643658
}
644659

645660
if (L_max >= 2) {
646-
float Fp_factor = 2.0f * Fp[n_max_angular_plus_1 + n];
647-
s2[0] *= C3B[3] * Fp_factor;
648-
Fp_factor *= 2;
649-
s2[1] *= C3B[4] * Fp_factor;
650-
s2[2] *= C3B[5] * Fp_factor;
651-
s2[3] *= C3B[6] * Fp_factor;
652-
s2[4] *= C3B[7] * Fp_factor;
661+
float s2[5];
662+
calculate_s_one<2>(n, n_max_angular_plus_1, Fp, sum_fxyz, s2);
653663
accumulate_f12_one<2>(d12inv, fn_original, fnp_original, s2, r12unit, f12);
654664
}
655665

656666
if (L_max >= 3) {
657667
float s3[7];
658-
float Fp_factor = 2.0f * Fp[2 * n_max_angular_plus_1 + n];
659-
s3[0] = sum_fxyz[n * NUM_OF_ABC + 8] * C3B[8] * Fp_factor;
660-
Fp_factor *= 2.0f;
661-
for (int k = 1; k < 7; ++k) {
662-
s3[k] = sum_fxyz[n * NUM_OF_ABC + 8 + k] * C3B[8 + k] * Fp_factor;
663-
}
668+
calculate_s_one<3>(n, n_max_angular_plus_1, Fp, sum_fxyz, s3);
664669
accumulate_f12_one<3>(d12inv, fn_original, fnp_original, s3, r12unit, f12);
665670
}
666671

667672
if (L_max >= 4) {
668673
float s4[9];
669-
float Fp_factor = 2.0f * Fp[3 * n_max_angular_plus_1 + n];
670-
s4[0] = sum_fxyz[n * NUM_OF_ABC + 15] * C3B[15] * Fp_factor;
671-
Fp_factor *= 2.0f;
672-
for (int k = 1; k < 9; ++k) {
673-
s4[k] = sum_fxyz[n * NUM_OF_ABC + 15 + k] * C3B[15 + k] * Fp_factor;
674-
}
674+
calculate_s_one<4>(n, n_max_angular_plus_1, Fp, sum_fxyz, s4);
675675
accumulate_f12_one<4>(d12inv, fn_original, fnp_original, s4, r12unit, f12);
676676
}
677677

678678
if (L_max >= 5) {
679679
float s5[11];
680-
float Fp_factor = 2.0f * Fp[4 * n_max_angular_plus_1 + n];
681-
s5[0] = sum_fxyz[n * NUM_OF_ABC + 24] * C3B[24] * Fp_factor;
682-
Fp_factor *= 2.0f;
683-
for (int k = 1; k < 11; ++k) {
684-
s5[k] = sum_fxyz[n * NUM_OF_ABC + 24 + k] * C3B[24 + k] * Fp_factor;
685-
}
680+
calculate_s_one<5>(n, n_max_angular_plus_1, Fp, sum_fxyz, s5);
686681
accumulate_f12_one<5>(d12inv, fn_original, fnp_original, s5, r12unit, f12);
687682
}
688683

689684
if (L_max >= 6) {
690685
float s6[13];
691-
float Fp_factor = 2.0f * Fp[5 * n_max_angular_plus_1 + n];
692-
s6[0] = sum_fxyz[n * NUM_OF_ABC + 35] * C3B[35] * Fp_factor;
693-
Fp_factor *= 2.0f;
694-
for (int k = 1; k < 13; ++k) {
695-
s6[k] = sum_fxyz[n * NUM_OF_ABC + 35 + k] * C3B[35 + k] * Fp_factor;
696-
}
686+
calculate_s_one<6>(n, n_max_angular_plus_1, Fp, sum_fxyz, s6);
697687
accumulate_f12_one<6>(d12inv, fn_original, fnp_original, s6, r12unit, f12);
698688
}
699689

700690
if (L_max >= 7) {
701691
float s7[15];
702-
float Fp_factor = 2.0f * Fp[6 * n_max_angular_plus_1 + n];
703-
s7[0] = sum_fxyz[n * NUM_OF_ABC + 48] * C3B[48] * Fp_factor;
704-
Fp_factor *= 2.0f;
705-
for (int k = 1; k < 15; ++k) {
706-
s7[k] = sum_fxyz[n * NUM_OF_ABC + 48 + k] * C3B[48 + k] * Fp_factor;
707-
}
692+
calculate_s_one<7>(n, n_max_angular_plus_1, Fp, sum_fxyz, s7);
708693
accumulate_f12_one<7>(d12inv, fn_original, fnp_original, s7, r12unit, f12);
709694
}
710695

711696
if (L_max >= 8) {
712697
float s8[17];
713-
float Fp_factor = 2.0f * Fp[7 * n_max_angular_plus_1 + n];
714-
s8[0] = sum_fxyz[n * NUM_OF_ABC + 63] * C3B[63] * Fp_factor;
715-
Fp_factor *= 2.0f;
716-
for (int k = 1; k < 17; ++k) {
717-
s8[k] = sum_fxyz[n * NUM_OF_ABC + 63 + k] * C3B[63 + k] * Fp_factor;
718-
}
698+
calculate_s_one<8>(n, n_max_angular_plus_1, Fp, sum_fxyz, s8);
719699
accumulate_f12_one<8>(d12inv, fn_original, fnp_original, s8, r12unit, f12);
720700
}
721701
}

0 commit comments

Comments
 (0)