From ed64bbce9cc0117d52be3c4efda2c7ff7809ba81 Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Wed, 3 Jan 2024 17:43:12 -0800 Subject: [PATCH 1/3] Bugfix transpose sum (#73) * fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__ * fixed a bug in eval_ao * fixed a bug in transpose_sum * remove print --- gpu4pyscf/__config__.py | 6 +-- gpu4pyscf/__init__.py | 2 +- gpu4pyscf/df/df.py | 5 ++- gpu4pyscf/df/df_jk.py | 5 +-- gpu4pyscf/dft/gen_grid.py | 21 +--------- gpu4pyscf/dft/tests/test_ao_values.py | 1 - gpu4pyscf/dft/tests/test_grids.py | 2 +- gpu4pyscf/lib/cupy_helper.py | 4 +- gpu4pyscf/lib/cupy_helper/transpose.cu | 53 ++++++++++++++----------- gpu4pyscf/lib/gdft/gen_grids.cu | 1 - gpu4pyscf/lib/gdft/nr_eval_gto.cu | 52 ++++++++++++------------ gpu4pyscf/lib/gdft/vv10.cu | 20 +++------- gpu4pyscf/lib/tests/test_cupy_helper.py | 4 +- 13 files changed, 76 insertions(+), 100 deletions(-) diff --git a/gpu4pyscf/__config__.py b/gpu4pyscf/__config__.py index 987ca87c..5b740207 100644 --- a/gpu4pyscf/__config__.py +++ b/gpu4pyscf/__config__.py @@ -13,7 +13,7 @@ # such as V100-32G elif props['totalGlobalMem'] >= 32 * GB: min_ao_blksize = 128 - min_grid_blksize = 256*256#128*128 + min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 mem_fraction = 0.9 @@ -21,7 +21,7 @@ # such as A30-24GB elif props['totalGlobalMem'] >= 16 * GB: min_ao_blksize = 128 - min_grid_blksize = 256*256#128*128 + min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 mem_fraction = 0.9 @@ -35,4 +35,4 @@ mem_fraction = 0.9 number_of_threads = 1024 * 80 -cupy.get_default_memory_pool().set_limit(fraction=mem_fraction) \ No newline at end of file +cupy.get_default_memory_pool().set_limit(fraction=mem_fraction) diff --git a/gpu4pyscf/__init__.py b/gpu4pyscf/__init__.py index 004dfb05..5fd45d2f 100644 --- a/gpu4pyscf/__init__.py +++ b/gpu4pyscf/__init__.py @@ -1,5 +1,5 @@ from . import lib, grad, hessian, solvent, scf, dft -__version__ = '0.6.13' +__version__ = '0.6.14' # monkey patch libxc reference due to a bug in nvcc from pyscf.dft import libxc diff --git a/gpu4pyscf/df/df.py b/gpu4pyscf/df/df.py index 0b40d26e..ff3c6877 100644 --- a/gpu4pyscf/df/df.py +++ b/gpu4pyscf/df/df.py @@ -22,7 +22,7 @@ from pyscf import lib from pyscf.df import df, addons from gpu4pyscf.lib.cupy_helper import ( - cholesky, tag_array, get_avail_mem, cart2sph, take_last2d) + cholesky, tag_array, get_avail_mem, cart2sph, take_last2d, transpose_sum) from gpu4pyscf.df import int3c2e, df_jk from gpu4pyscf.lib import logger from gpu4pyscf import __config__ @@ -262,7 +262,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): row = intopt.ao_pairs_row[cp_ij_id] - i0 col = intopt.ao_pairs_col[cp_ij_id] - j0 if cpi == cpj: - ints_slices = ints_slices + ints_slices.transpose([0,2,1]) + #ints_slices = ints_slices + ints_slices.transpose([0,2,1]) + transpose_sum(ints_slices) ints_slices = ints_slices[:,col,row] if cd_low.tag == 'eig': diff --git a/gpu4pyscf/df/df_jk.py b/gpu4pyscf/df/df_jk.py index 23a46850..ff0cbd7e 100644 --- a/gpu4pyscf/df/df_jk.py +++ b/gpu4pyscf/df/df_jk.py @@ -290,15 +290,14 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e- vj_packed += cupy.dot(rhoj, cderi_sparse.T) if with_k: rhok = contract('Lij,jk->Lki', cderi, occ_coeff) - #vk[0] += contract('Lki,Lkj->ij', rhok, rhok) - cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=1.0, beta=1.0, lower=True) + #vk[0] += 2.0 * contract('Lki,Lkj->ij', rhok, rhok) + cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=2.0, beta=1.0, lower=True) if with_j: vj[:,rows,cols] = vj_packed vj[:,cols,rows] = vj_packed if with_k: vk[0][numpy.diag_indices(nao)] *= 0.5 transpose_sum(vk) - vk *= 2.0 # CP-HF K matrix elif hasattr(dms_tag, 'mo1'): if with_j: diff --git a/gpu4pyscf/dft/gen_grid.py b/gpu4pyscf/dft/gen_grid.py index 40c6947a..5ced5d94 100644 --- a/gpu4pyscf/dft/gen_grid.py +++ b/gpu4pyscf/dft/gen_grid.py @@ -186,27 +186,8 @@ def gen_grids_partition(atm_coords, coords, a): natm = atm_coords.shape[0] ngrids = coords.shape[0] assert ngrids < 65535 * 16 - #x_i = cupy.expand_dims(atm_coords, axis=1) - #x_g = cupy.expand_dims(coords, axis=0) - #squared_diff = (x_i - x_g)**2 - #dist_ig = cupy.sum(squared_diff, axis=2)**0.5 - #x_j = cupy.expand_dims(atm_coords, axis=0) - #squared_diff = (x_i - x_j)**2 - #dist_ij = cupy.sum(squared_diff, axis=2)**0.5 - - pbecke = cupy.ones([natm, ngrids], order='C') - ''' - err = libgdft.GDFTgen_grid_partition( - ctypes.cast(stream.ptr, ctypes.c_void_p), - ctypes.cast(pbecke.data.ptr, ctypes.c_void_p), - ctypes.cast(dist_ig.data.ptr, ctypes.c_void_p), - ctypes.cast(dist_ij.data.ptr, ctypes.c_void_p), - ctypes.cast(a.data.ptr, ctypes.c_void_p), - ctypes.c_int(ngrids), - ctypes.c_int(natm) - ) - ''' + pbecke = cupy.empty([natm, ngrids], order='C') atm_coords = cupy.asarray(atm_coords, order='F') err = libgdft.GDFTgen_grid_partition( ctypes.cast(stream.ptr, ctypes.c_void_p), diff --git a/gpu4pyscf/dft/tests/test_ao_values.py b/gpu4pyscf/dft/tests/test_ao_values.py index 78ae6c98..5ee875b6 100644 --- a/gpu4pyscf/dft/tests/test_ao_values.py +++ b/gpu4pyscf/dft/tests/test_ao_values.py @@ -73,7 +73,6 @@ def test_ao_sph_deriv2(self): ao_cpu = cupy.asarray(ao) ni = NumInt(xc='LDA') ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=2) - #idx = cupy.argwhere(cupy.abs(ao_gpu - ao_cpu) > 1e-10) assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8 def test_ao_sph_deriv3(self): diff --git a/gpu4pyscf/dft/tests/test_grids.py b/gpu4pyscf/dft/tests/test_grids.py index 534ece58..d675393b 100644 --- a/gpu4pyscf/dft/tests/test_grids.py +++ b/gpu4pyscf/dft/tests/test_grids.py @@ -32,7 +32,7 @@ def setUpModule(): O 0.000000 0.000000 0.117790 H 0.000000 0.755453 -0.471161 H 0.000000 -0.755453 -0.471161''', - basis = 'ccpvdz', + basis = 'ccpvqz', charge = 1, spin = 1, # = 2S = spin_up - spin_down output = '/dev/null') diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index f52f68f7..cebbbfd3 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -264,9 +264,9 @@ def take_last2d(a, indices, out=None): raise RuntimeError('failed in take_last2d kernel') return out -def transpose_sum(a): +def transpose_sum(a, stream=None): ''' - transpose (0,2,1) + return a + a.transpose(0,2,1) ''' assert a.flags.c_contiguous assert a.ndim == 3 diff --git a/gpu4pyscf/lib/cupy_helper/transpose.cu b/gpu4pyscf/lib/cupy_helper/transpose.cu index 7e927059..78b3dd39 100644 --- a/gpu4pyscf/lib/cupy_helper/transpose.cu +++ b/gpu4pyscf/lib/cupy_helper/transpose.cu @@ -35,36 +35,41 @@ static void _dsymm_triu(double *a, int n) a[off + j * N + i] = a[off + i * N + j]; } -__global__ +__global__ void _transpose_sum(double *a, int n) { + if(blockIdx.x > blockIdx.y){ + return; + } __shared__ double block[BLOCK_DIM][BLOCK_DIM+1]; - - // read the matrix tile into shared memory - // load one element per thread from device memory (idata) and store it - // in transposed order in block[][] - unsigned int xIndex = blockIdx.x * BLOCK_DIM + threadIdx.x; - unsigned int yIndex = blockIdx.y * BLOCK_DIM + threadIdx.y; - unsigned int zIndex = blockIdx.z; - unsigned int off = zIndex * n * n; - if((xIndex < n) && (yIndex < n)) - { - unsigned int index_in = yIndex * n + xIndex + off; - block[threadIdx.y][threadIdx.x] = a[index_in]; - } + unsigned int blockx_off = blockIdx.x * BLOCK_DIM; + unsigned int blocky_off = blockIdx.y * BLOCK_DIM; + unsigned int x0 = blockx_off + threadIdx.x; + unsigned int y0 = blocky_off + threadIdx.y; + unsigned int x1 = blocky_off + threadIdx.x; + unsigned int y1 = blockx_off + threadIdx.y; + unsigned int z = blockIdx.z; + + unsigned int off = n * n * z; + unsigned int xy0 = y0 * n + x0 + off; + unsigned int xy1 = y1 * n + x1 + off; - // synchronise to ensure all writes to block[][] have completed - __syncthreads(); + if (x0 < n && y0 < n){ + block[threadIdx.y][threadIdx.x] = a[xy0]; + } + __syncthreads(); + if (x1 < n && y1 < n){ + block[threadIdx.x][threadIdx.y] += a[xy1]; + } + __syncthreads(); - // write the transposed matrix tile to global memory (odata) in linear order - xIndex = blockIdx.y * BLOCK_DIM + threadIdx.x; - yIndex = blockIdx.x * BLOCK_DIM + threadIdx.y; - if((xIndex < n) && (yIndex < n)) - { - unsigned int index_out = yIndex * n + xIndex + off; - a[index_out] += block[threadIdx.x][threadIdx.y]; - } + if(x0 < n && y0 < n){ + a[xy0] = block[threadIdx.y][threadIdx.x]; + } + if(x1 < n && y1 < n){ + a[xy1] = block[threadIdx.x][threadIdx.y]; + } } extern "C" { diff --git a/gpu4pyscf/lib/gdft/gen_grids.cu b/gpu4pyscf/lib/gdft/gen_grids.cu index 2c484ae3..d71eedf9 100644 --- a/gpu4pyscf/lib/gdft/gen_grids.cu +++ b/gpu4pyscf/lib/gdft/gen_grids.cu @@ -40,7 +40,6 @@ int ngrids, int natm) __shared__ double zj[NATOM_PER_BLOCK]; __shared__ double a_smem[NATOM_PER_BLOCK]; __shared__ double dij_smem[NATOM_PER_BLOCK]; - const int tx = threadIdx.x; for (int atom_i = 0; atom_i < natm; atom_i++){ diff --git a/gpu4pyscf/lib/gdft/nr_eval_gto.cu b/gpu4pyscf/lib/gdft/nr_eval_gto.cu index 0295c0c0..48ff5729 100644 --- a/gpu4pyscf/lib/gdft/nr_eval_gto.cu +++ b/gpu4pyscf/lib/gdft/nr_eval_gto.cu @@ -28,7 +28,7 @@ #include "nr_eval_gto.cuh" #include "contract_rho.cuh" -#define THREADS 128 +#define NG_PER_BLOCK 128 #define LMAX 8 #define GTO_MAX_CART 15 @@ -1121,15 +1121,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = ax * ry * ry * rz * rz; g13 = ax * ry * rz * rz * rz; g14 = ax * rz * rz * rz * rz; - gtox[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtox[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtox[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtox[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtox[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtox[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtox[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtox[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtox[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtox[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; double ay = ce_2a * ry; g0 = ay * rx * rx * rx * rx; @@ -1147,15 +1147,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = (ay * ry + 2 * ce) * ry * rz * rz; g13 = (ay * ry + ce) * rz * rz * rz; g14 = ay * rz * rz * rz * rz; - gtoy[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtoy[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtoy[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtoy[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; double az = ce_2a * rz; g0 = az * rx * rx * rx * rx; @@ -1173,15 +1173,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = (az * rz + 2 * ce) * ry * ry * rz; g13 = (az * rz + 3 * ce) * ry * rz * rz; g14 = (az * rz + 4 * ce) * rz * rz * rz; - gtoz[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtoz[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtoz[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtoz[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; } } @@ -1559,8 +1559,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart, offsets.bas_indices = bas_indices; offsets.nbas = local_ctr_offsets[nctr]; offsets.nao = nao; - dim3 threads(THREADS); - dim3 blocks((ngrids+THREADS-1)/THREADS); + dim3 threads(NG_PER_BLOCK); + dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK); for (int ictr = 0; ictr < nctr; ++ictr) { int local_ish = local_ctr_offsets[ictr]; @@ -1706,8 +1706,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart, int GDFTscreen_index(cudaStream_t stream, int *non0shl_idx, double cutoff, double *grids, int ngrids, int *bas_loc, int nbas, int *bas) { - dim3 threads(THREADS); - dim3 blocks((ngrids+THREADS-1)/THREADS); + dim3 threads(NG_PER_BLOCK); + dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK); for (int shl_id = 0; shl_id < nbas; ++shl_id) { int l = bas[ANG_OF+shl_id*BAS_SLOTS]; diff --git a/gpu4pyscf/lib/gdft/vv10.cu b/gpu4pyscf/lib/gdft/vv10.cu index b7be564d..bdcb01da 100644 --- a/gpu4pyscf/lib/gdft/vv10.cu +++ b/gpu4pyscf/lib/gdft/vv10.cu @@ -168,28 +168,20 @@ static void vv10_grad_kernel(double *Fvec, const double *vvcoords, const double __syncthreads(); for (int l = 0, M = min(NG_PER_BLOCK, vvngrids - j); l < M; ++l){ double3 xj_tmp = xj_t[l]; - double pjx = xj_tmp.x; - double pjy = xj_tmp.y; - double pjz = xj_tmp.z; - // about 23 operations for each pair - double DX = pjx - xi; - double DY = pjy - yi; - double DZ = pjz - zi; + double DX = xj_tmp.x - xi; + double DY = xj_tmp.y - yi; + double DZ = xj_tmp.z - zi; double R2 = DX*DX + DY*DY + DZ*DZ; double3 kp_tmp = kp_t[l]; - double Kpj = kp_tmp.x; - double W0pj = kp_tmp.y; - double RpWj = kp_tmp.z; - - double gp = R2*W0pj + Kpj; + double gp = R2*kp_tmp.y + kp_tmp.x; double g = R2*W0i + Ki; double gt = g + gp; double ggp = g * gp; double ggt_gp = gt * ggp; - double T = RpWj / (ggt_gp * ggt_gp); - double Q = T * ((W0i*gp + W0pj*g)*gt + (W0i+W0pj)*ggp); + double T = kp_tmp.z / (ggt_gp * ggt_gp); + double Q = T * ((W0i*gp + kp_tmp.y*g)*gt + (W0i+kp_tmp.y)*ggp); FX += Q * DX; FY += Q * DY; diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index 513f7c38..614e0fa3 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -31,8 +31,8 @@ def test_take_last2d(self): assert(cupy.linalg.norm(a[:,indices][:,:,indices] - b) < 1e-10) def test_transpose_sum(self): - n = 3 - count = 4 + n = 1287 + count = 127 a = cupy.random.rand(count,n,n) b = a + a.transpose(0,2,1) transpose_sum(a) From e19d51e6f1f6f5e22023dc5dac48ca8ecd388822 Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Fri, 5 Jan 2024 15:59:02 -0800 Subject: [PATCH 2/3] Get vxc (#74) * fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__ * fixed a bug in eval_ao * fixed a bug in transpose_sum * remove print * new get_vxc scheme --- examples/dft_driver.py | 1 - gpu4pyscf/dft/numint.py | 80 ++++++--- gpu4pyscf/lib/cupy_helper.py | 4 + gpu4pyscf/lib/cupy_helper/add_sparse.cu | 4 +- gpu4pyscf/lib/cupy_helper/block_diag.cu | 2 +- gpu4pyscf/lib/cupy_helper/take_last2d.cu | 2 +- gpu4pyscf/lib/cupy_helper/transpose.cu | 4 +- gpu4pyscf/lib/gdft/contract_rho.cu | 1 - gpu4pyscf/lib/gdft/nr_eval_gto.cu | 199 +++++++++++++++++------ 9 files changed, 216 insertions(+), 81 deletions(-) diff --git a/examples/dft_driver.py b/examples/dft_driver.py index 2b2c8cd6..80830e6c 100644 --- a/examples/dft_driver.py +++ b/examples/dft_driver.py @@ -34,7 +34,6 @@ basis=bas, max_memory=32000) # set verbose >= 6 for debugging timer - mol.verbose = 4 mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis) diff --git a/gpu4pyscf/dft/numint.py b/gpu4pyscf/dft/numint.py index f7758f1c..0e46c611 100644 --- a/gpu4pyscf/dft/numint.py +++ b/gpu4pyscf/dft/numint.py @@ -24,7 +24,8 @@ from pyscf.dft import numint from pyscf.gto.eval_gto import NBINS, CUTOFF, make_screen_index from gpu4pyscf.scf.hf import basis_seg_contraction -from gpu4pyscf.lib.cupy_helper import contract, get_avail_mem, load_library, add_sparse, release_gpu_stack, take_last2d +from gpu4pyscf.lib.cupy_helper import ( + contract, get_avail_mem, load_library, add_sparse, release_gpu_stack, take_last2d, transpose_sum) from gpu4pyscf.dft import xc_deriv, xc_alias, libxc from gpu4pyscf import __config__ from gpu4pyscf.lib import logger @@ -83,11 +84,17 @@ def eval_ao(ni, mol, coords, deriv=0, shls_slice=None, nao_slice=None, ao_loc_sl comp = (deriv+1)*(deriv+2)*(deriv+3)//6 stream = cupy.cuda.get_current_stream() + # ao must be set to zero due to implementation + if deriv > 1: + ao = cupy.zeros((comp, nao_slice, ngrids), order='C') + else: + ao = cupy.empty((comp, nao_slice, ngrids), order='C') + + #ao = cupy.zeros((comp, nao_slice, ngrids), order='C') if not with_opt: # mol may be different to _GDFTOpt.mol. # nao should be consistent with the _GDFTOpt.mol object coeff = cupy.asarray(opt.coeff) - ao = cupy.zeros((comp, nao_slice, ngrids), order='C') with opt.gdft_envs_cache(): err = libgdft.GDFTeval_gto( ctypes.cast(stream.ptr, ctypes.c_void_p), @@ -102,7 +109,6 @@ def eval_ao(ni, mol, coords, deriv=0, shls_slice=None, nao_slice=None, ao_loc_sl mol._bas.ctypes.data_as(ctypes.c_void_p)) ao = contract('nig,ij->njg', ao, coeff).transpose([0,2,1]) else: - ao = cupy.zeros((comp, nao_slice, ngrids), order='C') err = libgdft.GDFTeval_gto( ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(ao.data.ptr, ctypes.c_void_p), @@ -174,7 +180,7 @@ def eval_rho1(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', raise NotImplementedError def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', - with_lapl=True, verbose=None): + with_lapl=True, verbose=None, out=None): xctype = xctype.upper() if xctype == 'LDA' or xctype == 'HF': _, ngrids = ao.shape @@ -467,23 +473,48 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, ao_deriv = 0 else: ao_deriv = 1 + ngrids = grids.weights.size + if xctype == 'LDA': + rho_tot = cupy.empty([nset,1,ngrids]) + elif xctype == 'GGA': + rho_tot = cupy.empty([nset,4,ngrids]) + else: + rho_tot = cupy.empty([nset,6,ngrids]) + p0 = p1 = 0 for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): + p1 = p0 + weight.size for i in range(nset): t0 = log.init_timer() if mo_coeff is None: - rho = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1) + rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1) else: mo_coeff_mask = mo_coeff[idx,:] - rho = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) - + rho_tot[i,:,p0:p1] = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) t1 = log.timer_debug1('eval rho', *t0) - exc, vxc = ni.eval_xc_eff(xc_code, rho, deriv=1, xctype=xctype)[:2] - vxc = cupy.asarray(vxc, order='C') - exc = cupy.asarray(exc, order='C') - t1 = log.timer_debug1('eval vxc', *t1) + p0 = p1 + + vxc_tot = [] + for i in range(nset): + if xctype == 'LDA': + exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i][0], deriv=1, xctype=xctype)[:2] + else: + exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2] + vxc = cupy.asarray(vxc, order='C') + exc = cupy.asarray(exc, order='C') + den = rho_tot[i][0] * grids.weights + nelec[i] = den.sum() + excsum[i] = cupy.sum(den * exc[:,0]) + vxc_tot.append(vxc) + t1 = log.timer_debug1('eval vxc', *t1) + + p0 = p1 = 0 + for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): + p1 = p0 + weight.size + for i in range(nset): + vxc = vxc_tot[i][:,p0:p1] if xctype == 'LDA': - den = rho * weight + #den = rho * weight wv = weight * vxc[0] ''' if USE_SPARSITY == 0: @@ -499,7 +530,7 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'GGA': - den = rho[0] * weight + #den = rho[0] * weight wv = vxc * weight wv[0] *= .5 ''' @@ -512,14 +543,13 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, ''' if USE_SPARSITY == 2: aow = _scale_ao(ao_mask, wv) - #vmat[i][cupy.ix_(mask, mask)] += ao_mask[0].dot(aow.T) add_sparse(vmat[i], ao_mask[0].dot(aow.T), idx) else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'NLC': raise NotImplementedError('NLC') elif xctype == 'MGGA': - den = rho[0] * weight + #den = rho[0] * weight wv = vxc * weight wv[[0, 4]] *= .5 # *.5 for v+v.T ''' @@ -545,9 +575,8 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, pass else: raise NotImplementedError(f'numint.nr_rks for functional {xc_code}') - nelec[i] += den.sum() - excsum[i] += cupy.dot(den, exc)[0] t1 = log.timer_debug1('integration', *t1) + p0 = p1 vmat = contract('pi,npq->niq', coeff, vmat) vmat = contract('qj,niq->nij', coeff, vmat) @@ -555,8 +584,8 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, #vmat = take_last2d(vmat, rev_ao_idx) if xctype != 'LDA': - #transpose_sum(vmat) vmat = vmat + vmat.transpose([0,2,1]) + #transpose_sum(vmat) if FREE_CUPY_CACHE: dms = None @@ -567,7 +596,7 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, excsum = excsum[0] vmat = vmat[0] - return nelec, excsum, vmat#np.asarray(vmat) + return nelec, excsum, vmat def nr_uks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, max_memory=2000, verbose=None): @@ -1286,13 +1315,12 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, zero_idx = cupy.asarray(zero_idx, dtype=np.int32) pad = (len(idx) + AO_ALIGNMENT - 1) // AO_ALIGNMENT * AO_ALIGNMENT - len(idx) idx = cupy.hstack([idx, zero_idx[:pad]]) + pad = min(pad, len(zero_idx)) non0shl_idx = cupy.asarray(np.where(non0shl_idx)[0], dtype=np.int32) - - ni.non0ao_idx[deriv, block_id, blksize, ngrids] = (idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) + ni.non0ao_idx[deriv, block_id, blksize, ngrids] = (pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) log.timer_debug1('init ao sparsity', *t0) else: - idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[deriv, block_id, blksize, ngrids] - + pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[deriv, block_id, blksize, ngrids] t0 = log.init_timer() ao_mask = eval_ao( ni, mol, coords, deriv, @@ -1300,7 +1328,11 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, shls_slice=non0shl_idx, ao_loc_slice=ao_loc_slice, ctr_offsets_slice=ctr_offsets_slice) - + if pad > 0: + if deriv == 0: + ao_mask[-pad:,:] = 0.0 + else: + ao_mask[:,-pad:,:] = 0.0 block_id += 1 log.timer_debug1('evaluate ao slice', *t0) yield ao_mask, idx, weight, coords diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index cebbbfd3..3d0d25f5 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -154,7 +154,9 @@ def add_sparse(a, b, indices): count = 1 else: raise RuntimeError('add_sparse only supports 2d or 3d tensor') + stream = cupy.cuda.get_current_stream() err = libcupy_helper.add_sparse( + ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(a.data.ptr, ctypes.c_void_p), ctypes.cast(b.data.ptr, ctypes.c_void_p), ctypes.cast(indices.data.ptr, ctypes.c_void_p), @@ -272,7 +274,9 @@ def transpose_sum(a, stream=None): assert a.ndim == 3 n = a.shape[-1] count = a.shape[0] + stream = cupy.cuda.get_current_stream() err = libcupy_helper.transpose_sum( + ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(a.data.ptr, ctypes.c_void_p), ctypes.c_int(n), ctypes.c_int(count) diff --git a/gpu4pyscf/lib/cupy_helper/add_sparse.cu b/gpu4pyscf/lib/cupy_helper/add_sparse.cu index d8033015..edccf7e1 100644 --- a/gpu4pyscf/lib/cupy_helper/add_sparse.cu +++ b/gpu4pyscf/lib/cupy_helper/add_sparse.cu @@ -39,11 +39,11 @@ void _add_sparse(double *a, double *b, int *indices, int n, int m, int count) extern "C" { __host__ -int add_sparse(double *a, double *b, int *indices, int n, int m, int count){ +int add_sparse(cudaStream_t stream, double *a, double *b, int *indices, int n, int m, int count){ int ntile = (m + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile); - _add_sparse<<>>(a, b, indices, n, m, count); + _add_sparse<<>>(a, b, indices, n, m, count); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; diff --git a/gpu4pyscf/lib/cupy_helper/block_diag.cu b/gpu4pyscf/lib/cupy_helper/block_diag.cu index f2176c67..145fe45c 100644 --- a/gpu4pyscf/lib/cupy_helper/block_diag.cu +++ b/gpu4pyscf/lib/cupy_helper/block_diag.cu @@ -24,7 +24,7 @@ static void _block_diag(double *out, int m, int n, double *diags, int ndiags, in int i = threadIdx.x; int j = threadIdx.y; int r = blockIdx.x; - + if (r >= ndiags){ return; } diff --git a/gpu4pyscf/lib/cupy_helper/take_last2d.cu b/gpu4pyscf/lib/cupy_helper/take_last2d.cu index 36342013..26a1a6a6 100644 --- a/gpu4pyscf/lib/cupy_helper/take_last2d.cu +++ b/gpu4pyscf/lib/cupy_helper/take_last2d.cu @@ -27,7 +27,7 @@ static void _take(double *a, const double *b, int *indices, int n) if (j >= n || k >= n) { return; } - + int j_b = indices[j]; int k_b = indices[k]; int off = i * n * n; diff --git a/gpu4pyscf/lib/cupy_helper/transpose.cu b/gpu4pyscf/lib/cupy_helper/transpose.cu index 78b3dd39..748c83a8 100644 --- a/gpu4pyscf/lib/cupy_helper/transpose.cu +++ b/gpu4pyscf/lib/cupy_helper/transpose.cu @@ -88,11 +88,11 @@ int CPdsymm_triu(double *a, int n, int counts) } __host__ -int transpose_sum(double *a, int n, int counts){ +int transpose_sum(cudaStream_t stream, double *a, int n, int counts){ int ntile = (n + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile, counts); - _transpose_sum<<>>(a, n); + _transpose_sum<<>>(a, n); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; diff --git a/gpu4pyscf/lib/gdft/contract_rho.cu b/gpu4pyscf/lib/gdft/contract_rho.cu index 1957928e..5c6dbd1c 100644 --- a/gpu4pyscf/lib/gdft/contract_rho.cu +++ b/gpu4pyscf/lib/gdft/contract_rho.cu @@ -30,7 +30,6 @@ void GDFTcontract_rho_kernel(double *rho, double *bra, double *ket, int ngrids, { int grid_id = blockIdx.x * blockDim.x + threadIdx.x; const bool active = grid_id < ngrids; - size_t Ngrids = ngrids; double v = 0; if (active){ diff --git a/gpu4pyscf/lib/gdft/nr_eval_gto.cu b/gpu4pyscf/lib/gdft/nr_eval_gto.cu index 48ff5729..1e2689d7 100644 --- a/gpu4pyscf/lib/gdft/nr_eval_gto.cu +++ b/gpu4pyscf/lib/gdft/nr_eval_gto.cu @@ -86,28 +86,28 @@ void _screen_index(int *non0shl_idx, double cutoff, int l, int ish, int nprim, d template __device__ static void _cart2sph(double g_cart[GTO_MAX_CART], double *g_sph, int stride, int grid_id){ if (ANG == 0) { - g_sph[grid_id + 0*stride] += g_cart[0]; + g_sph[grid_id ] += g_cart[0]; } else if (ANG == 1){ - g_sph[grid_id + 0*stride] += g_cart[0]; - g_sph[grid_id + 1*stride] += g_cart[1]; - g_sph[2*stride] += g_cart[2]; + g_sph[grid_id ] += g_cart[0]; + g_sph[grid_id + stride] += g_cart[1]; + g_sph[grid_id + 2*stride] += g_cart[2]; } else if (ANG == 2){ - g_sph[grid_id + 0*stride] += 1.092548430592079070 * g_cart[1]; - g_sph[grid_id + 1*stride] += 1.092548430592079070 * g_cart[4]; + g_sph[grid_id ] += 1.092548430592079070 * g_cart[1]; + g_sph[grid_id + stride] += 1.092548430592079070 * g_cart[4]; g_sph[grid_id + 2*stride] += 0.630783130505040012 * g_cart[5] - 0.315391565252520002 * (g_cart[0] + g_cart[3]); g_sph[grid_id + 3*stride] += 1.092548430592079070 * g_cart[2]; g_sph[grid_id + 4*stride] += 0.546274215296039535 * (g_cart[0] - g_cart[3]); } else if (ANG == 3){ - g_sph[grid_id + 0*stride] += 1.770130769779930531 * g_cart[1] - 0.590043589926643510 * g_cart[6]; - g_sph[grid_id + 1*stride] += 2.890611442640554055 * g_cart[4]; + g_sph[grid_id ] += 1.770130769779930531 * g_cart[1] - 0.590043589926643510 * g_cart[6]; + g_sph[grid_id + stride] += 2.890611442640554055 * g_cart[4]; g_sph[grid_id + 2*stride] += 1.828183197857862944 * g_cart[8] - 0.457045799464465739 * (g_cart[1] + g_cart[6]); g_sph[grid_id + 3*stride] += 0.746352665180230782 * g_cart[9] - 1.119528997770346170 * (g_cart[2] + g_cart[7]); g_sph[grid_id + 4*stride] += 1.828183197857862944 * g_cart[5] - 0.457045799464465739 * (g_cart[0] + g_cart[3]); g_sph[grid_id + 5*stride] += 1.445305721320277020 * (g_cart[2] - g_cart[7]); g_sph[grid_id + 6*stride] += 0.590043589926643510 * g_cart[0] - 1.770130769779930530 * g_cart[3]; } else if (ANG == 4){ - g_sph[grid_id + 0*stride] += 2.503342941796704538 * (g_cart[1] - g_cart[6]) ; - g_sph[grid_id + 1*stride] += 5.310392309339791593 * g_cart[4] - 1.770130769779930530 * g_cart[11]; + g_sph[grid_id ] += 2.503342941796704538 * (g_cart[1] - g_cart[6]) ; + g_sph[grid_id + stride] += 5.310392309339791593 * g_cart[4] - 1.770130769779930530 * g_cart[11]; g_sph[grid_id + 2*stride] += 5.677048174545360108 * g_cart[8] - 0.946174695757560014 * (g_cart[1] + g_cart[6]); g_sph[grid_id + 3*stride] += 2.676186174229156671 * g_cart[13]- 2.007139630671867500 * (g_cart[4] + g_cart[11]); g_sph[grid_id + 4*stride] += 0.317356640745612911 * (g_cart[0] + g_cart[10]) + 0.634713281491225822 * g_cart[3] - 2.538853125964903290 * (g_cart[5] + g_cart[12]) + 0.846284375321634430 * g_cart[14]; @@ -118,6 +118,92 @@ static void _cart2sph(double g_cart[GTO_MAX_CART], double *g_sph, int stride, in } } +template __device__ +static void _memset_cart(double *g_cart, int stride, int grid_id){ + if (ANG == 0){ + g_cart[grid_id] = 0.0; + } else if (ANG == 1){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + } else if (ANG == 2){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + } else if (ANG == 3){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + g_cart[grid_id + 6*stride] = 0.0; + g_cart[grid_id + 7*stride] = 0.0; + g_cart[grid_id + 8*stride] = 0.0; + g_cart[grid_id + 9*stride] = 0.0; + } else if (ANG == 4){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + g_cart[grid_id + 6*stride] = 0.0; + g_cart[grid_id + 7*stride] = 0.0; + g_cart[grid_id + 8*stride] = 0.0; + g_cart[grid_id + 9*stride] = 0.0; + g_cart[grid_id +10*stride] = 0.0; + g_cart[grid_id +11*stride] = 0.0; + g_cart[grid_id +12*stride] = 0.0; + g_cart[grid_id +14*stride] = 0.0; + } else { + int i = 0; + for (int lx = ANG; lx >= 0; lx--){ + for (int ly = ANG - lx; ly >= 0; ly--, i++){ + g_cart[grid_id + i*stride] = 0.0; + } + } + } +} + +template __device__ +static void _memset_sph(double *g_sph, int stride, int grid_id){ + if (ANG == 0){ + g_sph[grid_id] = 0.0; + } else if (ANG == 1){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + } else if (ANG == 2){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + } else if (ANG == 3){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + g_sph[grid_id + 5*stride] = 0.0; + g_sph[grid_id + 6*stride] = 0.0; + } else if (ANG == 4){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + g_sph[grid_id + 5*stride] = 0.0; + g_sph[grid_id + 6*stride] = 0.0; + g_sph[grid_id + 7*stride] = 0.0; + g_sph[grid_id + 8*stride] = 0.0; + } +} + template __device__ static void _cart_gto(double *g, double ce, double *fx, double *fy, double *fz){ for (int lx = ANG, i = 0; lx >= 0; lx--){ @@ -962,46 +1048,52 @@ static void _sph_kernel_deriv1(BasOffsets offsets) */ gto[ grid_id] = 1.092548430592079070 * g1; gto[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gto[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gto[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gto[3*ngrids+grid_id] = 1.092548430592079070 * g2; gto[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double ax = ce_2a * rx; - g0 = (ax * rx + 2 * ce) * rx; - g1 = (ax * rx + ce) * ry; - g2 = (ax * rx + ce) * rz; + double ax_ce = ax * rx + ce; + double ax_2ce = ax_ce + ce; + g0 = ax_2ce * rx; + g1 = ax_ce * ry; + g2 = ax_ce * rz; g3 = ax * ry * ry; g4 = ax * ry * rz; g5 = ax * rz * rz; gtox[ grid_id] = 1.092548430592079070 * g1; gtox[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtox[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtox[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtox[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtox[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double ay = ce_2a * ry; + double ay_ce = ay * ry + ce; + double ay_2ce = ay_ce + ce; g0 = ay * rx * rx; - g1 = (ay * ry + ce) * rx; + g1 = ay_ce * rx; g2 = ay * rx * rz; - g3 = (ay * ry + 2 * ce) * ry; - g4 = (ay * ry + ce) * rz; + g3 = ay_2ce * ry; + g4 = ay_ce * rz; g5 = ay * rz * rz; gtoy[ grid_id] = 1.092548430592079070 * g1; gtoy[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtoy[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtoy[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtoy[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtoy[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double az = ce_2a * rz; + double az_ce = az * rz + ce; + double az_2ce = az_ce + ce; g0 = az * rx * rx; g1 = az * rx * ry; - g2 = (az * rz + ce) * rx; + g2 = az_ce * rx; g3 = az * ry * ry; - g4 = (az * rz + ce) * ry; - g5 = (az * rz + 2 * ce) * rz; + g4 = az_ce * ry; + g5 = az_2ce * rz; gtoz[ grid_id] = 1.092548430592079070 * g1; gtoz[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtoz[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtoz[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtoz[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtoz[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); } else if (ANG == 3) { @@ -1024,12 +1116,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gto[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double ax = ce_2a * rx; - g0 = (ax * rx + 3 * ce) * rx * rx; - g1 = (ax * rx + 2 * ce) * rx * ry; - g2 = (ax * rx + 2 * ce) * rx * rz; - g3 = (ax * rx + ce) * ry * ry; - g4 = (ax * rx + ce) * ry * rz; - g5 = (ax * rx + ce) * rz * rz; + double ax_ce = ax * rx + ce; + double ax_2ce = ax_ce + ce; + double ax_3ce = ax_2ce + ce; + g0 = ax_3ce * rx * rx; + g1 = ax_2ce * rx * ry; + g2 = ax_2ce * rx * rz; + g3 = ax_ce * ry * ry; + g4 = ax_ce * ry * rz; + g5 = ax_ce * rz * rz; g6 = ax * ry * ry * ry; g7 = ax * ry * ry * rz; g8 = ax * ry * rz * rz; @@ -1043,16 +1138,19 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gtox[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double ay = ce_2a * ry; - g0 = ay * rx * rx * rx; - g1 = (ay * ry + ce) * rx * rx; - g2 = ay * rx * rx * rz; - g3 = (ay * ry + 2 * ce) * rx * ry; - g4 = (ay * ry + ce) * rx * rz; - g5 = ay * rx * rz * rz; - g6 = (ay * ry + 3 * ce) * ry * ry; - g7 = (ay * ry + 2 * ce) * ry * rz; - g8 = (ay * ry + ce) * rz * rz; - g9 = ay * rz * rz * rz; + double ay_ce = ay * ry + ce; + double ay_2ce = ay_ce + ce; + double ay_3ce = ay_2ce + ce; + g0 = ay * rx * rx * rx; + g1 = ay_ce * rx * rx; + g2 = ay * rx * rx * rz; + g3 = ay_2ce * rx * ry; + g4 = ay_ce * rx * rz; + g5 = ay * rx * rz * rz; + g6 = ay_3ce * ry * ry; + g7 = ay_2ce * ry * rz; + g8 = ay_ce * rz * rz; + g9 = ay * rz * rz * rz; gtoy[ grid_id] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6; gtoy[1*ngrids+grid_id] = 2.890611442640554055 * g4; gtoy[2*ngrids+grid_id] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6); @@ -1062,16 +1160,19 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gtoy[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double az = ce_2a * rz; - g0 = az * rx * rx * rx; - g1 = az * rx * rx * ry; - g2 = (az * rz + ce) * rx * rx; - g3 = az * rx * ry * ry; - g4 = (az * rz + ce) * rx * ry; - g5 = (az * rz + 2 * ce) * rx * rz; - g6 = az * ry * ry * ry; - g7 = (az * rz + ce) * ry * ry; - g8 = (az * rz + 2 * ce) * ry * rz; - g9 = (az * rz + 3 * ce) * rz * rz; + double az_ce = az * rz + ce; + double az_2ce = az_ce + ce; + double az_3ce = az_2ce + ce; + g0 = az * rx * rx * rx; + g1 = az * rx * rx * ry; + g2 = az_ce * rx * rx; + g3 = az * rx * ry * ry; + g4 = az_ce * rx * ry; + g5 = az_2ce * rx * rz; + g6 = az * ry * ry * ry; + g7 = az_ce * ry * ry; + g8 = az_2ce * ry * rz; + g9 = az_3ce * rz * rz; gtoz[ grid_id] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6; gtoz[1*ngrids+grid_id] = 2.890611442640554055 * g4; gtoz[2*ngrids+grid_id] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6); From 6cf2ee3b6219bad9c0796cf18ee9c33f086e8e1c Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Wed, 10 Jan 2024 16:04:47 -0800 Subject: [PATCH 3/3] zero-copy np arrays (#76) * Add empty_mapped for zero-copy * Add takebak * Add test --- gpu4pyscf/lib/cupy_helper.py | 47 +++++++++++++++++++++- gpu4pyscf/lib/cupy_helper/take_last2d.cu | 50 ++++++++++++++++++++++-- gpu4pyscf/lib/tests/test_cupy_helper.py | 13 +++++- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index 3d0d25f5..bbaf2ebc 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -240,12 +240,12 @@ def block_diag(blocks, out=None): def take_last2d(a, indices, out=None): ''' - reorder the last 2 dimensions with 'indices', the first n-2 indices do not change - shape in the last 2 dimensions have to be the same + Reorder the last 2 dimensions as a[..., indices[:,None], indices] ''' assert a.flags.c_contiguous assert a.shape[-1] == a.shape[-2] nao = a.shape[-1] + assert len(indices) == nao if a.ndim == 2: count = 1 else: @@ -266,6 +266,35 @@ def take_last2d(a, indices, out=None): raise RuntimeError('failed in take_last2d kernel') return out +def takebak(out, a, indices, axis=-1): + '''(experimental) + Take elements from a NumPy array along an axis and write to CuPy array. + out[..., indices] = a + ''' + assert axis == -1 + assert isinstance(a, np.ndarray) + assert isinstance(out, cupy.ndarray) + assert out.ndim == a.ndim + assert a.shape[-1] == len(indices) + if a.ndim == 1: + count = 1 + else: + assert out.shape[:-1] == a.shape[:-1] + count = np.prod(a.shape[:-1]) + n_a = a.shape[-1] + n_o = out.shape[-1] + indices_int32 = cupy.asarray(indices, dtype=cupy.int32) + stream = cupy.cuda.get_current_stream() + err = libcupy_helper.takebak( + ctypes.c_void_p(stream.ptr), + ctypes.c_void_p(out.data.ptr), a.ctypes, + ctypes.c_void_p(indices_int32.data.ptr), + ctypes.c_int(count), ctypes.c_int(n_o), ctypes.c_int(n_a) + ) + if err != 0: # Not the mapped host memory + out[...,indices] = cupy.asarray(a) + return out + def transpose_sum(a, stream=None): ''' return a + a.transpose(0,2,1) @@ -497,3 +526,17 @@ def _qr(xs, dot, lindep=1e-14): def _gen_x0(v, xs): return cupy.dot(v.T, xs) + +def empty_mapped(shape, dtype=float, order='C'): + '''(experimental) + Returns a new, uninitialized NumPy array with the given shape and dtype. + + This is a convenience function which is just :func:`numpy.empty`, + except that the underlying buffer is a pinned and mapped memory. + This array can be used as the buffer of zero-copy memory. + ''' + nbytes = np.prod(shape) * np.dtype(dtype).itemsize + mem = cupy.cuda.PinnedMemoryPointer( + cupy.cuda.PinnedMemory(nbytes, cupy.cuda.runtime.hostAllocMapped), 0) + out = np.ndarray(shape, dtype=dtype, buffer=mem, order=order) + return out diff --git a/gpu4pyscf/lib/cupy_helper/take_last2d.cu b/gpu4pyscf/lib/cupy_helper/take_last2d.cu index 26a1a6a6..4b671211 100644 --- a/gpu4pyscf/lib/cupy_helper/take_last2d.cu +++ b/gpu4pyscf/lib/cupy_helper/take_last2d.cu @@ -17,11 +17,12 @@ #include #include #define THREADS 32 +#define COUNT_BLOCK 80 __global__ -static void _take(double *a, const double *b, int *indices, int n) +static void _take_last2d(double *a, const double *b, int *indices, int n) { - int i = blockIdx.z; + size_t i = blockIdx.z; int j = blockIdx.x * blockDim.x + threadIdx.x; int k = blockIdx.y * blockDim.y + threadIdx.y; if (j >= n || k >= n) { @@ -35,6 +36,27 @@ static void _take(double *a, const double *b, int *indices, int n) a[off + j * n + k] = b[off + j_b * n + k_b]; } +__global__ +static void _takebak(double *out, double *a, int *indices, + int count, int n_o, int n_a) +{ + int i0 = blockIdx.y * COUNT_BLOCK; + int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j > n_a) { + return; + } + + // a is on host with zero-copy memory. We need enough iterations for + // data prefetch to hide latency + int i1 = i0 + COUNT_BLOCK; + if (i1 > count) i1 = count; + int jp = indices[j]; +#pragma unroll + for (size_t i = i0; i < i1; ++i) { + out[i * n_o + jp] = a[i * n_a + j]; + } +} + extern "C" { int take_last2d(cudaStream_t stream, double *a, const double *b, int *indices, int blk_size, int n) { @@ -42,11 +64,33 @@ int take_last2d(cudaStream_t stream, double *a, const double *b, int *indices, i int ntile = (n + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile, blk_size); - _take<<>>(a, b, indices, n); + _take_last2d<<>>(a, b, indices, n); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; } return 0; } + +int takebak(cudaStream_t stream, double *out, double *a_h, int *indices, + int count, int n_o, int n_a) +{ + double *a_d; + cudaError_t err; + err = cudaHostGetDevicePointer(&a_d, a_h, 0); // zero-copy check + if (err != cudaSuccess) { + return 1; + } + + int ntile = (n_a + THREADS*THREADS - 1) / (THREADS*THREADS); + int ncount = (count + COUNT_BLOCK - 1) / COUNT_BLOCK; + dim3 threads(THREADS*THREADS); + dim3 blocks(ntile, ncount); + _takebak<<>>(out, a_d, indices, count, n_o, n_a); + err = cudaGetLastError(); + if (err != cudaSuccess) { + return 1; + } + return 0; +} } diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index 614e0fa3..fed01456 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -18,7 +18,7 @@ import cupy from gpu4pyscf.lib.cupy_helper import ( take_last2d, transpose_sum, krylov, unpack_sparse, - add_sparse) + add_sparse, takebak, empty_mapped) class KnownValues(unittest.TestCase): def test_take_last2d(self): @@ -69,6 +69,15 @@ def test_sparse(self): add_sparse(a, b, indices) assert cupy.linalg.norm(a - a0) < 1e-10 + def test_takebak(self): + a = empty_mapped((5, 8)) + a[:] = 1. + idx = numpy.arange(8) * 2 + out = cupy.zeros((5, 16)) + takebak(out, a, idx) + out[:,idx] -= 1. + assert abs(out).sum() == 0. + if __name__ == "__main__": print("Full tests for cupy helper module") - unittest.main() \ No newline at end of file + unittest.main()