"""
Business logic — no GUI, no Tkinter. Drop-in for the Flask wrapper.
Dates, account numbers, amounts. That's it.
"""

import os
import sqlite3
import json
import csv
import uuid
from decimal import Decimal, InvalidOperation
from datetime import datetime
from io import StringIO
import threading

DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ledger.db")
_thread_local = threading.local()


def set_active_db(path):
    """Set the active SQLite database path for the current thread context."""
    _thread_local.db_path = path


def get_db():
    """Retrieve an open SQLite connection. Note: login controls access via directory/session isolation,
    not at-rest database encryption.
    
    SQLCipher Extension Point:
    To add at-rest encryption via SQLCipher in the future, replace the line below with:
        import pysqlcipher3.dbapi2 as sqlite3
        conn = sqlite3.connect(path)
        conn.execute(f"PRAGMA key = '{encryption_key}'")
    """
    path = getattr(_thread_local, "db_path", DB_PATH)
    return sqlite3.connect(path)


def D(x):
    s = str(x).strip() if x is not None else ""
    if not s:
        return Decimal("0")
    return Decimal(s)


def add_column_if_not_exists(c, table_name, column_name, definition):
    c.execute(f"PRAGMA table_info({table_name})")
    columns = [row[1] for row in c.fetchall()]
    if column_name not in columns:
        c.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {definition}")


