#!/usr/bin/env python3
"""
Adaptive, sparse, multiprocessing Moran’s I for TPW ensemble stress maps.

Features:
• O(N·k) sparse neighbor formulation
• Disk-cached neighbors (per scale)
• Ensemble-safe (per RUN_ID)
• Adaptive permutation stopping
• Multiprocessing (Apple Silicon safe)
• Progress indicators
• Ensemble ranking + confidence intervals
• TPW geometry metadata preserved in outputs
"""

import argparse
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.spatial import cKDTree
import multiprocessing as mp

EARTH_RADIUS_KM = 6371.0

# ============================================================
# Geometry utilities
# ============================================================

def haversine_km(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2)**2
    return 2 * EARTH_RADIUS_KM * np.arcsin(np.sqrt(a))


def build_neighbors(lat, lon, L_km):
    coords = np.column_stack((lat, lon))
    tree = cKDTree(coords)
    neighbors = []

    for i in tqdm(range(len(lat)),
                  desc=f"Precomputing neighbors (L={int(L_km)} km)"):
        idx = tree.query_ball_point(coords[i], r=L_km / 111.0)
        idx = [
            j for j in idx
            if j != i and haversine_km(lat[i], lon[i], lat[j], lon[j]) <= L_km
        ]
        neighbors.append(np.array(idx, dtype=np.int32))

    return neighbors


# ============================================================
# Moran’s I (sparse)
# ============================================================

def morans_I_sparse(values, neighbors):
    v = values - values.mean()
    num = 0.0
    wsum = 0

    for i, nbrs in enumerate(neighbors):
        if len(nbrs) == 0:
            continue
        num += v[i] * v[nbrs].sum()
        wsum += len(nbrs)

    den = np.sum(v * v)
    return (len(values) / wsum) * (num / den)


# ============================================================
# Multiprocessing worker state
# ============================================================

_WORK_VALUES = None
_WORK_NEIGHBORS = None

def _init_worker(values, neighbors):
    global _WORK_VALUES, _WORK_NEIGHBORS
    _WORK_VALUES = values
    _WORK_NEIGHBORS = neighbors


def _perm_moran(seed):
    rng = np.random.default_rng(seed)
    perm = rng.permutation(_WORK_VALUES)
    return morans_I_sparse(perm, _WORK_NEIGHBORS)


