import numpy as np
from scipy.interpolate import interp1d
from scipy.integrate import quad
from beast.physicsmodel.grid_weights_stars import compute_bin_boundaries
import beast.physicsmodel.priormodel_functions as pmfuncs
__all__ = [
"PriorModel",
"PriorDustModel",
"PriorAgeModel",
"PriorMassModel",
"PriorMetallicityModel",
"PriorDistanceModel",
]
[docs]class PriorModel:
"""
Compute the priors as weights given the input grid
"""
def __init__(self, model, allowed_models=None):
"""
Initialize with basic information
Parameters
----------
model: dict
Choice of model type
"""
if (allowed_models is not None) and (model["name"] not in allowed_models):
modname = model["name"]
raise NotImplementedError(f"{modname} is not an allowed model")
# save the model
self.model = model
[docs] def __call__(self, x):
"""
Weights based on input model choice
Parameters
----------
x : float
values for model evaluation
"""
if self.model["name"] == "flat":
if "amp" in self.model.keys():
amp = self.model["amp"]
else:
amp = 1.0
if hasattr(x, "shape"):
return np.full(x.shape, amp)
else:
return amp
elif self.model["name"] in ["bins_histo", "bins_interp"]:
for ckey in ["x", "values"]:
if ckey not in self.model.keys():
raise KeyError(f"{ckey} not in prior model keys")
# check if all ages within interpolation range
mod_x = self.model["x"]
if np.any([(cval > np.max(mod_x)) or (cval < np.min(mod_x)) for cval in x]):
raise ValueError("requested x outside of model x range")
if self.model["name"] == "bins_histo":
# interpolate according to bins, assuming value is constant from i to i+1
# and allow for bin edges input
if len(self.model["values"]) == len(self.model["x"]) - 1:
self.model["values"].append(0.0)
interfunc = interp1d(self.model["x"], self.model["values"], kind="zero")
return interfunc(x)
else:
# interpolate model to grid ages
return np.interp(
x,
np.array(self.model["x"]),
np.array(self.model["values"]),
)
elif self.model["name"] == "lognormal":
for ckey in ["mean", "sigma"]:
if ckey not in self.model.keys():
raise ValueError(f"{ckey} not in prior model keys")
return pmfuncs._lognorm(x, self.model["mean"], sigma=self.model["sigma"])
elif self.model["name"] == "two_lognormal":
for ckey in ["mean1", "sigma1", "mean2", "sigma2"]:
if ckey not in self.model.keys():
raise ValueError(f"{ckey} not in prior model keys")
return pmfuncs._two_lognorm(
x,
self.model["mean1"],
self.model["mean2"],
sigma1=self.model["sigma1"],
sigma2=self.model["sigma2"],
N1=self.model["N1_to_N2"],
N2=1.0,
)
elif self.model["name"] == "exponential":
for ckey in ["tau"]:
if ckey not in self.model.keys():
raise ValueError(f"{ckey} not in prior model keys")
return pmfuncs._exponential(x, tau=self.model["tau"])
else:
modname = self.model["name"]
raise NotImplementedError(f"{modname} is not an allowed model")
[docs]class PriorDustModel(PriorModel):
"""
Prior model for dust parameters with specific allowed models.
"""
def __init__(self, model):
"""
Initialize the dust prior model
Parameters
----------
model : dict
Possible choices are flat, lognormal, two_lognormal, and exponential
"""
super().__init__(
model, allowed_models=["flat", "lognormal", "two_lognormal", "exponential"]
)
[docs]class PriorAgeModel(PriorModel):
"""
Prior model for age parameter with specific allowed models.
"""
def __init__(self, model):
"""
Initialize the stellar age prior model
Parameters
----------
model : dict
Possible choices are flat, flat_log, bins_histo, bins_interp, and exponential
"""
super().__init__(
model,
allowed_models=[
"flat",
"flat_log",
"bins_histo",
"bins_interp",
"exponential",
],
)
[docs] def __call__(self, x):
"""
Weights based on input model choice
Parameters
----------
x : float
values for model evaluation
"""
if self.model["name"] == "flat_log":
weights = 1.0 / np.diff(10 ** compute_bin_boundaries(x))
return weights / np.sum(weights)
elif self.model["name"] == "exponential":
return pmfuncs._exponential(10.0 ** x, tau=self.model["tau"] * 1e9)
else:
return super().__call__(x)
[docs]class PriorDistanceModel(PriorModel):
"""
Prior model for distance parameter with specific allowed models.
"""
def __init__(self, model):
"""
Initialize the distance prior model
Parameters
----------
model : dict
Possible choices are flat
"""
super().__init__(model, allowed_models=["flat"])
[docs]class PriorMassModel(PriorModel):
"""
Prior model for mass parameter with specific allowed models.
"""
def __init__(self, model):
"""
Initialize the initial mass prior model
Parameters
----------
model : dict
Possible choices are flat, slapeter, and kroupa
"""
super().__init__(model, allowed_models=["flat", "salpeter", "kroupa"])
[docs] def __call__(self, x):
"""
Weights based on input model choice
Parameters
----------
x : float
values for model evaluation
"""
# sort the initial mass along this isochrone
sindxs = np.argsort(x)
# Compute the mass bin boundaries
mass_bounds = compute_bin_boundaries(x[sindxs])
# integrate the IMF over each bin
args = None
if self.model["name"] == "kroupa":
if "alpha0" in self.model.keys(): # assume other alphas also present
args = (
self.model["alpha0"],
self.model["alpha1"],
self.model["alpha2"],
self.model["alpha3"],
)
imf_func = pmfuncs._imf_kroupa
elif self.model["name"] == "salpeter":
if "slope" in self.model.keys():
slope = self.model["slope"]
args = (slope,)
imf_func = pmfuncs._imf_salpeter
elif self.model["name"] == "flat":
imf_func = pmfuncs._imf_flat
# calculate the average prior in each mass bin
mass_weights = np.zeros(len(x))
for i, cindx in enumerate(sindxs):
# fmt: off
if args is not None:
mass_weights[cindx] = (quad(imf_func, mass_bounds[i], mass_bounds[i + 1], args))[0]
else:
mass_weights[cindx] = (quad(imf_func, mass_bounds[i], mass_bounds[i + 1]))[0]
# fmt: on
mass_weights[cindx] /= mass_bounds[i + 1] - mass_bounds[i]
# normalize to avoid numerical issues (too small or too large)
mass_weights /= np.average(mass_weights)
return mass_weights