Skip to content

Commit 8bba83a

Browse files
authored
Improve gen grids (#67)
* save * implement SMD solvent model * updated readme * fixed issues after v0.6.9 * format examples * output = /dev/null in test_smd.py * evaluate sparse ao directly * consistent grids with PySCF * update README.md
1 parent 5d660d8 commit 8bba83a

24 files changed

+403
-249
lines changed

README.md

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,22 @@ Installation
44
--------
55

66
> [!NOTE]
7-
> The compiled binary packages support compute capability 7.0 and later (Volta and later, such as Tesla V100, RTX 20 series and later). For older GPUs, please compile the package with the source code as follows.
7+
> The compiled binary packages support compute capability 7.0 and later (Volta and later, such as Tesla V100, RTX 20 series and later). For older GPUs (GTX 10**, Tesla P100), please compile the package with the source code as follows.
88
9-
For **CUDA 11.x**
10-
```sh
11-
pip3 install gpu4pyscf-cuda11x
12-
```
13-
and install cutensor
14-
```sh
15-
python -m cupyx.tools.install_library --cuda 11.x --library cutensor
16-
```
9+
Run ```nvidia-smi``` in your terminal to check the installed CUDA version.
1710

18-
For **CUDA 12.x**
19-
```sh
20-
pip3 install gpu4pyscf-cuda12x
21-
```
22-
and install cutensor
23-
```sh
24-
python -m cupyx.tools.install_library --cuda 12.x --library cutensor
25-
```
11+
Choose the proper package based on your CUDA environment.
12+
13+
| Platform | Command |
14+
----------------| --------------------------------------|
15+
| **CUDA 11.x** | ```pip3 install gpu4pyscf-cuda11x``` |
16+
| **CUDA 12.x** | ```pip3 install gpu4pyscf-cuda12x``` |
17+
18+
```cuTensor``` is **highly recommended** to be installed for accelerating tensor contractions.
19+
20+
For **CUDA 11.x**, ```python -m cupyx.tools.install_library --cuda 11.x --library cutensor```
21+
22+
For **CUDA 12.x**, ```python -m cupyx.tools.install_library --cuda 12.x --library cutensor```
2623

2724
Compilation
2825
--------

examples/dft_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
mf_df.nlcgrids.atom_grid = (50,194)
5252
mf_df.direct_scf_tol = 1e-14
5353
mf_df.direct_scf = 1e-14
54-
mf_df.conv_tol = 1e-12
54+
mf_df.conv_tol = 1e-10
5555
e_tot = mf_df.kernel()
5656
scf_time = time.time() - start_time
5757
print(f'compute time for energy: {scf_time:.3f} s')

gpu4pyscf/__config__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
GB = 1024*1024*1024
55
# such as A100-80G
66
if props['totalGlobalMem'] >= 64 * GB:
7-
min_ao_blksize = 256
7+
min_ao_blksize = 128
88
min_grid_blksize = 128*128
99
ao_aligned = 32
1010
grid_aligned = 128

gpu4pyscf/df/tests/test_df_ecp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def setUpModule():
3737
output = '/dev/null'
3838
)
3939

40-
mol.build()
41-
mol.verbose = 3
42-
4340
def tearDownModule():
4441
global mol
4542
mol.stdout.close()

gpu4pyscf/df/tests/test_df_grad.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@
3535
H 0.7570000000 0.0000000000 -0.4696000000
3636
'''
3737

38-
xc0='B3LYP'
39-
bas0='def2-tzvpp'
40-
auxbasis0='def2-tzvpp-jkfit'
41-
disp0='d3bj'
38+
xc0 = 'B3LYP'
39+
bas0 = 'def2-tzvpp'
40+
auxbasis0 = 'def2-tzvpp-jkfit'
41+
disp0 = 'd3bj'
4242
grids_level = 6
4343
nlcgrids_level = 3
4444
def setUpModule():
4545
global mol
4646
mol = pyscf.M(atom=atom, basis=bas0, max_memory=32000)
4747
mol.output = '/dev/null'
48-
mol.build()
4948
mol.verbose = 1
49+
mol.build()
5050

5151
eps = 1.0/1024
5252

gpu4pyscf/df/tests/test_df_hessian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
H 0.7570000000 0.0000000000 -0.4696000000
2525
'''
2626