# ============================================================
# Main
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Ensemble CSV")
    ap.add_argument("--misfit-column", required=True, help="Misfit column name")
    ap.add_argument("--scales", nargs="+", type=float,
                    default=[250, 500, 1000, 2000, 4000])
    ap.add_argument("--alpha", type=float, default=0.01)
    ap.add_argument("--min-permutations", type=int, default=100)
    ap.add_argument("--max-permutations", type=int, default=5000)
    ap.add_argument("--workers", type=int, default=12)
    ap.add_argument("--cache-dir", default="neighbor_cache")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--output", default="morans_ensemble_results.csv")
    args = ap.parse_args()

    rng = np.random.default_rng(args.seed)
    os.makedirs(args.cache_dir, exist_ok=True)

    df = pd.read_csv(args.input)

    # --------------------------------------------------------
    # Required columns
    # --------------------------------------------------------

    required = {"RUN_ID", "LAT", "LON", args.misfit_column}
    missing = required - set(df.columns)
    if missing:
        raise RuntimeError(f"Missing required columns: {missing}")

    # --------------------------------------------------------
    # Extract TPW metadata per RUN_ID
    # --------------------------------------------------------

    TPW_META_FIELDS = ["POLE_LAT", "POLE_LON", "ROT_DEG", "ROT_SIGN"]
    missing_meta = set(TPW_META_FIELDS) - set(df.columns)
    if missing_meta:
        raise RuntimeError(f"Missing TPW metadata columns: {missing_meta}")

    tpw_meta = (
        df[["RUN_ID"] + TPW_META_FIELDS]
        .drop_duplicates(subset="RUN_ID")
        .set_index("RUN_ID")
    )

    # --------------------------------------------------------
    # Geometry (identical across ensemble)
    # --------------------------------------------------------

    first_run = df["RUN_ID"].iloc[0]
    ref = df[df["RUN_ID"] == first_run]
    lat = ref["LAT"].values
    lon = ref["LON"].values

    # --------------------------------------------------------
    # Neighbor caching
    # --------------------------------------------------------

    neighbor_cache = {}
    for L in args.scales:
        cache_file = os.path.join(args.cache_dir, f"neighbors_L{int(L)}.pkl")
        if os.path.exists(cache_file):
            print(f"Loading cached neighbors for L={L} km")
            with open(cache_file, "rb") as f:
                neighbor_cache[L] = pickle.load(f)
        else:
            nbrs = build_neighbors(lat, lon, L)
            with open(cache_file, "wb") as f:
                pickle.dump(nbrs, f)
            neighbor_cache[L] = nbrs

    # --------------------------------------------------------
    # Moran’s I evaluation
    # --------------------------------------------------------

    records = []
    run_groups = list(df.groupby("RUN_ID"))
    ctx = mp.get_context("spawn")

    for run_id, group in tqdm(run_groups,
                              desc="Evaluating ensemble members",
                              unit="run"):

        values = group[args.misfit_column].values
        meta = tpw_meta.loc[run_id]

        for L in tqdm(args.scales,
                      desc=f"Scales for {run_id}",
                      leave=False):

            neighbors = neighbor_cache[L]
            I_obs = morans_I_sparse(values, neighbors)

            k = 0
            n = 0
            null_I = []

            with ctx.Pool(
                processes=min(mp.cpu_count(), args.workers),
                initializer=_init_worker,
                initargs=(values, neighbors)
            ) as pool:

                pbar = tqdm(
                    total=args.max_permutations,
                    desc=f"Permutations (L={L} km)",
                    leave=False
                )

                while n < args.max_permutations:

                    batch = min(100, args.max_permutations - n)
                    seeds = rng.integers(0, 2**32 - 1, size=batch)
                    results = pool.map(_perm_moran, seeds)

                    for I_null in results:
                        null_I.append(I_null)
                        if abs(I_null) >= abs(I_obs):
                            k += 1
                        n += 1

                    pbar.update(batch)

                    if n >= args.min_permutations:
                        p_hat = (k + 1) / (n + 1)
                        if p_hat < args.alpha:
                            break

                pbar.close()

            null_I = np.array(null_I)
            null_mean = null_I.mean()
            null_std = null_I.std(ddof=1)
            z = (I_obs - null_mean) / null_std if null_std > 0 else np.inf
            pval = (k + 1) / (n + 1)

            records.append({
                "RUN_ID": run_id,
                "POLE_LAT": meta["POLE_LAT"],
                "POLE_LON": meta["POLE_LON"],
                "ROT_DEG": meta["ROT_DEG"],
                "ROT_SIGN": meta["ROT_SIGN"],
                "SCALE_KM": L,
                "I_OBS": I_obs,
                "I_NULL_MEAN": null_mean,
                "I_NULL_STD": null_std,
                "Z_SCORE": z,
                "P_VALUE": pval,
                "PERMUTATIONS_USED": n
            })

    results = pd.DataFrame(records)
    results.to_csv(args.output, index=False)

    # --------------------------------------------------------
    # Ensemble ranking + confidence intervals
    # --------------------------------------------------------

    summary = (
        results.groupby("RUN_ID")
        .agg(
            MEAN_I=("I_OBS", "mean"),
            MEDIAN_I=("I_OBS", "median"),
            I_5P=("I_OBS", lambda x: np.percentile(x, 5)),
            I_95P=("I_OBS", lambda x: np.percentile(x, 95))
        )
        .reset_index()
        .merge(tpw_meta.reset_index(), on="RUN_ID", how="left")
        .sort_values("MEAN_I", ascending=False)
    )

    summary["RANK"] = np.arange(1, len(summary) + 1)
    summary.to_csv("morans_ensemble_ranking.csv", index=False)

    print("\nSaved:")
    print(" •", args.output)
    print(" • morans_ensemble_ranking.csv")


if __name__ == "__main__":
    main()
