Skip to content

Commit

Permalink
Merge pull request #15 from zonca/almxfl
Browse files Browse the repository at this point in the history
bugfix and almxfl
  • Loading branch information
mreineck authored Jan 17, 2018
2 parents d24879b + fc32c39 commit 593d4eb
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 18 deletions.
15 changes: 13 additions & 2 deletions python/libsharp/libsharp.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cdef extern from "sharp.h":
ctypedef long ptrdiff_t

void sharp_legendre_transform_s(float *bl, float *recfac, ptrdiff_t lmax, float *x,
float *out, ptrdiff_t nx)
Expand All @@ -11,7 +10,19 @@ cdef extern from "sharp.h":

# sharp_lowlevel.h
ctypedef struct sharp_alm_info:
pass
# Maximum \a l index of the array
int lmax
# Number of different \a m values in this object
int nm
# Array with \a nm entries containing the individual m values
int *mval
# Combination of flags from sharp_almflags
int flags
# Array with \a nm entries containing the (hypothetical) indices of
# the coefficients with quantum numbers 0,\a mval[i]
long *mvstart
# Stride between a_lm and a_(l+1),m
long stride

ctypedef struct sharp_geom_info:
pass
Expand Down
59 changes: 56 additions & 3 deletions python/libsharp/libsharp.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
cimport numpy as np
cimport cython

