#!/usr/bin/env python3
"""
Spatial randomized null test for TPW return diagnostics
(FINAL CONSOLIDATED VERSION).

Features:
- Handles grid mismatch between Z and dz safely
- Interpolates Z onto dz grid once
- Uses flattened grid indexing throughout
- Optimized for large-memory Apple Silicon systems

Inputs:
- tpw_return_sequence_event_guided_highcadence.nc
- dzdt_event_guided.nc
- early_homo_sites.geojson
- early_civilizations.geojson

Outputs:
- null_distributions.csv
- null_summary.csv
"""

import json
import numpy as np
import xarray as xr
import pandas as pd
from scipy.spatial import cKDTree
from skimage import measure
from tqdm import trange

# ============================================================
# PARAMETERS
# ============================================================

N_ITER = 20000        # adjust as desired
SEED = 42
np.random.seed(SEED)

# ============================================================
# LOAD DATASETS (DZ GRID IS MASTER)
# ============================================================

ds_z  = xr.open_dataset("tpw_return_sequence_event_guided_highcadence.nc")
ds_dz = xr.open_dataset("dzdt_event_guided.nc")

Z_raw = ds_z["effective_elevation"].values.astype(np.float32)
dz    = ds_dz["dzdt_event_guided"].values.astype(np.float32)

lat = ds_dz["lat"].values.astype(np.float32)
lon = ds_dz["lon"].values.astype(np.float32)

lat_z = ds_z["lat"].values
lon_z = ds_z["lon"].values

n_t, n_lat, n_lon = dz.shape
n_grid = n_lat * n_lon

# ============================================================
# INTERPOLATE Z ONTO DZ GRID (ONCE)
# ============================================================

print("Interpolating Z onto dz grid...")

Z = np.empty_like(dz, dtype=np.float32)

for i in trange(n_t, desc="Z interpolation"):
    Z[i] = (
        xr.DataArray(
            Z_raw[i],
            coords=[lat_z, lon_z],
            dims=["lat", "lon"]
        )
        .interp(lat=lat, lon=lon, method="linear")
        .values
    )

# ============================================================
# FLATTEN SPATIAL DIMENSIONS
# ============================================================

Z_flat  = Z.reshape(n_t, n_grid)
dz_flat = dz.reshape(n_t, n_grid)

# ============================================================
# GRID POINTS
# ============================================================

lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij")
grid_points = np.column_stack([lat2d.ravel(), lon2d.ravel()])

# ============================================================
# LAND MASK (STATIC)
# ============================================================

land_mask = np.any(Z > 0, axis=0).ravel()
land_idx = np.where(land_mask)[0]
n_land = len(land_idx)

# ============================================================
# LOAD SITE DATA
# ============================================================

def load_sites(fname):
    with open(fname) as f:
        g = json.load(f)
    return np.array([
        (feat["geometry"]["coordinates"][1],
         feat["geometry"]["coordinates"][0])
        for feat in g["features"]
    ], dtype=np.float32)

sites_homo = load_sites("early_homo_sites.geojson")
sites_civ  = load_sites("early_civilizations.geojson")

n_homo = len(sites_homo)
n_civ  = len(sites_civ)

# ============================================================
# MAP SITES TO GRID INDICES
# ============================================================

grid_tree = cKDTree(grid_points)
_, homo_grid_idx = grid_tree.query(sites_homo)
_, civ_grid_idx  = grid_tree.query(sites_civ)

# ============================================================
# PRECOMPUTE ZERO-CONTOUR DISTANCE MAPS
# ============================================================

print("Precomputing zero-contour distance maps...")

zero_dist_maps = np.full((n_t, n_grid), np.nan, dtype=np.float32)

for i in trange(n_t, desc="Zero-contours"):
    contours = measure.find_contours(Z[i], 0.0)
    if not contours:
        continue

    pts = []
    for c in contours:
        clat = np.interp(c[:, 0], np.arange(n_lat), lat)
        clon = np.interp(c[:, 1], np.arange(n_lon), lon)
        pts.append(np.column_stack([clat, clon]))

    pts = np.vstack(pts)
    tree = cKDTree(pts)
    zero_dist_maps[i], _ = tree.query(grid_points)

# ============================================================
# OBSERVED METRICS
# ============================================================

def observed_metrics(grid_idx):
    dz_mean = np.mean(np.abs(dz_flat[:, grid_idx]))
    zvar = np.nanmean([
        np.nanvar(zero_dist_maps[i, grid_idx])
        for i in range(n_t)
    ])
    return dz_mean, zvar

obs_homo = observed_metrics(homo_grid_idx)
obs_civ  = observed_metrics(civ_grid_idx)

# ============================================================
# MONTE CARLO SPATIAL NULL
# ============================================================

records = np.zeros((N_ITER, 4), dtype=np.float32)

for k in trange(N_ITER, desc="Spatial null"):
    idx_h = np.random.choice(land_idx, n_homo, replace=False)
    idx_c = np.random.choice(land_idx, n_civ, replace=False)

    dz_h = np.mean(np.abs(dz_flat[:, idx_h]))
    dz_c = np.mean(np.abs(dz_flat[:, idx_c]))

    var_h = np.nanmean([
        np.nanvar(zero_dist_maps[i, idx_h])
        for i in range(n_t)
    ])
    var_c = np.nanmean([
        np.nanvar(zero_dist_maps[i, idx_c])
        for i in range(n_t)
    ])

    records[k] = [dz_h, var_h, dz_c, var_c]

df_null = pd.DataFrame(
    records,
    columns=["dz_homo", "var_homo", "dz_civ", "var_civ"]
)
df_null.to_csv("null_distributions.csv", index=False)

# ============================================================
# EMPIRICAL P-VALUES
# ============================================================

def pval(obs, null):
    return np.mean(null <= obs)

summary = pd.DataFrame({
    "metric": ["mean|dz/dt|", "variance(zero-distance)"],
    "obs_homo": obs_homo,
    "p_homo": [
        pval(obs_homo[0], df_null["dz_homo"]),
        pval(obs_homo[1], df_null["var_homo"])
    ],
    "obs_civ": obs_civ,
    "p_civ": [
        pval(obs_civ[0], df_null["dz_civ"]),
        pval(obs_civ[1], df_null["var_civ"])
    ]
})

summary.to_csv("null_summary.csv", index=False)

print("✓ Spatial null test complete")
print("  - null_distributions.csv")
print("  - null_summary.csv")
