import numpy as np
import matplotlib.pyplot as plt

plotdir = "/eos/home-s/spalluot/www/MTD/Production/LYSO/"

import mplhep as hep

hep.style.use("CMS")  # Stile CMS mplhep

def plot_cms(x, y, yerr, ylabel, filename, horizontal=None, batches=None, ymin=None, ymax=None):
    fig, ax = plt.subplots(figsize=(8,7))    
    ax.errorbar(x, y, yerr=yerr, fmt='o', color='k', capsize=3, markersize=6, label='Data')
    
    w = 1/yerr**2
    y_fit = np.average(y, weights=w)
    ax.axhline(y_fit, color='green', linestyle='-', label='Average')
    ax.fill_between([x[0]-0.5, x[-1]+0.5], y_fit-np.std(y)*0.5, y_fit+np.std(y)*0.5, color='green', alpha=0.2)

    if horizontal is not None:
        if isinstance(horizontal, (int, float)):
            horizontals = [horizontal]
        else:
            horizontals = horizontal
        it = 0 
        for h in horizontals:
            if it == 0:
                ax.axhline(h, color='red', linestyle='--', linewidth=1.5, label="Tender spec")
                it = it+1
            else:
                ax.axhline(h, color='red', linestyle='--', linewidth=1.5)
    #if horizontal is not None:
    #ax.axhline(horizontal, color='red', linestyle='--', linewidth=1.5, label="Tender spec")
        
    if batches is not None:
        ax.set_xticks(x)
        ax.set_xticklabels(batches, rotation=45, ha='right')
    ax.set_xlabel("Batch")

    if ymin is not None and ymax is not None:
        ax.set_ylim(ymin, ymax)
    ax.set_ylabel(ylabel)
    
    # CMS label
    #hep.cms.label("Private Work", loc=0, data=False, ax=ax)
    
    #ax.legend(frameon=False)
    ax.legend(frameon=False, loc='upper right')
    plt.tight_layout()
    plt.savefig(f"{plotdir}/{filename}.png")
    plt.savefig(f"{plotdir}/{filename}.pdf")
    plt.close(fig)

batches = ["PB1","PB2","PB3","PB4","PB5","PB6","PB7","PB8","PB9"]
x = np.arange(len(batches))

# --- Bar data ---
LO = np.array([5500, 5400, 5200, 5150, 5400, 5450, 5250, 5700, 6200])
LO_err = np.array([350, 300, 350, 250, 500, 350, 400, 350, 300])
DT = np.array([41.5, 42.2, 43.2, 43.4, 42.7, 42.9, 42.2, 42.74, 43.4])
DT_err = np.array([1.0, 1.0, 0.9, 0.9, 0.9, 0.7, 0.9, 0.9, 0.7])
FOM = np.array([134, 128, 120, 119, 126, 127, 124.4, 132.9, 144])
FOM_err = np.array([9, 7, 7, 7, 10, 7, 7.9, 7.3, 7.5])

plot_cms(x, LO, LO_err, "Light Output [ph/MeV]", "bar_LO", horizontal=4000, batches=batches, ymax=7500, ymin=3800)
plot_cms(x, DT, DT_err, "Decay Time [ns]", "bar_DT", horizontal=45, batches=batches, ymax=50, ymin=40)
plot_cms(x, FOM, FOM_err, "LO/DT [ph/(MeV ns)]", "bar_tres", horizontal=105, batches=batches, ymax=170 , ymin=100)



# --- Array data ---
LO_opt = np.array([5200, 5200, 5250, 5100, 5200, 5250, 5250, 5250, 5350])
LO_opt_err = np.array([100, 100, 100, 150, 100, 150, 150, 200, 150])
time_res = np.array([131, 131, 132, 133, 133, 134, 133, 133, 133])
time_res_err = np.array([3, 2, 2, 2, 2, 2, 3, 2, 2])
cross_talk = np.array([9.5, 9.4, 8.1, 7.6, 8.5, 9.5, 9.2, 8.7, 9.3])
cross_talk_err = np.array([2, 2, 2, 2, 2, 2, 3, 2, 2])
LO_loss = np.array([-8.6, -6.8, -9.2, -5.1, -8.5, -8.4, -9.0, -8.7, -7.1])
LO_loss_err = np.array([1.3, 2.7, 2.5, 2.8, 1.7, 2.7, 1.9, 1.8, 2.8])

plot_cms(x, LO_opt, LO_opt_err, "Light Output [ph/MeV]", "array_LO", horizontal=4000, batches=batches, ymax=6500, ymin=3800)
plot_cms(x, time_res, time_res_err, "Time Resolution [ps]", "array_tres", horizontal=140, batches=batches, ymax=145 , ymin=125)
plot_cms(x, cross_talk, cross_talk_err, "Cross Talk [%]", "array_XT", batches=batches, ymax=30, ymin=0, horizontal=25)
plot_cms(x, LO_loss, LO_loss_err, "LO Loss after Irr [%]", "array_LO_loss", horizontal=-20,  batches=batches, ymax=2 , ymin=-22)


# --- Array dimensions ----
length = np.array([54.691, 54.687, 54.690, 54.687, 54.683, 54.679, 54.691, 54.688, 54.688])
length_err = np.array([0.008, 0.008, 0.008, 0.009, 0.008, 0.006, 0.006, 0.004, 0.007])
non_planarity = np.array([0.023, 0.022, 0.023, 0.022, 0.023, 0.025, 0.019, 0.023, 0.022])
non_planarity_err = np.array([0.01, 0.007, 0.009, 0.01, 0.01, 0.014, 0.006, 0.007, 0.006])
width = np.array([51.37, 51.38, 51.38, 51.38, 51.36, 51.36, 51.37, 51.36, 51.36])
width_err = np.array([0.05, 0.04, 0.03, 0.03, 0.02, 0.03, 0.03, 0.03, 0.03])
thickness = np.array([4.11, 4.11, 4.12, 4.11, 4.106, 4.108, 4.108, 4.105, 4.11])
thickness_err = np.array([0.01, 0.01, 0.02, 0.01, 0.009, 0.012, 0.011, 0.013, 0.01])

# Generate Geometrical property plots
plot_cms(x, length, length_err, "Length [mm]", "array_length", batches=batches, ymax=54.755, ymin=54.645, horizontal=[54.65, 54.75])
plot_cms(x, non_planarity, non_planarity_err, "Non-Planarity [mm]", "array_non_planarity", batches=batches,ymax=0.040, ymin=0.0)
plot_cms(x, width, width_err, "Width [mm]", "array_width", batches=batches, ymax=51.55, ymin=51.20, horizontal=[51.23, 51.53])
plot_cms(x, thickness, thickness_err, "Thickness [mm]", "array_thickness", batches=batches, ymax=4.22, ymin=4.00, horizontal=[4.01, 4.21])
