"""SPBU repository — all DB queries for SPBU and its sub-resources (shifts, islands, nozzles, tangkis, tenants, contracts)."""

from datetime import datetime, timezone

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.spbu import Island, KalibrasiTangki, KontrakSewa, Nozzle, Shift, Spbu, Tangki, Tenant


async def get_all(
    db: AsyncSession,
    skip: int = 0,
    limit: int = 50,
    spbu_ids: set[int] | None = None,
) -> tuple[list[Spbu], int]:
    """Fetch paginated SPBUs. Optionally filter to a set of IDs."""
    query = (
        select(Spbu)
        .where(Spbu.deleted_at.is_(None))
        .options(selectinload(Spbu.shifts))
        .order_by(Spbu.name)
    )
    if spbu_ids is not None:
        query = query.where(Spbu.id.in_(spbu_ids))
    count_q = await db.execute(select(func.count()).select_from(query.subquery()))
    total = count_q.scalar_one()
    result = await db.execute(query.offset(skip).limit(limit))
    return list(result.scalars().all()), total


async def get_by_id(db: AsyncSession, spbu_id: int) -> Spbu | None:
    """Fetch a single SPBU by primary key, eagerly loading its shifts."""
    result = await db.execute(
        select(Spbu)
        .where(Spbu.id == spbu_id, Spbu.deleted_at.is_(None))
        .options(selectinload(Spbu.shifts))
    )
    return result.scalar_one_or_none()


async def create(db: AsyncSession, **kwargs) -> Spbu:
    """Create a new SPBU together with its initial shifts. Flush only — caller commits."""
    shifts_data = kwargs.pop("shifts", [])
    spbu = Spbu(**kwargs)
    db.add(spbu)
    await db.flush()  # get spbu.id before adding shifts
    for shift_data in shifts_data:
        db.add(Shift(spbu_id=spbu.id, **shift_data))
    await db.flush()
    return await get_by_id(db, spbu.id)  # reload with relationships


async def update(db: AsyncSession, spbu: Spbu, **kwargs) -> Spbu:
    """Update scalar fields on an existing SPBU. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(spbu, key, value)
    await db.flush()
    return await get_by_id(db, spbu.id)  # reload with relationships


async def delete_spbu(db: AsyncSession, spbu: Spbu, hard_delete: bool = True) -> None:
    """Delete an SPBU record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(spbu)
    else:
        spbu.deleted_at = datetime.now(timezone.utc)
    await db.flush()


# --- Shifts ---

async def get_shifts(db: AsyncSession, spbu_id: int) -> list[Shift]:
    """Return all shifts for a given SPBU, ordered by start time."""
    result = await db.execute(
        select(Shift)
        .where(Shift.spbu_id == spbu_id, Shift.deleted_at.is_(None))
        .order_by(Shift.jam_mulai)
    )
    return list(result.scalars().all())


async def get_shift_by_id(db: AsyncSession, shift_id: int) -> Shift | None:
    """Fetch a single shift by primary key."""
    result = await db.execute(
        select(Shift).where(Shift.id == shift_id, Shift.deleted_at.is_(None))
    )
    return result.scalar_one_or_none()


async def create_shift(db: AsyncSession, spbu_id: int, **kwargs) -> Shift:
    """Create a new shift for the given SPBU. Flush only — caller commits."""
    shift = Shift(spbu_id=spbu_id, **kwargs)
    db.add(shift)
    await db.flush()
    await db.refresh(shift)
    return shift


async def update_shift(db: AsyncSession, shift: Shift, **kwargs) -> Shift:
    """Update scalar fields on an existing shift. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(shift, key, value)
    await db.flush()
    await db.refresh(shift)
    return shift


async def delete_shift(db: AsyncSession, shift: Shift, hard_delete: bool = True) -> None:
    """Delete a shift record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(shift)
    else:
        shift.deleted_at = datetime.now(timezone.utc)
    await db.flush()


# --- Islands ---