def init_db():
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        CREATE TABLE IF NOT EXISTS company (
            id INTEGER PRIMARY KEY CHECK (id = 1),
            company_id TEXT, name TEXT, reg_number TEXT,
            tax_number TEXT, address TEXT, currency TEXT,
            year_end TEXT, period_type TEXT DEFAULT 'monthly'
        )
    """)
    # Run migrations for company table
    add_column_if_not_exists(c, "company", "period_type", "TEXT DEFAULT 'monthly'")
    add_column_if_not_exists(c, "company", "lock_date", "TEXT DEFAULT NULL")

    c.execute("""
        CREATE TABLE IF NOT EXISTS uploads (
            upload_id TEXT PRIMARY KEY,
            type TEXT NOT NULL,
            filename TEXT,
            uploaded_at TEXT,
            status TEXT DEFAULT 'DRAFT',
            raw_csv TEXT,
            header_json TEXT
        )
    """)
    # Run migrations for uploads table
    add_column_if_not_exists(c, "uploads", "created_by", "TEXT DEFAULT 'system'")
    add_column_if_not_exists(c, "uploads", "source_type", "TEXT DEFAULT NULL")

    c.execute("""
        CREATE TABLE IF NOT EXISTS coa_rows (
            row_id TEXT PRIMARY KEY,
            upload_id TEXT NOT NULL,
            account_no TEXT NOT NULL,
            description TEXT NOT NULL,
            metadata TEXT,
            FOREIGN KEY (upload_id) REFERENCES uploads(upload_id)
        )
    """)
    c.execute("""
        CREATE TABLE IF NOT EXISTS entry_rows (
            row_id TEXT PRIMARY KEY,
            upload_id TEXT NOT NULL,
            date TEXT NOT NULL,
            account_no TEXT NOT NULL,
            contra_account_no TEXT,
            amount TEXT NOT NULL,
            reference TEXT,
            internally_balanced INTEGER NOT NULL DEFAULT 0,
            FOREIGN KEY (upload_id) REFERENCES uploads(upload_id)
        )
    """)
    # Create indexes for period-aware trial balance performance
    c.execute("CREATE INDEX IF NOT EXISTS idx_entry_rows_acct_date ON entry_rows(account_no, date)")
    c.execute("CREATE INDEX IF NOT EXISTS idx_entry_rows_date ON entry_rows(date)")

    # Bump schema version
    c.execute("PRAGMA user_version = 1")

    conn.commit()
    conn.close()


# ============================================================================
# VALIDATION & PARSING
# ============================================================================

def parse_csv(text):
    reader = csv.DictReader(StringIO(text))
    if reader.fieldnames is None:
        return [], []
    rows = list(reader)
    return rows, list(reader.fieldnames)


def validate_coa_csv(rows):
    if not rows:
        return False, "Empty CSV — no rows found.", []
    account_nos = []
    for i, row in enumerate(rows, 1):
        if not str(row.get("account_no", "")).strip():
            return False, f"Row {i}: account_no is required.", []
        if not str(row.get("description", "")).strip():
            return False, f"Row {i}: description is required.", []
        account_nos.append(str(row["account_no"]).strip())
    dups = [x for x in set(account_nos) if account_nos.count(x) > 1]
    if dups:
        return False, f"Duplicate account_no within file: {', '.join(sorted(dups))}", []
    return True, "", account_nos


def parse_and_normalize_date(date_str):
    """Parse and normalize date string into ISO format YYYY-MM-DD.
    Accepts formats: YYYY-MM-DD, DD/MM/YYYY, DD-MM-YYYY, MM/DD/YYYY, YYYY/MM/DD."""
    date_str = date_str.strip()
    formats = [
        "%Y-%m-%d",
        "%d/%m/%Y",
        "%d-%m-%Y",
        "%m/%d/%Y",
        "%Y/%m/%d"
    ]
    for fmt in formats:
        try:
            dt = datetime.strptime(date_str, fmt)
            return dt.strftime("%Y-%m-%d")
        except ValueError:
            continue
    raise ValueError(f"Unrecognized date format: '{date_str}'")


def validate_entry_csv(rows, source_type=None):
    if not rows:
        return False, "Empty CSV — no rows found.", [], []
    required = ["date", "account", "amount"]
    missing = [r for r in required if r not in (rows[0] or {})]
    if missing:
        return False, f"Missing required column(s): {', '.join(missing)}", [], []
    parsed = []
    net = Decimal("0")
    warnings = []
    for i, row in enumerate(rows, 1):
        date    = str(row.get("date", "")).strip()
        account = str(row.get("account", "")).strip()
        contra  = str(row.get("contra", "")).strip()
        amt_raw = str(row.get("amount", "")).strip()
        ref     = str(row.get("reference", "") or "").strip()
        if not date:
            return False, f"Row {i}: date is required.", [], []
        try:
            normalized_date = parse_and_normalize_date(date)
        except ValueError as e:
            return False, f"Row {i}: {str(e)}", [], []
        if not account:
            return False, f"Row {i}: account is required.", [], []
        if not amt_raw:
            return False, f"Row {i}: amount is required.", [], []
        try:
            amount = Decimal(amt_raw)
        except InvalidOperation:
            return False, f"Row {i}: amount '{amt_raw}' is not numeric.", [], []
        
        # Anti-double-count safeguard warning for P&L opening balances
        if source_type == 'OPENING' and amount != Decimal("0"):
            is_pl = False
            try:
                if int(account) >= 4000:
                    is_pl = True
            except ValueError:
                if account >= '4000':
                    is_pl = True
            if is_pl:
                warnings.append(
                    f"Warning: Row {i} - Account '{account}' is a temporary P&L account "
                    f"but has a non-zero opening balance of {amount}."
                )

        internally_balanced = bool(contra)
        parsed.append({
            "date": normalized_date, "account": account, "contra": contra or None,
            "amount": amount, "reference": ref,
            "internally_balanced": internally_balanced,
        })
        if not internally_balanced:
            net += amount
    if net != Decimal("0"):
        return False, (
            f"Non-contra rows do not net to zero. Net = {net}. "
            "Rows with both account and contra are excluded from this check."
        ), [], []
    return True, "", parsed, warnings


# ============================================================================
# DATA ACCESS
# ============================================================================

def get_company():
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT * FROM company WHERE id = 1")
    row = c.fetchone()
    conn.close()
    return row


def company_to_dict(row):
    keys = ["id", "company_id", "name", "reg_number", "tax_number",
            "address", "currency", "year_end", "period_type", "lock_date"]
    if not row:
        return {k: "" for k in keys}
    return dict(zip(keys[:len(row)], row))


def save_company(company_id, name, reg_number, tax_number,
                 address, currency, year_end, period_type="monthly", lock_date=None):
    conn = get_db()
    c = conn.cursor()
    c.execute("DELETE FROM company")
    c.execute("""
        INSERT INTO company
            (id, company_id, name, reg_number, tax_number, address, currency, year_end, period_type, lock_date)
        VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (company_id, name, reg_number, tax_number, address, currency, year_end, period_type, lock_date))
    conn.commit()
    conn.close()


