import xarray as xr
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
from rasterio.features import rasterize
from affine import Affine

# -----------------------------
# FILE PATHS
# -----------------------------
BEDMACHINE_FILE = "bedmachine4.nc"
VELOCITY_FILE   = "antarctica_velocity_measures.nc"
BASIN_SHP       = "ANT_Basins_IMBIE2_v1.6.shp"

# -----------------------------
# Load BedMachine
# -----------------------------
bm = xr.open_dataset(BEDMACHINE_FILE, decode_cf=True)

thickness = bm["thickness"]
mask = bm["mask"]

x = bm["x"].values
y = bm["y"].values

dx = float(np.abs(x[1] - x[0]))
dy = float(np.abs(y[1] - y[0]))

transform = (
    Affine.translation(x.min() - dx / 2, y.max() + dy / 2)
    * Affine.scale(dx, -dy)
)

cell_area = dx * dy

# -----------------------------
# Load IMBIE2 basins
# -----------------------------
gdf = gpd.read_file(BASIN_SHP).to_crs("EPSG:3031")

def classify_sheet(region):
    r = str(region).lower()
    if "west" in r or "peninsula" in r:
        return 2  # WAIS
    elif "east" in r:
        return 1  # EAIS
    else:
        return 0

gdf["sheet"] = gdf["Regions"].apply(classify_sheet)

sheet_mask = rasterize(
    zip(gdf.geometry, gdf["sheet"]),
    out_shape=(len(y), len(x)),
    transform=transform,
    fill=0,
    dtype="uint8"
)

sheet_mask = xr.DataArray(sheet_mask, coords={"y": y, "x": x}, dims=("y", "x"))

# -----------------------------
# Load velocity
# -----------------------------
vel = xr.open_dataset(VELOCITY_FILE, decode_cf=True)

vx = vel[[v for v in vel.data_vars if "vx" in v.lower()][0]].astype("float64")
vy = vel[[v for v in vel.data_vars if "vy" in v.lower()][0]].astype("float64")

speed = np.sqrt(vx*vx + vy*vy).interp_like(thickness, method="nearest")

# -----------------------------
# Residence-time proxy (years)
# -----------------------------
vmin = 1.0  # m/yr
tau = thickness / xr.where(speed > vmin, speed, vmin)

# Mask non-ice
valid = (thickness > 0) & (mask != 0)
tau = tau.where(valid)

# -----------------------------
# Volume weights
# -----------------------------
volume = thickness * cell_area
volume = volume.where(valid)

# -----------------------------
# Prepare histogram bins
# -----------------------------
log_tau = np.log10(tau)

bins = np.linspace(1, 6, 60)  # 10 yr → 1 Myr

def weighted_histogram(sheet_id):
    sel = (sheet_mask == sheet_id)
    lt = log_tau.where(sel).values.flatten()
    w  = volume.where(sel).values.flatten()

    ok = np.isfinite(lt) & np.isfinite(w)
    hist, edges = np.histogram(lt[ok], bins=bins, weights=w[ok])

    hist = hist / hist.sum()  # normalize
    cdf = np.cumsum(hist[::-1])[::-1]  # P(> τ)

    centers = 0.5 * (edges[:-1] + edges[1:])
    return centers, hist, cdf

EAIS_x, EAIS_pdf, EAIS_cdf = weighted_histogram(1)
WAIS_x, WAIS_pdf, WAIS_cdf = weighted_histogram(2)

# -----------------------------
# Plot: PDF + CDF
# -----------------------------
fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True)

# PDFs
axes[0,0].plot(EAIS_x, EAIS_pdf, color="black")
axes[0,1].plot(WAIS_x, WAIS_pdf, color="black")

axes[0,0].set_title("EAIS – volume-weighted residence-time PDF")
axes[0,1].set_title("WAIS – volume-weighted residence-time PDF")

# CDFs
axes[1,0].plot(EAIS_x, EAIS_cdf, color="black")
axes[1,1].plot(WAIS_x, WAIS_cdf, color="black")

axes[1,0].set_title("EAIS – P(residence time > τ)")
axes[1,1].set_title("WAIS – P(residence time > τ)")

for ax in axes.flatten():
    ax.grid(True, linestyle=":", linewidth=0.8)

axes[1,0].set_xlabel("log₁₀(thickness / speed) [years]")
axes[1,1].set_xlabel("log₁₀(thickness / speed) [years]")

axes[0,0].set_ylabel("Volume fraction per bin")
axes[1,0].set_ylabel("Cumulative volume fraction")

plt.tight_layout()
plt.savefig(
    "antarctica_residence_time_histograms_EAIS_WAIS.png",
    dpi=300,
    bbox_inches="tight"
)
plt.close()

print("Saved: antarctica_residence_time_histograms_EAIS_WAIS.png")

# -----------------------------
# Print key cumulative fractions
# -----------------------------
thresholds = [3, 4, 5]  # 1 kyr, 10 kyr, 100 kyr

print("\nCumulative volume fractions exceeding τ:\n")

for t in thresholds:
    e = EAIS_cdf[np.argmin(np.abs(EAIS_x - t))]
    w = WAIS_cdf[np.argmin(np.abs(WAIS_x - t))]
    print(f">10^{t:.0f} yr  |  EAIS: {100*e:5.1f}%   WAIS: {100*w:5.1f}%")
