"""
Flask web wrapper for the Ledger kernel.
All file I/O goes through HTTP — no disk writes except the database itself.
"""

import io
import os
import csv
import shutil
import tempfile
from decimal import Decimal
from datetime import datetime
from flask import (Flask, render_template, request, redirect,
                   url_for, flash, send_file, session, send_from_directory)
import ledger_core as core
import registry

app = Flask(__name__)
app.secret_key = "ledger-secret-key-change-before-deploy"

from pathlib import Path

@app.route("/web-assets/<path:filename>")
def web_assets(filename):
    return send_from_directory(Path(app.root_path) / "web-assets", filename)


@app.route("/csv-template/<name>")
def csv_template(name):
    """Download a blank CSV template from the csv_templates/ folder."""
    return send_from_directory(Path(app.root_path) / "csv_templates", name, as_attachment=True)


@app.before_request
def set_db_context_and_auth_guard():
    # Fully public endpoints (no login needed)
    public = ['login', 'register_user', 'static', 'web_assets']
    if request.endpoint in public or not request.endpoint:
        return

    # Step 1: must be logged in (single username + password gate)
    if not session.get('username'):
        return redirect(url_for('login'))

    # Step 2: engagement-management endpoints only need a logged-in user
    eng_mgmt = ['engagements', 'open_engagement', 'register_engagement', 'logout']
    if request.endpoint in eng_mgmt:
        return

    # Step 3: everything else needs an active engagement selected
    if not session.get('db_path'):
        return redirect(url_for('engagements'))

    # Bind the current thread context to the session's active database path
    core.set_active_db(session['db_path'])


# ============================================================================
# AUTHENTICATION DOORWAY
# ============================================================================

@app.route("/login", methods=["GET", "POST"])
def login():
    if request.method == "POST":
        username = request.form.get("username", "")
        password = request.form.get("password", "")

        # Single gate: verify user credentials only
        if not registry.verify_user(username, password):
            flash("Invalid username or password.", "danger")
            return redirect(url_for("login"))

        session['username'] = username.strip().lower()
        session.pop('engagement_key', None)
        session.pop('db_path', None)
        return redirect(url_for("engagements"))

    user_count = registry.count_users()
    return render_template("login.html", user_count=user_count)


@app.route("/engagements")
def engagements():
    """Engagement hub: see all companies, open an existing one, or create a new one."""
    return render_template("engagements.html", active="engagements",
                           engagements=registry.get_engagements(),
                           active_key=session.get('engagement_key'))


@app.route("/engagements/open/<key>", methods=["POST"])
def open_engagement(key):
    """Open an existing engagement (no per-engagement password in simplified flow)."""
    db_path = registry.get_engagement_db_path(key)
    if not db_path:
        flash("Engagement not found.", "danger")
        return redirect(url_for("engagements"))

    session['engagement_key'] = key.strip().lower()
    session['db_path'] = db_path
    core.set_active_db(db_path)
    core.init_db()
    registry.update_engagement_last_opened(key)
    flash(f"Opened engagement: {key}", "success")
    return redirect(url_for("company"))


@app.route("/register-user", methods=["POST"])
def register_user():
    username = request.form.get("new_username", "")
    password = request.form.get("new_password", "")
    if registry.create_user(username, password):
        flash(f"User '{username}' registered successfully.", "success")
    else:
        flash("Registration failed. Username may already be taken.", "danger")
    return redirect(url_for("login"))


@app.route("/register-engagement", methods=["POST"])
def register_engagement():
    name = request.form.get("eng_name", "")
    key = request.form.get("eng_key", "")

    # Restrict keys to simple strings
    import re
    if not re.match(r"^[a-zA-Z0-9_]+$", key):
        flash("Engagement key must contain only letters, numbers, and underscores.", "danger")
        return redirect(url_for("engagements"))

    # Engagement password deprecated in simplified flow — store a fixed placeholder.
    ok, err_or_path = registry.create_engagement(key, name, "_open_")
    if ok:
        flash(f"Engagement '{name}' created successfully.", "success")
    else:
        flash(f"Failed to create engagement: {err_or_path}", "danger")
    return redirect(url_for("engagements"))


