#!/usr/bin/env python3

import numpy as np
import pandas as pd
import xarray as xr
from shapely.geometry import LineString, Point
from shapely.ops import unary_union
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
from pyproj import Geod
import matplotlib.contour as mcontour

# --------------------------------------------------
# INPUT FILES
# --------------------------------------------------

TPW_NC = "tpw_centrifugal_relaxation_95pct_10steps.nc"
HOMO_CSV = "pbdb_data.csv"

# --------------------------------------------------
# LOAD TPW DATA
# --------------------------------------------------

ds = xr.open_dataset(TPW_NC)
E = ds["effective_elevation"]

# --------------------------------------------------
# FIND PBDB HEADER ROW
# --------------------------------------------------

header_line = None
with open(HOMO_CSV, "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if line.startswith('"occurrence_no"'):
            header_line = i
            break

if header_line is None:
    raise RuntimeError("PBDB header row not found")

# --------------------------------------------------
# LOAD PBDB DATA TABLE
# --------------------------------------------------

homo = pd.read_csv(
    HOMO_CSV,
    skiprows=header_line,
    engine="python"
)

# --------------------------------------------------
# REQUIRED COLUMNS (PBDB STANDARD)
# --------------------------------------------------

required_cols = ["lat", "lng", "max_ma"]

missing = [c for c in required_cols if c not in homo.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

homo = homo.dropna(subset=required_cols)

ages = homo["max_ma"].astype(float).values
coords = list(zip(
    homo["lng"].astype(float).values,
    homo["lat"].astype(float).values
))

geod = Geod(ellps="WGS84")

# --------------------------------------------------
# FUNCTION: extract zero contour
# --------------------------------------------------

def zero_contour_geometry(field, lats, lons):
    """
    Extract the zero-elevation contour as a Shapely geometry.
    """
    fig = plt.figure()
    cs = plt.contour(lons, lats, field, levels=[0.0])

    lines = []
    for seg in cs.allsegs[0]:
        if len(seg) > 1:
            lines.append(LineString(seg))

    plt.close(fig)

    if not lines:
        return unary_union([])

    return unary_union(lines)

# --------------------------------------------------
# FUNCTION: distance to contour (km)
# --------------------------------------------------

def distance_km(point, contour):
    if contour.is_empty:
        return np.nan
    nearest = contour.interpolate(contour.project(point))
    _, _, dist = geod.inv(
        point.x, point.y,
        nearest.x, nearest.y
    )
    return abs(dist) / 1000.0

# --------------------------------------------------
# MAIN LOOP
# --------------------------------------------------

results = []

for step in ds.step.values:
    field = E.sel(step=step).values
    contour = zero_contour_geometry(field, ds.lat.values, ds.lon.values)

    distances = np.array([
        distance_km(Point(lon, lat), contour)
        for lon, lat in coords
    ])

    rho, p = spearmanr(ages, distances)

    results.append({
        "step": int(step),
        "spearman_rho": rho,
        "p_value": p,
        "mean_distance_km": np.nanmean(distances)
    })

results_df = pd.DataFrame(results)
results_df.to_csv("hominin_viscoelastic_step_correlations.csv", index=False)

print(results_df)

# --------------------------------------------------
# PLOT
# --------------------------------------------------

plt.figure(figsize=(7, 4))
plt.plot(
    results_df["step"],
    results_df["spearman_rho"],
    marker="o",
    linewidth=2
)

plt.axhline(0, color="gray", linestyle="--", alpha=0.5)
plt.xlabel("Viscoelastic relaxation step")
plt.ylabel("Spearman ρ (age vs distance)")
plt.title("Hominin age–distance correlation across centrifugal relaxation")

plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("hominin_vs_relaxation_step.png", dpi=300)
plt.show()