def get_uploads_by_type(type_str):
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        SELECT upload_id, type, filename, uploaded_at, status,
               CASE type
                   WHEN 'COA' THEN (SELECT COUNT(*) FROM coa_rows   WHERE upload_id = u.upload_id)
                   ELSE            (SELECT COUNT(*) FROM entry_rows WHERE upload_id = u.upload_id)
               END,
               source_type
        FROM uploads u
        WHERE type = ?
        ORDER BY uploaded_at DESC
    """, (type_str,))
    rows = c.fetchall()
    conn.close()
    return rows


def upload_coa_batch(filename, raw_csv, headers, rows, created_by='system'):
    upload_id = str(uuid.uuid4())
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        INSERT INTO uploads (upload_id, type, filename, uploaded_at, status, raw_csv, header_json, created_by, source_type)
        VALUES (?, 'COA', ?, ?, 'DRAFT', ?, ?, ?, NULL)
    """, (upload_id, filename, datetime.now().isoformat(), raw_csv, json.dumps(headers), created_by))
    for row in rows:
        metadata = {k: v for k, v in row.items() if k not in ("account_no", "description")}
        c.execute("""
            INSERT INTO coa_rows (row_id, upload_id, account_no, description, metadata)
            VALUES (?, ?, ?, ?, ?)
        """, (str(uuid.uuid4()), upload_id,
              str(row["account_no"]).strip(), str(row["description"]).strip(),
              json.dumps(metadata)))
    conn.commit()
    conn.close()
    return upload_id


def upload_entry_batch(filename, raw_csv, headers, parsed_rows, created_by='system', source_type=None):
    upload_id = str(uuid.uuid4())
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        INSERT INTO uploads (upload_id, type, filename, uploaded_at, status, raw_csv, header_json, created_by, source_type)
        VALUES (?, 'ENTRY', ?, ?, 'DRAFT', ?, ?, ?, ?)
    """, (upload_id, filename, datetime.now().isoformat(), raw_csv, json.dumps(headers), created_by, source_type))
    for r in parsed_rows:
        c.execute("""
            INSERT INTO entry_rows
                (row_id, upload_id, date, account_no, contra_account_no,
                 amount, reference, internally_balanced)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """, (str(uuid.uuid4()), upload_id,
              r["date"], r["account"], r["contra"],
              str(r["amount"]), r["reference"],
              1 if r["internally_balanced"] else 0))
    conn.commit()
    conn.close()
    return upload_id


def is_batch_locked(upload_id):
    """Check if the upload batch contains any rows on or before the company lock_date."""
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT lock_date FROM company WHERE id=1")
    row = c.fetchone()
    if not row or not row[0]:
        conn.close()
        return False
    lock_date = row[0]
    
    c.execute("SELECT type FROM uploads WHERE upload_id=?", (upload_id,))
    u_row = c.fetchone()
    if not u_row or u_row[0] != "ENTRY":
        conn.close()
        return False
        
    c.execute("SELECT 1 FROM entry_rows WHERE upload_id=? AND date <= ? LIMIT 1", (upload_id, lock_date))
    locked = c.fetchone() is not None
    conn.close()
    return locked


def post_coa_batch(upload_id):
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT account_no FROM coa_rows WHERE upload_id = ?", (upload_id,))
    new_accounts = {r[0] for r in c.fetchall()}
    c.execute("""
        SELECT DISTINCT account_no FROM coa_rows
        WHERE upload_id IN (SELECT upload_id FROM uploads WHERE type='COA' AND status='POSTED')
    """)
    posted = {r[0] for r in c.fetchall()}
    dups = new_accounts & posted
    if dups:
        conn.close()
        return False, f"Accounts already posted: {', '.join(sorted(dups))}"
    c.execute("UPDATE uploads SET status='POSTED' WHERE upload_id=?", (upload_id,))
    conn.commit()
    conn.close()
    return True, ""


def post_entry_batch(upload_id):
    if is_batch_locked(upload_id):
        return False, "This batch contains transaction dates on or before the company lock date."
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        SELECT account_no FROM entry_rows WHERE upload_id=?
        UNION
        SELECT contra_account_no FROM entry_rows
        WHERE upload_id=? AND contra_account_no IS NOT NULL AND contra_account_no <> ''
    """, (upload_id, upload_id))
    used = {r[0] for r in c.fetchall() if r[0]}
    c.execute("""
        SELECT DISTINCT account_no FROM coa_rows
        WHERE upload_id IN (SELECT upload_id FROM uploads WHERE type='COA' AND status='POSTED')
    """)
    coa = {r[0] for r in c.fetchall()}
    missing = used - coa
    if missing:
        conn.close()
        return False, f"Unknown accounts (not in posted COA): {', '.join(sorted(missing))}"
    c.execute("UPDATE uploads SET status='POSTED' WHERE upload_id=?", (upload_id,))
    conn.commit()
    conn.close()
    return True, ""


