#!/usr/bin/env python3
import numpy as np
import xarray as xr
from scipy.stats import ks_2samp
from tqdm import trange

# ============================================================
# CONFIGURATION
# ============================================================

NETCDF_FILE = "SEISGLOB2_percent.nc"

DEPTH_BANDS = {
    "upper_mantle": (300, 600),
    "mid_mantle":   (900, 1200),
    "lower_mantle": (1800, 2400),
}

PERCENTILES = [90, 95, 97.5]
N_NULL = 1000

# Euler points (degrees)
# (0N, 59W) -> 301E
# (0N, 121E)
EULERS = [
    (0.0, 301.0),
    (0.0, 121.0),
]

# ============================================================
# GEOMETRY
# ============================================================

def great_circle_distance(lat1, lon1, lat2, lon2):
    """
    Angular great-circle distance in degrees.
    Inputs in degrees.
    """
    lat1, lon1, lat2, lon2 = map(np.deg2rad, [lat1, lon1, lat2, lon2])
    cos_d = (
        np.sin(lat1) * np.sin(lat2)
        + np.cos(lat1) * np.cos(lat2) * np.cos(lon1 - lon2)
    )
    cos_d = np.clip(cos_d, -1.0, 1.0)
    return np.rad2deg(np.arccos(cos_d))


# ============================================================
# DATA EXTRACTION
# ============================================================

def extract_extremes(lat, lon, depth, dvs, depth_range, percentile):
    """
    Extract extreme |dvs| points within a depth band.
    Returns array of (lat, lon, depth, dvs).
    """
    dmin, dmax = depth_range
    depth_mask = (depth >= dmin) & (depth <= dmax)

    dvs_slice = dvs[depth_mask, :, :]
    depth_vals = depth[depth_mask]

    abs_vals = np.abs(dvs_slice[np.isfinite(dvs_slice)])
    threshold = np.percentile(abs_vals, percentile)

    points = []

    for k, z in enumerate(depth_vals):
        layer = dvs_slice[k]
        mask = np.abs(layer) >= threshold
        lat_idx, lon_idx = np.where(mask)

        for i, j in zip(lat_idx, lon_idx):
            points.append((lat[i], lon[j], z, layer[i, j]))

    return np.array(points)


def compute_euler_distances(points):
    """
    Compute minimum angular distance to either Euler point.
    """
    dists = []
    for lat_p, lon_p, _, _ in points:
        d = min(
            great_circle_distance(lat_p, lon_p, e_lat, e_lon)
            for e_lat, e_lon in EULERS
        )
        dists.append(d)
    return np.array(dists)


# ============================================================
# NULL MODELS
# ============================================================

def longitude_randomised_null(points, n_iter):
    """
    Longitude-randomised null (latitude preserved).
    """
    null_dists = []

    for _ in trange(n_iter, desc="Longitude null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))

        for (lat_p, _, _, _), lon_r in zip(points, rand_lon):
            d = min(
                great_circle_distance(lat_p, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)

    return np.array(null_dists)


def lat_symmetrised_null(points, n_iter):
    """
    Latitude-symmetrised + longitude-randomised null.
    Preserves |latitude|.
    """
    null_dists = []

    for _ in trange(n_iter, desc="Lat-sym null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))
        rand_sign = np.random.choice([-1, 1], size=len(points))

        for (lat_p, _, _, _), lon_r, sgn in zip(points, rand_lon, rand_sign):
            lat_r = sgn * abs(lat_p)
            d = min(
                great_circle_distance(lat_r, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)

    return np.array(null_dists)


# ============================================================
# ANALYSIS DRIVER
# ============================================================

def run_test(lat, lon, depth, dvs, band_name, depth_range, percentile):
    print(f"\n=== {band_name.upper()} | {percentile}th percentile ===")

    points = extract_extremes(lat, lon, depth, dvs, depth_range, percentile)
    obs_dists = compute_euler_distances(points)

    print(f"N anomalies: {len(points)}")
    print(f"Observed mean dist: {np.mean(obs_dists):.2f}°")

    # Longitude null
    null_lon = longitude_randomised_null(points, N_NULL)
    ks_lon = ks_2samp(obs_dists, null_lon)

    print("\n[Longitude-randomised null]")
    print(f"Null mean dist: {np.mean(null_lon):.2f}°")
    print(f"Δmean (obs-null): {np.mean(obs_dists) - np.mean(null_lon):.2f}°")
    print(f"KS: {ks_lon.statistic:.3f} | p = {ks_lon.pvalue:.3e}")

    # Latitude-symmetrised null
    null_lat = lat_symmetrised_null(points, N_NULL)
    ks_lat = ks_2samp(obs_dists, null_lat)

    print("\n[Latitude-symmetrised null]")
    print(f"Null mean dist: {np.mean(null_lat):.2f}°")
    print(f"Δmean (obs-null): {np.mean(obs_dists) - np.mean(null_lat):.2f}°")
    print(f"KS: {ks_lat.statistic:.3f} | p = {ks_lat.pvalue:.3e}")


# ============================================================
# MAIN
# ============================================================

def main():
    ds = xr.open_dataset(NETCDF_FILE)

    lat = ds["latitude"].values
    lon = ds["longitude"].values
    depth = ds["depth"].values
    dvs = ds["dvs"].values

    for band, depth_range in DEPTH_BANDS.items():
        for p in PERCENTILES:
            run_test(lat, lon, depth, dvs, band, depth_range, p)


if __name__ == "__main__":
    main()
