import xarray as xr
import numpy as np
import geopandas as gpd
from rasterio.features import rasterize
from affine import Affine

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"
BASIN_SHP       = "ANT_Basins_IMBIE2_v1.6.shp"

# -----------------------------
# Load BedMachine
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)

surface = bm["surface"]
thickness = bm["thickness"]
mask = bm["mask"]

x = bm["x"].values
y = bm["y"].values

dx = float(np.abs(x[1] - x[0]))
dy = float(np.abs(y[1] - y[0]))

# BedMachine grid: y decreases downward
transform = (
    Affine.translation(x.min() - dx / 2, y.max() + dy / 2)
    * Affine.scale(dx, -dy)
)

out_shape = (len(y), len(x))

# -----------------------------
# Load IMBIE2 basins
# -----------------------------
gdf = gpd.read_file(BASIN_SHP).to_crs("EPSG:3031")

# -----------------------------
# EAIS vs WAIS classification using IMBIE2 "Regions"
# -----------------------------
# Typical values include:
#   "East Antarctica"
#   "West Antarctica"
#   "Antarctic Peninsula"

def classify_sheet(region):
    r = str(region).lower()
    if "west" in r or "peninsula" in r:
        return 2  # WAIS (including Antarctic Peninsula)
    elif "east" in r:
        return 1  # EAIS
    else:
        return 0  # fallback (should be rare)

gdf["sheet"] = gdf["Regions"].apply(classify_sheet)

# -----------------------------
# Rasterize basin mask
# -----------------------------
shapes = zip(gdf.geometry, gdf["sheet"])

sheet_mask = rasterize(
    shapes=shapes,
    out_shape=out_shape,
    transform=transform,
    fill=0,
    dtype="uint8"
)

sheet_mask = xr.DataArray(
    sheet_mask,
    coords={"y": y, "x": x},
    dims=("y", "x")
)

# -----------------------------
# Load velocity
# -----------------------------
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

vx_name = [v for v in vel.data_vars if "vx" in v.lower()][0]
vy_name = [v for v in vel.data_vars if "vy" in v.lower()][0]

vx = vel[vx_name].astype("float64")
vy = vel[vy_name].astype("float64")

speed = np.sqrt(vx * vx + vy * vy)
speed = speed.interp_like(thickness, method="nearest")

# -----------------------------
# Velocity-weighted effective volume
# -----------------------------
cell_area = dx * dy
vmin = 1.0  # m/yr

eff_vol = (thickness * cell_area) / xr.where(speed > vmin, speed, vmin)
eff_vol = eff_vol.where(thickness > 0)

# -----------------------------
# Age-proxy bins
# -----------------------------
age_bins = {
    ">3000 m":              (mask == 2) & (surface > 3000),
    "2000–3000 m":          (mask == 2) & (surface > 2000) & (surface <= 3000),
    "1000–2000 m":          (mask == 2) & (surface > 1000) & (surface <= 2000),
    "<1000 m grounded":     (mask == 2) & (surface <= 1000),
    "Floating ice":         (mask == 3),
}

# -----------------------------
# Compute fractions
# -----------------------------
def compute_sheet(sheet_id):
    vals = {}
    for label, cond in age_bins.items():
        vals[label] = eff_vol.where(
            (sheet_mask == sheet_id) & cond,
            0.0
        ).sum().item()

    total = sum(vals.values())

    if total == 0:
        raise RuntimeError(
            f"No ice volume found for sheet_id={sheet_id}. "
            f"Check basin rasterization alignment."
        )

    return {k: 100 * v / total for k, v in vals.items()}

EAIS = compute_sheet(1)
WAIS = compute_sheet(2)

# -----------------------------
# Output
# -----------------------------
import matplotlib.pyplot as plt
sheet_mask.plot()
plt.title("EAIS (1) vs WAIS (2) basin mask")
plt.show()

print("\nEAIS velocity-weighted fractions:\n")
for k, v in EAIS.items():
    print(f"{k:20s}: {v:6.2f} %")

print("\nWAIS velocity-weighted fractions:\n")
for k, v in WAIS.items():
    print(f"{k:20s}: {v:6.2f} %")
