import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"

# -----------------------------
# Load data
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

thickness = bm["thickness"]
mask = bm["mask"]

vx = vel[[v for v in vel.data_vars if "vx" in v.lower()][0]].astype("float64")
vy = vel[[v for v in vel.data_vars if "vy" in v.lower()][0]].astype("float64")

speed = np.sqrt(vx*vx + vy*vy).interp_like(thickness, method="nearest")

# -----------------------------
# Residence time (years)
# -----------------------------
vmin = 1.0  # m/yr
tau = thickness / xr.where(speed > vmin, speed, vmin)

tau = tau.where(thickness > 0)
tau = tau.where(mask != 0)

# -----------------------------
# Plot
# -----------------------------
fig, ax = plt.subplots(figsize=(7, 7))

im = ax.imshow(
    np.log10(tau),
    origin="lower",
    cmap="cividis",
    vmin=1,   # ~10 years
    vmax=6    # ~1 Myr
)

ax.set_title("Antarctic Ice Residence-Time Proxy\nlog₁₀(thickness / surface speed)")
ax.axis("off")

cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("log₁₀(years)")
cbar.set_ticks([1, 2, 3, 4, 5, 6])
cbar.set_ticklabels(["10", "100", "1 kyr", "10 kyr", "100 kyr", "1 Myr"])

plt.savefig("antarctica_residence_time_map.png", dpi=300, bbox_inches="tight")
plt.close()

print("Saved: antarctica_residence_time_map.png")