async def get_islands(db: AsyncSession, spbu_id: int) -> list[Island]:
    """Return all islands for a given SPBU, with nested nozzle relationships loaded."""
    result = await db.execute(
        select(Island)
        .where(Island.spbu_id == spbu_id, Island.deleted_at.is_(None))
        .options(
            selectinload(Island.nozzles).selectinload(Nozzle.produk),
            selectinload(Island.nozzles).selectinload(Nozzle.tangki),
        )
        .order_by(Island.urutan)
    )
    return list(result.scalars().all())


async def get_island_by_id(db: AsyncSession, island_id: int) -> Island | None:
    """Fetch a single island by primary key, with nozzle relationships loaded."""
    result = await db.execute(
        select(Island)
        .where(Island.id == island_id, Island.deleted_at.is_(None))
        .options(
            selectinload(Island.nozzles).selectinload(Nozzle.produk),
            selectinload(Island.nozzles).selectinload(Nozzle.tangki),
        )
    )
    return result.scalar_one_or_none()


async def create_island(db: AsyncSession, spbu_id: int, **kwargs) -> Island:
    """Create a new island for the given SPBU. Flush only — caller commits."""
    island = Island(spbu_id=spbu_id, **kwargs)
    db.add(island)
    await db.flush()
    return await get_island_by_id(db, island.id)


async def update_island(db: AsyncSession, island: Island, **kwargs) -> Island:
    """Update scalar fields on an existing island. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(island, key, value)
    await db.flush()
    return await get_island_by_id(db, island.id)


async def delete_island(db: AsyncSession, island: Island, hard_delete: bool = True) -> None:
    """Delete an island record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(island)
    else:
        island.deleted_at = datetime.now(timezone.utc)
    await db.flush()


# --- Nozzles ---

async def get_nozzle_by_id(db: AsyncSession, nozzle_id: int) -> Nozzle | None:
    """Fetch a single nozzle by primary key, with produk and tangki loaded."""
    result = await db.execute(
        select(Nozzle)
        .where(Nozzle.id == nozzle_id, Nozzle.deleted_at.is_(None))
        .options(selectinload(Nozzle.produk), selectinload(Nozzle.tangki))
    )
    return result.scalar_one_or_none()


async def create_nozzle(db: AsyncSession, island_id: int, **kwargs) -> Nozzle:
    """Create a new nozzle under the given island. Flush only — caller commits."""
    nozzle = Nozzle(island_id=island_id, **kwargs)
    db.add(nozzle)
    await db.flush()
    return await get_nozzle_by_id(db, nozzle.id)


async def update_nozzle(db: AsyncSession, nozzle: Nozzle, **kwargs) -> Nozzle:
    """Update scalar fields on an existing nozzle. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(nozzle, key, value)
    await db.flush()
    return await get_nozzle_by_id(db, nozzle.id)


async def delete_nozzle(db: AsyncSession, nozzle: Nozzle, hard_delete: bool = True) -> None:
    """Delete a nozzle record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(nozzle)
    else:
        nozzle.deleted_at = datetime.now(timezone.utc)
    await db.flush()


# --- Tangkis ---

async def get_tangkis(db: AsyncSession, spbu_id: int) -> list[Tangki]:
    """Return all underground tanks for a given SPBU, with kalibrasi and produk loaded."""
    result = await db.execute(
        select(Tangki)
        .where(Tangki.spbu_id == spbu_id, Tangki.deleted_at.is_(None))
        .options(selectinload(Tangki.kalibrasi), selectinload(Tangki.produk))
        .order_by(Tangki.nama)
    )
    return list(result.scalars().all())


async def get_tangki_by_id(db: AsyncSession, tangki_id: int) -> Tangki | None:
    """Fetch a single tangki by primary key, with kalibrasi and produk loaded."""
    result = await db.execute(
        select(Tangki)
        .where(Tangki.id == tangki_id, Tangki.deleted_at.is_(None))
        .options(selectinload(Tangki.kalibrasi), selectinload(Tangki.produk))
    )
    return result.scalar_one_or_none()


async def create_tangki(db: AsyncSession, spbu_id: int, **kwargs) -> Tangki:
    """Create a new underground tank for the given SPBU. Flush only — caller commits."""
    tangki = Tangki(spbu_id=spbu_id, **kwargs)
    db.add(tangki)
    await db.flush()
    return await get_tangki_by_id(db, tangki.id)


