import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from beast.physicsmodel.priormodel import PriorDustModel

fig, ax = plt.subplots()

# distance grid with linear spacing
d1, d2 = (50.e3, 70.e3)
dists = np.linspace(d1, d2, num=100)
fA1, fA2 = (0.0, 1.0)
fAs = np.arange(fA1, fA2, 0.01)
distim, fAim = np.meshgrid(dists, fAs)

dustmod = {
    "name": "step",
    "dist0": 60 * u.kpc,
    "amp1": 0.1,
    "damp2": 0.8,
    "lgsigma1": 0.1,
    "lgsigma2": 0.01}

dustprior = PriorDustModel(dustmod)
probim = dustprior(fAim, y=distim)

ax.imshow(
    probim, origin="lower", aspect="auto", extent=[d1, d2, fA1, fA2], norm="log"
)

ax.set_ylabel(r"$f_A$")
ax.set_xlabel("distance [kpc]")
plt.tight_layout()
plt.show()