#!/usr/bin/env python3
"""
SYLink HoneyBot — Watchdog services + disque.

Toutes les 2 min :
  - Pour chaque service leurre : si is-active != active ET pas de flag suspended → restart
  - Si un service a > 10 restarts dans la dernière heure → log + push alert event
  - Si df / > 80% → purge agressive logs (compressed + cowrie.json/cowrie.log/auth.log)
  - Si df / > 90% → mode dégradé : stop rdp-recorder (gros consommateur ffmpeg)
"""
from __future__ import annotations

import json
import os
import shutil
import subprocess
import sys
import time
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path

LOG_DIR = Path("/opt/sylink-honeypot/log")
SUSPENDED_FLAG = Path("/var/lib/unisoc-honeypot.suspended")
RESTART_COUNTER = Path("/opt/sylink-honeypot/etc/watchdog_restarts.json")

CRITICAL_SERVICES = [
    "cowrie", "opencanary", "veeam-fake", "proftpd-fake",
    "ssh-tarpit", "file-watcher", "http-honeytrap",
    "smbd",
]
NON_CRITICAL = ["rdp-recorder", "xrdp", "xrdp-sesman", "nmbd"]
DISK_THRESHOLD_PURGE = 80
DISK_THRESHOLD_DEGRADED = 90


def log(msg: str):
    LOG_DIR.mkdir(parents=True, exist_ok=True)
    line = f"[{datetime.now(timezone.utc).isoformat()}] {msg}"
    print(line, flush=True)
    with (LOG_DIR / "watchdog.log").open("a") as f:
        f.write(line + "\n")


def is_active(svc: str) -> bool:
    r = subprocess.run(["systemctl", "is-active", svc], capture_output=True, text=True, timeout=8)
    return r.stdout.strip() == "active"


def restart(svc: str) -> bool:
    r = subprocess.run(["systemctl", "restart", svc], capture_output=True, text=True, timeout=30)
    return r.returncode == 0


def load_counter() -> dict:
    if not RESTART_COUNTER.exists():
        return {}
    try:
        return json.loads(RESTART_COUNTER.read_text())
    except Exception:
        return {}


def save_counter(c: dict):
    RESTART_COUNTER.parent.mkdir(parents=True, exist_ok=True)
    RESTART_COUNTER.write_text(json.dumps(c, indent=2))


def disk_usage_pct() -> int:
    r = shutil.disk_usage("/")
    return int((1 - r.free / r.total) * 100)


def purge_logs():
    log("DISK PURGE — vidage logs forwardés")
    # Cowrie cowrie.json + cowrie.log déjà forwardés
    for p in ["/var/log/cowrie/cowrie.json", "/var/log/cowrie/cowrie.log", "/var/log/auth.log"]:
        if Path(p).exists():
            try:
                with open(p, "w") as f:
                    f.truncate()
                log(f"  truncated {p}")
            except Exception as e:
                log(f"  truncate {p} échec : {e}")
    # Logs compressés vieux
    subprocess.run(["bash", "-c", "find /var/log -name '*.gz' -mtime +1 -delete"], timeout=20)
    subprocess.run(["bash", "-c", "find /var/log/cowrie/tty -type f -mtime +1 -delete 2>/dev/null"], timeout=20)
    # Cowrie ttylog peut être énorme aussi
    subprocess.run(["bash", "-c", "find /var/log/cowrie/downloads -type f -delete 2>/dev/null"], timeout=20)


def degraded_mode():
    log("DISK >90% — mode dégradé : stop rdp-recorder (vidéos ffmpeg)")
    subprocess.run(["systemctl", "stop", "rdp-recorder"], timeout=10)


def main() -> int:
    log("== watchdog start ==")
    counter = load_counter()
    now_h = datetime.now(timezone.utc).strftime("%Y-%m-%d %H")

    # ──── 1. Service health ────
    suspended = SUSPENDED_FLAG.exists()
    if suspended:
        log("VM en mode SUSPENDED — pas de restart")
    else:
        for svc in CRITICAL_SERVICES + NON_CRITICAL:
            if not is_active(svc):
                log(f"  service {svc} INACTIF → restart")
                ok = restart(svc)
                key = f"{svc}|{now_h}"
                counter[key] = counter.get(key, 0) + 1
                if counter[key] > 10:
                    log(f"  ⚠ {svc} > 10 restarts dans cette heure — crash loop ?")
                if not ok:
                    log(f"  restart {svc} ÉCHEC")
        save_counter(counter)
        # Compaction counter : retire entrées > 24h
        cutoff = (datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0))
        from datetime import timedelta as _td
        cutoff -= _td(hours=24)
        counter = {k: v for k, v in counter.items() if k.split("|", 1)[1] >= cutoff.strftime("%Y-%m-%d %H")}
        save_counter(counter)

    # ──── 2. Disk health ────
    pct = disk_usage_pct()
    log(f"disk / : {pct}% utilisé")
    if pct > DISK_THRESHOLD_DEGRADED:
        degraded_mode()
        purge_logs()
    elif pct > DISK_THRESHOLD_PURGE:
        purge_logs()

    log("== watchdog done ==")
    return 0


if __name__ == "__main__":
    sys.exit(main())
