#!/usr/bin/env python3
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from math import radians, sin, cos, acos

# ============================================================
# Geometry utilities
# ============================================================

def angular_distance_deg(lat1, lon1, lat2, lon2):
    """Great-circle angular distance in degrees."""
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dot = (
        sin(lat1) * sin(lat2)
        + cos(lat1) * cos(lat2) * cos(lon1 - lon2)
    )
    dot = max(-1.0, min(1.0, dot))
    return np.degrees(acos(dot))


def euler_polar_coords(lat, lon, lat0, lon0):
    """
    Convert Euler pole (lat,lon) into polar coordinates
    relative to baseline pole (lat0,lon0).
    Returns (theta, r) where:
      r     = angular distance (degrees)
      theta = azimuth around baseline pole (radians)
    """
    lat, lon, lat0, lon0 = map(radians, [lat, lon, lat0, lon0])

    # Angular distance
    dot = (
        sin(lat) * sin(lat0)
        + cos(lat) * cos(lat0) * cos(lon - lon0)
    )
    dot = max(-1.0, min(1.0, dot))
    r = acos(dot)

    # Azimuth
    y = sin(lon - lon0) * cos(lat)
    x = cos(lat0) * sin(lat) - sin(lat0) * cos(lat) * cos(lon - lon0)
    theta = np.arctan2(y, x)

    return theta, np.degrees(r)


# ============================================================
# Main
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Rankings CSV")
    ap.add_argument("--baseline", default="TPW_BASELINE")
    ap.add_argument("--region", default="Arctic")
    ap.add_argument("--output-prefix", default="morans_rank")
    args = ap.parse_args()

    df = pd.read_csv(args.input)

    # --------------------------------------------------------
    # Euler pole metadata
    # --------------------------------------------------------
    if not {"POLE_LAT", "POLE_LON"}.issubset(df.columns):
        raise RuntimeError(
            "CSV must include POLE_LAT and POLE_LON for Euler-space plotting"
        )

    # Baseline pole
    b = df[df["RUN_ID"] == args.baseline].iloc[0]
    b_lat, b_lon = b["POLE_LAT"], b["POLE_LON"]

    # Euler angular distance (degrees)
    df["EULER_DIST_DEG"] = df.apply(
        lambda r: angular_distance_deg(
            r["POLE_LAT"], r["POLE_LON"], b_lat, b_lon
        ),
        axis=1
    )

    # Filter region
    df = df[df["REGION"] == args.region]
    if df.empty:
	    raise RuntimeError(
	        f"No valid Moran results available for region '{args.region}'. "
	        "This usually means the region was skipped due to insufficient "
	        "data or invalid neighbor geometry."
	    )


    # ========================================================
    # Plot 1: Rank vs Euler distance (faceted by scale)
    # ========================================================

    scales = sorted(df["SCALE_KM"].unique())
    fig, axes = plt.subplots(
        1, len(scales),
        figsize=(6 * len(scales), 5),
        sharey=True
    )

    if len(scales) == 1:
        axes = [axes]

    for ax, L in zip(axes, scales):
        d = df[df["SCALE_KM"] == L]
        rand = d[d["RUN_ID"] != args.baseline]
        base = d[d["RUN_ID"] == args.baseline]

        ax.scatter(
            rand["EULER_DIST_DEG"], rand["RANK"],
            alpha=0.6, label="Random ensembles"
        )

        ax.scatter(
            base["EULER_DIST_DEG"], base["RANK"],
            color="red", s=100, label="Baseline"
        )

        ax.set_title(f"L = {int(L)} km")
        ax.set_xlabel("Euler-pole angular distance (degrees)")
        ax.invert_yaxis()
        ax.grid(True, alpha=0.3)

    axes[0].set_ylabel("Rank (lower = better)")
    axes[0].legend()

    plt.tight_layout()
    plt.savefig(f"{args.output_prefix}_rank_vs_euler.png", dpi=200)
    plt.close()

    # ========================================================
    # Plot 2: Rank vs wavelength
    # ========================================================

    fig, ax = plt.subplots(figsize=(7, 5))

    for run_id, g in df.groupby("RUN_ID"):
        if run_id == args.baseline:
            ax.plot(
                g["SCALE_KM"], g["RANK"],
                "-o", color="red", linewidth=3, label="Baseline"
            )
        else:
            ax.plot(
                g["SCALE_KM"], g["RANK"],
                color="gray", alpha=0.3
            )

    ax.set_xlabel("Wavelength (km)")
    ax.set_ylabel("Rank")
    ax.set_xscale("log")
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3)
    ax.legend()

    plt.tight_layout()
    plt.savefig(f"{args.output_prefix}_rank_vs_scale.png", dpi=200)
    plt.close()

    # ========================================================
    # Plot 3: Polar Euler-space visualization
    # ========================================================

    fig = plt.figure(figsize=(7, 7))
    ax = plt.subplot(111, polar=True)

    for run_id, g in df.groupby("RUN_ID"):
        theta, r = zip(*[
            euler_polar_coords(
                row["POLE_LAT"], row["POLE_LON"],
                b_lat, b_lon
            )
            for _, row in g.iterrows()
        ])

        if run_id == args.baseline:
            ax.scatter(
                theta, r,
                color="red", s=120, label="Baseline"
            )
        else:
            ax.scatter(
                theta, r,
                color="gray", alpha=0.6
            )

    ax.set_theta_zero_location("N")
    ax.set_theta_direction(-1)
    ax.set_rlabel_position(135)
    ax.set_title("Euler pole space (relative to baseline)", pad=20)

    ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))

    plt.tight_layout()
    plt.savefig(f"{args.output_prefix}_euler_polar.png", dpi=200)
    plt.close()

    print("Saved plots:")
    print(f"  {args.output_prefix}_rank_vs_euler.png")
    print(f"  {args.output_prefix}_rank_vs_scale.png")
    print(f"  {args.output_prefix}_euler_polar.png")


if __name__ == "__main__":
    main()