def unpost_batch(upload_id):
    if is_batch_locked(upload_id):
        return False, "This batch contains transaction dates on or before the company lock date."
    conn = get_db()
    c = conn.cursor()
    c.execute("UPDATE uploads SET status='DRAFT' WHERE upload_id=?", (upload_id,))
    conn.commit()
    conn.close()
    return True, ""


def delete_batch(upload_id):
    # This is a safe backend function, UI assertions are handled in assert_draft
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT type FROM uploads WHERE upload_id=?", (upload_id,))
    result = c.fetchone()
    if not result:
        conn.close()
        return
    table = "coa_rows" if result[0] == "COA" else "entry_rows"
    c.execute(f"DELETE FROM {table} WHERE upload_id=?", (upload_id,))
    c.execute("DELETE FROM uploads WHERE upload_id=?", (upload_id,))
    conn.commit()
    conn.close()


def assert_draft(upload_id):
    if is_batch_locked(upload_id):
        return False, "This batch contains transaction dates on or before the company lock date."
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT status FROM uploads WHERE upload_id=?", (upload_id,))
    result = c.fetchone()
    conn.close()
    if not result:
        return False, "Upload not found."
    if result[0] != "DRAFT":
        return False, "Unpost the batch before deleting."
    return True, ""


def get_coa_rows(upload_id):
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT account_no, description FROM coa_rows WHERE upload_id=? ORDER BY account_no",
              (upload_id,))
    rows = c.fetchall()
    conn.close()
    return rows


def get_entry_rows(upload_id):
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        SELECT date, account_no, contra_account_no, amount, reference, internally_balanced
        FROM entry_rows WHERE upload_id=? ORDER BY rowid
    """, (upload_id,))
    rows = c.fetchall()
    conn.close()
    return rows


def get_trial_balance(date_from=None, date_to=None):
    """Aggregate signed balances from POSTED entry batches.
    Optionally filtered to entries where date >= date_from and date <= date_to.
    Returns list of (account_no, description, Decimal balance) sorted by account_no."""
    conn = get_db()
    c = conn.cursor()
    where = ["upload_id IN (SELECT upload_id FROM uploads WHERE type='ENTRY' AND status='POSTED')"]
    params = []
    if date_from:
        where.append("date >= ?")
        params.append(date_from)
    if date_to:
        where.append("date <= ?")
        params.append(date_to)
    c.execute(
        f"SELECT account_no, contra_account_no, amount FROM entry_rows WHERE {' AND '.join(where)}",
        params
    )
    tb = {}
    for acct, contra, amt in c.fetchall():
        a = D(amt)
        tb[acct] = tb.get(acct, Decimal("0")) + a
        if contra:
            tb[contra] = tb.get(contra, Decimal("0")) - a
    c.execute("""
        SELECT account_no, description FROM coa_rows
        WHERE upload_id IN (SELECT upload_id FROM uploads WHERE type='COA' AND status='POSTED')
    """)
    descs = {a: d for a, d in c.fetchall()}
    conn.close()
    return [(acct, descs.get(acct, "(no description)"), tb[acct]) for acct in sorted(tb)]


def get_composition():
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        SELECT type, filename, status,
               CASE type
                   WHEN 'COA' THEN (SELECT COUNT(*) FROM coa_rows   WHERE upload_id = u.upload_id)
                   ELSE            (SELECT COUNT(*) FROM entry_rows WHERE upload_id = u.upload_id)
               END
        FROM uploads u WHERE status='POSTED' ORDER BY uploaded_at ASC
    """)
    rows = c.fetchall()
    conn.close()
    return rows


def get_status_counts():
    conn = get_db()
    c = conn.cursor()
    c.execute("SELECT COUNT(*) FROM uploads WHERE type='COA'   AND status='POSTED'")
    coa = c.fetchone()[0]
    c.execute("SELECT COUNT(*) FROM uploads WHERE type='ENTRY' AND status='POSTED'")
    ent = c.fetchone()[0]
    conn.close()
    return coa, ent


