Skip to content

Commit 2e5bc47

Browse files
authored
to_gpu and group grids (#69)
* fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__
1 parent 43f21be commit 2e5bc47

18 files changed

+521
-203
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Installation
66
> [!NOTE]
77
> The compiled binary packages support compute capability 7.0 and later (Volta and later, such as Tesla V100, RTX 20 series and later). For older GPUs (GTX 10**, Tesla P100), please compile the package with the source code as follows.
88
9-
Run ```nvidia-smi``` in your terminal to check the installed CUDA version.
9+
Run ```nvidia-smi``` in your terminal to check the installed CUDA version.
1010

1111
Choose the proper package based on your CUDA environment.
1212

@@ -15,7 +15,7 @@ Choose the proper package based on your CUDA environment.
1515
| **CUDA 11.x** | ```pip3 install gpu4pyscf-cuda11x``` |
1616
| **CUDA 12.x** | ```pip3 install gpu4pyscf-cuda12x``` |
1717

18-
```cuTensor``` is **highly recommended** to be installed for accelerating tensor contractions.
18+
```cuTensor``` is **highly recommended** for accelerating tensor contractions.
1919

2020
For **CUDA 11.x**, ```python -m cupyx.tools.install_library --cuda 11.x --library cutensor```
2121

@@ -59,7 +59,7 @@ Limitations
5959
- Rys roots up to 9 for direct scf scheme;
6060
- Atomic basis up to g orbitals;
6161
- Auxiliary basis up to h orbitals;
62-
- Up to ~168 atoms with def2-tzvpd basis, consuming a large amount of CPU memory;
62+
- Density fitting scheme up to ~168 atoms with def2-tzvpd basis, bounded CPU memory;
6363
- Hessian is unavailable for Direct SCF yet;
6464
- meta-GGA without density laplacian;
6565

examples/00-h2o.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# You should have received a copy of the GNU General Public License
1414
# along with this program. If not, see <http://www.gnu.org/licenses/>.
1515

16+
import numpy as np
17+
import cupy
1618
import pyscf
1719
from pyscf import lib
20+
from pyscf.hessian import thermo
1821
from gpu4pyscf.dft import rks
1922
lib.num_threads(8)
2023

@@ -30,13 +33,14 @@
3033
scf_tol = 1e-10
3134
max_scf_cycles = 50
3235
screen_tol = 1e-14
33-
grids_level = 3
36+
grids_level = 5
3437

35-
mol = pyscf.M(atom=atom, basis=bas, max_memory=32000, output='./pyscf.log')
38+
mol = pyscf.M(atom=atom, basis=bas, max_memory=32000)
3639

3740
mol.verbose = 4
3841
mf_GPU = rks.RKS(mol, xc=xc).density_fit(auxbasis=auxbasis)
3942
mf_GPU.grids.level = grids_level
43+
mf_GPU.grids.atom_grid = (99,590)
4044
mf_GPU.conv_tol = scf_tol
4145
mf_GPU.max_cycle = max_scf_cycles
4246
mf_GPU.screen_tol = screen_tol
@@ -55,3 +59,17 @@
5559
h = mf_GPU.Hessian()
5660
h.auxbasis_response = 2
5761
h_dft = h.kernel()
62+
63+
# harmonic analysis
64+
results = thermo.harmonic_analysis(mol, h_dft)
65+
thermo.dump_normal_mode(mol, results)
66+
67+
results = thermo.thermo(mf_GPU, results['freq_au'], 298.15, 101325)
68+
thermo.dump_thermo(mol, results)
69+
70+
# force translational symmetry
71+
natm = mol.natm
72+
h_dft = h_dft.transpose([0,2,1,3]).reshape(3*natm,3*natm)
73+
h_diag = h_dft.sum(axis=0)
74+
h_dft -= np.diag(h_diag)
75+
print(h_dft[:3,:3])

gpu4pyscf/__config__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
# such as A100-80G
66
if props['totalGlobalMem'] >= 64 * GB:
77
min_ao_blksize = 128
8-
min_grid_blksize = 128*128
8+
min_grid_blksize = 256*256#128*128
99
ao_aligned = 32
1010
grid_aligned = 128
1111
mem_fraction = 0.9
1212
number_of_threads = 2048 * 108
1313
# such as V100-32G
1414
elif props['totalGlobalMem'] >= 32 * GB:
1515
min_ao_blksize = 128
16-
min_grid_blksize = 128*128
16+
min_grid_blksize = 256*256#128*128
1717
ao_aligned = 32
1818
grid_aligned = 128
1919
mem_fraction = 0.9
2020
number_of_threads = 1024 * 80
2121
# such as A30-24GB
2222
elif props['totalGlobalMem'] >= 16 * GB:
2323
min_ao_blksize = 128
24-
min_grid_blksize = 128*128
24+
min_grid_blksize = 256*256#128*128
2525
ao_aligned = 32
2626
grid_aligned = 128
2727
mem_fraction = 0.9

gpu4pyscf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import lib, grad, hessian, solvent, scf, dft
2-
__version__ = '0.6.11'
2+
__version__ = '0.6.12'
33

44
# monkey patch libxc reference due to a bug in nvcc
55
from pyscf.dft import libxc

gpu4pyscf/df/df.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
class DF(df.DF):
3636
from gpu4pyscf.lib.utils import to_gpu, device
3737

38+
_keys = {'intopt'}
39+
3840
def __init__(self, mol, auxbasis=None):
3941
super().__init__(mol, auxbasis)
4042
self.auxmol = None
@@ -210,8 +212,12 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False):
210212
if(not use_gpu_memory):
211213
log.debug("Not enough GPU memory")
212214
# TODO: async allocate memory
213-
mem = cupy.cuda.alloc_pinned_memory(naux * npair * 8)
214-
cderi = np.ndarray([naux, npair], dtype=np.float64, order='C', buffer=mem)
215+
try:
216+
mem = cupy.cuda.alloc_pinned_memory(naux * npair * 8)
217+
cderi = np.ndarray([naux, npair], dtype=np.float64, order='C', buffer=mem)
218+
except Exception:
219+
raise RuntimeError('Out of CPU memory')
220+
215221
data_stream = cupy.cuda.stream.Stream(non_blocking=False)
216222
count = 0
217223
nq = len(intopt.log_qs)

