Skip to content

Commit

Permalink
fixed a bug in transpose_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 committed Jan 4, 2024
1 parent 69e9fa4 commit fd4d5f8
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 98 deletions.
6 changes: 3 additions & 3 deletions gpu4pyscf/__config__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# such as V100-32G
elif props['totalGlobalMem'] >= 32 * GB:
min_ao_blksize = 128
min_grid_blksize = 256*256#128*128
min_grid_blksize = 128*128
ao_aligned = 32
grid_aligned = 128
mem_fraction = 0.9
number_of_threads = 1024 * 80
# such as A30-24GB
elif props['totalGlobalMem'] >= 16 * GB:
min_ao_blksize = 128
min_grid_blksize = 256*256#128*128
min_grid_blksize = 128*128
ao_aligned = 32
grid_aligned = 128
mem_fraction = 0.9
Expand All @@ -35,4 +35,4 @@
mem_fraction = 0.9
number_of_threads = 1024 * 80

cupy.get_default_memory_pool().set_limit(fraction=mem_fraction)
cupy.get_default_memory_pool().set_limit(fraction=mem_fraction)
5 changes: 3 additions & 2 deletions gpu4pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyscf import lib
from pyscf.df import df, addons
from gpu4pyscf.lib.cupy_helper import (
cholesky, tag_array, get_avail_mem, cart2sph, take_last2d)
cholesky, tag_array, get_avail_mem, cart2sph, take_last2d, transpose_sum)
from gpu4pyscf.df import int3c2e, df_jk
from gpu4pyscf.lib import logger
from gpu4pyscf import __config__
Expand Down Expand Up @@ -262,7 +262,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False):
row = intopt.ao_pairs_row[cp_ij_id] - i0
col = intopt.ao_pairs_col[cp_ij_id] - j0
if cpi == cpj:
ints_slices = ints_slices + ints_slices.transpose([0,2,1])
#ints_slices = ints_slices + ints_slices.transpose([0,2,1])
transpose_sum(ints_slices)
ints_slices = ints_slices[:,col,row]

if cd_low.tag == 'eig':
Expand Down
5 changes: 2 additions & 3 deletions gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,14 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-
vj_packed += cupy.dot(rhoj, cderi_sparse.T)
if with_k:
rhok = contract('Lij,jk->Lki', cderi, occ_coeff)
#vk[0] += contract('Lki,Lkj->ij', rhok, rhok)
cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=1.0, beta=1.0, lower=True)
#vk[0] += 2.0 * contract('Lki,Lkj->ij', rhok, rhok)
cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=2.0, beta=1.0, lower=True)
if with_j:
vj[:,rows,cols] = vj_packed
vj[:,cols,rows] = vj_packed
if with_k:
vk[0][numpy.diag_indices(nao)] *= 0.5
transpose_sum(vk)
vk *= 2.0
# CP-HF K matrix
elif hasattr(dms_tag, 'mo1'):
if with_j:
Expand Down
21 changes: 1 addition & 20 deletions gpu4pyscf/dft/gen_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,8 @@ def gen_grids_partition(atm_coords, coords, a):
natm = atm_coords.shape[0]
ngrids = coords.shape[0]
assert ngrids < 65535 * 16
#x_i = cupy.expand_dims(atm_coords, axis=1)
#x_g = cupy.expand_dims(coords, axis=0)
#squared_diff = (x_i - x_g)**2
#dist_ig = cupy.sum(squared_diff, axis=2)**0.5

#x_j = cupy.expand_dims(atm_coords, axis=0)
#squared_diff = (x_i - x_j)**2
#dist_ij = cupy.sum(squared_diff, axis=2)**0.5

