Source code for ehtim.caltable

# caltable.py
# a calibration table class
#
#    Copyright (C) 2018 Andrew Chael
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
from __future__ import print_function

from builtins import str
from builtins import range
from builtins import object

import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import scipy.interpolate

import ehtim.io.save
import ehtim.io.load

import ehtim.const_def as ehc
import ehtim.observing.obs_helpers as obsh


##################################################################################################
# Caltable object
##################################################################################################


[docs]class Caltable(object): """ Attributes: source (str): The source name ra (float): The source Right Ascension in fractional hours dec (float): The source declination in fractional degrees mjd (int): The integer MJD of the observation rf (float): The observation frequency in Hz bw (float): The observation bandwidth in Hz timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' tarr (numpy.recarray): The array of telescope data with datatype DTARR tkey (dict): A dictionary of rows in the tarr for each site name data (dict): keys are sites in tarr, entries are calibration data tables of type DTCAL """ def __init__(self, ra, dec, rf, bw, datadict, tarr, source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC'): """A Calibration Table. Args: ra (float): The source Right Ascension in fractional hours dec (float): The source declination in fractional degrees rf (float): The observation frequency in Hz mjd (int): The integer MJD of the observation bw (float): The observation bandwidth in Hz datadict (dict): keys are sites in tarr, entries are data tables of type DTCAL tarr (numpy.recarray): The array of telescope data with datatype DTARR source (str): The source name mjd (int): The integer MJD of the observation timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' Returns: (Caltable): an Caltable object """ # Set the various parameters self.source = str(source) self.ra = float(ra) self.dec = float(dec) self.rf = float(rf) self.bw = float(bw) self.mjd = int(mjd) if timetype not in ['GMST', 'UTC']: raise Exception("timetype must by 'GMST' or 'UTC'") self.timetype = timetype # Dictionary of array indices for site names self.tarr = tarr self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} # Save the data self.data = datadict
[docs] def copy(self): """Copy the observation object. Args: Returns: (Caltable): a copy of the Caltable object. """ new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, self.data, self.tarr, source=self.source, mjd=self.mjd, timetype=self.timetype) return new_caltable
[docs] def plot_dterms(self, sites='all', label=None, legend=True, clist=ehc.SCOLORS, rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE, show=True, grid=True, export_pdf=""): """Make a plot of the D-terms. Args: sites (list) : list of sites to plot label (str) : title for plot legend (bool) : add telescope legend or not clist (list) : list of colors for different stations rangex (list) : lower and upper x-axis limits rangey (list) : lower and upper y-axis limits markersize (float) : marker size show (bool) : display the plot or not grid (bool) : add a grid to the plot or not export_pdf (str) : save a pdf file to this path Returns: matplotlib.axes """ # sites if (isinstance(sites,str) and sites.lower() == 'all'): sites = list(self.data.keys()) if isinstance(sites,str): sites = [sites] if len(sites)==0: sites = list(self.data.keys()) keys = [self.tkey[site] for site in sites] axes = plot_tarr_dterms(self.tarr, keys=keys, label=label, legend=legend, clist=clist, rangex=rangex, rangey=rangey, markersize=markersize, show=show, grid=grid, export_pdf=export_pdf) return axes
[docs] def plot_gains(self, sites, gain_type='amp', pol='R', label=None, ang_unit='deg', timetype=False, yscale='log', legend=True, clist=ehc.SCOLORS, rangex=False, rangey=False, markersize=[ehc.MARKERSIZE], show=True, grid=False, axislabels=True, axis=False, export_pdf=""): """Plot gains on multiple sites vs time. Args: sites (list): a list of site names for which to plot gains. Empty list is all sites. gain_type (str): 'amp' or 'phase' pol str(str): 'R' or 'L' ang_unit (str): phase unit 'deg' or 'rad' timetype (str): 'GMST' or 'UTC' yscale (str): 'log' or 'lin', clist (list): list of colors for the plot label (str): base label for legend rangex (list): [xmin, xmax] x-axis (time) limits rangey (list): [ymin, ymax] y-axis (gain) limits legend (bool): Plot legend if True grid (bool): Plot gridlines if True axislabels (bool): Show axis labels if True show (bool): Display the plot if true axis (matplotlib.axes.Axes): add plot to this axis markersize (int): size of plot markers export_pdf (str): path to pdf file to save figure Returns: (matplotlib.axes.Axes): Axes object with the plot """ colors = iter(clist) if timetype is False: timetype = self.timetype if timetype not in ['GMST', 'UTC', 'utc', 'gmst']: raise Exception("timetype should be 'GMST' or 'UTC'!") if gain_type not in ['amp', 'phase']: raise Exception("gain_type must be 'amp' or 'phase' ") if pol not in ['R', 'L', 'both']: raise Exception("pol must be 'R' or 'L'") if ang_unit == 'deg': angle = ehc.DEGREE else: angle = 1.0 # axis if axis: x = axis else: fig = plt.figure() x = fig.add_subplot(1, 1, 1) # sites if (isinstance(sites,str) and sites.lower() == 'all'): sites = sorted(list(self.data.keys())) if isinstance(sites,str): sites = [sites] if len(sites)==0: sites = sorted(list(self.data.keys())) if len(markersize) == 1: markersize = markersize * np.ones(len(sites)) # plot gain on each site tmins = tmaxes = gmins = gmaxes = [] for s in range(len(sites)): site = sites[s] times = self.data[site]['time'] if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': times = obsh.gmst_to_utc(times, self.mjd) elif timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': times = obsh.utc_to_gmst(times, self.mjd) if pol == 'R': gains = self.data[site]['rscale'] elif pol == 'L': gains = self.data[site]['lscale'] if gain_type == 'amp': gains = np.abs(gains) ylabel = r'$|G|$' if gain_type == 'phase': gains = np.angle(gains) / angle if ang_unit == 'deg': ylabel = r'arg($|G|$) ($^\circ$)' else: ylabel = r'arg($|G|$) (radian)' tmins.append(np.min(times)) tmaxes.append(np.max(times)) gmins.append(np.min(gains)) gmaxes.append(np.max(gains)) # Plot the data if label is None: bllabel = str(site) else: bllabel = label + ' ' + str(site) plt.plot(times, gains, color=next(colors), marker='o', markersize=markersize[s], label=bllabel, linestyle='none') if not rangex: rangex = [np.min(tmins) - 0.2 * np.abs(np.min(tmins)), np.max(tmaxes) + 0.2 * np.abs(np.max(tmaxes))] if np.any(np.isnan(np.array(rangex))): print("Warning: NaN in data x range: specifying rangex to default") rangex = [0, 24] if not rangey: rangey = [np.min(gmins) - 0.2 * np.abs(np.min(gmins)), np.max(gmaxes) + 0.2 * np.abs(np.max(gmaxes))] if np.any(np.isnan(np.array(rangey))): print("Warning: NaN in data x range: specifying rangey to default") rangey = [1.e-2, 1.e2] plt.plot(np.linspace(rangex[0], rangex[1], 5), np.ones(5), 'k--') x.set_xlim(rangex) x.set_ylim(rangey) # labels if axislabels: x.set_xlabel(self.timetype + ' (hr)') x.set_ylabel(ylabel) plt.title('Caltable gains for %s on day %s' % (self.source, self.mjd)) if legend: plt.legend() if yscale == 'log': x.set_yscale('log') if grid: x.grid() if export_pdf != "" and not axis: fig.savefig(export_pdf, bbox_inches='tight') if show: #plt.show(block=False) ehc.show_noblock() return x
[docs] def enforce_positive(self, method='median', min_gain=0.9, sites=[], verbose=True): """Enforce that caltable gains are not low (e.g., that sites are not significantly more sensitive than estimated). By rescaling the entire gain curve to enforce a specified minimum site gain. Args: caltab (Caltable): Input Caltable with station gains method (str): 'median', 'mean', or 'min' min_gain (float): Site gains above this value are not modified. sites (list): List of sites to check and adjust. For sites=[], all sites are fixed. verbose (bool): If True, print corrections. Returns: (Caltable): Axes object with the plot """ if len(sites) == 0: sites = self.data.keys() caltab_pos = self.copy() for site in self.data.keys(): if site not in sites: continue if len(self.data[site]['rscale']) == 0: continue if method == 'min': sitemin = np.min([np.abs(self.data[site]['rscale']), np.abs(self.data[site]['lscale'])]) elif method == 'mean': sitemin = np.mean([np.abs(self.data[site]['rscale']), np.abs(self.data[site]['lscale'])]) elif method == 'median': sitemin = np.median([np.abs(self.data[site]['rscale']), np.abs(self.data[site]['lscale'])]) else: print('Method ' + method + ' not recognized!') return caltab_pos if sitemin < min_gain: if verbose: print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Rescaling.') caltab_pos.data[site]['rscale'] /= sitemin caltab_pos.data[site]['lscale'] /= sitemin else: if verbose: print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Not adjusting.') return caltab_pos
# TODO default extrapolation?
[docs] def pad_scans(self, maxdiff=60, padtype='median'): """Pad data points around scans. Args: maxdiff (float): "scan" separation length (seconds) padtype (str): padding type, 'endval' or 'median' Returns: (Caltable): a padded caltable object """ outdict = {} scopes = list(self.data.keys()) for scope in scopes: if np.any(self.data[scope] is None) or len(self.data[scope]) == 0: continue caldata = copy.deepcopy(self.data[scope]) # Gather data into "scans" # TODO we could use a scan table for this as well! gathered_data = [] scandata = [caldata[0]] for i in range(1, len(caldata)): if (caldata[i]['time'] - caldata[i - 1]['time']) * 3600 > maxdiff: scandata = np.array(scandata, dtype=ehc.DTCAL) gathered_data.append(scandata) scandata = [caldata[i]] else: scandata.append(caldata[i]) # This adds the last scan scandata = np.array(scandata) gathered_data.append(scandata) # Compute padding values and pad scans for i in range(len(gathered_data)): gg = gathered_data[i] medR = np.median(gg['rscale']) medL = np.median(gg['lscale']) timepre = gg['time'][0] - maxdiff / 2. / 3600. timepost = gg['time'][-1] + maxdiff / 2. / 3600. if padtype == 'median': # pad with median scan value medR = np.median(gg['rscale']) medL = np.median(gg['lscale']) preR = medR postR = medR preL = medL postL = medL elif padtype == 'endval': # pad with endpoints preR = gg['rscale'][0] postR = gg['rscale'][-1] preL = gg['lscale'][0] postL = gg['lscale'][-1] else: # pad with ones preR = 1. postR = 1. preL = 1. postL = 1. valspre = np.array([(timepre, preR, preL)], dtype=ehc.DTCAL) valspost = np.array([(timepost, postR, postL)], dtype=ehc.DTCAL) gg = np.insert(gg, 0, valspre) gg = np.append(gg, valspost) # output data table if i == 0: caldata_out = gg else: caldata_out = np.append(caldata_out, gg) try: caldata_out # TODO: refractor to avoid using exception except NameError: print("No gathered_data") else: outdict[scope] = caldata_out return Caltable(self.ra, self.dec, self.rf, self.bw, outdict, self.tarr, source=self.source, mjd=self.mjd, timetype=self.timetype)
[docs] def applycal(self, obs, interp='linear', extrapolate=None, force_singlepol=False, copy_closure_tables=True): """Apply the calibration table to an observation. Args: obs (Obsdata): The observation with data to be calibrated interp (str): Interpolation method ('linear','nearest','cubic') extrapolate (bool): If True, points outside interpolation range will be extrapolated. force_singlepol (str): If 'L' or 'R', will set opposite polarization gains equal to chosen polarization Returns: (Obsdata): the calibrated Obsdata object """ if not (self.tarr == obs.tarr).all(): raise Exception("The telescope array in the Caltable is not the same as in the Obsdata") if extrapolate is True: # extrapolate can be a tuple or numpy array fill_value = "extrapolate" else: fill_value = extrapolate obs_orig = obs.copy() # Need to do this before switch_polrep to keep tables orig_polrep = obs.polrep obs = obs.switch_polrep('circ') rinterp = {} linterp = {} skipsites = [] for s in range(0, len(self.tarr)): site = self.tarr[s]['site'] try: self.data[site] except KeyError: skipsites.append(site) print("No Calibration Data for %s !" % site) continue time_mjd = self.data[site]['time'] / 24.0 + self.mjd rinterp[site] = relaxed_interp1d(time_mjd, self.data[site]['rscale'], kind=interp, fill_value=fill_value, bounds_error=False) linterp[site] = relaxed_interp1d(time_mjd, self.data[site]['lscale'], kind=interp, fill_value=fill_value, bounds_error=False) bllist = obs.bllist() datatable = [] for bl_obs in bllist: t1 = bl_obs['t1'][0] t2 = bl_obs['t2'][0] time_mjd = bl_obs['time'] / 24.0 + obs.mjd if t1 in skipsites: rscale1 = lscale1 = np.array(1.) else: rscale1 = rinterp[t1](time_mjd) lscale1 = linterp[t1](time_mjd) if t2 in skipsites: rscale2 = lscale2 = np.array(1.) else: rscale2 = rinterp[t2](time_mjd) lscale2 = linterp[t2](time_mjd) if force_singlepol == 'R': lscale1 = rscale1 lscale2 = rscale2 if force_singlepol == 'L': rscale1 = lscale1 rscale2 = lscale2 rrscale = rscale1 * rscale2.conj() llscale = lscale1 * lscale2.conj() rlscale = rscale1 * lscale2.conj() lrscale = lscale1 * rscale2.conj() bl_obs['rrvis'] = (bl_obs['rrvis']) * rrscale bl_obs['llvis'] = (bl_obs['llvis']) * llscale bl_obs['rlvis'] = (bl_obs['rlvis']) * rlscale bl_obs['lrvis'] = (bl_obs['lrvis']) * lrscale bl_obs['rrsigma'] = bl_obs['rrsigma'] * np.abs(rrscale) bl_obs['llsigma'] = bl_obs['llsigma'] * np.abs(llscale) bl_obs['rlsigma'] = bl_obs['rlsigma'] * np.abs(rlscale) bl_obs['lrsigma'] = bl_obs['lrsigma'] * np.abs(lrscale) if len(datatable): datatable = np.hstack((datatable, bl_obs)) else: datatable = bl_obs calobs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, np.array(datatable), obs.tarr, polrep=obs.polrep, scantable=obs.scans, source=obs.source, mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal, opacitycal=obs.opacitycal, dcal=obs.dcal, frcal=obs.frcal, timetype=obs.timetype) calobs = calobs.switch_polrep(orig_polrep) if copy_closure_tables: calobs.camp = obs_orig.camp calobs.logcamp = obs_orig.logcamp calobs.cphase = obs_orig.cphase return calobs
[docs] def merge(self, caltablelist, interp='linear', extrapolate=1): """Merge the calibration table with a list of other calibration tables Args: caltablelist (list): The list of caltables to be merged interp (str): Interpolation method ('linear','nearest','cubic') extrapolate (bool): If True, points outside interpolation range will be extrapolated. Returns: (Caltable): the merged Caltable object """ if extrapolate is True: # extrapolate can be a tuple or numpy array fill_value = "extrapolate" else: fill_value = extrapolate if not hasattr(caltablelist, '__iter__'): caltablelist = [caltablelist] tarr1 = self.tarr.copy() tkey1 = self.tkey.copy() data1 = self.data.copy() for caltable in caltablelist: # TODO check metadata! # TODO CHECK ARE THEY ALL REFERENCED TO SAME MJD??? tarr2 = caltable.tarr.copy() tkey2 = caltable.tkey.copy() data2 = caltable.data.copy() sites2 = list(data2.keys()) sites1 = list(data1.keys()) for site in sites2: if site in sites1: # if site in both tables # merge the data by interpolating time1 = data1[site]['time'] time2 = data2[site]['time'] rinterp1 = relaxed_interp1d(time1, data1[site]['rscale'], kind=interp, fill_value=fill_value, bounds_error=False) linterp1 = relaxed_interp1d(time1, data1[site]['lscale'], kind=interp, fill_value=fill_value, bounds_error=False) rinterp2 = relaxed_interp1d(time2, data2[site]['rscale'], kind=interp, fill_value=fill_value, bounds_error=False) linterp2 = relaxed_interp1d(time2, data2[site]['lscale'], kind=interp, fill_value=fill_value, bounds_error=False) times_merge = np.unique(np.hstack((time1, time2))) rscale_merge = rinterp1(times_merge) * rinterp2(times_merge) lscale_merge = linterp1(times_merge) * linterp2(times_merge) # put the merged data back in data1 # TODO can we do this faster? datatable = [] for i in range(len(times_merge)): datatable.append( np.array((times_merge[i], rscale_merge[i], lscale_merge[i]), dtype=ehc.DTCAL)) data1[site] = np.array(datatable) # sites not in both caltables else: if site not in tkey1.keys(): tarr1 = np.append(tarr1, tarr2[tkey2[site]]) data1[site] = data2[site] # update tkeys every time tkey1 = {tarr1[i]['site']: i for i in range(len(tarr1))} new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, data1, tarr1, source=self.source, mjd=self.mjd, timetype=self.timetype) return new_caltable
[docs] def save_txt(self, obs, datadir='.', sqrt_gains=False): """Saves a Caltable object to text files in the given directory Args: obs (Obsdata): The observation object associated with the Caltable datadir (str): directory to save caltable in sqrt_gains (bool): If True, we square gains before saving. Returns: """ return save_caltable(self, obs, datadir=datadir, sqrt_gains=sqrt_gains)
[docs] def scan_avg(self, obs, incoherent=True): """average the gains across scans. Args: obs (ehtim.Obsdata) : input observation incoherent (bool) : True to average gain amps, False to average amps+phase Returns: (Caltable): the averaged Caltable object """ sites = self.data.keys() ntele = len(sites) datatables = {} # iterate over each site for s in range(0, ntele): site = sites[s] # make a list of times that is the same value for all points in the same scan times = self.data[site]['time'] times_stable = times.copy() obs.add_scans() scans = obs.scans for j in range(len(times_stable)): for scan in scans: if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]: times_stable[j] = scan[0] break datatable = [] for scan in scans: gains_l = self.data[site]['lscale'] gains_r = self.data[site]['rscale'] # if incoherent average then average the magnitude of gains if incoherent: gains_l = np.abs(gains_l) gains_r = np.abs(gains_r) # average the gains gains_l_avg = np.mean(gains_l[np.array(times_stable == scan[0])]) gains_r_avg = np.mean(gains_r[np.array(times_stable == scan[0])]) # add them to a new datatable datatable.append(np.array((scan[0], gains_r_avg, gains_l_avg), dtype=ehc.DTCAL)) datatables[site] = np.array(datatable) if len(datatables) > 0: caltable = Caltable(obs.ra, obs.dec, obs.rf, obs.bw, datatables, obs.tarr, source=obs.source, mjd=obs.mjd, timetype=obs.timetype) else: caltable = False return caltable
def invert_gains(self): sites = self.data.keys() for site in sites: self.data[site]['rscale'] = 1 / self.data[site]['rscale'] self.data[site]['lscale'] = 1 / self.data[site]['lscale'] return self
[docs]def load_caltable(obs, datadir, sqrt_gains=False): """Load apriori Caltable object from text files in the given directory Args: obs (Obsdata): The observation object associated with the Caltable datadir (str): directory to save caltable in sqrt_gains (bool): If True, we take the sqrt of table gains before loading. Returns: (Caltable): a caltable object """ tarr = obs.tarr array_filename = datadir + '/array.txt' if os.path.exists(array_filename): tarr = ehtim.io.load.load_array_txt(array_filename).tarr datatables = {} for s in range(0, len(tarr)): site = tarr[s]['site'] filename = os.path.join(datadir, obs.source + '_' + site + '.txt') try: data = np.loadtxt(filename, dtype=bytes).astype(str) except IOError: try: filename = datadir + site + '.txt' data = np.loadtxt(filename, dtype=bytes).astype(str) except IOError: continue datatable = [] for row in data: time = (float(row[0]) - obs.mjd) * 24.0 # time is given in mjd if len(row) == 3: rscale = float(row[1]) lscale = float(row[2]) elif len(row) == 5: rscale = float(row[1]) + 1j * float(row[2]) lscale = float(row[3]) + 1j * float(row[4]) else: raise Exception("cannot load caltable -- format unknown!") if sqrt_gains: rscale = rscale**.5 lscale = lscale**.5 datatable.append(np.array((time, rscale, lscale), dtype=ehc.DTCAL)) datatables[site] = np.array(datatable) if len(datatables) > 0: caltable = Caltable(obs.ra, obs.dec, obs.rf, obs.bw, datatables, tarr, source=obs.source, mjd=obs.mjd, timetype=obs.timetype) else: print("COULD NOT FIND CALTABLE IN DIRECTORY %s" % datadir) caltable = False return caltable
[docs]def save_caltable(caltable, obs, datadir='.', sqrt_gains=False): """Saves a Caltable object to text files in the given directory Args: obs (Obsdata): The observation object associated with the Caltable datadir (str): directory to save caltable in sqrt_gains (bool): If True, we square gains before saving. Returns: """ if not os.path.exists(datadir): os.makedirs(datadir) ehtim.io.save.save_array_txt(obs.tarr, datadir + '/array.txt') datatables = caltable.data src = caltable.source for site_info in caltable.tarr: site = site_info['site'] if len(datatables.get(site, [])) == 0: continue filename = datadir + '/' + src + '_' + site + '.txt' outfile = open(filename, 'w') site_data = datatables[site] for entry in site_data: time = entry['time'] / 24.0 + obs.mjd if sqrt_gains: rscale = np.square(entry['rscale']) lscale = np.square(entry['lscale']) else: rscale = entry['rscale'] lscale = entry['lscale'] rreal = float(np.real(rscale)) rimag = float(np.imag(rscale)) lreal = float(np.real(lscale)) limag = float(np.imag(lscale)) outline = (str(float(time)) + ' ' + str(float(rreal)) + ' ' + str(float(rimag)) + ' ' + str(float(lreal)) + ' ' + str(float(limag)) + '\n') outfile.write(outline) outfile.close() return
[docs]def make_caltable(obs, gains, sites, times): """Create a Caltable object for an observation Args: obs (Obsdata): The observation object associated with the Caltable gains (list): list of gains (?? format ??) sites (list): list of sites times (list): list of times Returns: (Caltable): a caltable object """ ntele = len(sites) ntimes = len(times) datatables = {} for s in range(0, ntele): datatable = [] for t in range(0, ntimes): gain = gains[s * ntele + t] datatable.append(np.array((times[t], gain, gain), dtype=ehc.DTCAL)) datatables[sites[s]] = np.array(datatable) if len(datatables) > 0: caltable = Caltable(obs.ra, obs.dec, obs.rf, obs.bw, datatables, obs.tarr, source=obs.source, mjd=obs.mjd, timetype=obs.timetype) else: caltable = False return caltable
def relaxed_interp1d(x, y, **kwargs): try: len(x) except TypeError: x = np.asarray([x]) y = np.asarray([y]) # allows to run on a single float number if len(x) == 1: x = np.array([-0.5, 0.5]) + x[0] y = np.array([1.0, 1.0]) * y[0] return scipy.interpolate.interp1d(x, y, **kwargs) def plot_tarr_dterms(tarr, keys=None, label=None, legend=True, clist=ehc.SCOLORS, rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE, show=True, grid=True, export_pdf="", auto_order=True): if auto_order: # Ensure that the plot will put the stations in alphabetical order keys = np.argsort(tarr['site']) # range(len(tarr)) else: keys = range(len(tarr)) colors = iter(clist) if export_pdf != "": fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True, figsize=(16, 8)) else: fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True) for key in keys: # get the label site = str(tarr[key]['site']) if label is None: bllabel = str(site) else: bllabel = label + ' ' + str(site) color = next(colors) axes[0].plot(np.real(tarr[key]['dr']), np.imag(tarr[key]['dr']), color=color, marker='o', markersize=markersize, label=bllabel, linestyle='none') axes[0].set_title("Right D-terms") axes[0].set_xlabel("Real") axes[0].set_ylabel("Imaginary") axes[1].plot(np.real(tarr[key]['dl']), np.imag(tarr[key]['dl']), color=color, marker='o', markersize=markersize, label=bllabel, linestyle='none') axes[1].set_title("Left D-terms") axes[1].set_xlabel("Real") axes[1].set_ylabel("Imaginary") axes[0].axhline(y=0, color='k') axes[0].axvline(x=0, color='k') axes[1].axhline(y=0, color='k') axes[1].axvline(x=0, color='k') if grid: axes[0].grid() axes[1].grid() if rangex: axes[0].set_xlim(rangex) axes[1].set_xlim(rangex) if rangey: axes[0].set_ylim(rangey) axes[1].set_ylim(rangey) if legend: axes[1].legend(loc='center left', bbox_to_anchor=(1, 0.5)) if export_pdf != "": fig.savefig(export_pdf, bbox_inches='tight') return axes def plot_compare_gains(caltab1, caltab2, obs, sites='all', pol='R', gain_type='amp', ang_unit='deg', scan_avg=True, site_name_dict=None, fontsize=13, legend_fontsize=13, yscale='log', legend=True, clist=ehc.SCOLORS, rangex=False, rangey=False, scalefac=[0.9, 1.1], markersize=[2 * ehc.MARKERSIZE], show=True, grid=False, axislabels=True, remove_ticks=False, axis=False, export_pdf=""): colors = iter(clist) if ang_unit == 'deg': angle = ehc.DEGREE else: angle = 1.0 # axis if axis: x = axis else: fig = plt.figure() x = fig.add_subplot(1, 1, 1) if scan_avg: caltab1 = caltab1.scan_avg(obs, incoherent=True) caltab2 = caltab2.scan_avg(obs, incoherent=True) # sites if (isinstance(sites,str) and sites.lower() == 'all'): sites = list(set(caltab1.data.keys()).intersection(caltab2.data.keys())) if isinstance(sites,str): sites = [sites] if len(sites)==0: sites = list(set(caltab1.data.keys()).intersection(caltab2.data.keys())) if site_name_dict is None: print('hi') site_name_dict = {} for site in sites: site_name_dict[site] = site if len(markersize) == 1: markersize = markersize * np.ones(len(sites)) maxgain = 0.0 mingain = 10000 for s in range(len(sites)): site = sites[s] if pol == 'R': gains1 = caltab1.data[site]['rscale'] gains2 = caltab2.data[site]['rscale'] elif pol == 'L': gains1 = caltab1.data[site]['lscale'] gains2 = caltab2.data[site]['lscale'] if gain_type == 'amp': gains1 = np.abs(gains1) gains2 = np.abs(gains2) ylabel = 'Amplitudes' # r'$|G|$' if gain_type == 'phase': gains1 = np.angle(gains1) / angle gains2 = np.angle(gains2) / angle if ang_unit == 'deg': ylabel = r'arg($|G|$) ($^\circ$)' else: ylabel = 'Phases (radian)' # r'arg($|G|$) (radian)' # print a line maxgain = np.nanmax([maxgain, np.nanmax(gains1), np.nanmax(gains2)]) mingain = np.nanmin([mingain, np.nanmin(gains1), np.nanmin(gains2)]) # mark the gains on the plot plt.plot(gains1, gains2, marker='.', linestyle='None', color=next( colors), markersize=markersize[s], label=site_name_dict[site]) plt.xticks(fontsize=fontsize) plt.yticks(fontsize=fontsize) plt.axes().set_aspect('equal') if rangex: x.set_xlim(rangex) else: x.set_xlim([mingain * scalefac[0], maxgain * scalefac[1]]) if rangey: x.set_ylim(rangey) else: x.set_ylim([mingain * scalefac[0], maxgain * scalefac[1]]) plt.plot([mingain * scalefac[0], maxgain * scalefac[1]], [mingain * scalefac[0], maxgain * scalefac[1]], 'grey', linewidth=1) # labels if axislabels: x.set_xlabel('Ground Truth Gain ' + ylabel, fontsize=fontsize) x.set_ylabel('Recovered Gain ' + ylabel, fontsize=fontsize) else: x.tick_params(axis="y", direction="in", pad=-30) x.tick_params(axis="x", direction="in", pad=-18) if remove_ticks: plt.setp(x.get_xticklabels(), visible=False) plt.setp(x.get_yticklabels(), visible=False) if legend: plt.legend(frameon=False, fontsize=legend_fontsize) if yscale == 'log': x.set_yscale('log') x.set_xscale('log') if grid: x.grid() if export_pdf != "" and not axis: fig.savefig(export_pdf, bbox_inches='tight') if show: #plt.show(block=False) ehc.show_noblock() return x