"""Expense repository."""

from datetime import date
from decimal import Decimal

from sqlalchemy import func, select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.expenses import Expense, ExpenseKategori


async def get_kategori_list(db: AsyncSession, spbu_id: int) -> list[ExpenseKategori]:
    """Return global + SPBU-specific categories, ordered by urutan."""
    result = await db.execute(
        select(ExpenseKategori)
        .where(
            ExpenseKategori.is_active.is_(True),
            or_(ExpenseKategori.spbu_id.is_(None), ExpenseKategori.spbu_id == spbu_id),
        )
        .order_by(ExpenseKategori.urutan, ExpenseKategori.id)
    )
    return list(result.scalars().all())


async def create_kategori(db: AsyncSession, spbu_id: int, nama: str, urutan: int) -> ExpenseKategori:
    kat = ExpenseKategori(spbu_id=spbu_id, nama=nama, urutan=urutan)
    db.add(kat)
    await db.flush()
    await db.refresh(kat)
    return kat


async def update_kategori(db: AsyncSession, kat_id: int, **kwargs) -> ExpenseKategori | None:
    result = await db.execute(select(ExpenseKategori).where(ExpenseKategori.id == kat_id))
    kat = result.scalar_one_or_none()
    if kat is None:
        return None
    for k, v in kwargs.items():
        if v is not None:
            setattr(kat, k, v)
    await db.flush()
    await db.refresh(kat)
    return kat


async def get_all_expenses(
    db: AsyncSession,
    spbu_id: int,
    tanggal: date | None = None,
    tanggal_from: date | None = None,
    tanggal_to: date | None = None,
    laporan_shift_id: int | None = None,
    kategori_id: int | None = None,
    skip: int = 0,
    limit: int = 50,
) -> tuple[list[Expense], int]:
    query = (
        select(Expense)
        .where(Expense.spbu_id == spbu_id)
        .options(
            selectinload(Expense.kategori),
            selectinload(Expense.created_by),
            selectinload(Expense.submitted_by),
            selectinload(Expense.reviewed_by),
            selectinload(Expense.recalled_by),
            selectinload(Expense.unlocked_by),
        )
        .order_by(Expense.tanggal.desc(), Expense.id.desc())
    )
    if tanggal:
        query = query.where(Expense.tanggal == tanggal)
    if tanggal_from:
        query = query.where(Expense.tanggal >= tanggal_from)
    if tanggal_to:
        query = query.where(Expense.tanggal <= tanggal_to)
    if laporan_shift_id:
        query = query.where(Expense.laporan_shift_id == laporan_shift_id)
    if kategori_id:
        query = query.where(Expense.kategori_id == kategori_id)

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()
    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().unique().all()), total


async def get_expense_by_id(db: AsyncSession, expense_id: int, spbu_id: int) -> Expense | None:
    result = await db.execute(
        select(Expense)
        .where(Expense.id == expense_id, Expense.spbu_id == spbu_id)
        .options(
            selectinload(Expense.kategori),
            selectinload(Expense.created_by),
            selectinload(Expense.submitted_by),
            selectinload(Expense.reviewed_by),
            selectinload(Expense.recalled_by),
            selectinload(Expense.unlocked_by),
        )
    )
    return result.scalar_one_or_none()


async def create_expense(db: AsyncSession, spbu_id: int, user_id: int, data: dict) -> Expense:
    expense = Expense(spbu_id=spbu_id, created_by_id=user_id, **data)
    db.add(expense)
    await db.flush()
    await db.refresh(expense)
    return await get_expense_by_id(db, expense.id, spbu_id)  # type: ignore[return-value]


async def update_expense(db: AsyncSession, expense: Expense, data: dict) -> Expense:
    for k, v in data.items():
        if v is not None:
            setattr(expense, k, v)
    await db.flush()
    await db.refresh(expense)
    return await get_expense_by_id(db, expense.id, expense.spbu_id)  # type: ignore[return-value]


async def delete_expense(db: AsyncSession, expense: Expense) -> None:
    await db.delete(expense)
    await db.flush()


async def update_bukti_url(db: AsyncSession, expense: Expense, url: str) -> Expense:
    expense.bukti_url = url
    await db.flush()
    await db.refresh(expense)
    return await get_expense_by_id(db, expense.id, expense.spbu_id)  # type: ignore[return-value]


async def get_total_by_shift(db: AsyncSession, laporan_shift_id: int) -> Decimal:
    """Sum of all expenses for a given laporan_shift."""
    result = await db.execute(
        select(func.coalesce(func.sum(Expense.jumlah), 0))
        .where(Expense.laporan_shift_id == laporan_shift_id)
    )
    return result.scalar_one()
