# imager.py
# a general interferometric imager class
#
# Copyright (C) 2018 Andrew Chael
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# TODO FIX INDEXING MESS FOR MULTIFREQUENCY POLARIZATION
# TODO for polarization imaging give better control than just initializing at 20%
# better initialization for all terms!! in init_imager
from __future__ import division
from __future__ import print_function
from builtins import str
from builtins import range
from builtins import object
import copy
import time
import numpy as np
import scipy.optimize as opt
import ehtim.imaging.imager_utils as imutils
import ehtim.imaging.pol_imager_utils as polutils
import ehtim.imaging.multifreq_imager_utils as mfutils
import ehtim.image
import ehtim.const_def as ehc
MAXIT = 200 # number of iterations
NHIST = 50 # number of steps to store for hessian approx
MAXLS = 40 # maximum number of line search steps in BFGS-B
STOP = 1e-6 # convergence criterion
EPS = 1e-8
DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag']
DATATERMS_POL = ['pvis', 'm', 'vvis']
REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log', 'l1', 'l1w', 'lA', 'patch',
'flux', 'cm', 'simple', 'compact', 'compact2', 'rgauss']
REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','vtv2','vflux']
REGULARIZERS_ALLFREQS_I = ['flux_mf']
REGULARIZERS += REGULARIZERS_ALLFREQS_I
REGULARIZERS_SPECIND = ['l2_alpha', 'tv_alpha']
REGULARIZERS_CURV = ['l2_beta', 'tv_beta']
REGULARIZERS_SPECIND_P = ['l2_alphap', 'tv_alphap']
REGULARIZERS_CURV_P = ['l2_betap', 'tv_betap']
REGULARIZERS_RM = ['l2_rm', 'tv_rm']
REGULARIZERS_CM = ['l2_cm', 'tv_cm']
REGULARIZERS_ISPECTRAL = REGULARIZERS_SPECIND + REGULARIZERS_CURV
REGULARIZERS_POLSPECTRAL = REGULARIZERS_SPECIND_P + REGULARIZERS_CURV_P + REGULARIZERS_RM + REGULARIZERS_CM
REGULARIZERS_SPECTRAL = REGULARIZERS_ISPECTRAL + REGULARIZERS_POLSPECTRAL
GRIDDER_P_RAD_DEFAULT = 2
GRIDDER_CONV_FUNC_DEFAULT = 'gaussian'
FFT_PAD_DEFAULT = 2
FFT_INTERP_DEFAULT = 3
REG_DEFAULT = {'simple': 1}
DAT_DEFAULT = {'vis': 100}
REGPARAMS_DEFAULT = {'major':50*ehc.RADPERUAS,
'minor':50*ehc.RADPERUAS,
'PA':0.,
'alpha_A':1.0,
'epsilon_tv':0.0}
POLARIZATION_MODES = ['P','QU','IP','IQU','V','IV','IQUV','IPV'] # TODO: treatment of V may be inconsistent
MEANPOL_INIT = 0.2 # mean initial polarization if not in initial image
SIGMAPOL_INIT = 1.e-2 # perturbations to initial polarization if not in initial image
###################################################################################################
# Imager object
###################################################################################################
[docs]class Imager(object):
"""A general interferometric imager.
"""
def __init__(self, obs_in, init_im,
prior_im=None, flux=None, data_term=DAT_DEFAULT, reg_term=REG_DEFAULT, **kwargs):
self.logstr = ""
self._out_list = []
self._obs_list = []
self._init_list = []
self._prior_list = []
self._reg_term_list = []
self._dat_term_list = []
self._maxit_list = []
self._stop_list = []
self._pol_list = []
self._flux_list = []
self._pflux_list = []
self._vflux_list = []
self._clipfloor_list = []
self._snrcut_list = []
self._debias_list = []
self._systematic_noise_list = []
self._systematic_cphase_noise_list = []
self._transform_list = []
self._weighting_list = []
self._maxset_list = []
self._cp_uv_min_list = []
self._reffreq_list = []
self._mf_list = []
self._mf_order_list = []
self._mf_order_pol_list = []
self._mf_rm_list = []
self._mf_cm_list = []
# iterations / convergence
self.maxit_next = kwargs.get('maxit', MAXIT)
self.stop_next = kwargs.get('stop', STOP)
# Regularizer/data terms for the next imaging iteration
self.reg_term_next = reg_term # e.g. [('simple',1), ('l1',10), ('flux',500), ('cm',500)]
self.dat_term_next = data_term # e.g. [('amp', 1000), ('cphase',100)]
# Observations, frequencies
self.reffreq = init_im.rf
if isinstance(obs_in, list) or isinstance(obs_in, np.ndarray):
self._obslist_next = obs_in
self.obslist_next = obs_in
else:
self._obslist_next = [obs_in]
self.obslist_next = [obs_in]
# Initial image, prior, flux
self.init_next = init_im
if prior_im is None:
self.prior_next = self.init_next
else:
self.prior_next = prior_im
if flux is None:
self.flux_next = self.prior_next.total_flux()
else:
self.flux_next = flux
# set polarimetric flux values equal to Stokes I flux by default
# used in polarimetric regularizer normalization
self.pflux_next = kwargs.get('pflux', flux)
self.vflux_next = kwargs.get('vflux', flux)
# Polarization and image transforms
self.pol_next = kwargs.get('pol', self.init_next.pol_prim)
self.transform_next = kwargs.get('transform', ['log','mcv'])
self.transform_next = np.array([self.transform_next]).flatten() #so we can handle multiple transforms
# Weighting/debiasing/snr cut/systematic noise
self.debias_next = kwargs.get('debias', True)
snrcut = kwargs.get('snrcut', 0.)
self.snrcut_next = {key: 0. for key in set(DATATERMS+DATATERMS_POL)}
if type(snrcut) is dict:
for key in snrcut.keys():
self.snrcut_next[key] = snrcut[key]
else:
for key in self.snrcut_next.keys():
self.snrcut_next[key] = snrcut
self.systematic_noise_next = kwargs.get('systematic_noise', 0.)
self.systematic_cphase_noise_next = kwargs.get('systematic_cphase_noise', 0.)
self.weighting_next = kwargs.get('weighting', 'natural')
# Maximal/minimal closure set
self.maxset_next = kwargs.get('maxset', False)
self.cp_uv_min = kwargs.get('cp_uv_min', False) # UV minimum for closure phase
# clipfloor/initial image normalization
self.clipfloor_next = kwargs.get('clipfloor', 0.)
self.norm_init = kwargs.get('norm_init', True)
self.norm_reg = kwargs.get('norm_reg', False)
self.beam_size = self.obslist_next[0].res()
self.regparams = {k: kwargs.get(k, REGPARAMS_DEFAULT[k]) for k in REGPARAMS_DEFAULT.keys()}
self.chisq_transform = False
self.chisq_offset_gradient = 0.0
# FFT parameters
self._ttype = kwargs.get('ttype', 'nfft')
self._fft_gridder_prad = kwargs.get('fft_gridder_prad', GRIDDER_P_RAD_DEFAULT)
self._fft_conv_func = kwargs.get('fft_conv_func', GRIDDER_CONV_FUNC_DEFAULT)
self._fft_pad_factor = kwargs.get('fft_pad_factor', FFT_PAD_DEFAULT)
self._fft_interp_order = kwargs.get('fft_interp_order', FFT_INTERP_DEFAULT)
# multifrequency
self.mf_next = kwargs.get('mf',False)
self.mf_order = kwargs.get('mf_order',0)
self.mf_order_pol = kwargs.get('mf_order_pol',0)
self.mf_rm = kwargs.get('mf_rm',0)
self.mf_cm = kwargs.get('mf_cm',0)
self.mf_flux = kwargs.get('mf_flux',[self.flux_next]) # TODO: merge these
if kwargs.get('mf_which_solve') is not None:
raise Exception("'mf_which_solve' argument for multifrequency imaging is deprecated -- use 'mf_order' instead!")
if kwargs.get('reg_all_freq_mf') is not None:
raise Exception("'reg_all_freq_mf' argument for multifrequency imaging is deprecated")
# Imager history
self._change_imgr_params = True
self.nruns = 0
# Set embedding matrices and prepare imager
self.check_params()
self.check_limits()
self.init_imager()
@property
def obslist_next(self):
return self._obslist_next
@obslist_next.setter
def obslist_next(self, obslist):
if not isinstance(obslist, list):
raise Exception("obslist_next must be a list!")
self._obslist_next = obslist
self.freq_list = [obs.rf for obs in self.obslist_next]
self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
@property
def obs_next(self):
"""the next Obsdata to be used in imaging
"""
return self.obslist_next[0]
@obs_next.setter
def obs_next(self, obs):
"""the next Obsdata to be used in imaging
"""
self.obslist_next = [obs]
[docs] def reg_terms_last(self):
"""Return last used regularizer terms.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._reg_term_list[-1]
[docs] def dat_terms_last(self):
"""Return last used data terms.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._dat_term_list[-1]
[docs] def obslist_last(self):
"""Return last used observation.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._obs_list[-1]
[docs] def obs_last(self):
"""Return last used observation.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._obs_list[-1][0]
[docs] def prior_last(self):
"""Return last used prior image.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._prior_list[-1]
[docs] def out_last(self):
"""Return last result.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._out_list[-1]
[docs] def init_last(self):
"""Return last initial image.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._init_list[-1]
[docs] def flux_last(self):
"""Return last total flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._flux_list[-1]
[docs] def pflux_last(self):
"""Return last total linear polarimetric flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._pflux_list[-1]
[docs] def vflux_last(self):
"""Return last total circular polarimetric flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._vflux_list[-1]
[docs] def clipfloor_last(self):
"""Return last clip floor.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._clipfloor_list[-1]
[docs] def pol_last(self):
"""Return last polarization imaged.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._pol_list[-1]
[docs] def maxit_last(self):
"""Return last max_iterations value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._maxit_list[-1]
[docs] def stop_last(self):
"""Return last convergence value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._stop_list[-1]
[docs] def debias_last(self):
"""Return last debias value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._debias_list[-1]
[docs] def snrcut_last(self):
"""Return last snrcut value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._snrcut_list[-1]
[docs] def weighting_last(self):
"""Return last weighting value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._weighting_list[-1]
[docs] def systematic_noise_last(self):
"""Return last systematic_noise value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._systematic_noise_list[-1]
[docs] def systematic_cphase_noise_last(self):
"""Return last closure phase systematic noise value (in degree).
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._systematic_cphase_noise_list[-1]
[docs] def maxset_last(self):
"""Return last choice of closure phase maximal/minimal set
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._maxset_list[-1]
[docs] def cp_uv_min_last(self):
"""Return last choice of minimal uvdistance for closure phases
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._cp_uv_min_list[-1]
[docs] def mf_last(self):
"""Return last choice for multifrequency imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._mf_list[-1]
[docs] def reffreq_last(self):
"""Return last choice of order for multifrequency imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._reffreq_list[-1]
[docs] def mf_order_last(self):
"""Return last choice of order for multifrequency imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._mf_order_list[-1]
[docs] def mf_order_pol_last(self):
"""Return last choice of order for multifrequency polarimetric imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._mf_order_pol_list[-1]
[docs] def mf_rm_last(self):
"""Return last choice for RM imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._mf_rm_list[-1]
[docs] def mf_cm_last(self):
"""Return last choice for CM imaging
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._mf_cm_list[-1]
def converge(self, niter, blur_frac, pol, grads=True, **kwargs):
blur = blur_frac * self.obs_next.res()
for repeat in range(niter-1):
init = self.out_last()
init = init.blur_circ(blur, blur)
self.init_next = init
self.make_image(pol=pol, grads=grads, **kwargs)
[docs] def make_image_I(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'I'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_P(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes P polarimetric image using current imager settings.
"""
pol = 'P'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_IP(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I and P polarimetric image simultaneously using current imager settings.
"""
pol = 'IP'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_V(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'V'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_IV(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'IV'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image(self, pol=None, grads=True, mf=False, **kwargs):
"""Make an image using current imager settings.
Args:
pol (str): which polarization to image
grads (bool): whether or not to use image gradients
show_updates (bool): whether or not to show imager progress
update_interval (int): step interval for plotting if show_updates=True
mf (bool): whether or not to do multifrequency (spectral index only for now)
mf_order (int): order for multifrequency spectral index imaging
mf_order_pol (int): order for multifrequency polarization fraction imaging
mf_rm (int): order for rotation measure imaging
mf_cm (int): order for conversion measure imaging
Returns:
(Image): output image
"""
print("==============================")
print("Imager run %i " % (int(self.nruns)+1))
# multifrequency parameters
self.mf_next = mf
self.mf_order = kwargs.get('mf_order', self.mf_order)
self.mf_order_pol = kwargs.get('mf_order_pol', self.mf_order_pol)
self.mf_rm = kwargs.get('mf_rm', self.mf_rm)
self.mf_cm = kwargs.get('mf_cm', self.mf_cm)
if kwargs.get('mf_which_solve') is not None:
raise Exception("'mf_which_solve' argument for multifrequency imaging is deprecated -- use 'mf_order' instead!")
if kwargs.get('reg_all_freq_mf') is not None:
raise Exception("'reg_all_freq_mf' argument for multifrequency imaging is deprecated")
# polarization parameters
if pol is None:
pol_prim = self.pol_next
else:
self.pol_next = pol
pol_prim = pol
# For polarimetric imaging, we must switch polrep to Stokes
if self.pol_next in POLARIZATION_MODES:
print("Imaging Polarization: switching image polrep to Stokes!")
self.prior_next = self.prior_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
self.init_next = self.init_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
pol_prim = 'I'
# Checks and initialize
self.check_params()
self.check_limits()
self.init_imager()
# Print initial stats
self._nit = 0
self._show_updates = kwargs.get('show_updates', True)
self._update_interval = kwargs.get('update_interval', 1)
# Plot initial image
self.plotcur(self._xinit, **kwargs)
# Minimize
print("Imaging . . .")
optdict = {'maxiter': self.maxit_next,
'ftol': self.stop_next, 'gtol': self.stop_next,
'maxcor': NHIST, 'maxls': MAXLS}
def callback_func(xcur):
self.plotcur(xcur, **kwargs)
tstart = time.time()
if grads:
res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B', jac=self.objgrad,
options=optdict, callback=callback_func)
else:
res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B',
options=optdict, callback=callback_func)
tstop = time.time()
# Format output
out = res.x[:]
self.tmpout = res.x
# unpack output vector into outarr
outarr = unpack_imarr(out, self._xarr, self._which_solve)
# apply image transform to bounded values
outarr = transform_imarr(outarr, self.transform_next, self._which_solve)
# get and print final statistics
outstr = ""
chi2_term_dict = self.make_chisq_dict(outarr)
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
try:
print("time: %f s" % (tstop - tstart))
print("J: %f" % res.fun)
print(outstr)
if isinstance(res.message,str): print(res.message)
else: print(res.message.decode())
except: # TODO -- issues for some users with res.message
pass
print("==============================")
# Format final image object
outim = self.format_outim(outarr, pol_prim=pol_prim)
# Append to history
logstr = str(self.nruns) + ": make_image(pol=%s)" % pol
self._append_image_history(outim, logstr)
self.nruns += 1
# Return Image object
return outim
[docs] def check_params(self):
"""Check parameter consistency.
"""
if ((self.prior_next.psize != self.init_next.psize) or
(self.prior_next.xdim != self.init_next.xdim) or
(self.prior_next.ydim != self.init_next.ydim)):
raise Exception("Initial image does not match dimensions of the prior image!")
if ((self.prior_next.rf != self.init_next.rf)):
raise Exception("Initial image does not have same frequency as prior image!")
if (self.prior_next.polrep != self.init_next.polrep):
raise Exception(
"Initial image polrep does not match prior polrep!")
if (self.prior_next.polrep == 'circ' and not(self.pol_next in ['RR', 'LL'])):
raise Exception("Initial image polrep is 'circ': pol_next must be 'RR' or 'LL'")
if (self.prior_next.polrep == 'stokes'
and not(self.pol_next in ['I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'])):
raise Exception(
"Initial image polrep is 'stokes': pol_next must be in 'I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'!")
if ('log' in self.transform_next and self.pol_next in ['Q', 'U', 'V']):
raise Exception("Cannot image Stokes Q, U, V with log image transformation!")
if(self.pol_next in ['Q', 'U', 'V'] and
('gs' in self.reg_term_next.keys() or 'simple' in self.reg_term_next.keys())):
raise Exception(
"'simple' and 'gs' regularizers do not work with Stokes Q, U, or V images!")
if self._ttype not in ['fast', 'direct', 'nfft']:
raise Exception("Possible ttype values are 'fast', 'direct','nfft'!")
# Catch errors in multifrequency imaging setup
if self.mf_next:
if len(set(self.freq_list)) < 2:
raise Exception("Must have observations at at least two frequencies for multifrequency imaging!")
if self.mf_order not in [0,1,2]:
raise Exception("mf_order must be in [0,1,2]!")
if (self.pol_next in POLARIZATION_MODES):
if not (self.pol_next in ['P','QU']):
raise Exception("Currently we only support pol_next=P for polarization multifrequency imaging!")
if self.mf_order_pol not in [0,1,2]:
raise Exception("mf_order_pol must be in [0,1,2]!")
# Catch errors for polarimetric imaging setup
if self.pol_next in POLARIZATION_MODES:
if (self.pol_next in ['P', 'QU','IP','IQU']):
if 'mcv' not in self.transform_next:
raise Exception("%s imaging requires 'mcv' transform!"%self.pol_next)
if 'vcv' in self.transform_next:
raise Exception("Cannot do %s imaging with 'vcv' transform!"%self.pol_next)
if 'polcv' in self.transform_next:
raise Exception("Cannot do %s imaging only with 'polcv' transform!"%self.pol_next)
if (self.pol_next in ['V','IV']):
if 'vcv' not in self.transform_next:
raise Exception("%s imaging requires 'vcv' transform!"%self.pol_next)
if 'mcv' in self.transform_next:
raise Exception("Cannot do %s imaging only with 'mcv' transform!"%self.pol_next)
if 'polcv' in self.transform_next:
raise Exception("Cannot do %s imaging only with 'polcv' transform!"%self.pol_next)
if self.pol_next in ['IPV','IQUV']:
if 'polcv' not in self.transform_next:
raise Exception("Linear+Circular polarization imaging requires 'polcv' transform!")
if (self._ttype not in ["direct", "nfft"]):
raise Exception("FFT not yet implemented in polarimetric imaging -- use NFFT!")
# catch errors in general imaging setup
if self.mf_next:
if self.pol_next in POLARIZATION_MODES:
if 'I' in self.pol_next:
rlist = REGULARIZERS + REGULARIZERS_POL + REGULARIZERS_SPECTRAL
dlist = DATATERMS + DATATERMS_POL
else:
rlist = REGULARIZERS_POL + REGULARIZERS_POLSPECTRAL
dlist = DATATERMS_POL
else:
rlist = REGULARIZERS + REGULARIZERS_ISPECTRAL
dlist = DATATERMS
else:
if self.pol_next in POLARIZATION_MODES:
if 'I' in self.pol_next:
rlist = REGULARIZERS + REGULARIZERS_POL
dlist = DATATERMS + DATATERMS_POL
else:
rlist = REGULARIZERS_POL
dlist = DATATERMS_POL
else:
rlist = REGULARIZERS
dlist = DATATERMS
dt_here = False
dt_type = True
for term in sorted(self.dat_term_next.keys()):
if (term is not None) and (term is not False):
dt_here = True
if not ((term in dlist) or (term is False)):
dt_type = False
st_here = False
st_type = True
for term in sorted(self.reg_term_next.keys()):
if (term is not None) and (term is not False):
st_here = True
if not ((term in rlist) or (term is False)):
st_type = False
if not dt_here:
raise Exception("Must have at least one data term!")
if not st_here:
raise Exception("Must have at least one regularizer term!")
if not dt_type:
raise Exception("Invalid data term: valid data terms are: " + ','.join(dlist))
if not st_type:
raise Exception("Invalid regularizer: valid regularizers are: " + ','.join(rlist))
# Determine if we need to recompute the saved imager parameters on the next imager run
if self.nruns == 0:
return
if self.pol_next != self.pol_last():
print("changed polarization!")
self._change_imgr_params = True
return
if self.obslist_next != self.obslist_last():
print("changed observation!")
self._change_imgr_params = True
return
if len(self.reg_term_next) != len(self.reg_terms_last()):
print("changed number of regularizer terms!")
self._change_imgr_params = True
return
if len(self.dat_term_next) != len(self.dat_terms_last()):
print("changed number of data terms!")
self._change_imgr_params = True
return
for term in sorted(self.dat_term_next.keys()):
if term not in self.dat_terms_last().keys():
print("added %s to data terms" % term)
self._change_imgr_params = True
return
for term in sorted(self.reg_term_next.keys()):
if term not in self.reg_terms_last().keys():
print("added %s to regularizers!" % term)
self._change_imgr_params = True
return
if ((self.prior_next.psize != self.prior_last().psize) or
(self.prior_next.xdim != self.prior_last().xdim) or
(self.prior_next.ydim != self.prior_last().ydim)):
print("changed prior dimensions!")
self._change_imgr_params = True
return
if self.debias_next != self.debias_last():
print("changed debiasing!")
self._change_imgr_params = True
return
if self.snrcut_next != self.snrcut_last():
print("changed snrcut!")
self._change_imgr_params = True
return
if self.weighting_next != self.weighting_last():
print("changed data weighting!")
self._change_imgr_params = True
return
if self.systematic_noise_next != self.systematic_noise_last():
print("changed systematic noise!")
self._change_imgr_params = True
return
if self.systematic_cphase_noise_next != self.systematic_cphase_noise_last():
print("changed systematic cphase noise!")
self._change_imgr_params = True
return
if self.cp_uv_min != self.cp_uv_min_last():
print("changed cphase maximal/minimal set!")
self._change_imgr_params = True
return
if self.reffreq != self.reffreq_last():
print("changed refrence frequency!")
self._change_imgr_params = True
return
if self.mf_next != self.mf_last():
print("changed multifrequncy strategy!")
self._change_imgr_params = True
return
if self.mf_order != self.mf_order_last():
print("changed multifrequncy order!")
self._change_imgr_params = True
return
if self.mf_order_pol != self.mf_order_pol_last():
print("changed pol. multifrequncy order!")
self._change_imgr_params = True
return
if self.mf_rm != self.mf_rm_last():
print("changed pol. rm imaging order!")
self._change_imgr_params = True
return
if self.mf_cm != self.mf_cm_last():
print("changed pol. cm imaging order!")
self._change_imgr_params = True
return
return
[docs] def check_limits(self):
"""Check image parameter consistency with observation.
"""
for i,obs in enumerate(self.obslist_next):
uvmax = 1.0/self.prior_next.psize
uvmin = 1.0/(self.prior_next.psize*np.max((self.prior_next.xdim, self.prior_next.ydim)))
uvdists = obs.unpack('uvdist')['uvdist']
maxbl = np.max(uvdists)
minbl = np.max(uvdists[uvdists > 0])
if uvmax < maxbl:
print("Warning! Pixel size is larger than smallest spatial wavelength for freq %.1f GHz!"%(obs.rf/1.e9))
if uvmin > minbl:
print("Warning! Field of View is smaller than largest nonzero spatial wavelength for freq %.1f GHz!"%(obs.rf/1.e9))
if self.pol_next in ['I', 'RR', 'LL']:
maxamp = np.max(np.abs(obs.unpack('amp')['amp']))
# TODO: better handling of mf fluxes
if len(self.mf_flux)==len(self.obslist_next):
flux = self.mf_flux[i]
else:
flux = self.flux_next
if flux > 1.2*maxamp:
print("Warning! Specified flux %.1f is > 120%% of maximum visibility amplitude for freq %.1f GHz!"%(flux,obs.rf/1.e9))
if flux < .8*maxamp:
print("Warning! Specified flux %.1f is < 80%% of maximum visibility amplitude for freq %.1f GHz!"%(flux,obs.rf/1.e9))
return
[docs] def init_imager(self):
"""Set up Stokes I imager.
"""
# Set embedding mask
self.set_embed()
self._nimage = np.sum(self._embed_mask)
# Set prior & initial image vectors for multifrequency imaging
if self.mf_next:
# set reference frequency to same as prior
self.reffreq = self.init_next.rf
# reset logfreqratios in case the reference frequency changed
self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
# determine self._which_solve
# TODO is there a nicer way to do this?
if self.mf_order==2:
do_a = 1; do_b = 1;
elif self.mf_order==1:
do_a = 1; do_b = 0;
elif self.mf_order == 0:
do_a = 0; do_b = 0;
else:
raise Exception("Imager.mf_order must be 0, 1, or 2!")
# polarization multi-frequency
if self.pol_next in POLARIZATION_MODES:
# determine self._which_solve
# TODO is there a nicer way to do this?
if self.mf_order_pol == 2:
do_ap=1; do_bp=1
elif self.mf_order_pol == 1:
do_ap=1; do_bp=0
elif self.mf_order_pol == 0:
do_ap=0; do_bp=0
else:
raise Exception("Imager.mf_order_pol must be 0, 1, or 2!")
if self.mf_rm:
do_rm = 1
else:
do_rm = 0
if self.mf_cm:
do_cm = 1
else:
do_cm = 0
if 'I' in self.pol_next:
do_i = 1
else:
do_i = 0
# TODO: No Stokes V imaging for multifrequency yet
do_rho = 1; do_phi=1
do_psi = 0
if not (('P' in self.pol_next) or ('QU' in self.pol_next)):
raise Exception("Multifrequency polarization imaging currently requires pol_next=P!")
if ('V' in self.pol_next):
raise Exception("Stokes V not yet implemented in multifrequency polarization imaging!")
# set_which_solve vector
self._which_solve = np.array([do_i,do_rho,do_phi,do_psi,
do_a,do_b,do_ap,do_bp,
do_rm,do_cm])
# make initial and prior images
randompol_circ = randompol_lin=False
if 'V' in self.pol_next:
randompol_circ=True
if ('P' in self.pol_next) or ('QU' in self.pol_next):
randompol_lin=True
initarr = make_initarr(self.init_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=True, pol=True,
randompol_lin=randompol_lin, randompol_circ=randompol_circ,
meanpol=MEANPOL_INIT, sigmapol=SIGMAPOL_INIT)
priorarr = make_initarr(self.prior_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=True, pol=True,
randompol_lin=False, randompol_circ=False)
# prior
self._xprior = priorarr
# transform the initial image to solved for values and assign to self._xarr
initarr = transform_imarr_inverse(initarr, self.transform_next, self._which_solve)
self._xarr = initarr
# Stokes I multi-frequency
else:
# set_which_solve vector
self._which_solve = np.array([1,do_a,do_b])
# make initial and prior images
initarr = make_initarr(self.init_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=True, pol=False)
priorarr = make_initarr(self.prior_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=True, pol=False)
# prior
self._xprior = priorarr
# transform the initial image to solved for values and assign to self._xarr
initarr = transform_imarr_inverse(initarr, self.transform_next, self._which_solve)
self._xarr = initarr
# Pack multi-frequency tuple into single vector
self._xinit = pack_imarr(self._xarr, self._which_solve)
# Set prior & initial image vectors for single-frequency imaging
else:
# single-frequency polarimetric imaging
if self.pol_next in POLARIZATION_MODES:
# Determine self._which_solve
if ('I' in self.pol_next):
do_i = 1
else:
do_i = 0
if ('P' in self.pol_next) or ('QU' in self.pol_next):
do_rho = 1; do_phi=1
else:
do_rho = 0; do_phi=0
if ('V' in self.pol_next):
do_psi = 1;
else:
do_psi = 0;
# set self._which_solve vector
self._which_solve = np.array([do_i,do_rho,do_phi,do_psi])
# make initial and prior images
randompol_circ = randompol_lin=False
if 'V' in self.pol_next:
randompol_circ=True
if ('P' in self.pol_next) or ('QU' in self.pol_next):
randompol_lin=True
initarr = make_initarr(self.init_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=False, pol=True,
randompol_lin=randompol_lin, randompol_circ=randompol_circ,
meanpol=MEANPOL_INIT, sigmapol=SIGMAPOL_INIT)
priorarr = make_initarr(self.prior_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=False, pol=True,
randompol_lin=False, randompol_circ=False)
# prior
self._xprior = priorarr
# transform the initial image to solved for values and assign to self._xarr
initarr = transform_imarr_inverse(initarr, self.transform_next, self._which_solve)
self._xarr = initarr
# Pack into single vector
self._xinit = pack_imarr(self._xarr, self._which_solve)
# regular single-frequency single-stokes (or RR, LL) imaging
else:
# set self._which_solve vector
self._which_solve = np.array([1])
# make initial and prior images
initarr = make_initarr(self.init_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=False, pol=False)
priorarr = make_initarr(self.prior_next, self._embed_mask,
norm_init=self.norm_init, flux=self.flux_next,
mf=False, pol=False)
# Prior
self._xprior = priorarr
# transform the initial image to solved for values and assign to self._xarr
initarr = transform_imarr_inverse(initarr, self.transform_next, self._which_solve)
self._xarr = initarr
# Pack into single vector
self._xinit = pack_imarr(self._xarr, self._which_solve)
# Make data term tuples
if self._change_imgr_params:
if self.nruns == 0:
print("Initializing imager data products . . .")
if self.nruns > 0:
print("Recomputing imager data products . . .")
self._data_tuples = {}
# Loop over all data term types
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
# Each entry in the dterm dictionary past the first has an appended number
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
# Polarimetric data products
if dname in DATATERMS_POL:
tup = polutils.polchisqdata(obs, self.prior_next, self._embed_mask, dname,
ttype=self._ttype,
fft_pad_factor=self._fft_pad_factor,
conv_func=self._fft_conv_func,
p_rad=self._fft_gridder_prad)
# Single polarization data products
elif dname in DATATERMS:
if self.pol_next in POLARIZATION_MODES:
if not 'I' in self.pol_next:
raise Exception("cannot use dterm %s with pol=%s"%(dname,self.pol_next))
pol_next = 'I'
else:
pol_next = self.pol_next
tup = imutils.chisqdata(obs, self.prior_next, self._embed_mask, dname,
pol=pol_next, maxset=self.maxset_next,
debias=self.debias_next,
snrcut=self.snrcut_next[dname],
weighting=self.weighting_next,
systematic_noise=self.systematic_noise_next,
systematic_cphase_noise=self.systematic_cphase_noise_next,
ttype=self._ttype, order=self._fft_interp_order,
fft_pad_factor=self._fft_pad_factor,
conv_func=self._fft_conv_func,
p_rad=self._fft_gridder_prad,
cp_uv_min=self.cp_uv_min)
else:
raise Exception("data term %s not recognized!" % dname)
self._data_tuples[dname_key] = tup
self._change_imgr_params = False
return
[docs] def set_embed(self):
"""Set embedding matrix.
"""
self._embed_mask = (self.prior_next.imvec > self.clipfloor_next)
if not np.any(self._embed_mask):
raise Exception("clipfloor_next too large: all prior pixels have been clipped!")
xmax = self.prior_next.xdim//2
ymax = self.prior_next.ydim//2
if self.prior_next.xdim % 2: xmin=-xmax-1
else: xmin=-xmax
if self.prior_next.ydim % 2: ymin=-ymax-1
else: ymin=-ymax
coord = np.array([[[x, y]
for x in np.arange(xmax, xmin, -1)]
for y in np.arange(ymax, ymin, -1)])
coord = coord.reshape(self.prior_next.ydim * self.prior_next.xdim, 2)
coord = coord * self.prior_next.psize
self._coord_matrix = coord[self._embed_mask]
return
[docs] def make_chisq_dict(self, imcur):
"""Make a dictionary of current chi^2 term values
input is image array transformed to bounded values
"""
chi2_dict = {}
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
# get data products
(data, sigma, A) = self._data_tuples[dname_key]
# get current multifrequency image
if self.mf_next:
logfreqratio = self._logfreqratio_list[i]
imcur_nu = mfutils.image_at_freq(imcur, logfreqratio)
else:
imcur_nu = imcur
# Polarization chi^2 terms
if dname in DATATERMS_POL:
chi2 = polutils.polchisq(imcur_nu, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask)
# Single Polarization chi^2 terms
elif dname in DATATERMS:
if self.pol_next in POLARIZATION_MODES:
imcur_nu_I = imcur_nu[0]
else:
imcur_nu_I = imcur_nu
chi2 = imutils.chisq(imcur_nu, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask)
else:
raise Exception("data term %s not recognized!" % dname)
chi2_dict[dname_key] = chi2
return chi2_dict
[docs] def make_chisqgrad_dict(self, imcur):
"""Make a dictionary of current chi^2 term gradient values
input is image array transformed to bounded values
"""
chi2grad_dict = {}
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
# get data products
(data, sigma, A) = self._data_tuples[dname_key]
# get current multifrequency image
if self.mf_next:
logfreqratio = self._logfreqratio_list[i]
imcur_nu = mfutils.image_at_freq(imcur, logfreqratio)
else:
imcur_nu = imcur
# Polarimetric chi^2 gradients
if dname in DATATERMS_POL:
if self.mf_next:
pol_solve = self._which_solve[0:4]
else:
pol_solve = self._which_solve
chi2grad = polutils.polchisqgrad(imcur_nu, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask,
pol_solve=pol_solve)
# Single polarization chi^2 gradients
elif dname in DATATERMS:
if self.pol_next in POLARIZATION_MODES: # polarization
imcur_nu_I = imcur_nu[0]
else:
imcur_nu_I = imcur_nu
chi2grad = imutils.chisqgrad(imcur_nu_I, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask)
# If imaging Stokes I with polarization simultaneously, bundle the gradient
if self.pol_next in POLARIZATION_MODES:
chi2grad = np.array((chi2grad,
np.zeros(self._nimage),
np.zeros(self._nimage),
np.zeros(self._nimage)))
else:
raise Exception("data term %s not recognized!" % dname)
# If multifrequency imaging,
# transform the image gradients for all the solved quantities
if self.mf_next:
logfreqratio = self._logfreqratio_list[i]
chi2grad = mfutils.mf_all_grads_chain(chi2grad, imcur_nu, imcur, logfreqratio)
chi2grad_dict[dname_key] = np.array(chi2grad)
return chi2grad_dict
[docs] def make_reg_dict(self, imcur):
"""Make a dictionary of current regularizer values
input is image array transformed to bounded values
"""
reg_dict = {}
for regname in sorted(self.reg_term_next.keys()):
# Multifrequency regularizers
if self.mf_next:
# Polarimetric regularizers
if regname in REGULARIZERS_POL:
# we only regularize the reference frequency image
imcur_pol = imcur[0:4]
prior_pol = self._xprior[0:4]
reg = polutils.polregularizer(imcur_pol, prior_pol, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Stokes I regularizers
elif regname in REGULARIZERS:
# here we regularize images at each frequency
if regname in REGULARIZERS_ALLFREQS_I:
# TODO move this to checks?
if (not isinstance(self.mf_flux, list)) or len(self.mf_flux)!=len(self.obslist_next):
raise Exception("when using regularizer '%s', "%regname
+"self.mf_flux must be a list of same length as self.obslist_next!")
regname_base = '_'.join(regname.split('_')[:-1]) # remove the '_mf' tag
for i in range(len(self.obslist_next)): # sum up regularizer gradients at each frequency
logfreqratio = self._logfreqratio_list[i]
flux_nu = self.mf_flux[i]
imcur_nu = mfutils.image_at_freq(imcur, logfreqratio)
prior_nu = mfutils.image_at_freq(self._xprior, logfreqratio)
regi = imutils.regularizer(imcur_nu, prior_nu, self._embed_mask,
flux_nu, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname_base,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
if i==0:
reg = regi
else:
reg += regi
# here we only regularize the reference frequency image
else:
reg = imutils.regularizer(imcur[0], self._xprior[0], self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Spectral regularizers
elif regname in REGULARIZERS_SPECTRAL:
if regname in REGULARIZERS_SPECIND:
if len(imcur)==10: idx = 4
else: idx=1
elif regname in REGULARIZERS_CURV:
if len(imcur)==10: idx = 5
else: idx=2
elif regname in REGULARIZERS_SPECIND_P:
idx = 6
elif regname in REGULARIZERS_CURV_P:
idx = 7
elif regname in REGULARIZERS_RM:
idx = 8
elif regname in REGULARIZERS_CM:
idx = 9
reg = mfutils.regularizer_mf(imcur[idx], self._xprior[idx], self._embed_mask,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
else:
raise Exception("regularizer term %s not recognized!" % regname)
# Single-frequency polarimetric regularizer
elif regname in REGULARIZERS_POL:
reg = polutils.polregularizer(imcur, self._xprior, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Single-frequency, single-polarization regularizer
elif regname in REGULARIZERS:
if self.pol_next in POLARIZATION_MODES:
imcur0 = imcur[0]
prior0 = self._xprior[0]
else:
imcur0 = imcur
prior0 = self._xprior
reg = imutils.regularizer(imcur0, prior0, self._embed_mask,
self.flux_next,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
else:
raise Exception("regularizer term %s not recognized!" % regname)
# put regularizer terms in the dictionary
reg_dict[regname] = reg
return reg_dict
[docs] def make_reggrad_dict(self, imcur):
"""Make a dictionary of current regularizer gradient values
input is image array transformed to bounded values
"""
reggrad_dict = {}
for regname in sorted(self.reg_term_next.keys()):
# Multifrequency regularizers
if self.mf_next:
# Polarimetric regularizers
if regname in REGULARIZERS_POL:
# we only regularize reference frequency image
imcur_pol = imcur[0:4]
prior_pol = self._xprior[0:4]
pol_solve = self._which_solve[0:4]
regp = polutils.polregularizergrad(imcur_pol, prior_pol, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
pol_solve=pol_solve,
**self.regparams)
reggrad = np.zeros((len(imcur), self._nimage))
reggrad[0:4] = regp
# Stokes I regularizers
elif regname in REGULARIZERS:
# here we regularize images at each frequency
if regname in REGULARIZERS_ALLFREQS_I:
# TODO move this to checks?
if (not isinstance(self.mf_flux, list)) or len(self.mf_flux)!=len(self.obslist_next):
raise Exception("when using regularizer '%s', "%regname
+"self.mf_flux must be a list of same length as self.obslist_next!")
regname_base = '_'.join(regname.split('_')[:-1]) # remove the '_mf' tag
for i in range(len(self.obslist_next)): # sum up regularizer gradients at each frequency
logfreqratio = self._logfreqratio_list[i]
flux_nu = self.mf_flux[i]
imcur_nu = mfutils.image_at_freq(imcur, logfreqratio)
prior_nu = mfutils.image_at_freq(self._xprior, logfreqratio)
regi = imutils.regularizergrad(imcur_nu, prior_nu, self._embed_mask,
flux_nu, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname_base,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
reggrad_i = mfutils.mf_all_grads_chain(regi, imcur_nu, imcur, logfreqratio)
if i==0:
reggrad = reggrad_i
else:
reggrad += reggrad_i
# here we only regularize the reference frequency image
else:
regi = imutils.regularizergrad(imcur[0], self._xprior[0],
self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
reggrad = np.zeros((len(imcur), self._nimage))
reggrad[0] = regi
elif regname in REGULARIZERS_SPECTRAL:
if regname in REGULARIZERS_SPECIND:
if len(imcur)==10: idx = 4
else: idx=1
elif regname in REGULARIZERS_CURV:
if len(imcur)==10: idx = 5
else: idx=2
elif regname in REGULARIZERS_SPECIND_P:
idx = 6
elif regname in REGULARIZERS_CURV_P:
idx = 7
elif regname in REGULARIZERS_RM:
idx = 8
elif regname in REGULARIZERS_CM:
idx = 9
regmf = mfutils.regularizergrad_mf(imcur[idx], self._xprior[idx], self._embed_mask,
self.prior_next.xdim, self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
reggrad = np.zeros((len(imcur), self._nimage))
reggrad[idx] = regmf
else:
raise Exception("regularizer term %s not recognized!" % regname)
else:
# Single-frequency polarimetric regularizer
if regname in REGULARIZERS_POL:
reggrad = polutils.polregularizergrad(imcur, self._xprior, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
pol_solve=self._which_solve,
**self.regparams)
# Single-frequency, single polarization regularizer
elif regname in REGULARIZERS:
if self.pol_next in POLARIZATION_MODES:
imcur0 = imcur[0]
prior0 = self._xprior[0]
else:
imcur0 = imcur
prior0 = self._xprior
reggrad = imutils.regularizergrad(imcur0, prior0, self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# If imaging Stokes I with polarization simultaneously, bundle the gradient
if self.pol_next in POLARIZATION_MODES:
reggrad = np.array((reggrad,
np.zeros(self._nimage),
np.zeros(self._nimage),
np.zeros(self._nimage)))
else:
raise Exception("regularizer term %s not recognized!" % regname)
# put regularizer terms in the dictionary
reggrad_dict[regname] = reggrad
return reggrad_dict
[docs] def objfunc(self, imvec):
"""Current objective function.
"""
# Unpack polarimetric/multifrequency vector into an array
imcur = unpack_imarr(imvec, self._xarr, self._which_solve)
# apply image transform to bounded values
imcur = transform_imarr(imcur, self.transform_next, self._which_solve)
# Data terms
datterm = 0.
chi2_term_dict = self.make_chisq_dict(imcur)
for dname in sorted(self.dat_term_next.keys()):
hyperparameter = self.dat_term_next[dname]
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
chi2 = chi2_term_dict[dname_key]
if self.chisq_transform:
datterm += hyperparameter * (chi2 + 1./chi2 - 1.)
else:
datterm += hyperparameter * (chi2 - 1.)
# Regularizer terms
regterm = 0
reg_term_dict = self.make_reg_dict(imcur)
for regname in sorted(self.reg_term_next.keys()):
hyperparameter = self.reg_term_next[regname]
regularizer = reg_term_dict[regname]
regterm += hyperparameter * regularizer
# Total cost
cost = datterm + regterm
return cost
[docs] def objgrad(self, imvec):
"""Current objective function gradient.
"""
# Unpack polarimetric/multifrequency vector into an array
imcur = unpack_imarr(imvec, self._xarr, self._which_solve)
# apply image transform to bounded values
imcur_prime = imcur.copy()
imcur = transform_imarr(imcur, self.transform_next, self._which_solve)
# Data terms
datterm = 0.
chi2_term_dict = self.make_chisqgrad_dict(imcur)
if self.chisq_transform:
chi2_value_dict = self.make_chisq_dict(imcur)
for dname in sorted(self.dat_term_next.keys()):
hyperparameter = self.dat_term_next[dname]
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
chi2_grad = chi2_term_dict[dname_key]
if self.chisq_transform:
chi2_val = chi2_value_dict[dname]
datterm += hyperparameter * chi2_grad * (1. - 1./(chi2_val**2))
else:
datterm += hyperparameter * (chi2_grad + self.chisq_offset_gradient)
# Regularizer terms
regterm = 0
reg_term_dict = self.make_reggrad_dict(imcur)
for regname in sorted(self.reg_term_next.keys()):
hyperparameter = self.reg_term_next[regname]
regularizer_grad = reg_term_dict[regname]
regterm += hyperparameter * regularizer_grad
# Total gradient
grad = datterm + regterm
# Chain rule term for change of variables
grad = transform_gradients(grad, imcur_prime, self.transform_next, self._which_solve)
# repack gradient
grad = pack_imarr(grad, self._which_solve)
return grad
[docs] def plotcur(self, imvec, **kwargs):
"""Plot current image.
"""
if self._show_updates:
if self._nit % self._update_interval == 0:
# Unpack polarimetric/multifrequency vector into an array
imcur = unpack_imarr(imvec, self._xarr, self._which_solve)
# apply image transform to bounded values
imcur_prime = imcur.copy()
imcur = transform_imarr(imcur, self.transform_next, self._which_solve)
# Get chi^2 and regularizer
chi2_term_dict = self.make_chisq_dict(imcur)
reg_term_dict = self.make_reg_dict(imcur)
# Format print string
outstr = "------------------------------------------------------------------"
outstr += "\n%4d | " % self._nit
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
outstr += "\n "
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
dval = chi2_term_dict[dname_key]*self.dat_term_next[dname]
outstr += "%s : %0.1f " % (dname_key, dval)
outstr += "\n "
for regname in sorted(self.reg_term_next.keys()):
rval = reg_term_dict[regname]*self.reg_term_next[regname]
outstr += "%s : %0.1f " % (regname, rval)
# Embed and plot the image
if not self.mf_next: # TODO plot multi-frequency?
if np.any(np.invert(self._embed_mask)):
implot = embed_imarr(imcur, self._embed_mask)
else:
implot = imcur
if self.pol_next in POLARIZATION_MODES:
polutils.plot_m(implot, self.prior_next, self._nit, chi2_term_dict, **kwargs)
else:
imutils.plot_i(implot, self.prior_next, self._nit,
chi2_term_dict, pol=self.pol_next, **kwargs)
if self._nit == 0:
print()
print(outstr)
self._nit += 1
def _append_image_history(self, outim, logstr):
self.logstr += (logstr + "\n")
self._obs_list.append(self.obslist_next)
self._init_list.append(self.init_next)
self._prior_list.append(self.prior_next)
self._debias_list.append(self.debias_next)
self._weighting_list.append(self.weighting_next)
self._systematic_noise_list.append(self.systematic_noise_next)
self._systematic_cphase_noise_list.append(self.systematic_cphase_noise_next)
self._snrcut_list.append(self.snrcut_next)
self._flux_list.append(self.flux_next)
self._pflux_list.append(self.pflux_next)
self._vflux_list.append(self.vflux_next)
self._pol_list.append(self.pol_next)
self._clipfloor_list.append(self.clipfloor_next)
self._maxset_list.append(self.clipfloor_next)
self._maxit_list.append(self.maxit_next)
self._stop_list.append(self.stop_next)
self._cp_uv_min_list.append(self.cp_uv_min)
self._mf_list.append(self.mf_next)
self._reffreq_list.append(self.reffreq)
self._mf_order_list.append(self.mf_order)
self._mf_order_pol_list.append(self.mf_order_pol)
self._mf_rm_list.append(self.mf_rm)
self._mf_cm_list.append(self.mf_cm)
self._transform_list.append(self.transform_next)
self._reg_term_list.append(self.reg_term_next)
self._dat_term_list.append(self.dat_term_next)
self._out_list.append(outim)
return
#############################################################
# Helper functions
#############################################################
[docs]def embed_imarr(imarr, mask, clipfloor=0., randomfloor=False):
"""Embeds a multidimensional image array into the size of boolean embed mask
"""
imarrdim = len(imarr.shape)
if imarrdim==2:
nsolve = imarr.shape[0]
nimage = imarr.shape[1]
elif imarrdim==1:
nsolve = 1
nimage = imarr.shape[0]
imarr = imarr.reshape((nsolve,nimage))
else:
raise Exception("in embed_imarr, imarr should have one or two dimensions!")
if nimage!=np.sum(mask):
raise Exception("in embed_imarr, number of masked pixels is not consistent with imarr shape!")
nimage_out = len(mask)
outarr = np.empty((nsolve,nimage_out))
# TODO does this require the for loop?
for kk in range(nsolve):
outarr[kk] = imutils.embed(imarr[kk], mask, clipfloor=clipfloor, randomfloor=randomfloor)
if imarrdim==1:
outarr = outarr[0]
return outarr
[docs]def pack_imarr(imarr, which_solve):
"""pack image array imarr into 1D array vec for minimizaiton
ignore quantities not solved for
"""
imarrdim = len(imarr.shape)
if imarrdim==2:
nsolve = imarr.shape[0]
nimage = imarr.shape[1]
elif imarrdim==1:
nsolve = 1
nimage = imarr.shape[0]
imarr = imarr.reshape((nsolve,nimage))
else:
raise Exception("in pack_imarr, imarr should have one or two dimensions!")
if nsolve != len(which_solve):
raise Exception("in pack_imarr, imarr has inconsistent shape with which_solve!")
vec = np.array([])
for kk in range(nsolve):
if which_solve[kk]!=0:
vec = np.hstack((vec,imarr[kk]))
return vec
[docs]def unpack_imarr(vec, priorarr, which_solve):
"""unpack minimized vector vec into array,
replace quantities not solved for with their initial values
"""
imarrdim = len(priorarr.shape)
if imarrdim==2:
nsolve = priorarr.shape[0]
nimage = priorarr.shape[1]
elif imarrdim==1:
nsolve = 1
nimage = priorarr.shape[0]
imarr = priorarr.reshape((nsolve,nimage))
else:
raise Exception("in unpack_imarr, priorarr should have one or two dimensions !")
if nsolve != len(which_solve):
raise Exception("in unpack_imarr, priorarr has inconsistent shape with which_solve!")
imct = 0
imarr = np.empty((nsolve, nimage))
for kk in range(nsolve):
if which_solve[kk]==0:
imarr[kk] = priorarr[kk]
else:
imarr[kk] = vec[imct*nimage:(imct+1)*nimage]
imct += 1
if imarrdim==1:
imarr = imarr[0]
return imarr
[docs]def make_initarr(image, mask, norm_init=False, flux=1,
mf=False, pol=False,
randompol_lin=False, randompol_circ=False,
meanpol=0.2, sigmapol=1.e-2):
"""Make initial image array from image object, or initialize with default values"""
# set initial and prior images
init_I = image.imvec[mask]
nimage = len(init_I)
if norm_init:
normfac = flux / (np.sum(init_I))
init_I = normfac * init_I
else:
normfac = 1
# TODO -- apply a floor to init_I?
# single-frequency, single-polarization
if not(pol) and not(mf):
initarr = np.array(init_I)
# polarization
if pol:
if len(image.qvec):
init_q = normfac*image.qvec[mask]
else:
init_q = np.zeros(nimage)
if len(image.uvec):
init_u = normfac*image.uvec[mask]
else:
init_u = np.zeros(nimage)
if len(image.vvec):
init_v = normfac*image.vvec[mask]
else:
init_v = normfac*np.zeros(nimage)
init_P = np.sqrt(init_q**2 + init_u**2)
init_rho = np.sqrt(init_q**2 + init_u**2 + init_v**2) / init_I
init_phi = np.arctan2(init_u, init_q)
init_psi = np.arctan2(init_v, init_P)
if not(np.any(init_rho!=0)) and randompol_lin:
print("No polarimetric image in init!")
print("--initializing with 20% pol and random orientation!")
init_rho = meanpol * (np.ones(nimage) + sigmapol * np.random.rand(nimage))
init_phi = np.zeros(nimage) + sigmapol * np.random.rand(nimage)
if not(np.any(init_psi!=0)) and randompol_circ:
print("No circular polarization image in init!")
print("--initializing with random values!")
init_rho = meanpol * (np.ones(nimage) + sigmapol * np.random.rand(nimage))
init_psi = np.zeros(nimage) + sigmapol * np.random.rand(nimage)
if not(mf):
initarr = np.array((init_I, init_rho, init_phi, init_psi))
# multi-frequency
if mf:
if len(image.specvec):
init_a = image.specvec[mask]
else:
init_a = np.zeros(nimage)
if len(image.curvvec):
init_b = image.curvvec[mask]
else:
init_b = np.zeros(nimage)
# multi-frequency, multi-polarization
if pol:
if len(image.specvec_pol):
init_ap = image.specvec_pol[mask]
else:
init_ap = np.zeros(nimage)
if len(image.curvvec_pol):
init_bp = image.curvvec_pol[mask]
else:
init_bp = np.zeros(nimage)
# TODO what do we want to initialize RM and CM to?
if len(image.rmvec):
init_rm = image.rmvec[mask]
else:
init_rm = np.zeros(nimage)
if len(image.cmvec):
init_cm = image.cmvec[mask]
else:
init_cm = np.zeros(nimage)
initarr = np.array((init_I, init_rho, init_phi, init_psi,
init_a, init_b, init_ap, init_bp,
init_rm, init_cm))
else:
initarr = np.array((init_I, init_a, init_b))
return initarr