Skip to content

Commit

Permalink
Merge pull request #231 from denghuilu/master
Browse files Browse the repository at this point in the history
add max nbor size from 256 to 1024
  • Loading branch information
amcadmus authored Jun 19, 2020
2 parents 07c42c1 + 7f2bc85 commit 3718c7c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion source/lib/src/NNPInter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "SimulationRegion.h"
#include <stdexcept>

#define MAGIC_NUMBER 256
#define MAGIC_NUMBER 1024
typedef double compute_t;

#ifdef USE_CUDA_TOOLKIT
Expand Down
6 changes: 3 additions & 3 deletions source/op/cuda/descrpt_se_a.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include <cub/block/block_radix_sort.cuh>
#include <cuda_runtime.h>

#define MAGIC_NUMBER 256
#define MAGIC_NUMBER 1024

#ifdef HIGH_PREC
typedef double VALUETYPE;
Expand Down Expand Up @@ -339,8 +339,8 @@ void DescrptSeALauncher(const VALUETYPE* coord,
key,
i_idx
);
const int ITEMS_PER_THREAD = 4;
const int BLOCK_THREADS = 64;
const int ITEMS_PER_THREAD = 8;
const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
// BlockSortKernel<NeighborInfo, BLOCK_THREADS, ITEMS_PER_THREAD><<<g_grid_size, BLOCK_THREADS>>> (
BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER);

Expand Down
6 changes: 3 additions & 3 deletions source/op/cuda/descrpt_se_r.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
#include <cuda_runtime.h>
#include <fstream>

#define MAGIC_NUMBER 256
#define MAGIC_NUMBER 1024

#ifdef HIGH_PREC
typedef double VALUETYPE;
Expand Down Expand Up @@ -311,8 +311,8 @@ void DescrptSeRLauncher(const VALUETYPE* coord,
key,
i_idx
);
const int ITEMS_PER_THREAD = 4;
const int BLOCK_THREADS = 64;
const int ITEMS_PER_THREAD = 8;
const int BLOCK_THREADS = MAGIC_NUMBER / ITEMS_PER_THREAD;
BlockSortKernel<int_64, BLOCK_THREADS, ITEMS_PER_THREAD> <<<nloc, BLOCK_THREADS>>> (key, key + nloc * MAGIC_NUMBER);
format_nlist_fill_b_se_r<<<nblock, LEN>>> (
nlist,
Expand Down
3 changes: 2 additions & 1 deletion source/op/descrpt_se_a_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow; // NOLINT(build/namespaces)
#define MAGIC_NUMBER 256
#define MAGIC_NUMBER 1024

#ifdef HIGH_PREC
typedef double VALUETYPE ;
Expand Down Expand Up @@ -159,6 +159,7 @@ class DescrptSeAOp : public OpKernel {

OP_REQUIRES (context, (ntypes == int(sel_a.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
OP_REQUIRES (context, (ntypes == int(sel_r.size())), errors::InvalidArgument ("number of types should match the length of sel array"));
OP_REQUIRES (context, (nnei <= 1024), errors::InvalidArgument ("Assert failed, max neighbor size of atom(nnei) " + std::to_string(nnei) + " is larger than 1024, which currently is not supported by deepmd-kit."));

// Create output tensors
TensorShape descrpt_shape ;
Expand Down
1 change: 1 addition & 0 deletions source/op/descrpt_se_r_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class DescrptSeROp : public OpKernel {
OP_REQUIRES (context, (9 == box_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of box should be 9"));
OP_REQUIRES (context, (ndescrpt == avg_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of avg should be ndescrpt"));
OP_REQUIRES (context, (ndescrpt == std_tensor.shape().dim_size(1)), errors::InvalidArgument ("number of std should be ndescrpt"));
OP_REQUIRES (context, (nnei <= 1024), errors::InvalidArgument ("Assert failed, max neighbor size of atom(nnei) " + std::to_string(nnei) + " is larger than 1024, which currently is not supported by deepmd-kit."));

// Create output tensors
TensorShape descrpt_shape ;
Expand Down

0 comments on commit 3718c7c

Please sign in to comment.