"""Penyetoran repository — DB queries for the per-shift cash deposit module."""

from datetime import date
from typing import Optional

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.operational import LaporanShift
from app.models.penyetoran import Penyetoran, PenyetoranBatch


# ---------------------------------------------------------------------------
# Eager-load options
# ---------------------------------------------------------------------------

def _eager_opts():
    """Standard eager-load options for a Penyetoran row."""
    return [
        selectinload(Penyetoran.created_by),
        selectinload(Penyetoran.laporan_shift).selectinload(LaporanShift.shift),
        selectinload(Penyetoran.batch),
    ]


def _batch_eager_opts():
    """Standard eager-load options for a PenyetoranBatch row."""
    return [
        selectinload(PenyetoranBatch.submitted_by),
        selectinload(PenyetoranBatch.reviewed_by),
        selectinload(PenyetoranBatch.items).selectinload(Penyetoran.laporan_shift).selectinload(LaporanShift.shift),
        selectinload(PenyetoranBatch.items).selectinload(Penyetoran.created_by),
    ]


# ---------------------------------------------------------------------------
# Single penyetoran queries
# ---------------------------------------------------------------------------

async def get_by_id(db: AsyncSession, id: int, spbu_id: int) -> Penyetoran | None:
    """Fetch a single Penyetoran by PK and SPBU."""
    result = await db.execute(
        select(Penyetoran)
        .where(Penyetoran.id == id, Penyetoran.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_by_laporan_shift_id(
    db: AsyncSession, laporan_shift_id: int
) -> Penyetoran | None:
    """Fetch a Penyetoran by laporan_shift_id (unique)."""
    result = await db.execute(
        select(Penyetoran)
        .where(Penyetoran.laporan_shift_id == laporan_shift_id)
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_by_tanggal(
    db: AsyncSession, spbu_id: int, tanggal: date
) -> Penyetoran | None:
    """Fetch the first Penyetoran by SPBU + date (backward compat — may return None)."""
    result = await db.execute(
        select(Penyetoran)
        .where(Penyetoran.spbu_id == spbu_id, Penyetoran.tanggal == tanggal)
        .options(*_eager_opts())
        .order_by(Penyetoran.id)
        .limit(1)
    )
    return result.scalar_one_or_none()


async def get_all(
    db: AsyncSession,
    spbu_id: int,
    skip: int = 0,
    limit: int = 50,
    tanggal_from: Optional[date] = None,
    tanggal_to: Optional[date] = None,
    status: Optional[str] = None,
) -> tuple[list[Penyetoran], int]:
    """Return paginated Penyetoran list for an SPBU, newest first."""
    query = (
        select(Penyetoran)
        .where(Penyetoran.spbu_id == spbu_id)
        .options(*_eager_opts())
        .order_by(Penyetoran.tanggal.desc(), Penyetoran.id.desc())
    )
    if tanggal_from is not None:
        query = query.where(Penyetoran.tanggal >= tanggal_from)
    if tanggal_to is not None:
        query = query.where(Penyetoran.tanggal <= tanggal_to)
    if status is not None:
        query = query.where(Penyetoran.status == status)

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def update(db: AsyncSession, p: Penyetoran, data: dict) -> Penyetoran:
    """Apply a partial update dict to a Penyetoran and flush."""
    for k, v in data.items():
        setattr(p, k, v)
    await db.flush()
    await db.refresh(p)
    return await get_by_id(db, p.id, p.spbu_id)  # type: ignore[return-value]


async def update_bukti(db: AsyncSession, p: Penyetoran, url: str) -> Penyetoran:
    """Set the bukti_url on a Penyetoran and flush."""
    p.bukti_url = url
    await db.flush()
    await db.refresh(p)
    return await get_by_id(db, p.id, p.spbu_id)  # type: ignore[return-value]


# ---------------------------------------------------------------------------
# Batch queries
# ---------------------------------------------------------------------------

async def get_batch_by_id(
    db: AsyncSession, batch_id: int, spbu_id: int
) -> PenyetoranBatch | None:
    """Fetch a single PenyetoranBatch by PK and SPBU."""
    result = await db.execute(
        select(PenyetoranBatch)
        .where(PenyetoranBatch.id == batch_id, PenyetoranBatch.spbu_id == spbu_id)
        .options(*_batch_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_all_batches(
    db: AsyncSession,
    spbu_id: int,
    skip: int = 0,
    limit: int = 50,
    status: Optional[str] = None,
) -> tuple[list[PenyetoranBatch], int]:
    """Return paginated PenyetoranBatch list for an SPBU, newest first."""
    query = (
        select(PenyetoranBatch)
        .where(PenyetoranBatch.spbu_id == spbu_id)
        .options(*_batch_eager_opts())
        .order_by(PenyetoranBatch.id.desc())
    )
    if status is not None:
        query = query.where(PenyetoranBatch.status == status)

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def create_batch(
    db: AsyncSession, spbu_id: int, data: dict
) -> PenyetoranBatch:
    """Insert a new PenyetoranBatch and return it fully loaded."""
    batch = PenyetoranBatch(spbu_id=spbu_id, **data)
    db.add(batch)
    await db.flush()
    await db.refresh(batch)
    return await get_batch_by_id(db, batch.id, spbu_id)  # type: ignore[return-value]


async def update_batch(
    db: AsyncSession, batch: PenyetoranBatch, data: dict
) -> PenyetoranBatch:
    """Apply a partial update dict to a PenyetoranBatch and flush."""
    for k, v in data.items():
        setattr(batch, k, v)
    await db.flush()
    await db.refresh(batch)
    return await get_batch_by_id(db, batch.id, batch.spbu_id)  # type: ignore[return-value]
