import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# -----------------------------
# Load BedMachine v4
# -----------------------------
ds = xr.open_dataset("bedmachine4.nc")

surface = ds["surface"]
thickness = ds["thickness"]
mask = ds["mask"]

# -----------------------------
# Age-proxy classification
# -----------------------------
age_class = xr.full_like(surface, np.nan)

# Grounded ice
age_class = age_class.where(mask != 2, 4)
age_class = age_class.where(~((mask == 2) & (surface > 1000)), 3)
age_class = age_class.where(~((mask == 2) & (surface > 2000)), 2)
age_class = age_class.where(~((mask == 2) & (surface > 3000)), 1)

# Floating ice
age_class = age_class.where(mask != 3, 5)

# Mask non-ice
age_class = age_class.where(thickness > 0)

# -----------------------------
# Colormap (ordered, subdued)
# -----------------------------
cmap = ListedColormap([
    "#3b0f70",
    "#8c2981",
    "#de4968",
    "#fe9f6d",
    "#f0f921"
])

# -----------------------------
# Plot
# -----------------------------
fig, ax = plt.subplots(figsize=(7, 7))

im = ax.imshow(age_class, origin="lower", cmap=cmap, vmin=1, vmax=5)
ax.set_title("Antarctic Ice Age Structure (Elevation Proxy)", fontsize=12)
ax.axis("off")

cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_ticks([1, 2, 3, 4, 5])
cbar.ax.set_yticklabels([
    ">3000 m  (~100 kyr – 1+ Myr)",
    "2000–3000 m  (~10–100 kyr)",
    "1000–2000 m  (~1–10 kyr)",
    "<1000 m grounded  (~<1 kyr)",
    "Floating ice  (~<1 kyr)"
])
cbar.set_label("Dominant ice age range")

plt.savefig(
    "antarctica_ice_age_map.png",
    dpi=300,
    bbox_inches="tight"
)
plt.close()

print("Saved figure: antarctica_ice_age_map.png")


# -----------------------------
# Data (from BedMachine v4 analysis)
# -----------------------------
labels = [
    ">3000 m\n(~100 kyr – 1+ Myr)",
    "2000–3000 m\n(~10–100 kyr)",
    "1000–2000 m\n(~1–10 kyr)",
    "<1000 m grounded\n(~<1 kyr)",
    "Floating ice\n(~<1 kyr)"
]

central = np.array([35.5, 37.8, 17.2, 6.7, 2.7])
low =     np.array([32.0, 34.0, 14.0, 5.0, 2.0])
high =    np.array([39.0, 41.0, 20.0, 9.0, 4.0])

y = np.arange(len(labels))

# -----------------------------
# Plot
# -----------------------------
fig, ax = plt.subplots(figsize=(7, 4.5))

ax.errorbar(
    central,
    y,
    xerr=[central - low, high - central],
    fmt="o",
    color="black",
    ecolor="black",
    elinewidth=2,
    capsize=4
)

ax.set_yticks(y)
ax.set_yticklabels(labels)
ax.invert_yaxis()

ax.set_xlabel("Share of total Antarctic ice volume (%)")
ax.set_xlim(0, 45)
ax.set_title(
    "Antarctic Ice Volume by Age Proxy\n(conservative robustness ranges)",
    fontsize=12
)

ax.grid(axis="x", linestyle=":", linewidth=0.8)

plt.savefig(
    "antarctica_ice_volume_ranges.png",
    dpi=300,
    bbox_inches="tight"
)
plt.close()

print("Saved figure: antarctica_ice_volume_ranges.png")
