#!/usr/bin/env python3

import numpy as np
import pandas as pd
from scipy.interpolate import RegularGridInterpolator
from tqdm import trange
import importlib.util

# ------------------------------------------------------------
# 1. LOAD SKS DATA (Silver 1996)
# ------------------------------------------------------------

sks = pd.read_csv("silver1996.csv", sep=";", encoding="utf-8-sig")
sks.columns = sks.columns.str.strip()

lat_obs = sks["Latitude"].values
lon_obs = sks["Longitude"].values
phi_obs = sks["phi_deg"].values % 180.0


# ------------------------------------------------------------
# 2. IMPORT SHEAR FIELD FROM shear-map.py (UNCHANGED)
# ------------------------------------------------------------

spec = importlib.util.spec_from_file_location("shear_map", "shear-map.py")
shear_map = importlib.util.module_from_spec(spec)
spec.loader.exec_module(shear_map)

lon_grid = shear_map.lon          # 1D, degrees
lat_grid = shear_map.lat          # 1D, degrees
theta_grid = shear_map.theta      # 2D, radians


# ------------------------------------------------------------
# 3. CONVERT TO GEOGRAPHIC AZIMUTH + CONJUGATE SHEAR NETS
# ------------------------------------------------------------

# Convert mathematical angle (CCW from East) to
# geographic azimuth (CW from North), axial
theta0 = (90.0 - np.rad2deg(theta_grid)) % 180.0

theta_net1 = (theta0 + 45.0) % 180.0
theta_net2 = (theta0 - 45.0) % 180.0


interp_net1 = RegularGridInterpolator(
    (lat_grid, lon_grid),
    theta_net1,
    bounds_error=False,
    fill_value=np.nan
)

interp_net2 = RegularGridInterpolator(
    (lat_grid, lon_grid),
    theta_net2,
    bounds_error=False,
    fill_value=np.nan
)

def shear_net_azimuths(lon, lat):
    return (
        interp_net1((lat, lon)),
        interp_net2((lat, lon))
    )


# ------------------------------------------------------------
# 4. SAMPLE SHEAR NETS AT SKS SITES + MASK
# ------------------------------------------------------------

net1_at_sks = np.zeros(len(lon_obs))
net2_at_sks = np.zeros(len(lon_obs))

for i in range(len(lon_obs)):
    net1_at_sks[i], net2_at_sks[i] = shear_net_azimuths(
        lon_obs[i], lat_obs[i]
    )

valid = np.isfinite(net1_at_sks) & np.isfinite(net2_at_sks)

print(f"Total SKS records: {len(valid)}")
print(f"Valid shear-field overlaps: {valid.sum()}")
print(f"Excluded (outside shear grid): {(~valid).sum()}")

lon_obs = lon_obs[valid]
lat_obs = lat_obs[valid]
phi_obs = phi_obs[valid]
net1_at_sks = net1_at_sks[valid]
net2_at_sks = net2_at_sks[valid]


# ------------------------------------------------------------
# 5. AXIAL MISFIT FUNCTION
# ------------------------------------------------------------

def axial_misfit(theta_obs, theta_model):
    d = np.abs(theta_obs - theta_model) % 180.0
    return np.minimum(d, 180.0 - d)


# ------------------------------------------------------------
# 6. OBSERVED MISFIT (MINIMUM OF NET-1 / NET-2)
# ------------------------------------------------------------

misfit1 = axial_misfit(phi_obs, net1_at_sks)
misfit2 = axial_misfit(phi_obs, net2_at_sks)

obs_misfit = np.minimum(misfit1, misfit2)

obs_mean = obs_misfit.mean()
obs_median = np.median(obs_misfit)
obs_variance = 1.0 - np.mean(np.cos(np.deg2rad(2.0 * obs_misfit)))

print("Observed mean misfit:", obs_mean)
print("Observed median misfit:", obs_median)
print("Observed axial variance:", obs_variance)


# ------------------------------------------------------------
# 7. NULL HYPOTHESIS: GLOBAL AXIAL ROTATION
# ------------------------------------------------------------

K = 10000
null_mean = np.zeros(K)
null_variance = np.zeros(K)

for k in trange(K):
    alpha = np.random.uniform(0.0, 180.0)
    rotated_phi = (phi_obs + alpha) % 180.0

    m1 = axial_misfit(rotated_phi, net1_at_sks)
    m2 = axial_misfit(rotated_phi, net2_at_sks)
    m = np.minimum(m1, m2)

    null_mean[k] = m.mean()
    null_variance[k] = 1.0 - np.mean(np.cos(np.deg2rad(2.0 * m)))


# ------------------------------------------------------------
# 8. EMPIRICAL P-VALUES
# ------------------------------------------------------------

p_mean = np.mean(null_mean <= obs_mean)
p_var = np.mean(null_variance <= obs_variance)

print("p-value (mean misfit):", p_mean)
print("p-value (axial variance):", p_var)


# ------------------------------------------------------------
# 9. SAVE RESULTS
# ------------------------------------------------------------

out = sks.loc[valid].copy()
out["shear_net1_azimuth"] = net1_at_sks
out["shear_net2_azimuth"] = net2_at_sks
out["angular_misfit_deg"] = obs_misfit

out.to_csv("Silver1996_shear_misfit_results.csv", index=False)
print("Saved: Silver1996_shear_misfit_results.csv")
