from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy.orm import Session, joinedload
from datetime import date
from collections import defaultdict
from app.core.database import get_db
from app.models.employee import Employee
from app.models.company import Company
from app.models.loan_transaction import LoanTransaction
from app.models.payroll_run import PayrollRun, PayrollRunItem
from app.models.benefit import Benefit  # noqa: F401 – used via Employee.benefits relationship
from app.schemas.payroll import (
    PayrollPreviewResponse,
    PayrollRow,
    CompanySummary,
    PayrollExportRequest,
    THRExportRequest,
    PayrollExportUnifiedRequest,
    PayrollRunOut,
    THRCashAdvance,  # noqa: F401 – imported for clarity; used via body.thr_cash_advances
)
from app.helpers.crud import get_or_404
from app.helpers.csv_export import build_kopramandiri_csv

router = APIRouter(prefix="/api/payroll", tags=["payroll"])


def _get_employees_for_export(
    db: Session,
    company_id: str,
    employee_ids: List[str],
) -> List[Employee]:
    return (
        db.query(Employee)
        .options(
            joinedload(Employee.company_rel),
            joinedload(Employee.loans),
            joinedload(Employee.benefits),
        )
        .filter(
            Employee.company_id == company_id,
            Employee.id.in_(employee_ids),
            Employee.active == True,  # noqa: E712
        )
        .order_by(Employee.name)
        .all()
    )


def _active_loan_deduction(emp: Employee) -> int:
    return sum(l.monthly_deduction for l in emp.loans if l.status == "active")


def _active_benefit(emp: Employee) -> int:
    return sum(b.monthly_benefit for b in emp.benefits if b.status == "active")


def _apply_loan_deduction(
    db: Session,
    emp: Employee,
    total_to_apply: int,
    transaction_date: date,
    description: str,
) -> None:
    """Distribute `total_to_apply` across the employee's active loans proportionally."""
    active_loans = [ln for ln in emp.loans if ln.status == "active"]
    if not active_loans or total_to_apply <= 0:
        return

    standard_total = sum(ln.monthly_deduction for ln in active_loans)
    for i, loan in enumerate(active_loans):
        if standard_total > 0:
            if i < len(active_loans) - 1:
                share = round(total_to_apply * loan.monthly_deduction / standard_total)
            else:
                # Last loan absorbs rounding remainder
                share = total_to_apply - sum(
                    round(total_to_apply * ln.monthly_deduction / standard_total)
                    for ln in active_loans[:-1]
                )
        else:
            share = total_to_apply // len(active_loans)

        credit_amt = min(share, loan.remaining_balance)
        if credit_amt <= 0:
            continue
        new_balance = loan.remaining_balance - credit_amt
        tx = LoanTransaction(
            loan_id=loan.id,
            employee_id=emp.id,
            transaction_date=transaction_date,
            description=description,
            debit=0,
            credit=credit_amt,
            balance=new_balance,
        )
        loan.remaining_balance = new_balance
        if new_balance <= 0:
            loan.status = "paid"
        db.add(tx)


@router.get("/preview", response_model=PayrollPreviewResponse)
def preview_payroll(
    month: int = Query(..., ge=1, le=12),
    year: int = Query(..., ge=2000),
    company_id: Optional[str] = Query(None),
    employee_ids: Optional[str] = Query(None),
    db: Session = Depends(get_db),
):
    query = db.query(Employee).options(
        joinedload(Employee.company_rel),
        joinedload(Employee.loans),
        joinedload(Employee.benefits),
    ).filter(Employee.active == True)  # noqa: E712

    if company_id:
        query = query.filter(Employee.company_id == company_id)
    if employee_ids:
        ids = [i.strip() for i in employee_ids.split(",") if i.strip()]
        query = query.filter(Employee.id.in_(ids))

    employees = query.order_by(Employee.name).all()

    rows = []
    company_totals: dict[str, dict] = defaultdict(
        lambda: {"name": "Unknown", "count": 0, "total": 0}
    )

    for emp in employees:
        loan_ded = _active_loan_deduction(emp)
        benefit = _active_benefit(emp)
        net = emp.base_salary + benefit - loan_ded
        cid = emp.company_id or "none"
        cname = emp.company_rel.name if emp.company_rel else "No Company"
        company_totals[cid]["name"] = cname
        company_totals[cid]["count"] += 1
        company_totals[cid]["total"] += net

        rows.append(
            PayrollRow(
                employee_id=emp.id,
                name=emp.name,
                bank_name=emp.company_rel.bank_name if emp.company_rel else None,
                account_number=emp.account_number,
                account_name=emp.account_name,
                base_salary=emp.base_salary,
                benefit=benefit,
                loan_deduction=loan_ded,
                net_salary=net,
                company_id=emp.company_id,
                company_name=cname,
            )
        )

    summaries = [
        CompanySummary(
            company_id=cid,
            company_name=v["name"],
            employee_count=v["count"],
            total_net=v["total"],
        )
        for cid, v in company_totals.items()
    ]
    grand_total = sum(s.total_net for s in summaries)

    return PayrollPreviewResponse(rows=rows, summary=summaries, grand_total=grand_total)