@app.route("/logout")
def logout():
    session.clear()
    flash("Session closed.", "success")
    return redirect(url_for("login"))



# ============================================================================
# COMPANY
# ============================================================================

@app.route("/")
def index():
    return redirect(url_for("company"))


@app.route("/company", methods=["GET", "POST"])
def company():
    if request.method == "POST":
        core.save_company(
            request.form.get("company_id", ""),
            request.form.get("name", ""),
            request.form.get("reg_number", ""),
            request.form.get("tax_number", ""),
            request.form.get("address", ""),
            request.form.get("currency", ""),
            request.form.get("year_end", ""),
            request.form.get("period_type", "monthly"),
            request.form.get("lock_date") or None
        )
        flash("Company saved.", "success")
        return redirect(url_for("company"))
    co = core.company_to_dict(core.get_company())
    coa_count, entry_count = core.get_status_counts()
    accounts = core.get_posted_accounts()
    return render_template("company.html", co=co, active="company",
                           coa_count=coa_count, entry_count=entry_count,
                           accounts=accounts)


@app.route("/company/export-csv")
def company_export_csv():
    co = core.company_to_dict(core.get_company())
    keys = ["company_id", "name", "reg_number", "tax_number",
            "address", "currency", "year_end", "period_type", "lock_date"]
    buf = io.StringIO()
    w = csv.writer(buf)
    for k in keys:
        w.writerow([k, co.get(k, "")])
    return send_file(
        io.BytesIO(buf.getvalue().encode("utf-8")),
        mimetype="text/csv",
        as_attachment=True,
        download_name="company.csv",
    )


@app.route("/company/import-csv", methods=["POST"])
def company_import_csv():
    f = request.files.get("csv_file")
    if not f:
        flash("No file selected.", "danger")
        return redirect(url_for("company"))
    try:
        text = f.read().decode("utf-8-sig")
        from io import StringIO as SIO
        import csv as csvmod
        data = {
            row[0].strip().lower(): (row[1] if len(row) > 1 else "")
            for row in csvmod.reader(SIO(text))
            if row and row[0].strip()
        }
        core.save_company(
            data.get("company_id", ""), data.get("name", ""),
            data.get("reg_number", ""), data.get("tax_number", ""),
            data.get("address", ""), data.get("currency", ""),
            data.get("year_end", ""), data.get("period_type", "monthly"),
            data.get("lock_date") or None
        )
        flash("Company imported and saved.", "success")
    except Exception as e:
        flash(f"Import error: {e}", "danger")
    return redirect(url_for("company"))


@app.route("/company/rollover", methods=["POST"])
def company_rollover():
    year_end_date = request.form.get("rollover_year_end", "")
    opening_date = request.form.get("rollover_opening", "")
    retained_earnings_acct = request.form.get("retained_earnings_acct", "")
    pl_starts_from = request.form.get("pl_starts_from", "4000")
    
    if not year_end_date or not opening_date or not retained_earnings_acct:
        flash("Year End Date, Opening Date, and Retained Earnings account are required.", "danger")
        return redirect(url_for("company"))
        
    try:
        ok, result = core.generate_rollover_batch(
            year_end_date=year_end_date,
            opening_date=opening_date,
            retained_earnings_acct=retained_earnings_acct,
            pl_starts_from=pl_starts_from,
            created_by=session.get('username', 'system')
        )
        if ok:
            flash(f"Rollover batch generated as DRAFT with ID: {result[:8]}. Please review and post it under Entries.", "success")
            return redirect(url_for("entries"))
        else:
            flash(f"Rollover failed: {result}", "danger")
    except Exception as e:
        flash(f"Rollover failed: {e}", "danger")
        
    return redirect(url_for("company"))


# ============================================================================
# CHART OF ACCOUNTS
# ============================================================================

@app.route("/coa")
def coa():
    batches = core.get_uploads_by_type("COA")
    return render_template("coa.html", batches=batches, active="coa")


