#!/usr/bin/env python3
"""
Generate ALL figures for the TPW return companion paper.

Conventions:
- Observed (baseline) values are derived from site_sampling_timeseries.csv
- Baselines are plotted as RED DASHED vertical lines
- Baselines are always layered ABOVE null distributions (high z-order)

Required input files:
- model_event_rate_envelope.csv
- dzdt_continuous.nc
- dzdt_event_guided.nc
- site_sampling_timeseries.csv
- null_distributions.csv
- temporal_null_distributions.csv
"""

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

plt.rcParams.update({
    "figure.dpi": 300,
    "font.size": 10
})

# ============================================================
# LOAD SITE METRICS (SOURCE OF OBSERVED BASELINES)
# ============================================================

df_sites = pd.read_csv("site_sampling_timeseries.csv")

# Observed baselines (derived, not read)
obs_dzdt_homo = df_sites.loc[
    df_sites["group"] == "Homo", "mean_abs_dzdt"
].mean()

obs_dzdt_civ = df_sites.loc[
    df_sites["group"] == "Civilization", "mean_abs_dzdt"
].mean()

obs_var_homo = df_sites.loc[
    df_sites["group"] == "Homo", "var_zero_contour_distance_deg2"
].mean()

obs_var_civ = df_sites.loc[
    df_sites["group"] == "Civilization", "var_zero_contour_distance_deg2"
].mean()

# ============================================================
# UTILITY: NULL PLOT WITH BASELINE
# ============================================================

def plot_null_with_baseline(
    null_values,
    baseline_value,
    title,
    xlabel,
    outfile,
    bins=40
):
    plt.figure(figsize=(5, 3))

    plt.hist(
        null_values,
        bins=bins,
        density=True,
        alpha=0.85,
        zorder=1
    )

    plt.axvline(
        baseline_value,
        color="red",
        linestyle="--",
        linewidth=2.0,
        zorder=10,
        label="Observed"
    )

    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel("Probability density")
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.savefig(outfile)
    plt.close()


# ============================================================
# FIGURE 1 — EVENT-RATE ENVELOPE COMPARISON
# ============================================================

df_evt = pd.read_csv("model_event_rate_envelope.csv")

time_evt = df_evt["time_ka"].values
rate_evt = df_evt.drop(columns=["time_ka"]).iloc[:, 0].values

ds_cont = xr.open_dataset("dzdt_continuous.nc")
dzdt_cont = ds_cont[list(ds_cont.data_vars)[0]].values
rate_cont = np.mean(np.abs(dzdt_cont), axis=(1, 2))

time_cont = (
    ds_cont["time_ka"].values
    if "time_ka" in ds_cont
    else time_evt
)

plt.figure(figsize=(8, 4))
plt.plot(time_cont, rate_cont, label="Continuous return")
plt.plot(time_evt, rate_evt, label="Event-driven return")
plt.xlabel("Time (ka BP)")
plt.ylabel("Event-rate proxy")
plt.legend()
plt.tight_layout()
plt.savefig("fig_model_event_rate_envelope.png")
plt.close()


# ============================================================
# FIGURE 2 — SITE-CONDITIONED |dz/dt| TIME SERIES
# ============================================================

plt.figure(figsize=(8, 4))
for g in ["Homo", "Civilization"]:
    sub = df_sites[df_sites["group"] == g]
    plt.plot(sub["time_ka"], sub["mean_abs_dzdt"], label=g)

plt.xlabel("Time (ka BP)")
plt.ylabel("Mean |ΔZeff / Δt|")
plt.legend()
plt.tight_layout()
plt.savefig("fig_sites_dzdt_timeseries.png")
plt.close()


# ============================================================
# FIGURE 3 — COMPOSITE RETURN DIAGNOSTICS
# ============================================================

fig, axes = plt.subplots(4, 1, figsize=(8, 10), sharex=True)

axes[0].plot(time_evt, rate_evt)
axes[0].set_ylabel("Event-rate")

for i, g in enumerate(["Homo", "Civilization"]):
    sub = df_sites[df_sites["group"] == g]
    axes[i + 1].plot(sub["time_ka"], sub["mean_abs_dzdt"])
    axes[i + 1].set_ylabel(f"{g} |ΔZeff / Δt|")

civ = df_sites[df_sites["group"] == "Civilization"]
axes[3].plot(civ["time_ka"], civ["var_zero_contour_distance_deg2"])
axes[3].set_ylabel("Civ. stability variance")
axes[3].set_xlabel("Time (ka BP)")

plt.tight_layout()
plt.savefig("fig_composite_return_diagnostics.png")
plt.close()


# ============================================================
# SPATIAL NULL DISTRIBUTIONS
# ============================================================

df_null = pd.read_csv("null_distributions.csv")

plot_null_with_baseline(
    df_null["dz_homo"],
    obs_dzdt_homo,
    "Spatial null: Homo |ΔZeff / Δt|",
    "Mean |ΔZeff / Δt|",
    "fig_null_homo_dzdt.png"
)

plot_null_with_baseline(
    df_null["var_homo"],
    obs_var_homo,
    "Spatial null: Homo variance",
    "Variance of equilibrium-margin distance (deg$^2$)",
    "fig_null_homo_variance.png"
)

plot_null_with_baseline(
    df_null["dz_civ"],
    obs_dzdt_civ,
    "Spatial null: Civilization |ΔZeff / Δt|",
    "Mean |ΔZeff / Δt|",
    "fig_null_civ_dzdt.png"
)

plot_null_with_baseline(
    df_null["var_civ"],
    obs_var_civ,
    "Spatial null: Civilization variance",
    "Variance of equilibrium-margin distance (deg$^2$)",
    "fig_null_civ_variance.png"
)


# ============================================================
# TEMPORAL NULL DISTRIBUTIONS
# ============================================================

df_t = pd.read_csv("temporal_null_distributions.csv")

plot_null_with_baseline(
    df_t["homo_dzdt"],
    obs_dzdt_homo,
    "Temporal null: Homo |ΔZeff / Δt|",
    "Mean |ΔZeff / Δt|",
    "fig_temporal_null_homo_dzdt.png"
)

plot_null_with_baseline(
    df_t["homo_var"],
    obs_var_homo,
    "Temporal null: Homo variance",
    "Phase-alignment metric",
    "fig_temporal_null_homo_variance.png"
)

plot_null_with_baseline(
    df_t["civ_dzdt"],
    obs_dzdt_civ,
    "Temporal null: Civilization |ΔZeff / Δt|",
    "Mean |ΔZeff / Δt|",
    "fig_temporal_null_civ_dzdt.png"
)

plot_null_with_baseline(
    df_t["civ_var"],
    obs_var_civ,
    "Temporal null: Civilization variance",
    "Phase-alignment metric",
    "fig_temporal_null_civ_variance.png"
)

print("✓ ALL FIGURES GENERATED SUCCESSFULLY (baseline = red dashed)")
