#!/usr/bin/env python3
"""
shear-sac-permutation.py

Permutation-based Moran's I spatial autocorrelation test
for WSM–shear misfit fields.

Correct null model for scalar spatial statistics.

Author: Craig Stone
"""

import numpy as np
import pandas as pd
from pyproj import Geod
from scipy.spatial import cKDTree
from tqdm import tqdm

# --------------------------------------------------
# CONFIGURATION
# --------------------------------------------------

CSV = "wsm_shear_with_null.csv"

L_SCALES = [250, 500, 1000, 2000, 3000, 4000]   # km
D_MAX_FACTOR = 3.0
N_PERM = 1000                      # permutations per scale

# --------------------------------------------------
# LOAD DATA
# --------------------------------------------------

df = pd.read_csv(CSV)

for col in ["LAT", "LON", "NET1_MISFIT", "NET2_MISFIT"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")

df = df.dropna(subset=["LAT", "LON", "NET1_MISFIT", "NET2_MISFIT"])

df["OBS_MISFIT"] = np.minimum(df["NET1_MISFIT"], df["NET2_MISFIT"])

lat = df["LAT"].values
lon = df["LON"].values
values_obs = df["OBS_MISFIT"].values

N = len(df)
print(f"Using {N} spatial points")

# --------------------------------------------------
# SPATIAL INDEX
# --------------------------------------------------

geod = Geod(ellps="WGS84")

lat_rad = np.radians(lat)
lon_rad = np.radians(lon)

xyz = np.column_stack([
    np.cos(lat_rad) * np.cos(lon_rad),
    np.cos(lat_rad) * np.sin(lon_rad),
    np.sin(lat_rad)
])

tree = cKDTree(xyz)

# --------------------------------------------------
# PRECOMPUTE NEIGHBORS + WEIGHTS
# --------------------------------------------------

def precompute_neighbors_weights(L_km):
    d_max = D_MAX_FACTOR * L_km
    radius = d_max / 6371.0  # radians

    neighbors = []
    weights = []

    for i in tqdm(range(N), desc=f"Precomputing (L={L_km} km)"):
        idx = tree.query_ball_point(xyz[i], r=radius)

        js = []
        ws = []

        for j in idx:
            if i == j:
                continue

            _, _, d = geod.inv(lon[i], lat[i], lon[j], lat[j])
            d /= 1000.0

            if d <= d_max:
                js.append(j)
                ws.append(np.exp(-d / L_km))

        neighbors.append(np.array(js, dtype=np.int32))
        weights.append(np.array(ws, dtype=np.float32))

    return neighbors, weights

# --------------------------------------------------
# MORAN'S I USING PRECOMPUTED WEIGHTS
# --------------------------------------------------

def morans_I(values, neighbors, weights):
    mean = values.mean()
    var = np.sum((values - mean) ** 2)

    I_num = 0.0
    W = 0.0

    for i in range(N):
        vi = values[i] - mean
        if vi == 0:
            continue

        js = neighbors[i]
        ws = weights[i]

        if len(js) == 0:
            continue

        vj = values[js] - mean
        I_num += np.sum(ws * vi * vj)
        W += np.sum(ws)

    return (N / W) * (I_num / var)

# --------------------------------------------------
# MAIN ANALYSIS
# --------------------------------------------------

results = []

print("\nRunning permutation-based Moran's I tests")

for L in L_SCALES:
    print(f"\n=== Scale L = {L} km ===")

    neighbors, weights = precompute_neighbors_weights(L)

    # Observed Moran's I
    I_obs = morans_I(values_obs, neighbors, weights)

    # Permutation null
    I_perm = []

    for _ in tqdm(range(N_PERM), desc="Permutations"):
        permuted = np.random.permutation(values_obs)
        I_p = morans_I(permuted, neighbors, weights)
        I_perm.append(I_p)

    I_perm = np.array(I_perm)

    z = (I_obs - I_perm.mean()) / I_perm.std()
    p = np.mean(I_perm >= I_obs)

    print(f"Observed I: {I_obs:.5f}")
    print(f"Null mean: {I_perm.mean():.5f}")
    print(f"Null std: {I_perm.std():.5f}")
    print(f"z-score: {z:.2f}")
    print(f"p-value: {p:.5f}")

    results.append({
        "L_km": L,
        "I_obs": I_obs,
        "I_null_mean": I_perm.mean(),
        "I_null_std": I_perm.std(),
        "z_score": z,
        "p_value": p
    })

# --------------------------------------------------
# SAVE RESULTS
# --------------------------------------------------

out = pd.DataFrame(results)
out.to_csv("spatial_autocorrelation_permutation_results.csv", index=False)

print("\nSaved spatial_autocorrelation_permutation_results.csv")