gpu4pyscf/df/grad/rhf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# You should have received a copy of the GNU General Public License
1414
# along with this program. If not, see <http://www.gnu.org/licenses/>.
1515

16-
16+
import copy
1717
import numpy
1818
import cupy
1919
from cupyx.scipy.linalg import solve_triangular
@@ -44,9 +44,13 @@ def get_jk(mf_grad, mol=None, dm0=None, hermi=0, with_j=True, with_k=True, omega
4444
if key in mf_grad.base.with_df._rsh_df:
4545
with_df = mf_grad.base.with_df._rsh_df[key]
4646
else:
47-
raise RuntimeError(f'omega={omega} is not calculated in SCF')
47+
dfobj = mf_grad.base.with_df
48+
with_df = dfobj._rsh_df[key] = copy.copy(dfobj).reset()
49+
#raise RuntimeError(f'omega={omega} is not calculated in SCF')
4850

4951
auxmol = with_df.auxmol
52+
if not hasattr(with_df, 'intopt') or with_df._cderi is None:
53+
with_df.build(omega=omega)
5054
intopt = with_df.intopt
5155

5256
naux = with_df.naux

gpu4pyscf/df/hessian/rhf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,21 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None):
383383
for ia, h1, vj1, vk1 in _gen_jk(hessobj, mo_coeff, mo_occ, chkfile,
384384
atmlst, verbose, True):
385385
h1 += vj1 - vk1 * .5
386+
h1ao[ia] = h1
387+
'''
386388
if chkfile is None:
387389
h1ao[ia] = h1
388390
else:
389391
key = 'scf_f1ao/%d' % ia
390392
lib.chkfile.save(chkfile, key, h1)
393+
'''
394+
return h1ao
395+
'''
391396
if chkfile is None:
392397
return h1ao
393398
else:
394399
return chkfile
395-
400+
'''
396401
def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
397402
verbose=None, with_k=True, omega=None):
398403
log = logger.new_logger(hessobj, verbose)

gpu4pyscf/df/hessian/rks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,17 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None):
103103
for ia, h1, vj1_lr, vk1_lr in df_rhf_hess._gen_jk(hessobj, mo_coeff, mo_occ, chkfile,
104104
atmlst, verbose, True, omega=omega):
105105
h1ao[ia] -= .5 * (alpha - hyb) * vk1_lr
106+
return h1ao
107+
108+
# support chkfile?
109+
'''
106110
if chkfile is None:
107111
return h1ao
108112
else:
109113
for ia in atmlst:
110114
lib.chkfile.save(chkfile, 'scf_f1ao/%d'%ia, h1ao[ia])
111115
return chkfile
112-
116+
'''
113117

114118
class Hessian(rks_hess.Hessian):
115119
'''Non-relativistic RKS hessian'''

gpu4pyscf/dft/gen_grid.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@
3737
from cupyx.scipy.spatial.distance import cdist
3838
from gpu4pyscf.dft import radi
3939
from gpu4pyscf.lib.cupy_helper import load_library
40-
40+
from gpu4pyscf import __config__ as __gpu4pyscf_config__
4141
libdft = lib.load_library('libdft')
4242
libgdft = load_library('libgdft')
4343

4444
from pyscf.dft.gen_grid import GROUP_BOUNDARY_PENALTY, NELEC_ERROR_TOL, LEBEDEV_ORDER, LEBEDEV_NGRID
4545

