Skip to content

Commit

Permalink
Support complex128 arrays in cupy_helper.contract function (pyscf#245)
Browse files Browse the repository at this point in the history
* Support complex128 arrays in cupy_helper.contract function

* Update cusolver for complex matrices

---------

Co-authored-by: Qiming Sun <[email protected]>
  • Loading branch information
sunqm and Qiming Sun authored Nov 13, 2024
1 parent 6f45bd7 commit ec72d46
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 34 deletions.
69 changes: 57 additions & 12 deletions gpu4pyscf/lib/cusolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,59 @@
ctypes.c_void_p # *devInfo
]

# https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdn-t-sygvd
libcusolver.cusolverDnZhegvd_bufferSize.argtypes = [
ctypes.c_void_p, # handle
ctypes.c_int, # itype
ctypes.c_int, # jobz
ctypes.c_int, # uplo
ctypes.c_int, # n
ctypes.c_void_p, # *A
ctypes.c_int, # lda
ctypes.c_void_p, # *B
ctypes.c_int, # ldb
ctypes.c_void_p, # *w
ctypes.c_void_p # *lwork
]

libcusolver.cusolverDnZhegvd.argtypes = [
ctypes.c_void_p, # handle
ctypes.c_int, # itype
ctypes.c_int, # jobz
ctypes.c_int, # uplo
ctypes.c_int, # n
ctypes.c_void_p, # *A
ctypes.c_int, # lda
ctypes.c_void_p, # *B
ctypes.c_int, # ldb
ctypes.c_void_p, # *w
ctypes.c_void_p, # *work
ctypes.c_int, # lwork
ctypes.c_void_p # *devInfo
]

def eigh(h, s):
'''
solve generalized eigenvalue problem
'''
assert h.dtype == s.dtype
assert h.dtype in (np.float64, np.complex128)
n = h.shape[0]
w = cupy.zeros(n)
A = h.copy()
B = s.copy()
_handle = device.get_cusolver_handle()

