"""
AI Assistant tools — functions the LLM can call to query SPBU data.

Each tool returns a dict that gets serialized to the LLM as context.
"""

from __future__ import annotations

from datetime import date, timedelta
from decimal import Decimal
from typing import Any

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.anomali import AnomaliRecord
from app.models.expenses import Expense
from app.models.operational import (
    LaporanShift,
    PenjualanNozzle,
    StatusLaporan,
    StockAdjustment,
    StockAdjustmentItem,
)
from app.models.penebusan import Penebusan, PenebusanItem, StatusPenebusan
from app.models.penerimaan import Penerimaan, PenerimaanItem
from app.models.penyetoran import Penyetoran
from app.models.product import Produk, ProdukHarga
from app.models.spbu import Nozzle, Shift, Spbu, Tangki


# ── Tool Definitions (sent to LLM) ──────────────────────────────────────────

TOOL_DEFINITIONS: list[dict[str, Any]] = [
    {
        "name": "get_ringkasan_harian",
        "description": "Dapatkan rangkuman operasional SPBU untuk satu hari atau rentang tanggal. Termasuk total penjualan, expenses, penyetoran, dan penerimaan BBM.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal_mulai": {"type": "string", "description": "Tanggal mulai (YYYY-MM-DD)"},
                "tanggal_akhir": {"type": "string", "description": "Tanggal akhir (YYYY-MM-DD)"},
            },
            "required": ["tanggal_mulai", "tanggal_akhir"],
        },
    },
    {
        "name": "get_penjualan",
        "description": "Dapatkan data penjualan BBM per produk dalam rentang tanggal tertentu, termasuk volume dan nilai rupiah.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal_mulai": {"type": "string", "description": "Tanggal mulai (YYYY-MM-DD)"},
                "tanggal_akhir": {"type": "string", "description": "Tanggal akhir (YYYY-MM-DD)"},
                "produk_nama": {"type": "string", "description": "Filter nama produk (opsional, misal: Pertalite, Pertamax)"},
            },
            "required": ["tanggal_mulai", "tanggal_akhir"],
        },
    },
    {
        "name": "get_stok",
        "description": "Dapatkan status stok terkini per tangki, termasuk volume terakhir, kapasitas, dan persentase isi.",
        "parameters": {
            "type": "object",
            "properties": {},
        },
    },
    {
        "name": "get_anomali",
        "description": "Dapatkan daftar anomali aktif (status new atau investigating).",
        "parameters": {
            "type": "object",
            "properties": {
                "tipe": {"type": "string", "description": "Filter tipe anomali (opsional: quota_exceeded, meter_discrepancy, losses_exceeded, negative_stock)"},
            },
        },
    },
    {
        "name": "get_expenses",
        "description": "Dapatkan data pengeluaran operasional dalam rentang tanggal.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal_mulai": {"type": "string", "description": "Tanggal mulai (YYYY-MM-DD)"},
                "tanggal_akhir": {"type": "string", "description": "Tanggal akhir (YYYY-MM-DD)"},
            },
            "required": ["tanggal_mulai", "tanggal_akhir"],
        },
    },
    {
        "name": "get_laporan_status",
        "description": "Dapatkan status laporan shift hari ini — mana yang sudah submit, draft, atau belum dibuat.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal": {"type": "string", "description": "Tanggal (YYYY-MM-DD), default hari ini"},
            },
        },
    },
    {
        "name": "get_tangki_status",
        "description": "Dapatkan kondisi detail setiap tangki: volume, kapasitas, produk, dan estimasi hari stok tersisa berdasarkan rata-rata penjualan.",
        "parameters": {
            "type": "object",
            "properties": {},
        },
    },
    {
        "name": "get_penebusan",
        "description": "Dapatkan status penebusan (DO/SO) BBM ke Pertamina — termasuk yang pending dan sudah diterima.",
        "parameters": {
            "type": "object",
            "properties": {
                "status": {"type": "string", "description": "Filter status (opsional: draft, waiting_so, submitted, partially_received, fully_received)"},
            },
        },
    },
    {
        "name": "get_penerimaan",
        "description": "Dapatkan data penerimaan BBM dalam rentang tanggal.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal_mulai": {"type": "string", "description": "Tanggal mulai (YYYY-MM-DD)"},
                "tanggal_akhir": {"type": "string", "description": "Tanggal akhir (YYYY-MM-DD)"},
            },
            "required": ["tanggal_mulai", "tanggal_akhir"],
        },
    },
    {
        "name": "get_penyetoran",
        "description": "Dapatkan data penyetoran kas dalam rentang tanggal.",
        "parameters": {
            "type": "object",
            "properties": {
                "tanggal_mulai": {"type": "string", "description": "Tanggal mulai (YYYY-MM-DD)"},
                "tanggal_akhir": {"type": "string", "description": "Tanggal akhir (YYYY-MM-DD)"},
            },
            "required": ["tanggal_mulai", "tanggal_akhir"],
        },
    },
    {
        "name": "get_harga_produk",
        "description": "Dapatkan harga produk yang berlaku saat ini dan riwayat perubahan harga.",
        "parameters": {
            "type": "object",
            "properties": {
                "produk_nama": {"type": "string", "description": "Nama produk (opsional, jika kosong tampilkan semua)"},
            },
        },
    },
    {
        "name": "get_info_spbu",
        "description": "Dapatkan informasi konfigurasi SPBU: nama, nomor Pertamina, jumlah tangki, shift, island, nozzle.",
        "parameters": {
            "type": "object",
            "properties": {},
        },
    },
]


