Skip to content

Commit af2f777

Browse files
committed
Add hardcore potential only on exchange term
1 parent 6ae17c7 commit af2f777

File tree

2 files changed

+40
-77
lines changed

2 files changed

+40
-77
lines changed

dmff/admp/pairwise.py

100755100644
+29-8
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def generate_pairwise_interaction(pair_int_kernel, static_args):
6363
with the order in kernel
6464
'''
6565

66-
def pair_int(positions, box, pairs, mScales, *atomic_params):
66+
def pair_int(positions, box, pairs, mScales, *atomic_params):
6767
# pairs = regularize_pairs(pairs)
6868
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
6969

@@ -77,7 +77,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
7777
buffer_scales = pair_buffer_scales(pairs)
7878
mscales = mscales * buffer_scales
7979
# mscales = mScales[nbonds-1]
80-
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
80+
box_inv = jnp.linalg.inv(box)
8181
dr = ri - rj
8282
dr = v_pbc_shift(dr, box, box_inv)
8383
dr = jnp.linalg.norm(dr, axis=1)
@@ -89,7 +89,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params):
8989
# pair_params.append(param[pairs[:, 0]])
9090
# pair_params.append(param[pairs[:, 1]])
9191

92-
energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales)
92+
energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales)
9393
return energy
9494

9595
return pair_int
@@ -155,7 +155,9 @@ def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j):
155155

156156
@vmap
157157
@jit_condition(static_argnums=())
158-
def slater_sr_kernel(dr, m, ai, aj, bi, bj):
158+
# with hardcore potential
159+
def slater_sr_hc_kernel(dr, m, ai, aj, bi, bj):
160+
159161
'''
160162
Slater-ISA type short range terms
161163
see jctc 12 3851
@@ -165,11 +167,30 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj):
165167
br = b * dr
166168
br2 = br * br
167169
P = 1/3 * br2 + br + 1
168-
# hard core potential
169-
x = 3.9/br
170+
171+
alpha = 0.24
172+
beta = 14
173+
x = alpha * br
170174
x2 = x * x
171175
x4 = x2 * x2
172176
x8 = x4 * x4
173-
x14 = x8 * x2 * x4
174-
return a * (P * jnp.exp(-br) + x14) * m
177+
x12 = x4 * x8
178+
x14 = x12 * x2
179+
HardCorePotential = a / x14 * m
180+
return a * P * jnp.exp(-br) * m + HardCorePotential
181+
182+
@vmap
183+
@jit_condition(static_argnums=())
184+
def slater_sr_kernel(dr, m, ai, aj, bi, bj):
185+
186+
'''
187+
Slater-ISA type short range terms
188+
see jctc 12 3851
189+
'''
190+
b = jnp.sqrt(bi * bj)
191+
a = ai * aj
192+
br = b * dr
193+
br2 = br * br
194+
P = 1/3 * br2 + br + 1
175195

196+
return a * P * jnp.exp(-br) * m

dmff/generators/admp.py

+11-69
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
TT_damping_qq_c6_kernel,
1414
generate_pairwise_interaction,
1515
slater_disp_damping_kernel,
16-
slater_sr_kernel,
16+
slater_sr_kernel, ## no Hard Core Potential
17+
slater_sr_hc_kernel, ## added Hard Core Potential
1718
TT_damping_qq_kernel,
1819
)
1920
from ..admp.pme import ADMPPmeForce
@@ -759,20 +760,21 @@ def createPotential(
759760

760761
topdata._meta[self.name+"_map_atomtype"] = map_atomtype
761762

762-
pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})
763+
pot_fn_sr = generate_pairwise_interaction(slater_sr_hc_kernel, static_args={})
764+
#slater_ex_sr_kernel: added Hard Core Potential
763765

764766
has_aux = False
765767
if "has_aux" in kwargs and kwargs["has_aux"]:
766768
has_aux = True
767769

768-
def potential_fn(positions, box, pairs, params, aux=None):
770+
def potential_fn(positions, box, pairs, params, aux=None):
769771
positions = positions * 10
770772
box = box * 10
771773
params = params[self.name]
772774
a_list = params["A"][map_atomtype]
773775
b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1
774776

775-
energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list)
777+
energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list)
776778
if has_aux:
777779
return energy, aux
778780
else:
@@ -790,6 +792,7 @@ def getJaxPotential(self):
790792
_DMFFGenerators["SlaterExForce"] = SlaterExGenerator
791793

792794

795+
793796
# Here are all the short range "charge penetration" terms
794797
# They all have the exchange form with minus sign
795798
class SlaterSrEsGenerator(SlaterExGenerator):
@@ -798,7 +801,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet, default_name=None):
798801
super().__init__(ffinfo, paramset, default_name="SlaterSrEsForce")
799802
else:
800803
super().__init__(ffinfo, paramset, default_name=default_name)
801-
802804
def createPotential(
803805
self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs
804806
):
@@ -812,14 +814,14 @@ def createPotential(
812814

813815
topdata._meta[self.name+"_map_atomtype"] = map_atomtype
814816

815-
pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
816-
static_args={})
817+
pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})
818+
## slater_sr_others_kernel: no Hard Core Potential
817819

818820
has_aux = False
819821
if "has_aux" in kwargs and kwargs["has_aux"]:
820822
has_aux = True
821823

