Skip to content

Commit e58ada4

Browse files
authored
Merge pull request brucefan1983#907 from Yi-FanLi/avirial
Support fitting atomic virial, dipole, and polarizability
2 parents fc74dbe + 68ca346 commit e58ada4

File tree

10 files changed

+392
-20
lines changed

10 files changed

+392
-20
lines changed

doc/nep/input_files/nep_in.rst

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ Keywords
6666
- weight of force loss term
6767
* - :ref:`lambda_v <kw_lambda_v>`
6868
- weight of virial loss term
69+
* - :ref:`atomic_v <kw_atomic_v>`
70+
- fit atomic or global virial
6971
* - :ref:`force_delta <kw_force_delta>`
7072
- bias term that can be used to make smaller forces more accurate
7173
* - :ref:`batch <kw_batch>`

doc/nep/input_parameters/atomic_v.rst

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.. _kw_atomic_v:
2+
.. index::
3+
single: atomic_v (keyword in nep.in)
4+
5+
:attr:`atomic_v`
6+
================
7+
8+
This keyword sets the mode :math:`\atomic_v` of whether to fit atomic or global quantities for dipole (`model_type = 1`) or polarizability (`model_type = 2`). Only one of atomic and global can be fitted at a time. Fitting both simultaneously is not supported. For the virial tensor (`model_type = 0`), only the global model is supported.
9+
The syntax is::
10+
11+
atomic_v <mode>
12+
13+
where :attr:`<mode>` must be an integer that can assume one of the following values.
14+
15+
===== ===========================
16+
Value Mode
17+
----- ---------------------------
18+
0 fit global tensor (default)
19+
1 fit atomic tensor
20+
===== ===========================

src/main_nep/dataset.cu

+204
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "utilities/gpu_macro.cuh"
2222
#include "utilities/nep_utilities.cuh"
2323
#include <cstring>
24+
#include <iostream>
25+
#include <stdexcept>
2426

2527
void Dataset::copy_structures(std::vector<Structure>& structures_input, int n1, int n2)
2628
{
@@ -32,6 +34,8 @@ void Dataset::copy_structures(std::vector<Structure>& structures_input, int n1,
3234
structures[n].num_atom = structures_input[n_input].num_atom;
3335
structures[n].weight = structures_input[n_input].weight;
3436
structures[n].has_virial = structures_input[n_input].has_virial;
37+
structures[n].has_atomic_virial = structures_input[n_input].has_atomic_virial;
38+
structures[n].atomic_virial_diag_only = structures_input[n_input].atomic_virial_diag_only;
3539
structures[n].charge = structures_input[n_input].charge;
3640
structures[n].energy = structures_input[n_input].energy;
3741
structures[n].energy_weight = structures_input[n_input].energy_weight;
@@ -68,6 +72,33 @@ void Dataset::copy_structures(std::vector<Structure>& structures_input, int n1,
6872
structures[n].fy[na] = structures_input[n_input].fy[na];
6973
structures[n].fz[na] = structures_input[n_input].fz[na];
7074
}
75+
76+
if (structures[n].has_atomic_virial != structures[0].has_atomic_virial) {
77+
throw std::runtime_error("All structures must have the same has_atomic_virial flag.");
78+
}
79+
if (structures[n].atomic_virial_diag_only != structures[0].atomic_virial_diag_only) {
80+
throw std::runtime_error("All structures must have the same atomic_virial_diag_only flag.");
81+
}
82+
if (structures[n].has_atomic_virial) {
83+
structures[n].avirialxx.resize(structures[n].num_atom);
84+
structures[n].avirialyy.resize(structures[n].num_atom);
85+
structures[n].avirialzz.resize(structures[n].num_atom);
86+
for (int na = 0; na < structures[n].num_atom; ++na) {
87+
structures[n].avirialxx[na] = structures_input[n_input].avirialxx[na];
88+
structures[n].avirialyy[na] = structures_input[n_input].avirialyy[na];
89+
structures[n].avirialzz[na] = structures_input[n_input].avirialzz[na];
90+
}
91+
if (!structures[n].atomic_virial_diag_only) {
92+
structures[n].avirialxy.resize(structures[n].num_atom);
93+
structures[n].avirialyz.resize(structures[n].num_atom);
94+
structures[n].avirialzx.resize(structures[n].num_atom);
95+
for (int na = 0; na < structures[n].num_atom; ++na) {
96+
structures[n].avirialxy[na] = structures_input[n_input].avirialxy[na];
97+
structures[n].avirialyz[na] = structures_input[n_input].avirialyz[na];
98+
structures[n].avirialzx[na] = structures_input[n_input].avirialzx[na];
99+
}
100+
}
101+
}
71102
}
72103
}
73104

