"""Anomali repository — DB queries for anomaly records."""

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.anomali import AnomaliRecord


def _eager_opts():
    return [
        selectinload(AnomaliRecord.produk),
        selectinload(AnomaliRecord.nozzle),
        selectinload(AnomaliRecord.assigned_user),
        selectinload(AnomaliRecord.resolved_by),
    ]


async def get_by_id(db: AsyncSession, id: int, spbu_id: int) -> AnomaliRecord | None:
    result = await db.execute(
        select(AnomaliRecord)
        .where(AnomaliRecord.id == id, AnomaliRecord.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    return result.scalar_one_or_none()


async def get_all(
    db: AsyncSession,
    spbu_id: int,
    tipe: str | None = None,
    status: str | None = None,
    skip: int = 0,
    limit: int = 50,
) -> tuple[list[AnomaliRecord], int]:
    query = (
        select(AnomaliRecord)
        .where(AnomaliRecord.spbu_id == spbu_id)
        .options(*_eager_opts())
    )
    if tipe is not None:
        query = query.where(AnomaliRecord.tipe == tipe)
    if status is not None:
        query = query.where(AnomaliRecord.status == status)
    query = query.order_by(AnomaliRecord.created_at.desc())

    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()

    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def count_active(db: AsyncSession, spbu_id: int) -> int:
    result = await db.execute(
        select(func.count())
        .select_from(AnomaliRecord)
        .where(
            AnomaliRecord.spbu_id == spbu_id,
            AnomaliRecord.status.in_(["new", "investigating"]),
        )
    )
    return result.scalar_one()


async def create(db: AsyncSession, data: dict) -> AnomaliRecord:
    record = AnomaliRecord(**data)
    db.add(record)
    await db.flush()
    await db.refresh(record)
    return await get_by_id(db, record.id, record.spbu_id)  # type: ignore[return-value]


async def update(db: AsyncSession, record: AnomaliRecord, data: dict) -> AnomaliRecord:
    for k, v in data.items():
        setattr(record, k, v)
    await db.flush()
    await db.refresh(record)
    return await get_by_id(db, record.id, record.spbu_id)  # type: ignore[return-value]
