#!/usr/bin/env python3
"""
Continuous TPW Return with Logarithmic Viscoelastic Relaxation
(first-order physically derived model)

Implements:
- Logarithmic TPW angular return (derived rate law)
- Logarithmic viscoelastic relaxation (99% by 1.8 ka BP)
- Continuous coupling between geometry and relaxation
- NetCDF + PNG outputs
"""

import json
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from pathlib import Path

# ============================================================
# CONFIGURATION
# ============================================================

GEBCO_NC = "GEBCO_2025_sub_ice.nc"
OUTPUT_NC = "tpw_return_continuous_log.nc"
PNG_DIR = Path("tpw_return_png_continuous")
PNG_DIR.mkdir(exist_ok=True)

HOMO_GEOJSON = "early_homo_sites.geojson"
CIV_GEOJSON  = "early_civilizations.geojson"

# Physical constants
R = 6371000.0
OMEGA = 7.2921159e-5
G = 9.81

# TPW geometry
TPW_MERIDIAN_DEG = 31.0
THETA_0 = 104.0            # initial TPW angle (deg)

# Time parameters
T_YEARS = 12_000.0
N_STEPS = 48               # higher temporal resolution
TAU_THETA = 1500.0         # TPW angular timescale (yr)

# Viscoelastic relaxation
TAU_RELAX = 500.0
RELAX_T99_YEARS = 10_200.0  # 99% by 1.8 ka BP

# Numerical
DOWNSAMPLE = 60

# Fixed color scale
VMIN = -15000.0
VMAX =  15000.0
LEVELS = np.linspace(VMIN, VMAX, 31)

# ============================================================
# LOAD HUMAN DATA
# ============================================================

def load_points(fname):
    with open(fname) as f:
        data = json.load(f)
    return np.array([
        (feat["geometry"]["coordinates"][1],
         feat["geometry"]["coordinates"][0])
        for feat in data["features"]
    ])

homo_sites = load_points(HOMO_GEOJSON)
civ_sites  = load_points(CIV_GEOJSON)

# ============================================================
# CONTINUOUS LAWS (DERIVED)
# ============================================================

def tpw_angle_log(t, theta0, tau, T):
    """Logarithmic TPW angular return"""
    return theta0 * (
        1.0 - np.log1p(t / tau) / np.log1p(T / tau)
    )

def log_relaxation(t, tau, T99):
    """Logarithmic viscoelastic relaxation"""
    return np.clip(
        np.log1p(t / tau) / np.log1p(T99 / tau),
        0.0, 1.0
    )

# ============================================================
# LOAD & DOWNSAMPLE GEBCO
# ============================================================

ds = xr.open_dataset(GEBCO_NC)
for v in ds.data_vars:
    if "elev" in v.lower() or v.lower() in ("z","bedrock","height"):
        elev_var = v
        break

Z = ds[elev_var].isel(
    lat=slice(None,None,DOWNSAMPLE),
    lon=slice(None,None,DOWNSAMPLE)
)

lat = Z.lat.values
lon = Z.lon.values
lon2d, lat2d = np.meshgrid(lon, lat)

# ============================================================
# GEOMETRY
# ============================================================

def colatitude(lon, lat, pole_lon, pole_colat):
    lon = np.deg2rad(lon)
    lat = np.deg2rad(lat)
    pole_lon = np.deg2rad(pole_lon)
    pole_lat = np.pi/2 - np.deg2rad(pole_colat)
    return np.arccos(
        np.sin(lat)*np.sin(pole_lat) +
        np.cos(lat)*np.cos(pole_lat)*np.cos(lon-pole_lon)
    )

def centrifugal_potential(theta):
    return 0.5 * OMEGA**2 * R**2 * np.sin(theta)**2

# ============================================================
# BUILD CONTINUOUS RETURN SEQUENCE
# ============================================================

steps = np.arange(N_STEPS+1)
times_years = np.linspace(0, T_YEARS, N_STEPS+1)
times_ka = T_YEARS/1000 - times_years/1000

