#!/usr/bin/env python3

import numpy as np
import xarray as xr

# -------------------------
# CONFIG
# -------------------------

GEBCO_NC = "GEBCO_2025_sub_ice.nc"
OUTPUT_NC = "tpw_centrifugal_relaxation_95pct_10steps.nc"

DOWNSAMPLE = 60
N_STEPS = 10
RELAX_FRAC = 0.95

R = 6371000.0
OMEGA = 7.2921159e-5
G = 9.81

TPW_LON = 31.0
TPW_COLAT = 104.0   # degrees from old pole

# -------------------------
# LOAD GEBCO SAFELY
# -------------------------

ds = xr.open_dataset(GEBCO_NC, engine="netcdf4")

# Detect elevation variable
for v in ds.data_vars:
    if "elev" in v.lower() or v.lower() in ("z", "bedrock", "height"):
        elev_var = v
        break
else:
    raise ValueError("Could not find elevation variable in GEBCO file")

Z = ds[elev_var]

# Downsample
Z = Z.isel(
    lat=slice(None, None, DOWNSAMPLE),
    lon=slice(None, None, DOWNSAMPLE)
)

lat = Z.lat.values
lon = Z.lon.values

lon2d, lat2d = np.meshgrid(lon, lat)

# -------------------------
# ROTATION GEOMETRY
# -------------------------

def colat(lon, lat, pole_lon, pole_colat):
    lon = np.deg2rad(lon)
    lat = np.deg2rad(lat)
    pole_lon = np.deg2rad(pole_lon)
    pole_lat = np.pi/2 - np.deg2rad(pole_colat)
    return np.arccos(
        np.sin(lat) * np.sin(pole_lat) +
        np.cos(lat) * np.cos(pole_lat) * np.cos(lon - pole_lon)
    )

theta_S1 = colat(lon2d, lat2d, 0.0, 0.0)
theta_S2 = colat(lon2d, lat2d, TPW_LON, TPW_COLAT)

# -------------------------
# CENTRIFUGAL POTENTIAL
# -------------------------

Phi = lambda th: 0.5 * OMEGA**2 * R**2 * np.sin(th)**2

delta_h = (Phi(theta_S1) - Phi(theta_S2)) / G

# -------------------------
# RELAXATION SEQUENCE
# -------------------------

steps = np.arange(N_STEPS + 1)
alpha = 1.0 - RELAX_FRAC * steps / N_STEPS

Z_eff = np.empty((len(steps), lat.size, lon.size), dtype=np.float32)

for i, a in enumerate(alpha):
    Z_eff[i] = Z.values + a * delta_h

# -------------------------
# WRITE NETCDF (CF SAFE)
# -------------------------

out = xr.Dataset(
    data_vars=dict(
        effective_elevation=(
            ("step", "lat", "lon"),
            Z_eff,
            {
                "units": "m",
                "long_name": "Effective elevation relative to equilibrium sea level",
                "_FillValue": np.float32(-9.96921e36),
            }
        )
    ),
    coords=dict(
        step=("step", steps),
        lat=("lat", lat, {"units": "degrees_north"}),
        lon=("lon", lon, {"units": "degrees_east"}),
    ),
    attrs=dict(
        title="Centrifugal Viscoelastic Relaxation After 104° True Polar Wander",
        model="Fixed bedrock + relaxing centrifugal potential mismatch",
        TPW_angle_deg=104.0,
        relaxation_fraction=RELAX_FRAC,
    )
)

out.to_netcdf(
    OUTPUT_NC,
    format="NETCDF4",
    engine="netcdf4",
    encoding={
        "effective_elevation": {
            "zlib": True,
            "complevel": 4,
            "chunksizes": (1, 180, 360)
        }
    }
)

print(f"✔ Saved valid NetCDF: {OUTPUT_NC}")
