import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import copy
from scipy.stats import binned_statistic, binned_statistic_2d
from astropy.table import Table, vstack
from beast.physicsmodel.grid import SEDGrid
import beast.observationmodel.noisemodel.generic_noisemodel as noisemodel
__all__ = ["plot_completeness"]
[docs]
def plot_completeness(
physgrid_list,
noise_model_list,
output_plot_filename,
param_list=["Av", "Rv", "logA", "f_A", "M_ini", "Z", "distance"],
compl_filter="F475W",
):
"""
Make visualization of the completeness
Parameters
----------
physgrid_list : string or list of strings
Name of the physics model file. If there are multiple physics model
grids (i.e., if there are subgrids), list them all here.
noise_model_list : string or list of strings
Name of the noise model file. If there are multiple files for
physgrid_list (because of subgrids), list the noise model file
associated with each physics model file.
param_list : list of strings
names of the parameters to plot. Two or more parameters need to be specified.
compl_filter : str
filter to use for completeness (required for toothpick model)
output_plot_filename : string
name of the file in which to save the output plot
"""
n_params = len(param_list)
# If there are subgrids, we can't read them all into memory. Therefore,
# we'll go through each one and just grab the relevant parts.
compl_table_list = []
# make a table for each physics model + noise model
for physgrid, noise_model in zip(
np.atleast_1d(physgrid_list), np.atleast_1d(noise_model_list)
):
# get the physics model grid - includes priors
modelsedgrid = SEDGrid(str(physgrid))
# get list of filters
short_filters = [
filter.split(sep="_")[-1].upper() for filter in modelsedgrid.filters
]
if compl_filter.upper() not in short_filters:
raise ValueError("requested completeness filter not present")
filter_k = short_filters.index(compl_filter.upper())
print("Completeness from {0}".format(modelsedgrid.filters[filter_k]))
# read in the noise model
noisegrid = noisemodel.get_noisemodelcat(str(noise_model))
# get the completeness
model_compl = noisegrid["completeness"]
# put it all into a table
table_dict = {x: modelsedgrid[x] for x in param_list}
table_dict["compl"] = model_compl[:, filter_k]
# append to the list
compl_table_list.append(Table(table_dict))
# stack all the tables into one
compl_table = vstack(compl_table_list)
# import pdb; pdb.set_trace()
# figure
fig = plt.figure(figsize=(4 * n_params, 4 * n_params))
# label font sizes
label_font = 25
tick_font = 22
# load in color map
cmap = matplotlib.cm.get_cmap("magma")
# iterate through the panels
for i, pi in enumerate(param_list):
for j, pj in enumerate(param_list[i:], i):
print("plotting {0} and {1}".format(pi, pj))
# not along diagonal
if i != j:
# set up subplot
plt.subplot(n_params, n_params, i + j * (n_params) + 1)
ax = plt.gca()
# create image and labels
x_col, x_bins, x_label = setup_axis(compl_table, pi)
y_col, y_bins, y_label = setup_axis(compl_table, pj)
compl_image, _, _, _ = binned_statistic_2d(
x_col,
y_col,
compl_table["compl"],
statistic="mean",
bins=(x_bins, y_bins),
)
# plot points
im = plt.imshow(
compl_image.T,
# np.random.random((4,4)),
extent=(
np.min(x_bins),
np.max(x_bins),
np.min(y_bins),
np.max(y_bins),
),
cmap="magma",
vmin=0,
vmax=1,
aspect="auto",
origin="lower",
)
ax.tick_params(
axis="both",
which="both",
direction="in",
labelsize=tick_font,
bottom=True,
top=True,
left=True,
right=True,
)
# axis labels and ticks
if i == 0:
ax.set_ylabel(y_label, fontsize=label_font)
# ax.get_yaxis().set_label_coords(-0.35,0.5)
else:
ax.set_yticklabels([])
if j == n_params - 1:
ax.set_xlabel(x_label, fontsize=label_font)
plt.xticks(rotation=-45)
else:
ax.set_xticklabels([])
# along diagonal
if i == j:
# set up subplot
plt.subplot(n_params, n_params, i + j * (n_params) + 1)
ax = plt.gca()
# create histogram and labels
x_col, x_bins, x_label = setup_axis(compl_table, pi)
y_col, y_bins, y_label = setup_axis(compl_table, pj)
compl_hist, _, _ = binned_statistic(
x_col, compl_table["compl"], statistic="mean", bins=x_bins,
)
# make histogram
_, _, patches = plt.hist(x_bins[:-1], x_bins, weights=compl_hist)
# color each bar by its completeness
for c, comp in enumerate(compl_hist):
patches[c].set_color(cmap(comp))
patches[c].set_linewidth = 0.1
# make a black outline so it stands out as a histogram
plt.hist(
x_bins[:-1], x_bins, weights=compl_hist, histtype="step", color="k"
)
# axis ranges
plt.xlim(np.min(x_bins), np.max(x_bins))
plt.ylim(0, 1.05)
ax.tick_params(axis="y", which="both", length=0, labelsize=tick_font)
ax.tick_params(
axis="x", which="both", direction="in", labelsize=tick_font
)
# axis labels and ticks
ax.set_yticklabels([])
if i < n_params - 1:
ax.set_xticklabels([])
if i == n_params - 1:
ax.set_xlabel(x_label, fontsize=label_font)
plt.xticks(rotation=-45)
if i == 0:
ax.set_ylabel(y_label, fontsize=label_font)
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.tight_layout()
# add a colorbar
gs = GridSpec(nrows=20, ncols=n_params)
cax = fig.add_subplot(gs[0, 1:])
cbar = plt.colorbar(im, cax=cax, orientation="horizontal")
cbar.set_label("Completeness", fontsize=label_font)
cbar.ax.tick_params(labelsize=tick_font)
gs.tight_layout(fig)
fig.savefig(output_plot_filename)
plt.close(fig)
def setup_axis(compl_table, param):
"""
Set up the bins and labels for a parameter
Parameters
----------
compl_table : astropy table
table with each set of physical parameters and their completeness
param : string
name of the parameter we're binning/labeling
Returns
-------
col : numpy array
column to plot
bins : numpy array
bin edges
label : string
the axis label to use
"""
# mass isn't reguarly spaced, so take log and manually define bins
if "M_" in param:
col = np.log10(compl_table[param])
bins = np.linspace(np.min(col), np.max(col), 20)
label = "log " + param
# metallicity just needs to be log
elif param == "Z":
col = np.log10(compl_table[param])
bins = np.linspace(np.min(col), np.max(col), len(np.unique(col)) + 1)
label = "log " + param
# for all others, standard linear spacing is ok
else:
col = copy.copy(compl_table[param])
bins = np.linspace(np.min(col), np.max(col), len(np.unique(col)) + 1)
label = copy.copy(param)
return col, bins, label