#!/usr/bin/env python3

import numpy as np
import xarray as xr

# -----------------------------
# USER INPUTS
# -----------------------------

INPUT_NC = "gebco_tpw_effective_elevation.nc"
OUTPUT_NC = "effective_elevation_viscoelastic_95pct_10steps.nc"

N_STEPS = 10
TARGET_FRACTION = 0.95   # 95% relaxation

# -----------------------------
# LOAD EQUILIBRIUM DATA
# -----------------------------

ds_eq = xr.open_dataset(INPUT_NC)

# Expect a variable named exactly this
E_eq = ds_eq["effective_elevation"]

lat = ds_eq["lat"]
lon = ds_eq["lon"]

# -----------------------------
# COMPUTE RELAXATION CONSTANT
# -----------------------------
# Solve: 1 - exp(-T / tau) = TARGET_FRACTION
# Choose T = N_STEPS (arbitrary normalized time)

T = float(N_STEPS)
tau = -T / np.log(1.0 - TARGET_FRACTION)

# -----------------------------
# GENERATE TIME STEPS
# -----------------------------

steps = np.arange(0, N_STEPS + 1)

relaxation_factor = 1.0 - np.exp(-steps / tau)

# -----------------------------
# BUILD 4D DATA ARRAY
# -----------------------------

E_time = np.zeros((len(steps), len(lat), len(lon)), dtype=np.float32)

for i, f in enumerate(relaxation_factor):
    E_time[i, :, :] = f * E_eq.values

# -----------------------------
# CREATE OUTPUT DATASET
# -----------------------------

ds_out = xr.Dataset(
    data_vars=dict(
        effective_elevation=(
            ["step", "lat", "lon"],
            E_time,
            {
                "units": "meters",
                "description": "Effective elevation relative to equilibrium sea level",
            },
        )
    ),
    coords=dict(
        step=("step", steps, {"description": "Viscoelastic relaxation step"}),
        lat=lat,
        lon=lon,
    ),
    attrs=dict(
        title="Viscoelastic Relaxation Toward Equilibrium Sea Level After TPW",
        relaxation_model="Exponential Maxwell-type relaxation",
        equilibrium_fraction=TARGET_FRACTION,
        number_of_steps=N_STEPS,
        relaxation_timescale_tau=tau,
        note=(
            "Step 0 represents present-day geometry (0% relaxation). "
            "Final step reaches ~95% of equilibrium configuration."
        ),
    ),
)

# -----------------------------
# WRITE NETCDF
# -----------------------------

ds_out.to_netcdf(
    OUTPUT_NC,
    encoding={
        "effective_elevation": {
            "zlib": True,
            "complevel": 4,
            "dtype": "float32",
        }
    },
)

print(f"Written: {OUTPUT_NC}")
print(f"Relaxation timescale tau = {tau:.2f} steps")
