import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from beast.physicsmodel.priormodel import PriorDustModel

fig, ax = plt.subplots()

# distance grid with linear spacing
d1, d2 = (50.e3, 70.e3)
dists = np.linspace(d1, d2, num=100)
av1, av2 = (0.0, 2.0)
avs = np.arange(av1, av2, 0.025)
distim, avim = np.meshgrid(dists, avs)

dustmod = {
    "name": "step",
    "dist0": 60 * u.kpc,
    "amp1": 0.1,
    "damp2": 1.0,
    "lgsigma1": 0.05,
    "lgsigma2": 0.05}

dustprior = PriorDustModel(dustmod)
probim = dustprior(avim, y=distim)

ax.imshow(
    probim, origin="lower", aspect="auto", extent=[d1, d2, av1, av2], norm="log"
)

ax.set_ylabel("A(V)")
ax.set_xlabel("distance [kpc]")
ax.set_title("step")
plt.tight_layout()
plt.show()