import argparse
import gzip
import io
import json
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urljoin

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

try:
    from astropy.io import fits
except ImportError as exc:
    raise SystemExit(
        "astropy is required for PREMOS FITS parsing. Install with: "
        "python -m pip install astropy"
    ) from exc


BASE_DIR = Path(__file__).parent
OUT_DIR = BASE_DIR / "tsi_satellite_sources"
OUT_PNG = OUT_DIR / "picard_premos_tsi_daily_monthly_annual.png"
OUT_CSV = OUT_DIR / "picard_premos_tsi_daily.csv"
OUT_MANIFEST = OUT_DIR / "picard_premos_tsi_manifest.txt"
RAW_DIR = BASE_DIR / "tsi_raw" / "picard_premos" / "raw"

IDOC_BASE = "http://idoc-picard.ias.u-psud.fr"
RECORDS_URL = f"{IDOC_BASE}/premosn2a/records?media=json"
TIME_EPOCH = pd.Timestamp("2000-01-01T00:00:00Z")


@dataclass(frozen=True)
class PremosRecord:
    filename: str
    dir_fits: str
    datetime_obs: str

    @property
    def remote_url(self) -> str:
        return urljoin(IDOC_BASE, self.dir_fits)

    @property
    def local_path(self) -> Path:
        return RAW_DIR / self.filename


def _get_json(url: str) -> dict:
    with urllib.request.urlopen(url, timeout=90) as response:
        return json.load(response)


def fetch_tsi_records() -> list[PremosRecord]:
    first = _get_json(f"{RECORDS_URL}&start=0&limit=1")
    total = int(first.get("total", 0))
    if total <= 0:
        return []

    records: list[PremosRecord] = []
    limit = 500
    for start in range(0, total, limit):
        payload = _get_json(f"{RECORDS_URL}&start={start}&limit={limit}")
        for row in payload.get("data", []):
            filename = row.get("filename", "")
            if "_TSI_" not in filename or not filename.endswith(".gz"):
                continue
            records.append(
                PremosRecord(
                    filename=filename,
                    dir_fits=row.get("dir_fits", ""),
                    datetime_obs=row.get("datetimeobs", ""),
                )
            )
    records.sort(key=lambda rec: rec.datetime_obs)
    return records


def ensure_downloads(records: list[PremosRecord], max_files: int | None = None) -> tuple[int, int]:
    RAW_DIR.mkdir(parents=True, exist_ok=True)

    to_process = records[:max_files] if max_files and max_files > 0 else records
    downloaded = 0
    skipped = 0
    for rec in to_process:
        if rec.local_path.exists() and rec.local_path.stat().st_size > 0:
            skipped += 1
            continue
        with urllib.request.urlopen(rec.remote_url, timeout=120) as response:
            payload = response.read()
        rec.local_path.write_bytes(payload)
        downloaded += 1
    return downloaded, skipped


def _extract_daily_from_fits(path: Path) -> pd.DataFrame:
    compressed = path.read_bytes()
    raw_fits = gzip.decompress(compressed)
    with fits.open(io.BytesIO(raw_fits), memmap=False) as hdul:
        table = hdul[1].data
        names = set(table.names or [])
        required = {"Time", "TSI_A", "Q_A", "TSI_B", "Q_B"}
        if not required.issubset(names):
            return pd.DataFrame(columns=["timestamp", "tsi", "source_file"])

        time_days = np.asarray(table["Time"], dtype=float)
        tsi_a = np.asarray(table["TSI_A"], dtype=float)
        q_a = np.asarray(table["Q_A"], dtype=int)
        tsi_b = np.asarray(table["TSI_B"], dtype=float)
        q_b = np.asarray(table["Q_B"], dtype=int)

    tsi = np.full(time_days.shape, np.nan, dtype=float)
    a_ok = np.isfinite(tsi_a) & (tsi_a > 1200) & (tsi_a < 1500) & (q_a == 0)
    b_ok = np.isfinite(tsi_b) & (tsi_b > 1200) & (tsi_b < 1500) & (q_b == 0)

    tsi[a_ok] = tsi_a[a_ok]
    tsi[~a_ok & b_ok] = tsi_b[~a_ok & b_ok]

    # Fallback: allow positive finite values if quality flags are never clean for a row.
    fallback_a = np.isnan(tsi) & np.isfinite(tsi_a) & (tsi_a > 1200) & (tsi_a < 1500)
    fallback_b = np.isnan(tsi) & np.isfinite(tsi_b) & (tsi_b > 1200) & (tsi_b < 1500)
    tsi[fallback_a] = tsi_a[fallback_a]
    tsi[fallback_b] = tsi_b[fallback_b]

    valid = np.isfinite(tsi) & np.isfinite(time_days)
    if not np.any(valid):
        return pd.DataFrame(columns=["timestamp", "tsi", "source_file"])

    timestamps = TIME_EPOCH + pd.to_timedelta(time_days[valid], unit="D")
    frame = pd.DataFrame(
        {
            "timestamp": timestamps.tz_convert(None),
            "tsi": tsi[valid],
            "source_file": path.name,
        }
    )
    frame["date"] = frame["timestamp"].dt.floor("D")
    return frame.groupby("date", as_index=False).agg(tsi=("tsi", "mean"), source_file=("source_file", "first"))


def parse_downloaded_daily() -> pd.DataFrame:
    frames: list[pd.DataFrame] = []
    for path in sorted(RAW_DIR.glob("PIC_PRE_N2A_TSI_*.fits.*.gz")):
        parsed = _extract_daily_from_fits(path)
        if not parsed.empty:
            frames.append(parsed)

    if not frames:
        return pd.DataFrame(columns=["date", "tsi", "source_file"])

    merged = pd.concat(frames, ignore_index=True)
    merged = merged.sort_values("date").drop_duplicates(subset=["date"], keep="last")
    return merged