@router.post("/export-csv")
def export_payroll_csv(body: PayrollExportRequest, db: Session = Depends(get_db)):
    company = get_or_404(db, Company, body.company_id, "Company")
    employees = _get_employees_for_export(db, body.company_id, body.employee_ids)

    transfer = body.transfer_date or date.today()
    rows = [
        {
            "account_number": e.account_number,
            "name": e.account_name,
            "amount": e.base_salary - _active_loan_deduction(e),
        }
        for e in employees
    ]

    csv_content = build_kopramandiri_csv(transfer, company, rows)
    filename = f"payroll_{body.year}{body.month:02d}_{company.name}.csv"
    return Response(
        content=csv_content,
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename={filename}"},
    )


@router.post("/export-thr")
def export_thr_csv(body: THRExportRequest, db: Session = Depends(get_db)):
    company = get_or_404(db, Company, body.company_id, "Company")
    employees = _get_employees_for_export(db, body.company_id, body.employee_ids)

    transfer = body.transfer_date or date.today()
    rows = [
        {
            "account_number": e.account_number,
            "name": e.account_name,
            "amount": e.base_salary,  # THR = base salary only, no deductions
        }
        for e in employees
    ]

    csv_content = build_kopramandiri_csv(transfer, company, rows)
    filename = f"thr_{body.year}_{company.name}.csv"
    return Response(
        content=csv_content,
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename={filename}"},
    )


# ── Unified export endpoint ────────────────────────────────────────────────────

@router.post("/export")
def export_unified(body: PayrollExportUnifiedRequest, db: Session = Depends(get_db)):
    """
    Unified payroll/THR export endpoint.
    - export_type "regular": base_salary minus loan deductions (with per-employee overrides)
    - export_type "thr": base_salary only, no deductions
    - Transfer date = today (date of actual export)
    - Loan transactions are NOT recorded here; use POST /runs/{id}/run to record them.
    Always creates a PayrollRun record for audit history.
    """
    company = get_or_404(db, Company, body.company_id, "Company")
    employees = _get_employees_for_export(db, body.company_id, body.employee_ids)

    # Build a map of employee_id -> overridden loan_deduction amount (regular payroll)
    override_map: dict[str, int] = {o.employee_id: o.loan_deduction for o in body.overrides}

    # Build a map of employee_id -> cash advance already paid (THR only)
    cash_advance_map: dict[str, int] = {
        ca.employee_id: ca.cash_advance for ca in body.thr_cash_advances
    }

    # Transfer date = today (actual bank transfer date)
    transfer = date.today()

    # Calculate each employee's amounts and build CSV rows
    csv_rows = []
    run_items_data = []
    for emp in employees:
        if body.export_type == "thr":
            emp_cash_advance = cash_advance_map.get(emp.id, 0)
            loan_ded = 0                     # no loan deduction for THR
            benefit = 0                      # THR does not include monthly benefits
            amount = emp.base_salary - emp_cash_advance
            is_override = emp_cash_advance > 0
        else:
            emp_cash_advance = 0
            standard_ded = _active_loan_deduction(emp)
            benefit = _active_benefit(emp)
            if emp.id in override_map:
                loan_ded = override_map[emp.id]
                is_override = True
            else:
                loan_ded = standard_ded
                is_override = False
            amount = emp.base_salary + benefit - loan_ded

        # THR: skip CSV transfer if cash advance >= base salary (net <= 0)
        # Still record in run_items for audit; net_amount stored as 0
        net_amount = max(0, amount) if body.export_type == "thr" else amount
        if net_amount > 0:
            csv_rows.append({
                "account_number": emp.account_number,
                "name": emp.account_name,
                "amount": net_amount,
            })
        run_items_data.append({
            "employee_id": emp.id,
            "employee_name": emp.name,
            "account_name": emp.account_name,
            "account_number": emp.account_number,
            "base_salary": emp.base_salary,
            "benefit": benefit,
            "loan_deduction": loan_ded,
            "cash_advance": emp_cash_advance,
            "net_amount": net_amount,
            "has_override": is_override,
        })

    # Create PayrollRun record (transactions recorded separately via /run endpoint)
    total_amount = sum(r["net_amount"] for r in run_items_data)
    run = PayrollRun(
        month=body.month,
        year=body.year,
        export_type=body.export_type,
        company_id=body.company_id,
        company_name=company.name,
        total_amount=total_amount,
        employee_count=len(run_items_data),
        recorded_transactions=False,
    )
    db.add(run)
    db.flush()  # Get run.id before creating items

    for item_data in run_items_data:
        item = PayrollRunItem(run_id=run.id, **item_data)
        db.add(item)

    db.commit()

    csv_content = build_kopramandiri_csv(transfer, company, csv_rows)
    if body.export_type == "thr":
        filename = f"thr_{body.year}_{company.name}.csv"
    else:
        filename = f"payroll_{body.year}{body.month:02d}_{company.name}.csv"

    return Response(
        content=csv_content,
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename={filename}"},
    )


