#!/usr/bin/env python3

import json
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature

from matplotlib.colors import BoundaryNorm

# ============================================================
# CONFIG
# ============================================================

NC_FILE = "tpw_centrifugal_relaxation_95pct_10steps.nc"

HOMO_GEOJSON = "early_homo_sites.geojson"
CIV_GEOJSON  = "early_civilizations.geojson"

OUT_PREFIX = "tpw_relaxation_step"

# Downsampling factor (increase if still slow)
DS_FACTOR = 4   # 1 = full resolution, 4 = 1/4 resolution

# Contour levels (meters)
LEVELS = np.array([
    -15000, -12500, -10000, -7500, -5000, -2500,
     -1000,  -500,  -250,
         0,
       250,   500,  1000,
      2500,  5000,  7500, 10000, 12500, 15000
])

# ============================================================
# LOAD NC DATA
# ============================================================

ds = xr.open_dataset(NC_FILE)

# Expecting: effective_elevation(step, lat, lon)
field = ds["effective_elevation"]

# Downsample for speed
field = field.isel(
    lat=slice(None, None, DS_FACTOR),
    lon=slice(None, None, DS_FACTOR)
)

lats = field.lat.values
lons = field.lon.values
steps = field.step.values

LON, LAT = np.meshgrid(lons, lats)

norm = BoundaryNorm(LEVELS, ncolors=256, clip=True)

# ============================================================
# LOAD POINT DATA
# ============================================================

def load_points(fname):
    lons, lats = [], []
    with open(fname, "r") as f:
        data = json.load(f)
    for feat in data["features"]:
        lon, lat = feat["geometry"]["coordinates"]
        lons.append(lon)
        lats.append(lat)
    return np.array(lons), np.array(lats)

homo_lons, homo_lats = load_points(HOMO_GEOJSON)
civ_lons, civ_lats   = load_points(CIV_GEOJSON)

# ============================================================
# LOOP OVER RELAXATION STEPS
# ============================================================

for i, step in enumerate(steps):

    Z = field.sel(step=step).values

    fig = plt.figure(figsize=(17, 9))
    ax = plt.axes(projection=ccrs.Robinson())
    ax.set_global()

    ax.set_title(
        "Centrifugal Viscoelastic Relaxation After 104° TPW\n"
        f"Relaxation step {i} of {len(steps)-1}",
        fontsize=14,
        pad=14
    )

    # --------------------------------------------------------
    # Filled contours
    # --------------------------------------------------------

    cf = ax.contourf(
        LON,
        LAT,
        Z,
        levels=LEVELS,
        cmap="RdBu_r",
        norm=norm,
        transform=ccrs.PlateCarree(),
        extend="both",
        alpha=0.95
    )

    # --------------------------------------------------------
    # Zero-elevation contour (THIS SHOULD MOVE)
    # --------------------------------------------------------

    ax.contour(
        LON,
        LAT,
        Z,
        levels=[0.0],
        colors="black",
        linewidths=1.2,
        transform=ccrs.PlateCarree()
    )

    # --------------------------------------------------------
    # Homo sites
    # --------------------------------------------------------

    ax.scatter(
        homo_lons,
        homo_lats,
        s=35,
        color="black",
        edgecolor="white",
        linewidth=0.5,
        transform=ccrs.PlateCarree(),
        label="Early Homo sites",
        zorder=5
    )

    # --------------------------------------------------------
    # Civilization sites
    # --------------------------------------------------------

    ax.scatter(
        civ_lons,
        civ_lats,
        s=40,
        color="gold",
        edgecolor="black",
        linewidth=0.5,
        transform=ccrs.PlateCarree(),
        label="Early civilizations",
        zorder=6
    )

    # --------------------------------------------------------
    # Geographic reference
    # --------------------------------------------------------

    ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
    ax.add_feature(cfeature.BORDERS, linewidth=0.4)
    ax.add_feature(cfeature.LAND, facecolor="none", edgecolor="black", linewidth=0.3)

    # --------------------------------------------------------
    # Colorbar
    # --------------------------------------------------------

    cbar = plt.colorbar(
        cf,
        orientation="horizontal",
        pad=0.06,
        fraction=0.06,
        ticks=LEVELS
    )
    cbar.set_label(
        "Effective Elevation Relative to Equilibrium Sea Level (m)",
        fontsize=10
    )
    cbar.ax.tick_params(labelsize=8)

    # --------------------------------------------------------
    # Legend
    # --------------------------------------------------------

    ax.legend(
        loc="lower left",
        frameon=True,
        framealpha=0.9,
    )

    # --------------------------------------------------------
    # Save frame
    # --------------------------------------------------------

    outfile = f"{OUT_PREFIX}_{i:02d}.png"
    plt.savefig(outfile, dpi=300, bbox_inches="tight")
    plt.close()

    print(f"Saved {outfile}")

print("Sequence complete.")
