@@ -498,6 +498,25 @@ static __device__ __forceinline__ void get_f12_5body(
498
498
f12[2 ] += tmp1 * r12[2 ];
499
499
}
500
500
501
+ template <int L>
502
+ static __device__ __forceinline__ void calculate_s_one (
503
+ const int n,
504
+ const int n_max_angular_plus_1,
505
+ const float * Fp,
506
+ const float * sum_fxyz,
507
+ float * s)
508
+ {
509
+ const int L_minus_1 = L - 1 ;
510
+ const int L_twice_plus_1 = 2 * L + 1 ;
511
+ const int L_square_minus_1 = L * L - 1 ;
512
+ float Fp_factor = 2 .0f * Fp[L_minus_1 * n_max_angular_plus_1 + n];
513
+ s[0 ] = sum_fxyz[n * NUM_OF_ABC + L_square_minus_1] * C3B[L_square_minus_1] * Fp_factor;
514
+ Fp_factor *= 2 .0f ;
515
+ for (int k = 1 ; k < L_twice_plus_1; ++k) {
516
+ s[k] = sum_fxyz[n * NUM_OF_ABC + L_square_minus_1 + k] * C3B[L_square_minus_1 + k] * Fp_factor;
517
+ }
518
+ }
519
+
501
520
template <int L>
502
521
static __device__ __forceinline__ void accumulate_f12_one (
503
522
const float d12inv,
@@ -611,111 +630,72 @@ static __device__ __forceinline__ void accumulate_f12(
611
630
const float fnp_original = fnp;
612
631
const float d12inv = 1 .0f / d12;
613
632
const float r12unit[3 ] = {r12[0 ]*d12inv, r12[1 ]*d12inv, r12[2 ]*d12inv};
614
- // l = 1
633
+
615
634
fnp = fnp * d12inv - fn * d12inv * d12inv;
616
635
fn = fn * d12inv;
617
- float s1[3 ] = {
618
- sum_fxyz[n * NUM_OF_ABC + 0 ], sum_fxyz[n * NUM_OF_ABC + 1 ], sum_fxyz[n * NUM_OF_ABC + 2 ]};
619
636
if (num_L >= L_max + 2 ) {
637
+ float s1[3 ] = {
638
+ sum_fxyz[n * NUM_OF_ABC + 0 ], sum_fxyz[n * NUM_OF_ABC + 1 ], sum_fxyz[n * NUM_OF_ABC + 2 ]};
620
639
get_f12_5body (d12, d12inv, fn, fnp, Fp[(L_max + 1 ) * n_max_angular_plus_1 + n], s1, r12, f12);
621
640
}
622
641
623
642
if (L_max >= 1 ) {
624
- float Fp_factor = 2 .0f * Fp[n];
625
- s1[0 ] *= C3B[0 ] * Fp_factor;
626
- Fp_factor *= 2 .0f ;
627
- s1[1 ] *= C3B[1 ] * Fp_factor;
628
- s1[2 ] *= C3B[2 ] * Fp_factor;
643
+ float s1[3 ];
644
+ calculate_s_one<1 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s1);
629
645
accumulate_f12_one<1 >(d12inv, fn_original, fnp_original, s1, r12unit, f12);
630
646
}
631
647
632
- // l = 2
633
648
fnp = fnp * d12inv - fn * d12inv * d12inv;
634
649
fn = fn * d12inv;
635
- float s2[5 ] = {
636
- sum_fxyz[n * NUM_OF_ABC + 3 ],
637
- sum_fxyz[n * NUM_OF_ABC + 4 ],
638
- sum_fxyz[n * NUM_OF_ABC + 5 ],
639
- sum_fxyz[n * NUM_OF_ABC + 6 ],
640
- sum_fxyz[n * NUM_OF_ABC + 7 ]};
641
650
if (num_L >= L_max + 1 ) {
651
+ float s2[5 ] = {
652
+ sum_fxyz[n * NUM_OF_ABC + 3 ],
653
+ sum_fxyz[n * NUM_OF_ABC + 4 ],
654
+ sum_fxyz[n * NUM_OF_ABC + 5 ],
655
+ sum_fxyz[n * NUM_OF_ABC + 6 ],
656
+ sum_fxyz[n * NUM_OF_ABC + 7 ]};
642
657
get_f12_4body (d12, d12inv, fn, fnp, Fp[L_max * n_max_angular_plus_1 + n], s2, r12, f12);
643
658
}
644
659
645
660
if (L_max >= 2 ) {
646
- float Fp_factor = 2 .0f * Fp[n_max_angular_plus_1 + n];
647
- s2[0 ] *= C3B[3 ] * Fp_factor;
648
- Fp_factor *= 2 ;
649
- s2[1 ] *= C3B[4 ] * Fp_factor;
650
- s2[2 ] *= C3B[5 ] * Fp_factor;
651
- s2[3 ] *= C3B[6 ] * Fp_factor;
652
- s2[4 ] *= C3B[7 ] * Fp_factor;
661
+ float s2[5 ];
662
+ calculate_s_one<2 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s2);
653
663
accumulate_f12_one<2 >(d12inv, fn_original, fnp_original, s2, r12unit, f12);
654
664
}
655
665
656
666
if (L_max >= 3 ) {
657
667
float s3[7 ];
658
- float Fp_factor = 2 .0f * Fp[2 * n_max_angular_plus_1 + n];
659
- s3[0 ] = sum_fxyz[n * NUM_OF_ABC + 8 ] * C3B[8 ] * Fp_factor;
660
- Fp_factor *= 2 .0f ;
661
- for (int k = 1 ; k < 7 ; ++k) {
662
- s3[k] = sum_fxyz[n * NUM_OF_ABC + 8 + k] * C3B[8 + k] * Fp_factor;
663
- }
668
+ calculate_s_one<3 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s3);
664
669
accumulate_f12_one<3 >(d12inv, fn_original, fnp_original, s3, r12unit, f12);
665
670
}
666
671
667
672
if (L_max >= 4 ) {
668
673
float s4[9 ];
669
- float Fp_factor = 2 .0f * Fp[3 * n_max_angular_plus_1 + n];
670
- s4[0 ] = sum_fxyz[n * NUM_OF_ABC + 15 ] * C3B[15 ] * Fp_factor;
671
- Fp_factor *= 2 .0f ;
672
- for (int k = 1 ; k < 9 ; ++k) {
673
- s4[k] = sum_fxyz[n * NUM_OF_ABC + 15 + k] * C3B[15 + k] * Fp_factor;
674
- }
674
+ calculate_s_one<4 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s4);
675
675
accumulate_f12_one<4 >(d12inv, fn_original, fnp_original, s4, r12unit, f12);
676
676
}
677
677
678
678
if (L_max >= 5 ) {
679
679
float s5[11 ];
680
- float Fp_factor = 2 .0f * Fp[4 * n_max_angular_plus_1 + n];
681
- s5[0 ] = sum_fxyz[n * NUM_OF_ABC + 24 ] * C3B[24 ] * Fp_factor;
682
- Fp_factor *= 2 .0f ;
683
- for (int k = 1 ; k < 11 ; ++k) {
684
- s5[k] = sum_fxyz[n * NUM_OF_ABC + 24 + k] * C3B[24 + k] * Fp_factor;
685
- }
680
+ calculate_s_one<5 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s5);
686
681
accumulate_f12_one<5 >(d12inv, fn_original, fnp_original, s5, r12unit, f12);
687
682
}
688
683
689
684
if (L_max >= 6 ) {
690
685
float s6[13 ];
691
- float Fp_factor = 2 .0f * Fp[5 * n_max_angular_plus_1 + n];
692
- s6[0 ] = sum_fxyz[n * NUM_OF_ABC + 35 ] * C3B[35 ] * Fp_factor;
693
- Fp_factor *= 2 .0f ;
694
- for (int k = 1 ; k < 13 ; ++k) {
695
- s6[k] = sum_fxyz[n * NUM_OF_ABC + 35 + k] * C3B[35 + k] * Fp_factor;
696
- }
686
+ calculate_s_one<6 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s6);
697
687
accumulate_f12_one<6 >(d12inv, fn_original, fnp_original, s6, r12unit, f12);
698
688
}
699
689
700
690
if (L_max >= 7 ) {
701
691
float s7[15 ];
702
- float Fp_factor = 2 .0f * Fp[6 * n_max_angular_plus_1 + n];
703
- s7[0 ] = sum_fxyz[n * NUM_OF_ABC + 48 ] * C3B[48 ] * Fp_factor;
704
- Fp_factor *= 2 .0f ;
705
- for (int k = 1 ; k < 15 ; ++k) {
706
- s7[k] = sum_fxyz[n * NUM_OF_ABC + 48 + k] * C3B[48 + k] * Fp_factor;
707
- }
692
+ calculate_s_one<7 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s7);
708
693
accumulate_f12_one<7 >(d12inv, fn_original, fnp_original, s7, r12unit, f12);
709
694
}
710
695
711
696
if (L_max >= 8 ) {
712
697
float s8[17 ];
713
- float Fp_factor = 2 .0f * Fp[7 * n_max_angular_plus_1 + n];
714
- s8[0 ] = sum_fxyz[n * NUM_OF_ABC + 63 ] * C3B[63 ] * Fp_factor;
715
- Fp_factor *= 2 .0f ;
716
- for (int k = 1 ; k < 17 ; ++k) {
717
- s8[k] = sum_fxyz[n * NUM_OF_ABC + 63 + k] * C3B[63 + k] * Fp_factor;
718
- }
698
+ calculate_s_one<8 >(n, n_max_angular_plus_1, Fp, sum_fxyz, s8);
719
699
accumulate_f12_one<8 >(d12inv, fn_original, fnp_original, s8, r12unit, f12);
720
700
}
721
701
}
0 commit comments