import xarray as xr
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
from rasterio.features import rasterize
from affine import Affine

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"
BASIN_SHP       = "ANT_Basins_IMBIE2_v1.6.shp"

# -----------------------------
# Load BedMachine
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)

surface = bm["surface"]
thickness = bm["thickness"]
mask = bm["mask"]

x = bm["x"].values
y = bm["y"].values

dx = float(np.abs(x[1] - x[0]))
dy = float(np.abs(y[1] - y[0]))

transform = (
    Affine.translation(x.min() - dx/2, y.max() + dy/2)
    * Affine.scale(dx, -dy)
)

# -----------------------------
# Load IMBIE2 basins
# -----------------------------
gdf = gpd.read_file(BASIN_SHP).to_crs("EPSG:3031")

def classify_sheet(region):
    r = str(region).lower()
    if "west" in r or "peninsula" in r:
        return 2  # WAIS
    elif "east" in r:
        return 1  # EAIS
    else:
        return 0

gdf["sheet"] = gdf["Regions"].apply(classify_sheet)

sheet_mask = rasterize(
    zip(gdf.geometry, gdf["sheet"]),
    out_shape=(len(y), len(x)),
    transform=transform,
    fill=0,
    dtype="uint8"
)

sheet_mask = xr.DataArray(sheet_mask, coords={"y": y, "x": x}, dims=("y", "x"))

# -----------------------------
# Load velocity
# -----------------------------
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

vx = vel[[v for v in vel.data_vars if "vx" in v.lower()][0]].astype("float64")
vy = vel[[v for v in vel.data_vars if "vy" in v.lower()][0]].astype("float64")

speed = np.sqrt(vx*vx + vy*vy).interp_like(thickness, method="nearest")

# -----------------------------
# Velocity-weighted effective volume
# -----------------------------
cell_area = dx * dy
vmin = 1.0

eff_vol = (thickness * cell_area) / xr.where(speed > vmin, speed, vmin)
eff_vol = eff_vol.where(thickness > 0)

# -----------------------------
# Age bins
# -----------------------------
bins = [
    (">3000 m (~100 kyr–1+ Myr)", (mask == 2) & (surface > 3000)),
    ("2000–3000 m (~10–100 kyr)", (mask == 2) & (surface > 2000) & (surface <= 3000)),
    ("1000–2000 m (~1–10 kyr)",   (mask == 2) & (surface > 1000) & (surface <= 2000)),
    ("<1000 m grounded (~<1 kyr)",(mask == 2) & (surface <= 1000)),
    ("Floating ice (~<1 kyr)",    (mask == 3)),
]

def fractions(sheet_id):
    vals = []
    total = eff_vol.where(sheet_mask == sheet_id).sum().item()
    for _, cond in bins:
        vals.append(100 * eff_vol.where((sheet_mask == sheet_id) & cond).sum().item() / total)
    return np.array(vals)

EAIS = fractions(1)
WAIS = fractions(2)

labels = [b[0] for b in bins]
ypos = np.arange(len(labels))

# -----------------------------
# Plot
# -----------------------------
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5), sharey=True)

for ax, data, title in zip(
    axes,
    [EAIS, WAIS],
    ["East Antarctica (EAIS)", "West Antarctica (WAIS)"]
):
    ax.hlines(ypos, 0, data, color="black", linewidth=2)
    ax.plot(data, ypos, "o", color="black")
    ax.set_title(title)
    ax.set_xlim(0, 70)
    ax.grid(axis="x", linestyle=":", linewidth=0.8)

axes[0].set_yticks(ypos)
axes[0].set_yticklabels(labels)
axes[0].invert_yaxis()

fig.supxlabel("Velocity-weighted share of ice volume (%)")
fig.suptitle("Velocity-weighted Antarctic Ice Volume by Elevation and Ice Sheet")

plt.tight_layout(rect=[0, 0, 1, 0.93])
plt.savefig("EAIS_WAIS_velocity_weighted_comparison.png", dpi=300)
plt.close()

print("Saved: EAIS_WAIS_velocity_weighted_comparison.png")
