Skip to content

Commit

Permalink
Refactor:Replace the current fft with templates and polymorphism (#5410)
Browse files Browse the repository at this point in the history
* add the basic func of the file

* modify the Makefile

* delete file

* modify the position of the new fft

* modify the Makefile

* [pre-commit.ci lite] apply automatic fixes

* add the cpu float in the fft floder

* change the test file

* [pre-commit.ci lite] apply automatic fixes

* add the func in test

* add the float fft

* change ft into ft1

* add the file of the float_define and the device set

* delete the memory allocate in the ft

* [pre-commit.ci lite] apply automatic fixes

* add the Smart Pointer and the logic gate

* modify the position of the FFT

* change fft_bundle name

* save version of the pw_test and single version

* fix complie bug and change the fftwf logic

* add comments for the fft class

* modify the fft name and add comments

* modify the Makefile

* update the file

* update the format

* update the shared_ptr

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 8d6e593 commit 7c46674
Show file tree
Hide file tree
Showing 27 changed files with 1,875 additions and 114 deletions.
5 changes: 4 additions & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ VPATH=./src_global:\
./module_base/module_mixing:\
./module_md:\
./module_basis/module_pw:\
./module_basis/module_pw/module_fft:\
./module_esolver:\
./module_hsolver:\
./module_hsolver/kernels:\
Expand Down Expand Up @@ -168,7 +169,6 @@ OBJS_BASE=abfs-vector3_order.o\
memory_op.o\
device.o\


OBJS_CELL=atom_pseudo.o\
atom_spec.o\
pseudo.o\
Expand Down Expand Up @@ -414,6 +414,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\
psi_initializer_nao_random.o\

OBJS_PW=fft.o\
fft_bundle.o\
fft_base.o\
fft_cpu.o\
pw_basis.o\
pw_basis_k.o\
pw_basis_sup.o\
Expand Down
1 change: 0 additions & 1 deletion source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ list (APPEND LIBM_SRC
libm/sincos.cpp
)
endif()

add_library(
base
OBJECT
Expand Down
9 changes: 9 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
if (ENABLE_FLOAT_FFTW)
list (APPEND FFT_SRC
module_fft/fft_cpu_float.cpp
)
endif()
list(APPEND objects
fft.cpp
pw_basis.cpp
Expand All @@ -10,6 +15,10 @@ list(APPEND objects
pw_init.cpp
pw_transform.cpp
pw_transform_k.cpp
module_fft/fft_base.cpp
module_fft/fft_bundle.cpp
module_fft/fft_cpu.cpp
${FFT_SRC}
)

add_library(
Expand Down
55 changes: 28 additions & 27 deletions source/module_basis/module_pw/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
this->fftny = this->ny = ny_in;
if (this->gamma_only)
{
if (xprime)
if (xprime) {
this->fftnx = int(nx / 2) + 1;
else
} else {
this->fftny = int(ny / 2) + 1;
}
}
this->nz = nz_in;
this->ns = ns_in;
Expand All @@ -92,10 +93,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
int maxgrids = (nsz > nrxx) ? nsz : nrxx;
if (!this->mpifft)
{
z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
d_rspace = (double*)z_auxg;
// z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
// d_rspace = (double*)z_auxg;
// auxr_3d = static_cast<std::complex<double> *>(
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
#if defined(__CUDA) || defined(__ROCM)
Expand All @@ -105,15 +106,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
#endif // defined(__CUDA) || defined(__ROCM)
#if defined(__ENABLE_FLOAT_FFTW)
if (this->precision == "single")
{
c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
s_rspace = (float*)c_auxg;
}
#endif // defined(__ENABLE_FLOAT_FFTW)
// #if defined(__ENABLE_FLOAT_FFTW)
// if (this->precision == "single")
// {
// c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
// s_rspace = (float*)c_auxg;
// }
// #endif // defined(__ENABLE_FLOAT_FFTW)
}
else
{
Expand Down Expand Up @@ -353,62 +354,62 @@ void FFT::cleanFFT()
if (planzfor)
{
fftw_destroy_plan(planzfor);
planzfor = NULL;
planzfor = nullptr;
}
if (planzbac)
{
fftw_destroy_plan(planzbac);
planzbac = NULL;
planzbac = nullptr;
}
if (planxfor1)
{
fftw_destroy_plan(planxfor1);
planxfor1 = NULL;
planxfor1 = nullptr;
}
if (planxbac1)
{
fftw_destroy_plan(planxbac1);
planxbac1 = NULL;
planxbac1 = nullptr;
}
if (planxfor2)
{
fftw_destroy_plan(planxfor2);
planxfor2 = NULL;
planxfor2 = nullptr;
}
if (planxbac2)
{
fftw_destroy_plan(planxbac2);
planxbac2 = NULL;
planxbac2 = nullptr;
}
if (planyfor)
{
fftw_destroy_plan(planyfor);
planyfor = NULL;
planyfor = nullptr;
}
if (planybac)
{
fftw_destroy_plan(planybac);
planybac = NULL;
planybac = nullptr;
}
if (planxr2c)
{
fftw_destroy_plan(planxr2c);
planxr2c = NULL;
planxr2c = nullptr;
}
if (planxc2r)
{
fftw_destroy_plan(planxc2r);
planxc2r = NULL;
planxc2r = nullptr;
}
if (planyr2c)
{
fftw_destroy_plan(planyr2c);
planyr2c = NULL;
planyr2c = nullptr;
}
if (planyc2r)
{
fftw_destroy_plan(planyc2r);
planyc2r = NULL;
planyc2r = nullptr;
}

// fftw_destroy_plan(this->plan3dforward);
Expand Down
8 changes: 8 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "fft_base.h"
namespace ModulePW
{
template FFT_BASE<float>::FFT_BASE();
template FFT_BASE<double>::FFT_BASE();
template FFT_BASE<float>::~FFT_BASE();
template FFT_BASE<double>::~FFT_BASE();
}
163 changes: 163 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include <complex>
#include <string>
#include "fftw3.h"
#ifndef FFT_BASE_H
#define FFT_BASE_H
namespace ModulePW
{
template <typename FPTYPE>
class FFT_BASE
{
public:

FFT_BASE(){};
virtual ~FFT_BASE(){};

/**
* @brief Initialize the fft parameters As virtual function.
*
* The function is used to initialize the fft parameters.
*/
virtual __attribute__((weak))
void initfft(int nx_in,
int ny_in,
int nz_in,
int lixy_in,
int rixy_in,
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool xprime_in = true);

/**
* @brief Setup the fft Plan and data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to setup the fft Plan and data.
*/
virtual void setupFFT()=0;

/**
* @brief Clean the fft Plan As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clean the fft Plan.
*/
virtual void cleanFFT()=0;

/**
* @brief Clear the fft data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clear the fft data.
*/
virtual void clear()=0;

/**
* @brief Get the real space data in cpu-like fft
*
* The function is used to get the real space data.While the
* FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
FPTYPE* get_rspace_data() const;

virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_data() const;

virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxg_data() const;

/**
* @brief Get the auxiliary real space data in 3D
*
* The function is used to get the auxiliary real space data in 3D.
* While the FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_3d_data() const;

//forward fft in x-y direction

/**
* @brief Forward FFT in x-y direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the x-y direction.
* It involves two axes, x and y. The FFT is applied multiple times
* along the left and right boundaries in the primary direction(which is
* determined by the xprime flag).Notably, the Y axis operates in
* "many-many-FFT" mode.
*/
virtual __attribute__((weak))
void fftxyfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftxybac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

/**
* @brief Forward FFT in z direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the z direction.
* It involves only one axis, z. The FFT is applied only once.
* Notably, the Z axis operates in many FFT with nz*ns.
*/
virtual __attribute__((weak))
void fftzfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftzbac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

/**
* @brief Forward FFT in x-y direction with real to complex
* @param in input data, real type
* @param out output data, complex type
*
* This function performs the forward FFT in the x-y direction
* with real to complex.There is no difference between fftxyfor.
*/
virtual __attribute__((weak))
void fftxyr2c(FPTYPE* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftxyc2r(std::complex<FPTYPE>* in,
FPTYPE* out) const;

/**
* @brief Forward FFT in 3D
* @param in input data
* @param out output data
*
* This function performs the forward FFT for gpu-like fft.
* It involves three axes, x, y, and z. The FFT is applied multiple times
* for fft3D_forward.
*/
virtual __attribute__((weak))
void fft3D_forward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fft3D_backward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

protected:
int nx=0;
int ny=0;
int nz=0;
};
}
#endif // FFT_BASE_H
Loading

0 comments on commit 7c46674

Please sign in to comment.