def backup_db(dest_path):
    """Safely copy the current active database to dest_path using SQLite online backup."""
    conn = get_db()
    dest_conn = sqlite3.connect(dest_path)
    with dest_conn:
        conn.backup(dest_conn)
    dest_conn.close()
    conn.close()


def get_backup_manifest_stats():
    """Retrieve details for the backup manifest.json file."""
    conn = get_db()
    c = conn.cursor()
    
    c.execute("SELECT name FROM company WHERE id=1")
    row = c.fetchone()
    company_name = row[0] if row else "Unnamed Company"
    
    c.execute("PRAGMA user_version")
    schema_ver = c.fetchone()[0]
    
    c.execute("SELECT COUNT(*) FROM uploads")
    uploads_cnt = c.fetchone()[0]
    
    c.execute("SELECT COUNT(*) FROM coa_rows")
    coa_cnt = c.fetchone()[0]
    
    c.execute("SELECT COUNT(*) FROM entry_rows")
    entries_cnt = c.fetchone()[0]
    
    conn.close()
    return {
        "company_name": company_name,
        "schema_version": schema_ver,
        "row_counts": {
            "uploads": uploads_cnt,
            "coa_rows": coa_cnt,
            "entry_rows": entries_cnt
        }
    }


def get_posted_accounts():
    """Retrieve all posted accounts for Retained Earnings lookup."""
    conn = get_db()
    c = conn.cursor()
    c.execute("""
        SELECT DISTINCT account_no, description FROM coa_rows
        WHERE upload_id IN (SELECT upload_id FROM uploads WHERE type='COA' AND status='POSTED')
        ORDER BY account_no
    """)
    rows = c.fetchall()
    conn.close()
    return rows


def generate_rollover_batch(year_end_date, opening_date, retained_earnings_acct, pl_starts_from, created_by='system'):
    """
    Calculate closing balances as of year_end_date.
    Generate a draft entry batch on opening_date.
    """
    conn = get_db()
    c = conn.cursor()
    
    # 1. Block double-counting of opening balances for this date
    c.execute("""
        SELECT 1 FROM uploads 
        WHERE type='ENTRY' AND source_type='OPENING' 
          AND upload_id IN (SELECT DISTINCT upload_id FROM entry_rows WHERE date = ?)
    """, (opening_date,))
    if c.fetchone():
        conn.close()
        return False, f"An opening balance batch for {opening_date} already exists."
        
    # 2. Compute trial balance up to year_end_date
    tb = get_trial_balance(None, year_end_date)
    if not tb:
        conn.close()
        return False, "No posted balances found to roll forward."
        
    parsed_rows = []
    retained_earnings_impact = Decimal("0")
    
    # 3. Process each account balance
    for acct, desc, bal in tb:
        is_pl = False
        try:
            if int(acct) >= int(pl_starts_from):
                is_pl = True
        except ValueError:
            if acct >= pl_starts_from:
                is_pl = True
                
        if is_pl:
            retained_earnings_impact += bal
        else:
            if bal != Decimal("0"):
                parsed_rows.append({
                    "date": opening_date,
                    "account": acct,
                    "contra": None,
                    "amount": bal,
                    "reference": "Opening Balance",
                    "internally_balanced": False
                })
                
    # Add Retained Earnings opening entry
    if retained_earnings_impact != Decimal("0"):
        parsed_rows.append({
            "date": opening_date,
            "account": retained_earnings_acct,
            "contra": None,
            "amount": retained_earnings_impact,
            "reference": "Opening Balance - P&L Rollover",
            "internally_balanced": False
        })
        
    if not parsed_rows:
        conn.close()
        return False, "No non-zero balances to roll forward."
        
    conn.close()
    
    # 4. Generate CSV representation
    import csv
    import io
    buf = io.StringIO()
    w = csv.writer(buf)
    w.writerow(["date", "account", "contra", "amount", "reference"])
    for r in parsed_rows:
        w.writerow([r["date"], r["account"], "", f"{r['amount']:.2f}", r["reference"]])
    raw_csv = buf.getvalue()
    
    headers = ["date", "account", "contra", "amount", "reference"]
    
    upload_id = upload_entry_batch(
        filename=f"Rollover_{year_end_date}_to_{opening_date}.csv",
        raw_csv=raw_csv,
        headers=headers,
        parsed_rows=parsed_rows,
        created_by=created_by,
        source_type='OPENING'
    )
    
    return True, upload_id