@app.route("/coa/upload", methods=["POST"])
def coa_upload():
    f = request.files.get("csv_file")
    if not f:
        flash("No file selected.", "danger")
        return redirect(url_for("coa"))
    try:
        raw = f.read().decode("utf-8-sig")
        rows, headers = core.parse_csv(raw)
        ok, err, _ = core.validate_coa_csv(rows)
        if not ok:
            flash(f"Validation error: {err}", "danger")
            return redirect(url_for("coa"))
        core.upload_coa_batch(f.filename, raw, headers, rows, created_by=session.get('username', 'system'))
        flash(f"{len(rows)} rows uploaded as DRAFT.", "success")
    except Exception as e:
        flash(f"Error: {e}", "danger")
    return redirect(url_for("coa"))


@app.route("/coa/post/<uid>", methods=["POST"])
def coa_post(uid):
    ok, err = core.post_coa_batch(uid)
    flash("COA batch posted." if ok else f"Cannot post: {err}",
          "success" if ok else "danger")
    return redirect(url_for("coa"))


@app.route("/coa/unpost/<uid>", methods=["POST"])
def coa_unpost(uid):
    ok, err = core.unpost_batch(uid)
    if not ok:
        flash(f"Cannot unpost: {err}", "danger")
    else:
        flash("COA batch set back to DRAFT.", "success")
    return redirect(url_for("coa"))


@app.route("/coa/export/<uid>")
def coa_export(uid):
    return _stream_raw_csv(uid)


@app.route("/coa/view/<uid>")
def coa_view(uid):
    rows = core.get_coa_rows(uid)
    return render_template("view_rows.html", rows=rows, upload_type="COA", active="coa")


@app.route("/coa/delete/<uid>", methods=["POST"])
def coa_delete(uid):
    ok, err = core.assert_draft(uid)
    if not ok:
        flash(err, "danger")
        return redirect(url_for("coa"))
    core.delete_batch(uid)
    flash("Batch deleted.", "success")
    return redirect(url_for("coa"))


# ============================================================================
# ENTRIES
# ============================================================================

@app.route("/entries")
def entries():
    batches = core.get_uploads_by_type("ENTRY")
    return render_template("entries.html", batches=batches, active="entries")


@app.route("/entries/upload", methods=["POST"])
def entries_upload():
    f = request.files.get("csv_file")
    if not f:
        flash("No file selected.", "danger")
        return redirect(url_for("entries"))
    try:
        raw = f.read().decode("utf-8-sig")
        rows, headers = core.parse_csv(raw)
        source_type = request.form.get("source_type") or None
        
        ok, err, parsed, warnings = core.validate_entry_csv(rows, source_type)
        if not ok:
            flash(f"Validation error: {err}", "danger")
            return redirect(url_for("entries"))
            
        for w in warnings:
            flash(w, "warning")
            
        core.upload_entry_batch(f.filename, raw, headers, parsed,
                                created_by=session.get('username', 'system'),
                                source_type=source_type)
        flash(f"{len(parsed)} rows uploaded as DRAFT.", "success")
    except Exception as e:
        flash(f"Error: {e}", "danger")
    return redirect(url_for("entries"))


@app.route("/entries/post/<uid>", methods=["POST"])
def entries_post(uid):
    ok, err = core.post_entry_batch(uid)
    flash("Entry batch posted." if ok else f"Cannot post: {err}",
          "success" if ok else "danger")
    return redirect(url_for("entries"))


@app.route("/entries/unpost/<uid>", methods=["POST"])
def entries_unpost(uid):
    ok, err = core.unpost_batch(uid)
    if not ok:
        flash(f"Cannot unpost: {err}", "danger")
    else:
        flash("Entry batch set back to DRAFT.", "success")
    return redirect(url_for("entries"))


@app.route("/entries/export/<uid>")
def entries_export(uid):
    return _stream_raw_csv(uid)


@app.route("/entries/export-ledger")
def entries_export_ledger():
    db_path = session.get('db_path')
    if not db_path:
        flash("No active session database found.", "danger")
        return redirect(url_for("entries"))
    try:
        import exports
        out = io.BytesIO()
        ext, mimetype = exports.export_ledger(db_path, out)
        out.seek(0)
        filename = f"general_ledger_{datetime.now().strftime('%Y%m%d_%H%M%S')}.{ext}"
        return send_file(
            out,
            mimetype=mimetype,
            as_attachment=True,
            download_name=filename
        )
    except Exception as e:
        flash(f"Export failed: {e}", "danger")
        return redirect(url_for("entries"))


