Source code for beast.plotting.plot_param_recovery
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from astropy.io import fits
__all__ = ["plot_param_recovery"]
[docs]def plot_param_recovery(
sim_data_list,
stats_file_list,
output_plot_filename,
file_label_list=None,
max_nbins=20,
):
"""
Make plots comparing the physical parameters from simulated data to the
recovered physical parameters
If there are multiple files input, it is presumably because they are from
different noise models. If that's the case, you may want to assign labels
for each of them (file_label_list).
Parameters
----------
sim_data_list : string or list of strings
File(s) of simulated data from beast.tools.simulate_obs, which have both
the photometry and physical parameters
stats_file_list : string or list of strings
File(s) of the corresponding stats files with the fit statistics
output_plot_filename : string
name of the file in which to save the output plot
file_label_list : string (default=None)
Labels to use for each of the files (e.g., their source density ranges)
max_nbins : int (default=10)
maximum number of bins to use in each dimension of the 2D histogram
(fewer will be used if there are fewer unique values)
"""
# parameters to plot
param_list = ["Av", "logA", "M_ini", "Rv", "f_A", "Z", "distance"]
n_param = len(param_list)
# number of files
n_stat = len(sim_data_list)
# figure
fig = plt.figure(figsize=(5 * n_stat, 4 * n_param))
# iterate through the files
for i, (sim_stats, recov_stats) in enumerate(
zip(np.atleast_1d(sim_data_list), np.atleast_1d(stats_file_list))
):
# read in data
with fits.open(sim_stats) as hdu_sim, fits.open(recov_stats) as hdu_recov:
sim_table = hdu_sim[1].data
recov_table = hdu_recov[1].data
# make plots
for p, param in enumerate(param_list):
# subplot region
ax = plt.subplot(n_param, n_stat, 1 + n_stat * p + i)
# set things to plot
plot_x = sim_table[param]
plot_y = recov_table[param + "_p50"]
if ("M_" in param) or (param == "Z"):
plot_x = np.log10(plot_x)
plot_y = np.log10(plot_y)
# number of bins
n_uniq = len(np.unique(plot_x))
n_bins = [min(n_uniq, max_nbins), min(3 * n_uniq, max_nbins)]
# plot
plt.hist2d(
plot_x,
plot_y,
bins=n_bins,
cmap="magma",
norm=matplotlib.colors.LogNorm(),
)
# axis labels
ax.tick_params(axis="both", which="major", labelsize=13)
# ax.set_xlim(ax.get_xlim()[::-1])
param_label = param
if ("M_" in param) or (param == "Z"):
param_label = "log " + param
plt.xlabel("Simulated " + param_label, fontsize=14)
plt.ylabel("Recovered " + param_label, fontsize=14)
plt.tight_layout()
fig.savefig(output_plot_filename)
plt.close(fig)