822-
def potential_fn(positions, box, pairs, params, aux=None):
824+
def potential_fn(positions, box, pairs, params, aux=None):
823825
positions = positions * 10
824826
box = box * 10
825827
params = params[self.name]
@@ -934,10 +936,7 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
934936
for node in self.ffinfo["Forces"][self.name]["node"]
935937
if node["name"] in ["Multipole", "Atom"]
936938
]
937-
c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ, oXXX, oXXY, oXYY, oYYY, oXXZ, oXYZ, oYYZ, oXZZ, oYZZ, oZZZ = (
938-
[],
939-
[],
940-
[],
939+
c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ = (
941940
[],
942941
[],
943942
[],
@@ -948,13 +947,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
948947
[],
949948
[],
950949
[],
951-
[],
952-
[],
953-
[],
954-
[],
955-
[],
956-
[],
957-
[]
958950
)
959951
kxs, kys, kzs = [], [], []
960952
multipole_masks = []
@@ -997,29 +989,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
997989
qXY.append(0.0)
998990
qXZ.append(0.0)
999991
qYZ.append(0.0)
1000-
if self.lmax >= 3:
1001-
oXXX.append(float(attribs["oXXX"]))
1002-
oXXY.append(float(attribs["oXXY"]))
1003-
oXYY.append(float(attribs["oXYY"]))
1004-
oYYY.append(float(attribs["oYYY"]))
1005-
oXXZ.append(float(attribs["oXXZ"]))
1006-
oXYZ.append(float(attribs["oXYZ"]))
1007-
oYYZ.append(float(attribs["oYYZ"]))
1008-
oXZZ.append(float(attribs["oXZZ"]))
1009-
oYZZ.append(float(attribs["oYZZ"]))
1010-
oZZZ.append(float(attribs["oZZZ"]))
1011-
else:
1012-
oXXX.append(0.0)
1013-
oXXY.append(0.0)
1014-
oXYY.append(0.0)
1015-
oYYY.append(0.0)
1016-
oXXZ.append(0.0)
1017-
oXYZ.append(0.0)
1018-
oYYZ.append(0.0)
1019-
oXZZ.append(0.0)
1020-
oYZZ.append(0.0)
1021-
oZZZ.append(0.0)
1022-
1023992
mask = 1.0
1024993
if "mask" in attribs and attribs["mask"].upper() == "TRUE":
1025994
mask = 0.0
@@ -1077,8 +1046,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
10771046
n_mtps = 4
10781047
elif self.lmax == 2:
10791048
n_mtps = 10
1080-
elif self.lmax == 3:
1081-
n_mtps = 20
10821049
Q = np.zeros((n_atoms, n_mtps))
10831050

10841051
# TDDO: unit conversion
@@ -1096,19 +1063,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
10961063
Q[:, 8] = qXZ
10971064
Q[:, 9] = qYZ
10981065
Q[:, 4:10] *= 300
1099-
if self.lmax >= 3:
1100-
Q[:, 10] = oXXX
1101-
Q[:, 11] = oXXY
1102-
Q[:, 12] = oXYY
1103-
Q[:, 13] = oYYY
1104-
Q[:, 14] = oXXZ
1105-
Q[:, 15] = oXYZ
1106-
Q[:, 16] = oYYZ
1107-
Q[:, 17] = oXZZ
1108-
Q[:, 18] = oYZZ
1109-
Q[:, 19] = oZZZ
1110-
# TO DO: To be decided
1111-
Q[:, 10:20] *= 15000
11121066

11131067
# add all differentiable params to self.params
11141068
Q_local = convert_cart2harm(jnp.array(Q), self.lmax)
@@ -1138,18 +1092,6 @@ def overwrite(self, paramset):
11381092
node["attrib"]["qXY"] = Q_global[n_multipole, 7] / 300.0
11391093
node["attrib"]["qXZ"] = Q_global[n_multipole, 8] / 300.0
11401094
node["attrib"]["qYZ"] = Q_global[n_multipole, 9] / 300.0
1141-
if self.lmax >= 3:
1142-
node["attrib"]["oXXX"] = Q_global[n_multipole, 10] / 15000.0
1143-
node["attrib"]["oXXY"] = Q_global[n_multipole, 11] / 15000.0
1144-
node["attrib"]["oXYY"] = Q_global[n_multipole, 12] / 15000.0
1145-
node["attrib"]["oYYY"] = Q_global[n_multipole, 13] / 15000.0
1146-
node["attrib"]["oXXZ"] = Q_global[n_multipole, 14] / 15000.0
1147-
node["attrib"]["oXYZ"] = Q_global[n_multipole, 15] / 15000.0
1148-
node["attrib"]["oYYZ"] = Q_global[n_multipole, 16] / 15000.0
1149-
node["attrib"]["oXZZ"] = Q_global[n_multipole, 17] / 15000.0
1150-
node["attrib"]["oYZZ"] = Q_global[n_multipole, 18] / 15000.0
1151-
node["attrib"]["oZZZ"] = Q_global[n_multipole, 19] / 15000.0
1152-
11531095
if q_local_masks[n_multipole] < 0.999:
11541096
node["mask"] = "true"
11551097
n_multipole += 1

0 commit comments

Comments
 (0)