4646
GROUP_BOX_SIZE = 3.0
47-
ALIGNMENT_UNIT = 32
47+
ALIGNMENT_UNIT = getattr(__gpu4pyscf_config__, 'grid_aligned', 128)
4848
# SG0
4949
# S. Chien and P. Gill, J. Comput. Chem. 27 (2006) 730-739.
5050

@@ -334,6 +334,7 @@ def gen_grid_partition(coords):
334334

335335
coords_all = []
336336
weights_all = []
337+
# support atomic_radii_adjust = None
337338
assert radii_adjust == radi.treutler_atomic_radii_adjust
338339
a = -radi.get_treutler_fac(mol, atomic_radii)
339340
for ia in range(mol.natm):
@@ -377,6 +378,49 @@ def make_mask(mol, coords, relativity=0, shls_slice=None, cutoff=CUTOFF,
377378
'''
378379
return make_screen_index(mol, coords, shls_slice, cutoff)
379380

381+
def atomic_group_grids(mol, coords):
382+
'''
383+
partition the entire space based on atomic position
384+
'''
385+
from scipy.spatial import distance_matrix
386+
natm = mol.natm
387+
ngrids = coords.shape[0]
388+
atom_coords = mol.atom_coords()
389+
dist = distance_matrix(atom_coords, atom_coords)
390+
visited = numpy.zeros(natm, dtype=bool)
391+
current_node = numpy.argmin(atom_coords[:,0])
392+
# greedy traverse atoms
393+
path = [current_node]
394+
while len(path) < natm:
395+
visited[current_node] = True
396+
# Set distances to visited nodes as infinity so they won't be chosen
397+
distances_to_unvisited = numpy.where(visited, numpy.inf, dist[current_node])
398+
next_node = numpy.argmin(distances_to_unvisited)
399+
path.append(next_node)
400+
current_node = next_node
401+
402+
atom_coords = cupy.asarray(atom_coords[path])
403+
#dij = cupy.sum((atom_coords[:,None,:] - coords[None,:,:])**2, axis=2)
404+
#group_ids = cupy.argmin(dij, axis=0)
405+
406+
coords = cupy.asarray(coords, order='F')
407+
atom_coords = cupy.asarray(atom_coords, order='F')
408+
group_ids = cupy.empty([ngrids], dtype=numpy.int32)
409+
stream = cupy.cuda.get_current_stream()
410+
err = libgdft.GDFTgroup_grids(
411+
ctypes.cast(stream.ptr, ctypes.c_void_p),
412+
ctypes.cast(group_ids.data.ptr, ctypes.c_void_p),
413+
ctypes.cast(atom_coords.data.ptr, ctypes.c_void_p),
414+
ctypes.cast(coords.data.ptr, ctypes.c_void_p),
415+
ctypes.c_int(natm),
416+
ctypes.c_int(ngrids)
417+
)
418+
if err != 0:
419+
raise RuntimeError('CUDA Error')
420+
421+
return group_ids.argsort()
422+
423+
380424
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
381425
'''
382426
Parition the entire space into small boxes according to the input box_size.
@@ -510,11 +554,6 @@ def build(self, mol=None, with_non0tab=False, sort_grids=True, **kwargs):
510554
self.coords, self.weights = self.get_partition(
511555
mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme)
512556

513-
if sort_grids:
514-
idx = arg_group_grids(mol, self.coords)
515-
self.coords = self.coords[idx]
516-
self.weights = self.weights[idx]
517-
518557
if self.alignment > 1:
519558
padding = _padding_size(self.size, self.alignment)
520559
logger.debug(self, 'Padding %d grids', padding)
@@ -523,6 +562,13 @@ def build(self, mol=None, with_non0tab=False, sort_grids=True, **kwargs):
523562
self.coords = cupy.vstack(
524563
[self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
525564
self.weights = cupy.hstack([self.weights, numpy.zeros(padding)])
565+
566+
if sort_grids:
567+
#idx = arg_group_grids(mol, self.coords)
568+
idx = atomic_group_grids(mol, self.coords)
569+
self.coords = self.coords[idx]
570+
self.weights = self.weights[idx]
571+
526572
if with_non0tab:
527573
self.non0tab = self.make_mask(mol, self.coords)
528574
self.screen_index = self.non0tab

gpu4pyscf/dft/numint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,8 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000,
13061306
class NumInt(numint.NumInt):
13071307
from gpu4pyscf.lib.utils import to_cpu, to_gpu, device
13081308

1309+
_keys = {'screen_idx', 'xcfuns', 'gdftopt'}
1310+
13091311
def __init__(self, xc='LDA'):
13101312
super().__init__()
13111313
self.gdftopt = None

0 commit comments

Comments
 (0)