@app.route("/entries/view/<uid>")
def entries_view(uid):
    rows = core.get_entry_rows(uid)
    return render_template("view_rows.html", rows=rows, upload_type="ENTRY", active="entries")


@app.route("/entries/delete/<uid>", methods=["POST"])
def entries_delete(uid):
    ok, err = core.assert_draft(uid)
    if not ok:
        flash(err, "danger")
        return redirect(url_for("entries"))
    core.delete_batch(uid)
    flash("Batch deleted.", "success")
    return redirect(url_for("entries"))


# ============================================================================
# TRIAL BALANCE
# ============================================================================

def calculate_preset_dates(preset, anchor_date_str, year_end_str=None):
    from datetime import datetime, timedelta
    try:
        anchor = datetime.strptime(anchor_date_str, "%Y-%m-%d")
    except ValueError:
        anchor = datetime.now()
        anchor_date_str = anchor.strftime("%Y-%m-%d")

    if preset == "day":
        return anchor_date_str, anchor_date_str
    elif preset == "week":
        start = anchor - timedelta(days=anchor.weekday())
        end = start + timedelta(days=6)
        return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d")
    elif preset == "month":
        start = anchor.replace(day=1)
        if anchor.month == 12:
            end = anchor.replace(year=anchor.year + 1, month=1, day=1) - timedelta(days=1)
        else:
            end = anchor.replace(month=anchor.month + 1, day=1) - timedelta(days=1)
        return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d")
    elif preset == "quarter":
        q_month = ((anchor.month - 1) // 3) * 3 + 1
        start = anchor.replace(month=q_month, day=1)
        if q_month == 10:
            end = anchor.replace(year=anchor.year + 1, month=1, day=1) - timedelta(days=1)
        else:
            end = anchor.replace(month=q_month + 3, day=1) - timedelta(days=1)
        return start.strftime("%Y-%m-%d"), end.strftime("%Y-%m-%d")
    elif preset == "year":
        if not year_end_str:
            year_end_str = "31 Dec"
        months_map = {
            "jan": 1, "feb": 2, "mar": 3, "apr": 4, "may": 5, "jun": 6,
            "jul": 7, "aug": 8, "sep": 9, "oct": 10, "nov": 11, "dec": 12
        }
        day = 31
        month = 12
        parts = year_end_str.lower().strip().split()
        for p in parts:
            p = p.strip()
            if p.isdigit():
                day = int(p)
            else:
                for m_key, m_val in months_map.items():
                    if p.startswith(m_key):
                        month = m_val
                        break
        try:
            candidate_end = datetime(year=anchor.year, month=month, day=day)
        except ValueError:
            for d in range(day, 27, -1):
                try:
                    candidate_end = datetime(year=anchor.year, month=month, day=d)
                    break
                except ValueError:
                    continue
        if anchor <= candidate_end:
            fy_end = candidate_end
        else:
            try:
                fy_end = datetime(year=anchor.year + 1, month=month, day=day)
            except ValueError:
                for d in range(day, 27, -1):
                    try:
                        fy_end = datetime(year=anchor.year + 1, month=month, day=d)
                        break
                    except ValueError:
                        continue
        try:
            prev_fy_end = datetime(year=fy_end.year - 1, month=fy_end.month, day=fy_end.day)
        except ValueError:
            for d in range(fy_end.day, 27, -1):
                try:
                    prev_fy_end = datetime(year=fy_end.year - 1, month=fy_end.month, day=d)
                    break
                except ValueError:
                    continue
        fy_start = prev_fy_end + timedelta(days=1)
        return fy_start.strftime("%Y-%m-%d"), fy_end.strftime("%Y-%m-%d")
    return "", ""


def get_prior_period(date_from_str, date_to_str):
    from datetime import datetime, timedelta
    try:
        d_from = datetime.strptime(date_from_str, "%Y-%m-%d")
        d_to = datetime.strptime(date_to_str, "%Y-%m-%d")
    except (ValueError, TypeError):
        return None, None
    delta = (d_to - d_from).days + 1
    prior_to = d_from - timedelta(days=1)
    prior_from = prior_to - timedelta(days=delta - 1)
    return prior_from.strftime("%Y-%m-%d"), prior_to.strftime("%Y-%m-%d")


@app.route("/trial-balance")
def trial_balance():
    anchor_date = request.args.get("anchor_date", "")
    if not anchor_date:
        anchor_date = datetime.now().strftime("%Y-%m-%d")
        
    preset = request.args.get("preset", "month")
    
    co = core.company_to_dict(core.get_company())
    year_end = co.get("year_end", "31 Dec")
    
    if preset == "custom":
        date_from = request.args.get("date_from", "")
        date_to = request.args.get("date_to", "")
    else:
        date_from, date_to = calculate_preset_dates(preset, anchor_date, year_end)
        
    tb = core.get_trial_balance(date_from or None, date_to or None)
    composition = core.get_composition()
    total = sum(bal for _, _, bal in tb)
    
    # Comparatives calculation
    comparative_tb = []
    prior_from, prior_to = None, None
    if date_from and date_to:
        prior_from, prior_to = get_prior_period(date_from, date_to)
        if prior_from and prior_to:
            tb_prior = core.get_trial_balance(prior_from, prior_to)
            
            merged = {}
            for acct, desc, bal in tb:
                merged[acct] = {"desc": desc, "current_bal": bal, "prior_bal": Decimal("0")}
            for acct, desc, bal in tb_prior:
                if acct in merged:
                    merged[acct]["prior_bal"] = bal
                else:
                    merged[acct] = {"desc": desc, "current_bal": Decimal("0"), "prior_bal": bal}
                    
            for acct in sorted(merged.keys()):
                item = merged[acct]
                curr = item["current_bal"]
                prior = item["prior_bal"]
                var = curr - prior
                pct = None
                if prior != Decimal("0"):
                    pct = (var / abs(prior)) * 100
                comparative_tb.append({
                    "account_no": acct,
                    "description": item["desc"],
                    "current_bal": curr,
                    "prior_bal": prior,
                    "variance": var,
                    "pct_variance": pct
                })
                
    return render_template("trial_balance.html",
                           tb=tb, composition=composition, total=total,
                           date_from=date_from, date_to=date_to,
                           anchor_date=anchor_date, preset=preset, active="tb",
                           comparative_tb=comparative_tb,
                           prior_from=prior_from, prior_to=prior_to)


@app.route("/trial-balance/export")
def trial_balance_export():
    date_from = request.args.get("date_from", "")
    date_to   = request.args.get("date_to", "")
    
    buf = io.StringIO()
    w = csv.writer(buf)
    
    if date_from and date_to:
        prior_from, prior_to = get_prior_period(date_from, date_to)
        tb_curr = core.get_trial_balance(date_from or None, date_to or None)
        tb_prior = core.get_trial_balance(prior_from or None, prior_to or None)
        
        merged = {}
        for acct, desc, bal in tb_curr:
            merged[acct] = {"desc": desc, "current_bal": bal, "prior_bal": Decimal("0")}
        for acct, desc, bal in tb_prior:
            if acct in merged:
                merged[acct]["prior_bal"] = bal
            else:
                merged[acct] = {"desc": desc, "current_bal": Decimal("0"), "prior_bal": bal}
                
        w.writerow([
            "account_no", "description", 
            f"current_balance ({date_from} to {date_to})", 
            f"prior_balance ({prior_from} to {prior_to})", 
            "variance", "percentage_variance"
        ])
        for acct in sorted(merged.keys()):
            item = merged[acct]
            curr = item["current_bal"]
            prior = item["prior_bal"]
            var = curr - prior
            pct_str = ""
            if prior != Decimal("0"):
                pct = (var / abs(prior)) * 100
                pct_str = f"{pct:.2f}%"
            w.writerow([acct, item["desc"], f"{curr:.2f}", f"{prior:.2f}", f"{var:.2f}", pct_str])
            
        download_name = "trial_balance_comparative.csv"
    else:
        tb = core.get_trial_balance(None, None)
        w.writerow(["account_no", "description", "balance"])
        for acct, desc, bal in tb:
            w.writerow([acct, desc, f"{bal:.2f}"])
        download_name = "trial_balance.csv"
            
    return send_file(
        io.BytesIO(buf.getvalue().encode("utf-8")),
        mimetype="text/csv",
        as_attachment=True,
        download_name=download_name,
    )


# ============================================================================
# BACKUP / RESTORE
# ============================================================================

def validate_db_file(db_path):
    """Verify that the database at db_path is valid and contains required tables and user_version."""
    import sqlite3
    conn = None
    try:
        conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
        c = conn.cursor()
        c.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = {row[0] for row in c.fetchall()}
        required = {"company", "uploads", "coa_rows", "entry_rows"}
        if not required.issubset(tables):
            return False, f"Missing required tables. Found: {', '.join(tables)}"
        c.execute("PRAGMA user_version")
        ver = c.fetchone()[0]
        if ver < 0:
            return False, "Invalid user_version."
        return True, ""
    except Exception as e:
        return False, f"Not a valid database: {e}"
    finally:
        if conn:
            conn.close()


@app.route("/backup")
def backup():
    key = session.get('engagement_key')
    db_path = session.get('db_path')
    if not key or not db_path:
        flash("No active engagement session.", "danger")
        return redirect(url_for("login"))
        
    try:
        import zipfile
        import json
        
        backup_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "backups", key))
        os.makedirs(backup_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_filename = f"backup_{timestamp}.zip"
        zip_path = os.path.join(backup_dir, zip_filename)
        
        # 1. Back up database to a temp file using API
        tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db")
        os.close(tmp_fd)
        
        core.backup_db(tmp_path)
        
        # 2. Get manifest data
        stats = core.get_backup_manifest_stats()
        manifest = {
            "engagement_name": stats["company_name"],
            "engagement_key": key,
            "schema_version": stats["schema_version"],
            "row_counts": stats["row_counts"],
            "created_at": datetime.now().isoformat()
        }
        
        # 3. Zip database and manifest
        with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
            z.write(tmp_path, "ledger.db")
            z.writestr("manifest.json", json.dumps(manifest, indent=2))
            
        os.unlink(tmp_path)
        
        return send_file(zip_path, as_attachment=True, download_name=zip_filename)
    except Exception as e:
        flash(f"Backup failed: {e}", "danger")
        return redirect(url_for("company"))


@app.route("/restore", methods=["POST"])
def restore():
    key = session.get('engagement_key')
    active_db_path = session.get('db_path')
    if not key or not active_db_path:
        flash("No active engagement session.", "danger")
        return redirect(url_for("login"))
        
    f = request.files.get("db_file")
    if not f:
        flash("No file selected.", "danger")
        return redirect(url_for("company"))
        
    tmp_path = None
    tmp_zip_path = None
    try:
        import zipfile
        import json
        import sqlite3
        
        # Save the uploaded file to a temporary zip/db location
        tmp_zip_fd, tmp_zip_path = tempfile.mkstemp()
        os.close(tmp_zip_fd)
        f.save(tmp_zip_path)
        
        is_zip = zipfile.is_zipfile(tmp_zip_path)
        
        # Prepare temporary db path
        tmp_db_fd, tmp_db_path = tempfile.mkstemp(suffix=".db")
        os.close(tmp_db_fd)
        tmp_path = tmp_db_path
        
        if is_zip:
            with zipfile.ZipFile(tmp_zip_path, "r") as z:
                if "ledger.db" not in z.namelist():
                    flash("Invalid backup zip: missing ledger.db", "danger")
                    return redirect(url_for("company"))
                
                with open(tmp_db_path, "wb") as out_f:
                    out_f.write(z.read("ledger.db"))
        else:
            shutil.copy2(tmp_zip_path, tmp_db_path)
            
        # Validate the temporary database file
        ok, err = validate_db_file(tmp_db_path)
        if not ok:
            flash(f"Validation failed: {err}", "danger")
            return redirect(url_for("company"))
            
        # Take a safety backup of the current database before restore
        backup_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "backups", key))
        os.makedirs(backup_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        safety_zip_path = os.path.join(backup_dir, f"before_restore_{timestamp}.zip")
        
        safety_db_fd, safety_db_path = tempfile.mkstemp(suffix=".db")
        os.close(safety_db_fd)
        
        core.backup_db(safety_db_path)
        
        stats = core.get_backup_manifest_stats()
        safety_manifest = {
            "engagement_name": stats["company_name"],
            "engagement_key": key,
            "schema_version": stats["schema_version"],
            "row_counts": stats["row_counts"],
            "created_at": datetime.now().isoformat(),
            "notes": "Safety backup created automatically before restore"
        }
        
        with zipfile.ZipFile(safety_zip_path, "w", zipfile.ZIP_DEFLATED) as z:
            z.write(safety_db_path, "ledger.db")
            z.writestr("manifest.json", json.dumps(safety_manifest, indent=2))
            
        os.unlink(safety_db_path)
        
        # Overwrite current active database using the SQLite online backup API in reverse
        active_conn = sqlite3.connect(active_db_path)
        restore_source = sqlite3.connect(tmp_db_path)
        with active_conn:
            restore_source.backup(active_conn)
        restore_source.close()
        active_conn.close()
        
        flash("Database restored successfully. A safety backup was saved.", "success")
        
    except Exception as e:
        flash(f"Restore failed: {e}", "danger")
    finally:
        if tmp_path and os.path.exists(tmp_path):
            os.unlink(tmp_path)
        if tmp_zip_path and os.path.exists(tmp_zip_path):
            os.unlink(tmp_zip_path)
            
    return redirect(url_for("company"))


# ============================================================================
# HELPER
# ============================================================================

def _stream_raw_csv(uid):
    conn = core.get_db()
    c = conn.cursor()
    c.execute("SELECT raw_csv, filename FROM uploads WHERE upload_id=?", (uid,))
    result = c.fetchone()
    conn.close()
    if not result:
        flash("Upload not found.", "danger")
        return redirect(url_for("index"))
    raw_csv, filename = result
    return send_file(
        io.BytesIO(raw_csv.encode("utf-8")),
        mimetype="text/csv",
        as_attachment=True,
        download_name=filename,
    )


# ============================================================================
# AUTO-BACKUP ON APP CLOSE
# ============================================================================

def auto_backup_all():
    """Backup all registered engagement databases on app close."""
    import os
    import tempfile
    import zipfile
    import json
    import sqlite3
    from datetime import datetime
    import registry

    try:
        engs = registry.get_engagements()
        for key, name, db_path, last_opened in engs:
            if not os.path.exists(db_path):
                continue
            
            backup_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "backups", key))
            os.makedirs(backup_dir, exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            zip_path = os.path.join(backup_dir, f"auto_close_{timestamp}.zip")
            
            # Temporary file for backup connection destination
            tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db")
            os.close(tmp_fd)
            
            src_conn = sqlite3.connect(db_path)
            dest_conn = sqlite3.connect(tmp_path)
            with dest_conn:
                src_conn.backup(dest_conn)
                
            dest_conn.execute("PRAGMA user_version")
            schema_ver = dest_conn.execute("PRAGMA user_version").fetchone()[0]
            
            def get_tbl_count(t):
                try:
                    return dest_conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]
                except Exception:
                    return 0
                    
            manifest = {
                "engagement_name": name,
                "engagement_key": key,
                "schema_version": schema_ver,
                "row_counts": {
                    "uploads": get_tbl_count("uploads"),
                    "coa_rows": get_tbl_count("coa_rows"),
                    "entry_rows": get_tbl_count("entry_rows")
                },
                "created_at": datetime.now().isoformat(),
                "notes": "Automatic backup triggered on application exit"
            }
            
            dest_conn.close()
            src_conn.close()
            
            with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
                z.write(tmp_path, "ledger.db")
                z.writestr("manifest.json", json.dumps(manifest, indent=2))
                
            os.unlink(tmp_path)
    except Exception as e:
        print(f"Auto-backup on exit failed: {e}")


import atexit
atexit.register(auto_backup_all)


if __name__ == "__main__":
    app.run(debug=True, port=5001)