# ── Payroll run history ────────────────────────────────────────────────────────

@router.get("/runs", response_model=List[PayrollRunOut])
def list_payroll_runs(
    year: Optional[int] = Query(None),
    db: Session = Depends(get_db),
):
    """Return all payroll export runs, optionally filtered by year, newest first."""
    query = db.query(PayrollRun).options(joinedload(PayrollRun.items))
    if year:
        query = query.filter(PayrollRun.year == year)
    return query.order_by(PayrollRun.run_date.desc()).all()


@router.get("/runs/{run_id}", response_model=PayrollRunOut)
def get_payroll_run(run_id: str, db: Session = Depends(get_db)):
    """Return a single payroll run with all its items."""
    run = (
        db.query(PayrollRun)
        .options(joinedload(PayrollRun.items))
        .filter(PayrollRun.id == run_id)
        .first()
    )
    if not run:
        raise HTTPException(status_code=404, detail="Payroll run not found")
    return run


@router.post("/runs/{run_id}/run")
def run_payroll_transactions(run_id: str, db: Session = Depends(get_db)):
    """
    Record loan deduction transactions for a previously exported payroll run.
    Uses the stored loan_deduction per item and distributes it across the
    employee's current active loans proportionally.
    Can only be called once (no-op is prevented via recorded_transactions flag).
    """
    run = (
        db.query(PayrollRun)
        .options(joinedload(PayrollRun.items))
        .filter(PayrollRun.id == run_id)
        .first()
    )
    if not run:
        raise HTTPException(status_code=404, detail="Payroll run not found")
    if run.recorded_transactions:
        raise HTTPException(status_code=400, detail="Transactions have already been recorded for this run")

    tx_date = run.run_date.date()

    if run.export_type == "regular":
        # Regular payroll: record loan deduction transactions for each employee
        for item in run.items:
            if not item.employee_id or item.loan_deduction <= 0:
                continue

            emp = (
                db.query(Employee)
                .options(joinedload(Employee.loans))
                .filter(Employee.id == item.employee_id)
                .first()
            )
            if not emp:
                continue  # Employee was deleted — skip gracefully

            suffix = " (adjusted)" if item.has_override else ""
            _apply_loan_deduction(
                db,
                emp,
                item.loan_deduction,
                tx_date,
                f"Payroll deduction {run.year}/{run.month:02d}{suffix}",
            )
        message = f"Loan transactions recorded for {run.year}/{run.month:02d}"
    else:
        # THR: no loan transactions to record — just mark as executed
        message = f"THR marked as executed for {run.year}"

    run.recorded_transactions = True
    db.commit()

    return {"message": message}


@router.delete("/runs/{run_id}", status_code=204)
def delete_payroll_run(run_id: str, db: Session = Depends(get_db)):
    """
    Delete a payroll export run. Only allowed if transactions have NOT been recorded.
    Once 'Run' has been executed, the run is locked and cannot be deleted.
    """
    run = db.query(PayrollRun).filter(PayrollRun.id == run_id).first()
    if not run:
        raise HTTPException(status_code=404, detail="Payroll run not found")
    if run.recorded_transactions:
        raise HTTPException(
            status_code=400,
            detail="Cannot delete a run that has already been executed (loan transactions recorded)",
        )
    db.delete(run)
    db.commit()


@router.get("/runs/{run_id}/download")
def download_payroll_run(run_id: str, db: Session = Depends(get_db)):
    """Regenerate and return the CSV for a previously recorded payroll run."""
    run = (
        db.query(PayrollRun)
        .options(joinedload(PayrollRun.items))
        .filter(PayrollRun.id == run_id)
        .first()
    )
    if not run:
        raise HTTPException(status_code=404, detail="Payroll run not found")

    company = db.query(Company).filter(Company.id == run.company_id).first()
    if not company:
        raise HTTPException(status_code=404, detail="Company no longer exists")

    transfer = run.run_date.date()
    csv_rows = [
        {
            "account_number": item.account_number or "",
            "name": item.account_name or item.employee_name,
            "amount": item.net_amount,
        }
        for item in run.items
    ]

    csv_content = build_kopramandiri_csv(transfer, company, csv_rows)
    if run.export_type == "thr":
        filename = f"thr_{run.year}_{company.name}.csv"
    else:
        filename = f"payroll_{run.year}{run.month:02d}_{company.name}.csv"

    return Response(
        content=csv_content,
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename={filename}"},
    )
