#!/usr/bin/env python3
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ---------------------------------------------------------------------
# Robust finite-difference slope estimator (dI / d log L)
# ---------------------------------------------------------------------
def finite_difference_slopes(L, I):
    """
    Compute dI/d(log L) with numerical stability.
    Handles duplicate scales, missing values, and sparse scale sets.
    Returns one slope per scale value.
    """
    L = np.asarray(L, dtype=float)
    I = np.asarray(I, dtype=float)

    logL = np.log(L)
    slopes = np.full_like(I, np.nan)

    n = len(L)
    if n < 2:
        return slopes

    for i in range(n):
        il = i - 1 if i > 0 else None
        ir = i + 1 if i < n - 1 else None

        # central difference (preferred)
        if il is not None and ir is not None:
            dlog = logL[ir] - logL[il]
            if dlog != 0:
                slopes[i] = (I[ir] - I[il]) / dlog
                continue

        # forward fallback
        if ir is not None:
            dlog = logL[ir] - logL[i]
            if dlog != 0:
                slopes[i] = (I[ir] - I[i]) / dlog
                continue

        # backward fallback
        if il is not None:
            dlog = logL[i] - logL[il]
            if dlog != 0:
                slopes[i] = (I[i] - I[il]) / dlog
                continue

    return slopes


# ---------------------------------------------------------------------
# Load and merge coherence + Moran’s outputs
# ---------------------------------------------------------------------
def load_data(coherence_path, morans_path):
    coh = pd.read_csv(coherence_path)
    mor = pd.read_csv(morans_path)

    mor = mor.copy()
    mor["I_EXCESS_MEAN"] = mor["I_OBS"] - mor["I_NULL_MEAN"]

    df = pd.merge(
        coh,
        mor[
            [
                "RUN_ID",
                "REGION",
                "SCALE_KM",
                "I_OBS",
                "I_NULL_MEAN",
                "I_EXCESS_MEAN",
            ]
        ],
        on=["RUN_ID", "REGION"],
        how="inner",
    )

    return df


# ---------------------------------------------------------------------
# Global scatter — baseline highlighted
# ---------------------------------------------------------------------
def plot_global(df, outdir):
    fig, ax = plt.subplots(figsize=(10, 6))

    peers = df[df["ROLE"] != "BASELINE"]
    base = df[df["ROLE"] == "BASELINE"]

    ax.scatter(
        peers["SLOPE_MEAN"],
        peers["I_EXCESS_MEAN"],
        c="0.5",
        marker="x",
        alpha=1.0,
        s=26,
        label="NULL",
        zorder=1,
    )

    ax.scatter(
        base["SLOPE_MEAN"],
        base["I_EXCESS_MEAN"],
        c="red",
        s=50,
        zorder=4,
        label="BASELINE",
    )

    ax.axhline(0, color="black", lw=0.8)
    ax.set_title("Global coherence vs spatial clustering strength (per-scale)")
    ax.set_xlabel("Coherence (finite-difference slope dI / d log L)")
    ax.set_ylabel("Moran’s I excess vs null (I_obs − I_null_mean)")
    ax.legend(frameon=False)

    fig.tight_layout()
    fig.savefig(f"{outdir}/global_coherence_vs_I.png", dpi=220)
    plt.close(fig)


# ---------------------------------------------------------------------
# Combined mean baseline trajectory (BLUE LINE)
# ---------------------------------------------------------------------
def plot_combined_mean(df, outdir):
    base = df[df["ROLE"] == "BASELINE"].copy()

    mean_curve = (
        base.groupby("SCALE_KM", as_index=False)[["I_OBS", "I_EXCESS_MEAN"]]
        .mean()
        .sort_values("SCALE_KM")
    )

    mean_curve["COHERENCE_SLOPE"] = finite_difference_slopes(
        mean_curve["SCALE_KM"].values,
        mean_curve["I_OBS"].values,
    )

    fig, ax = plt.subplots(figsize=(10, 6))

    ax.plot(
        mean_curve["COHERENCE_SLOPE"],
        mean_curve["I_EXCESS_MEAN"],
        "-o",
        color="blue",
        lw=2.2,
        ms=6,
        zorder=4,
        label="Combined Mean (Baseline)",
    )

    for _, r in mean_curve.iterrows():
        ax.text(
            r["COHERENCE_SLOPE"],
            r["I_EXCESS_MEAN"],
            f"{int(r['SCALE_KM'])} km",
            fontsize=9,
            color="blue",
            ha="left",
            va="bottom",
        )

    ax.axhline(0, color="black", lw=0.8)
    ax.set_title("Combined mean coherence vs Moran’s I (baseline only)")
    ax.set_xlabel("Coherence (finite-difference slope dI / d log L)")
    ax.set_ylabel("Moran’s I excess vs null")

    ax.legend(frameon=False)
    fig.tight_layout()
    fig.savefig(f"{outdir}/combined_mean_coherence_vs_I.png", dpi=220)
    plt.close(fig)


# ---------------------------------------------------------------------
# Regional plot — baseline trajectory vs null ensemble
# ---------------------------------------------------------------------
def plot_region(df_region, outdir):
    region = df_region["REGION"].iloc[0]

    baseline = (
        df_region[df_region["ROLE"] == "BASELINE"]
        .sort_values("SCALE_KM")
        .copy()
    )
    peers = df_region[df_region["ROLE"] != "BASELINE"]

    slopes = finite_difference_slopes(
        baseline["SCALE_KM"].values,
        baseline["I_OBS"].values,
    )
    baseline["COHERENCE_SLOPE"] = slopes

    fig, ax = plt.subplots(figsize=(10, 6))

    ax.scatter(
        peers["SLOPE_MEAN"],
        peers["I_EXCESS_MEAN"],
        s=26,
        c="0.5",
        marker="x",
        alpha=1.0,
        label="NULL",
        zorder=1,
    )

    ax.plot(
        baseline["COHERENCE_SLOPE"],
        baseline["I_EXCESS_MEAN"],
        "-o",
        color="red",
        lw=2.0,
        ms=6,
        zorder=4,
        label="BASELINE",
    )

    for _, r in baseline.iterrows():
        ax.text(
            r["COHERENCE_SLOPE"],
            r["I_EXCESS_MEAN"],
            f"{int(r['SCALE_KM'])} km",
            fontsize=9,
            color="red",
            ha="left",
            va="bottom",
        )

    ax.axhline(0, color="black", lw=0.8)
    ax.set_title(f"{region}: Coherence vs Moran’s I (per-scale)")
    ax.set_xlabel("Coherence (finite-difference slope dI / d log L)")
    ax.set_ylabel("Moran’s I excess vs null")

    ax.legend(frameon=False)
    fig.tight_layout()
    fig.savefig(f"{outdir}/{region}_coherence_vs_I.png", dpi=220)
    plt.close(fig)


# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(
        description="Compare coherence descriptors vs Moran’s I clustering strength"
    )
    parser.add_argument("--coherence", required=True)
    parser.add_argument("--morans", required=True)
    parser.add_argument("--output", required=True)

    args = parser.parse_args()
    os.makedirs(args.output, exist_ok=True)

    df = load_data(args.coherence, args.morans)

    print("\nDetected role distribution:\n", df["ROLE"].value_counts(), "\n")

    plot_global(df, args.output)
    plot_combined_mean(df, args.output)

    for _, sub in df.groupby("REGION"):
        plot_region(sub, args.output)

    print(f"\nPlots written to {os.path.abspath(args.output)}")


if __name__ == "__main__":
    main()
