Skip to content

Commit

Permalink
Mixed precision for unrolled kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
henryw7 committed Feb 19, 2025
1 parent 9b01e19 commit 257b3be
Show file tree
Hide file tree
Showing 5 changed files with 10,550 additions and 10,379 deletions.
1 change: 1 addition & 0 deletions gpu4pyscf/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ else()
set(CMAKE_CUDA_ARCHITECTURES "60-real;70-real;80-real;90-real")
endif()
message("CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
set(CMAKE_CUDA_FLAGS "--use_fast_math ${CMAKE_CUDA_FLAGS}")

if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE RELWITHDEBINFO)
Expand Down
19 changes: 19 additions & 0 deletions gpu4pyscf/lib/gvhf-mixed-precision/mixed_precision_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,22 @@

#define IS_FLOAT(type) (sizeof(type) == sizeof(float))
#define IS_DOUBLE(type) (sizeof(type) == sizeof(double))

template<typename FloatType>
class MixedPrecisionOperator {};

template<>
class MixedPrecisionOperator<double> {
public:
static __device__ double fp_exp(const double x) { return exp(x); }
static __device__ double fp_erf(const double x) { return erf(x); }
static __device__ double fp_sqrt(const double x) { return sqrt(x); }
};

template<>
class MixedPrecisionOperator<float> {
public:
static __device__ float fp_exp(const float x) { return expf(x); }
static __device__ float fp_erf(const float x) { return erff(x); }
static __device__ float fp_sqrt(const float x) { return sqrtf(x); }
};
6 changes: 3 additions & 3 deletions gpu4pyscf/lib/gvhf-mixed-precision/rys_contract_jk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static void rys_jk_general(RysIntEnvVars envs, JKMatrix jk, BoundsInfo bounds,
const FloatType aj = static_cast<FloatType>(expj[jp]);
const FloatType aij = ai + aj;
const FloatType theta_ij = ai * aj / aij;
const FloatType Kab = exp(-theta_ij * rr_ij);
const FloatType Kab = MixedPrecisionOperator<FloatType>::fp_exp(-theta_ij * rr_ij);
cicj_cache[ij*nsq_per_block] = fac_sym * static_cast<FloatType>(ci[ip]) * static_cast<FloatType>(cj[jp]) * Kab;
}
for (int gout_start = 0; gout_start < nfij*nfkl; gout_start+=gout_stride*GOUT_WIDTH) {
Expand All @@ -162,7 +162,7 @@ static void rys_jk_general(RysIntEnvVars envs, JKMatrix jk, BoundsInfo bounds,
const FloatType zlzk = rlrk[2*nsq_per_block];
const FloatType rr_kl = xlxk*xlxk + ylyk*ylyk + zlzk*zlzk;
const FloatType theta_kl = ak * al / akl;
const FloatType Kcd = exp(-theta_kl * rr_kl);
const FloatType Kcd = MixedPrecisionOperator<FloatType>::fp_exp(-theta_kl * rr_kl);
const FloatType ckcl = static_cast<FloatType>(ck[kp]) * static_cast<FloatType>(cl[lp]) * Kcd;
gx[0] = ckcl;
}
Expand All @@ -188,7 +188,7 @@ static void rys_jk_general(RysIntEnvVars envs, JKMatrix jk, BoundsInfo bounds,
Rpq[1*nsq_per_block] = ypq;
Rpq[2*nsq_per_block] = zpq;
const FloatType cicj = cicj_cache[ijp*nsq_per_block];
gy[0] = cicj / (aij*akl*sqrt(aij+akl));
gy[0] = cicj / (aij*akl* MixedPrecisionOperator<FloatType>::fp_sqrt(aij+akl));
}
const FloatType rr = xpq*xpq + ypq*ypq + zpq*zpq;
const FloatType theta = aij * akl / (aij + akl);
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/lib/gvhf-mixed-precision/rys_jk_driver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ __global__ void rys_jk_kernel(RysIntEnvVars envs, JKMatrix jk, BoundsInfo bounds
// extern int rys_j_unrolled(RysIntEnvVars *envs, JKMatrix *jk, BoundsInfo *bounds,
// ShellQuartet *pool, uint32_t *batch_head, int *scheme, int workers);
extern int rys_jk_unrolled(RysIntEnvVars *envs, JKMatrix *jk, BoundsInfo *bounds,
ShellQuartet *pool, uint32_t *batch_head, int *scheme, int workers);
ShellQuartet *pool, uint32_t *batch_head, int *scheme, int workers);
extern int os_jk_unrolled(RysIntEnvVars *envs, JKMatrix *jk, BoundsInfo *bounds,
ShellQuartet *pool, uint32_t *batch_head,
int *scheme, int workers, double omega);
Expand Down
Loading

0 comments on commit 257b3be

Please sign in to comment.