#!/usr/bin/env python3
"""
mcp_data_server.py -- Acme SaaS data platform MCP server.

Exposes four tools, two resources, and one prompt over stdio transport.
Launched as a subprocess by D1_mcp_client.ipynb to demonstrate MCP client
usage. Also reused as-is in D2_mcp_server.ipynb.

Usage (standalone):  python mcp_data_server.py
Usage (from D1/D2):  started automatically by the notebook via subprocess.

Datasets required (generated by scripts/generate_data.py):
  ../data/warehouse_usage.csv
  ../data/pipeline_jobs.csv
  ../data/runbook_warehouse_cost.md
"""
from __future__ import annotations

import csv
import sys
from datetime import datetime, timedelta
from pathlib import Path

# FastMCP is the high-level server interface introduced in mcp>=1.0.0
from mcp.server.fastmcp import FastMCP

# --------------------------------------------------------------------------
# Paths
# --------------------------------------------------------------------------
_HERE = Path(__file__).resolve().parent          # notebooks/
DATA_DIR = _HERE.parent / "data"                  # data/

# --------------------------------------------------------------------------
# Load datasets once at startup
# --------------------------------------------------------------------------
def _load_csv(filename: str) -> list[dict]:
    path = DATA_DIR / filename
    if not path.exists():
        sys.stderr.write(
            f"[mcp_data_server] ERROR: {path} not found. "
            "Run: python ../scripts/generate_data.py\n"
        )
        sys.exit(1)
    with open(path, newline="") as f:
        return list(csv.DictReader(f))


USAGE_ROWS = _load_csv("warehouse_usage.csv")
JOBS_ROWS  = _load_csv("pipeline_jobs.csv")

# --------------------------------------------------------------------------
# Create the FastMCP server
# --------------------------------------------------------------------------
mcp = FastMCP(
    "acme-data-platform",
    instructions=(
        "Data platform monitoring server for Acme SaaS Co. "
        "Provides access to Snowflake warehouse usage and pipeline job history. "
        "Use get_warehouse_summary for recent cost trends (last-N-days window), "
        "get_warehouse_spend_range for a specific date range, "
        "get_failing_jobs for pipeline health, and check_job_sla for freshness checks."
    ),
)

# ==========================================================================
# TOOLS
# ==========================================================================

@mcp.tool()
def get_warehouse_summary(days: int = 7) -> str:
    """Return total and average daily Snowflake credits per warehouse for the last N days.

    Call this to understand current cost distribution across warehouses, identify the
    most expensive warehouses, or compare spend across time windows.

    Args:
        days: Number of days to look back from the most recent date in the dataset.
    """
    all_dates = sorted({r["date"] for r in USAGE_ROWS})
    if not all_dates:
        return "No data available."

    latest = all_dates[-1]
    # days - 1 so that "last N days" returns exactly N rows for a daily dataset
    # (>= cutoff with timedelta(days=N) would include N+1 calendar dates).
    cutoff = (
        datetime.strptime(latest, "%Y-%m-%d") - timedelta(days=days - 1)
    ).strftime("%Y-%m-%d")

    recent = [r for r in USAGE_ROWS if r["date"] >= cutoff]
    if not recent:
        return f"No rows after cutoff {cutoff}."

    # Aggregate by warehouse
    by_wh: dict[str, list[float]] = {}
    for r in recent:
        by_wh.setdefault(r["warehouse_name"], []).append(float(r["credits_used"]))

    lines = [
        f"Warehouse credit summary -- last {days} days (from {cutoff} to {latest})",
        "",
        f"{'Warehouse':<22} {'Days':>5} {'Total Cr':>10} {'Daily Avg':>10} {'Max Day':>10}",
        "-" * 62,
    ]
    for wh, vals in sorted(by_wh.items(), key=lambda x: -sum(x[1])):
        lines.append(
            f"{wh:<22} {len(vals):>5} {sum(vals):>10.2f} "
            f"{sum(vals)/len(vals):>10.2f} {max(vals):>10.2f}"
        )
    lines.append(f"\nTotal credits across all warehouses: {sum(sum(v) for v in by_wh.values()):.2f}")
    return "\n".join(lines)


