#!/usr/bin/env python3
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from math import radians, sin, cos, acos


# ============================================================
# Geometry utilities
# ============================================================

def angular_distance_deg(lat1, lon1, lat2, lon2):
    """Great-circle angular distance in degrees."""
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dot = (
        sin(lat1) * sin(lat2)
        + cos(lat1) * cos(lat2) * cos(lon1 - lon2)
    )
    dot = max(-1.0, min(1.0, dot))
    return np.degrees(acos(dot))


def euler_polar_coords(lat, lon, lat0, lon0):
    """
    Convert Euler pole (lat,lon) into polar coordinates
    relative to baseline pole (lat0,lon0).
    """
    d = angular_distance_deg(lat, lon, lat0, lon0)
    theta = np.radians(lon)
    r = np.radians(d)
    return theta, np.degrees(r)


# ============================================================
# Shared plotting routine
# ============================================================

def make_plots(df, baseline_id, b_lat, b_lon, output_prefix):
    # ========================================================
    # Plot 1: Rank vs Euler distance (faceted by scale)
    # ========================================================

    scales = sorted(df["SCALE_KM"].unique())
    fig, axes = plt.subplots(
        1, len(scales),
        figsize=(6 * len(scales), 5),
        sharey=True
    )

    if len(scales) == 1:
        axes = [axes]

    for ax, L in zip(axes, scales):
        d = df[df["SCALE_KM"] == L]
        rand = d[d["RUN_ID"] != baseline_id]
        base = d[d["RUN_ID"] == baseline_id]

        ax.scatter(
            rand["EULER_DIST_DEG"], rand["RANK"],
            s=12, alpha=0.6, label="Random"
        )

        if not base.empty:
            ax.scatter(
                base["EULER_DIST_DEG"], base["RANK"],
                s=50, c="red", label="Baseline"
            )

        ax.set_title(f"Scale = {int(L)} km")
        ax.set_xlabel("Euler distance to baseline (deg)")
        ax.set_ylabel("Rank")
        ax.invert_yaxis()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f"{output_prefix}_rank_vs_euler.png", dpi=200)
    plt.close()

    # ========================================================
    # Plot 2: Rank vs wavelength
    # ========================================================

    fig, ax = plt.subplots(figsize=(7, 5))

    for run_id, g in df.groupby("RUN_ID"):
        if run_id == baseline_id:
            ax.plot(
                g["SCALE_KM"], g["RANK"],
                linewidth=2.5, label="Baseline"
            )
        else:
            ax.plot(
                g["SCALE_KM"], g["RANK"],
                color="gray", alpha=0.3
            )

    ax.set_xlabel("Wavelength (km)")
    ax.set_ylabel("Rank")
    ax.set_xscale("log")
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3)
    ax.legend()

    plt.tight_layout()
    plt.savefig(f"{output_prefix}_rank_vs_scale.png", dpi=200)
    plt.close()

    # ========================================================
    # Plot 3: Polar Euler-space visualization
    # ========================================================

    fig = plt.figure(figsize=(7, 7))
    ax = plt.subplot(111, polar=True)

    for run_id, g in df.groupby("RUN_ID"):
        theta, r = zip(*[
            euler_polar_coords(
                row["POLE_LAT"], row["POLE_LON"],
                b_lat, b_lon
            )
            for _, row in g.iterrows()
        ])
        ax.plot(theta, r, alpha=0.5)

    ax.set_title("Euler-space scatter (baseline at origin)")
    ax.set_rlabel_position(225)
    ax.grid(True)

    plt.tight_layout()
    plt.savefig(f"{output_prefix}_euler_polar.png", dpi=200)
    plt.close()

    print("Saved plots:")
    print(f"  {output_prefix}_rank_vs_euler.png")
    print(f"  {output_prefix}_rank_vs_scale.png")
    print(f"  {output_prefix}_euler_polar.png")


# ============================================================
# Main
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Rankings CSV")
    ap.add_argument("--baseline", default="TPW_BASELINE")
    ap.add_argument("--region", default="Arctic")
    ap.add_argument("--output-prefix", default="morans_rank")
    ap.add_argument(
        "--all-regions-mean",
        action="store_true",
        help="Also generate plots for the mean across all regions combined"
    )
    args = ap.parse_args()

    df = pd.read_csv(args.input)

    # --------------------------------------------------------
    # Euler pole metadata
    # --------------------------------------------------------
    if not {"POLE_LAT", "POLE_LON"}.issubset(df.columns):
        raise RuntimeError(
            "CSV must include POLE_LAT and POLE_LON for Euler-space plotting"
        )

    # Baseline pole
    b = df[df["RUN_ID"] == args.baseline].iloc[0]
    b_lat, b_lon = b["POLE_LAT"], b["POLE_LON"]

    # Euler angular distance (degrees)
    df["EULER_DIST_DEG"] = df.apply(
        lambda r: angular_distance_deg(
            r["POLE_LAT"], r["POLE_LON"], b_lat, b_lon
        ),
        axis=1
    )

    # Preserve full dataset for optional averaging
    df_full = df.copy()

    # ========================================================
    # Region-specific plots (existing behavior)
    # ========================================================
    df_region = df[df["REGION"] == args.region]
    if df_region.empty:
        raise RuntimeError(
            f"No valid Moran results available for region '{args.region}'."
        )

    make_plots(
        df_region,
        baseline_id=args.baseline,
        b_lat=b_lat,
        b_lon=b_lon,
        output_prefix=args.output_prefix
    )

    # ========================================================
    # Optional: averaged results across ALL regions
    # ========================================================
    if args.all_regions_mean:
        df_mean = (
            df_full.groupby(["RUN_ID", "SCALE_KM"], as_index=False)
            .agg({
                "RANK": "mean",
                "EULER_DIST_DEG": "mean",
                "POLE_LAT": "mean",
                "POLE_LON": "mean"
            })
        )
        df_mean["REGION"] = "ALL_MEAN"

        make_plots(
            df_mean,
            baseline_id=args.baseline,
            b_lat=b_lat,
            b_lon=b_lon,
            output_prefix=f"{args.output_prefix}_ALLMEAN"
        )


if __name__ == "__main__":
    main()
