#!/usr/bin/env python3
import argparse
import numpy as np
import pandas as pd
from scipy.signal import savgol_filter

MIN_POINTS = 3  # require ≥3 wavelength scales


def curvature_metrics(L, I):
    """Compute slope, curvature, monotonicity, plateau index, inflections."""
    L = np.asarray(L, float)
    I = np.asarray(I, float)

    # normalize wavelength spacing
    x = (L - L.min()) / (L.max() - L.min() + 1e-12)

    # smoothing only when enough points exist
    if len(I) >= 5:
        I_sm = savgol_filter(I, window_length=5 if len(I) >= 5 else len(I)|1,
                             polyorder=2, mode="interp")
    else:
        I_sm = I.copy()

    d1 = np.gradient(I_sm, x, edge_order=1)
    d2 = np.gradient(d1, x, edge_order=1)

    slope_mean = np.nanmean(d1)
    slope_var = np.nanvar(d1)

    curvature_mean = np.nanmean(d2)
    curvature_std = np.nanstd(d2)

    # monotonicity: sign consistency in first derivative
    sign_changes = np.sum(np.sign(d1[1:]) != np.sign(d1[:-1]))
    monotonicity = 1.0 - sign_changes / max(len(d1) - 1, 1)

    # plateau index: fraction where |slope| ≈ 0
    plateau = np.mean(np.abs(d1) < np.percentile(np.abs(d1), 25))

    # inflections = zero crossings of curvature
    infl = np.sum(np.sign(d2[1:]) != np.sign(d2[:-1]))

    return dict(
        SLOPE_MEAN=slope_mean,
        SLOPE_VAR=slope_var,
        CURVATURE_MEAN=curvature_mean,
        CURVATURE_STD=curvature_std,
        MONOTONICITY=monotonicity,
        PLATEAU_INDEX=plateau,
        INFLECTIONS=infl,
        I_RANGE=float(np.nanmax(I) - np.nanmin(I)),
    )


def compute_descriptors(df, out_path):

    rows = []
    grouped = df.groupby(["REGION", "RUN_ID"], dropna=False)

    for (region, run_id), g in grouped:

        g = g.sort_values("SCALE_KM")
        L = g["SCALE_KM"].values
        I = g["I_OBS"].values
        Z = g["Z_SCORE"].values

        role = "BASELINE" if run_id == "TPW_BASELINE" else "PEER"

        if len(L) < MIN_POINTS or np.all(np.isnan(I)):
            rows.append(dict(
                REGION=region, RUN_ID=run_id, ROLE=role,
                STATUS="INSUFFICIENT_POINTS"
            ))
            continue

        m = curvature_metrics(L, I)
        rows.append(dict(
            REGION=region,
            RUN_ID=run_id,
            ROLE=role,
            STATUS="OK",
            N_SCALES=len(L),
            **m,
            I_MEAN=float(np.nanmean(I)),
            Z_MEAN=float(np.nanmean(Z)),
        ))

    desc = pd.DataFrame(rows)

    # ---- Baseline vs Ensemble Coherence ----
    summaries = []
    for region, g in desc.groupby("REGION"):
        base = g[g["ROLE"] == "BASELINE"]
        peers = g[g["ROLE"] == "PEER"]

        if base.empty or peers.empty:
            continue

        b = base.iloc[0]
        diffs = peers.copy()

        for col in [
            "SLOPE_MEAN","CURVATURE_STD","MONOTONICITY",
            "PLATEAU_INDEX","INFLECTIONS","I_RANGE"
        ]:
            diffs[f"DELTA_{col}"] = peers[col] - b[col]

        summaries.append(dict(
            REGION=region,
            N_PEERS=len(peers),
            BASE_SLOPE_MEAN=b["SLOPE_MEAN"],
            ENSEMBLE_SLOPE_MEAN=float(peers["SLOPE_MEAN"].mean()),
            DELTA_SLOPE_MEAN=float(diffs["DELTA_SLOPE_MEAN"].mean()),

            BASE_CURVATURE_STD=b["CURVATURE_STD"],
            ENSEMBLE_CURVATURE_STD=float(peers["CURVATURE_STD"].mean()),
            DELTA_CURVATURE_STD=float(diffs["DELTA_CURVATURE_STD"].mean()),

            BASE_MONOTONICITY=b["MONOTONICITY"],
            ENSEMBLE_MONOTONICITY=float(peers["MONOTONICITY"].mean()),
            DELTA_MONOTONICITY=float(diffs["DELTA_MONOTONICITY"].mean()),

            BASE_PLATEAU_INDEX=b["PLATEAU_INDEX"],
            ENSEMBLE_PLATEAU_INDEX=float(peers["PLATEAU_INDEX"].mean()),
            DELTA_PLATEAU_INDEX=float(diffs["DELTA_PLATEAU_INDEX"].mean()),
        ))

    summary = pd.DataFrame(summaries)

    desc.to_csv(out_path, index=False)
    summary.to_csv(out_path.replace(".csv", "_summary.csv"), index=False)

    print(f"Saved descriptors → {out_path}")
    print(f"Saved ensemble summary → {out_path.replace('.csv','_summary.csv')}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True)
    ap.add_argument("--output", default="coherence_shape_descriptors.csv")
    args = ap.parse_args()

    df = pd.read_csv(args.input)
    required = {
        "RUN_ID","REGION","SCALE_KM","I_OBS","Z_SCORE"
    }
    missing = required - set(df.columns)
    if missing:
        raise RuntimeError(f"Missing columns: {missing}")

    compute_descriptors(df, args.output)


if __name__ == "__main__":
    main()