@@ -142,6 +173,9 @@ void Dataset::initialize_gpu_data(Parameters& para)
142173
energy_weight_cpu.resize(Nc);
143174
virial_ref_cpu.resize(Nc * 6);
144175
force_ref_cpu.resize(N * 3);
176+
if (structures[0].has_atomic_virial) {
177+
avirial_ref_cpu.resize(N * (structures[0].atomic_virial_diag_only ? 3 : 6));
178+
}
145179
temperature_ref_cpu.resize(N);
146180

147181
for (int n = 0; n < Nc; ++n) {
@@ -170,6 +204,16 @@ void Dataset::initialize_gpu_data(Parameters& para)
170204
force_ref_cpu[Na_sum_cpu[n] + na + N] = structures[n].fy[na];
171205
force_ref_cpu[Na_sum_cpu[n] + na + N * 2] = structures[n].fz[na];
172206
temperature_ref_cpu[Na_sum_cpu[n] + na] = structures[n].temperature;
207+
if (structures[n].has_atomic_virial) {
208+
avirial_ref_cpu[Na_sum_cpu[n] + na] = structures[n].avirialxx[na];
209+
avirial_ref_cpu[Na_sum_cpu[n] + na + N] = structures[n].avirialyy[na];
210+
avirial_ref_cpu[Na_sum_cpu[n] + na + N * 2] = structures[n].avirialzz[na];
211+
if (!structures[n].atomic_virial_diag_only) {
212+
avirial_ref_cpu[Na_sum_cpu[n] + na + N * 3] = structures[n].avirialxy[na];
213+
avirial_ref_cpu[Na_sum_cpu[n] + na + N * 4] = structures[n].avirialyz[na];
214+
avirial_ref_cpu[Na_sum_cpu[n] + na + N * 5] = structures[n].avirialzx[na];
215+
}
216+
}
173217
}
174218
}
175219

@@ -179,13 +223,19 @@ void Dataset::initialize_gpu_data(Parameters& para)
179223
energy_weight_gpu.resize(Nc);
180224
virial_ref_gpu.resize(Nc * 6);
181225
force_ref_gpu.resize(N * 3);
226+
if (structures[0].has_atomic_virial) {
227+
avirial_ref_gpu.resize(N * (structures[0].atomic_virial_diag_only ? 3 : 6));
228+
}
182229
temperature_ref_gpu.resize(N);
183230
type_weight_gpu.copy_from_host(para.type_weight_cpu.data());
184231
charge_ref_gpu.copy_from_host(charge_ref_cpu.data());
185232
energy_ref_gpu.copy_from_host(energy_ref_cpu.data());
186233
energy_weight_gpu.copy_from_host(energy_weight_cpu.data());
187234
virial_ref_gpu.copy_from_host(virial_ref_cpu.data());
188235
force_ref_gpu.copy_from_host(force_ref_cpu.data());
236+
if (structures[0].has_atomic_virial) {
237+
avirial_ref_gpu.copy_from_host(avirial_ref_cpu.data());
238+
}
189239
temperature_ref_gpu.copy_from_host(temperature_ref_cpu.data());
190240

191241
box.resize(Nc * 18);
@@ -447,6 +497,157 @@ std::vector<float> Dataset::get_rmse_force(Parameters& para, const bool use_weig
447497
return rmse_array;
448498
}
449499

