from __future__ import annotations

import math
import re
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd


BASE_DIR = Path(__file__).parent
INPUT_DIR = BASE_DIR / "pre_sorce_individual_raw"
OUTPUT_DIR = BASE_DIR / "tsi_satellite_sources"
OUTPUT_PNG = OUTPUT_DIR / "pre_sorce_individual_missions_daily_monthly_annual.png"
OUTPUT_CSV = OUTPUT_DIR / "pre_sorce_individual_missions_daily.csv"
OUTPUT_MANIFEST = OUTPUT_DIR / "pre_sorce_individual_missions_manifest.txt"

TSI_MIN = 1200.0
TSI_MAX = 1500.0

DATE_PATTERNS = [
    re.compile(r"^(\d{4})-(\d{2})-(\d{2})$"),
    re.compile(r"^(\d{4})/(\d{2})/(\d{2})$"),
    re.compile(r"^(\d{8})$"),
]


@dataclass
class MissionSeries:
    label: str
    file_name: str
    frame: pd.DataFrame


def decimal_year_to_timestamp(value: float) -> pd.Timestamp | None:
    year = int(value)
    if year < 1800 or year > 2100:
        return None
    fraction = value - year
    days_in_year = 366 if pd.Timestamp(year=year, month=12, day=31).is_leap_year else 365
    day_offset = int(round(fraction * days_in_year))
    day_offset = max(0, min(day_offset, days_in_year - 1))
    return pd.Timestamp(year=year, month=1, day=1) + timedelta(days=day_offset)


def parse_date_token(token: str) -> pd.Timestamp | None:
    token = token.strip()
    if not token:
        return None

    for pattern in DATE_PATTERNS:
        match = pattern.match(token)
        if not match:
            continue
        if len(match.groups()) == 3:
            y, m, d = map(int, match.groups())
            if 1800 <= y <= 2100:
                try:
                    return pd.Timestamp(year=y, month=m, day=d)
                except ValueError:
                    return None
        if len(match.groups()) == 1:
            raw = match.group(1)
            y, m, d = int(raw[:4]), int(raw[4:6]), int(raw[6:8])
            if 1800 <= y <= 2100:
                try:
                    return pd.Timestamp(year=y, month=m, day=d)
                except ValueError:
                    return None

    return None


def parse_float_token(token: str) -> float | None:
    token = token.strip()
    if not token:
        return None
    token = token.replace("D", "E")
    try:
        value = float(token)
    except ValueError:
        return None
    if not math.isfinite(value):
        return None
    return value


def normalize_year(value: float) -> int | None:
    year = int(round(value))
    if 1800 <= year <= 2100:
        return year
    if 0 <= year <= 99:
        return 1900 + year if year >= 70 else 2000 + year
    return None


def parse_line(line: str) -> tuple[pd.Timestamp, float] | None:
    raw = line.strip()
    if not raw:
        return None
    if raw.startswith((";", "#", "%", "//", "!")):
        return None

    parts = [p for p in re.split(r"[\s,;\t]+", raw) if p]
    if len(parts) < 2:
        return None

    # Strategy 1: explicit date token + a plausible TSI token.
    for token in parts:
        date = parse_date_token(token)
        if date is None:
            continue
        tsi_candidates = []
        for maybe in parts:
            value = parse_float_token(maybe)
            if value is None:
                continue
            if TSI_MIN <= value <= TSI_MAX:
                tsi_candidates.append(value)
        if tsi_candidates:
            return date, float(tsi_candidates[-1])

    numbers: list[float] = []
    for token in parts:
        value = parse_float_token(token)
        if value is not None:
            numbers.append(value)

    if len(numbers) < 2:
        return None

    # Strategy 2: year + doy + tsi (common in mission files).
    for i, year_value in enumerate(numbers):
        year = normalize_year(year_value)
        if year is None:
            continue
        for j, doy_value in enumerate(numbers):
            if i == j:
                continue
            doy = int(round(doy_value))
            if not (1 <= doy <= 366):
                continue
            for k, tsi_value in enumerate(numbers):
                if k in (i, j):
                    continue
                if TSI_MIN <= tsi_value <= TSI_MAX:
                    date = pd.Timestamp(year=year, month=1, day=1) + timedelta(days=doy - 1)
                    return date, float(tsi_value)

    # Strategy 3: decimal year + tsi.
    for i, year_value in enumerate(numbers):
        if not (1800.0 <= year_value <= 2100.999):
            continue
        date = decimal_year_to_timestamp(year_value)
        if date is None:
            continue
        for j, tsi_value in enumerate(numbers):
            if i == j:
                continue
            if TSI_MIN <= tsi_value <= TSI_MAX:
                return date, float(tsi_value)

    return None