27-
xc0='B3LYP'
28-
bas0='def2-tzvpp'
29-
auxbasis0='def2-tzvpp-jkfit'
30-
disp0='d3bj'
27+
xc0 = 'B3LYP'
28+
bas0 = 'def2-tzvpp'
29+
auxbasis0 = 'def2-tzvpp-jkfit'
30+
disp0 = 'd3bj'
3131
grids_level = 6
3232
eps = 1e-3
3333

gpu4pyscf/df/tests/test_int3c2e.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
libgint = load_library('libgint')
2727

2828
'''
29-
compare int3c2e by pyscf and gpu4pyscf
29+
check int3c2e consistency between pyscf and gpu4pyscf
3030
'''
3131

3232
def setUpModule():
@@ -41,13 +41,13 @@ def setUpModule():
4141
output='/dev/null')
4242
auxmol = df.addons.make_auxmol(mol, auxbasis='def2-tzvpp-jkfit')
4343
auxmol.output = '/dev/null'
44-
44+
4545
def tearDownModule():
4646
global mol, auxmol
4747
mol.stdout.close()
4848
auxmol.stdout.close()
4949
del mol, auxmol
50-
50+
5151
omega = 0.2
5252

5353
def check_int3c2e_derivatives(ip_type):
@@ -69,40 +69,40 @@ def check_int3c2e_derivatives(ip_type):
6969
int3c_pyscf = getints(intor, pmol._atm, pmol._bas, pmol._env, shls_slice, aosym='s1', cintopt=opt)
7070
int3c_gpu = int3c2e.get_int3c2e_general(mol, auxmol, ip_type=ip_type, omega=omega).get()
7171
assert np.linalg.norm(int3c_pyscf - int3c_gpu) < 1e-9
72-
72+
7373
class KnownValues(unittest.TestCase):
7474
def test_int3c2e(self):
7575
get_int3c = _int3c_wrapper(mol, auxmol, 'int3c2e', 's1')
7676
int3c_pyscf = get_int3c((0, mol.nbas, 0, mol.nbas, 0, auxmol.nbas))
7777
int3c_gpu = int3c2e.get_int3c2e(mol, auxmol, aosym='s1').get()
7878
assert np.linalg.norm(int3c_gpu - int3c_pyscf) < 1e-9
79-
79+
8080
def test_int3c2e_omega(self):
8181
omega = 0.2
8282
with mol.with_range_coulomb(omega):
8383
get_int3c = _int3c_wrapper(mol, auxmol, 'int3c2e', 's1')
8484
int3c_pyscf = get_int3c((0, mol.nbas, 0, mol.nbas, 0, auxmol.nbas))
8585
int3c_gpu = int3c2e.get_int3c2e(mol, auxmol, aosym='s1', omega=omega).get()
8686
assert np.linalg.norm(int3c_gpu[0,0,:] - int3c_pyscf[0,0,:]) < 1e-9
87-
87+
8888
def test_int3c2e_ip1(self):
8989
check_int3c2e_derivatives('ip1')
90-
90+
9191
def test_int3c2e_ip2(self):
9292
check_int3c2e_derivatives('ip2')
93-
93+
9494
def test_int3c2e_ipip1(self):
9595
check_int3c2e_derivatives('ipip1')
9696

9797
def test_int3c2e_ipip2(self):
9898
check_int3c2e_derivatives('ipip2')
99-
99+
100100
def test_int3c2e_ip1ip2(self):
101101
check_int3c2e_derivatives('ip1ip2')
102102

103103
def test_int3c2e_ipvip1(self):
104104
check_int3c2e_derivatives('ipvip1')
105-
105+
106106
def test_int1e_iprinv(self):
107107
from pyscf import gto
108108
coords = mol.atom_coords()
@@ -115,7 +115,7 @@ def test_int1e_iprinv(self):
115115
mol.set_rinv_origin(coords[i])
116116
h1ao = mol.intor('int1e_iprinv', comp=3) # <\nabla|1/r|>
117117
assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8
118-
118+
119119
if __name__ == "__main__":
120120
print("Full Tests for int3c")
121121
unittest.main()

gpu4pyscf/df/tests/test_jk.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ def setUpModule():
4141
mol.build()
4242
mol.verbose = 1
4343
auxmol = df.addons.make_auxmol(mol, auxbasis='sto3g')
44-
44+
4545
def tearDownModule():
4646
global mol, auxmol
47+
mol.stdout.close()
4748
del mol, auxmol
4849

