#!/usr/bin/env python3
import argparse
import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from math import radians, sin, cos
from scipy.spatial import cKDTree

EARTH_RADIUS_KM = 6371.0
MIN_POINTS = 500

# ============================================================
# Geometry utilities
# ============================================================

def latlon_to_unit(lat, lon):
    lat = np.radians(lat)
    lon = np.radians(lon)
    return np.column_stack((
        np.cos(lat) * np.cos(lon),
        np.cos(lat) * np.sin(lon),
        np.sin(lat)
    ))

def chord_radius(L_km):
    theta = L_km / EARTH_RADIUS_KM
    return 2.0 * np.sin(theta / 2.0)

# ============================================================
# Neighbor construction (cached)
# ============================================================

def build_neighbors(lat, lon, L_km, cache_file):
    if os.path.exists(cache_file):
        with open(cache_file, "rb") as f:
            return pickle.load(f)

    print(f"Precomputing neighbors (L={L_km} km)")
    xyz = latlon_to_unit(lat, lon)
    tree = cKDTree(xyz)
    r = chord_radius(L_km)

    neighbors = []
    for i in tqdm(range(len(xyz))):
        idx = tree.query_ball_point(xyz[i], r)
        idx = [j for j in idx if j != i]
        neighbors.append(np.array(idx, dtype=np.int32))

    with open(cache_file, "wb") as f:
        pickle.dump(neighbors, f)

    return neighbors

# ============================================================
# Moran’s I (sparse)
# ============================================================

def morans_I_sparse(values, neighbors):
    x = values - values.mean()
    num = 0.0
    W = 0

    for i, nbrs in enumerate(neighbors):
        if len(nbrs) == 0:
            continue
        num += x[i] * x[nbrs].sum()
        W += len(nbrs)

    den = np.sum(x * x)
    if W == 0 or den == 0:
        return np.nan

    return (len(values) / W) * (num / den)

def permutation_null(values, neighbors, n_perm, rng):
    null = np.empty(n_perm)
    for i in range(n_perm):
        perm = rng.permutation(values)
        null[i] = morans_I_sparse(perm, neighbors)
    return null

# ============================================================
# Region filtering
# ============================================================

def filter_region(df, region):
    if region["type"] == "global":
        return df

    lat0, lon0, R = region["lat"], region["lon"], region["radius_km"]

    d = np.array([
        EARTH_RADIUS_KM * np.arccos(
            max(-1.0, min(1.0,
                sin(radians(lat)) * sin(radians(lat0)) +
                cos(radians(lat)) * cos(radians(lat0)) *
                cos(radians(lon - lon0))
            ))
        )
        for lat, lon in zip(df["LAT"], df["LON"])
    ])

    return df[d <= R]

# ============================================================
# Main
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True)
    ap.add_argument("--misfit-column", required=True)
    ap.add_argument("--scales", required=True)
    ap.add_argument("--permutations", type=int, default=1000)
    ap.add_argument("--region", action="append")
    ap.add_argument("--output", default="morans_results.csv")
    ap.add_argument("--cache-dir", default="neighbor_cache")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    os.makedirs(args.cache_dir, exist_ok=True)
    scales = [float(s) for s in args.scales.split(",")]
    rng = np.random.default_rng(args.seed)

    df = pd.read_csv(args.input)

    # --------------------------------------------------------
    # Euler pole validation / normalization
    # --------------------------------------------------------
    if {"POLE_LAT", "POLE_LON"}.issubset(df.columns):
        pole_lat_col = "POLE_LAT"
        pole_lon_col = "POLE_LON"
    elif {"EULER_LAT", "EULER_LON"}.issubset(df.columns):
        pole_lat_col = "EULER_LAT"
        pole_lon_col = "EULER_LON"
    else:
        raise RuntimeError(
            "Input CSV must contain POLE_LAT/POLE_LON or EULER_LAT/EULER_LON"
        )

    if "RUN_ID" not in df.columns:
        df["RUN_ID"] = "DEFAULT"

    pole_lookup = (
        df.groupby("RUN_ID")[[pole_lat_col, pole_lon_col]]
          .first()
          .rename(columns={
              pole_lat_col: "POLE_LAT",
              pole_lon_col: "POLE_LON"
          })
    )

    # --------------------------------------------------------
    # Regions
    # --------------------------------------------------------
    regions = []
    if args.region:
        for r in args.region:
            name, rest = r.split(":")
            lat, lon, rad = map(float, rest.split(","))
            regions.append({
                "name": name,
                "type": "circle",
                "lat": lat,
                "lon": lon,
                "radius_km": rad
            })
    else:
        regions.append({"name": "GLOBAL", "type": "global"})

    rows = []
    skipped_regions = set()

    for run_id, df_run in df.groupby("RUN_ID"):
        print(f"\nRUN_ID: {run_id}")

        pole_lat = pole_lookup.loc[run_id, "POLE_LAT"]
        pole_lon = pole_lookup.loc[run_id, "POLE_LON"]

        for region in regions:
            df_r = filter_region(df_run, region)

            if len(df_r) < MIN_POINTS:
                skipped_regions.add(region["name"])
                continue

            lat = df_r["LAT"].values
            lon = df_r["LON"].values
            vals = df_r[args.misfit_column].values

            for L in scales:
                cache_file = os.path.join(
                    args.cache_dir,
                    f"neighbors_{region['name']}_L{int(L)}.pkl"
                )

                neighbors = build_neighbors(lat, lon, L, cache_file)

                I_obs = morans_I_sparse(vals, neighbors)
                if not np.isfinite(I_obs):
                    skipped_regions.add(region["name"])
                    continue

                null = permutation_null(vals, neighbors, args.permutations, rng)

                rows.append({
                    "RUN_ID": run_id,
                    "REGION": region["name"],
                    "SCALE_KM": L,
                    "N_POINTS": len(vals),
                    "I_OBS": I_obs,
                    "I_NULL_MEAN": null.mean(),
                    "I_NULL_STD": null.std(),
                    "Z_SCORE": (I_obs - null.mean()) / null.std(),
                    "P_VALUE": (null >= I_obs).mean(),
                    "I_P05": np.percentile(null, 5),
                    "I_P50": np.percentile(null, 50),
                    "I_P95": np.percentile(null, 95),
                    "POLE_LAT": pole_lat,
                    "POLE_LON": pole_lon
                })

    if not rows:
        raise RuntimeError("No valid Moran results were produced for any region.")

    results = pd.DataFrame(rows)

    if "REGION" not in results.columns:
        raise RuntimeError(
            "Internal error: REGION column missing from results."
        )

    results.to_csv(args.output, index=False)
    print(f"\nSaved results → {args.output}")

    rankings = (
        results
        .sort_values(["REGION", "SCALE_KM", "I_OBS"],
                     ascending=[True, True, False])
        .assign(RANK=lambda d:
            d.groupby(["REGION", "SCALE_KM"]).cumcount() + 1)
    )

    rank_file = args.output.replace(".csv", "_rankings.csv")
    rankings.to_csv(rank_file, index=False)
    print(f"Saved rankings → {rank_file}")

    if skipped_regions:
        print("\nSkipped regions due to insufficient or invalid data:")
        for r in sorted(skipped_regions):
            print(f"  - {r}")

if __name__ == "__main__":
    main()
