"""
CLEAN algorithms.
The CLEAN algorithm solves the deconvolution problem by assuming a model for the true sky intensity
which is a collection of point sources or in the case of multiscale clean a collection of
appropriate component shapes at different scales.
"""
from typing import Union, Optional
from collections.abc import Iterable
import astropy.units as u
import numpy as np
from astropy.convolution import Gaussian2DKernel
from astropy.units import Quantity
from numpy.typing import NDArray
from scipy import signal
from scipy.ndimage import shift
from sunpy.map.map_factory import Map
from xrayvision.imaging import vis_psf_image, vis_to_map
from xrayvision.utils import get_logger
from xrayvision.visibility import Visibilities
__all__ = ["clean", "vis_clean", "ms_clean", "vis_ms_clean"]
logger = get_logger(__name__, "DEBUG")
__common_clean_doc__ = r"""
clean_beam_width :
The width of the gaussian to convolve the model with. If set to 0.0 \
the gaussian to convolution is disabled
gain :
The gain per loop or loop gain
thres :
Terminates clean when ``residual.max() <= thres``
niter :
Maximum number of iterations to perform
Returns
-------
:
The CLEAN image 2D
Notes
-----
The CLEAN algorithm can be summarised in pesudo code as follows:
.. math::
& \textrm{CLEAN} \left (I^{D}(l, m),\ B(l,m),\ \gamma,\ f_{Thresh},\ N \right ) \\
& I^{Res} = I^{D},\ M = \{\},\ i=0 \\
& \textbf{while} \ \operatorname{max} I^{Res} > f_{Thresh} \ \textrm{and} \ i \lt N \
\textbf{do:} \\
& \qquad l_{max}, m_{max} = \underset{l,m}{\operatorname{argmax}} I^{Res}(l,m) \\
& \qquad f_{max} = I^{Res}(l_{max}, m_{max}) \\
& \qquad I^{Res} = I^{Res} - \alpha \cdot f_{max} \cdot \operatorname{shift} \left
( B(l,m), l_{max}, m_{max} \right ) \\
& \qquad M = M + \{ l_{max}, m_{max}: \alpha \cdot f_{max} \} \\
& \qquad i = i + 1 \\
& \textbf{done} \\
& \textbf{return}\ M,\ I^{Res}
"""
[docs]
@u.quantity_input
def clean(
dirty_map: Quantity,
dirty_beam: Quantity,
pixel_size: Quantity[u.arcsec / u.pix] = None,
clean_beam_width: Optional[Quantity[u.arcsec]] = 4.0 * u.arcsec,
gain: Optional[float] = 0.1,
thres: Optional[float] = None,
niter: int = 5000,
) -> Union[Quantity, NDArray[np.float64]]:
r"""
Clean the image using Hogbom's original method.
CLEAN iteratively subtracts the PSF or dirty beam from the dirty map to create the residual.
At each iteration, the location of the maximum residual is found and a shifted dirty beam is
subtracted at that location. This process continues until either `niter` iterations are reached or
the maximum residual <= `thres`.
Parameters
----------
dirty_map :
The dirty map to be cleaned 2D
dirty_beam :
The dirty beam or point spread function (PSF) 2D must
pixel_size :
The pixel size in arcsec
"""
# Ensure both beam and map are even/odd on same axes
# if not [x % 2 == 0 for x in dirty_map.shape] == [x % 2 == 0 for x in dirty_beam.shape]:
# raise ValueError('')
pad = [0 if x % 2 == 0 else 1 for x in dirty_map.shape]
# Assume beam, map phase_center is in middle
beam_center = (dirty_beam.shape[0] - 1) / 2.0, (dirty_beam.shape[1] - 1) / 2.0
map_center = (dirty_map.shape[0] - 1) / 2.0, (dirty_map.shape[1] - 1) / 2.0
# Work out size of map for slicing over-sized dirty beam
shape = dirty_map.shape
height = shape[0] // 2
width = shape[1] // 2
# max_beam = dirty_beam.max()
# Model for sources
model = np.zeros(dirty_map.shape)
components = []
for i in range(niter):
# Find max in dirty map and save to point source
mx, my = np.unravel_index(dirty_map.argmax(), dirty_map.shape)
imax = dirty_map[mx, my]
# TODO check if correct and how to undo
# imax = imax * max_beam
model[mx, my] += gain * imax
if i % 25 == 0:
logger.info(f"Iter: {i}, strength: {imax}, location: {mx, my}")
offset = map_center[0] - mx, map_center[1] - my
shifted_beam_center = int(beam_center[0] + offset[0]), int(beam_center[1] + offset[1])
xr = slice(shifted_beam_center[0] - height, shifted_beam_center[0] + height + pad[0])
yr = slice(shifted_beam_center[1] - width, shifted_beam_center[1] + width + pad[0])
shifted = dirty_beam[xr, yr]
comp = imax * gain * shifted
components.append((mx, my, comp[mx, my]))
dirty_map = np.subtract(dirty_map, comp)
if thres:
if np.abs(dirty_map).max() <= thres:
logger.info("Threshold reached")
break
if np.abs(dirty_map.min()) > dirty_map.max():
logger.info("Largest residual negative")
break
else:
print("Max iterations reached")
if clean_beam_width is not None:
# Convert from FWHM to StDev FWHM = sigma*(8ln2)**0.5 = 2.3548200450309493
x_stdev = (clean_beam_width / pixel_size[1]).to_value(u.pix) / 2.3548200450309493
y_stdev = (clean_beam_width / pixel_size[0]).to_value(u.pix) / 2.3548200450309493
clean_beam = Gaussian2DKernel(x_stdev, y_stdev, x_size=dirty_beam.shape[1], y_size=dirty_beam.shape[0]).array
# Normalise beam
clean_beam = clean_beam / clean_beam.max()
# Convolve clean beam with model and scale
clean_map = signal.convolve2d(model, clean_beam / clean_beam.sum(), mode="same") / (
pixel_size[0].value * pixel_size[1].value
)
# Scale residual map with model and scale
dirty_map = dirty_map / clean_beam.sum() / (pixel_size[0].value * pixel_size[1].value)
return clean_map + dirty_map, model, dirty_map
return model + dirty_map, model, dirty_map
clean.__doc__ += __common_clean_doc__ # type: ignore
[docs]
@u.quantity_input
def vis_clean(
vis: Visibilities,
shape: Quantity[u.pix],
pixel_size: Quantity[u.arcsec / u.pix],
clean_beam_width: Optional[Quantity[u.arcsec]] = 4.0,
niter: Optional[int] = 5000,
map: Optional[bool] = True,
gain: Optional[float] = 0.1,
**kwargs,
):
r"""
Clean the visibilities using Hogbom's original method.
A wrapper around lower level `clean` which calculates the dirty map and psf
Parameters
----------
vis :
The visibilities to clean
shape :
Size of map
pixel_size :
The pixel size in arcsec
map :
Return a `sunpy.map.Map` by default or array only if `False`
"""
dirty_map = vis_to_map(vis, shape=shape, pixel_size=pixel_size, **kwargs)
dirty_beam_shape = [x.value * 3 + 1 if x.value * 3 % 2 == 0 else x.value * 3 for x in shape] * shape.unit
dirty_beam = vis_psf_image(vis, shape=dirty_beam_shape, pixel_size=pixel_size, **kwargs)
clean_map, model, residual = clean(
dirty_map.data,
dirty_beam.value,
pixel_size=pixel_size,
clean_beam_width=clean_beam_width,
gain=gain,
niter=niter,
)
if not map:
return clean_map, model, residual
return [Map((data, dirty_map.meta)) for data in (clean_map, model, residual)]
vis_clean.__doc__ += __common_clean_doc__ # type: ignore
__common_ms_clean_doc__ = r"""
scales : array-like, optional, optional
The scales to use eg ``[1, 2, 4, 8]``
clean_beam_width :
The width of the gaussian to convolve the model with. If set to 0.0 the gaussian \
convolution is disabled
gain :
The gain per loop or loop gain
thres :
Terminates clean when `residuals.max() <= thres``
niter :
Maximum number of iterations to perform
Returns
-------
:
Cleaned image
Notes
-----
This is an implementation of the multiscale clean algorithm as outlined in [R1]_ adapted for \
x-ray Fourier observations.
It is based on the on the implementation in the CASA software which can be found here_.
.. _here: https://github.com/casacore/casacore/blob/f4dc1c36287c766796ce3375cebdfc8af797a388/lattices/LatticeMath/LatticeCleaner.tcc#L956 #noqa
References
----------
.. [R1] Cornwell, T. J., "Multiscale CLEAN Deconvolution of Radio Synthesis Images", IEEE Journal of Selected Topics in Signal Processing, vol 2, p793-801, Paper_ #noqa
.. _Paper: https://ieeexplore.ieee.org/document/4703304/
"""
[docs]
@u.quantity_input
def ms_clean(
dirty_map: Quantity,
dirty_beam: Quantity,
pixel_size: Quantity[u.arcsec / u.pix],
scales: Union[Iterable, NDArray, None] = None,
clean_beam_width: Quantity = 4.0 * u.arcsec,
gain: float = 0.1,
thres: float = 0.01,
niter: int = 5000,
) -> Union[Quantity, NDArray[np.float64]]:
r"""
Clean the map using a multiscale clean algorithm.
Parameters
----------
dirty_map :
The 2D dirty map to be cleaned
dirty_beam :
The 2D dirty beam should have the same dimensions as `dirty_map`
pixel_size :
The pixel size in arcsec
"""
# Compute the number of dyadic scales, their sizes and scale biases
number_of_scales: int = np.floor(np.log2(min(dirty_map.shape))).astype(int)
scale_sizes: NDArray[np.int_] = 2 ** np.arange(number_of_scales)
if scales:
scales = np.array(scales)
number_of_scales = len(scales)
scale_sizes = scales
scale_sizes = np.where(scale_sizes == 0, 1, scale_sizes)
scale_biases = 1 - 0.6 * scale_sizes / scale_sizes.max()
model = np.zeros(dirty_map.shape)
map_center = (dirty_map.shape[0] - 1) / 2.0, (dirty_map.shape[1] - 1) / 2.0
height = dirty_map.shape[0] // 2
width = dirty_map.shape[1] // 2
pad = [0 if x % 2 == 0 else 1 for x in dirty_map.shape]
# Pre-compute scales, residual maps and dirty beams at each scale and dirty beam cross terms
scales = np.zeros((dirty_map.shape[0], dirty_map.shape[1], number_of_scales))
scaled_residuals = np.zeros((dirty_map.shape[0], dirty_map.shape[1], number_of_scales))
scaled_dirty_beams = np.zeros((dirty_beam.shape[0], dirty_beam.shape[1], number_of_scales))
max_scaled_dirty_beams = np.zeros(number_of_scales)
cross_terms = {}
for i, scale in enumerate(scale_sizes):
scales[:, :, i] = _component(scale=scale, shape=dirty_map.shape)
scaled_residuals[:, :, i] = signal.convolve(dirty_map, scales[:, :, i], mode="same")
scaled_dirty_beams[:, :, i] = signal.convolve(dirty_beam, scales[:, :, i], mode="same")
max_scaled_dirty_beams[i] = scaled_dirty_beams[:, :, i].max()
for j in range(i, number_of_scales):
cross_terms[(i, j)] = signal.convolve(
signal.convolve(dirty_beam, scales[:, :, i], mode="same"), scales[:, :, j], mode="same"
)
# Clean loop
for i in range(niter):
# print(f'Clean loop {i}')
# For each scale find the strength and location of max residual
# Chose scale with has maximum strength
max_index: int = np.argmax(scaled_residuals).astype(int)
max_x: int
max_y: int
max_scale: int
max_x, max_y, max_scale = (int(x) for x in np.unravel_index(max_index, scaled_residuals.shape))
strength = scaled_residuals[max_x, max_y, max_scale]
# Adjust for the max of scaled beam
strength = strength / max_scaled_dirty_beams[max_scale]
logger.info(f"Iter: {i}, max scale: {max_scale}, strength: {strength}")
# Loop gain and scale dependent bias
strength = strength * scale_biases[max_scale] * gain
beam_center = [
(scaled_dirty_beams[:, :, max_scale].shape[0] - 1) / 2.0,
(scaled_dirty_beams[:, :, max_scale].shape[1] - 1) / 2.0,
]
offset = map_center[0] - max_x, map_center[1] - max_y
shifted_beam_center = int(beam_center[0] + offset[0]), int(beam_center[1] + offset[1])
xr = slice(shifted_beam_center[0] - height, shifted_beam_center[0] + height + pad[0])
yr = slice(shifted_beam_center[1] - width, shifted_beam_center[1] + width + pad[0])
# shifted = dirty_beam[xr, yr]
comp = strength * shift(scales[:, :, max_scale], (max_x - map_center[0], max_y - map_center[1]), order=0)
# comp = strength * scales[xr, yr]
# Add this component to current model
model = np.add(model, comp)
# Update all images using precomputed terms
for j, _ in enumerate(scale_sizes):
if j > max_scale:
cross_term = cross_terms[(max_scale, j)]
else:
cross_term = cross_terms[(j, max_scale)]
# comp = strength * shift(cross_term[xr, yr],
# (max_x - beam_center[0], max_y - beam_center[1]), order=0)
comp = strength * cross_term[xr, yr]
scaled_residuals[:, :, j] = np.subtract(scaled_residuals[:, :, j], comp)
# End max(res(a)) or niter
if np.abs(scaled_residuals[:, :, max_scale].max()) <= thres:
logger.info("Threshold reached")
break
# Largest scales largest residual is negative
if np.abs(scaled_residuals[:, :, 0].min()) > scaled_residuals[:, :, 0].max():
logger.info("Max scale residual negative")
break
else:
logger.info("Max iterations reached")
# Convolve model with clean beam B_G * I^M
if clean_beam_width is not None:
x_stdev = ((clean_beam_width / pixel_size[0]).to_value(u.pix) / (2.0 * np.sqrt(2.0 * np.log(2.0)))).value
y_stdev = ((clean_beam_width / pixel_size[1]).to_value(u.pix) / (2.0 * np.sqrt(2.0 * np.log(2.0)))).value
clean_beam = Gaussian2DKernel(x_stdev, y_stdev, x_size=dirty_beam.shape[1], y_size=dirty_beam.shape[0]).array
# Normalise beam
clean_beam = clean_beam / clean_beam.max()
clean_map = signal.convolve2d(model, clean_beam, mode="same") / (pixel_size[0] * pixel_size[1])
# Scale residual map with model and scale
dirty_map = (scaled_residuals / clean_beam.sum() / (pixel_size[0] * pixel_size[1])).sum(axis=2)
return clean_map + dirty_map, model, dirty_map
# Add residuals B_G * I^M + I^R
return model, scaled_residuals.sum(axis=2)
ms_clean.__doc__ += __common_ms_clean_doc__ # type: ignore
[docs]
def vis_ms_clean(
vis: Visibilities,
shape: Quantity[u.pix],
pixel_size: Quantity[u.arcsec / u.pix],
scales: Optional[Iterable],
clean_beam_width: Optional[Quantity[u.arcsec]] = 4.0,
niter: Optional[int] = 5000,
map: Optional[bool] = True,
gain: Optional[float] = 0.1,
thres: Optional[float] = 0.01,
) -> Union[Quantity, NDArray[np.float64]]:
r"""
Clean the visibilities using a multiscale clean method.
A wrapper around `ms_clean` which calculates the dirty map and psf.
Parameters
----------
vis :
The visibilities to clean
shape :
Size of map
pixel_size :
The pixel size in arcsec
scales : array-like, optional, optional
The scales to use eg ``[1, 2, 4, 8]``
clean_beam_width :
The width of the gaussian to convolve the model with. If set to 0.0 the gaussian \
convolution is disabled
gain :
The gain per loop or loop gain
thres :
Terminates clean when `residuals.max() <= thres``
niter :
Maximum number of iterations to perform
map :
Return a `sunpy.map.Map` by default or array only if `False`
Returns
-------
:
Cleaned image
"""
dirty_map = vis_to_map(vis, shape=shape, pixel_size=pixel_size)
dirty_beam = vis_psf_image(vis, shape=shape * 3, pixel_size=pixel_size)
clean_map, model, residual = ms_clean(
dirty_map.data,
dirty_beam,
pixel_size=pixel_size,
scales=scales,
clean_beam_width=clean_beam_width,
gain=gain,
thres=thres,
niter=niter,
)
if not map:
return clean_map, model, residual
return [Map((data, dirty_map.meta)) for data in (clean_map, model, residual)]
# vis_ms_clean.__doc__ += __common_ms_clean_doc__
def _radial_prolate_sphereoidal(nu):
r"""
Calculate prolate spheroidal wave function approximation.
Parameters
----------
nu : `float`
The radial value to evaluate the function at
Returns
-------
`float`
The amplitude of the the prolate spheroid function at `nu`
Notes
-----
Note this is a direct translation of the on the implementation the CASA code reference by [1] \
and can be found here Link_
.. _Link: https://github.com/casacore/casacore/blob/f4dc1c36287c766796ce3375cebdfc8af797a388/lattices/LatticeMath/LatticeCleaner.tcc#L956 #noqa
"""
if nu <= 0:
return 1.0
elif nu >= 1.0:
return 0.0
else:
n_p = 5
n_q = 3
p = np.zeros((n_p, 2))
q = np.zeros((n_q, 2))
p[0, 0] = 8.203343e-2
p[1, 0] = -3.644705e-1
p[2, 0] = 6.278660e-1
p[3, 0] = -5.335581e-1
p[4, 0] = 2.312756e-1
p[0, 1] = 4.028559e-3
p[1, 1] = -3.697768e-2
p[2, 1] = 1.021332e-1
p[3, 1] = -1.201436e-1
p[4, 1] = 6.412774e-2
q[0, 0] = 1.0000000e0
q[1, 0] = 8.212018e-1
q[2, 0] = 2.078043e-1
q[0, 1] = 1.0000000e0
q[1, 1] = 9.599102e-1
q[2, 1] = 2.918724e-1
part = 0
nuend = 0.0
if 0.0 <= nu < 0.75:
part = 0
nuend = 0.75
elif 0.75 <= nu <= 1.00:
part = 1
nuend = 1.0
top = p[0, part]
delnusq = np.power(nu, 2.0) - np.power(nuend, 2.0)
for k in range(1, n_p):
top += p[k, part] * np.power(delnusq, k)
bot = q[0, part]
for k in range(1, n_q):
bot += q[k, part] * np.power(delnusq, k)
if bot != 0.0:
return top / bot
else:
return 0
def _vec_radial_prolate_sphereoidal(nu):
r"""
Calculate prolate spheroidal wave function approximation.
Parameters
----------
nu : `float` array
The radial value to evaluate the function at
Returns
-------
`float`
The amplitude of the the prolate spheroid function at `nu`
Notes
-----
Note this is based on the implementation the CASA code reference by [1] and can be found here
Link_
.. _Link: https://github.com/casacore/casacore/blob/f4dc1c36287c766796ce3375cebdfc8af797a388/lattices/LatticeMath/LatticeCleaner.tcc#L956 #noqa
"""
nu = np.array(nu)
n_p = 5
n_q = 3
p = np.zeros((n_p, 2))
q = np.zeros((n_q, 2))
p[0, 0] = 8.203343e-2
p[1, 0] = -3.644705e-1
p[2, 0] = 6.278660e-1
p[3, 0] = -5.335581e-1
p[4, 0] = 2.312756e-1
p[0, 1] = 4.028559e-3
p[1, 1] = -3.697768e-2
p[2, 1] = 1.021332e-1
p[3, 1] = -1.201436e-1
p[4, 1] = 6.412774e-2
q[0, 0] = 1.0000000e0
q[1, 0] = 8.212018e-1
q[2, 0] = 2.078043e-1
q[0, 1] = 1.0000000e0
q[1, 1] = 9.599102e-1
q[2, 1] = 2.918724e-1
lower = np.where((nu >= 0.0) & (nu < 0.75)) # part = 0, nuend = 0.75
upper = np.where((nu >= 0.75) & (nu <= 1.00)) # part = 1, nuend = 1.0
delnusq = np.zeros_like(nu)
delnusq[lower] = np.power(nu[lower], 2.0) - np.power(0.75, 2.0)
delnusq[upper] = np.power(nu[upper], 2.0) - np.power(1.00, 2.0)
top = np.zeros_like(nu, dtype=float)
top[lower] = p[0, 0]
top[upper] = p[0, 1]
k = np.arange(1, n_p)
top[lower] += np.sum(p[k, 0, np.newaxis] * np.power(delnusq[lower], k[..., np.newaxis]), axis=0)
top[upper] += np.sum(p[k, 1, np.newaxis] * np.power(delnusq[upper], k[..., np.newaxis]), axis=0)
bot = np.zeros_like(nu, dtype=float)
bot[lower] = q[0, 0]
bot[upper] = q[0, 1]
j = np.arange(1, n_q)
bot[lower] += np.sum(q[j, 0, np.newaxis] * np.power(delnusq[lower], j[..., np.newaxis]), axis=0)
bot[upper] += np.sum(q[j, 1, np.newaxis] * np.power(delnusq[upper], j[..., np.newaxis]), axis=0)
out = np.zeros(nu.shape)
out[bot != 0] = top[bot != 0] / bot[bot != 0]
out = np.where(nu <= 0, 1.0, out)
out = np.where(nu >= 1, 0.0, out)
return out
def _component(scale, shape):
r"""
Parameters
----------
scale
Returns
-------
"""
# if scale == 0.0:
# out = np.zeros((3, 3))
# out[1,1] = 1.0
# return out
# elif scale % 2 == 0: # Even so keep output even
# shape = np.array((2 * scale + 2, 2 * scale + 2), dtype=int)
# else: # Odd so keep odd
# shape = np.array((2 * scale + 1, 2 * scale + 1), dtype=int)
refx, refy = (np.array(shape) - 1) / 2.0
if scale == 0.0:
wave_amp = np.zeros(shape)
wave_amp[int(refx), int(refy)] = 1
return wave_amp
xy = np.mgrid[0 : shape[0] : 1, 0 : shape[1] : 1]
radii_squared = ((xy[0, :, :] - refx) / scale) ** 2 + ((xy[1, :, :] - refy) / scale) ** 2
rad_zeros_indices = radii_squared <= 0.0
amp_zero_indices = radii_squared >= 1.0
wave_amp = _vec_radial_prolate_sphereoidal(np.sqrt(radii_squared.reshape(radii_squared.size)))
wave_amp = wave_amp.reshape(shape)
wave_amp[rad_zeros_indices] = _vec_radial_prolate_sphereoidal([0])[0]
wave_amp = wave_amp * (1 - radii_squared)
wave_amp[amp_zero_indices] = 0.0
return wave_amp