4950
class KnownValues(unittest.TestCase):
50-
51+
5152
def test_vj_incore(self):
5253
int3c_gpu = int3c2e.get_int3c2e(mol, auxmol, aosym=True, direct_scf_tol=1e-14)
5354
intopt = int3c2e.VHFOpt(mol, auxmol, 'int2e')
@@ -66,7 +67,7 @@ def test_vj_incore(self):
6667
vj_outcore = cupy.einsum('ijL,L->ij', int3c_gpu, rhoj_outcore)
6768
vj_incore = int3c2e.get_j_int3c2e_pass2(intopt, rhoj_incore)
6869
assert cupy.linalg.norm(vj_outcore - vj_incore) < 1e-9
69-
70+
7071
def test_j_outcore(self):
7172
cupy.random.seed(1)
7273
nao = mol.nao
@@ -77,7 +78,7 @@ def test_j_outcore(self):
7778
vj0, _ = mf.get_jk(dm=dm, with_j=True, with_k=False)
7879
vj = df_jk.get_j(mf.with_df, dm)
7980
assert cupy.linalg.norm(vj - vj0) < 1e-9
80-
81+
8182

8283
if __name__ == "__main__":
8384
print("Full Tests for DF JK")

gpu4pyscf/dft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .uks import UKS
44
from .gks import GKS
55
from .roks import ROKS
6+
from gpu4pyscf.dft.gen_grid import Grids