@mcp.tool()
def get_failing_jobs(days: int = 14) -> str:
    """Return all pipeline job runs with status=failed in the last N days.

    Call this to identify which jobs are failing, view their error messages, and
    find failure clusters. For correlating failures with warehouse cost spikes,
    call alongside get_warehouse_summary.

    Args:
        days: Number of days to look back from the most recent started_at date.
    """
    all_dates = sorted({r["started_at"][:10] for r in JOBS_ROWS})
    if not all_dates:
        return "No data available."

    latest = all_dates[-1]
    # days - 1 for the same inclusive-window reason as get_warehouse_summary.
    cutoff = (
        datetime.strptime(latest, "%Y-%m-%d") - timedelta(days=days - 1)
    ).strftime("%Y-%m-%d")

    failed = [
        r for r in JOBS_ROWS
        if r["status"] == "failed" and r["started_at"][:10] >= cutoff
    ]

    if not failed:
        return f"No failed jobs in the last {days} days (since {cutoff})."

    # Group by job name to show failure counts
    by_job: dict[str, list[dict]] = {}
    for r in failed:
        by_job.setdefault(r["job_name"], []).append(r)

    lines = [f"Failed pipeline jobs -- last {days} days (since {cutoff})", ""]
    lines.append(f"{'Job':<35} {'Failures':>8}")
    lines.append("-" * 45)
    for job, runs in sorted(by_job.items(), key=lambda x: -len(x[1])):
        lines.append(f"{job:<35} {len(runs):>8}")

    lines.append(f"\n-- Most recent failure details (up to 15) --")
    for r in sorted(failed, key=lambda x: x["started_at"])[-15:]:
        lines.append(f"\n  [{r['started_at'][:10]}] {r['job_name']}")
        lines.append(f"    Warehouse : {r['warehouse_name']}")
        if r.get("error_message"):
            lines.append(f"    Error     : {r['error_message'][:120]}")

    lines.append(f"\nTotal failed runs: {len(failed)}")
    return "\n".join(lines)


@mcp.tool()
def check_job_sla(job_name: str) -> str:
    """Check whether a named pipeline job is within its SLA window.

    Returns the timestamp of the last successful run and whether more than 26 hours
    have elapsed since then (the SLA threshold for daily jobs at Acme). Call this
    when asked "is this table up to date?" or "when did X last succeed?"

    Args:
        job_name: Exact name of the pipeline job (e.g. 'dbt_fct_subscriptions').
    """
    successful = [
        r for r in JOBS_ROWS
        if r["job_name"] == job_name
        and r["status"] == "success"
        and r.get("finished_at")
    ]

    if not successful:
        # Check if the job exists at all
        all_job_names = sorted({r["job_name"] for r in JOBS_ROWS})
        if job_name not in all_job_names:
            return (
                f"Job '{job_name}' not found in dataset.\n"
                f"Available jobs: {', '.join(all_job_names)}"
            )
        return f"No successful runs found for '{job_name}' in the dataset."

    # Use the last date in the dataset as the "reference now"
    all_dates = sorted({r["started_at"][:10] for r in JOBS_ROWS})
    dataset_end = datetime.strptime(all_dates[-1], "%Y-%m-%d") + timedelta(hours=6)

    last_ok = max(datetime.fromisoformat(r["finished_at"]) for r in successful)
    hours_ago = (dataset_end - last_ok).total_seconds() / 3600
    sla_hours = 26

    if hours_ago <= sla_hours:
        sla_status = f"OK (last success {hours_ago:.1f}h ago, within {sla_hours}h SLA)"
    else:
        sla_status = f"BREACHED ({hours_ago:.1f}h since last success, SLA is {sla_hours}h)"

    total_runs = len([r for r in JOBS_ROWS if r["job_name"] == job_name])
    success_runs = len(successful)

    return (
        f"Job              : {job_name}\n"
        f"Last success     : {last_ok.isoformat(timespec='seconds')}\n"
        f"Hours since      : {hours_ago:.1f}h\n"
        f"SLA (26h)        : {sla_status}\n"
        f"Success rate     : {success_runs}/{total_runs} "
        f"({100*success_runs/total_runs:.0f}%)"
    )


@mcp.tool()
def get_warehouse_spend_range(
    warehouse_name: str,
    start_date: str,
    end_date: str,
) -> str:
    """Return daily Snowflake credit spend for a named warehouse over an exact date range.

    Use this when you have a specific incident window (e.g. from an incident_analysis
    prompt) and need to filter to that exact range rather than a rolling N-day window.
    For open-ended "last N days" queries, prefer get_warehouse_summary instead.

    Args:
        warehouse_name: Exact warehouse name (e.g. WH_BI_M), or "*" for all warehouses.
        start_date: Inclusive start date in YYYY-MM-DD format.
        end_date: Inclusive end date in YYYY-MM-DD format.
    """
    if start_date > end_date:
        return f"start_date ({start_date}) must be <= end_date ({end_date})."

    rows = [
        r for r in USAGE_ROWS
        if (warehouse_name == "*" or r["warehouse_name"] == warehouse_name)
        and start_date <= r["date"] <= end_date
    ]

    if not rows:
        all_wh = sorted({r["warehouse_name"] for r in USAGE_ROWS})
        return (
            f"No data for warehouse='{warehouse_name}' between {start_date} and {end_date}.\n"
            f"Available warehouses: {', '.join(all_wh)}"
        )

    by_date: dict[str, float] = {}
    for r in rows:
        by_date[r["date"]] = by_date.get(r["date"], 0.0) + float(r["credits_used"])

    label = warehouse_name if warehouse_name != "*" else "ALL warehouses"
    lines = [
        f"Warehouse spend -- {label} -- {start_date} to {end_date}",
        "",
        f"{'Date':<12} {'Credits':>10}",
        "-" * 25,
    ]
    for date in sorted(by_date):
        lines.append(f"{date:<12} {by_date[date]:>10.2f}")

    total = sum(by_date.values())
    n_days = len(by_date)
    lines += [
        "-" * 25,
        f"{'Total':<12} {total:>10.2f}",
        f"{'Daily avg':<12} {total / n_days:>10.2f}  ({n_days} days)",
    ]
    return "\n".join(lines)