500+
static __global__ void gpu_sum_avirial_diag_only_error(
501+
const int N,
502+
int* g_Na,
503+
int* g_Na_sum,
504+
int* g_type,
505+
float* g_type_weight,
506+
float* g_virial,
507+
float* g_avxx_ref,
508+
float* g_avyy_ref,
509+
float* g_avzz_ref,
510+
float* error_gpu)
511+
{
512+
int tid = threadIdx.x;
513+
int bid = blockIdx.x;
514+
int N1 = g_Na_sum[bid];
515+
int N2 = N1 + g_Na[bid];
516+
extern __shared__ float s_error[];
517+
s_error[tid] = 0.0f;
518+
519+
for (int n = N1 + tid; n < N2; n += blockDim.x) {
520+
float avxx_ref = g_avxx_ref[n];
521+
float avyy_ref = g_avyy_ref[n];
522+
float avzz_ref = g_avzz_ref[n];
523+
float dxx = g_virial[n] - avxx_ref;
524+
float dyy = g_virial[1 * N + n] - avyy_ref;
525+
float dzz = g_virial[2 * N + n] - avzz_ref;
526+
float diff_square = dxx * dxx + dyy * dyy + dzz * dzz;
527+
s_error[tid] += diff_square;
528+
}
529+
__syncthreads();
530+
531+
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
532+
if (tid < offset) {
533+
s_error[tid] += s_error[tid + offset];
534+
}
535+
__syncthreads();
536+
}
537+
538+
if (tid == 0) {
539+
error_gpu[bid] = s_error[0];
540+
}
541+
}
542+
543+
static __global__ void gpu_sum_avirial_error(
544+
const int N,
545+
int* g_Na,
546+
int* g_Na_sum,
547+
int* g_type,
548+
float* g_type_weight,
549+
float* g_virial,
550+
float* g_avxx_ref,
551+
float* g_avyy_ref,
552+
float* g_avzz_ref,
553+
float* g_avxy_ref,
554+
float* g_avyz_ref,
555+
float* g_avzx_ref,
556+
float* error_gpu)
557+
{
558+
int tid = threadIdx.x;
559+
int bid = blockIdx.x;
560+
int N1 = g_Na_sum[bid];
561+
int N2 = N1 + g_Na[bid];
562+
extern __shared__ float s_error[];
563+
s_error[tid] = 0.0f;
564+
565+
for (int n = N1 + tid; n < N2; n += blockDim.x) {
566+
float avxx_ref = g_avxx_ref[n];
567+
float avyy_ref = g_avyy_ref[n];
568+
float avzz_ref = g_avzz_ref[n];
569+
float avxy_ref = g_avxy_ref[n];
570+
float avyz_ref = g_avyz_ref[n];
571+
float avzx_ref = g_avzx_ref[n];
572+
float dxx = g_virial[n] - avxx_ref;
573+
float dyy = g_virial[1 * N + n] - avyy_ref;
574+
float dzz = g_virial[2 * N + n] - avzz_ref;
575+
float dxy = g_virial[3 * N + n] - avxy_ref;
576+
float dyz = g_virial[4 * N + n] - avyz_ref;
577+
float dzx = g_virial[5 * N + n] - avzx_ref;
578+
float diff_square = dxx * dxx + dyy * dyy + dzz * dzz + dxy * dxy + dyz * dyz + dzx * dzx;
579+
s_error[tid] += diff_square;
580+
}
581+
__syncthreads();
582+
583+
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
584+
if (tid < offset) {
585+
s_error[tid] += s_error[tid + offset];
586+
}
587+
__syncthreads();
588+
}
589+
590+
if (tid == 0) {
591+
error_gpu[bid] = s_error[0];
592+
}
593+
}
594+
595+
std::vector<float> Dataset::get_rmse_avirial(Parameters& para, const bool use_weight, int device_id)
596+
{
597+
CHECK(gpuSetDevice(device_id));
598+
const int block_size = 256;
599+
600+
if (structures[0].atomic_virial_diag_only) {
601+
gpu_sum_avirial_diag_only_error<<<Nc, block_size, sizeof(float) * block_size>>>(
602+
N,
603+
Na.data(),
604+
Na_sum.data(),
605+
type.data(),
606+
type_weight_gpu.data(),
607+
virial.data(),
608+
avirial_ref_gpu.data(),
609+
avirial_ref_gpu.data() + N,
610+
avirial_ref_gpu.data() + N * 2,
611+
error_gpu.data());
612+
} else {
613+
gpu_sum_avirial_error<<<Nc, block_size, sizeof(float) * block_size>>>(
614+
N,
615+
Na.data(),
616+
Na_sum.data(),
617+
type.data(),
618+
type_weight_gpu.data(),
619+
virial.data(),
620+
avirial_ref_gpu.data(),
621+
avirial_ref_gpu.data() + N,
622+
avirial_ref_gpu.data() + N * 2,
623+
avirial_ref_gpu.data() + N * 3,
624+
avirial_ref_gpu.data() + N * 4,
625+
avirial_ref_gpu.data() + N * 5,
626+
error_gpu.data());
627+
}
628+
int mem = sizeof(float) * Nc;
629+
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
630+
631+
std::vector<float> rmse_array(para.num_types + 1, 0.0f);
632+
std::vector<int> count_array(para.num_types + 1, 0);
633+
for (int n = 0; n < Nc; ++n) {
634+
float rmse_temp = use_weight ? weight_cpu[n] * weight_cpu[n] * error_cpu[n] : error_cpu[n];
635+
for (int t = 0; t < para.num_types + 1; ++t) {
636+
if (has_type[t * Nc + n]) {
637+
rmse_array[t] += rmse_temp;
638+
count_array[t] += Na_cpu[n];
639+
}
640+
}
641+
}
642+
643+
for (int t = 0; t <= para.num_types; ++t) {
644+
if (count_array[t] > 0) {
645+
rmse_array[t] = sqrt(rmse_array[t] / (count_array[t] * 6));
646+
}
647+
}
648+
return rmse_array;
649+
}
650+
450651
static __global__ void
451652
gpu_get_energy_shift(
452653
int* g_Na,
@@ -625,6 +826,9 @@ static __global__ void gpu_sum_virial_error(
625826

626827
std::vector<float> Dataset::get_rmse_virial(Parameters& para, const bool use_weight, int device_id)
627828
{
829+
if (para.atomic_v) {
830+
return get_rmse_avirial(para, use_weight, device_id);
831+
}
628832
CHECK(gpuSetDevice(device_id));
629833

630834
std::vector<float> rmse_array(para.num_types + 1, 0.0f);

src/main_nep/dataset.cuh

+5
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,26 @@ public:
4343
GPU_Vector<float> energy; // calculated energy in GPU
4444
GPU_Vector<float> virial; // calculated virial in GPU
4545
GPU_Vector<float> force; // calculated force in GPU
46+
GPU_Vector<float> avirial; // calculated atomic virial in GPU
4647
std::vector<float> charge_cpu; // calculated charge in CPU
4748
std::vector<float> energy_cpu; // calculated energy in CPU
4849
std::vector<float> virial_cpu; // calculated virial in CPU
4950
std::vector<float> force_cpu; // calculated force in CPU
51+
std::vector<float> avirial_cpu; // calculated atomic virial in CPU
5052

5153
GPU_Vector<float> energy_weight_gpu; // energy weight in GPU
5254
GPU_Vector<float> charge_ref_gpu; // reference charge in GPU
5355
GPU_Vector<float> energy_ref_gpu; // reference energy in GPU
5456
GPU_Vector<float> virial_ref_gpu; // reference virial in GPU
5557
GPU_Vector<float> force_ref_gpu; // reference force in GPU
58+
GPU_Vector<float> avirial_ref_gpu; // reference atomic virial in GPU
5659
GPU_Vector<float> temperature_ref_gpu; // reference temperature in GPU
5760
std::vector<float> energy_weight_cpu; // energy weight in CPU
5861
std::vector<float> charge_ref_cpu; // reference charge in CPU
5962
std::vector<float> energy_ref_cpu; // reference energy in CPU
6063
std::vector<float> virial_ref_cpu; // reference virial in CPU
6164
std::vector<float> force_ref_cpu; // reference force in CPU
65+
std::vector<float> avirial_ref_cpu; // reference atomic virial in CPU
6266
std::vector<float> weight_cpu; // configuration weight in CPU
6367
std::vector<float> temperature_ref_cpu; // reference temeprature in CPU
6468

@@ -81,6 +85,7 @@ public:
8185
const bool do_shift,
8286
int device_id);
8387
std::vector<float> get_rmse_virial(Parameters& para, const bool use_weight, int device_id);
88+
std::vector<float> get_rmse_avirial(Parameters& para, const bool use_weight, int device_id);
8489
std::vector<float> get_rmse_charge(Parameters& para, int device_id);
8590

8691
private:

0 commit comments

Comments
 (0)