#!/usr/bin/env python3
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp
from tqdm import trange

# ============================================================
# CONFIGURATION
# ============================================================

NETCDF_FILE = "SEISGLOB2_percent.nc"

MID_MANTLE_RANGE = (900, 1200)
PERCENTILES = [90, 95, 97.5]
N_NULL = 1000

# Euler points (degrees)
# (0N, 59W) -> 301E
# (0N, 121E)
EULERS = [
    (0.0, 301.0),
    (0.0, 121.0),
]

# ============================================================
# GEOMETRY
# ============================================================

def great_circle_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.deg2rad, [lat1, lon1, lat2, lon2])
    cos_d = (
        np.sin(lat1) * np.sin(lat2)
        + np.cos(lat1) * np.cos(lat2) * np.cos(lon1 - lon2)
    )
    cos_d = np.clip(cos_d, -1.0, 1.0)
    return np.rad2deg(np.arccos(cos_d))


# ============================================================
# DATA EXTRACTION
# ============================================================

def extract_extremes(lat, lon, depth, dvs, depth_range, percentile):
    dmin, dmax = depth_range
    depth_mask = (depth >= dmin) & (depth <= dmax)

    dvs_slice = dvs[depth_mask, :, :]
    depth_vals = depth[depth_mask]

    abs_vals = np.abs(dvs_slice[np.isfinite(dvs_slice)])
    threshold = np.percentile(abs_vals, percentile)

    points = []

    for k, z in enumerate(depth_vals):
        layer = dvs_slice[k]
        mask = np.abs(layer) >= threshold
        lat_idx, lon_idx = np.where(mask)

        for i, j in zip(lat_idx, lon_idx):
            points.append((lat[i], lon[j], z, layer[i, j]))

    return np.array(points)


def compute_euler_distances(points):
    dists = []
    for lat_p, lon_p, _, _ in points:
        d = min(
            great_circle_distance(lat_p, lon_p, e_lat, e_lon)
            for e_lat, e_lon in EULERS
        )
        dists.append(d)
    return np.array(dists)


# ============================================================
# NULL MODELS
# ============================================================

def longitude_randomised_null(points, n_iter):
    null_dists = []

    for _ in trange(n_iter, desc="Longitude null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))
        for (lat_p, _, _, _), lon_r in zip(points, rand_lon):
            d = min(
                great_circle_distance(lat_p, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)

    return np.array(null_dists)


def lat_symmetrised_null(points, n_iter):
    null_dists = []

    for _ in trange(n_iter, desc="Lat-sym null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))
        rand_sign = np.random.choice([-1, 1], size=len(points))

        for (lat_p, _, _, _), lon_r, sgn in zip(points, rand_lon, rand_sign):
            lat_r = sgn * abs(lat_p)
            d = min(
                great_circle_distance(lat_r, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)

    return np.array(null_dists)


# ============================================================
# CDF UTILITIES
# ============================================================

def empirical_cdf(data):
    x = np.sort(data)
    y = np.arange(1, len(x) + 1) / len(x)
    return x, y


def plot_cdf(obs, null_lon, null_lat, percentile):
    x_obs, y_obs = empirical_cdf(obs)
    x_lon, y_lon = empirical_cdf(null_lon)
    x_lat, y_lat = empirical_cdf(null_lat)

    plt.figure(figsize=(7, 5))
    plt.plot(x_obs, y_obs, linewidth=2, label="Observed")
    plt.plot(x_lon, y_lon, linestyle="--", label="Longitude null")
    plt.plot(x_lat, y_lat, linestyle=":", label="Lat-sym null")

    plt.xlabel("Angular distance to nearest Euler (degrees)")
    plt.ylabel("Cumulative probability")
    plt.title(f"Mid-mantle Euler proximity CDF ({percentile}th percentile)")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()


# ============================================================
# MAIN ANALYSIS
# ============================================================

def main():
    ds = xr.open_dataset(NETCDF_FILE)

    lat = ds["latitude"].values
    lon = ds["longitude"].values
    depth = ds["depth"].values
    dvs = ds["dvs"].values

    for p in PERCENTILES:
        print(f"\n=== MID-MANTLE | {p}th percentile ===")

        points = extract_extremes(
            lat, lon, depth, dvs, MID_MANTLE_RANGE, p
        )
        obs_dists = compute_euler_distances(points)

        print(f"N anomalies: {len(points)}")
        print(f"Observed mean distance: {np.mean(obs_dists):.2f}°")

        null_lon = longitude_randomised_null(points, N_NULL)
        null_lat = lat_symmetrised_null(points, N_NULL)

        print(f"Longitude null mean: {np.mean(null_lon):.2f}°")
        print(f"Lat-sym null mean:   {np.mean(null_lat):.2f}°")

        ks_lon = ks_2samp(obs_dists, null_lon)
        ks_lat = ks_2samp(obs_dists, null_lat)

        print(f"KS (lon null): {ks_lon.statistic:.3f}")
        print(f"KS (lat null): {ks_lat.statistic:.3f}")

        plot_cdf(obs_dists, null_lon, null_lat, p)


if __name__ == "__main__":
    main()
