"""Operational repository — DB queries for LaporanShift and PenjualanNozzle."""

from datetime import date, datetime, timezone
from decimal import Decimal

from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.operational import LaporanShift, PenjualanNozzle, StatusLaporan
from app.models.spbu import Island, Nozzle, Shift


async def get_all_laporan(
    db: AsyncSession,
    spbu_id: int,
    tanggal: "date | None" = None,
    shift_id: int | None = None,
    status: StatusLaporan | None = None,
    skip: int = 0,
    limit: int = 20,
) -> tuple[list[LaporanShift], int]:
    """Fetch paginated laporan_shift rows, with shift/user joins for name fields."""
    from datetime import date  # local import to avoid top-level annotation issue

    query = (
        select(LaporanShift)
        .where(LaporanShift.spbu_id == spbu_id)
        .options(
            selectinload(LaporanShift.shift),
            selectinload(LaporanShift.submitted_by),
            selectinload(LaporanShift.reviewed_by),
            selectinload(LaporanShift.recalled_by),
            selectinload(LaporanShift.penjualan_nozzle),
        )
        .order_by(LaporanShift.tanggal.desc(), LaporanShift.id.desc())
    )
    if tanggal is not None:
        query = query.where(LaporanShift.tanggal == tanggal)
    if shift_id is not None:
        query = query.where(LaporanShift.shift_id == shift_id)
    if status is not None:
        query = query.where(LaporanShift.status == status)

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().unique().all()), total


async def get_laporan_by_id(
    db: AsyncSession, laporan_id: int, spbu_id: int
) -> LaporanShift | None:
    """Fetch a single LaporanShift by PK and SPBU, eagerly loading nozzle rows with their produk/island."""
    result = await db.execute(
        select(LaporanShift)
        .where(LaporanShift.id == laporan_id, LaporanShift.spbu_id == spbu_id)
        .options(
            selectinload(LaporanShift.shift),
            selectinload(LaporanShift.submitted_by),
            selectinload(LaporanShift.reviewed_by),
            selectinload(LaporanShift.recalled_by),
            selectinload(LaporanShift.penjualan_nozzle).selectinload(PenjualanNozzle.nozzle).selectinload(Nozzle.island),
            selectinload(LaporanShift.penjualan_nozzle).selectinload(PenjualanNozzle.nozzle).selectinload(Nozzle.produk),
        )
    )
    return result.scalar_one_or_none()


async def create_laporan(
    db: AsyncSession,
    spbu_id: int,
    shift_id: int,
    tanggal,
    nozzle_rows: list[dict],
) -> LaporanShift:
    """Insert a new LaporanShift together with its PenjualanNozzle rows. Caller owns commit."""
    laporan = LaporanShift(
        spbu_id=spbu_id,
        shift_id=shift_id,
        tanggal=tanggal,
        status=StatusLaporan.DRAFT,
    )
    db.add(laporan)
    await db.flush()  # obtain laporan.id
    for row in nozzle_rows:
        db.add(PenjualanNozzle(laporan_shift_id=laporan.id, **row))
    await db.flush()
    # Reload with full relationships
    return await get_laporan_by_id(db, laporan.id, spbu_id)


async def update_laporan_nozzles(
    db: AsyncSession, laporan: LaporanShift, nozzle_rows: list[dict]
) -> LaporanShift:
    """Replace all PenjualanNozzle rows for a laporan (delete-then-insert). Caller owns commit."""
    await db.execute(
        delete(PenjualanNozzle).where(PenjualanNozzle.laporan_shift_id == laporan.id)
    )
    for row in nozzle_rows:
        db.add(PenjualanNozzle(laporan_shift_id=laporan.id, **row))
    await db.flush()
    return await get_laporan_by_id(db, laporan.id, laporan.spbu_id)