tpw_angles = tpw_angle_log(
    times_years, THETA_0, TAU_THETA, T_YEARS
)

relax_frac = log_relaxation(
    times_years, TAU_RELAX, RELAX_T99_YEARS
)

theta_S1 = colatitude(lon2d, lat2d, 0.0, 0.0)
Phi_S1 = centrifugal_potential(theta_S1)

theta_init = colatitude(lon2d, lat2d, TPW_MERIDIAN_DEG, THETA_0)
Phi_init = centrifugal_potential(theta_init)

Zeff = np.zeros((len(steps), *Z.shape), dtype=np.float32)

for i in steps:
    theta_target = colatitude(
        lon2d, lat2d,
        TPW_MERIDIAN_DEG,
        tpw_angles[i]
    )

    Phi_target = centrifugal_potential(theta_target)

    Phi_eff = (
        (1.0 - relax_frac[i]) * Phi_init +
        relax_frac[i] * Phi_target
    )

    Zeff[i] = Z.values + (Phi_S1 - Phi_eff) / G

# ============================================================
# WRITE NETCDF
# ============================================================

xr.Dataset(
    data_vars=dict(effective_elevation=(("step","lat","lon"), Zeff)),
    coords=dict(step=steps, time_ka=("step",times_ka), lat=lat, lon=lon),
    attrs=dict(
        model="Continuous logarithmic TPW return with viscoelastic relaxation",
        tpw_timescale_years=TAU_THETA,
        relaxation_99pct_by_ka=1.8
    )
).to_netcdf(OUTPUT_NC)

# ============================================================
# DIAGNOSTIC PLOTS
# ============================================================

plt.figure(figsize=(6,4))
plt.plot(times_ka, tpw_angles)
plt.xlabel("ka BP")
plt.ylabel("TPW angle (deg)")
plt.title("Logarithmic TPW return")
plt.grid(alpha=0.3)
plt.savefig("tpw_angle_vs_time.png", dpi=200)
plt.close()

plt.figure(figsize=(6,4))
plt.plot(tpw_angles, relax_frac*100)
plt.xlabel("TPW angle (deg)")
plt.ylabel("Relaxation (%)")
plt.title("Relaxation vs TPW angle")
plt.grid(alpha=0.3)
plt.savefig("relaxation_vs_tpw_continuous.png", dpi=200)
plt.close()

# ============================================================
# MAP OUTPUT
# ============================================================

for i in steps:
    fig = plt.figure(figsize=(14,7))
    ax = plt.axes(projection=ccrs.Robinson())
    ax.set_global()

    cf = ax.contourf(
        lon, lat, Zeff[i],
        levels=LEVELS, cmap="RdBu_r",
        vmin=VMIN, vmax=VMAX,
        transform=ccrs.PlateCarree(), extend="both"
    )

    ax.contour(
        lon, lat, Zeff[i],
        levels=[0], colors="black", linewidths=1.2,
        transform=ccrs.PlateCarree()
    )

    ax.scatter(homo_sites[:,1], homo_sites[:,0],
               s=35, color="black", edgecolor="white",
               transform=ccrs.PlateCarree(), label="Early Homo")

    ax.scatter(civ_sites[:,1], civ_sites[:,0],
               s=40, color="gold", edgecolor="black",
               transform=ccrs.PlateCarree(), label="Early Civilizations")

    ax.add_feature(cfeature.COASTLINE, linewidth=0.6)

    ax.set_title(
        f"{times_ka[i]:.1f} ka | "
        f"TPW {tpw_angles[i]:.1f}° | "
        f"{relax_frac[i]*100:.1f}% relaxation"
    )

    plt.colorbar(cf, orientation="horizontal", pad=0.05,
                 label="Effective elevation relative to equilibrium (m)")

    ax.legend(loc="lower left")
    plt.savefig(PNG_DIR / f"tpw_continuous_{i:02d}.png",
                dpi=200, bbox_inches="tight")
    plt.close()