async def update_tangki(db: AsyncSession, tangki: Tangki, **kwargs) -> Tangki:
    """Update scalar fields on an existing tangki. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(tangki, key, value)
    await db.flush()
    return await get_tangki_by_id(db, tangki.id)


async def delete_tangki(db: AsyncSession, tangki: Tangki, hard_delete: bool = True) -> None:
    """Delete a tangki record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(tangki)
    else:
        tangki.deleted_at = datetime.now(timezone.utc)
    await db.flush()


async def replace_kalibrasi(
    db: AsyncSession, tangki_id: int, rows: list[dict]
) -> list[KalibrasiTangki]:
    """Replace all calibration rows for a tangki (delete-all then insert-all). Flush only — caller commits."""
    existing = await db.execute(
        select(KalibrasiTangki).where(KalibrasiTangki.tangki_id == tangki_id)
    )
    for row in existing.scalars().all():
        await db.delete(row)
    new_rows = [KalibrasiTangki(tangki_id=tangki_id, **r) for r in rows]
    db.add_all(new_rows)
    await db.flush()
    return new_rows


# --- Tenants & Kontrak ---

async def get_tenants(db: AsyncSession, spbu_id: int) -> list[Tenant]:
    """Return all tenants for a given SPBU, with their contracts loaded."""
    result = await db.execute(
        select(Tenant)
        .where(Tenant.spbu_id == spbu_id, Tenant.deleted_at.is_(None))
        .options(selectinload(Tenant.kontrak))
        .order_by(Tenant.nama)
    )
    return list(result.scalars().all())


async def get_tenant_by_id(db: AsyncSession, tenant_id: int) -> Tenant | None:
    """Fetch a single tenant by primary key, with contracts loaded."""
    result = await db.execute(
        select(Tenant)
        .where(Tenant.id == tenant_id, Tenant.deleted_at.is_(None))
        .options(selectinload(Tenant.kontrak))
    )
    return result.scalar_one_or_none()


async def create_tenant(db: AsyncSession, spbu_id: int, **kwargs) -> Tenant:
    """Create a new tenant under the given SPBU. Flush only — caller commits."""
    tenant = Tenant(spbu_id=spbu_id, **kwargs)
    db.add(tenant)
    await db.flush()
    return await get_tenant_by_id(db, tenant.id)


async def update_tenant(db: AsyncSession, tenant: Tenant, **kwargs) -> Tenant:
    """Update scalar fields on an existing tenant. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(tenant, key, value)
    await db.flush()
    return await get_tenant_by_id(db, tenant.id)


async def delete_tenant(db: AsyncSession, tenant: Tenant, hard_delete: bool = True) -> None:
    """Delete a tenant record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(tenant)
    else:
        tenant.deleted_at = datetime.now(timezone.utc)
    await db.flush()


async def get_kontrak_by_id(db: AsyncSession, kontrak_id: int) -> KontrakSewa | None:
    """Fetch a single lease contract by primary key."""
    result = await db.execute(
        select(KontrakSewa).where(KontrakSewa.id == kontrak_id, KontrakSewa.deleted_at.is_(None))
    )
    return result.scalar_one_or_none()


async def create_kontrak(db: AsyncSession, tenant_id: int, **kwargs) -> KontrakSewa:
    """Create a new lease contract for the given tenant. Flush only — caller commits."""
    kontrak = KontrakSewa(tenant_id=tenant_id, **kwargs)
    db.add(kontrak)
    await db.flush()
    await db.refresh(kontrak)
    return kontrak


async def update_kontrak(db: AsyncSession, kontrak: KontrakSewa, **kwargs) -> KontrakSewa:
    """Update scalar fields on an existing lease contract. Flush only — caller commits."""
    for key, value in kwargs.items():
        setattr(kontrak, key, value)
    await db.flush()
    await db.refresh(kontrak)
    return kontrak


async def delete_kontrak(db: AsyncSession, kontrak: KontrakSewa, hard_delete: bool = True) -> None:
    """Delete a lease contract record. Flush only — caller commits."""
    if hard_delete:
        await db.delete(kontrak)
    else:
        kontrak.deleted_at = datetime.now(timezone.utc)
    await db.flush()