pbecke = cupy.ones([natm, ngrids], order='C')
'''
err = libgdft.GDFTgen_grid_partition(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(pbecke.data.ptr, ctypes.c_void_p),
ctypes.cast(dist_ig.data.ptr, ctypes.c_void_p),
ctypes.cast(dist_ij.data.ptr, ctypes.c_void_p),
ctypes.cast(a.data.ptr, ctypes.c_void_p),
ctypes.c_int(ngrids),
ctypes.c_int(natm)
)
'''
pbecke = cupy.empty([natm, ngrids], order='C')
atm_coords = cupy.asarray(atm_coords, order='F')
err = libgdft.GDFTgen_grid_partition(
ctypes.cast(stream.ptr, ctypes.c_void_p),
Expand Down
11 changes: 10 additions & 1 deletion gpu4pyscf/dft/tests/test_ao_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_ao_sph_deriv0(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=0)
print('sph_deriv0', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_sph_deriv1(self):
Expand All @@ -65,6 +66,7 @@ def test_ao_sph_deriv1(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=1)
print('sph_deriv1', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_sph_deriv2(self):
Expand All @@ -73,7 +75,7 @@ def test_ao_sph_deriv2(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=2)
#idx = cupy.argwhere(cupy.abs(ao_gpu - ao_cpu) > 1e-10)
print('sph_deriv2', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_sph_deriv3(self):
Expand All @@ -82,6 +84,7 @@ def test_ao_sph_deriv3(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=3)
print('sph_deriv3', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_sph_deriv4(self):
Expand All @@ -90,6 +93,7 @@ def test_ao_sph_deriv4(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=4)
print('sph_deriv4', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

# cart mol
Expand All @@ -99,6 +103,7 @@ def test_ao_cart_deriv0(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_cart, coords, deriv=0)
print('cart_deriv0', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_cart_deriv1(self):
Expand All @@ -107,6 +112,7 @@ def test_ao_cart_deriv1(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_cart, coords, deriv=1)
print('cart_deriv1', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_cart_deriv2(self):
Expand All @@ -115,6 +121,7 @@ def test_ao_cart_deriv2(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_cart, coords, deriv=2)
print('cart_deriv2', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_cart_deriv3(self):
Expand All @@ -123,6 +130,7 @@ def test_ao_cart_deriv3(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = ni.eval_ao(mol_cart, coords, deriv=3)
print('cart_deriv3', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

def test_ao_cart_deriv4(self):
Expand All @@ -131,6 +139,7 @@ def test_ao_cart_deriv4(self):
ao_cpu = cupy.asarray(ao)
ni = NumInt(xc='LDA')
ao_gpu = numint.eval_ao(ni, mol_cart, coords, deriv=4)
print('cart_deriv4', cupy.linalg.norm(ao_cpu - ao_gpu))
assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/dft/tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setUpModule():
O 0.000000 0.000000 0.117790
H 0.000000 0.755453 -0.471161
H 0.000000 -0.755453 -0.471161''',
basis = 'ccpvdz',
basis = 'ccpvqz',
charge = 1,
spin = 1, # = 2S = spin_up - spin_down
output = '/dev/null')
Expand Down
4 changes: 2 additions & 2 deletions gpu4pyscf/lib/cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def take_last2d(a, indices, out=None):
raise RuntimeError('failed in take_last2d kernel')
return out

def transpose_sum(a):
def transpose_sum(a, stream=None):
'''
transpose (0,2,1)
return a + a.transpose(0,2,1)
'''
assert a.flags.c_contiguous
assert a.ndim == 3
Expand Down
53 changes: 29 additions & 24 deletions gpu4pyscf/lib/cupy_helper/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,41 @@ static void _dsymm_triu(double *a, int n)
a[off + j * N + i] = a[off + i * N + j];
}

__global__
__global__
void _transpose_sum(double *a, int n)
{
if(blockIdx.x > blockIdx.y){
return;
}
__shared__ double block[BLOCK_DIM][BLOCK_DIM+1];

// read the matrix tile into shared memory
// load one element per thread from device memory (idata) and store it
// in transposed order in block[][]
unsigned int xIndex = blockIdx.x * BLOCK_DIM + threadIdx.x;
unsigned int yIndex = blockIdx.y * BLOCK_DIM + threadIdx.y;
unsigned int zIndex = blockIdx.z;
unsigned int off = zIndex * n * n;

if((xIndex < n) && (yIndex < n))
{
unsigned int index_in = yIndex * n + xIndex + off;
block[threadIdx.y][threadIdx.x] = a[index_in];
}
unsigned int blockx_off = blockIdx.x * BLOCK_DIM;
unsigned int blocky_off = blockIdx.y * BLOCK_DIM;
unsigned int x0 = blockx_off + threadIdx.x;
unsigned int y0 = blocky_off + threadIdx.y;
unsigned int x1 = blocky_off + threadIdx.x;
unsigned int y1 = blockx_off + threadIdx.y;
unsigned int z = blockIdx.z;

unsigned int off = n * n * z;
unsigned int xy0 = y0 * n + x0 + off;
unsigned int xy1 = y1 * n + x1 + off;

// synchronise to ensure all writes to block[][] have completed
__syncthreads();
if (x0 < n && y0 < n){
block[threadIdx.y][threadIdx.x] = a[xy0];
}
__syncthreads();
if (x1 < n && y1 < n){
block[threadIdx.x][threadIdx.y] += a[xy1];
}
__syncthreads();

// write the transposed matrix tile to global memory (odata) in linear order
xIndex = blockIdx.y * BLOCK_DIM + threadIdx.x;
yIndex = blockIdx.x * BLOCK_DIM + threadIdx.y;
if((xIndex < n) && (yIndex < n))
{
unsigned int index_out = yIndex * n + xIndex + off;
a[index_out] += block[threadIdx.x][threadIdx.y];
}
if(x0 < n && y0 < n){
a[xy0] = block[threadIdx.y][threadIdx.x];
}
if(x1 < n && y1 < n){
a[xy1] = block[threadIdx.x][threadIdx.y];
}
}

extern "C" {
Expand Down
52 changes: 26 additions & 26 deletions gpu4pyscf/lib/gdft/nr_eval_gto.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "nr_eval_gto.cuh"
#include "contract_rho.cuh"

#define THREADS 128
#define NG_PER_BLOCK 128
#define LMAX 8
#define GTO_MAX_CART 15

Expand Down Expand Up @@ -1121,15 +1121,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets)
g12 = ax * ry * ry * rz * rz;
g13 = ax * ry * rz * rz * rz;
g14 = ax * rz * rz * rz * rz;
gtox[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ;
gtox[ grid_id] = 2.503342941796704538 * (g1 - g6) ;
gtox[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11;
gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ;
gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11;
gtox[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14;
gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ;
gtox[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12;
gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6);
gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11);
gtox[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14;
gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7);
gtox[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0);
gtox[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ;
gtox[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10;
gtox[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3;

double ay = ce_2a * ry;
g0 = ay * rx * rx * rx * rx;
Expand All @@ -1147,15 +1147,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets)
g12 = (ay * ry + 2 * ce) * ry * rz * rz;
g13 = (ay * ry + ce) * rz * rz * rz;
g14 = ay * rz * rz * rz * rz;
gtoy[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ;
gtoy[ grid_id] = 2.503342941796704538 * (g1 - g6) ;
gtoy[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11;
gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ;
gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11;
gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14;
gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ;
gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12;
gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6);
gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11);
gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14;
gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7);
gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0);
gtoy[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ;
gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10;
gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3;

double az = ce_2a * rz;
g0 = az * rx * rx * rx * rx;
Expand All @@ -1173,15 +1173,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets)
g12 = (az * rz + 2 * ce) * ry * ry * rz;
g13 = (az * rz + 3 * ce) * ry * rz * rz;
g14 = (az * rz + 4 * ce) * rz * rz * rz;
gtoz[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ;
gtoz[ grid_id] = 2.503342941796704538 * (g1 - g6) ;
gtoz[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11;
gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ;
gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11;
gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14;
gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ;
gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12;
gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6);
gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11);
gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14;
gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7);
gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0);
gtoz[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ;
gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10;
gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3;
}
}

Expand Down Expand Up @@ -1559,8 +1559,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart,
offsets.bas_indices = bas_indices;
offsets.nbas = local_ctr_offsets[nctr];
offsets.nao = nao;
dim3 threads(THREADS);
dim3 blocks((ngrids+THREADS-1)/THREADS);
dim3 threads(NG_PER_BLOCK);
dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK);

for (int ictr = 0; ictr < nctr; ++ictr) {
int local_ish = local_ctr_offsets[ictr];
Expand Down Expand Up @@ -1706,8 +1706,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart,
int GDFTscreen_index(cudaStream_t stream, int *non0shl_idx, double cutoff,
double *grids, int ngrids, int *bas_loc, int nbas, int *bas)
{
dim3 threads(THREADS);
dim3 blocks((ngrids+THREADS-1)/THREADS);
dim3 threads(NG_PER_BLOCK);
dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK);

for (int shl_id = 0; shl_id < nbas; ++shl_id) {
int l = bas[ANG_OF+shl_id*BAS_SLOTS];
Expand Down
20 changes: 6 additions & 14 deletions gpu4pyscf/lib/gdft/vv10.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,28 +168,20 @@ static void vv10_grad_kernel(double *Fvec, const double *vvcoords, const double
__syncthreads();
for (int l = 0, M = min(NG_PER_BLOCK, vvngrids - j); l < M; ++l){
double3 xj_tmp = xj_t[l];
double pjx = xj_tmp.x;
double pjy = xj_tmp.y;
double pjz = xj_tmp.z;

// about 23 operations for each pair
double DX = pjx - xi;
double DY = pjy - yi;
double DZ = pjz - zi;
double DX = xj_tmp.x - xi;
double DY = xj_tmp.y - yi;
double DZ = xj_tmp.z - zi;
double R2 = DX*DX + DY*DY + DZ*DZ;

double3 kp_tmp = kp_t[l];
double Kpj = kp_tmp.x;
double W0pj = kp_tmp.y;
double RpWj = kp_tmp.z;

double gp = R2*W0pj + Kpj;
double gp = R2*kp_tmp.y + kp_tmp.x;
double g = R2*W0i + Ki;
double gt = g + gp;
double ggp = g * gp;
double ggt_gp = gt * ggp;
double T = RpWj / (ggt_gp * ggt_gp);
double Q = T * ((W0i*gp + W0pj*g)*gt + (W0i+W0pj)*ggp);
double T = kp_tmp.z / (ggt_gp * ggt_gp);
double Q = T * ((W0i*gp + kp_tmp.y*g)*gt + (W0i+kp_tmp.y)*ggp);

FX += Q * DX;
FY += Q * DY;
Expand Down
Loading

0 comments on commit fd4d5f8

Please sign in to comment.