"""Product repository — all DB queries for products (Produk) and their price history (ProdukHarga)."""

from datetime import date

from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.product import Produk, ProdukHarga


async def get_all(
    db: AsyncSession, skip: int = 0, limit: int = 50
) -> tuple[list[Produk], int]:
    """Return a paginated product list with total count (uses a proper COUNT query, no full scan)."""
    base_query = select(Produk).where(Produk.deleted_at.is_(None))
    count_q = await db.execute(select(func.count()).select_from(base_query.subquery()))
    total = count_q.scalar_one()
    result = await db.execute(
        base_query.order_by(Produk.nama).offset(skip).limit(limit)
    )
    return list(result.scalars().all()), total


async def get_by_id(db: AsyncSession, produk_id: int) -> Produk | None:
    """Fetch a single product by primary key, with full price history loaded."""
    result = await db.execute(
        select(Produk)
        .where(Produk.id == produk_id, Produk.deleted_at.is_(None))
        .options(selectinload(Produk.harga_history))
    )
    return result.scalar_one_or_none()


async def create(db: AsyncSession, **kwargs) -> Produk:
    """Create a new product record. Flush only — caller commits."""
    produk = Produk(**kwargs)
    db.add(produk)
    await db.flush()
    await db.refresh(produk)
    return produk


async def update(db: AsyncSession, produk: Produk, **kwargs) -> Produk:
    """Update scalar fields on an existing product. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(produk, key, value)
    await db.flush()
    await db.refresh(produk)
    return produk


async def get_harga_history(db: AsyncSession, produk_id: int) -> list[ProdukHarga]:
    """Return all price entries for a product, newest first."""
    result = await db.execute(
        select(ProdukHarga)
        .where(ProdukHarga.produk_id == produk_id)
        .order_by(ProdukHarga.berlaku_mulai.desc())
    )
    return list(result.scalars().all())


async def get_current_harga(
    db: AsyncSession, produk_id: int, tanggal: date | None = None
) -> ProdukHarga | None:
    """Return the price entry active on the given date (defaults to today)."""
    if tanggal is None:
        tanggal = date.today()
    result = await db.execute(
        select(ProdukHarga)
        .where(
            and_(
                ProdukHarga.produk_id == produk_id,
                ProdukHarga.berlaku_mulai <= tanggal,
                or_(
                    ProdukHarga.berlaku_sampai.is_(None),
                    ProdukHarga.berlaku_sampai >= tanggal,
                ),
            )
        )
        .order_by(ProdukHarga.berlaku_mulai.desc())
        .limit(1)
    )
    return result.scalar_one_or_none()


async def get_current_harga_bulk(
    db: AsyncSession, produk_ids: list[int], tanggal: date | None = None
) -> dict[int, ProdukHarga]:
    """Fetch the current active price for each product in one query. Returns {produk_id: ProdukHarga}."""
    if tanggal is None:
        tanggal = date.today()
    if not produk_ids:
        return {}
    result = await db.execute(
        select(ProdukHarga)
        .where(
            and_(
                ProdukHarga.produk_id.in_(produk_ids),
                ProdukHarga.berlaku_mulai <= tanggal,
                or_(
                    ProdukHarga.berlaku_sampai.is_(None),
                    ProdukHarga.berlaku_sampai >= tanggal,
                ),
            )
        )
        .order_by(ProdukHarga.berlaku_mulai.desc())
    )
    rows = result.scalars().all()
    # Keep only the latest price per product (order desc guarantees first seen = latest)
    harga_map: dict[int, ProdukHarga] = {}
    for row in rows:
        if row.produk_id not in harga_map:
            harga_map[row.produk_id] = row
    return harga_map


async def close_current_harga(
    db: AsyncSession, produk_id: int, berlaku_sampai: date
) -> None:
    """Set berlaku_sampai on the currently open price entry (flush only).

    # Flush within caller's transaction — caller is responsible for commit/rollback
    """
    result = await db.execute(
        select(ProdukHarga).where(
            and_(
                ProdukHarga.produk_id == produk_id,
                ProdukHarga.berlaku_sampai.is_(None),
            )
        )
    )
    current = result.scalar_one_or_none()
    if current:
        current.berlaku_sampai = berlaku_sampai
        await db.flush()


async def create_harga(
    db: AsyncSession, produk_id: int, harga, berlaku_mulai: date
) -> ProdukHarga:
    """Insert a new price entry. Flush only — caller commits (atomically with close_current_harga)."""
    ph = ProdukHarga(produk_id=produk_id, harga=harga, berlaku_mulai=berlaku_mulai)
    db.add(ph)
    await db.flush()
    await db.refresh(ph)
    return ph
