Source code for ehtim.calibrating.network_cal

# network_cal.py
# functions for network-calibration
#
#    Copyright (C) 2018 Andrew Chael
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.


from __future__ import division
from __future__ import print_function

from builtins import str
from builtins import range
from builtins import object

import numpy as np
import scipy.optimize as opt
import time
import copy
from multiprocessing import cpu_count, Pool

import ehtim.obsdata
import ehtim.parloop as parloop
from . import cal_helpers as calh
import ehtim.observing.obs_helpers as obsh
import ehtim.const_def as ehc

ZBLCUTOFF = 1.e7
MAXIT = 5000

###################################################################################################
# Network-Calibration
###################################################################################################


[docs]def network_cal(obs, zbl, sites=[], zbl_uvdist_max=ZBLCUTOFF, method="amp", minimizer_method='BFGS', pol='I', pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False, caltable=False, processes=-1, show_solution=False, debias=True, msgtype='bar'): """Network-calibrate a dataset with zero baseline constraints. Args: obs (Obsdata): The observation to be calibrated zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour. sites (list): list of sites to include in the network calibration. empty list calibrates all sites zbl_uvdist_max (float): maximum uv-distance considered a zero baseline method (str): chooses what to calibrate, 'amp', 'phase', or 'both'. minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') pol (str): which visibility to compute gains for pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature gain_tol (float): gains that exceed this value will be disfavored by the prior solution_interval (float): solution interval in seconds; one gain is derived for each interval. If 0.0, a solution is determined for each unique time scan_solutions (bool): If True, determine one gain per site per scan. Supersedes solution_interval debias (bool): If True, debias the amplitudes caltable (bool): if True, returns a Caltable instead of an Obsdata processes (int): number of cores to use in multiprocessing show_solution (bool): if True, display the solution as it is calculated msgtype (str): type of progress message to be printed, default is 'bar' Returns: (Obsdata): the calibrated observation, if caltable==False (Caltable): the derived calibration table, if caltable==True """ # Here, RRLL means to use both RR and LL (both as proxies for Stokes I) # to derive a network calibration solution if pol not in ['I', 'Q', 'U', 'V', 'RR', 'LL', 'RRLL']: raise Exception("Can only network-calibrate to I, Q, U, V, RR, LL, or RRLL!") if pol in ['I', 'Q', 'U', 'V']: if obs.polrep != 'stokes': raise Exception("netcal pol is a stokes parameter, but obs.polrep!='stokes'") # obs = obs.switch_polrep('stokes',pol) elif pol in ['RR', 'LL', 'RRLL']: if obs.polrep != 'circ': raise Exception("netcal pol is RR or LL or RRLL, but obs.polrep!='circ'") # obs = obs.switch_polrep('circ',pol) # V = model visibility, V' = measured visibility, G_i = site gain # G_i * conj(G_j) * V_ij = V'_ij if len(sites) == 0: print("No stations specified in network cal: defaulting to calibrating all stations!") sites = obs.tarr['site'] # find colocated sites and put into list allclusters cluster_data = calh.make_cluster_data(obs, zbl_uvdist_max) # get scans scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions) scans_cal = copy.copy(scans) # Make the pool for parallel processing if processes > 0: counter = parloop.Counter(initval=0, maxval=len(scans)) if processes > len(scans): processes = len(scans) print("Using Multiprocessing with %d Processes" % processes) pool = Pool(processes=processes, initializer=init, initargs=(counter,)) elif processes == 0: counter = parloop.Counter(initval=0, maxval=len(scans)) processes = int(cpu_count()) if processes > len(scans): processes = len(scans) print("Using Multiprocessing with %d Processes" % processes) pool = Pool(processes=processes, initializer=init, initargs=(counter,)) else: print("Not Using Multiprocessing") # loop over scans and calibrate tstart = time.time() if processes > 0: # with multiprocessing scans_cal = np.array(pool.map(get_network_scan_cal, [[i, len(scans), scans[i], zbl, sites, cluster_data, obs.polrep, pol, method, pad_amp, gain_tol, caltable, show_solution, debias, msgtype ] for i in range(len(scans))]), dtype=object) else: # without multiprocessing for i in range(len(scans)): obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1) scans_cal[i] = network_cal_scan(scans[i], zbl, sites, cluster_data, polrep=obs.polrep, pol=pol, method=method, minimizer_method=minimizer_method, show_solution=show_solution, caltable=caltable, pad_amp=pad_amp, gain_tol=gain_tol, debias=debias) tstop = time.time() print("\nnetwork_cal time: %f s" % (tstop - tstart)) if caltable: # create and return a caltable allsites = obs.tarr['site'] caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()} for i in range(1, len(scans_cal)): row = scans_cal[i] if len(row) == 0: continue for site in allsites: try: dat = row[site] except KeyError: continue try: caldict[site] = np.append(caldict[site], row[site]) except KeyError: caldict[site] = [dat] caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr, source=obs.source, mjd=obs.mjd, timetype=obs.timetype) out = caltable else: # return the calibrated observation arglist, argdict = obs.obsdata_args() arglist[4] = np.concatenate(scans_cal) out = ehtim.obsdata.Obsdata(*arglist, **argdict) # close multiprocessing jobs if processes != -1: pool.close() return out
[docs]def network_cal_scan(scan, zbl, sites, clustered_sites, polrep='stokes', pol='I', zbl_uvidst_max=ZBLCUTOFF, method="both", minimizer_method='BFGS', show_solution=False, pad_amp=0., gain_tol=.2, caltable=False, debias=True): """Network-calibrate a scan with zero baseline constraints. Args: obs (Obsdata): The observation to be calibrated zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour. sites (list): list of sites to include in the network calibration. empty list calibrates all sites clustered_sites (tuple): information on clustered sites, returned by make_cluster_data polrep (str): 'stokes' or 'circ' to specify the polarization products in scan pol (str): which image polarization to self-calibrate visibilities to zbl_uvdist_max (float): maximum uv-distance considered a zero baseline method (str): chooses what to calibrate, 'amp', 'phase', or 'both' pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature gain_tol (float): gains that exceed this value will be disfavored by the prior debias (bool): If True, debias the amplitudes caltable (bool): if True, returns a Caltable instead of an Obsdata show_solution (bool): if True, display the solution as it is calculated Returns: (Obsdata): the calibrated scan, if caltable==False (Caltable): the derived calibration table, if caltable==True """ # determine the zero-baseline flux of the scan if callable(zbl): zbl_scan = np.median(zbl(scan['time'])) else: zbl_scan = zbl # clustered site information allclusters = clustered_sites[0] clusterdict = clustered_sites[1] clusterbls = clustered_sites[2] # all the sites in the scan allsites = list(set(np.hstack((scan['t1'], scan['t2'])))) if len(sites) == 0: print("No stations specified in network cal: defaulting to calibrating all !") sites = allsites # only include sites that are present sites = [s for s in sites if s in allsites] # create a dictionary to keep track of gains; # sites that aren't network calibrated (no co-located partners) get a value of -1 # so that they won't be network calibrated; other sites get a unique number tkey = {b: a for a, b in enumerate(sites)} for cluster in allclusters: if len(cluster) == 1: tkey[cluster[0]] = -1 clusterkey = clusterdict # restrict solved cluster visibilities to ones present in the scan # (this is much faster than allowing many unconstrained variables) clusterbls_scan = [set([clusterkey[row['t1']], clusterkey[row['t2']]]) for row in scan if len(set([clusterkey[row['t1']], clusterkey[row['t2']]])) == 2] # now delete duplicates clusterbls = [cluster for cluster in clusterbls if cluster in clusterbls_scan] # make two lists of gain keys that relates scan bl gains to solved site ones # (-1 means that this station does not have a gain that is being solved for) # and make one list of scan keys that relates scan bl visibilities to solved cluster ones # (-1 means it's a zero baseline!) g1_keys = [] g2_keys = [] scan_keys = [] for row in scan: try: g1_keys.append(tkey[row['t1']]) except KeyError: g1_keys.append(-1) try: g2_keys.append(tkey[row['t2']]) except KeyError: g2_keys.append(-1) clusternum1 = clusterkey[row['t1']] clusternum2 = clusterkey[row['t2']] if clusternum1 == clusternum2: # sites are in the same cluster scan_keys.append(-1) else: # sites are not in the same cluster bl_index = clusterbls.index(set((clusternum1, clusternum2))) scan_keys.append(bl_index) # no sites to calibrate on this scan! # if np.all(g1_keys == -1): # return scan #Doesn't work with the caldict options # Start by restricting to visibilities that include baselines to a site with a zero-baseline vis_mask = [((row['t1'] in tkey.keys() and tkey[row['t1']] != -1) or (row['t2'] in tkey.keys() and tkey[row['t2']] != -1)) for row in scan] # get scan visibilities of the specified polarization if pol != 'RRLL': vis = scan[ehc.vis_poldict[pol]] sigma = scan[ehc.sig_poldict[pol]] else: vis = np.concatenate([scan[ehc.vis_poldict['RR']], scan[ehc.vis_poldict['LL']]]) sigma = np.concatenate([scan[ehc.sig_poldict['RR']], scan[ehc.sig_poldict['LL']]]) vis_mask = np.concatenate([vis_mask, vis_mask]) if method == 'amp': if debias: vis = obsh.amp_debias(np.abs(vis), np.abs(sigma)) else: vis = np.abs(vis) sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2) # initial guesses for parameters n_gains = len(sites) n_clusterbls = len(clusterbls) if show_solution: print('%d Gains; %d Clusters' % (n_gains, n_clusterbls)) gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64) vpar_guess = np.ones(n_clusterbls, dtype=np.complex128) for i in range(len(scan_keys)): if scan_keys[i] < 0: continue if np.isnan(vis[i]): continue vpar_guess[scan_keys[i]] = vis[i] vpar_guess = vpar_guess.view(dtype=np.float64) gvpar_guess = np.hstack((gpar_guess, vpar_guess)) # error function def errfunc(gvpar): # all the forward site gains (complex) g = gvpar[0:2 * n_gains].astype(np.float64).view(dtype=np.complex128) # all the intercluster visibilities (complex) v = gvpar[2 * n_gains:].astype(np.float64).view(dtype=np.complex128) # choose to only scale ampliltudes or phases if method == "phase": g = g / np.abs(g) elif method == "amp": g = np.abs(np.real(g)) # append the default values to g for missing points # and to v for the zero baseline points g = np.append(g, 1.) v = np.append(v, zbl_scan) # scan visibilities are either an intercluster visibility or the fixed zbl v_scan = v[scan_keys] g1 = g[g1_keys] g2 = g[g2_keys] if pol == 'RRLL': v_scan = np.concatenate([v_scan, v_scan]) g1 = np.concatenate([g1, g1]) g2 = np.concatenate([g2, g2]) if method == 'amp': verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan) else: verr = vis - g1 * g2.conj() * v_scan chi = np.abs(verr) * sigma_inv chisq = np.sum((chi * chi)[np.isfinite(chi) * vis_mask]) # prior on the gains g_fracerr = gain_tol if method == "phase": chisq_g = 0 # because |g| == 1 so log(|g|) = 0 elif method == "amp": logg = np.log(g) chisq_g = np.sum(logg * logg) / (g_fracerr * g_fracerr) else: logabsg = np.log(np.abs(g)) chisq_g = np.sum(logabsg * logabsg) / (g_fracerr * g_fracerr) absv = np.abs(v) vv = absv * absv chisq_v = np.sum(vv * vv) / zbl_scan**4 return chisq + chisq_g + chisq_v if np.max(g1_keys) > -1 or np.max(g2_keys) > -1: # run the minimizer to get a solution (but only run if there's at least one gain to fit) optdict = {'maxiter': MAXIT} # minimizer params res = opt.minimize(errfunc, gvpar_guess, method=minimizer_method, options=optdict) # get solution g_fit = res.x[0:2 * n_gains].view(np.complex128) v_fit = res.x[2 * n_gains:].view(np.complex128) if method == "phase": g_fit = g_fit / np.abs(g_fit) if method == "amp": g_fit = np.abs(np.real(g_fit)) if show_solution: print(np.abs(g_fit)) print(np.abs(v_fit)) else: g_fit = [] v_fit = [] g_fit = np.append(g_fit, 1.) v_fit = np.append(v_fit, zbl_scan) # Derive a calibration table or apply the solution to the scan if caltable: allsites = list(set(scan['t1']).union(set(scan['t2']))) caldict = {} for site in allsites: if site in sites: site_key = tkey[site] else: site_key = -1 # We will *always* set the R and L gain corrections to be equal in network calibration, # to avoid breaking polarization consistency relationships rscale = g_fit[site_key]**-1 lscale = g_fit[site_key]**-1 # Note: we may want to give two entries for the start/stop times # when a non-zero solution interval is used caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL) out = caldict else: g1_fit = g_fit[g1_keys] g2_fit = g_fit[g2_keys] gij_inv = (g1_fit * g2_fit.conj())**(-1) if polrep == 'stokes': # scale visibilities for vistype in ['vis', 'qvis', 'uvis', 'vvis']: scan[vistype] *= gij_inv # scale sigmas for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']: scan[sigtype] *= np.abs(gij_inv) elif polrep == 'circ': # scale visibilities for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']: scan[vistype] *= gij_inv # scale sigmas for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']: scan[sigtype] *= np.abs(gij_inv) out = scan return out
def init(x): global counter counter = x def get_network_scan_cal(args): return get_network_scan_cal2(*args) def get_network_scan_cal2(i, n, scan, zbl, sites, cluster_data, polrep, pol, method, pad_amp, gain_tol, caltable, show_solution, debias, msgtype): if n > 1: global counter counter.increment() obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1) return network_cal_scan(scan, zbl, sites, cluster_data, polrep=polrep, pol=pol, zbl_uvidst_max=ZBLCUTOFF, method=method, caltable=caltable, show_solution=show_solution, pad_amp=pad_amp, gain_tol=gain_tol, debias=debias)