def infer_label(file_name: str) -> str:
    lower = file_name.lower()
    if "nimbus" in lower:
        return "NIMBUS-7 ERB Ch10C"
    if "acrim1" in lower or "acrim-1" in lower:
        return "ACRIM-1"
    if "acrim2" in lower or "acrim-2" in lower:
        return "ACRIM-2"
    if "erbs" in lower or "erbe" in lower:
        return "ERBS/ERBE"
    if "virgo" in lower:
        return "SOHO/VIRGO"
    if "hf" in lower:
        return "HF-style mission series"
    stem = Path(file_name).stem
    return stem.replace("_", " ")


def load_series(path: Path) -> MissionSeries:
    rows: list[tuple[pd.Timestamp, float]] = []
    for line in path.read_text(errors="ignore").splitlines():
        parsed = parse_line(line)
        if parsed is not None:
            rows.append(parsed)

    frame = pd.DataFrame(rows, columns=["date", "tsi"]).dropna()
    if frame.empty:
        raise ValueError(f"No parseable rows in {path.name}")

    frame = frame.sort_values("date").drop_duplicates(subset=["date"], keep="last")
    return MissionSeries(label=infer_label(path.name), file_name=path.name, frame=frame)


def gather_input_files() -> list[Path]:
    if not INPUT_DIR.exists():
        return []
    files = []
    for path in sorted(INPUT_DIR.iterdir()):
        if path.name.startswith("."):
            continue
        if path.suffix.lower() in {".source", ".md"}:
            continue
        if path.name.endswith(".source.txt"):
            continue
        if path.is_file():
            files.append(path)
    return files


def build_chart(series_list: list[MissionSeries]) -> None:
    colors = [
        ("#93c5fd", "#3b82f6", "#1d4ed8"),
        ("#fca5a5", "#ef4444", "#991b1b"),
        ("#a7f3d0", "#10b981", "#047857"),
        ("#fde68a", "#f59e0b", "#b45309"),
        ("#c4b5fd", "#8b5cf6", "#5b21b6"),
        ("#f9a8d4", "#ec4899", "#9d174d"),
    ]

    plt.style.use("seaborn-v0_8-whitegrid")
    fig, ax = plt.subplots(figsize=(13.0, 7.2), dpi=170)

    min_date: pd.Timestamp | None = None
    max_date: pd.Timestamp | None = None

    for idx, series in enumerate(series_list):
        c_daily, c_monthly, c_annual = colors[idx % len(colors)]
        frame = series.frame

        monthly = frame.set_index("date").resample("MS").mean().reset_index()
        annual = frame.set_index("date").resample("YS").mean().reset_index()
        annual["plot_date"] = annual["date"] + pd.offsets.Day(181)

        ax.scatter(
            frame["date"],
            frame["tsi"],
            s=7,
            color=c_daily,
            alpha=0.22,
            edgecolors="none",
            label=f"{series.label} daily",
        )
        ax.plot(monthly["date"], monthly["tsi"], color=c_monthly, linewidth=1.3, label=f"{series.label} monthly")
        ax.plot(
            annual["plot_date"],
            annual["tsi"],
            color=c_annual,
            linewidth=2.0,
            marker="o",
            markersize=3.2,
            label=f"{series.label} annual",
        )

        series_min = frame["date"].min()
        series_max = frame["date"].max()
        min_date = series_min if min_date is None else min(min_date, series_min)
        max_date = series_max if max_date is None else max(max_date, series_max)

    ax.set_title("NIMBUS-7 ERB Ch10C TSI (Raw Inputs)", fontsize=16, pad=14)
    ax.set_ylabel("Total Solar Irradiance (W m^-2)")
    ax.set_xlabel("Year")

    if min_date is not None and max_date is not None:
        ax.set_xlim(mdates.date2num(min_date), mdates.date2num(max_date))

    ax.legend(loc="upper left", ncol=2, frameon=True, fontsize=9)
    ax.text(
        0.015,
        0.02,
        "Separate mission streams only (no NOAA CDR blend in this figure)",
        transform=ax.transAxes,
        fontsize=9.5,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="#cbd5e1", alpha=0.96),
    )

    fig.tight_layout()
    fig.savefig(OUTPUT_PNG, bbox_inches="tight")
    plt.close(fig)


