Skip to content

Commit

Permalink
First round of mixed precision trial
Browse files Browse the repository at this point in the history
  • Loading branch information
henryw7 committed Feb 17, 2025
1 parent 7a6c92e commit 9b01e19
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 110 deletions.
220 changes: 220 additions & 0 deletions gpu4pyscf/lib/gvhf-mixed-precision/create_tasks.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cuda_runtime.h>

#include "vhf.cuh"
#include "mixed_precision_helper.h"

// 8-fold symmery
__device__
Expand Down Expand Up @@ -177,6 +178,225 @@ static int _fill_jk_tasks(ShellQuartet *shl_quartet_idx,
return ntasks;
}

// 8-fold symmery
template<typename FloatType>
__device__
static int _fill_jk_tasks_mixed_precision(ShellQuartet *shl_quartet_idx,
RysIntEnvVars envs, JKMatrix jk, BoundsInfo bounds,
int batch_ij, int batch_kl)
{
const int nbas = envs.nbas;
const int *tile_ij_mapping = bounds.tile_ij_mapping;
const int *tile_kl_mapping = bounds.tile_kl_mapping;
const float *q_cond = bounds.q_cond;
const float *tile_q_cond = bounds.tile_q_cond;
const float *dm_cond = bounds.dm_cond;

const float single_cutoff = bounds.cutoff;
const float double_cutoff = bounds.single_double_cutoff;
const float log_max_dm = bounds.log_max_dm;
float early_cutoff;
if constexpr (IS_DOUBLE(FloatType)) {
early_cutoff = double_cutoff - log_max_dm;
} else {
early_cutoff = single_cutoff - log_max_dm;
}

const int t_id = threadIdx.y * blockDim.x + threadIdx.x;
const int t_kl0 = batch_kl * TILES_IN_BATCH;
const int t_kl1 = MIN(t_kl0 + TILES_IN_BATCH, bounds.ntile_kl_pairs);
const int threads = blockDim.x * blockDim.y;

const int tile_ij = tile_ij_mapping[batch_ij];
const int nbas_tiles = nbas / TILE;
const int tile_i = tile_ij / nbas_tiles;
const int tile_j = tile_ij % nbas_tiles;
const int ish0 = tile_i * TILE;
const int jsh0 = tile_j * TILE;
const int ish1 = ish0 + TILE;
const int jsh1 = jsh0 + TILE;
const int do_j = jk.vj != NULL;
const int do_k = jk.vk != NULL;

int count = 0;
const float tile_q_ij = tile_q_cond[tile_ij];
for (int t_kl_id = t_kl0+t_id; t_kl_id < t_kl1; t_kl_id += threads) {
const int tile_kl = tile_kl_mapping[t_kl_id];
if (tile_q_ij + tile_q_cond[tile_kl] < early_cutoff) {
break;
}
const int tile_k = tile_kl / nbas_tiles;
const int tile_l = tile_kl % nbas_tiles;
const int ksh0 = tile_k * TILE;
const int lsh0 = tile_l * TILE;
const int ksh1 = ksh0 + TILE;
const int lsh1 = lsh0 + TILE;
for (int ish = ish0; ish < ish1; ++ish) {
for (int jsh = jsh0; jsh < MIN(ish+1, jsh1); ++jsh) {
const int bas_ij = ish * nbas + jsh;
const float q_ij = q_cond [bas_ij];
const float d_ij = dm_cond[bas_ij];
for (int ksh = ksh0; ksh < MIN(ish+1, ksh1); ++ksh) {
const float d_ik = dm_cond[ish*nbas+ksh];
const float d_jk = dm_cond[jsh*nbas+ksh];
for (int lsh = lsh0; lsh < MIN(ksh+1, lsh1); ++lsh) {
const int bas_kl = ksh * nbas + lsh;
if (bas_ij < bas_kl) {
continue;
}
const float q_ijkl = q_ij + q_cond[bas_kl];
if (q_ijkl < early_cutoff) {
continue;
}

if constexpr (IS_DOUBLE(FloatType)) {
const float density_cutoff_double = double_cutoff - q_ijkl;
const bool if_double_precision_job =
(do_k && (d_ik > density_cutoff_double ||
d_jk > density_cutoff_double ||
dm_cond[ish*nbas+lsh] > density_cutoff_double ||
dm_cond[jsh*nbas+lsh] > density_cutoff_double)) ||
(do_j && (d_ij > density_cutoff_double ||
dm_cond[bas_kl ] > density_cutoff_double));
if (if_double_precision_job) {
++count;
}
} else {
const float density_cutoff_double = double_cutoff - q_ijkl;
const bool if_double_precision_job =
(do_k && (d_ik > density_cutoff_double ||
d_jk > density_cutoff_double ||
dm_cond[ish*nbas+lsh] > density_cutoff_double ||
dm_cond[jsh*nbas+lsh] > density_cutoff_double)) ||
(do_j && (d_ij > density_cutoff_double ||
dm_cond[bas_kl ] > density_cutoff_double));
const float density_cutoff_single = single_cutoff - q_ijkl;
const bool if_single_precision_job =
(do_k && (d_ik > density_cutoff_single ||
d_jk > density_cutoff_single ||
dm_cond[ish*nbas+lsh] > density_cutoff_single ||
dm_cond[jsh*nbas+lsh] > density_cutoff_single)) ||
(do_j && (d_ij > density_cutoff_single ||
dm_cond[bas_kl ] > density_cutoff_single));
if (!if_double_precision_job && if_single_precision_job) {
++count;
}
}
}
}
}
}
}

// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
extern __shared__ int cum_count[];
cum_count[t_id] = count;
// Up-sweep phase
for (int stride = 1; stride < threads; stride *= 2) {
__syncthreads();
int index = (t_id + 1) * stride * 2 - 1;
if (index < threads) {
cum_count[index] += cum_count[index-stride];
}
}
__syncthreads();
// Down-sweep phase
for (int stride = threads/4; stride > 0; stride /= 2) {
__syncthreads();
int index = (t_id + 1) * stride * 2 - 1;
if (index + stride < threads) {
cum_count[index + stride] += cum_count[index];
}
}
__syncthreads();
int ntasks = cum_count[threads-1];
if (ntasks == 0) {
return ntasks;
}

int offset = 0;
if (t_id > 0) {
offset = cum_count[t_id-1];
}
for (int t_kl_id = t_kl0+t_id; t_kl_id < t_kl1; t_kl_id += threads) {
const int tile_kl = tile_kl_mapping[t_kl_id];
if (tile_q_ij + tile_q_cond[tile_kl] < early_cutoff) {
break;
}
const int tile_k = tile_kl / nbas_tiles;
const int tile_l = tile_kl % nbas_tiles;
const int ksh0 = tile_k * TILE;
const int lsh0 = tile_l * TILE;
const int ksh1 = ksh0 + TILE;
const int lsh1 = lsh0 + TILE;
ShellQuartet sq;
for (int ish = ish0; ish < ish1; ++ish) {
for (int jsh = jsh0; jsh < MIN(ish+1, jsh1); ++jsh) {
const int bas_ij = ish * nbas + jsh;
const float q_ij = q_cond [bas_ij];
const float d_ij = dm_cond[bas_ij];
sq.i = ish;
sq.j = jsh;
for (int ksh = ksh0; ksh < MIN(ish+1, ksh1); ++ksh) {
const float d_ik = dm_cond[ish*nbas+ksh];
const float d_jk = dm_cond[jsh*nbas+ksh];
for (int lsh = lsh0; lsh < MIN(ksh+1, lsh1); ++lsh) {
const int bas_kl = ksh * nbas + lsh;
if (bas_ij < bas_kl) {
continue;
}
const float q_ijkl = q_ij + q_cond[bas_kl];
if (q_ijkl < early_cutoff) {
continue;
}

if constexpr (IS_DOUBLE(FloatType)) {
const float density_cutoff_double = double_cutoff - q_ijkl;
const bool if_double_precision_job =
(do_k && (d_ik > density_cutoff_double ||
d_jk > density_cutoff_double ||
dm_cond[ish*nbas+lsh] > density_cutoff_double ||
dm_cond[jsh*nbas+lsh] > density_cutoff_double)) ||
(do_j && (d_ij > density_cutoff_double ||
dm_cond[bas_kl ] > density_cutoff_double));
if (if_double_precision_job) {
sq.k = ksh;
sq.l = lsh;
shl_quartet_idx[offset] = sq;
++offset;
}
} else {
const float density_cutoff_double = double_cutoff - q_ijkl;
const bool if_double_precision_job =
(do_k && (d_ik > density_cutoff_double ||
d_jk > density_cutoff_double ||
dm_cond[ish*nbas+lsh] > density_cutoff_double ||
dm_cond[jsh*nbas+lsh] > density_cutoff_double)) ||
(do_j && (d_ij > density_cutoff_double ||
dm_cond[bas_kl ] > density_cutoff_double));
const float density_cutoff_single = single_cutoff - q_ijkl;
const bool if_single_precision_job =
(do_k && (d_ik > density_cutoff_single ||
d_jk > density_cutoff_single ||
dm_cond[ish*nbas+lsh] > density_cutoff_single ||
dm_cond[jsh*nbas+lsh] > density_cutoff_single)) ||
(do_j && (d_ij > density_cutoff_single ||
dm_cond[bas_kl ] > density_cutoff_single));
if (!if_double_precision_job && if_single_precision_job) {
sq.k = ksh;
sq.l = lsh;
shl_quartet_idx[offset] = sq;
++offset;
}
}
}
}
}
}
}
return ntasks;
}

// 8-fold symmery
__device__
static int _fill_sr_jk_tasks(ShellQuartet *shl_quartet_idx,
Expand Down
5 changes: 5 additions & 0 deletions gpu4pyscf/lib/gvhf-mixed-precision/mixed_precision_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

#pragma once

#define IS_FLOAT(type) (sizeof(type) == sizeof(float))
#define IS_DOUBLE(type) (sizeof(type) == sizeof(double))
Loading

0 comments on commit 9b01e19

Please sign in to comment.