Source code for ehtim.plotting.comparisons

# comparisons.py
# Image Consistency Comparisons
#
#    Copyright (C) 2018 Katie Bouman
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
from __future__ import print_function

from builtins import str
from builtins import range
from builtins import object

import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)
from itertools import cycle

try:
    import networkx as nx
except ImportError:
    print("Warning: networkx not installed! Cannot use image_agreements()")


import ehtim.const_def as ehc


def image_consistency(imarr, beamparams, metric='nxcorr',
                      blursmall=True, beam_max=1.0, beam_steps=5, savepath=[]):

    # get the pixel sizes and fov to compare images at
    (min_psize, max_fov) = get_psize_fov(imarr)

    # initialize matrix matrix
    metric_mtx = np.zeros([len(imarr), len(imarr), beam_steps])

    # get the different fracsteps
    fracsteps = np.linspace(0, beam_max, beam_steps)

    # loop over the different beam sizes
    for fracidx in range(beam_steps):
        # print(fracidx)
        # look at every pair of images and compute their beam convolved metrics
        for i in range(len(imarr)):
            img1 = imarr[i]
            if fracsteps[fracidx] > 0:
                img1 = img1.blur_gauss(beamparams, fracsteps[fracidx])

            for j in range(i+1, len(imarr)):
                img2 = imarr[j]
                if fracsteps[fracidx] > 0:
                    img2 = img2.blur_gauss(beamparams, fracsteps[fracidx])

                # compute image comparision under a specified blur_frac
                (error, im1_pad, im2_shift) = img1.compare_images(img2, metric=[metric],
                                                                  psize=min_psize,
                                                                  target_fov=max_fov,
                                                                  blur_frac=0.0,
                                                                  beamparams=beamparams)

                # if specified save the shifted images used for comparision
                if savepath:
                    im1_pad.save_fits(savepath + '/' + str(i) + '_' + str(fracidx) + '.fits')
                    im2_shift.save_fits(savepath + '/' + str(j) + '_' + str(fracidx) + '.fits')

                # save the metric value in a matrix
                metric_mtx[i, j, fracidx] = error[0]

    return (metric_mtx, fracsteps)


[docs]def get_psize_fov(imarr): """Look over an array of images and determine the min pixel size and max fov that can be used consistently across them """ min_psize = 100 for i in range(0, len(imarr)): if i == 0: max_fov = np.max([imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim]) min_psize = imarr[i].psize else: max_fov = np.max([max_fov, imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim]) min_psize = np.min([min_psize, imarr[i].psize]) return (min_psize, max_fov)
def image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=0.95): if 'networkx' not in sys.modules: raise Exception("networkx not installed: cannot use image_agreements()!") (min_psize, max_fov) = get_psize_fov(imarr) im_cliques_fraclevels = [] cliques_fraclevels = [] for fracidx in range(len(fracsteps)): # print(fracidx) slice_metric_mtx = metric_mtx[:, :, fracidx] cuttoffidx = np.where(slice_metric_mtx >= cutoff) consistant = zip(*cuttoffidx) # make graph G = nx.Graph() for i in range(len(consistant)): G.add_edge(consistant[i][0], consistant[i][1]) # find all cliques cliques = list(nx.find_cliques(G)) # print(cliques) cliques_fraclevels.append(cliques) im_clique = [] for c in range(len(cliques)): clique = cliques[c] im_avg = imarr[clique[0]].blur_gauss(beamparams, fracsteps[fracidx]) for n in range(1, len(clique)): imcomp = imarr[clique[n]].blur_gauss(beamparams, fracsteps[fracidx]) (error, im_avg, im2_shift) = im_avg.compare_images(imcomp, metric=['xcorr'], psize=min_psize, target_fov=max_fov, blur_frac=0.0, beamparams=beamparams) im_avg.imvec = (im_avg.imvec + im2_shift.imvec) / 2.0 im_clique.append(im_avg.copy()) im_cliques_fraclevels.append(im_clique) return(cliques_fraclevels, im_cliques_fraclevels) def change_cut_off(metric_mtx, fracsteps, imarr, beamparams, cutoff=0.95, zoom=0.1, fov=1): (cliques_fraclevels, im_cliques_fraclevels) = image_agreements( imarr, beamparams, metric_mtx, fracsteps, cutoff=cutoff) generate_consistency_plot(cliques_fraclevels, im_cliques_fraclevels, metric_mtx=metric_mtx, fracsteps=fracsteps, beamparams=beamparams, zoom=zoom, fov=fov) def generate_consistency_plot(clique_fraclevels, im_clique_fraclevels, zoom=0.1, fov=1, show=True, framesize=(20, 10), fracsteps=None, cutoff=None, r_offset=1): fig, ax = plt.subplots(figsize=framesize) cycol = cycle('bgrcmk') x_loc = [] for c, column in enumerate(clique_fraclevels): colorc = cycol.next() x_loc.append(((20./len(clique_fraclevels))*c)) if len(column) == 0: continue for r, row in enumerate(column): # adding the images lenx = len(clique_fraclevels) leny = 0 for li in clique_fraclevels: if len(li) > leny: leny = len(li) sample_image = im_clique_fraclevels[c][r].regrid_image( fov*im_clique_fraclevels[c][r].fovx(), 512) arr_img = sample_image.imvec.reshape(sample_image.xdim, sample_image.ydim) imagebox = OffsetImage(arr_img, zoom=zoom, cmap='afmhot') imagebox.image.axes = ax ab = AnnotationBbox(imagebox, ((20./lenx)*c+r_offset, (20./leny)*r), xycoords='data', pad=0.0, arrowprops=None) ax.add_artist(ab) # adding the arrows if c+1 != len(clique_fraclevels): for a, ro in enumerate(clique_fraclevels[c+1]): if set(row).issubset(ro): px = c+1 px = ((20./lenx)*px) + r_offset py = a py = (20./leny)*py break xx = (20./lenx)*c + (8./lenx) + r_offset yy = (20./leny)*r ax.arrow(xx, yy, px - xx - (9./lenx), py - yy, head_width=0.05, head_length=0.1, color=colorc ) row.sort() # adding the text txtstring = str(row) ax.text((20./lenx)*c, (20./leny)*(r-0.5), txtstring, fontsize=10, horizontalalignment='center', color='black', zorder=1000) ax.set_xlim(0, 22) ax.set_ylim(-10, 22) ax.set_xticks(x_loc) ax.set_xticklabels(fracsteps) ax.set_yticks([]) ax.set_yticklabels([]) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.set_title('Blurred comparison of all images; cutoff={0}, fov (uas)={1}'.format( str(cutoff), str(im_clique_fraclevels[-1][-1].fovx()/ehc.RADPERUAS))) # for item in [fig, ax]: # item.patch.set_visible(False) # fig.patch.set_visible(False) # ax.axis('off') if show is True: plt.show()