# TODO: reuse workspace
if n in _buffersize:
lwork = _buffersize[n]
if (h.dtype, n) in _buffersize:
lwork = _buffersize[h.dtype, n]
else:
lwork = ctypes.c_int()
status = libcusolver.cusolverDnDsygvd_bufferSize(
lwork = ctypes.c_int(0)
if h.dtype == np.float64:
fn = libcusolver.cusolverDnDsygvd_bufferSize
else:
fn = libcusolver.cusolverDnZhegvd_bufferSize
status = fn(
_handle,
CUSOLVER_EIG_TYPE_1,
CUSOLVER_EIG_MODE_VECTOR,
Expand All @@ -98,10 +135,14 @@ def eigh(h, s):

if status != 0:
raise RuntimeError("failed in buffer size")

work = cupy.empty(lwork)

if h.dtype == np.float64:
fn = libcusolver.cusolverDnDsygvd
else:
fn = libcusolver.cusolverDnZhegvd
work = cupy.empty(lwork, dtype=h.dtype)
devInfo = cupy.empty(1, dtype=np.int32)
status = libcusolver.cusolverDnDsygvd(
status = fn(
_handle,
CUSOLVER_EIG_TYPE_1,
CUSOLVER_EIG_MODE_VECTOR,
Expand All @@ -116,7 +157,7 @@ def eigh(h, s):
lwork,
devInfo.data.ptr
)

if status != 0:
raise RuntimeError("failed in eigh kernel")
return w, A.T
Expand All @@ -126,15 +167,19 @@ def cholesky(A):
assert A.flags['C_CONTIGUOUS']
x = A.copy()
handle = device.get_cusolver_handle()
potrf = cusolver.dpotrf
potrf_bufferSize = cusolver.dpotrf_bufferSize
if A.dtype == np.float64:
potrf = cusolver.dpotrf
potrf_bufferSize = cusolver.dpotrf_bufferSize
else:
potrf = cusolver.zpotrf
potrf_bufferSize = cusolver.zpotrf_bufferSize
buffersize = potrf_bufferSize(handle, cublas.CUBLAS_FILL_MODE_UPPER, n, x.data.ptr, n)
workspace = cupy.empty(buffersize)
workspace = cupy.empty(buffersize, dtype=A.dtype)
dev_info = cupy.empty(1, dtype=np.int32)
potrf(handle, cublas.CUBLAS_FILL_MODE_UPPER, n, x.data.ptr, n,
workspace.data.ptr, buffersize, dev_info.data.ptr)

if dev_info[0] != 0:
raise RuntimeError('failed to perform Cholesky Decomposition')
cupy.linalg._util._tril(x,k=0)
return x
return x
42 changes: 21 additions & 21 deletions gpu4pyscf/lib/cutensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,20 @@ def _auto_create_mode(array, mode):
'ndim mismatch: {} != {}'.format(array.ndim, mode.ndim))
return mode

def _create_tensor_descriptor(a):
handle = cutensor._get_handle()
key = (handle.ptr, a.dtype, tuple(a.shape), tuple(a.strides))
# hard coded
alignment_req = 8
if key not in _tensor_descriptors:
num_modes = a.ndim
extent = np.array(a.shape, dtype=np.int64)
stride = np.array(a.strides, dtype=np.int64) // a.itemsize
cutensor_dtype = cutensor._get_cutensor_dtype(a.dtype)
_tensor_descriptors[key] = cutensor.TensorDescriptor(
handle.ptr, num_modes, extent.ctypes.data, stride.ctypes.data,
cutensor_dtype, alignment_req=alignment_req)
return _tensor_descriptors[key]
#def _create_tensor_descriptor(a):
# handle = cutensor._get_handle()
# key = (handle.ptr, a.dtype, tuple(a.shape), tuple(a.strides))
# # hard coded
# alignment_req = 8
# if key not in _tensor_descriptors:
# num_modes = a.ndim
# extent = np.array(a.shape, dtype=np.int64)
# stride = np.array(a.strides, dtype=np.int64) // a.itemsize
# cutensor_dtype = cutensor._get_cutensor_dtype(a.dtype)
# _tensor_descriptors[key] = cutensor.TensorDescriptor(
# handle.ptr, num_modes, extent.ctypes.data, stride.ctypes.data,
# cutensor_dtype, alignment_req=alignment_req)
# return _tensor_descriptors[key]

def contraction(
pattern, a, b, alpha, beta,
Expand All @@ -80,14 +80,14 @@ def contraction(
mode_b = list(str_b)
mode_c = list(str_c)

if(out is not None):
c = out
else:
c = cupy.empty([shape[k] for k in str_c], order='C')
if out is None:
dtype = np.result_type(a, b, alpha)
out = cupy.empty([shape[k] for k in str_c], order='C', dtype=dtype)
c = out

desc_a = _create_tensor_descriptor(a)
desc_b = _create_tensor_descriptor(b)
desc_c = _create_tensor_descriptor(c)
desc_a = cutensor.create_tensor_descriptor(a)
desc_b = cutensor.create_tensor_descriptor(b)
desc_c = cutensor.create_tensor_descriptor(c)

mode_a = _auto_create_mode(a, mode_a)
mode_b = _auto_create_mode(b, mode_b)
Expand Down
64 changes: 64 additions & 0 deletions gpu4pyscf/lib/tests/test_cusolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import unittest
import numpy as np
import scipy.linalg
import cupy as cp
from gpu4pyscf.lib.cusolver import eigh, cholesky

def test_eigh_real():
np.random.seed(6)
n = 12
a = np.random.rand(n, n)
a = a + a.T
b = np.random.rand(n, n)
b = b.dot(b.T)
ref = scipy.linalg.eigh(a, b)
e, c = eigh(cp.asarray(a), cp.asarray(b))
assert abs(e.get() - ref[0]).max() < 1e-10
ovlp = c.get().T.dot(b).dot(ref[1])
assert abs(abs(ovlp) - np.eye(n)).max() < 1e-10

def test_eigh_cmplx():
np.random.seed(6)
n = 12
a = np.random.rand(n, n) + np.random.rand(n, n) * 1j
a = a + a.conj().T
b = np.random.rand(n, n) + np.random.rand(n, n) * 1j
b = b.dot(b.conj().T)
ref = scipy.linalg.eigh(a, b)
e, c = eigh(cp.asarray(a), cp.asarray(b))
assert abs(e.get() - ref[0]).max() < 1e-10
ovlp = c.get().T.dot(b).dot(ref[1])
assert abs(abs(ovlp) - np.eye(n)).max() < 1e-10

def test_cholesky_real():
np.random.seed(6)
n = 12
a = np.random.rand(n, n)
a = a.dot(a.T)
ref = np.linalg.cholesky(a)
x = cholesky(cp.asarray(a))
assert abs(x.get() - ref).max() < 1e-12

def test_cholesky_cmplx():
np.random.seed(6)
n = 12
a = np.random.rand(n, n) + np.random.rand(n, n) * 1j
a = a.dot(a.conj().T)
ref = np.linalg.cholesky(a)
x = cholesky(cp.asarray(a))
assert abs(x.get() - ref).max() < 1e-12
9 changes: 8 additions & 1 deletion gpu4pyscf/lib/tests/test_cutensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def test_contract(self):
c_contract = contract('lkji,jk->il', a, b[10:20,10:20])
assert cupy.linalg.norm(c_einsum - c_contract) < 1e-10

def test_complex_valued(self):
a = cupy.random.rand(10,9,11) + cupy.random.rand(10,9,11)*1j
b = cupy.random.rand(11,7,13) + cupy.random.rand(11,7,13)*1j
c_einsum = cupy.einsum('ijk,ikl->jl', a[3:9,:,4:10], b[3:9,:6, 7:13])
c_contract = contract('ijk,ikl->jl', a[3:9,:,4:10], b[3:9,:6, 7:13])
assert cupy.linalg.norm(c_einsum - c_contract) < 1e-10

def test_cache(self):
a = cupy.random.rand(20,20,20,20)
b = cupy.random.rand(20,20)
Expand All @@ -52,4 +59,4 @@ def test_cache(self):

if __name__ == "__main__":
print("Full tests for cutensor module")
unittest.main()
unittest.main()

0 comments on commit ec72d46

Please sign in to comment.