import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

from beast.physicsmodel.dust.extinction_extension import G03_SMCBar_WD01_extension
from dust_extinction.grain_models import WD01

fig, ax = plt.subplots()

# define the extinction model
ext_model = G03_SMCBar_WD01_extension()

# generate the curves and plot them
x = np.arange(ext_model.x_range[0], ext_model.x_range[1], 0.1) / u.micron

ax.plot(1.0 / x, ext_model(x), label="G03 SMCBar WD01 ext")

dmod = WD01(modelname="SMCBar")
ax.plot(
    1.0 / x, dmod(x), label="WD01 SMCBar", linestyle="dashed", color="black"
)

ax.set_xlabel(r"$\lambda$ [$\mu m$]")
ax.set_ylabel(r"$A(x)/A(V)$")

ax.set_xscale("log")

ax.legend(loc="best")
plt.show()