#!/usr/bin/env python3
"""
Combined diagnostics for TPW return models:

1. Effective elevation change rate (ΔZeff / Δt)
2. Zero-contour (equilibrium margin) migration velocity
3. Time-integrated |ΔZeff / Δt| as a stability proxy

Inputs:
- tpw_return_sequence_event_guided_highcadence.nc
- tpw_return_continuous_log.nc

Outputs:
- dzdt_event_guided.nc
- dzdt_continuous.nc
- dzdt_difference.nc
- dzdt_integrated_difference.nc
- zerocontour_velocity_event.nc
- zerocontour_velocity_continuous.nc
- zerocontour_velocity_difference.nc
- diagnostic_summary.csv
"""

import numpy as np
import xarray as xr
import pandas as pd
from scipy.spatial import cKDTree
from skimage import measure

# ============================================================
# INPUT FILES
# ============================================================

NC_EVENT = "tpw_return_sequence_event_guided_highcadence.nc"
NC_CONT  = "tpw_return_continuous_log.nc"

# ============================================================
# LOAD DATA
# ============================================================

ds_evt = xr.open_dataset(NC_EVENT)
ds_con = xr.open_dataset(NC_CONT)

Z_evt = ds_evt["effective_elevation"].values
Z_con = ds_con["effective_elevation"].values

lat = ds_evt["lat"].values
lon = ds_evt["lon"].values
time_ka = ds_evt["time_ka"].values

lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")

# Uniform timestep (years)
dt_years = (time_ka[0] - time_ka[1]) * 1000.0

# ============================================================
# 1. ΔZeff / Δt DIAGNOSTIC
# ============================================================

dzdt_evt = np.diff(Z_evt, axis=0) / dt_years
dzdt_con = np.diff(Z_con, axis=0) / dt_years
dzdt_diff = dzdt_evt - dzdt_con

# Time-integrated magnitude (stability proxy)
dzdt_int_evt = np.sum(np.abs(dzdt_evt), axis=0) * dt_years
dzdt_int_con = np.sum(np.abs(dzdt_con), axis=0) * dt_years
dzdt_int_diff = dzdt_int_evt - dzdt_int_con

# ============================================================
# 2. ZERO-CONTOUR EXTRACTION
# ============================================================

def extract_zero_contour(z, lat, lon):
    contours = measure.find_contours(z, level=0.0)
    pts = []
    for c in contours:
        clat = np.interp(c[:,0], np.arange(len(lat)), lat)
        clon = np.interp(c[:,1], np.arange(len(lon)), lon)
        pts.append(np.column_stack([clat, clon]))
    if pts:
        return np.vstack(pts)
    return None

def contour_migration_speed(contours, dt):
    """
    Computes mean nearest-neighbour contour migration per timestep
    (degrees per year)
    """
    speeds = []
    for i in range(len(contours) - 1):
        if contours[i] is None or contours[i+1] is None:
            speeds.append(np.nan)
            continue
        tree = cKDTree(contours[i])
        d, _ = tree.query(contours[i+1])
        speeds.append(np.nanmean(d) / dt)
    return np.array(speeds)

# Extract contours for all steps
contours_evt = [
    extract_zero_contour(Z_evt[i], lat, lon)
    for i in range(Z_evt.shape[0])
]

contours_con = [
    extract_zero_contour(Z_con[i], lat, lon)
    for i in range(Z_con.shape[0])
]

# Compute migration velocities
zc_vel_evt = contour_migration_speed(contours_evt, dt_years)
zc_vel_con = contour_migration_speed(contours_con, dt_years)
zc_vel_diff = zc_vel_evt - zc_vel_con

# ============================================================
# WRITE NETCDF OUTPUTS
# ============================================================

def write_nc(fname, varname, data, units, desc):
    xr.Dataset(
        data_vars=dict(
            **{varname: (("step","lat","lon"), data)}
        ),
        coords=dict(
            step=np.arange(data.shape[0]),
            lat=lat,
            lon=lon
        ),
        attrs=dict(description=desc, units=units)
    ).to_netcdf(fname)

write_nc("dzdt_event_guided.nc", "dzdt_event_guided",
         dzdt_evt, "m/yr", "ΔZeff/Δt (event-guided)")

write_nc("dzdt_continuous.nc", "dzdt_continuous",
         dzdt_con, "m/yr", "ΔZeff/Δt (continuous)")

write_nc("dzdt_difference.nc", "dzdt_difference",
         dzdt_diff, "m/yr", "ΔZeff/Δt difference")

xr.Dataset(
    data_vars=dict(
        dzdt_integrated_difference=(("lat","lon"), dzdt_int_diff)
    ),
    coords=dict(lat=lat, lon=lon),
    attrs=dict(
        description="Time-integrated |ΔZeff/Δt| difference (event - continuous)",
        units="m"
    )
).to_netcdf("dzdt_integrated_difference.nc")

xr.Dataset(
    data_vars=dict(
        zerocontour_velocity_event=("step", zc_vel_evt),
        zerocontour_velocity_continuous=("step", zc_vel_con),
        zerocontour_velocity_difference=("step", zc_vel_diff)
    ),
    coords=dict(step=np.arange(len(zc_vel_evt))),
    attrs=dict(
        description="Zero-contour migration velocity diagnostics",
        units="deg/yr"
    )
).to_netcdf("zerocontour_velocity_diagnostics.nc")

# ============================================================
# SUMMARY METRICS
# ============================================================

summary = []

for i in range(len(zc_vel_diff)):
    summary.append({
        "step": i,
        "time_ka": time_ka[i],
        "rms_dzdt_diff_m_per_yr":
            np.sqrt(np.nanmean(dzdt_diff[i]**2)),
        "p95_dzdt_diff_m_per_yr":
            np.nanpercentile(np.abs(dzdt_diff[i]), 95),
        "zerocontour_velocity_event_deg_per_yr": zc_vel_evt[i],
        "zerocontour_velocity_continuous_deg_per_yr": zc_vel_con[i],
        "zerocontour_velocity_diff_deg_per_yr": zc_vel_diff[i]
    })

pd.DataFrame(summary).to_csv(
    "diagnostic_summary.csv", index=False
)

print("✓ Combined diagnostics complete")
print("  - ΔZeff/Δt fields and differences")
print("  - Time-integrated stability proxy")
print("  - Zero-contour migration velocities")
print("  - diagnostic_summary.csv")
