import numpy as np
import matplotlib.pyplot as plt

from beast.physicsmodel.priormodel import PriorDustModel

fig, ax = plt.subplots()

# rv grid with linear spacing
rvs = np.linspace(2.0, 6.0, num=200)

dust_prior_models = [
    {"name": "flat"},
    {"name": "lognormal", "mean": 3.1, "sigma": 0.25},
    {
        "name": "two_lognormal",
        "mean1": 3.1,
        "mean2": 4.5,
        "sigma1": 0.1,
        "sigma2": 0.2,
        "N1_to_N2": 2.0 / 5.0
    }
]

for dmod in dust_prior_models:
    pmod = PriorDustModel(dmod)
    ax.plot(rvs, pmod(rvs), label=dmod["name"])

ax.set_ylabel("probability")
ax.set_xlabel("R(V)")
ax.set_title("step")
ax.legend(loc="best")
plt.tight_layout()
plt.show()