# ==========================================================================
# RESOURCES
# ==========================================================================

@mcp.resource("data://warehouse-runbook")
def get_warehouse_runbook() -> str:
    """Snowflake warehouse cost runbook. Includes incident patterns, optimization
    procedures, and SQL reference queries for warehouse cost analysis."""
    runbook_path = DATA_DIR / "runbook_warehouse_cost.md"
    if runbook_path.exists():
        return runbook_path.read_text()
    return (
        "Runbook file not found at data/runbook_warehouse_cost.md. "
        "Run: python ../scripts/generate_data.py"
    )


@mcp.resource("data://dataset-schema")
def get_dataset_schema() -> str:
    """Schema reference for the warehouse_usage.csv and pipeline_jobs.csv datasets."""
    return """Dataset: warehouse_usage.csv  (630 rows, 90 days x 7 warehouses)
  date                  string  ISO date (YYYY-MM-DD), 2025-07-01 to 2025-09-28
  warehouse_name        string  One of: WH_AD_HOC_S, WH_BI_M, WH_DS_L, WH_ELT_L,
                                        WH_ELT_M, WH_EMBEDDED_M, WH_REVERSE_S
  credits_used          float   Snowflake credits consumed that day
  query_count           int     Number of queries executed
  avg_queue_time_s      float   Average queue wait in seconds
  avg_execution_time_s  float   Average query execution in seconds

Dataset: pipeline_jobs.csv  (668 rows: 6 daily jobs x 90 days + 2 weekday-only jobs x 64 days)
  run_id                int     Primary key (sequential)
  job_name              string  One of: dbt_fct_subscriptions, dbt_fct_mrr,
                                        dbt_fct_events, dbt_dim_accounts,
                                        airflow_sf_ingest, airflow_stripe_ingest,
                                        hightouch_crm_sync, hightouch_marketing
  warehouse_name        string  Warehouse used for this run
  started_at            string  ISO datetime (YYYY-MM-DDTHH:MM:SS)
  finished_at           string  ISO datetime or empty if failed/skipped
  duration_s            int     Run duration in seconds (empty if failed/skipped)
  status                string  success | failed | skipped
  credits_used          float   Credits consumed (0.0 if failed/skipped)
  rows_processed        int     Rows processed (0 if failed/skipped)
  error_message         string  Error details (empty if success/skipped)

Known anomalies:
  warehouse_usage.csv : WH_BI_M +40% credits 2025-07-22 to 2025-08-10
                        WH_DS_L 3.5x spike on 2025-08-30
  pipeline_jobs.csv   : dbt_fct_subscriptions Stripe schema drift cluster 2025-07-15 to 2025-07-17
                          error: KeyError column amount_due not found in stg_stripe__invoices
                        dbt_fct_mrr skipped 2025-07-15 to 2025-07-17 (downstream dependency)
                        hightouch_crm_sync and hightouch_marketing skipped 2025-08-05 to 2025-08-06
                        multi-job failure event 2025-08-20:
                          dbt_fct_subscriptions (Stripe drift recurrence)
                          airflow_stripe_ingest (warehouse timeout)
                          dbt_fct_events (WH_ELT_L unexpected suspend)"""


# ==========================================================================
# PROMPTS
# ==========================================================================

@mcp.prompt()
def incident_analysis(warehouse_name: str, date_range: str) -> str:
    """Structured prompt template for a Snowflake cost incident analysis.

    Args:
        warehouse_name: The warehouse to analyse (e.g. WH_BI_M).
        date_range: The date range of interest (e.g. '2025-07-20 to 2025-07-25').
    """
    return f"""You are a Snowflake cost analyst performing a structured incident analysis.

Warehouse  : {warehouse_name}
Date range : {date_range}

Perform the following analysis steps in order:
1. Call get_warehouse_spend_range(warehouse_name="{warehouse_name}", start_date="<start>", end_date="<end>") to retrieve the exact daily credit spend for the period. Parse the start and end dates from the date_range string above.
2. Identify any days where spend exceeded 1.25x the 30-day median.
3. Call get_failing_jobs with an appropriate day range to find correlated pipeline failures.
4. Read the resource data://warehouse-runbook and identify any matching incident pattern.
5. Call check_job_sla for any jobs with failure clusters during the period.

Then write a structured incident report with these sections:
  - Summary: what happened and when (2-3 sentences)
  - Credit impact: total excess credits vs baseline
  - Root cause hypothesis: most likely explanation based on the runbook
  - Correlated events: pipeline failures, job anomalies
  - Recommended action: specific next steps with owners

Be precise: cite exact dates, credit figures, and error messages from the tool results."""


# ==========================================================================
# Entry point
# ==========================================================================
if __name__ == "__main__":
    # FastMCP.run() defaults to stdio transport, which is what D1/D2 notebooks use.
    mcp.run()
