#!/usr/bin/env python3
"""
Plotting script for combined TPW return diagnostics.

Inputs:
- dzdt_event_guided.nc
- dzdt_continuous.nc
- dzdt_difference.nc
- dzdt_integrated_difference.nc
- zerocontour_velocity_diagnostics.nc
- diagnostic_summary.csv

Outputs:
- fig_dzdt_maps.png
- fig_integrated_stability_diff.png
- fig_zerocontour_velocity.png
- fig_diagnostics_timeseries.png
"""

import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

# ============================================================
# LOAD DATA
# ============================================================

dz_evt = xr.open_dataset("dzdt_event_guided.nc")
dz_con = xr.open_dataset("dzdt_continuous.nc")
dz_dif = xr.open_dataset("dzdt_difference.nc")

dz_int = xr.open_dataset("dzdt_integrated_difference.nc")

zc = xr.open_dataset("zerocontour_velocity_diagnostics.nc")

summary = pd.read_csv("diagnostic_summary.csv")

lat = dz_evt["lat"].values
lon = dz_evt["lon"].values
lon2d, lat2d = np.meshgrid(lon, lat)

# Choose a representative timestep (mid-sequence)
MID_STEP = len(dz_evt["step"]) // 2

# ============================================================
# FIGURE 1 — ΔZeff / Δt MAP COMPARISON
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)

vlim = np.nanpercentile(
    np.abs(dz_evt["dzdt_event_guided"][MID_STEP]), 95
)

im0 = axes[0].pcolormesh(
    lon2d, lat2d,
    dz_evt["dzdt_event_guided"][MID_STEP],
    cmap="RdBu_r", vmin=-vlim, vmax=vlim
)
axes[0].set_title("ΔZeff/Δt (event-guided)")
plt.colorbar(im0, ax=axes[0], label="m / yr")

im1 = axes[1].pcolormesh(
    lon2d, lat2d,
    dz_con["dzdt_continuous"][MID_STEP],
    cmap="RdBu_r", vmin=-vlim, vmax=vlim
)
axes[1].set_title("ΔZeff/Δt (continuous)")
plt.colorbar(im1, ax=axes[1], label="m / yr")

im2 = axes[2].pcolormesh(
    lon2d, lat2d,
    dz_dif["dzdt_difference"][MID_STEP],
    cmap="RdBu_r", vmin=-vlim, vmax=vlim
)
axes[2].set_title("ΔZeff/Δt difference")
plt.colorbar(im2, ax=axes[2], label="m / yr")

for ax in axes:
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")

plt.savefig("fig_dzdt_maps.png", dpi=300)
plt.close()

# ============================================================
# FIGURE 2 — TIME-INTEGRATED STABILITY DIFFERENCE
# ============================================================

fig, ax = plt.subplots(figsize=(6, 4))

vlim_int = np.nanpercentile(
    np.abs(dz_int["dzdt_integrated_difference"]), 95
)

im = ax.pcolormesh(
    lon2d, lat2d,
    dz_int["dzdt_integrated_difference"],
    cmap="RdBu_r", vmin=-vlim_int, vmax=vlim_int
)

ax.set_title("Time-integrated |ΔZeff/Δt| difference\n(event − continuous)")
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
plt.colorbar(im, ax=ax, label="m")

plt.savefig("fig_integrated_stability_diff.png", dpi=300)
plt.close()

# ============================================================
# FIGURE 3 — ZERO-CONTOUR MIGRATION VELOCITY
# ============================================================

fig, ax = plt.subplots(figsize=(7, 4))

ax.plot(
    summary["time_ka"],
    summary["zerocontour_velocity_event_deg_per_yr"],
    label="Event-guided", linewidth=2
)

ax.plot(
    summary["time_ka"],
    summary["zerocontour_velocity_continuous_deg_per_yr"],
    label="Continuous", linewidth=2
)

ax.plot(
    summary["time_ka"],
    summary["zerocontour_velocity_diff_deg_per_yr"],
    label="Difference", linestyle="--"
)

ax.invert_xaxis()
ax.set_xlabel("ka BP")
ax.set_ylabel("Zero-contour velocity (deg / yr)")
ax.set_title("Equilibrium margin migration rate")
ax.grid(alpha=0.3)
ax.legend(frameon=False)

plt.savefig("fig_zerocontour_velocity.png", dpi=300)
plt.close()

# ============================================================
# FIGURE 4 — GLOBAL DIAGNOSTIC TIMESERIES
# ============================================================

fig, axes = plt.subplots(2, 1, figsize=(7, 6), sharex=True)

axes[0].plot(
    summary["time_ka"],
    summary["rms_dzdt_diff_m_per_yr"],
    linewidth=2
)
axes[0].invert_xaxis()
axes[0].set_ylabel("RMS ΔZeff/Δt diff (m/yr)")
axes[0].set_title("Global ΔZeff/Δt divergence")
axes[0].grid(alpha=0.3)

axes[1].plot(
    summary["time_ka"],
    summary["p95_dzdt_diff_m_per_yr"],
    linewidth=2
)
axes[1].invert_xaxis()
axes[1].set_ylabel("95th percentile |ΔZeff/Δt| diff (m/yr)")
axes[1].set_xlabel("ka BP")
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig("fig_diagnostics_timeseries.png", dpi=300)
plt.close()

print("✓ Diagnostic figures generated")
