"""End-to-End repository — DB queries for the E2E reconciliation module."""

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.end_to_end import EndToEndCycle, StatusEndToEnd


def _eager_opts():
    return [
        selectinload(EndToEndCycle.tangki),
        selectinload(EndToEndCycle.started_by),
        selectinload(EndToEndCycle.closed_by),
    ]


async def get_by_id(db: AsyncSession, id: int, spbu_id: int) -> EndToEndCycle | None:
    """Fetch a single cycle by PK and SPBU."""
    result = await db.execute(
        select(EndToEndCycle)
        .where(EndToEndCycle.id == id, EndToEndCycle.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_open_by_tangki(db: AsyncSession, tangki_id: int) -> EndToEndCycle | None:
    """Check if there's an open cycle for this tank."""
    result = await db.execute(
        select(EndToEndCycle)
        .where(
            EndToEndCycle.tangki_id == tangki_id,
            EndToEndCycle.status == StatusEndToEnd.OPEN,
        )
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_all(
    db: AsyncSession,
    spbu_id: int,
    tangki_id: int | None = None,
    status: str | None = None,
    skip: int = 0,
    limit: int = 50,
) -> tuple[list[EndToEndCycle], int]:
    """Return paginated E2E cycles for an SPBU, newest first."""
    query = (
        select(EndToEndCycle)
        .where(EndToEndCycle.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    if tangki_id is not None:
        query = query.where(EndToEndCycle.tangki_id == tangki_id)
    if status is not None:
        query = query.where(EndToEndCycle.status == status)
    query = query.order_by(EndToEndCycle.tanggal_mulai.desc(), EndToEndCycle.id.desc())

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def create(
    db: AsyncSession, data: dict
) -> EndToEndCycle:
    """Insert a new E2E cycle and return it fully loaded."""
    cycle = EndToEndCycle(**data)
    db.add(cycle)
    await db.flush()
    await db.refresh(cycle)
    return await get_by_id(db, cycle.id, cycle.spbu_id)  # type: ignore[return-value]


async def update(db: AsyncSession, cycle: EndToEndCycle, data: dict) -> EndToEndCycle:
    """Apply update dict to a cycle and commit."""
    for k, v in data.items():
        setattr(cycle, k, v)
    await db.flush()
    await db.refresh(cycle)
    return await get_by_id(db, cycle.id, cycle.spbu_id)  # type: ignore[return-value]
