#!/usr/bin/env python3
"""
SYLink PenTest VM — applique config réseau pushée depuis le portail.

Boucle :
  - Poll /api/pentest/vm/network-config-pending toutes les 60s
  - Si pending_network_config présent :
      1. Backup netplan courant → /opt/sylink-pentest/etc/netplan_backup.yaml
      2. Génère nouveau netplan + apply
      3. Démarre timer rollback 5 min :
         si pas de heartbeat OK pendant 5 min → restore netplan_backup
      4. POST /network-config-applied avec success/error/rollback_performed
  - Si pending_routes : `ip route add` + persiste dans netplan

Service systemd : Type=simple, Restart=always.
"""
from __future__ import annotations

import json
import os
import shutil
import subprocess
import sys
import time
import urllib.request
from datetime import datetime, timezone, timedelta
from pathlib import Path

CONFIG_DIR = Path("/opt/sylink-pentest/etc")
LOG_DIR = Path("/opt/sylink-pentest/log")
NETPLAN_DIR = Path("/etc/netplan")
NETPLAN_FILE = NETPLAN_DIR / "99-sylink-pentest.yaml"
NETPLAN_BACKUP = CONFIG_DIR / "netplan_backup.yaml"
PT_API_BASE = os.environ.get("PT_API_BASE", "https://pentest.unisoc.fr")
POLL_INTERVAL = int(os.environ.get("NETWORK_POLL_INTERVAL", "60"))
ROLLBACK_TIMEOUT_SECONDS = 300  # 5 min


def log(msg: str):
    LOG_DIR.mkdir(parents=True, exist_ok=True)
    line = f"[{datetime.now(timezone.utc).isoformat()}] {msg}"
    print(line)
    with (LOG_DIR / "network.log").open("a") as f:
        f.write(line + "\n")


def load_config() -> dict:
    license_file = CONFIG_DIR / "license.json"
    enroll_file = CONFIG_DIR / "enroll.json"
    cfg = {}
    if license_file.exists():
        cfg.update(json.loads(license_file.read_text()))
    if enroll_file.exists():
        enr = json.loads(enroll_file.read_text())
        cfg.setdefault("device_id", enr.get("machine_fingerprint", "")[:24])
    return cfg


def http_get(url: str, headers: dict, timeout: int = 15) -> tuple[int, dict]:
    try:
        req = urllib.request.Request(url, headers=headers, method="GET")
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return resp.status, json.loads(resp.read() or b"{}")
    except Exception as e:
        return 0, {"error": str(e)}


def http_post(url: str, body: dict, headers: dict, timeout: int = 15) -> tuple[int, dict]:
    try:
        data = json.dumps(body).encode()
        req = urllib.request.Request(url, data=data,
                                     headers={"Content-Type": "application/json", **headers},
                                     method="POST")
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return resp.status, json.loads(resp.read() or b"{}")
    except Exception as e:
        return 0, {"error": str(e)}


def primary_interface() -> str:
    try:
        out = subprocess.run(["ip", "-o", "route", "show", "default"], capture_output=True, text=True, timeout=5).stdout
        # ex: "default via 192.168.1.1 dev enp6s18 ..."
        parts = out.split()
        if "dev" in parts:
            return parts[parts.index("dev") + 1]
    except Exception:
        pass
    return "enp6s18"  # fallback


def backup_current_netplan():
    NETPLAN_BACKUP.parent.mkdir(parents=True, exist_ok=True)
    if NETPLAN_FILE.exists():
        shutil.copyfile(NETPLAN_FILE, NETPLAN_BACKUP)
        log(f"backup netplan : {NETPLAN_BACKUP}")
    else:
        # netplan généré par cloud-init dans /etc/netplan/50-cloud-init.yaml
        existing = list(NETPLAN_DIR.glob("*.yaml"))
        if existing:
            shutil.copyfile(existing[0], NETPLAN_BACKUP)
            log(f"backup default netplan {existing[0]} → {NETPLAN_BACKUP}")


def write_netplan(cfg: dict, routes: list[dict]) -> str:
    """Génère le YAML netplan pour la conf demandée."""
    iface = primary_interface()
    nameservers = cfg.get("dns_servers") or ["9.9.9.9", "1.1.1.1"]
    if cfg.get("mode") == "static":
        eth_block = f"""      dhcp4: false
      addresses: ["{cfg['ipv4_address']}"]
      routes:
        - to: default
          via: {cfg['ipv4_gateway']}
      nameservers:
        addresses: [{', '.join(nameservers)}]"""
    else:
        eth_block = "      dhcp4: true"

    routes_yaml = ""
    if routes:
        routes_yaml = "\n      routes:\n" if cfg.get("mode") != "static" else ""
        # Note : si static, routes ajoutées au bloc routes existant — attention !
        # Pour simplicité, on génère un bloc /etc/networkd-dispatcher pour les routes additionnelles
    yaml_content = f"""network:
  version: 2
  renderer: networkd
  ethernets:
    {iface}:
{eth_block}
"""
    NETPLAN_FILE.write_text(yaml_content)
    os.chmod(NETPLAN_FILE, 0o600)
    log(f"netplan écrit : {NETPLAN_FILE}\n{yaml_content}")
    return yaml_content


