"""Stock adjustment repository — DB queries for StockAdjustment and StockAdjustmentItem."""

from datetime import datetime, timezone

from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.operational import StockAdjustment, StockAdjustmentItem, StockAdjustmentItemFoto, StatusLaporan
from app.models.spbu import Tangki


async def get_all(
    db: AsyncSession,
    spbu_id: int,
    tanggal=None,
    shift_id: int | None = None,
    status: StatusLaporan | None = None,
    skip: int = 0,
    limit: int = 20,
) -> tuple[list[StockAdjustment], int]:
    """Fetch paginated StockAdjustment rows with shift/user joins."""
    query = (
        select(StockAdjustment)
        .where(StockAdjustment.spbu_id == spbu_id)
        .options(
            selectinload(StockAdjustment.shift),
            selectinload(StockAdjustment.submitted_by),
            selectinload(StockAdjustment.reviewed_by),
            selectinload(StockAdjustment.recalled_by),
            selectinload(StockAdjustment.items),
        )
        .order_by(StockAdjustment.tanggal.desc(), StockAdjustment.id.desc())
    )
    if tanggal is not None:
        query = query.where(StockAdjustment.tanggal == tanggal)
    if shift_id is not None:
        query = query.where(StockAdjustment.shift_id == shift_id)
    if status is not None:
        query = query.where(StockAdjustment.status == status)

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().unique().all()), total


async def get_by_id(
    db: AsyncSession, adj_id: int, spbu_id: int
) -> StockAdjustment | None:
    """Fetch a single StockAdjustment by PK and SPBU, eagerly loading items with tangki→produk."""
    result = await db.execute(
        select(StockAdjustment)
        .where(StockAdjustment.id == adj_id, StockAdjustment.spbu_id == spbu_id)
        .options(
            selectinload(StockAdjustment.shift),
            selectinload(StockAdjustment.submitted_by),
            selectinload(StockAdjustment.reviewed_by),
            selectinload(StockAdjustment.recalled_by),
            selectinload(StockAdjustment.items).options(
                selectinload(StockAdjustmentItem.tangki).selectinload(Tangki.produk),
                selectinload(StockAdjustmentItem.fotos),
            ),
        )
    )
    return result.scalar_one_or_none()


async def create(
    db: AsyncSession,
    spbu_id: int,
    shift_id: int,
    tanggal,
    items: list[dict],
) -> StockAdjustment:
    """Insert a new StockAdjustment together with its StockAdjustmentItem rows atomically."""
    adj = StockAdjustment(
        spbu_id=spbu_id,
        shift_id=shift_id,
        tanggal=tanggal,
        status=StatusLaporan.DRAFT,
    )
    db.add(adj)
    await db.flush()  # obtain adj.id
    for item in items:
        db.add(StockAdjustmentItem(stock_adjustment_id=adj.id, **item))
    await db.flush()
    return await get_by_id(db, adj.id, spbu_id)


async def update_items(
    db: AsyncSession, adj: StockAdjustment, items: list[dict]
) -> StockAdjustment:
    """Replace all StockAdjustmentItem rows for an adjustment atomically (delete-then-insert)."""
    await db.execute(
        delete(StockAdjustmentItem).where(StockAdjustmentItem.stock_adjustment_id == adj.id)
    )
    for item in items:
        db.add(StockAdjustmentItem(stock_adjustment_id=adj.id, **item))
    await db.flush()
    return await get_by_id(db, adj.id, adj.spbu_id)


