Skip to content

Commit

Permalink
tag_array for tuple of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 committed Jan 21, 2024
1 parent 0728389 commit 3c1b461
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
4 changes: 3 additions & 1 deletion gpu4pyscf/lib/cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ class CPArrayWithTag(cupy.ndarray):
@functools.wraps(lib.tag_array)
def tag_array(a, **kwargs):
'''
a should be cupy/numpy array or tuple of cupy/numpy array
attach attributes to cupy ndarray for cupy array
attach attributes to numpy ndarray for numpy array
'''
if isinstance(a, cupy.ndarray):
if isinstance(a, cupy.ndarray) or isinstance(a[0], cupy.ndarray):
t = cupy.asarray(a).view(CPArrayWithTag)
if isinstance(a, CPArrayWithTag):
t.__dict__.update(a.__dict__)
Expand Down
39 changes: 19 additions & 20 deletions gpu4pyscf/solvent/tests/test_smd_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,25 @@
from gpu4pyscf.solvent import smd

def setUpModule():
global mol_list
xyz_list = ['CH4.xyz', 'CO2.xyz', 'H2SO4.xyz', 'H3PO4.xyz', 'HNO3.xyz', 'NaCl.xyz', 'NH4.xyz']
mol_list = {}
for xyz in xyz_list:
mol = gto.Mole()
mol.atom = xyz
mol.basis = 'sto3g'
mol.output = '/dev/null'
mol.build(verbose=0)
mol_list[xyz] = mol
global mol
mol = gto.Mole()
mol.atom = '''P 0.000 0.000 0.000
O 1.500 0.000 0.000
O -1.500 0.000 0.000
O 0.000 1.500 0.000
O 0.000 -1.500 0.000
H 1.000 1.000 0.000
H -1.000 -1.000 0.000
H 0.000 -2.000 0.000
'''
mol.basis = 'sto3g'
mol.output = '/dev/null'
mol.build(verbose=0)

def tearDownModule():
global mol_list
for xyz in mol_list:
mol_list[xyz].stdout.close()
del mol_list
global mol
mol.stdout.close()
del mol

def _check_grad(mol, solvent='water'):
natm = mol.natm
Expand Down Expand Up @@ -74,14 +77,10 @@ def _check_grad(mol, solvent='water'):

class KnownValues(unittest.TestCase):
def test_grad_water(self):
for xyz in mol_list:
print(f'running {xyz} with water')
_check_grad(mol_list[xyz], solvent='water')
_check_grad(mol, solvent='water')

def test_grad_solvent(self):
for xyz in mol_list:
print(f'running {xyz} with ethanol')
_check_grad(mol_list[xyz], solvent='ethanol')
_check_grad(mol, solvent='ethanol')

if __name__ == "__main__":
print("Full Tests for Gradient of SMD")
Expand Down

0 comments on commit 3c1b461

Please sign in to comment.