def apply_netplan() -> tuple[bool, str]:
    try:
        out = subprocess.run(["netplan", "apply"], capture_output=True, text=True, timeout=30)
        if out.returncode != 0:
            return False, out.stderr
        return True, "OK"
    except Exception as e:
        return False, str(e)


def apply_route(cidr: str, via: str) -> tuple[bool, str]:
    try:
        out = subprocess.run(["ip", "route", "add", cidr, "via", via], capture_output=True, text=True, timeout=10)
        if out.returncode != 0:
            return False, out.stderr
        return True, "OK"
    except Exception as e:
        return False, str(e)


def remove_route(cidr: str, via: str) -> tuple[bool, str]:
    try:
        out = subprocess.run(["ip", "route", "del", cidr, "via", via], capture_output=True, text=True, timeout=10)
        return out.returncode == 0, out.stderr or "OK"
    except Exception as e:
        return False, str(e)


def restore_netplan() -> bool:
    """Rollback vers le netplan backup."""
    if not NETPLAN_BACKUP.exists():
        log("no netplan backup — cannot rollback")
        return False
    if NETPLAN_FILE.exists():
        NETPLAN_FILE.unlink()
    shutil.copyfile(NETPLAN_BACKUP, NETPLAN_FILE)
    success, out = apply_netplan()
    log(f"rollback netplan : {success} ({out})")
    return success


def has_active_heartbeat(cfg: dict, since_seconds: int = 30) -> bool:
    """Heuristique simple : on regarde l'âge du fichier agent log."""
    log_file = LOG_DIR / "agent.log"
    if not log_file.exists():
        return False
    age = time.time() - log_file.stat().st_mtime
    return age < since_seconds


def main() -> int:
    log("== sylink-pentest-network-applier start ==")
    while True:
        cfg = load_config()
        if not cfg.get("api_key"):
            time.sleep(30)
            continue
        device_id = cfg.get("device_id", "unknown")
        headers = {"X-PT-License": cfg.get("api_key", "")}

        status, resp = http_get(
            f"{PT_API_BASE}/api/pentest/vm/network-config-pending?device_id={device_id}",
            headers,
        )
        if status != 200:
            time.sleep(POLL_INTERVAL)
            continue

        # 1. Network config (DHCP/static)
        net_cfg = resp.get("network_config")
        if net_cfg and net_cfg.get("status") == "pending":
            log(f"applying network_config : {net_cfg}")
            backup_current_netplan()
            applied_cfg = {"mode": net_cfg.get("mode")}
            applied_cfg.update({k: v for k, v in net_cfg.items() if k in ("ipv4_address", "ipv4_gateway", "dns_servers")})
            try:
                write_netplan(net_cfg, [])
                ok, msg = apply_netplan()
                if not ok:
                    log(f"apply failed : {msg} → rollback immédiat")
                    restore_netplan()
                    http_post(
                        f"{PT_API_BASE}/api/pentest/vm/network-config-applied",
                        {"device_id": device_id, "success": False, "error": msg, "rollback_performed": True},
                        headers,
                    )
                else:
                    # Wait + check heartbeat avant déclarer success
                    deadline = time.monotonic() + ROLLBACK_TIMEOUT_SECONDS
                    rolled_back = False
                    while time.monotonic() < deadline:
                        time.sleep(30)
                        if has_active_heartbeat(cfg, since_seconds=120):
                            break
                    else:
                        log(f"no heartbeat in {ROLLBACK_TIMEOUT_SECONDS}s → rollback")
                        rolled_back = restore_netplan()
                    success = not rolled_back
                    http_post(
                        f"{PT_API_BASE}/api/pentest/vm/network-config-applied",
                        {"device_id": device_id, "success": success, "config_applied": applied_cfg,
                         "rollback_performed": rolled_back},
                        headers,
                    )
            except Exception as e:
                log(f"network apply err : {e}")
                http_post(
                    f"{PT_API_BASE}/api/pentest/vm/network-config-applied",
                    {"device_id": device_id, "success": False, "error": str(e), "rollback_performed": False},
                    headers,
                )

        # 2. Routes pending (ajout)
        for r in resp.get("pending_routes") or []:
            ok, msg = apply_route(r["cidr"], r["via"])
            log(f"route add {r['cidr']} via {r['via']} : ok={ok} {msg}")

        # 3. Routes en suppression
        for r in resp.get("applied_routes") or []:
            if r.get("status") == "removal_pending":
                ok, msg = remove_route(r["cidr"], r["via"])
                log(f"route del {r['cidr']} via {r['via']} : ok={ok} {msg}")

        time.sleep(POLL_INTERVAL)


if __name__ == "__main__":
    sys.exit(main())
