Source code for beast.plotting.plot_indiv_fit

#!/usr/bin/env python
"""
Plot the individual fit for a single observed star
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator
from matplotlib.patches import Rectangle
import matplotlib

from astropy.table import Table
from astropy.io import fits

from beast.plotting.beastplotlib import initialize_parser
from beast.tools.symlog import inverse_symlog, symlog_linthreshold


__all__ = ["plot_indiv_fit"]


def disp_str(stats, k, keyname):
    dvals = [
        stats[keyname + "_p50"][k],
        stats[keyname + "_p84"][k],
        stats[keyname + "_p16"][k],
    ]
    if keyname in ["M_ini", "Z"]:
        dvals = np.log10(dvals)
    if keyname == "distance":
        if dvals[0] > 1000:
            dvals = [v / 1000.0 for v in dvals]
    disp_str = (
        "$"
        + "{0:.2f}".format(dvals[0])
        + "^{+"
        + "{0:.2f}".format(dvals[1] - dvals[0])
        + "}_{-"
        + "{0:.2f}".format(dvals[0] - dvals[2])
        + "}$"
    )

    return disp_str


def plot_1dpdf(ax, pdf1d_hdu, tagname, xlabel, starnum, stats=None, logx=False):

    pdf_data = pdf1d_hdu[tagname].data

    if pdf_data.ndim == 2:
        pdf = pdf_data[starnum, :]
        xvals = pdf_data[-1, :]
        n_objects, n_bins = pdf_data.shape
        n_objects -= 1
    elif pdf_data.ndim == 3:
        pdf = pdf_data[starnum, :, 0]
        xvals = pdf_data[starnum, :, 1]
        n_bins = np.sum(~np.isnan(xvals))

    ax.text(0.95, 0.95, xlabel, transform=ax.transAxes, va="top", ha="right")

    if (n_bins == 1) or (n_bins == 0):
        ax.text(0.5, 0.5, "unused", transform=ax.transAxes, va="center", ha="center")
        ax.set_yticklabels([])
        return

    if logx:
        xvals = np.log10(xvals)

    # there is a problem when "Z" in the model grid is set to a single value
    #  it shows up as two values here
    #  likely an issue deep in the BEAST fit.py code - distance handled better
    # if tagname == "Z":
    #     print(xvals, pdf)
    #     (gindxs,) = np.where(pdf > 0.0)
    #     ax.plot(xvals[gindxs], pdf[gindxs] / max(pdf[gindxs]), color="k")
    # else:
    # if tagname == "Z":
    #    print(xvals, pdf)
    ax.plot(xvals, pdf / max(pdf), color="k")

    ax.yaxis.set_major_locator(MaxNLocator(6))
    ax.xaxis.set_major_locator(MaxNLocator(4))
    xlim = [xvals.min(), xvals.max()]
    xlim_delta = xlim[1] - xlim[0]
    if ~np.isnan(xlim[0]):
        ax.set_xlim(xlim[0] - 0.05 * xlim_delta, xlim[1] + 0.05 * xlim_delta)
    else:
        bestval = stats[tagname + "_Best"][starnum]
        if tagname == "distance":
            bestval /= 1000.0
        ax.set_xlim(0.95 * bestval, 1.05 * bestval)
    ax.set_ylim(0.0, 1.1)
    ax.set_yticklabels([])

    if stats is not None:
        ylim = ax.get_ylim()

        y1 = ylim[0] + 0.5 * (ylim[1] - ylim[0])
        y2 = ylim[0] + 0.7 * (ylim[1] - ylim[0])
        pval = stats[tagname + "_Best"][starnum]
        if tagname == "distance":
            pval /= 1000.0
        if logx:
            pval = np.log10(pval)
        ax.plot(np.full((2), pval), [y1, y2], "-", color="c")

        y1 = ylim[0] + 0.2 * (ylim[1] - ylim[0])
        y2 = ylim[0] + 0.4 * (ylim[1] - ylim[0])
        y1m = ylim[0] + 0.25 * (ylim[1] - ylim[0])
        y2m = ylim[0] + 0.35 * (ylim[1] - ylim[0])
        ym = 0.5 * (y1 + y2)
        pvals = [
            stats[tagname + "_p50"][starnum],
            stats[tagname + "_p16"][starnum],
            stats[tagname + "_p84"][starnum],
        ]
        if logx:
            pvals = np.log10(pvals)
        ax.plot(np.full((2), pvals[0]), [y1m, y2m], "-", color="m")
        ax.plot(np.full((2), pvals[1]), [y1, y2], "-", color="m")
        ax.plot(np.full((2), pvals[2]), [y1, y2], "-", color="m")
        ax.plot(pvals[1:3], [ym, ym], "-", color="m")


[docs] def plot_indiv_fit(filebase, starnum=0, savefig=False, plotfig=True): """ Plot the individual fit for a single observed star including best fit & percentile parameters and various 1D pPDFs Parameters ---------- filebase : str base filename of run starnum : int number of star in the stats file savefig : str set to the file extension fo the desired plot file (e.g., png, pdf, etc) plotfig : boolean plot the figure to a file or the screen based on savefig otherwise return the fig object """ starnum = int(starnum) # determine how the stats/pdf1d filenames are to be set if len(np.atleast_1d(filebase)) == 1: stats_fname = f"{filebase}_stats.fits" pdf1d_fname = f"{filebase}_pdf1d.fits" else: stats_fname = filebase[0] pdf1d_fname = filebase[1] # read in the stats stats = Table.read(stats_fname, hdu=1) # check how many extensions the stats file has # determines how to get the filternames and wavelengths with fits.open(stats_fname) as hdul: nhdu = len(hdul) if nhdu > 2: filter_info = Table.read(stats_fname, hdu=2) bfilters = filter_info["filternames"].data waves = filter_info["wavelengths"].data filters = [cfilter.decode("utf-8") for cfilter in bfilters] else: # PHAT values as default to support old stats files filters = [ "HST_WFC3_F275W", "HST_WFC3_F336W", "HST_ACS_WFC_F475W", "HST_ACS_WFC_F814W", "HST_WFC3_F110W", "HST_WFC3_F160W", ] waves = np.asarray([2722.05, 3366.01, 4763.05, 8087.37, 11672.36, 15432.74]) # open 1D PDF file pdf1d_hdu = fits.open(pdf1d_fname) fig, ax = plt.subplots(figsize=(8, 8)) # setup the plot grid gridNrow, gridNcol = 5, 12 gs = gridspec.GridSpec( gridNrow, gridNcol, height_ratios=[1.0] * gridNrow, width_ratios=[1.0] * gridNcol, ) ax = [] # axes for the big SED plot. Leave empty columns right of the plot to # put the legend and values. sed_height = 2 free_cols = 3 index_sedplot = len(ax) ax.append(plt.subplot(gs[0:sed_height, 0 : -1 - free_cols])) # axes for the 1D PDFs nprim = 4 nsec = 3 nderiv = 3 indices_1dpdf = [] rows = [sed_height + i for i in range(3)] widths = [3, 4, 4] naxes = [nprim, nsec, nderiv] for r, w, n in zip(rows, widths, naxes): for i in range(n): indices_1dpdf.append(len(ax)) ax.append(plt.subplot(gs[r, i * w : (i + 1) * w])) # plot the SED n_filters = len(filters) # get the observations waves *= 1e-4 obs_flux = np.zeros((n_filters), dtype=float) mod_flux = np.zeros((n_filters, 3), dtype=float) mod_flux_nd = np.zeros((n_filters, 3), dtype=float) mod_flux_wbias = np.zeros((n_filters, 3), dtype=float) k = starnum corname = stats["Name"][k] for i, cfilter in enumerate(filters): obs_flux[i] = stats[cfilter][k] fluxname = "log" + cfilter mod_flux[i, 0] = np.power(10.0, stats[fluxname + "_wd_p50"][k]) mod_flux[i, 1] = np.power(10.0, stats[fluxname + "_wd_p16"][k]) mod_flux[i, 2] = np.power(10.0, stats[fluxname + "_wd_p84"][k]) mod_flux_nd[i, 0] = np.power(10.0, stats[fluxname + "_nd_p50"][k]) mod_flux_nd[i, 1] = np.power(10.0, stats[fluxname + "_nd_p16"][k]) mod_flux_nd[i, 2] = np.power(10.0, stats[fluxname + "_nd_p84"][k]) if "sym" + fluxname + "_wd_bias_p50" in stats.colnames: mod_flux_wbias[i, 0] = inverse_symlog( stats["sym" + fluxname + "_wd_bias_p50"][k] ) mod_flux_wbias[i, 1] = inverse_symlog( stats["sym" + fluxname + "_wd_bias_p16"][k] ) mod_flux_wbias[i, 2] = inverse_symlog( stats["sym" + fluxname + "_wd_bias_p84"][k] ) sed_ax = ax[index_sedplot] sed_ax.plot(waves, obs_flux, "ko", label="observed") if "symlog" + filters[0] + "_wd_bias_p50" in stats.colnames: sed_ax.plot(waves, mod_flux_wbias[:, 0], "b-", label="stellar+dust+bias") sed_ax.fill_between( waves, mod_flux_wbias[:, 1], mod_flux_wbias[:, 2], color="b", alpha=0.3 ) sed_ax.plot(waves, mod_flux[:, 0], "r-", label="stellar+dust") sed_ax.fill_between(waves, mod_flux[:, 1], mod_flux[:, 2], color="r", alpha=0.2) sed_ax.plot(waves, mod_flux_nd[:, 0], "y-", label="stellar only") sed_ax.fill_between( waves, mod_flux_nd[:, 1], mod_flux_nd[:, 2], color="y", alpha=0.1 ) # can introduce a legend loc option if 'best' produces overlap sed_ax.legend(loc='best', fontsize=9) sed_ax.set_ylabel(r"Flux [ergs s$^{-1}$ cm$^{-2}$ $\AA^{-1}$]") sed_ax.set_yscale("symlog", linthresh=symlog_linthreshold) sed_ax.grid(True) sed_ax.text(0.5, -0.07, r"$\lambda$ [$\mu m$]", transform=sed_ax.transAxes, va="top") sed_ax.set_xlim(0.2, 2.0) sed_ax.set_xscale("log") sed_ax.minorticks_off() sed_ax.set_xticks([0.2, 0.3, 0.4, 0.5, 0.8, 1.0, 2.0]) sed_ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) sed_ax.get_xaxis().set_minor_formatter(matplotlib.ticker.ScalarFormatter()) sed_ax.text(0.05, 0.95, corname, transform=sed_ax.transAxes, va="top", ha="left") # add the text results keys = ["Av", "M_ini", "logA", "distance", "Rv", "f_A", "Z", "logT", "logg", "logL"] dispnames = [ "A(V)", "log(M)", "log(t)", "d(kpc)", "R(V)", r"f$_\mathcal{A}$", "log(Z)", r"log(T$_\mathrm{eff})$", "log(g)", "log(L)", ] startprim, stopprim = 0, nprim - 1 # 0 1 2 3 startsec, stopsec = stopprim + 1, stopprim + nsec # 4 5 6 startderiv, stopderiv = stopsec + 1, stopsec + nderiv # 7 8 9 laby = 0.96 ty = np.linspace(laby - 0.1, 0.1, num=len(keys)) ty[startsec:] -= 0.04 ty[startderiv:] -= 0.04 tx = [1.14, 1.3, 1.47] for i in range(len(keys)): sed_ax.text(tx[0], ty[i], dispnames[i], ha="center", transform=sed_ax.transAxes) sed_ax.text( tx[1], ty[i], disp_str(stats, starnum, keys[i]), ha="center", color="m", transform=sed_ax.transAxes, ) best_val = stats[keys[i] + "_Best"][k] if keys[i] in ["M_ini", "Z"]: best_val = np.log10(best_val) if keys[i] == "distance": best_val /= 1000.0 dispnames[i] = dispnames[i].replace("pc", "kpc") sed_ax.text( tx[2], ty[i], "$" + "{0:.2f}".format(best_val) + "$", ha="center", color="c", transform=sed_ax.transAxes, ) sed_ax.text( tx[0], laby, "Param", ha="center", transform=sed_ax.transAxes, fontsize=10 ) sed_ax.text( tx[1], laby, r"50$\pm$33%", ha="center", color="k", transform=sed_ax.transAxes, fontsize=10, ) sed_ax.text( tx[2], laby, "Best", color="k", ha="center", transform=sed_ax.transAxes, fontsize=10, ) # now draw boxes around the different kinds of parameters tax = sed_ax left, right = tx[0], tx[-1] def draw_box_around_values(start, stop, ls): deltaline = ty[start] - ty[start + 1] top = ty[start] + deltaline # Draw the top border ABOVE the text bottom = ty[stop] rec = Rectangle( (left - 0.1, bottom - 0.02), right - left + 0.15, top - bottom + 0.01, fill=False, lw=2, transform=tax.transAxes, ls=ls, ) rec = tax.add_patch(rec) rec.set_clip_on(False) # primary draw_box_around_values(startprim, stopprim, ls="dashed") # secondary draw_box_around_values(startsec, stopsec, ls="dotted") # derived draw_box_around_values(startderiv, stopderiv, ls="dashdot") # Make these plots: # A, M, t, dist, # R, fA, Z # logT, logg, logL # This is done by iterating over the axes created at the start of # this function, from left to right, line per line. # plot the primary parameter 1D PDFs ax_iter = (ax[i] for i in indices_1dpdf) first_primary_ax = next(ax_iter) plot_1dpdf(first_primary_ax, pdf1d_hdu, "Av", "A(V)", starnum, stats=stats) plot_1dpdf( next(ax_iter), pdf1d_hdu, "M_ini", "log(M)", starnum, logx=True, stats=stats ) plot_1dpdf(next(ax_iter), pdf1d_hdu, "logA", "log(t)", starnum, stats=stats) last_primary_ax = next(ax_iter) plot_1dpdf(last_primary_ax, pdf1d_hdu, "distance", "d(kpc)", starnum, stats=stats) # plot the secondary parameter 1D PDFs first_secondary_ax = next(ax_iter) plot_1dpdf(first_secondary_ax, pdf1d_hdu, "Rv", "R(V)", starnum, stats=stats) plot_1dpdf( next(ax_iter), pdf1d_hdu, "f_A", r"f$_\mathcal{A}$", starnum, stats=stats ) last_secondary_ax = next(ax_iter) plot_1dpdf(last_secondary_ax, pdf1d_hdu, "Z", "log(Z)", starnum, logx=True, stats=stats) # plot the derived parameter 1D PDFs first_derived_ax = next(ax_iter) plot_1dpdf( first_derived_ax, pdf1d_hdu, "logT", r"log(T$_\mathrm{eff})$", starnum, stats=stats, ) plot_1dpdf(next(ax_iter), pdf1d_hdu, "logg", "log(g)", starnum, stats=stats) last_derived_ax = next(ax_iter) plot_1dpdf(last_derived_ax, pdf1d_hdu, "logL", "log(L)", starnum, stats=stats) # A more manual version of tight_layout plt.subplots_adjust( top=0.95, bottom=0.05, left=0.125, right=0.925, wspace=0.5, hspace=0.5 ) # PLOT ALL THE BOXES AFTER CALLING TIGHT LAYOUT! Tight layout # changes the coordinates of the axes a little, but leaves the boxes # untouched. Therefore, we plot the boxes here by extracting the # coordinates of the axes after they have been modified by # tight_layout. def rectangle_around_axes(bottomleft_ax, topright_ax, pad, ls, label=None): """ pad: tuple, (left, right, bottom, top) """ left, bottom = bottomleft_ax.get_position().get_points()[0] right, top = topright_ax.get_position().get_points()[1] left -= pad[0] right += pad[1] bottom -= pad[2] top += pad[3] transf = plt.gcf().transFigure rec = Rectangle( (left, bottom), right - left, top - bottom, transform=transf, fill=False, lw=2, ls=ls, ) rec = bottomleft_ax.add_patch(rec) rec.set_clip_on(False) if label: middle = (top + bottom) / 2.0 moreleft = left # pad[0] bottomleft_ax.text( moreleft, middle, label, transform=transf, rotation="vertical", fontstyle="oblique", va="center", ha="right", ) rectanglePadding = (0.03, 0.01, 0.03, 0.01) # Box around primaries tax = first_primary_ax rectangle_around_axes( first_primary_ax, last_primary_ax, pad=rectanglePadding, ls="dashed", label="Primary", ) tax.text( 0.0, 0.5, "Probability", transform=tax.transAxes, rotation="vertical", va="center", ha="right", ) # Box around secondaries tax = first_secondary_ax rectangle_around_axes( first_secondary_ax, last_secondary_ax, pad=rectanglePadding, ls="dotted", label="Secondary", ) tax.text( 0.0, 0.5, "Probability", transform=tax.transAxes, rotation="vertical", va="center", ha="right", ) # Box around deriveds tax = first_derived_ax rectangle_around_axes( first_derived_ax, last_derived_ax, pad=rectanglePadding, ls="dashdot", label="Derived", ) tax.text( 0.0, 0.5, "Probability", transform=tax.transAxes, rotation="vertical", va="center", ha="right", ) # show or save if plotfig: basename = filebase + "_ifit_starnum_" + str(starnum) if savefig: fig.savefig("{}.{}".format(basename, savefig)) else: plt.show() else: return fig
if __name__ == "__main__": # pragma: no cover parser = initialize_parser() parser.add_argument("filebase", type=str, help="base filename of output results") parser.add_argument( "--starnum", type=int, default=0, help="star number in observed file" ) args = parser.parse_args() # make the plot! plot_indiv_fit(args.filebase, args.starnum, args.savefig)