#!/usr/bin/env python3
"""
Generate an ensemble of conjugate shear stress maps from TPW-style rotations
and sample them against the World Stress Map (WSM).

Baseline:
  - Rotation = 104°
  - Axis = 31°E meridian (northward)

Randomized alternates:
  - Random Euler poles
  - Random rotation magnitudes
  - Random rotation sense
"""

import argparse
import numpy as np
import pandas as pd

EARTH_RADIUS = 6371.0
DEG2RAD = np.pi / 180.0

# ---------------------------------------------------------
# Geometry utilities
# ---------------------------------------------------------

def sph_to_cart(lat, lon):
    lat *= DEG2RAD
    lon *= DEG2RAD
    return np.array([
        np.cos(lat) * np.cos(lon),
        np.cos(lat) * np.sin(lon),
        np.sin(lat)
    ])

def cart_to_sph(v):
    x, y, z = v
    lat = np.arcsin(z)
    lon = np.arctan2(y, x)
    return lat / DEG2RAD, lon / DEG2RAD

def rotate_vector(v, axis, angle_deg):
    theta = angle_deg * DEG2RAD
    axis = axis / np.linalg.norm(axis)
    return (
        v * np.cos(theta)
        + np.cross(axis, v) * np.sin(theta)
        + axis * np.dot(axis, v) * (1 - np.cos(theta))
    )

# ---------------------------------------------------------
# TPW stress model (simplified but physical)
# ---------------------------------------------------------

def compute_shear_azimuth(lat, lon, pole_lat, pole_lon):
    """
    Compute conjugate shear azimuths from TPW geometry.
    """
    p = sph_to_cart(lat, lon)
    pole = sph_to_cart(pole_lat, pole_lon)

    # Tangential velocity direction
    v = np.cross(pole, p)
    if np.linalg.norm(v) == 0:
        return np.nan, np.nan

    v /= np.linalg.norm(v)

    # Project to local tangent plane
    north = np.array([0, 0, 1])
    east = np.cross(north, p)
    east /= np.linalg.norm(east)
    north = np.cross(p, east)

    ve = np.dot(v, east)
    vn = np.dot(v, north)

    az = (np.arctan2(ve, vn) / DEG2RAD) % 360.0

    # Conjugate shears ±45°
    return az, (az + 90.0) % 360.0

# ---------------------------------------------------------
# Main
# ---------------------------------------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--wsm", required=True, help="WSM CSV")
    ap.add_argument("--output", default="shear_ensemble.csv")
    ap.add_argument("--random", type=int, default=0,
                    help="Number of randomized TPW realizations")
    ap.add_argument("--rot-min", type=float, default=30)
    ap.add_argument("--rot-max", type=float, default=150)
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    rng = np.random.default_rng(args.seed)
    wsm = pd.read_csv(args.wsm)

    results = []

    # -----------------------------------------------------
    # Define ensemble
    # -----------------------------------------------------

    ensemble = []

    # Baseline TPW (104° along 31°E)
    ensemble.append({
        "run_id": "TPW_BASELINE",
        "pole_lat": 0.0,
        "pole_lon": 31.0,
        "rot_deg": 104.0,
        "rot_sign": +1
    })

    # Randomized alternatives
    for i in range(args.random):
        u = rng.uniform(-1, 1)
        pole_lat = np.arcsin(u) / DEG2RAD
        pole_lon = rng.uniform(0, 360)
        rot_deg = rng.uniform(args.rot_min, args.rot_max)
        rot_sign = rng.choice([-1, +1])

        ensemble.append({
            "run_id": f"TPW_RANDOM_{i:03d}",
            "pole_lat": pole_lat,
            "pole_lon": pole_lon,
            "rot_deg": rot_deg,
            "rot_sign": rot_sign
        })

    # -----------------------------------------------------
    # Compute stress for each realization
    # -----------------------------------------------------

    for cfg in ensemble:
        print(f"Generating {cfg['run_id']}")

        pole = sph_to_cart(cfg["pole_lat"], cfg["pole_lon"])
        angle = cfg["rot_sign"] * cfg["rot_deg"]

        for _, row in wsm.iterrows():
            lat, lon = row["LAT"], row["LON"]
            wsm_az = row["AZI"] % 360.0

            az1, az2 = compute_shear_azimuth(lat, lon,
                                             cfg["pole_lat"],
                                             cfg["pole_lon"])
            if np.isnan(az1):
                continue

            mis1 = abs((az1 - wsm_az + 90) % 180 - 90)
            mis2 = abs((az2 - wsm_az + 90) % 180 - 90)

            results.append({
                "RUN_ID": cfg["run_id"],
                "POLE_LAT": cfg["pole_lat"],
                "POLE_LON": cfg["pole_lon"],
                "ROT_DEG": cfg["rot_deg"],
                "ROT_SIGN": cfg["rot_sign"],
                "LAT": lat,
                "LON": lon,
                "WSM_AZ": wsm_az,
                "NET1_AZ": az1,
                "NET2_AZ": az2,
                "NET1_MISFIT": mis1,
                "NET2_MISFIT": mis2,
                "BEST_MISFIT": min(mis1, mis2),
                "QUALITY": row.get("QUALITY", None),
                "PLATE": row.get("PLATE", None)
            })

    out = pd.DataFrame(results)
    out.to_csv(args.output, index=False)
    print(f"Saved {args.output}")

# ---------------------------------------------------------
if __name__ == "__main__":
    main()
