#!/usr/bin/env python3
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tqdm import trange

# ============================================================
# CONFIGURATION
# ============================================================

NETCDF_FILE = "SEISGLOB2_percent.nc"
MID_MANTLE_RANGE = (900, 1200)
PERCENTILES = [90, 95, 97.5]
N_NULL = 1000

EULERS = [
    (0.0, 301.0),  # 0N, 59W
    (0.0, 121.0),  # 0N, 121E
]

# ============================================================
# GEOMETRY
# ============================================================

def great_circle_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.deg2rad, [lat1, lon1, lat2, lon2])
    cos_d = (
        np.sin(lat1) * np.sin(lat2)
        + np.cos(lat1) * np.cos(lat2) * np.cos(lon1 - lon2)
    )
    cos_d = np.clip(cos_d, -1.0, 1.0)
    return np.rad2deg(np.arccos(cos_d))


# ============================================================
# DATA EXTRACTION
# ============================================================

def extract_extremes(lat, lon, depth, dvs, depth_range, percentile):
    dmin, dmax = depth_range
    depth_mask = (depth >= dmin) & (depth <= dmax)

    dvs_slice = dvs[depth_mask, :, :]
    depth_vals = depth[depth_mask]

    abs_vals = np.abs(dvs_slice[np.isfinite(dvs_slice)])
    threshold = np.percentile(abs_vals, percentile)

    points = []

    for k, z in enumerate(depth_vals):
        layer = dvs_slice[k]
        mask = np.abs(layer) >= threshold
        lat_idx, lon_idx = np.where(mask)

        for i, j in zip(lat_idx, lon_idx):
            points.append((lat[i], lon[j]))

    return np.array(points)


def compute_euler_distances(points):
    dists = []
    for lat_p, lon_p in points:
        d = min(
            great_circle_distance(lat_p, lon_p, e_lat, e_lon)
            for e_lat, e_lon in EULERS
        )
        dists.append(d)
    return np.array(dists)


# ============================================================
# NULL MODELS
# ============================================================

def longitude_randomised_null(points, n_iter):
    null_dists = []
    for _ in trange(n_iter, desc="Longitude null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))
        for (lat_p, _), lon_r in zip(points, rand_lon):
            d = min(
                great_circle_distance(lat_p, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)
    return np.array(null_dists)


def lat_symmetrised_null(points, n_iter):
    null_dists = []
    for _ in trange(n_iter, desc="Lat-sym null"):
        rand_lon = np.random.uniform(0, 360, size=len(points))
        rand_sign = np.random.choice([-1, 1], size=len(points))

        for (lat_p, _), lon_r, sgn in zip(points, rand_lon, rand_sign):
            lat_r = sgn * abs(lat_p)
            d = min(
                great_circle_distance(lat_r, lon_r, e_lat, e_lon)
                for e_lat, e_lon in EULERS
            )
            null_dists.append(d)
    return np.array(null_dists)


# ============================================================
# CDF UTILITIES
# ============================================================

def empirical_cdf(data):
    x = np.sort(data)
    y = np.arange(1, len(x) + 1) / len(x)
    return x, y


# ============================================================
# MAIN FIGURE
# ============================================================

def main():
    ds = xr.open_dataset(NETCDF_FILE)
    lat = ds["latitude"].values
    lon = ds["longitude"].values
    depth = ds["depth"].values
    dvs = ds["dvs"].values

    fig, axes = plt.subplots(
        1, 3, figsize=(14, 4), sharex=True, sharey=True
    )

    for ax, p, label in zip(
        axes, PERCENTILES, ["(a)", "(b)", "(c)"]
    ):
        print(f"\nProcessing {p}th percentile")

        points = extract_extremes(
            lat, lon, depth, dvs, MID_MANTLE_RANGE, p
        )

        obs = compute_euler_distances(points)
        null_lon = longitude_randomised_null(points, N_NULL)
        null_lat = lat_symmetrised_null(points, N_NULL)

        x_obs, y_obs = empirical_cdf(obs)
        x_lon, y_lon = empirical_cdf(null_lon)
        x_lat, y_lat = empirical_cdf(null_lat)

        ax.plot(x_obs, y_obs, linewidth=2, label="Observed")
        ax.plot(x_lon, y_lon, linestyle="--", label="Longitude null")
        ax.plot(x_lat, y_lat, linestyle=":", label="Lat-sym null")

        ax.set_title(f"{label} {p}th percentile")
        ax.grid(alpha=0.3)

    axes[0].set_ylabel("Cumulative probability")
    for ax in axes:
        ax.set_xlabel("Angular distance to nearest Euler (degrees)")

    # Single legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles, labels,
        loc="lower center",
        ncol=3,
        frameon=False
    )

    plt.tight_layout(rect=[0, 0.12, 1, 1])
    plt.show()


if __name__ == "__main__":
    main()