async def update_status(
    db: AsyncSession,
    adj: StockAdjustment,
    status: StatusLaporan,
    user_id: int | None = None,
    catatan: str | None = None,
    unlock_reason: str | None = None,
    recall: bool = False,
) -> StockAdjustment:
    """Transition the status of a StockAdjustment, setting relevant audit fields."""
    now = datetime.now(timezone.utc)
    adj.status = status

    if status == StatusLaporan.SUBMITTED and user_id is not None:
        adj.submitted_by_id = user_id
        adj.submitted_at = now
    elif status in (StatusLaporan.APPROVED, StatusLaporan.REJECTED) and user_id is not None:
        adj.reviewed_by_id = user_id
        adj.reviewed_at = now
        if catatan is not None:
            adj.catatan_review = catatan
    elif status == StatusLaporan.DRAFT:
        if recall and user_id is not None:
            adj.recalled_by_id = user_id
            adj.recalled_at = now
            adj.submitted_by_id = None
            adj.submitted_at = None
        elif unlock_reason is not None:
            adj.unlock_reason = unlock_reason

    await db.flush()
    await db.refresh(adj)
    return await get_by_id(db, adj.id, adj.spbu_id)


async def add_item_foto(
    db: AsyncSession,
    item_id: int,
    tipe: str,
    url: str,
) -> StockAdjustmentItemFoto:
    """Attach a photo to a StockAdjustmentItem."""
    foto = StockAdjustmentItemFoto(
        stock_adjustment_item_id=item_id,
        tipe=tipe,
        url=url,
    )
    db.add(foto)
    await db.flush()
    await db.refresh(foto)
    return foto


async def delete_item_foto(
    db: AsyncSession,
    foto_id: int,
    item_id: int,
) -> StockAdjustmentItemFoto | None:
    """Delete a photo by PK, verifying it belongs to the given item. Returns deleted foto or None."""
    result = await db.execute(
        select(StockAdjustmentItemFoto).where(
            StockAdjustmentItemFoto.id == foto_id,
            StockAdjustmentItemFoto.stock_adjustment_item_id == item_id,
        )
    )
    foto = result.scalar_one_or_none()
    if foto is None:
        return None
    await db.delete(foto)
    await db.flush()
    return foto


async def get_item_by_id(
    db: AsyncSession,
    item_id: int,
    adj_id: int,
) -> StockAdjustmentItem | None:
    """Fetch a single StockAdjustmentItem by PK, verifying it belongs to the given adjustment."""
    result = await db.execute(
        select(StockAdjustmentItem)
        .where(
            StockAdjustmentItem.id == item_id,
            StockAdjustmentItem.stock_adjustment_id == adj_id,
        )
        .options(selectinload(StockAdjustmentItem.fotos))
    )
    return result.scalar_one_or_none()


async def get_stock_init(db: AsyncSession, spbu_id: int) -> list[dict]:
    """
    Return last known StockAdjustmentItem per active tank in an SPBU.
    Used to pre-fill the sounding form with last known readings.
    """
    # Get all active tanks for this SPBU
    tanks_result = await db.execute(
        select(Tangki)
        .where(
            Tangki.spbu_id == spbu_id,
            Tangki.is_active.is_(True),
            Tangki.deleted_at.is_(None),
        )
        .options(selectinload(Tangki.produk))
        .order_by(Tangki.nama)
    )
    tanks = list(tanks_result.scalars().all())

    rows = []
    for tangki in tanks:
        # Find the most recent StockAdjustmentItem for this tank
        item_result = await db.execute(
            select(StockAdjustmentItem)
            .join(StockAdjustment, StockAdjustmentItem.stock_adjustment_id == StockAdjustment.id)
            .where(
                StockAdjustmentItem.tangki_id == tangki.id,
                StockAdjustment.status.in_([StatusLaporan.APPROVED, StatusLaporan.LOCKED]),
            )
            .order_by(StockAdjustment.tanggal.desc(), StockAdjustment.id.desc())
            .limit(1)
        )
        last_item = item_result.scalar_one_or_none()

        rows.append({
            "tangki_id": tangki.id,
            "tangki_nama": tangki.nama,
            "produk_nama": tangki.produk.nama if tangki.produk else "",
            "kapasitas_liter": tangki.kapasitas_liter,
            "last_volume_final": last_item.volume_final_liter if last_item else None,
            "last_dipstick_digital_mm": last_item.dipstick_digital_mm if last_item else None,
            "is_first_time": last_item is None,
        })
    return rows
