import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from beast.physicsmodel.dust.extinction_extension import F19_D03_extension
from dust_extinction.grain_models import D03

fig, ax = plt.subplots()

# temp model to get the correct x range
text_model = F19_D03_extension()

# generate the curves and plot them
x = np.arange(text_model.x_range[0], text_model.x_range[1], 0.1) / u.micron

Rvs = [2.0, 3.0, 4.0, 5.0, 6.0]
for cur_Rv in Rvs:
    ext_model = F19_D03_extension(Rv=cur_Rv)
    ax.plot(1.0 / x, ext_model(x), label="F19_D03_ext R(V) = " + str(cur_Rv))

pmods = ["MWRV31", "MWRV40", "MWRV55"]
for cmod in pmods:
    dmod = D03(modelname=cmod)
    ax.plot(1.0 / x, dmod(x), label=f"D03 {cmod}", linestyle="dashed", color="black")

ax.set_xlabel(r"$\lambda$ [$\mu m$]")
ax.set_ylabel(r"$A(x)/A(V)$")

ax.set_xscale("log")

ax.legend(loc="best")
plt.show()