__all__ = ['legendre_transform', 'legendre_roots', 'sht', 'synthesis', 'adjoint_synthesis',
Expand Down Expand Up @@ -62,7 +63,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
cdef int r
cdef sharp_jobtype jobtype_i
cdef double[:, :, ::1] output_buf
cdef int ntrans = input.shape[0] * input.shape[1]
cdef int ntrans = input.shape[0]
cdef int ntotcomp = ntrans * input.shape[1]
cdef int i, j

if spin == 0 and input.shape[1] != 1:
Expand All @@ -71,9 +73,9 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
raise ValueError('For spin != 0, we need input.shape[1] == 2')


cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
cdef size_t[::1] ptrbuf = np.empty(2 * ntotcomp, dtype=np.uintp)
cdef double **alm_ptrs = <double**>&ptrbuf[0]
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
cdef double **map_ptrs = <double**>&ptrbuf[ntotcomp]

try:
jobtype_i = JOBTYPE_TO_CONST[jobtype]
Expand Down Expand Up @@ -230,11 +232,62 @@ cdef class alm_info:
raise NotInitializedError()
return sharp_alm_count(self.ainfo)

def mval(self):
if self.ainfo == NULL:
raise NotInitializedError()
return np.asarray(<int[:self.ainfo.nm]> self.ainfo.mval)

def mvstart(self):
if self.ainfo == NULL:
raise NotInitializedError()
return np.asarray(<long[:self.ainfo.nm]> self.ainfo.mvstart)

def __dealloc__(self):
if self.ainfo != NULL:
sharp_destroy_alm_info(self.ainfo)
self.ainfo = NULL

@cython.boundscheck(False)
def almxfl(self, np.ndarray[double, ndim=3, mode='c'] alm, np.ndarray[double, ndim=2, mode='c'] fl):
"""Multiply Alm by a Ell based array
Parameters
----------
alm : np.ndarray
input alm, 3 dimensions = (different signal x polarizations x lm-ordering)
fl : np.ndarray
either 1 dimension, e.g. gaussian beam, or 2 dimensions e.g. a polarized beam
Returns
-------
None, it modifies alms in-place
"""
cdef int mvstart = 0
cdef bint has_multiple_beams = alm.shape[2] > 1 and fl.shape[1] > 1
cdef int f, i_m, m, num_ells, i_l, i_signal, i_pol, i_mv

for i_m in range(self.ainfo.nm):
m = self.ainfo.mval[i_m]
f = 1 if (m==0) else 2
num_ells = self.ainfo.lmax + 1 - m

if not has_multiple_beams:
for i_signal in range(alm.shape[0]):
for i_pol in range(alm.shape[1]):
for i_l in range(num_ells):
l = m + i_l
for i_mv in range(mvstart + f*i_l, mvstart + f*i_l +f):
alm[i_signal, i_pol, i_mv] *= fl[l, 0]
else:
for i_signal in range(alm.shape[0]):
for i_pol in range(alm.shape[1]):
for i_l in range(num_ells):
l = m + i_l
for i_mv in range(mvstart + f*i_l, mvstart + f*i_l +f):
alm[i_signal, i_pol, i_mv] *= fl[l, i_pol]
mvstart += f * num_ells

cdef class triangular_order(alm_info):
def __init__(self, int lmax, mmax=None, stride=1):
Expand Down
12 changes: 5 additions & 7 deletions python/libsharp/tests/test_sht.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import numpy as np
import healpy
from scipy.special import legendre
from scipy.special import p_roots
from numpy.testing import assert_allclose
import libsharp

Expand All @@ -28,7 +25,8 @@ def test_basic():
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
map = map[0, 0, :]
if rank == 0:
healpy.mollzoom(map)
from matplotlib.pyplot import show
show()
print(rank, "shape", map.shape)
print(rank, "mean", map.mean())

if __name__=="__main__":
test_basic()
137 changes: 137 additions & 0 deletions python/libsharp/tests/test_smoothing_noise_pol_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# This test needs to be run with:

# mpirun -np X python test_smoothing_noise_pol_mpi.py

from mpi4py import MPI

import numpy as np

import healpy as hp

import libsharp

mpi = True
rank = MPI.COMM_WORLD.Get_rank()

nside = 256
npix = hp.nside2npix(nside)

np.random.seed(100)
input_map = np.random.normal(size=(3, npix))
fwhm_deg = 10
lmax = 512

nrings = 4 * nside - 1 # four missing pixels

if rank == 0:
print("total rings", nrings)

n_mpi_processes = MPI.COMM_WORLD.Get_size()
rings_per_process = nrings // n_mpi_processes + 1
# ring indices are 1-based

ring_indices_emisphere = np.arange(2*nside, dtype=np.int32) + 1
local_ring_indices = ring_indices_emisphere[rank::n_mpi_processes]

# to improve performance, simmetric rings north/south need to be in the same rank
# therefore we use symmetry to create the full ring indexing

if local_ring_indices[-1] == 2 * nside:
# has equator ring
local_ring_indices = np.concatenate(
[local_ring_indices[:-1],
nrings - local_ring_indices[::-1] + 1]
)
else:
# does not have equator ring
local_ring_indices = np.concatenate(
[local_ring_indices,
nrings - local_ring_indices[::-1] + 1]
)

print("rank", rank, "n_rings", len(local_ring_indices))

if not mpi:
local_ring_indices = None
grid = libsharp.healpix_grid(nside, rings=local_ring_indices)

# returns start index of the ring and number of pixels
startpix, ringpix, _, _, _ = hp.ringinfo(nside, local_ring_indices.astype(np.int64))

local_npix = grid.local_size()

def expand_pix(startpix, ringpix, local_npix):
"""Turn first pixel index and number of pixel in full array of pixels
to be optimized with cython or numba
"""
local_pix = np.empty(local_npix, dtype=np.int64)
i = 0
for start, num in zip(startpix, ringpix):
local_pix[i:i+num] = np.arange(start, start+num)
i += num
return local_pix

local_pix = expand_pix(startpix, ringpix, local_npix)

local_map = input_map[:, local_pix]

local_hitmap = np.zeros(npix)
local_hitmap[local_pix] = 1
hp.write_map("hitmap_{}.fits".format(rank), local_hitmap, overwrite=True)

print("rank", rank, "npix", npix, "local_npix", local_npix, "local_map len", len(local_map), "unique pix", len(np.unique(local_pix)))

local_m_indices = np.arange(rank, lmax + 1, MPI.COMM_WORLD.Get_size(), dtype=np.int32)
if not mpi:
local_m_indices = None

order = libsharp.packed_real_order(lmax, ms=local_m_indices)
local_nl = order.local_size()
print("rank", rank, "local_nl", local_nl, "mval", order.mval())

mpi_comm = MPI.COMM_WORLD if mpi else None

# map2alm
# maps in libsharp are 3D, 2nd dimension is IQU, 3rd is pixel

alm_sharp_I = libsharp.analysis(grid, order,
np.ascontiguousarray(local_map[0].reshape((1, 1, -1))),
spin=0, comm=mpi_comm)
alm_sharp_P = libsharp.analysis(grid, order,
np.ascontiguousarray(local_map[1:].reshape((1, 2, -1))),
spin=2, comm=mpi_comm)

beam = hp.gauss_beam(fwhm=np.radians(fwhm_deg), lmax=lmax, pol=True)

print("Smooth")
# smooth in place (zonca implemented this function)
order.almxfl(alm_sharp_I, np.ascontiguousarray(beam[:, 0:1]))
order.almxfl(alm_sharp_P, np.ascontiguousarray(beam[:, (1, 2)]))

# alm2map

new_local_map_I = libsharp.synthesis(grid, order, alm_sharp_I, spin=0, comm=mpi_comm)
new_local_map_P = libsharp.synthesis(grid, order, alm_sharp_P, spin=2, comm=mpi_comm)

# Transfer map to first process for writing

local_full_map = np.zeros(input_map.shape, dtype=np.float64)
local_full_map[0, local_pix] = new_local_map_I
local_full_map[1:, local_pix] = new_local_map_P

output_map = np.zeros(input_map.shape, dtype=np.float64) if rank == 0 else None
mpi_comm.Reduce(local_full_map, output_map, root=0, op=MPI.SUM)

if rank == 0:
# hp.write_map("sharp_smoothed_map.fits", output_map, overwrite=True)
# hp_smoothed = hp.alm2map(hp.map2alm(input_map, lmax=lmax), nside=nside) # transform only
hp_smoothed = hp.smoothing(input_map, fwhm=np.radians(fwhm_deg), lmax=lmax)
std_diff = (hp_smoothed-output_map).std()
print("Std of difference between libsharp and healpy", std_diff)
# hp.write_map(
# "healpy_smoothed_map.fits",
# hp_smoothed,
# overwrite=True
# )
assert std_diff < 1e-5
11 changes: 5 additions & 6 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
sys.path.append(os.path.join(project_path, 'fake_pyrex'))

from setuptools import setup, find_packages, Extension
from Cython.Distutils import build_ext
from Cython.Build import cythonize
import numpy as np

libsharp = os.environ.get('LIBSHARP', None)
Expand Down Expand Up @@ -64,21 +64,20 @@
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU General Public License (GPL)',
'Topic :: Scientific/Engineering'],
cmdclass = {"build_ext": build_ext},
ext_modules = [
ext_modules = cythonize([
Extension("libsharp.libsharp",
["libsharp/libsharp.pyx"],
libraries=["sharp", "fftpack", "c_utils"],
include_dirs=[libsharp_include],
include_dirs=[libsharp_include, np.get_include()],
library_dirs=[libsharp_lib],
extra_link_args=["-fopenmp"],
),
Extension("libsharp.libsharp_mpi",
["libsharp/libsharp_mpi.pyx"],
libraries=["sharp", "fftpack", "c_utils"],
include_dirs=[libsharp_include],
include_dirs=[libsharp_include, np.get_include()],
library_dirs=[libsharp_lib],
extra_link_args=["-fopenmp"],
),
],
]),
)

0 comments on commit 593d4eb

Please sign in to comment.