def write_outputs(series_list: list[MissionSeries]) -> None:
    combined_frames = []
    manifest_lines = [
        "Pre-SORCE individual mission raw ingestion manifest",
        "",
        "Input directory:",
        str(INPUT_DIR),
        "",
        "Included files:",
    ]

    for series in series_list:
        frame = series.frame.copy()
        frame["mission"] = series.label
        frame["source_file"] = series.file_name
        combined_frames.append(frame[["date", "tsi", "mission", "source_file"]])

        manifest_lines.append(f"- {series.file_name}")
        manifest_lines.append(f"  label: {series.label}")
        manifest_lines.append(f"  rows: {len(series.frame)}")
        manifest_lines.append(f"  span: {series.frame['date'].min().date()} to {series.frame['date'].max().date()}")

        sidecar = INPUT_DIR / f"{Path(series.file_name).stem}.source.txt"
        if sidecar.exists():
            manifest_lines.append(f"  sidecar: {sidecar.name}")

    combined = pd.concat(combined_frames, ignore_index=True).sort_values(["mission", "date"])
    combined["date"] = combined["date"].dt.strftime("%Y-%m-%d")
    combined.to_csv(OUTPUT_CSV, index=False)

    manifest_lines.extend(
        [
            "",
            f"Output chart: {OUTPUT_PNG.name}",
            f"Output data: {OUTPUT_CSV.name}",
        ]
    )
    OUTPUT_MANIFEST.write_text("\n".join(manifest_lines) + "\n")


def main() -> None:
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    INPUT_DIR.mkdir(parents=True, exist_ok=True)

    files = gather_input_files()
    if not files:
        print("No mission files found in", INPUT_DIR)
        print("Add raw files and rerun this script.")
        return

    loaded: list[MissionSeries] = []
    failures: list[str] = []

    for path in files:
        try:
            loaded.append(load_series(path))
        except (ValueError, TypeError, RuntimeError) as exc:
            failures.append(f"{path.name}: {exc}")

    if not loaded:
        print("No files could be parsed.")
        for item in failures:
            print("  -", item)
        return

    build_chart(loaded)
    write_outputs(loaded)

    print("saved", OUTPUT_PNG)
    print("saved", OUTPUT_CSV)
    print("saved", OUTPUT_MANIFEST)
    print("missions", len(loaded))
    for series in loaded:
        print(
            "series",
            series.label,
            "rows",
            len(series.frame),
            "span",
            series.frame["date"].min().date(),
            series.frame["date"].max().date(),
        )

    if failures:
        print("skipped files:")
        for item in failures:
            print("  -", item)


if __name__ == "__main__":
    main()
