import xarray as xr
import numpy as np
import geopandas as gpd
from rasterio.features import rasterize
from affine import Affine
import pandas as pd

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"
BASIN_SHP       = "ANT_Basins_IMBIE2_v1.6.shp"

# -----------------------------
# Load BedMachine
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)

thickness = bm["thickness"]
mask = bm["mask"]

x = bm["x"].values
y = bm["y"].values

dx = float(np.abs(x[1] - x[0]))
dy = float(np.abs(y[1] - y[0]))

cell_area = dx * dy

transform = (
    Affine.translation(x.min() - dx / 2, y.max() + dy / 2)
    * Affine.scale(dx, -dy)
)

# -----------------------------
# Load IMBIE2 basins
# -----------------------------
gdf = gpd.read_file(BASIN_SHP).to_crs("EPSG:3031")

def classify_sheet(region):
    r = str(region).lower()
    if "west" in r or "peninsula" in r:
        return 2
    elif "east" in r:
        return 1
    else:
        return 0

gdf["sheet"] = gdf["Regions"].apply(classify_sheet)

sheet_mask = rasterize(
    zip(gdf.geometry, gdf["sheet"]),
    out_shape=(len(y), len(x)),
    transform=transform,
    fill=0,
    dtype="uint8"
)

sheet_mask = xr.DataArray(
    sheet_mask, coords={"y": y, "x": x}, dims=("y", "x")
)

# -----------------------------
# Load velocity
# -----------------------------
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

vx = vel[[v for v in vel.data_vars if "vx" in v.lower()][0]].astype("float64")
vy = vel[[v for v in vel.data_vars if "vy" in v.lower()][0]].astype("float64")

speed = np.sqrt(vx * vx + vy * vy).interp_like(thickness, method="nearest")

# -----------------------------
# Residence-time proxy (years)
# -----------------------------
vmin = 0.01  # m/yr  <-- CRITICAL CHANGE
tau = thickness / xr.where(speed > vmin, speed, vmin)

valid = (thickness > 0) & (mask != 0)
tau = tau.where(valid)

volume = (thickness * cell_area).where(valid)

# -----------------------------
# Thresholds (years)
# -----------------------------
thresholds = [1e3, 5e3, 1e4, 2e4, 5e4, 1e5]

# -----------------------------
# Computation function
# -----------------------------
def compute_fraction(tau_thresh, sheet_id=None):
    if sheet_id is None:
        sel = valid
    else:
        sel = valid & (sheet_mask == sheet_id)

    vol_total = volume.where(sel).sum().item()
    vol_exceed = volume.where(sel & (tau >= tau_thresh)).sum().item()

    return vol_exceed / vol_total

# -----------------------------
# Build table
# -----------------------------
rows = []

for T in thresholds:
    rows.append({
        "Threshold_yr": int(T),
        "Antarctica": 100 * compute_fraction(T),
        "EAIS":        100 * compute_fraction(T, sheet_id=1),
        "WAIS":        100 * compute_fraction(T, sheet_id=2),
    })

df = pd.DataFrame(rows)

pd.set_option("display.float_format", "{:.1f}".format)

print("\nVolume fraction (%) exceeding residence-time thresholds:\n")
print(df)

df.to_csv("antarctica_age_exceedance_volume_table_corrected.csv", index=False)

print("\nSaved: antarctica_age_exceedance_volume_table_corrected.csv")
