"""Rekonsiliasi Harian repository — DB queries for daily reconciliation."""

from datetime import date

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.rekonsiliasi import RekonsiliasiHarian, RekonsiliasiTangki


def _eager_opts():
    return [
        selectinload(RekonsiliasiHarian.items).selectinload(RekonsiliasiTangki.tangki),
        selectinload(RekonsiliasiHarian.run_by),
        selectinload(RekonsiliasiHarian.approved_by),
    ]


async def get_by_id(
    db: AsyncSession, id: int, spbu_id: int
) -> RekonsiliasiHarian | None:
    result = await db.execute(
        select(RekonsiliasiHarian)
        .where(RekonsiliasiHarian.id == id, RekonsiliasiHarian.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_by_tanggal(
    db: AsyncSession, spbu_id: int, tanggal: date
) -> RekonsiliasiHarian | None:
    result = await db.execute(
        select(RekonsiliasiHarian)
        .where(
            RekonsiliasiHarian.spbu_id == spbu_id,
            RekonsiliasiHarian.tanggal == tanggal,
        )
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_all(
    db: AsyncSession,
    spbu_id: int,
    tanggal_mulai: date | None = None,
    tanggal_akhir: date | None = None,
    status: str | None = None,
    skip: int = 0,
    limit: int = 50,
) -> tuple[list[RekonsiliasiHarian], int]:
    query = (
        select(RekonsiliasiHarian)
        .where(RekonsiliasiHarian.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    if tanggal_mulai is not None:
        query = query.where(RekonsiliasiHarian.tanggal >= tanggal_mulai)
    if tanggal_akhir is not None:
        query = query.where(RekonsiliasiHarian.tanggal <= tanggal_akhir)
    if status is not None:
        query = query.where(RekonsiliasiHarian.status == status)
    query = query.order_by(RekonsiliasiHarian.tanggal.desc())

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def upsert(
    db: AsyncSession, rekon: RekonsiliasiHarian, items_data: list[dict]
) -> RekonsiliasiHarian:
    """Create or update rekonsiliasi with items. Replaces all items on re-run."""
    # Remove existing items if re-running
    if rekon.items:
        for item in list(rekon.items):
            await db.delete(item)
        await db.flush()

    for item_data in items_data:
        item = RekonsiliasiTangki(rekonsiliasi_harian_id=rekon.id, **item_data)
        db.add(item)

    await db.flush()
    await db.refresh(rekon)
    return await get_by_id(db, rekon.id, rekon.spbu_id)  # type: ignore[return-value]


async def create(db: AsyncSession, data: dict) -> RekonsiliasiHarian:
    rekon = RekonsiliasiHarian(**data)
    db.add(rekon)
    await db.flush()
    await db.refresh(rekon)
    return await get_by_id(db, rekon.id, rekon.spbu_id)  # type: ignore[return-value]


async def update_status(
    db: AsyncSession, rekon: RekonsiliasiHarian, data: dict
) -> RekonsiliasiHarian:
    for k, v in data.items():
        setattr(rekon, k, v)
    await db.flush()
    await db.refresh(rekon)
    return await get_by_id(db, rekon.id, rekon.spbu_id)  # type: ignore[return-value]