def build_chart(frame: pd.DataFrame) -> None:
    monthly = frame.set_index("date")["tsi"].resample("MS").mean().reset_index(name="tsi")
    annual = frame.set_index("date")["tsi"].resample("YS").mean().reset_index(name="tsi")
    annual["plot_date"] = annual["date"] + pd.offsets.Day(181)

    plt.style.use("seaborn-v0_8-whitegrid")
    fig, ax = plt.subplots(figsize=(12.5, 6.8), dpi=170)

    ax.scatter(
        frame["date"],
        frame["tsi"],
        s=12,
        color="#22c55e",
        alpha=0.45,
        edgecolors="none",
        label="Daily values (PREMOS N2A)",
    )
    ax.plot(monthly["date"], monthly["tsi"], color="#16a34a", linewidth=1.8, label="Monthly means")
    ax.plot(
        annual["plot_date"],
        annual["tsi"],
        color="#14532d",
        linewidth=2.4,
        marker="o",
        markersize=3.4,
        label="Annual means",
    )

    ax.set_title("PICARD/PREMOS Total Solar Irradiance (N2A FITS Daily Product)", fontsize=16, pad=14)
    ax.set_ylabel("Total Solar Irradiance (W m^-2)")
    ax.set_xlabel("Year")
    span_years = (frame["date"].max() - frame["date"].min()).days / 365.25
    if span_years <= 5.0:
        ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
        ax.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))
        plt.setp(ax.get_xticklabels(), rotation=60, ha="right", fontsize=8.5)
    else:
        ax.xaxis.set_major_locator(mdates.YearLocator(1))
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
    ax.set_xlim(frame["date"].min(), frame["date"].max())
    ax.legend(loc="upper right", frameon=True)

    note_text = (
        "Files discovered from idoc-picard PREMOS N2A records API\n"
        "Per-row TSI chooses channel A (Q_A=0) then B (Q_B=0), with finite-value fallback"
    )
    ax.text(
        0.015,
        0.03,
        note_text,
        transform=ax.transAxes,
        fontsize=9.5,
        bbox=dict(boxstyle="round,pad=0.35", facecolor="white", edgecolor="#cbd5e1", alpha=0.96),
    )

    fig.tight_layout()
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    fig.savefig(OUT_PNG, bbox_inches="tight")
    plt.close(fig)


def write_outputs(frame: pd.DataFrame, records: list[PremosRecord], downloaded: int, skipped: int) -> None:
    OUT_DIR.mkdir(parents=True, exist_ok=True)

    daily = frame.copy()
    daily["date"] = daily["date"].dt.strftime("%Y-%m-%d")
    daily.to_csv(OUT_CSV, index=False)

    lines = [
        "PICARD PREMOS TSI manifest",
        "",
        "Dataset title: PREMOS N2A TSI FITS granules",
        "Platform: PICARD",
        "Instrument: PREMOS",
        "Collection endpoint: http://idoc-picard.ias.u-psud.fr/premosn2a/records?media=json",
        "Download root: http://idoc-picard.ias.u-psud.fr/sitools/datastorage/user/storagefits/products/premos/n2/tsi/",
        "",
        "Method note:",
        "- Parse Time, TSI_A/Q_A, TSI_B/Q_B from each FITS binary table.",
        "- Convert Time using epoch 2000-01-01 UTC and aggregate to daily means.",
        "- Prefer channel A where Q_A=0, otherwise channel B where Q_B=0, with finite-value fallback.",
        "",
        f"Records discovered (TSI): {len(records)}",
        f"Files downloaded this run: {downloaded}",
        f"Files already present this run: {skipped}",
        f"Raw files parsed: {len(list(RAW_DIR.glob('PIC_PRE_N2A_TSI_*.fits.*.gz')))}",
        f"Rows plotted: {len(frame)}",
        f"Date span plotted: {frame['date'].min().date()} to {frame['date'].max().date()}",
        f"TSI range plotted: {frame['tsi'].min():.4f} to {frame['tsi'].max():.4f} W m^-2",
    ]
    OUT_MANIFEST.write_text("\n".join(lines) + "\n")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Build PICARD PREMOS N2A TSI daily/monthly/annual outputs")
    parser.add_argument(
        "--max-files",
        type=int,
        default=None,
        help="Optionally limit how many discovered records are downloaded/processed for a partial run",
    )
    parser.add_argument(
        "--skip-download",
        action="store_true",
        help="Parse local raw files only (do not query/download records)",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    records: list[PremosRecord] = []
    downloaded = 0
    skipped = 0

    if not args.skip_download:
        records = fetch_tsi_records()
        if not records:
            raise RuntimeError("No PREMOS N2A TSI records discovered from API")
        downloaded, skipped = ensure_downloads(records, max_files=args.max_files)

    frame = parse_downloaded_daily()
    if frame.empty:
        raise RuntimeError("No PREMOS daily rows parsed from FITS files")

    build_chart(frame)
    write_outputs(frame, records, downloaded, skipped)

    monthly_rows = frame.set_index("date")["tsi"].resample("MS").mean().shape[0]
    annual_rows = frame.set_index("date")["tsi"].resample("YS").mean().shape[0]

    print("saved", OUT_PNG)
    print("saved", OUT_CSV)
    print("saved", OUT_MANIFEST)
    print("daily_rows", len(frame), "monthly_rows", monthly_rows, "annual_rows", annual_rows)
    print("date_span", frame["date"].min().date(), frame["date"].max().date())
    print("tsi_range", round(frame["tsi"].min(), 4), round(frame["tsi"].max(), 4))


if __name__ == "__main__":
    main()
