import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"

# -----------------------------
# Load data
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

thickness = bm["thickness"]
mask = bm["mask"]

x = bm["x"].values
y = bm["y"].values

vx = vel[[v for v in vel.data_vars if "vx" in v.lower()][0]].astype("float64")
vy = vel[[v for v in vel.data_vars if "vy" in v.lower()][0]].astype("float64")

speed = np.sqrt(vx * vx + vy * vy).interp_like(thickness, method="nearest")

# -----------------------------
# Residence-time proxy (years)
# -----------------------------
vmin = 0.01  # <-- CRITICAL FIX
tau = thickness / xr.where(speed > vmin, speed, vmin)

valid = (thickness > 0) & (mask != 0)
tau = tau.where(valid)

# -----------------------------
# Age thresholds
# -----------------------------
thresholds = [
    ("1 kyr",   1e3),
    ("10 kyr",  1e4),
    ("100 kyr", 1e5),
]

# -----------------------------
# Plot
# -----------------------------
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)

extent = [x.min(), x.max(), y.min(), y.max()]

for ax, (label, T) in zip(axes, thresholds):
    exceed = (tau >= T)

    # Integer mask: 0 = no, 1 = yes
    img = exceed.astype(np.uint8)

    # Force background to 0, not NaN
    img = img.where(valid, 0)

    ax.imshow(
        img,
        origin="lower",
        extent=extent,
        cmap="Greys",
        vmin=0,
        vmax=1,
        interpolation="nearest"
    )

    ax.set_title(f"Ice likely older than {label}")
    ax.axis("off")

fig.suptitle(
    "Antarctic Ice Age Exceedance Maps\n"
    "(residence-time proxy: thickness / surface speed)",
    fontsize=14
)

plt.tight_layout(rect=[0, 0, 1, 0.92])
plt.savefig(
    "antarctica_age_exceedance_maps_fixed.png",
    dpi=300,
    bbox_inches="tight"
)
plt.close()

print("Saved: antarctica_age_exceedance_maps_fixed.png")
