import json
import random
import numpy as np
import multiprocessing as mp
import matplotlib.pyplot as plt

from scipy.stats import spearmanr
from shapely.geometry import Point, Polygon
from scipy.interpolate import griddata
from tqdm import tqdm

# ------------------------------------------------------------
# PARAMETERS
# ------------------------------------------------------------

N_ITER = 10000
RANDOM_SEED = 42
N_WORKERS = mp.cpu_count()   # use all cores

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# ------------------------------------------------------------
# LOAD HOMO DATA (AGES FIXED)
# ------------------------------------------------------------

with open("early_homo_distance_to_zero_contour.geojson", "r") as f:
    homo_data = json.load(f)

ages = []
orig_points = []

for feat in homo_data["features"]:
    props = feat["properties"]

    try:
        age = 0.5 * (float(props["min_ma"]) + float(props["max_ma"]))
    except (TypeError, ValueError):
        continue

    lon, lat = feat["geometry"]["coordinates"]
    ages.append(age)
    orig_points.append((lon, lat))

ages = np.array(ages)
orig_points = np.array(orig_points)

# ------------------------------------------------------------
# GEOGRAPHIC SAMPLING DOMAIN
# ------------------------------------------------------------

hull = Polygon(orig_points).convex_hull
minx, miny, maxx, maxy = hull.bounds

def random_point_in_hull(rng):
    while True:
        lon = rng.uniform(minx, maxx)
        lat = rng.uniform(miny, maxy)
        if hull.contains(Point(lon, lat)):
            return lon, lat

# ------------------------------------------------------------
# LOAD SEA LEVEL FIELD
# ------------------------------------------------------------

with open("TPW_signed_equilibrium_relative_sea_level.geojson", "r") as f:
    sl_data = json.load(f)

sl_lons = np.array([f["geometry"]["coordinates"][0] for f in sl_data["features"]])
sl_lats = np.array([f["geometry"]["coordinates"][1] for f in sl_data["features"]])
sl_vals = np.array([f["properties"]["net_relative_sea_level_m"] for f in sl_data["features"]])

# ------------------------------------------------------------
# DISTANCE PROXY (ABSOLUTE ANOMALY)
# ------------------------------------------------------------

def anomaly_magnitude(lon, lat):
    return abs(
        griddata(
            (sl_lons, sl_lats),
            sl_vals,
            [(lon, lat)],
            method="linear"
        )[0]
    )

# ------------------------------------------------------------
# OBSERVED STATISTIC
# ------------------------------------------------------------

observed_distances = [
    anomaly_magnitude(lon, lat)
    for lon, lat in orig_points
]

observed_rho, _ = spearmanr(ages, observed_distances)

print(f"Observed Spearman rho = {observed_rho:.4f}")

# ------------------------------------------------------------
# WORKER FUNCTION
# ------------------------------------------------------------

def worker(task_id):
    rng = random.Random(RANDOM_SEED + task_id)

    rand_points = [
        random_point_in_hull(rng)
        for _ in range(len(orig_points))
    ]

    rand_distances = [
        anomaly_magnitude(lon, lat)
        for lon, lat in rand_points
    ]

    rho, _ = spearmanr(ages, rand_distances)
    return rho

# ------------------------------------------------------------
# PARALLEL EXECUTION WITH PROGRESS
# ------------------------------------------------------------

if __name__ == "__main__":

    with mp.Pool(processes=N_WORKERS) as pool:
        null_rhos = list(
            tqdm(
                pool.imap(worker, range(N_ITER)),
                total=N_ITER,
                desc="Running null model",
                smoothing=0.05
            )
        )

    null_rhos = np.array(null_rhos)

    # --------------------------------------------------------
    # EMPIRICAL P-VALUE
    # --------------------------------------------------------

    p_empirical = np.mean(null_rhos >= observed_rho)

    print(f"Empirical p-value = {p_empirical:.6f}")
    print(f"Confidence level  = {(1 - p_empirical) * 100:.3f}%")

    # --------------------------------------------------------
    # SAVE RESULTS
    # --------------------------------------------------------

    np.save("null_rhos.npy", null_rhos)

    with open("observed_rho.txt", "w") as f:
        f.write(str(observed_rho))

    print("Saved null_rhos.npy and observed_rho.txt")
