@@ -74,7 +74,7 @@ def add_hookean_constraint(image, des_rt = 2., rec_rt = 1., spring_constant=7.5,
74
74
image .set_constraint (cons )
75
75
76
76
77
- def cal_slab_energy (data , calc , traj_output = False , debug = False ,refixed = False ):
77
+ def cal_slab_energy (data , calc , traj_output = False , debug = False ,refixed = False , add_spring = True ):
78
78
79
79
testobj = Atoms (data .atomic_numbers , positions = data .pos , tags = data .tags ,
80
80
cell = data .cell .squeeze (), pbc = True )
@@ -95,9 +95,10 @@ def cal_slab_energy(data, calc, traj_output=False, debug=False,refixed=False):
95
95
c = FixAtoms (mask = data .fixed )
96
96
testobj .set_constraint (c )
97
97
98
- # added spring constant to prevent massive
99
- # surface reconstruction and desorption
100
- add_hookean_constraint (testobj )
98
+ if add_spring :
99
+ # added spring constant to prevent massive
100
+ # surface reconstruction and desorption
101
+ add_hookean_constraint (testobj )
101
102
102
103
if traj_output == True :
103
104
os .makedirs ("./trajs" , exist_ok = True )
@@ -120,10 +121,10 @@ def cal_slab_energy(data, calc, traj_output=False, debug=False,refixed=False):
120
121
return unrelax_slab_energy , relax_slab_energy , forces , pos_relaxed
121
122
122
123
123
- def add_info (data , calc , debug = False , traj_output = False ,refixed = False ):
124
+ def add_info (data , calc , debug = False , traj_output = False ,refixed = False , add_spring = True ):
124
125
125
126
unrelax_slab_energy , relax_slab_energy ,forces , pos_relaxed = \
126
- cal_slab_energy (data , calc , traj_output = traj_output , debug = debug ,refixed = refixed )
127
+ cal_slab_energy (data , calc , traj_output = traj_output , debug = debug ,refixed = refixed , add_spring = add_spring )
127
128
128
129
data .y = relax_slab_energy
129
130
data .unrelax_energy = unrelax_slab_energy
@@ -140,7 +141,7 @@ class MyThread(threading.Thread):
140
141
max_threads = 8
141
142
thread_limiter = threading .BoundedSemaphore (max_threads )
142
143
143
- def __init__ (self , datalist , pathname , gpus = 0 , debug = False , skip_ads = None ,refixed = False ):
144
+ def __init__ (self , datalist , pathname , gpus = 0 , debug = False , skip_ads = None ,refixed = False , add_spring = True ):
144
145
145
146
threading .Thread .__init__ (self )
146
147
@@ -149,6 +150,7 @@ def __init__(self, datalist, pathname, gpus=0, debug=False, skip_ads=None,refixe
149
150
self .gpus = gpus
150
151
self .debug = debug
151
152
self .refixed = refixed
153
+ self .add_spring = add_spring
152
154
153
155
# Make a list of rid that have already been done and converged to be skipped
154
156
if os .path .isfile (pathname ):
@@ -182,7 +184,7 @@ def run(self):
182
184
continue
183
185
# run predictions here
184
186
try :
185
- data = add_info (data , calc , debug = self .debug ,refixed = self .refixed )
187
+ data = add_info (data , calc , debug = self .debug ,refixed = self .refixed , add_spring = self . add_spring )
186
188
except RuntimeError :
187
189
continue
188
190
data_list_E .append (data )
0 commit comments