# ── Tool Implementations ─────────────────────────────────────────────────────

async def execute_tool(
    db: AsyncSession,
    spbu_id: int,
    tool_name: str,
    arguments: dict[str, Any],
) -> dict[str, Any]:
    """Route a tool call to its implementation."""
    handlers = {
        "get_ringkasan_harian": _get_ringkasan_harian,
        "get_penjualan": _get_penjualan,
        "get_stok": _get_stok,
        "get_anomali": _get_anomali,
        "get_expenses": _get_expenses,
        "get_laporan_status": _get_laporan_status,
        "get_tangki_status": _get_tangki_status,
        "get_penebusan": _get_penebusan,
        "get_penerimaan": _get_penerimaan,
        "get_penyetoran": _get_penyetoran,
        "get_harga_produk": _get_harga_produk,
        "get_info_spbu": _get_info_spbu,
    }

    handler = handlers.get(tool_name)
    if not handler:
        return {"error": f"Tool '{tool_name}' not found"}

    return await handler(db, spbu_id, arguments)


def _parse_date(s: str) -> date:
    return date.fromisoformat(s)


def _dec(v: Any) -> str:
    return str(v) if v is not None else "0"


async def _get_ringkasan_harian(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    d1 = _parse_date(args["tanggal_mulai"])
    d2 = _parse_date(args["tanggal_akhir"])

    # Sales
    sq = (
        select(
            func.coalesce(func.sum(PenjualanNozzle.volume), 0),
            func.coalesce(func.sum(PenjualanNozzle.nilai), 0),
        )
        .join(LaporanShift)
        .where(LaporanShift.spbu_id == spbu_id, LaporanShift.tanggal.between(d1, d2))
    )
    vol, val = (await db.execute(sq)).one()

    # Expenses
    eq = select(func.coalesce(func.sum(Expense.jumlah), 0)).where(
        Expense.spbu_id == spbu_id, Expense.tanggal.between(d1, d2)
    )
    expenses = (await db.execute(eq)).scalar_one()

    # Penyetoran
    pq = select(func.coalesce(func.sum(Penyetoran.jumlah_setor), 0)).where(
        Penyetoran.spbu_id == spbu_id, Penyetoran.tanggal.between(d1, d2)
    )
    penyetoran = (await db.execute(pq)).scalar_one()

    # Penerimaan
    rq = (
        select(func.coalesce(func.sum(PenerimaanItem.volume_diterima), 0))
        .join(Penerimaan)
        .where(Penerimaan.spbu_id == spbu_id, Penerimaan.tanggal.between(d1, d2))
    )
    penerimaan = (await db.execute(rq)).scalar_one()

    return {
        "periode": f"{d1} s/d {d2}",
        "penjualan_volume_liter": _dec(vol),
        "penjualan_nilai_rupiah": _dec(val),
        "expenses_rupiah": _dec(expenses),
        "penyetoran_rupiah": _dec(penyetoran),
        "penerimaan_bbm_liter": _dec(penerimaan),
        "net_kas": _dec((val or 0) - (expenses or 0)),
    }


async def _get_penjualan(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    d1 = _parse_date(args["tanggal_mulai"])
    d2 = _parse_date(args["tanggal_akhir"])
    produk_filter = args.get("produk_nama")

    q = (
        select(
            Produk.nama, Produk.kode,
            func.coalesce(func.sum(PenjualanNozzle.volume), 0),
            func.coalesce(func.sum(PenjualanNozzle.nilai), 0),
        )
        .join(Nozzle, PenjualanNozzle.nozzle_id == Nozzle.id)
        .join(Produk, Nozzle.produk_id == Produk.id)
        .join(LaporanShift, PenjualanNozzle.laporan_shift_id == LaporanShift.id)
        .where(LaporanShift.spbu_id == spbu_id, LaporanShift.tanggal.between(d1, d2))
        .group_by(Produk.nama, Produk.kode)
        .order_by(func.sum(PenjualanNozzle.nilai).desc())
    )
    if produk_filter:
        q = q.where(Produk.nama.ilike(f"%{produk_filter}%"))

    rows = (await db.execute(q)).all()
    return {
        "periode": f"{d1} s/d {d2}",
        "produk": [
            {"nama": r[0], "kode": r[1], "volume_liter": _dec(r[2]), "nilai_rupiah": _dec(r[3])}
            for r in rows
        ],
        "total_volume": _dec(sum(r[2] for r in rows)),
        "total_nilai": _dec(sum(r[3] for r in rows)),
    }


async def _get_stok(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    tangkis = (await db.execute(
        select(Tangki).where(Tangki.spbu_id == spbu_id, Tangki.is_active.is_(True), Tangki.deleted_at.is_(None))
    )).scalars().all()

    result = []
    for t in tangkis:
        vol_q = (
            select(StockAdjustmentItem.volume_final_liter)
            .join(StockAdjustment)
            .where(StockAdjustment.spbu_id == spbu_id, StockAdjustmentItem.tangki_id == t.id)
            .order_by(StockAdjustment.tanggal.desc(), StockAdjustment.id.desc())
            .limit(1)
        )
        vol = (await db.execute(vol_q)).scalar_one_or_none()
        pct = round(float(vol / t.kapasitas_liter * 100), 1) if vol and t.kapasitas_liter else 0

        produk_q = select(Produk.nama).where(Produk.id == t.produk_id)
        produk_nama = (await db.execute(produk_q)).scalar_one_or_none() if t.produk_id else None

        result.append({
            "tangki": t.nama,
            "produk": produk_nama or "—",
            "volume_liter": _dec(vol),
            "kapasitas_liter": _dec(t.kapasitas_liter),
            "persen_isi": f"{pct}%",
            "status": "Kritis" if pct < 15 else "Rendah" if pct < 30 else "Normal",
        })

    return {"tangki": result}


async def _get_anomali(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    q = select(AnomaliRecord).where(
        AnomaliRecord.spbu_id == spbu_id,
        AnomaliRecord.status.in_(["new", "investigating"]),
    ).order_by(AnomaliRecord.created_at.desc()).limit(20)

    tipe = args.get("tipe")
    if tipe:
        q = q.where(AnomaliRecord.tipe == tipe)

    rows = (await db.execute(q)).scalars().all()
    return {
        "total_aktif": len(rows),
        "anomali": [
            {
                "id": r.id,
                "tipe": r.tipe,
                "status": r.status,
                "detail": r.detail,
                "created_at": r.created_at.isoformat() if r.created_at else None,
            }
            for r in rows
        ],
    }


async def _get_expenses(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    d1 = _parse_date(args["tanggal_mulai"])
    d2 = _parse_date(args["tanggal_akhir"])

    q = select(Expense).where(
        Expense.spbu_id == spbu_id, Expense.tanggal.between(d1, d2)
    ).order_by(Expense.tanggal.desc())

    rows = (await db.execute(q)).scalars().all()
    total = sum(r.jumlah for r in rows)

    return {
        "periode": f"{d1} s/d {d2}",
        "total_rupiah": _dec(total),
        "jumlah_item": len(rows),
        "items": [
            {"tanggal": str(r.tanggal), "keterangan": r.keterangan, "jumlah": _dec(r.jumlah)}
            for r in rows[:20]  # limit detail
        ],
    }


async def _get_laporan_status(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    tgl = _parse_date(args.get("tanggal", date.today().isoformat()))

    shifts = (await db.execute(
        select(Shift).where(Shift.spbu_id == spbu_id, Shift.is_active.is_(True), Shift.deleted_at.is_(None))
        .order_by(Shift.jam_mulai)
    )).scalars().all()

    result = []
    for s in shifts:
        lap = (await db.execute(
            select(LaporanShift.status).where(
                LaporanShift.spbu_id == spbu_id, LaporanShift.shift_id == s.id, LaporanShift.tanggal == tgl
            )
        )).scalar_one_or_none()

        sa = (await db.execute(
            select(StockAdjustment.status).where(
                StockAdjustment.spbu_id == spbu_id, StockAdjustment.shift_id == s.id, StockAdjustment.tanggal == tgl
            )
        )).scalar_one_or_none()

        result.append({
            "shift": s.nama,
            "laporan_penjualan": lap or "belum_dibuat",
            "stock_adjustment": sa or "belum_dibuat",
        })

    return {"tanggal": str(tgl), "shifts": result}


async def _get_tangki_status(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    # Reuse _get_stok but add estimated days
    stok = await _get_stok(db, spbu_id, args)

    # Calculate average daily sales per product (last 7 days)
    d_end = date.today()
    d_start = d_end - timedelta(days=7)

    for t in stok["tangki"]:
        produk_nama = t["produk"]
        if produk_nama == "—":
            t["estimasi_hari"] = "—"
            continue

        avg_q = (
            select(func.coalesce(func.sum(PenjualanNozzle.volume), 0))
            .join(Nozzle, PenjualanNozzle.nozzle_id == Nozzle.id)
            .join(Produk, Nozzle.produk_id == Produk.id)
            .join(LaporanShift, PenjualanNozzle.laporan_shift_id == LaporanShift.id)
            .where(
                LaporanShift.spbu_id == spbu_id,
                LaporanShift.tanggal.between(d_start, d_end),
                Produk.nama == produk_nama,
            )
        )
        total_7d = (await db.execute(avg_q)).scalar_one() or 0
        daily_avg = float(total_7d) / 7 if total_7d else 0
        vol = float(t["volume_liter"]) if t["volume_liter"] != "0" else 0
        t["estimasi_hari"] = f"{round(vol / daily_avg, 1)} hari" if daily_avg > 0 else "—"

    return stok


async def _get_penebusan(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    q = select(Penebusan).where(Penebusan.spbu_id == spbu_id).order_by(Penebusan.tanggal.desc()).limit(10)
    status_filter = args.get("status")
    if status_filter:
        q = q.where(Penebusan.status == status_filter)

    rows = (await db.execute(q)).scalars().all()
    return {
        "penebusan": [
            {
                "id": r.id,
                "no_do": r.no_do,
                "no_so": r.no_so,
                "tanggal": str(r.tanggal),
                "status": r.status,
            }
            for r in rows
        ],
    }


async def _get_penerimaan(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    d1 = _parse_date(args["tanggal_mulai"])
    d2 = _parse_date(args["tanggal_akhir"])

    q = (
        select(
            Penerimaan.id, Penerimaan.tanggal, Penerimaan.no_segel,
            func.coalesce(func.sum(PenerimaanItem.volume_diterima), 0),
        )
        .join(PenerimaanItem, PenerimaanItem.penerimaan_id == Penerimaan.id)
        .where(Penerimaan.spbu_id == spbu_id, Penerimaan.tanggal.between(d1, d2))
        .group_by(Penerimaan.id, Penerimaan.tanggal, Penerimaan.no_segel)
        .order_by(Penerimaan.tanggal.desc())
    )
    rows = (await db.execute(q)).all()
    return {
        "periode": f"{d1} s/d {d2}",
        "penerimaan": [
            {"id": r[0], "tanggal": str(r[1]), "no_segel": r[2], "volume_liter": _dec(r[3])}
            for r in rows
        ],
        "total_volume": _dec(sum(r[3] for r in rows)),
    }


async def _get_penyetoran(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    d1 = _parse_date(args["tanggal_mulai"])
    d2 = _parse_date(args["tanggal_akhir"])

    q = select(Penyetoran).where(
        Penyetoran.spbu_id == spbu_id, Penyetoran.tanggal.between(d1, d2)
    ).order_by(Penyetoran.tanggal.desc())

    rows = (await db.execute(q)).scalars().all()
    total = sum(r.jumlah_setor for r in rows)

    return {
        "periode": f"{d1} s/d {d2}",
        "total_rupiah": _dec(total),
        "penyetoran": [
            {"tanggal": str(r.tanggal), "jumlah": _dec(r.jumlah_setor), "status": r.status}
            for r in rows
        ],
    }


async def _get_harga_produk(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    q = select(Produk).where(Produk.deleted_at.is_(None))
    produk_filter = args.get("produk_nama")
    if produk_filter:
        q = q.where(Produk.nama.ilike(f"%{produk_filter}%"))

    produk_list = (await db.execute(q)).scalars().all()
    result = []
    for p in produk_list:
        harga_q = (
            select(ProdukHarga)
            .where(ProdukHarga.produk_id == p.id)
            .order_by(ProdukHarga.berlaku_mulai.desc())
            .limit(5)
        )
        harga_rows = (await db.execute(harga_q)).scalars().all()
        result.append({
            "produk": p.nama,
            "kode": p.kode,
            "is_subsidi": p.is_subsidi,
            "harga_terkini": _dec(harga_rows[0].harga) if harga_rows else "—",
            "berlaku_mulai": str(harga_rows[0].berlaku_mulai) if harga_rows else "—",
            "riwayat": [
                {"harga": _dec(h.harga), "berlaku_mulai": str(h.berlaku_mulai)}
                for h in harga_rows
            ],
        })

    return {"produk": result}


async def _get_info_spbu(db: AsyncSession, spbu_id: int, args: dict) -> dict:
    spbu = (await db.execute(select(Spbu).where(Spbu.id == spbu_id))).scalar_one_or_none()
    if not spbu:
        return {"error": "SPBU tidak ditemukan"}

    shifts = (await db.execute(
        select(Shift).where(Shift.spbu_id == spbu_id, Shift.is_active.is_(True), Shift.deleted_at.is_(None))
    )).scalars().all()

    tangkis = (await db.execute(
        select(Tangki).where(Tangki.spbu_id == spbu_id, Tangki.is_active.is_(True), Tangki.deleted_at.is_(None))
    )).scalars().all()

    nozzles = (await db.execute(
        select(Nozzle).where(Nozzle.island.has(spbu_id=spbu_id), Nozzle.is_active.is_(True), Nozzle.deleted_at.is_(None))
    )).scalars().all()

    return {
        "nama": spbu.name,
        "nomor_pertamina": spbu.nomor_pertamina,
        "alamat": spbu.alamat,
        "jumlah_shift": len(shifts),
        "shifts": [{"nama": s.nama, "jam": f"{s.jam_mulai}-{s.jam_selesai}"} for s in shifts],
        "jumlah_tangki": len(tangkis),
        "tangki": [{"nama": t.nama, "kapasitas": _dec(t.kapasitas_liter)} for t in tangkis],
        "jumlah_nozzle": len(nozzles),
    }