async def update_status(
    db: AsyncSession,
    laporan: LaporanShift,
    status: StatusLaporan,
    user_id: int | None = None,
    catatan: str | None = None,
    unlock_reason: str | None = None,
    recall: bool = False,
) -> LaporanShift:
    """Transition the status of a LaporanShift, setting relevant audit fields. Caller owns commit."""
    now = datetime.now(timezone.utc)
    laporan.status = status

    if status == StatusLaporan.SUBMITTED and user_id is not None:
        laporan.submitted_by_id = user_id
        laporan.submitted_at = now
    elif status in (StatusLaporan.APPROVED, StatusLaporan.REJECTED) and user_id is not None:
        laporan.reviewed_by_id = user_id
        laporan.reviewed_at = now
        if catatan is not None:
            laporan.catatan_review = catatan
    elif status == StatusLaporan.DRAFT:
        if recall and user_id is not None:
            # Operator recall: clear submitted_* so it can be re-submitted cleanly
            laporan.recalled_by_id = user_id
            laporan.recalled_at = now
            laporan.submitted_by_id = None
            laporan.submitted_at = None
        elif unlock_reason is not None:
            laporan.unlock_reason = unlock_reason

    await db.flush()
    return laporan


async def update_nozzle_teller(
    db: AsyncSession,
    nozzle_id: int,
    teller_terakhir_manual: Decimal,
    teller_terakhir_digital: Decimal,
) -> None:
    """Update both teller_terakhir fields on a Nozzle record after a shift is saved. Caller owns commit."""
    await db.execute(
        update(Nozzle)
        .where(Nozzle.id == nozzle_id)
        .values(
            teller_terakhir_manual=teller_terakhir_manual,
            teller_terakhir_digital=teller_terakhir_digital,
        )
    )
    await db.flush()


async def get_teller_init(
    db: AsyncSession,
    spbu_id: int,
    prev_tanggal: date,
    prev_shift_id: int,
) -> list[dict]:
    """Return teller_akhir from the previous shift's submitted/approved laporan for each active nozzle.

    If no qualifying laporan exists for (spbu_id, prev_tanggal, prev_shift_id),
    is_first_time=True for every nozzle.
    """
    # 1. Get all active nozzles ordered for display
    nozzle_result = await db.execute(
        select(Nozzle)
        .join(Island, Nozzle.island_id == Island.id)
        .where(Island.spbu_id == spbu_id, Nozzle.is_active.is_(True), Nozzle.deleted_at.is_(None))
        .order_by(Island.urutan, Nozzle.nama)
    )
    nozzles = list(nozzle_result.scalars().all())

    # 2. Find the previous shift's laporan with status submitted/approved/locked
    laporan_result = await db.execute(
        select(LaporanShift)
        .options(selectinload(LaporanShift.penjualan_nozzle))
        .where(
            LaporanShift.spbu_id == spbu_id,
            LaporanShift.tanggal == prev_tanggal,
            LaporanShift.shift_id == prev_shift_id,
            LaporanShift.status.in_([
                StatusLaporan.SUBMITTED, StatusLaporan.APPROVED, StatusLaporan.LOCKED
            ]),
        )
    )
    prev_laporan = laporan_result.scalar_one_or_none()

    # 3. Build a lookup: nozzle_id → PenjualanNozzle row from that laporan
    nozzle_row_map: dict[int, "PenjualanNozzle"] = {}
    if prev_laporan:
        for row in prev_laporan.penjualan_nozzle or []:
            nozzle_row_map[row.nozzle_id] = row

    return [
        {
            "nozzle_id": n.id,
            "teller_terakhir_manual": nozzle_row_map[n.id].teller_akhir_manual if n.id in nozzle_row_map else None,
            "teller_terakhir_digital": nozzle_row_map[n.id].teller_akhir_digital if n.id in nozzle_row_map else None,
            "primary_teller": n.primary_teller or "manual",
            "is_first_time": n.id not in nozzle_row_map,
        }
        for n in nozzles
    ]