gpu4pyscf/dft/gen_grid.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,22 @@ def gen_grids_partition(atm_coords, coords, a):
185185
stream = cupy.cuda.get_current_stream()
186186
natm = atm_coords.shape[0]
187187
ngrids = coords.shape[0]
188-
pbecke = cupy.ones([natm, ngrids], order='C')
189188
assert ngrids < 65535 * 16
189+
x_i = cupy.expand_dims(atm_coords, axis=1)
190+
x_g = cupy.expand_dims(coords, axis=0)
191+
squared_diff = (x_i - x_g)**2
192+
dist_ig = cupy.sum(squared_diff, axis=2)**0.5
193+
194+
x_j = cupy.expand_dims(atm_coords, axis=0)
195+
squared_diff = (x_i - x_j)**2
196+
dist_ij = cupy.sum(squared_diff, axis=2)**0.5
197+
198+
pbecke = cupy.ones([natm, ngrids], order='C')
190199
err = libgdft.GDFTgen_grid_partition(
191200
ctypes.cast(stream.ptr, ctypes.c_void_p),
192201
ctypes.cast(pbecke.data.ptr, ctypes.c_void_p),
193-
ctypes.cast(coords.data.ptr, ctypes.c_void_p),
194-
ctypes.cast(atm_coords.data.ptr, ctypes.c_void_p),
202+
ctypes.cast(dist_ig.data.ptr, ctypes.c_void_p),
203+
ctypes.cast(dist_ij.data.ptr, ctypes.c_void_p),
195204
ctypes.cast(a.data.ptr, ctypes.c_void_p),
196205
ctypes.c_int(ngrids),
197206
ctypes.c_int(natm)
@@ -243,13 +252,6 @@ def gen_atomic_grids(mol, atom_grid={}, radi_method=radi.gauss_chebyshev,
243252
logger.debug(mol, 'atom %s rad-grids = %d, ang-grids = %s',
244253
symb, n_rad, angs)
245254

246-
ang_grids = {}
247-
for n in sorted(set(angs)):
248-
grid = numpy.empty((n,4))
249-
libdft.MakeAngularGrid(grid.ctypes.data_as(ctypes.c_void_p),
250-
ctypes.c_int(n))
251-
ang_grids[n] = grid
252-
253255
angs = numpy.array(angs)
254256
coords = []
255257
vol = []
@@ -258,8 +260,13 @@ def gen_atomic_grids(mol, atom_grid={}, radi_method=radi.gauss_chebyshev,
258260
libdft.MakeAngularGrid(grid.ctypes.data_as(ctypes.c_void_p),
259261
ctypes.c_int(n))
260262
idx = numpy.where(angs==n)[0]
261-
coords.append(cupy.einsum('i,jk->jik', rad[idx], grid[:,:3]).reshape(-1,3))
262-
vol.append(cupy.einsum('i,j->ji', rad_weight[idx], grid[:,3]).ravel())
263+
for i0, i1 in lib.prange(0, len(idx), 12): # 12 radi-grids as a group
264+
coords.append(numpy.einsum('i,jk->jik',rad[idx[i0:i1]],
265+
grid[:,:3]).reshape(-1,3))
266+
vol.append(numpy.einsum('i,j->ji', rad_weight[idx[i0:i1]],
267+
grid[:,3]).ravel())
268+
#coords.append(cupy.einsum('i,jk->jik', rad[idx], grid[:,:3]).reshape(-1,3))
269+
#vol.append(cupy.einsum('i,j->ji', rad_weight[idx], grid[:,3]).ravel())
263270

264271
atom_grids_tab[symb] = (cupy.vstack(coords), cupy.hstack(vol))
265272

@@ -327,6 +334,7 @@ def gen_grid_partition(coords):
327334

328335
coords_all = []
329336
weights_all = []
337+
assert radii_adjust == radi.treutler_atomic_radii_adjust
330338
a = -radi.get_treutler_fac(mol, atomic_radii)
331339
for ia in range(mol.natm):
332340
coords, vol = atom_grids_tab[mol.atom_symbol(ia)]

gpu4pyscf/dft/numint.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
GRID_BLKSIZE = 32
3535
MIN_BLK_SIZE = getattr(__config__, 'min_grid_blksize', 64*64)
3636
ALIGNED = getattr(__config__, 'grid_aligned', 16*16)
37+
AO_ALIGNMENT = getattr(__config__, 'ao_aligned', 16)
3738
AO_THRESHOLD = 1e-12
38-
AO_ALIGNMENT = 32
3939

4040
# Should we release the cupy cache?
4141
FREE_CUPY_CACHE = False
@@ -564,6 +564,7 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
564564
nelec = nelec[0]
565565
excsum = excsum[0]
566566
vmat = vmat[0]
567+
567568
return nelec, excsum, vmat#np.asarray(vmat)
568569

569570
def nr_uks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
@@ -984,16 +985,31 @@ def nr_nlc_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
984985
if opt is None:
985986
ni.build(mol, grids.coords)
986987
opt = ni.gdftopt
988+
989+
mo_coeff = getattr(dms, 'mo_coeff', None)
990+
mo_occ = getattr(dms,'mo_occ', None)
991+
987992
nao, nao0 = opt.coeff.shape
988993
mol = opt.mol
989994
coeff = cupy.asarray(opt.coeff)
990995
dms = [coeff @ dm @ coeff.T for dm in dms.reshape(-1,nao0,nao0)]
996+
assert len(dms) == 1
997+
998+
if mo_coeff is not None:
999+
mo_coeff = coeff @ mo_coeff
1000+
9911001
ao_deriv = 1
9921002
vvrho = []
993-
for ao, mask, weight, coords \
1003+
for ao, idx, weight, coords \
9941004
in ni.block_loop(mol, grids, nao, ao_deriv, max_memory=max_memory):
995-
rho = eval_rho(opt.mol, ao, dms[0][np.ix_(mask,mask)], xctype='GGA', hermi=1)
1005+
#rho = eval_rho(opt.mol, ao, dms[0][np.ix_(mask,mask)], xctype='GGA', hermi=1)
1006+
if mo_coeff is None:
1007+
rho = eval_rho(mol, ao, dms[0][np.ix_(idx,idx)], xctype='GGA', hermi=1)
1008+
else:
1009+
mo_coeff_mask = mo_coeff[idx,:]
1010+
rho = eval_rho2(mol, ao, mo_coeff_mask, mo_occ, None, 'GGA')
9961011
vvrho.append(rho)
1012+
9971013
rho = cupy.hstack(vvrho)
9981014
t1 = log.timer_debug1('eval rho', *t0)
9991015
exc = 0
@@ -1227,7 +1243,7 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000,
12271243
# cache ao indices
12281244
if (deriv, block_id, blksize, ngrids) not in ni.non0ao_idx:
12291245
stream = cupy.cuda.get_current_stream()
1230-
cutoff = 1e-12
1246+
cutoff = AO_THRESHOLD
12311247
ng = ip1 - ip0
12321248
ao_loc = mol.ao_loc_nr()
12331249
nbas = mol.nbas

gpu4pyscf/dft/rks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def initialize_grids(ks, mol=None, dm=None):
9393
# Filter grids the first time setup grids
9494
ks.nlcgrids = prune_small_rho_grids_(ks, ks.mol, dm, ks.nlcgrids)
9595
t0 = logger.timer_debug1(ks, 'setting up nlc grids', *t0)
96-
9796
return ks
9897

9998
def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):

0 commit comments

Comments
 (0)