import xarray as xr
import numpy as np
import pandas as pd

# -----------------------------
# Load BedMachine v4 dataset
# -----------------------------
ds = xr.open_dataset("bedmachine4.nc")

surface = ds["surface"]
thickness = ds["thickness"]
mask = ds["mask"]

x = ds["x"]
y = ds["y"]

# -----------------------------
# Grid cell area (m²)
# -----------------------------
dx = np.abs(x[1] - x[0]).item()
dy = np.abs(y[1] - y[0]).item()
cell_area = dx * dy  # m²

# -----------------------------
# Helper function
# -----------------------------
def compute_stats(condition, label):
    ice = thickness.where(condition, 0.0)
    area_m2 = ((ice > 0).sum() * cell_area).item()
    volume_m3 = (ice.sum() * cell_area).item()
    return {
        "Bin": label,
        "Area_km2": area_m2 / 1e6,
        "Volume_km3": volume_m3 / 1e9
    }

results = []

# -----------------------------
# BedMachine v4 mask meanings:
# 0 = ocean
# 1 = ice-free land
# 2 = grounded ice
# 3 = floating ice
# -----------------------------

# >3000 m (oldest interior)
results.append(compute_stats(
    (mask == 2) & (surface > 3000),
    ">3000 m (≈100 kyr – 1+ Myr)"
))

# 2000–3000 m
results.append(compute_stats(
    (mask == 2) & (surface > 2000) & (surface <= 3000),
    "2000–3000 m (≈10–100 kyr)"
))

# 1000–2000 m
results.append(compute_stats(
    (mask == 2) & (surface > 1000) & (surface <= 2000),
    "1000–2000 m (≈1–10 kyr)"
))

# <1000 m grounded
results.append(compute_stats(
    (mask == 2) & (surface <= 1000),
    "<1000 m grounded (≈<1 kyr)"
))

# Floating ice shelves
results.append(compute_stats(
    mask == 3,
    "Floating ice (≈<1 kyr)"
))

# -----------------------------
# Build table
# -----------------------------
df = pd.DataFrame(results)

total_area = df["Area_km2"].sum()
total_volume = df["Volume_km3"].sum()

df["Area_%"] = 100 * df["Area_km2"] / total_area
df["Volume_%"] = 100 * df["Volume_km3"] / total_volume

df = df[[
    "Bin",
    "Area_km2", "Area_%",
    "Volume_km3", "Volume_%"
]]

pd.set_option("display.float_format", "{:,.2f}".format)
print(df)

print("\nTOTALS")
print(f"Total area:   {total_area:,.0f} km²")
print(f"Total volume:{total_volume:,.0f} km³")
