#!/usr/bin/env python3

import numpy as np
import xarray as xr
from scipy.ndimage import zoom

# ===============================
# USER SETTINGS
# ===============================

INPUT_EQ_NC = "gebco_tpw_effective_elevation.nc"
OUTPUT_NC   = "tpw_S2_viscoelastic_relaxation_95pct_10steps.nc"

N_STEPS = 10
TARGET_RELAX = 0.95

# target resolution (degrees)
TARGET_RES_DEG = 1.0

# ===============================
# LOAD INITIAL (INHERITED) FIELD
# ===============================

ds = xr.open_dataset(INPUT_EQ_NC)

A_rot = ds["effective_elevation"]  # inherited rotated S1 bulge

lat = ds["lat"].values
lon = ds["lon"].values

# ===============================
# DOWNSAMPLE TO 1° GRID
# ===============================

dlat = abs(lat[1] - lat[0])
dlon = abs(lon[1] - lon[0])

zoom_lat = dlat / TARGET_RES_DEG
zoom_lon = dlon / TARGET_RES_DEG

A_rot_ds = zoom(
    A_rot.values,
    zoom=(zoom_lat, zoom_lon),
    order=1
)

lat_ds = np.linspace(-90, 90, A_rot_ds.shape[0])
lon_ds = np.linspace(-180, 180, A_rot_ds.shape[1])

# ===============================
# RELAXATION FUNCTION
# ===============================

T = float(N_STEPS)
tau = -T / np.log(1 - TARGET_RELAX)

steps = np.arange(N_STEPS + 1)

relax_fraction = 1 - np.exp(-steps / tau)

# ===============================
# BUILD TIME SEQUENCE
# ===============================

A_time = np.zeros((len(steps), len(lat_ds), len(lon_ds)), dtype=np.float32)

for i, r in enumerate(relax_fraction):
    A_time[i] = (1 - r) * A_rot_ds

# ===============================
# OUTPUT DATASET
# ===============================

ds_out = xr.Dataset(
    data_vars=dict(
        effective_elevation=(
            ("time", "lat", "lon"),
            A_time,
            {
                "units": "meters",
                "long_name": "Effective elevation relative to sea level",
                "description": (
                    "Inherited S1 centrifugal bulge after IITPW, "
                    "progressively relaxed toward circular S2 geometry"
                ),
            },
        )
    ),
    coords=dict(
        time=("time", steps, {"long_name": "Relaxation step"}),
        lat=("lat", lat_ds, {"units": "degrees_north"}),
        lon=("lon", lon_ds, {"units": "degrees_east"}),
    ),
    attrs=dict(
        title="Viscoelastic Relaxation of Inherited Centrifugal Bulge After IITPW",
        model="Rapid IITPW followed by viscoelastic centrifugal relaxation",
        relaxation_target=TARGET_RELAX,
        steps=N_STEPS,
        relaxation_timescale_tau=tau,
        notes=(
            "Step 0 = immediate post-IITPW state with full inherited S1 oblateness.\n"
            "Final step ≈95% relaxation toward circular S2 geometry."
        ),
    ),
)

ds_out.to_netcdf(
    OUTPUT_NC,
    encoding={
        "effective_elevation": {
            "dtype": "float32",
            "zlib": True,
            "complevel": 4,
            "chunksizes": (1, 180, 360),
        }
    }
)

print(f"Written: {OUTPUT_NC}")
print(f"Relaxation timescale tau = {tau:.2f}")
