# imager.py
# a general interferometric imager class
#
# Copyright (C) 2018 Andrew Chael
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
from __future__ import print_function
from builtins import str
from builtins import range
from builtins import object
import copy
import time
import numpy as np
import scipy.optimize as opt
import ehtim.scattering as so
import ehtim.imaging.imager_utils as imutils
import ehtim.imaging.pol_imager_utils as polutils
import ehtim.imaging.multifreq_imager_utils as mfutils
import ehtim.image
import ehtim.const_def as ehc
MAXIT = 200 # number of iterations
NHIST = 50 # number of steps to store for hessian approx
MAXLS = 40 # maximum number of line search steps in BFGS-B
STOP = 1e-6 # convergence criterion
EPS = 1e-8
DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag']
REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log', 'l1', 'l1w', 'lA', 'patch',
'flux', 'cm', 'simple', 'compact', 'compact2', 'rgauss']
REGULARIZERS_SPECIND = ['l2_alpha', 'tv_alpha']
REGULARIZERS_CURV = ['l2_beta', 'tv_beta']
DATATERMS_POL = ['pvis', 'm', 'pbs','vvis']
REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','vtv2','vflux']
GRIDDER_P_RAD_DEFAULT = 2
GRIDDER_CONV_FUNC_DEFAULT = 'gaussian'
FFT_PAD_DEFAULT = 2
FFT_INTERP_DEFAULT = 3
REG_DEFAULT = {'simple': 1}
DAT_DEFAULT = {'vis': 100}
POL_TRANS = True # this means we solve for polarization in the m, chi basis
#POL_WHICH_SOLVE = (0, 1, 1) # this means that pol imaging solves for m & chi (not I), for now
# not used, now determined by 'pol_next'
MF_WHICH_SOLVE = (1, 1, 0) # this means that mf imaging solves for I0 and alpha (not beta), for now
# DEFAULT ONLY: object now uses self.mf_which_solve
REGPARAMS_DEFAULT = {'major':50*ehc.RADPERUAS,
'minor':50*ehc.RADPERUAS,
'PA':0.,
'alpha_A':1.0,
'epsilon_tv':0.0}
POLARIZATION_MODES = ['P','QU','IP','IQU','V','IV','IQUV','IPV'] # TODO: treatment of V may be inconsistent
###################################################################################################
# Imager object
###################################################################################################
[docs]class Imager(object):
"""A general interferometric imager.
"""
def __init__(self, obs_in, init_im,
prior_im=None, flux=None, data_term=DAT_DEFAULT, reg_term=REG_DEFAULT, **kwargs):
self.logstr = ""
self._obs_list = []
self._init_list = []
self._prior_list = []
self._out_list = []
self._out_list_epsilon = []
self._out_list_scattered = []
self._reg_term_list = []
self._dat_term_list = []
self._clipfloor_list = []
self._maxset_list = []
self._pol_list = []
self._maxit_list = []
self._stop_list = []
self._flux_list = []
self._pflux_list = []
self._vflux_list = []
self._snrcut_list = []
self._debias_list = []
self._systematic_noise_list = []
self._systematic_cphase_noise_list = []
self._transform_list = []
self._weighting_list = []
# Regularizer/data terms for the next imaging iteration
self.reg_term_next = reg_term # e.g. [('simple',1), ('l1',10), ('flux',500), ('cm',500)]
self.dat_term_next = data_term # e.g. [('amp', 1000), ('cphase',100)]
# Observations, frequencies
self.reffreq = init_im.rf
if isinstance(obs_in, list):
self._obslist_next = obs_in
self.obslist_next = obs_in
else:
self._obslist_next = [obs_in]
self.obslist_next = [obs_in]
# Init, prior, flux
self.init_next = init_im
if prior_im is None:
self.prior_next = self.init_next
else:
self.prior_next = prior_im
if flux is None:
self.flux_next = self.prior_next.total_flux()
else:
self.flux_next = flux
# set polarimetric flux values equal to Stokes I flux by default
# used in regularizer normalization
self.pflux_next = kwargs.get('pflux', flux)
self.vflux_next = kwargs.get('vflux', flux)
# Polarization
self.pol_next = kwargs.get('pol', self.init_next.pol_prim)
# Weighting/debiasing/snr cut/systematic noise
self.debias_next = kwargs.get('debias', True)
snrcut = kwargs.get('snrcut', 0.)
self.snrcut_next = {key: 0. for key in set(DATATERMS+DATATERMS_POL)}
if type(snrcut) is dict:
for key in snrcut.keys():
self.snrcut_next[key] = snrcut[key]
else:
for key in self.snrcut_next.keys():
self.snrcut_next[key] = snrcut
self.systematic_noise_next = kwargs.get('systematic_noise', 0.)
self.systematic_cphase_noise_next = kwargs.get('systematic_cphase_noise', 0.)
self.weighting_next = kwargs.get('weighting', 'natural')
# Maximal/minimal closure set
self.maxset_next = kwargs.get('maxset', False)
# Clippping
self.clipfloor_next = kwargs.get('clipfloor', 0.)
self.maxit_next = kwargs.get('maxit', MAXIT)
self.stop_next = kwargs.get('stop', STOP)
self.transform_next = kwargs.get('transform', ['log','mcv'])
self.transform_next = np.array([self.transform_next]).flatten() #so we can handle multiple transforms
# Normalize or not?
self.norm_init = kwargs.get('norm_init', True)
self.norm_reg = kwargs.get('norm_reg', False)
self.beam_size = self.obslist_next[0].res()
self.regparams = {k: kwargs.get(k, REGPARAMS_DEFAULT[k]) for k in REGPARAMS_DEFAULT.keys()}
self.chisq_transform = False
self.chisq_offset_gradient = 0.0
# FFT parameters
self._ttype = kwargs.get('ttype', 'nfft')
self._fft_gridder_prad = kwargs.get('fft_gridder_prad', GRIDDER_P_RAD_DEFAULT)
self._fft_conv_func = kwargs.get('fft_conv_func', GRIDDER_CONV_FUNC_DEFAULT)
self._fft_pad_factor = kwargs.get('fft_pad_factor', FFT_PAD_DEFAULT)
self._fft_interp_order = kwargs.get('fft_interp_order', FFT_INTERP_DEFAULT)
# UV minimum for closure phases
self.cp_uv_min = kwargs.get('cp_uv_min', False)
# Parameters related to scattering
self.epsilon_list_next = []
self.scattering_model = kwargs.get('scattering_model', None)
self._sqrtQ = None
self._ea_ker = None
self._ea_ker_gradient_x = None
self._ea_ker_gradient_y = None
self._alpha_phi_list = []
self.alpha_phi_next = kwargs.get('alpha_phi', 1e4)
# Imager history
self._change_imgr_params = True
self.nruns = 0
# multifrequency
self.mf_next = False
self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf',False)
self.mf_which_solve = kwargs.get('mf_which_solve',MF_WHICH_SOLVE)
# Set embedding matrices and prepare imager
self.check_params()
self.check_limits()
self.init_imager()
@property
def obslist_next(self):
return self._obslist_next
@obslist_next.setter
def obslist_next(self, obslist):
if not isinstance(obslist, list):
raise Exception("obslist_next must be a list!")
self._obslist_next = obslist
self.freq_list = [obs.rf for obs in self.obslist_next]
#self.reffreq = self.freq_list[0] #Changed so that reffreq is determined by initial image/prior rf
self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
@property
def obs_next(self):
"""the next Obsdata to be used in imaging
"""
return self.obslist_next[0]
@obs_next.setter
def obs_next(self, obs):
"""the next Obsdata to be used in imaging
"""
self.obslist_next = [obs]
[docs] def make_image(self, pol=None, grads=True, mf=False, **kwargs):
"""Make an image using current imager settings.
Args:
pol (str): which polarization to image
grads (bool): whether or not to use image gradients
mf (bool): whether or not to do multifrequency (spectral index only for now)
Returns:
(Image): output image
"""
self.mf_next = mf
self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf', self.reg_all_freq_mf)
self.mf_which_solve = kwargs.get('mf_which_solve', self.mf_which_solve)
if pol is None:
pol_prim = self.pol_next
else:
self.pol_next = pol
pol_prim = pol
print("==============================")
print("Imager run %i " % (int(self.nruns)+1))
# For polarimetric imaging, switch polrep to Stokes
if self.pol_next in POLARIZATION_MODES:
print("Imaging Polarization: switching to Stokes!")
self.prior_next = self.prior_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
self.init_next = self.init_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
pol_prim = 'I'
# Checks and initialize
self.check_params()
self.check_limits()
self.init_imager()
# Print initial stats
self._nit = 0
self._show_updates = kwargs.get('show_updates', True)
self._update_interval = kwargs.get('update_interval', 1)
# Plot initial image
self.plotcur(self._xinit, **kwargs)
# Minimize
optdict = {'maxiter': self.maxit_next,
'ftol': self.stop_next, 'gtol': self.stop_next,
'maxcor': NHIST, 'maxls': MAXLS}
def callback_func(xcur):
self.plotcur(xcur, **kwargs)
print("Imaging . . .")
tstart = time.time()
if grads:
res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B', jac=self.objgrad,
options=optdict, callback=callback_func)
else:
res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B',
options=optdict, callback=callback_func)
tstop = time.time()
# Format output
out = res.x[:]
self.tmpout = res.x
if self.pol_next in POLARIZATION_MODES: # polarization
if self.pol_next == 'P':
out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,1,1))
if 'mcv' in self.transform_next:
out = polutils.mcv(out)
elif self.pol_next == 'IP' or self.pol_next == 'IQU':
out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1))
if 'mcv' in self.transform_next:
out = polutils.mcv(out)
if 'log' in self.transform_next:
out[0] = np.exp(out[0])
elif self.pol_next == 'V':
out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,0,0,1))
if 'mcv' in self.transform_next:
out = polutils.mcv(out)
elif self.pol_next == 'IV':
out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,0,0,1))
if 'mcv' in self.transform_next:
out = polutils.mcv(out)
if 'log' in self.transform_next:
out[0] = np.exp(out[0])
elif self.pol_next == 'IQUV':
out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1,1))
if 'mcv' in self.transform_next:
out = polutils.mcv(out)
if 'log' in self.transform_next:
out[0] = np.exp(out[0])
elif self.mf_next: # multi-frequency
out = mfutils.unpack_mftuple(out, self._xtuple, self._nimage, self.mf_which_solve)
if 'log' in self.transform_next:
out[0] = np.exp(out[0])
elif 'log' in self.transform_next: # simple single-frequency
out = np.exp(out)
# Print final stats
outstr = ""
chi2_term_dict = self.make_chisq_dict(out)
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
try:
print("time: %f s" % (tstop - tstart))
print("J: %f" % res.fun)
print(outstr)
if isinstance(res.message,str): print(res.message)
else: print(res.message.decode())
except: # TODO -- issues for some users with res.message
pass
print("==============================")
# Embed image
if self.pol_next in POLARIZATION_MODES: # polarization
if np.any(np.invert(self._embed_mask)):
out = polutils.embed_pol(out, self._embed_mask)
iimage_out = out[0]
qimage_out = polutils.make_q_image(out, POL_TRANS)
uimage_out = polutils.make_u_image(out, POL_TRANS)
vimage_out = polutils.make_v_image(out, POL_TRANS)
elif self.mf_next: # multi-frequency
if np.any(np.invert(self._embed_mask)):
out = mfutils.embed_mf(out, self._embed_mask)
iimage_out = out[0]
specind_out = out[1]
curv_out = out[2]
else: # simple single-pol
if np.any(np.invert(self._embed_mask)):
out = imutils.embed(out, self._embed_mask)
iimage_out = out
# Return image
arglist, argdict = self.prior_next.image_args()
arglist[0] = iimage_out.reshape(self.prior_next.ydim, self.prior_next.xdim)
argdict['pol_prim'] = pol_prim
outim = ehtim.image.Image(*arglist, **argdict)
# Copy over other polarizations
for pol2 in list(outim._imdict.keys()):
# Is it the base image?
if pol2 == outim.pol_prim:
continue
# Did we solve for polarimeric image or are we copying over old pols?
if self.pol_next in POLARIZATION_MODES and pol2 == 'Q':
polvec = qimage_out
elif self.pol_next in POLARIZATION_MODES and pol2 == 'U':
polvec = uimage_out
elif self.pol_next in POLARIZATION_MODES and pol2 == 'V':
polvec = vimage_out
else:
polvec = self.init_next._imdict[pol2]
if len(polvec):
polarr = polvec.reshape(outim.ydim, outim.xdim)
outim.add_pol_image(polarr, pol2)
# Copy over spectral index information
outim._mflist = copy.deepcopy(self.init_next._mflist)
if self.mf_next:
outim._mflist[0] = specind_out
outim._mflist[1] = curv_out
# Append to history
logstr = str(self.nruns) + ": make_image(pol=%s)" % pol
self._append_image_history(outim, logstr)
self.nruns += 1
# Return Image object
return outim
def converge(self, niter, blur_frac, pol, grads=True, **kwargs):
blur = blur_frac * self.obs_next.res()
for repeat in range(niter-1):
init = self.out_last()
init = init.blur_circ(blur, blur)
self.init_next = init
self.make_image(pol=pol, grads=grads, **kwargs)
[docs] def make_image_I(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'I'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_P(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes P polarimetric image using current imager settings.
"""
pol = 'P'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_IP(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I and P polarimetric image simultaneously using current imager settings.
"""
pol = 'IP'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_V(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'V'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def make_image_IV(self, grads=True, niter=1, blur_frac=1, **kwargs):
"""Make Stokes I image using current imager settings.
"""
pol = 'IV'
self.make_image(pol=pol, grads=grads, **kwargs)
self.converge(niter, blur_frac, pol, grads, **kwargs)
return self.out_last()
[docs] def set_embed(self):
"""Set embedding matrix.
"""
self._embed_mask = self.prior_next.imvec > self.clipfloor_next
if not np.any(self._embed_mask):
raise Exception("clipfloor_next too large: all prior pixels have been clipped!")
xmax = self.prior_next.xdim//2
ymax = self.prior_next.ydim//2
if self.prior_next.xdim % 2: xmin=-xmax-1
else: xmin=-xmax
if self.prior_next.ydim % 2: ymin=-ymax-1
else: ymin=-ymax
coord = np.array([[[x, y]
for x in np.arange(xmax, xmin, -1)]
for y in np.arange(ymax, ymin, -1)])
coord = coord.reshape(self.prior_next.ydim * self.prior_next.xdim, 2)
coord = coord * self.prior_next.psize
self._coord_matrix = coord[self._embed_mask]
return
[docs] def check_params(self):
"""Check parameter consistency.
"""
if ((self.prior_next.psize != self.init_next.psize) or
(self.prior_next.xdim != self.init_next.xdim) or
(self.prior_next.ydim != self.init_next.ydim)):
raise Exception("Initial image does not match dimensions of the prior image!")
if ((self.prior_next.rf != self.init_next.rf)):
raise Exception("Initial image does not have same frequency as prior image!")
if (self.prior_next.polrep != self.init_next.polrep):
raise Exception(
"Initial image polrep does not match prior polrep!")
if (self.prior_next.polrep == 'circ' and not(self.pol_next in ['RR', 'LL'])):
raise Exception("Initial image polrep is 'circ': pol_next must be 'RR' or 'LL'")
if (self.prior_next.polrep == 'stokes' and not(self.pol_next in ['I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'])):
raise Exception(
"Initial image polrep is 'stokes': pol_next must be in 'I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'!")
# TODO single-polarization imaging. should we still support?
if ('log' in self.transform_next and self.pol_next in ['Q', 'U', 'V']):
raise Exception("Cannot image Stokes Q, U, V with log image transformation!")
if(self.pol_next in ['Q', 'U', 'V'] and
('gs' in self.reg_term_next.keys() or 'simple' in self.reg_term_next.keys())):
raise Exception(
"'simple' and 'gs' methods do not work with Stokes Q, U, or V images!")
if self._ttype not in ['fast', 'direct', 'nfft']:
raise Exception("Possible ttype values are 'fast', 'direct','nfft'!")
# Catch errors in multifrequency imaging setup
if self.mf_next and len(set(self.freq_list)) < 2:
raise Exception(
"must have observations at at least two frequencies for multifrequency imaging!")
# Catch errors for polarimetric imaging setup
if self.pol_next in POLARIZATION_MODES:
if 'mcv' not in self.transform_next:
raise Exception("Polarimetric imaging needs 'mcv' transform!")
if (self._ttype not in ["direct", "nfft"]):
raise Exception("FFT not yet implemented in polarimetric imaging -- use NFFT!")
if 'I' in self.pol_next:
rlist = REGULARIZERS + REGULARIZERS_POL
dlist = DATATERMS + DATATERMS_POL
else:
rlist = REGULARIZERS_POL
dlist = DATATERMS_POL
else:
rlist = REGULARIZERS + REGULARIZERS_SPECIND + REGULARIZERS_CURV
dlist = DATATERMS
# catch errors in general imaging setup
dt_here = False
dt_type = True
for term in sorted(self.dat_term_next.keys()):
if (term is not None) and (term is not False):
dt_here = True
if not ((term in dlist) or (term is False)):
dt_type = False
st_here = False
st_type = True
for term in sorted(self.reg_term_next.keys()):
if (term is not None) and (term is not False):
st_here = True
if not ((term in rlist) or (term is False)):
st_type = False
if not dt_here:
raise Exception("Must have at least one data term!")
if not st_here:
raise Exception("Must have at least one regularizer term!")
if not dt_type:
raise Exception("Invalid data term: valid data terms are: " + ','.join(dlist))
if not st_type:
raise Exception("Invalid regularizer: valid regularizers are: " + ','.join(rlist))
# Determine if we need to recompute the saved imager parameters on the next imager run
if self.nruns == 0:
return
if self.pol_next != self.pol_last():
print("changed polarization!")
self._change_imgr_params = True
return
if self.obslist_next != self.obslist_last():
print("changed observation!")
self._change_imgr_params = True
return
if len(self.reg_term_next) != len(self.reg_terms_last()):
print("changed number of regularizer terms!")
self._change_imgr_params = True
return
if len(self.dat_term_next) != len(self.dat_terms_last()):
print("changed number of data terms!")
self._change_imgr_params = True
return
for term in sorted(self.dat_term_next.keys()):
if term not in self.dat_terms_last().keys():
print("added %s to data terms" % term)
self._change_imgr_params = True
return
for term in sorted(self.reg_term_next.keys()):
if term not in self.reg_terms_last().keys():
print("added %s to regularizers!" % term)
self._change_imgr_params = True
return
if ((self.prior_next.psize != self.prior_last().psize) or
(self.prior_next.xdim != self.prior_last().xdim) or
(self.prior_next.ydim != self.prior_last().ydim)):
print("changed prior dimensions!")
self._change_imgr_params = True
if self.debias_next != self.debias_last():
print("changed debiasing!")
self._change_imgr_params = True
return
if self.snrcut_next != self.snrcut_last():
print("changed snrcut!")
self._change_imgr_params = True
return
if self.weighting_next != self.weighting_last():
print("changed data weighting!")
self._change_imgr_params = True
return
if self.systematic_noise_next != self.systematic_noise_last():
print("changed systematic noise!")
self._change_imgr_params = True
return
if self.systematic_cphase_noise_next != self.systematic_cphase_noise_last():
print("changed systematic cphase noise!")
self._change_imgr_params = True
return
[docs] def check_limits(self):
"""Check image parameter consistency with observation.
"""
uvmax = 1.0/self.prior_next.psize
uvmin = 1.0/(self.prior_next.psize*np.max((self.prior_next.xdim, self.prior_next.ydim)))
uvdists = self.obs_next.unpack('uvdist')['uvdist']
maxbl = np.max(uvdists)
minbl = np.max(uvdists[uvdists > 0])
if uvmax < maxbl:
print("Warning! Pixel size is larger than smallest spatial wavelength!")
if uvmin > minbl:
print("Warning! Field of View is smaller than largest nonzero spatial wavelength!")
if self.pol_next in ['I', 'RR', 'LL']:
maxamp = np.max(np.abs(self.obs_next.unpack('amp')['amp']))
if self.flux_next > 1.2*maxamp:
print("Warning! Specified flux is > 120% of maximum visibility amplitude!")
if self.flux_next < .8*maxamp:
print("Warning! Specified flux is < 80% of maximum visibility amplitude!")
[docs] def reg_terms_last(self):
"""Return last used regularizer terms.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._reg_term_list[-1]
[docs] def dat_terms_last(self):
"""Return last used data terms.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._dat_term_list[-1]
[docs] def obslist_last(self):
"""Return last used observation.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._obs_list[-1]
[docs] def obs_last(self):
"""Return last used observation.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._obs_list[-1][0]
[docs] def prior_last(self):
"""Return last used prior image.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._prior_list[-1]
[docs] def out_last(self):
"""Return last result.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._out_list[-1]
[docs] def out_scattered_last(self):
"""Return last result with scattering.
"""
if self.nruns == 0 or len(self._out_list_scattered) == 0:
print("No stochastic optics imager runs yet!")
return
return self._out_list_scattered[-1]
[docs] def out_epsilon_last(self):
"""Return last result with scattering.
"""
if self.nruns == 0 or len(self._out_list_epsilon) == 0:
print("No stochastic optics imager runs yet!")
return
return self._out_list_epsilon[-1]
[docs] def init_last(self):
"""Return last initial image.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._init_list[-1]
[docs] def flux_last(self):
"""Return last total flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._flux_list[-1]
[docs] def pflux_last(self):
"""Return last total linear polarimetric flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._pflux_list[-1]
[docs] def vflux_last(self):
"""Return last total circular polarimetric flux constraint.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._vflux_list[-1]
[docs] def clipfloor_last(self):
"""Return last clip floor.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._clipfloor_list[-1]
[docs] def pol_last(self):
"""Return last polarization imaged.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._pol_list[-1]
[docs] def maxit_last(self):
"""Return last max_iterations value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._maxit_list[-1]
[docs] def debias_last(self):
"""Return last debias value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._debias_list[-1]
[docs] def snrcut_last(self):
"""Return last snrcut value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._snrcut_list[-1]
[docs] def weighting_last(self):
"""Return last weighting value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._weighting_list[-1]
[docs] def systematic_noise_last(self):
"""Return last systematic_noise value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._systematic_noise_list[-1]
[docs] def systematic_cphase_noise_last(self):
"""Return last closure phase systematic noise value (in degree).
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._systematic_cphase_noise_list[-1]
[docs] def stop_last(self):
"""Return last convergence value.
"""
if self.nruns == 0:
print("No imager runs yet!")
return
return self._stop_list[-1]
[docs] def init_imager(self):
"""Set up Stokes I imager.
"""
# Set embedding
self.set_embed()
# Set prior & initial image vectors for polarimetric imaging
if self.pol_next in POLARIZATION_MODES:
# initial I image
if self.norm_init and ('I' in self.pol_next):
self._nprior = (self.flux_next * self.prior_next.imvec /
np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
iinit = (self.flux_next * self.init_next.imvec /
np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
else:
self._nprior = self.prior_next.imvec[self._embed_mask]
iinit = self.init_next.imvec[self._embed_mask]
self._nimage = len(iinit)
# Initialize m & phi & v
if (len(self.init_next.qvec) and
(np.any(self.init_next.qvec != 0) or np.any(self.init_next.uvec != 0))):
init1 = np.abs(self.init_next.qvec + 1j*self.init_next.uvec) / self.init_next.imvec
init1 = init1[self._embed_mask]
init2 = (np.arctan2(self.init_next.uvec, self.init_next.qvec) / 2.0)
init2 = init2[self._embed_mask]
else:
# !AC TODO get the actual zero baseline polarization fraction from the data?
print("No polarimetric image in init_next!")
print("--initializing with 20% pol and random orientation!")
init1 = 0.2 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage))
init2 = np.zeros(self._nimage) + 1e-2 * np.random.rand(self._nimage)
# Initialize v
if 'V' in self.pol_next:
if len(self.init_next.vvec) and (np.any(self.init_next.vvec != 0)):
init3 = self.init_next.vvec / self.init_next.imvec
init3 = init3[self._embed_mask]
else:
# !AC TODO get the actual zero baseline polarization fraction from the data?
print("No V polarimetric image in init_next!")
print("--initializing with random vector")
#init3 = 0.05 * np.random.randn(self._nimage)
init3 = 0.01 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage))
self._inittuple = np.array((iinit, init1, init2, init3))
else:
self._inittuple = np.array((iinit, init1, init2))
# Change of variables
if 'mcv' in self.transform_next:
self._xtuple = polutils.mcv_r(self._inittuple)
else:
raise Exception("Polarimetric imaging only works with mcv transform!")
# Only apply log transformation to Stokes I if simultaneous imaging
if ('log' in self.transform_next) and ('I' in self.pol_next):
self._xtuple[0] = np.log(self._xtuple[0])
# Determine pol_which_solve
if self.pol_next in ['P','QU']:
self._pol_which_solve = (0,1,1)
elif self.pol_next in ['IP','IQU']:
self._pol_which_solve = (1,1,1)
elif self.pol_next in ['V']:
self._pol_which_solve = (0,0,0,1)
elif self.pol_next in ['IV']:
self._pol_which_solve = (1,0,0,1)
elif self.pol_next in ['IQUV']:
self._pol_which_solve = (1,1,1,1)
else:
raise Exception("Do not know correct pol_which_solve for self.pol_next=%s!"%self.pol_next)
# Pack into single vector
self._xinit = polutils.pack_poltuple(self._xtuple, self._pol_which_solve)
# Set prior & initial image vectors for multifrequency imaging
elif self.mf_next:
self.reffreq = self.init_next.rf # set reference frequency to same as prior
# reset logfreqratios in case reference frequency changed
self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
if self.norm_init:
nprior_I = (self.flux_next * self.prior_next.imvec /
np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
ninit_I = (self.flux_next * self.init_next.imvec /
np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
else:
nprior_I = self.prior_next.imvec[self._embed_mask]
ninit_I = self.init_next.imvec[self._embed_mask]
if len(self.init_next.specvec):
ninit_a = self.init_next.specvec[self._embed_mask]
else:
ninit_a = np.zeros(self._nimage)[self._embed_mask]
if len(self.prior_next.specvec):
nprior_a = self.prior_next.specvec[self._embed_mask]
else:
nprior_a = np.zeros(self._nimage)[self._embed_mask]
if len(self.init_next.curvvec):
ninit_b = self.init_next.curvvec[self._embed_mask]
else:
ninit_b = np.zeros(self._nimage)[self._embed_mask]
if len(self.prior_next.curvvec):
nprior_b = self.init_next.curvvec[self._embed_mask]
else:
nprior_b = np.zeros(self._nimage)[self._embed_mask]
self._nimage = len(ninit_I)
self.inittuple = np.array((ninit_I, ninit_a, ninit_b))
self.priortuple = np.array((nprior_I, nprior_a, nprior_b))
# Change of variables
if 'log' in self.transform_next:
self._xtuple = np.array((np.log(ninit_I), ninit_a, ninit_b))
else:
self._xtuple = self.inittuple
# Pack into single vector
self._xinit = mfutils.pack_mftuple(self._xtuple, self.mf_which_solve)
# Set prior & initial image vectors for single stokes or RR, LL imaging
else:
if self.norm_init:
self._nprior = (self.flux_next * self.prior_next.imvec /
np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
ninit = (self.flux_next * self.init_next.imvec /
np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
else:
self._nprior = self.prior_next.imvec[self._embed_mask]
ninit = self.init_next.imvec[self._embed_mask]
self._nimage = len(ninit)
# Change of variables
if 'log' in self.transform_next:
self._xinit = np.log(ninit)
else:
self._xinit = ninit
# Make data term tuples
if self._change_imgr_params:
if self.nruns == 0:
print("Initializing imager data products . . .")
if self.nruns > 0:
print("Recomputing imager data products . . .")
self._data_tuples = {}
# Loop over all data term types
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
# Each entry in the dterm dictionary past the first has an appended number
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
# Polarimetric data products
if dname in DATATERMS_POL:
tup = polutils.polchisqdata(obs, self.prior_next, self._embed_mask, dname,
ttype=self._ttype,
fft_pad_factor=self._fft_pad_factor,
conv_func=self._fft_conv_func,
p_rad=self._fft_gridder_prad)
# Single polarization data products
elif dname in DATATERMS:
if self.pol_next in POLARIZATION_MODES:
if not 'I' in self.pol_next:
raise Exception("cannot use dterm %s with pol=%s"%(dname,self.pol_next))
pol_next = 'I'
else:
pol_next = self.pol_next
tup = imutils.chisqdata(obs, self.prior_next, self._embed_mask, dname,
pol=pol_next, maxset=self.maxset_next,
debias=self.debias_next,
snrcut=self.snrcut_next[dname],
weighting=self.weighting_next,
systematic_noise=self.systematic_noise_next,
systematic_cphase_noise=self.systematic_cphase_noise_next,
ttype=self._ttype, order=self._fft_interp_order,
fft_pad_factor=self._fft_pad_factor,
conv_func=self._fft_conv_func,
p_rad=self._fft_gridder_prad,
cp_uv_min=self.cp_uv_min)
else:
raise Exception("data term %s not recognized!" % dname)
self._data_tuples[dname_key] = tup
self._change_imgr_params = False
return
[docs] def init_imager_scattering(self):
"""Set up scattering imager.
"""
N = self.prior_next.xdim
if self.scattering_model is None:
self.scattering_model = so.ScatteringModel()
# First some preliminary definitions
wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm]
N = self.prior_next.xdim
# Field of view, in cm, at the scattering screen
FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance
# The ensemble-average convolution kernel and its gradients
self._ea_ker = self.scattering_model.Ensemble_Average_Kernel(
self.prior_next, wavelength_cm=wavelength)
ea_ker_gradient = so.Wrapped_Gradient(self._ea_ker/(FOV/N))
self._ea_ker_gradient_x = -ea_ker_gradient[1]
self._ea_ker_gradient_y = -ea_ker_gradient[0]
# The power spectrum
# Note: rotation is not currently implemented;
# the gradients would need to be modified slightly
self._sqrtQ = np.real(self.scattering_model.sqrtQ_Matrix(self.prior_next, t_hr=0.0))
# Generate the initial image+screen vector.
# By default, the screen is re-initialized to zero each time.
if len(self.epsilon_list_next) == 0:
self._xinit = np.concatenate((self._xinit, np.zeros(N**2-1)))
else:
self._xinit = np.concatenate((self._xinit, self.epsilon_list_next))
[docs] def make_chisq_dict(self, imcur):
"""Make a dictionary of current chi^2 term values
i indexes the observation number in self.obslist_next
"""
chi2_dict = {}
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
(data, sigma, A) = self._data_tuples[dname_key]
if dname in DATATERMS_POL:
chi2 = polutils.polchisq(imcur, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask,
pol_trans=POL_TRANS)
elif dname in DATATERMS:
if self.mf_next: # multifrequency
logfreqratio = self._logfreqratio_list[i]
imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
elif self.pol_next in POLARIZATION_MODES: # polarization
imcur_nu = imcur[0]
else: # normal imaging
imcur_nu = imcur
chi2 = imutils.chisq(imcur_nu, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask)
else:
raise Exception("data term %s not recognized!" % dname)
chi2_dict[dname_key] = chi2
return chi2_dict
[docs] def make_chisqgrad_dict(self, imcur, i=0):
"""Make a dictionary of current chi^2 term gradient values
i indexes the observation number in self.obslist_next
"""
chi2grad_dict = {}
for dname in sorted(self.dat_term_next.keys()):
# Loop over all observations in the list
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
(data, sigma, A) = self._data_tuples[dname_key]
# Polarimetric data products
if dname in DATATERMS_POL:
chi2grad = polutils.polchisqgrad(imcur, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask,
pol_solve=self._pol_which_solve,
pol_trans=POL_TRANS)
# Single polarization data products
elif dname in DATATERMS:
if self.mf_next: # multifrequency
logfreqratio = self._logfreqratio_list[i]
imref = imcur[0]
imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
elif self.pol_next in POLARIZATION_MODES: # polarization
imcur_nu = imcur[0]
else: # normal imaging
imcur_nu = imcur
chi2grad = imutils.chisqgrad(imcur_nu, A, data, sigma, dname,
ttype=self._ttype, mask=self._embed_mask)
# If multifrequency imaging,
# transform the image gradients for all the solved quantities
if self.mf_next:
logfreqratio = self._logfreqratio_list[i]
chi2grad = mfutils.mf_all_grads_chain(chi2grad, imcur_nu, imref, logfreqratio)
# If imaging polarization simultaneously, bundle the gradient properly
if self.pol_next in POLARIZATION_MODES:
if 'V' in self.pol_next:
chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage)))
else:
chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage)))
else:
raise Exception("data term %s not recognized!" % dname)
chi2grad_dict[dname_key] = np.array(chi2grad)
return chi2grad_dict
[docs] def make_reg_dict(self, imcur):
"""Make a dictionary of current regularizer values
"""
reg_dict = {}
for regname in sorted(self.reg_term_next.keys()):
# Polarimetric regularizer
if regname in REGULARIZERS_POL:
reg = polutils.polregularizer(imcur, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
pol_trans=POL_TRANS)
# Multifrequency regularizers
elif self.mf_next:
# Image regularizer(s)
if regname in REGULARIZERS:
# new option to regularize ALL the images in multifrequency imaging
# TODO total fluxes not right?
if self.reg_all_freq_mf:
for i in range(len(self.obslist_next)):
regname_key = regname + ('_%i' % i)
logfreqratio = self._logfreqratio_list[i]
imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio)
imref =imcur[0]
reg = imutils.regularizer(imcur_nu, prior_nu, self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
reg_dict[regname_key] = reg
# normally we only regularize reference frequency image
else:
reg = imutils.regularizer(imcur[0], self.priortuple[0], self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Spectral index regularizer(s)
elif regname in REGULARIZERS_SPECIND:
reg = mfutils.regularizer_mf(imcur[1], self.priortuple[1], self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Curvature index regularizer(s)
elif regname in REGULARIZERS_CURV:
reg = mfutils.regularizer_mf(imcur[2], self.priortuple[2], self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
# Normal, single polarization, single-frequency regularizer
elif regname in REGULARIZERS:
if self.pol_next in POLARIZATION_MODES:
imcur0 = imcur[0]
else:
imcur0 = imcur
reg = imutils.regularizer(imcur0, self._nprior, self._embed_mask,
self.flux_next, self.prior_next.xdim,
self.prior_next.ydim, self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
else:
raise Exception("regularizer term %s not recognized!" % regname)
# multifrequency regularizer terms are already in the dictionary
# if we regularize all images with self.reg_all_freq_mf
if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)):
reg_dict[regname] = reg
return reg_dict
[docs] def make_reggrad_dict(self, imcur):
"""Make a dictionary of current regularizer gradient values
"""
reggrad_dict = {}
for regname in sorted(self.reg_term_next.keys()):
# Polarimetric regularizer
if regname in REGULARIZERS_POL:
reg = polutils.polregularizergrad(imcur, self._embed_mask,
self.flux_next, self.pflux_next, self.vflux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
pol_solve=self._pol_which_solve,
pol_trans=POL_TRANS)
# Multifrequency regularizer
elif self.mf_next:
# Image regularizer(s)
if regname in REGULARIZERS:
# new option to regularize ALL the images in multifrequency imaging
# TODO total fluxes not right?
if self.reg_all_freq_mf:
for i in range(len(self.obslist_next)):
regname_key = regname + ('_%i' % i)
logfreqratio = self._logfreqratio_list[i]
imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio)
imref =imcur[0]
reg = imutils.regularizergrad(imcur_nu, prior_nu,
self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg,
beam_size=self.beam_size,
**self.regparams)
reg = mfutils.mf_all_grads_chain(reg, imcur_nu, imref, logfreqratio)
reg_dict[regname_key] = reg
# normally we only regularize the reference frequency image
else:
reg = imutils.regularizergrad(imcur[0], self.priortuple[0],
self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg,
beam_size=self.beam_size,
**self.regparams)
reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage)))
# Spectral index regularizer(s)
elif regname in REGULARIZERS_SPECIND:
reg = mfutils.regularizergrad_mf(imcur[1], self.priortuple[1],
self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg,
beam_size=self.beam_size,
**self.regparams)
reg = np.array((np.zeros(self._nimage), reg, np.zeros(self._nimage)))
# Curvature index regularizer(s)
elif regname in REGULARIZERS_CURV:
reg = mfutils.regularizergrad_mf(imcur[2], self.priortuple[2],
self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize, regname,
norm_reg=self.norm_reg,
beam_size=self.beam_size,
**self.regparams)
reg = np.array((np.zeros(self._nimage), np.zeros(self._nimage), reg))
# Normal, single polarization, single-frequency regularizer
elif regname in REGULARIZERS:
if self.pol_next in POLARIZATION_MODES:
imcur0 = imcur[0]
else:
imcur0 = imcur
reg = imutils.regularizergrad(imcur0, self._nprior, self._embed_mask, self.flux_next,
self.prior_next.xdim, self.prior_next.ydim,
self.prior_next.psize,
regname,
norm_reg=self.norm_reg, beam_size=self.beam_size,
**self.regparams)
if self.pol_next in POLARIZATION_MODES:
if 'V' in self.pol_next:
reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage)))
else:
reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage)))
else:
raise Exception("regularizer term %s not recognized!" % regname)
# multifrequency regularizer gradient terms are already in the dictionary
# if we regularize all images with self.reg_all_freq_mf
if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)):
reggrad_dict[regname] = reg
return reggrad_dict
[docs] def objfunc(self, imvec):
"""Current objective function.
"""
# Unpack polarimetric/multifrequency vector into an array
if self.pol_next in POLARIZATION_MODES:
imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
elif self.mf_next:
imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve)
else:
imcur = imvec
# Image change of variables
if self.pol_next in POLARIZATION_MODES and 'mcv' in self.transform_next:
imcur = polutils.mcv(imcur)
if 'log' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
imcur[0] = np.exp(imcur[0])
elif self.mf_next:
imcur[0] = np.exp(imcur[0])
else:
imcur = np.exp(imcur)
# Data terms
datterm = 0.
chi2_term_dict = self.make_chisq_dict(imcur)
for dname in sorted(self.dat_term_next.keys()):
hyperparameter = self.dat_term_next[dname]
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
chi2 = chi2_term_dict[dname_key]
if self.chisq_transform:
datterm += hyperparameter * (chi2 + 1./chi2 - 1.)
else:
datterm += hyperparameter * (chi2 - 1.)
# Regularizer terms
regterm = 0
reg_term_dict = self.make_reg_dict(imcur)
for regname in sorted(self.reg_term_next.keys()):
hyperparameter = self.reg_term_next[regname]
# multifrequency imaging, regularize every frequency
if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
for i in range(len(self.obslist_next)):
regname_key = regname + ('_%i' % i)
regularizer = reg_term_dict[regname_key]
regterm += hyperparameter * regularizer
# but normally just one regularizer
else:
regularizer = reg_term_dict[regname]
regterm += hyperparameter * regularizer
# Total cost
cost = datterm + regterm
return cost
[docs] def objgrad(self, imvec):
"""Current objective function gradient.
"""
# Unpack polarimetric/multifrequency vector into an array
if self.pol_next in POLARIZATION_MODES:
imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
elif self.mf_next:
imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve)
else:
imcur = imvec
# Image change of variables
if 'mcv' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
cvcur = imcur.copy()
imcur = polutils.mcv(imcur)
if 'log' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
imcur[0] = np.exp(imcur[0])
elif self.mf_next:
imcur[0] = np.exp(imcur[0])
else:
imcur = np.exp(imcur)
# Data terms
datterm = 0.
chi2_term_dict = self.make_chisqgrad_dict(imcur)
if self.chisq_transform:
chi2_value_dict = self.make_chisq_dict(imcur)
for dname in sorted(self.dat_term_next.keys()):
hyperparameter = self.dat_term_next[dname]
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
chi2_grad = chi2_term_dict[dname_key]
if self.chisq_transform:
chi2_val = chi2_value_dict[dname]
datterm += hyperparameter * chi2_grad * (1. - 1./(chi2_val**2))
else:
datterm += hyperparameter * (chi2_grad + self.chisq_offset_gradient)
# Regularizer terms
regterm = 0
reg_term_dict = self.make_reggrad_dict(imcur)
for regname in sorted(self.reg_term_next.keys()):
hyperparameter = self.reg_term_next[regname]
# multifrequency imaging, regularize every frequency
if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
for i in range(len(self.obslist_next)):
regname_key = regname + ('_%i' % i)
regularizer = reg_term_dict[regname_key]
regterm += hyperparameter * regularizer
# but normally just one regularizer
else:
regularizer_grad = reg_term_dict[regname]
regterm += hyperparameter * regularizer_grad
# Total gradient
grad = datterm + regterm
# Chain rule term for change of variables
if 'mcv' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
grad *= polutils.mchain(cvcur)
if 'log' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
grad[0] *= imcur[0]
elif self.mf_next:
grad[0] *= imcur[0]
else:
grad *= imcur
# Repack gradient for polarimetric imaging
if self.pol_next in POLARIZATION_MODES:
grad = polutils.pack_poltuple(grad, self._pol_which_solve)
# repack gradient for multifrequency imaging
elif self.mf_next:
grad = mfutils.pack_mftuple(grad, self.mf_which_solve)
return grad
[docs] def plotcur(self, imvec, **kwargs):
"""Plot current image.
"""
if self._show_updates:
if self._nit % self._update_interval == 0:
if self.pol_next in POLARIZATION_MODES:
imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
elif self.mf_next:
imcur = mfutils.unpack_mftuple(
imvec, self._xtuple, self._nimage, self.mf_which_solve)
else:
imcur = imvec
# Image change of variables
if 'mcv' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
imcur = polutils.mcv(imcur)
if 'log' in self.transform_next:
if self.pol_next in POLARIZATION_MODES:
imcur[0] = np.exp(imcur[0])
elif self.mf_next:
imcur[0] = np.exp(imcur[0])
else:
imcur = np.exp(imcur)
# Get chi^2 and regularizer
chi2_term_dict = self.make_chisq_dict(imcur)
reg_term_dict = self.make_reg_dict(imcur)
# Format print string
outstr = "------------------------------------------------------------------"
outstr += "\n%4d | " % self._nit
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
outstr += "\n "
for dname in sorted(self.dat_term_next.keys()):
for i, obs in enumerate(self.obslist_next):
if len(self.obslist_next)==1:
dname_key = dname
else:
dname_key = dname + ('_%i' % i)
dval = chi2_term_dict[dname_key]*self.dat_term_next[dname]
outstr += "%s : %0.1f " % (dname_key, dval)
outstr += "\n "
for regname in sorted(self.reg_term_next.keys()):
if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
for i in range(len(self.obslist_next)):
regname_key = regname + ('_%i' % i)
rval = reg_term_dict[regname_key]*self.reg_term_next[regname]
outstr += "%s : %0.1f " % (regname_key, rval)
else:
rval = reg_term_dict[regname]*self.reg_term_next[regname]
outstr += "%s : %0.1f " % (regname, rval)
# Embed and plot the image
if self.pol_next in POLARIZATION_MODES:
if np.any(np.invert(self._embed_mask)):
imcur = polutils.embed_pol(imcur, self._embed_mask)
polutils.plot_m(imcur, self.prior_next, self._nit, chi2_term_dict, **kwargs)
else:
if self.mf_next:
implot = imcur[0]
else:
implot = imcur
if np.any(np.invert(self._embed_mask)):
implot = imutils.embed(implot, self._embed_mask)
imutils.plot_i(implot, self.prior_next, self._nit,
chi2_term_dict, pol=self.pol_next, **kwargs)
if self._nit == 0:
print()
print(outstr)
self._nit += 1
[docs] def objfunc_scattering(self, minvec):
"""Current stochastic optics objective function.
"""
N = self.prior_next.xdim
imvec = minvec[:N**2]
EpsilonList = minvec[N**2:]
if 'log' in self.transform_next:
imvec = np.exp(imvec)
IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
self.prior_next.ra, self.prior_next.dec,
self.prior_next.pa, rf=self.obs_next.rf,
source=self.prior_next.source, mjd=self.prior_next.mjd)
# The scattered image vector
screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
Linearized_Approximation=True)
scatt_im = scatt_im.imvec
# Calculate the chi^2 using the scattered image
datterm = 0.
chi2_term_dict = self.make_chisq_dict(scatt_im)
for dname in sorted(self.dat_term_next.keys()):
datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.)
# Calculate the entropy using the unscattered image
regterm = 0
reg_term_dict = self.make_reg_dict(imvec)
# Make dict also for scattered image
reg_term_dict_scatt = self.make_reg_dict(scatt_im)
for regname in sorted(self.reg_term_next.keys()):
if regname == 'rgauss':
# Get gradient of the scattered image vector
regterm += self.reg_term_next[regname] * reg_term_dict_scatt[regname]
else:
regterm += self.reg_term_next[regname] * reg_term_dict[regname]
# Scattering screen regularization term
chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0)
return datterm + regterm + regterm_scattering
[docs] def objgrad_scattering(self, minvec):
"""Current stochastic optics objective function gradient
"""
wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm]
wavelengthbar = wavelength/(2.0*np.pi) # lambda/(2pi) [cm]
N = self.prior_next.xdim
# Field of view, in cm, at the scattering screen
FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance
rF = self.scattering_model.rF(wavelength)
imvec = minvec[:N**2]
EpsilonList = minvec[N**2:]
if 'log' in self.transform_next:
imvec = np.exp(imvec)
IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
self.prior_next.ra, self.prior_next.dec,
self.prior_next.pa, rf=self.obs_next.rf,
source=self.prior_next.source, mjd=self.prior_next.mjd)
# The scattered image vector
screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
Linearized_Approximation=True)
scatt_im = scatt_im.imvec
EA_Image = self.scattering_model.Ensemble_Average_Blur(IM, ker=self._ea_ker)
EA_Gradient = so.Wrapped_Gradient((EA_Image.imvec/(FOV/N)).reshape(N, N))
# The gradient signs don't actually matter, but let's make them match intuition
# (i.e., right to left, bottom to top)
EA_Gradient_x = -EA_Gradient[1]
EA_Gradient_y = -EA_Gradient[0]
Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
phi_scr = self.scattering_model.MakePhaseScreen(Epsilon_Screen, IM,
obs_frequency_Hz=self.obs_next.rf,
sqrtQ_init=self._sqrtQ)
phi = phi_scr.imvec.reshape((N, N))
phi_Gradient = so.Wrapped_Gradient(phi/(FOV/N))
phi_Gradient_x = -phi_Gradient[1]
phi_Gradient_y = -phi_Gradient[0]
# Entropy gradient; wrt unscattered image so unchanged by scattering
regterm = 0
reg_term_dict = self.make_reggrad_dict(imvec)
# Make dict also for scattered image
reg_term_dict_scatt = self.make_reggrad_dict(scatt_im)
for regname in sorted(self.reg_term_next.keys()):
# We need an exception if the regularizer is 'rgauss'
if regname == 'rgauss':
# Get gradient of the scattered image vector
gaussterm = self.reg_term_next[regname] * reg_term_dict_scatt[regname]
dgauss_dIa = gaussterm.reshape((N, N))
# Now the chain rule factor to get the gauss gradient wrt the unscattered image
gx = so.Wrapped_Convolve(
self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dgauss_dIa))
gx = (rF**2.0 * gx).flatten()
gy = so.Wrapped_Convolve(
self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dgauss_dIa))
gy = (rF**2.0 * gy).flatten()
# Now we add the gradient for the unscattered image
regterm += so.Wrapped_Convolve(self._ea_ker[::-1, ::-1],
(dgauss_dIa)).flatten() + gx + gy
else:
regterm += self.reg_term_next[regname] * reg_term_dict[regname]
# Chi^2 gradient wrt the unscattered image
# First, the chi^2 gradient wrt to the scattered image
datterm = 0.
chi2_term_dict = self.make_chisqgrad_dict(scatt_im)
for dname in sorted(self.dat_term_next.keys()):
datterm += self.dat_term_next[dname] * (chi2_term_dict[dname])
dchisq_dIa = datterm.reshape((N, N))
# Now the chain rule factor to get the chi^2 gradient wrt the unscattered image
gx = so.Wrapped_Convolve(self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dchisq_dIa))
gx = (rF**2.0 * gx).flatten()
gy = so.Wrapped_Convolve(self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dchisq_dIa))
gy = (rF**2.0 * gy).flatten()
chisq_grad_im = so.Wrapped_Convolve(
self._ea_ker[::-1, ::-1], (dchisq_dIa)).flatten() + gx + gy
# Gradient of the data chi^2 wrt to the epsilon screen
# Preliminary Definitions
chisq_grad_epsilon = np.zeros(N**2-1)
i_grad = 0
ell_mat = np.zeros((N, N))
m_mat = np.zeros((N, N))
for ell in range(0, N):
for m in range(0, N):
ell_mat[ell, m] = ell
m_mat[ell, m] = m
# Real part; top row
for t in range(1, (N+1)//2):
s = 0
grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] *
2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
grad_term = so.Wrapped_Gradient(grad_term)
grad_term_x = -grad_term[1]
grad_term_y = -grad_term[0]
cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
i_grad = i_grad + 1
# Real part; remainder
for s in range(1, (N+1)//2):
for t in range(N):
grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] *
2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
grad_term = so.Wrapped_Gradient(grad_term)
grad_term_x = -grad_term[1]
grad_term_y = -grad_term[0]
cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
i_grad = i_grad + 1
# Imaginary part; top row
for t in range(1, (N+1)//2):
s = 0
grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] *
2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
grad_term = so.Wrapped_Gradient(grad_term)
grad_term_x = -grad_term[1]
grad_term_y = -grad_term[0]
cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
i_grad = i_grad + 1
# Imaginary part; remainder
for s in range(1, (N+1)//2):
for t in range(N):
grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] *
2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
grad_term = so.Wrapped_Gradient(grad_term)
grad_term_x = -grad_term[1]
grad_term_y = -grad_term[0]
cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
i_grad = i_grad + 1
# Gradient of the chi^2 regularization term for the epsilon screen
chisq_epsilon_grad = self.alpha_phi_next * 2.0*EpsilonList/((N*N-1)/2.0)
# Chain rule term for change of variables
if 'log' in self.transform_next:
regterm *= imvec
chisq_grad_im *= imvec
out = np.concatenate(((regterm + chisq_grad_im), (chisq_grad_epsilon + chisq_epsilon_grad)))
return out
[docs] def plotcur_scattering(self, minvec):
"""Plot current stochastic optics image/screen
"""
if self._show_updates:
if self._nit % self._update_interval == 0:
N = self.prior_next.xdim
imvec = minvec[:N**2]
EpsilonList = minvec[N**2:]
if 'log' in self.transform_next:
imvec = np.exp(imvec)
IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
self.prior_next.ra, self.prior_next.dec,
self.prior_next.pa, rf=self.obs_next.rf,
source=self.prior_next.source, mjd=self.prior_next.mjd)
# The scattered image vector
screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
Linearized_Approximation=True)
scatt_im = scatt_im.imvec
# Calculate the chi^2 using the scattered image
datterm = 0.
chi2_term_dict = self.make_chisq_dict(scatt_im)
for dname in sorted(self.dat_term_next.keys()):
datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.)
# Calculate the entropy using the unscattered image
regterm = 0
reg_term_dict = self.make_reg_dict(imvec)
for regname in sorted(self.reg_term_next.keys()):
regterm += self.reg_term_next[regname] * reg_term_dict[regname]
# Scattering screen regularization term
chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
# regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0)
outstr = "i: %d " % self._nit
for dname in sorted(self.dat_term_next.keys()):
outstr += "%s : %0.2f " % (dname, chi2_term_dict[dname])
for regname in sorted(self.reg_term_next.keys()):
outstr += "%s : %0.2f " % (regname, reg_term_dict[regname])
outstr += "Epsilon chi^2 : %0.2f " % (chisq_epsilon)
outstr += "Max |Epsilon| : %0.2f " % (max(abs(EpsilonList)))
print(outstr)
self._nit += 1
[docs] def make_image_I_stochastic_optics(self, grads=True, **kwargs):
"""Reconstructs an image of total flux density
using the stochastic optics scattering mitigation technique.
Uses the scattering model in Imager.scattering_model.
If none has been specified, defaults to standard model for Sgr A*.
Returns the estimated unscattered image.
Args:
grads (bool): Flag for whether or not to use analytic gradients.
show_updates (bool): Flag for whether or not to show updates
Returns:
out (Image): The estimated *unscattered* image.
"""
N = self.prior_next.xdim
# Checks and initialize
self.check_params()
self.check_limits()
self.init_imager()
self.init_imager_scattering()
self._nit = 0
# Print stats
self._show_updates = kwargs.get('show_updates', True)
self._update_interval = kwargs.get('update_interval', 1)
self.plotcur_scattering(self._xinit)
# Minimize
optdict = {'maxiter': self.maxit_next, 'ftol': self.stop_next, 'maxcor': NHIST}
tstart = time.time()
if grads:
res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B',
jac=self.objgrad_scattering, options=optdict,
callback=self.plotcur_scattering)
else:
res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B',
options=optdict, callback=self.plotcur_scattering)
tstop = time.time()
# Format output
out = res.x[:N**2]
if 'log' in self.transform_next:
out = np.exp(out)
if np.any(np.invert(self._embed_mask)):
raise Exception("Embedding is not currently implemented!")
out = imutils.embed(out, self._embed_mask)
outim = ehtim.image.Image(out.reshape(N, N), self.prior_next.psize,
self.prior_next.ra, self.prior_next.dec, self.prior_next.pa,
rf=self.prior_next.rf, source=self.prior_next.source,
mjd=self.prior_next.mjd, pulse=self.prior_next.pulse)
outep = res.x[N**2:]
screen = so.MakeEpsilonScreenFromList(outep, N)
outscatt = self.scattering_model.Scatter(outim,
Epsilon_Screen=screen,
ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
Linearized_Approximation=True)
# Preserving image complex polarization fractions
if len(self.prior_next.qvec):
qvec = self.prior_next.qvec * out / self.prior_next.imvec
uvec = self.prior_next.uvec * out / self.prior_next.imvec
outim.add_qu(qvec.reshape(N, N),
uvec.reshape(N, N))
# Print stats
print("time: %f s" % (tstop - tstart))
print("J: %f" % res.fun)
print(res.message)
# Append to history
logstr = str(self.nruns) + ": make_image_I_stochastic_optics()"
self._append_image_history(outim, logstr)
self._out_list_epsilon.append(res.x[N**2:])
self._out_list_scattered.append(outscatt)
self.nruns += 1
# Return Image object
return outim
def _append_image_history(self, outim, logstr):
self.logstr += (logstr + "\n")
self._obs_list.append(self.obslist_next)
self._init_list.append(self.init_next)
self._prior_list.append(self.prior_next)
self._debias_list.append(self.debias_next)
self._weighting_list.append(self.weighting_next)
self._systematic_noise_list.append(self.systematic_noise_next)
self._systematic_cphase_noise_list.append(self.systematic_cphase_noise_next)
self._snrcut_list.append(self.snrcut_next)
self._flux_list.append(self.flux_next)
self._pflux_list.append(self.pflux_next)
self._vflux_list.append(self.vflux_next)
self._pol_list.append(self.pol_next)
self._clipfloor_list.append(self.clipfloor_next)
self._maxset_list.append(self.clipfloor_next)
self._maxit_list.append(self.maxit_next)
self._stop_list.append(self.stop_next)
self._transform_list.append(self.transform_next)
self._reg_term_list.append(self.reg_term_next)
self._dat_term_list.append(self.dat_term_next)
self._alpha_phi_list.append(self.alpha_phi_next)
self._out_list.append(outim)
return