#!/usr/bin/env python3
"""
Diagnostic comparison of effective elevation change rates
between event-guided and continuous TPW return models.

Requires:
- tpw_return_sequence_event_guided_highcadence.nc
- tpw_return_continuous_log.nc

Outputs:
- dzdt_event_guided.nc
- dzdt_continuous.nc
- dzdt_difference.nc
- dzdt_summary.csv
"""

import numpy as np
import xarray as xr
import pandas as pd

# ============================================================
# INPUT FILES
# ============================================================

NC_EVENT = "tpw_return_sequence_event_guided_highcadence.nc"
NC_CONT  = "tpw_return_continuous_log.nc"

# ============================================================
# LOAD DATA
# ============================================================

ds_evt = xr.open_dataset(NC_EVENT)
ds_con = xr.open_dataset(NC_CONT)

Z_evt = ds_evt["effective_elevation"].values
Z_con = ds_con["effective_elevation"].values

lat = ds_evt["lat"].values
lon = ds_evt["lon"].values
time_ka = ds_evt["time_ka"].values

# Time step in years (uniform by construction)
dt_years = (time_ka[0] - time_ka[1]) * 1000.0

# ============================================================
# COMPUTE ΔZ/Δt
# ============================================================

dzdt_evt = np.diff(Z_evt, axis=0) / dt_years
dzdt_con = np.diff(Z_con, axis=0) / dt_years

dzdt_diff = dzdt_evt - dzdt_con

# ============================================================
# WRITE NETCDF OUTPUTS
# ============================================================

xr.Dataset(
    data_vars=dict(
        dzdt_event_guided=(("step","lat","lon"), dzdt_evt)
    ),
    coords=dict(
        step=np.arange(dzdt_evt.shape[0]),
        lat=lat,
        lon=lon
    ),
    attrs=dict(
        description="Effective elevation change rate (event-guided)",
        units="m/yr"
    )
).to_netcdf("dzdt_event_guided.nc")

xr.Dataset(
    data_vars=dict(
        dzdt_continuous=(("step","lat","lon"), dzdt_con)
    ),
    coords=dict(
        step=np.arange(dzdt_con.shape[0]),
        lat=lat,
        lon=lon
    ),
    attrs=dict(
        description="Effective elevation change rate (continuous logarithmic)",
        units="m/yr"
    )
).to_netcdf("dzdt_continuous.nc")

xr.Dataset(
    data_vars=dict(
        dzdt_difference=(("step","lat","lon"), dzdt_diff)
    ),
    coords=dict(
        step=np.arange(dzdt_diff.shape[0]),
        lat=lat,
        lon=lon
    ),
    attrs=dict(
        description="ΔZ/Δt difference (event-guided minus continuous)",
        units="m/yr"
    )
).to_netcdf("dzdt_difference.nc")

# ============================================================
# GLOBAL SUMMARY METRICS
# ============================================================

summary = []

for i in range(dzdt_diff.shape[0]):
    diff = dzdt_diff[i]

    summary.append({
        "step": i,
        "time_ka": time_ka[i],
        "rms_diff_m_per_yr": np.sqrt(np.nanmean(diff**2)),
        "p95_abs_diff_m_per_yr": np.nanpercentile(np.abs(diff), 95),
        "max_abs_diff_m_per_yr": np.nanmax(np.abs(diff))
    })

df = pd.DataFrame(summary)
df.to_csv("dzdt_summary.csv", index=False)

print("✓ ΔZ/Δt diagnostics complete")
print("  - dzdt_event_guided.nc")
print("  - dzdt_continuous.nc")
print("  - dzdt_difference.nc")
print("  - dzdt_summary.csv")
