#!/usr/bin/env python3
"""
SYLink PenTest VM — agent principal (heartbeat + poll commands + exec scans).

Boucle infinie :
  - Heartbeat enrichi vers UniSOC toutes les 60s (HW, interfaces, routing)
  - Poll commandes pending toutes les 60s
  - Pour chaque commande : appel API locale `:8888` (pentest-api.service)
                          → poll job local → push result à UniSOC

Service systemd : Type=simple, Restart=always.
"""
from __future__ import annotations

import json
import os
import subprocess
import sys
import time
import urllib.request
from datetime import datetime, timezone
from pathlib import Path

CONFIG_DIR = Path("/opt/sylink-pentest/etc")
LOG_DIR = Path("/opt/sylink-pentest/log")
PT_API_BASE = os.environ.get("PT_API_BASE", "https://pentest.unisoc.fr")
LOCAL_API = os.environ.get("LOCAL_PENTEST_API", "http://localhost:8888")
HEARTBEAT_INTERVAL = int(os.environ.get("HEARTBEAT_INTERVAL", "60"))
POLL_INTERVAL = int(os.environ.get("POLL_INTERVAL", "60"))
JOB_POLL_TIMEOUT = int(os.environ.get("JOB_POLL_TIMEOUT", "1800"))  # 30 min max par job


def log(msg: str):
    LOG_DIR.mkdir(parents=True, exist_ok=True)
    line = f"[{datetime.now(timezone.utc).isoformat()}] {msg}"
    print(line)
    with (LOG_DIR / "agent.log").open("a") as f:
        f.write(line + "\n")


def load_config() -> dict:
    license_file = CONFIG_DIR / "license.json"
    enroll_file = CONFIG_DIR / "enroll.json"
    cfg = {}
    if license_file.exists():
        cfg.update(json.loads(license_file.read_text()))
    if enroll_file.exists():
        enr = json.loads(enroll_file.read_text())
        cfg.setdefault("machine_fingerprint", enr.get("machine_fingerprint"))
        cfg.setdefault("device_id", enr.get("machine_fingerprint", "")[:24])
    return cfg


def http_post(url: str, body: dict, headers: dict, timeout: int = 30) -> tuple[int, dict]:
    data = json.dumps(body).encode()
    req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json", **headers}, method="POST")
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return resp.status, json.loads(resp.read() or b"{}")
    except urllib.error.HTTPError as e:
        return e.code, json.loads((e.read() or b"{}").decode())
    except Exception as e:
        log(f"http_post err {url} : {e}")
        return 0, {"error": str(e)}


def http_get(url: str, headers: dict, timeout: int = 30) -> tuple[int, dict]:
    req = urllib.request.Request(url, headers=headers, method="GET")
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return resp.status, json.loads(resp.read() or b"{}")
    except urllib.error.HTTPError as e:
        return e.code, json.loads((e.read() or b"{}").decode())
    except Exception as e:
        log(f"http_get err {url} : {e}")
        return 0, {"error": str(e)}


# ─── Collecte HW/Réseau ──────────────────────────────────────────────────────


def collect_interfaces() -> list[dict]:
    """Liste les interfaces physiques avec IPs et MAC."""
    out = []
    try:
        ip_a = subprocess.run(["ip", "-j", "addr"], capture_output=True, text=True, timeout=5).stdout
        for iface in json.loads(ip_a):
            name = iface.get("ifname")
            if name in ("lo",) or name.startswith(("veth", "docker", "br-", "podman")):
                continue
            addrs = iface.get("addr_info", [])
            ipv4 = next((a["local"] + "/" + str(a["prefixlen"]) for a in addrs if a.get("family") == "inet"), None)
            ipv6 = next((a["local"] + "/" + str(a["prefixlen"]) for a in addrs if a.get("family") == "inet6" and not a["local"].startswith("fe80")), None)
            out.append({
                "name": name,
                "mac": iface.get("address"),
                "ipv4": ipv4,
                "ipv6": ipv6,
                "mtu": iface.get("mtu"),
                "operstate": iface.get("operstate"),
            })
    except Exception as e:
        log(f"collect_interfaces err : {e}")
    return out


def collect_routes() -> list[dict]:
    out = []
    try:
        rt = subprocess.run(["ip", "-j", "route"], capture_output=True, text=True, timeout=5).stdout
        for r in json.loads(rt):
            out.append({
                "destination": r.get("dst", "default"),
                "via": r.get("gateway"),
                "dev": r.get("dev"),
                "metric": r.get("metric"),
            })
    except Exception:
        pass
    return out


def collect_metrics() -> dict:
    m = {}
    try:
        with open("/proc/loadavg") as f:
            parts = f.read().split()
            m["load_1"] = float(parts[0])
            m["load_5"] = float(parts[1])
            m["load_15"] = float(parts[2])
    except Exception:
        pass
    try:
        with open("/proc/meminfo") as f:
            mem = {}
            for line in f:
                k, v = line.split(":", 1)
                mem[k.strip()] = int(v.strip().split()[0]) * 1024
            total = mem.get("MemTotal", 1)
            avail = mem.get("MemAvailable", 0)
            m["mem_total_bytes"] = total
            m["mem_used_pct"] = round(100 * (total - avail) / total, 1)
    except Exception:
        pass
    try:
        st = os.statvfs("/")
        used = (st.f_blocks - st.f_bavail) * st.f_frsize
        total = st.f_blocks * st.f_frsize
        m["disk_used_pct_root"] = round(100 * used / total, 1)
    except Exception:
        pass
    try:
        with open("/proc/uptime") as f:
            m["uptime_seconds"] = int(float(f.read().split()[0]))
    except Exception:
        pass
    return m


def detect_virtualization() -> str:
    try:
        out = subprocess.run(["systemd-detect-virt"], capture_output=True, text=True, timeout=2).stdout.strip()
        return out or "none"
    except Exception:
        return "unknown"


# ─── Heartbeat ───────────────────────────────────────────────────────────────


def heartbeat(cfg: dict):
    headers = {"X-PT-License": cfg.get("api_key", "")}
    body = {
        "device_id": cfg.get("device_id", "unknown"),
        "hostname": os.uname().nodename,
        "os": f"{os.uname().sysname} {os.uname().release}",
        "version": "1.0.0",
        "current_scan_status": "idle",
        "interfaces": collect_interfaces(),
        "routing_table": collect_routes(),
        "system_metrics": collect_metrics(),
        "virtualization": detect_virtualization(),
    }
    # local_ip (1ère interface non-lo)
    for i in body["interfaces"]:
        if i.get("ipv4") and not i["ipv4"].startswith("127."):
            body["local_ip"] = i["ipv4"].split("/")[0]
            break

    status, resp = http_post(f"{PT_API_BASE}/api/pentest/vm/heartbeat", body, headers, timeout=15)
    if status == 200:
        log(f"heartbeat OK paused={resp.get('paused')}")
    else:
        log(f"heartbeat FAIL {status} : {resp}")


# ─── Poll commands + exec ────────────────────────────────────────────────────


def _local_result(status: str, **kw) -> dict:
    """Helper pour formatter un résultat handler local."""
    started = kw.pop("started_at", datetime.now(timezone.utc).isoformat())
    return {
        "vm_job_id": kw.pop("vm_job_id", "local"),
        "status": status,
        "findings": kw.pop("findings", []),
        "started_at": started,
        "finished_at": datetime.now(timezone.utc).isoformat(),
        "report_json": kw.pop("report_json", None),
        "raw_output": kw.pop("raw_output", None),
        "error": kw.pop("error", None),
    }


def handle_update_os(params: dict) -> dict:
    """apt update + upgrade. Logs capturés."""
    started = datetime.now(timezone.utc).isoformat()
    try:
        log("  apt-get update...")
        r1 = subprocess.run(["apt-get", "update", "-y"], capture_output=True, text=True, timeout=600)
        log("  apt-get upgrade...")
        env = os.environ.copy()
        env["DEBIAN_FRONTEND"] = "noninteractive"
        r2 = subprocess.run(["apt-get", "upgrade", "-y"], capture_output=True, text=True, timeout=1800, env=env)
        out = (r1.stdout + r1.stderr + r2.stdout + r2.stderr)[-8000:]
        if params.get("reboot_after") and r2.returncode == 0:
            log("  reboot_after=true — schedule reboot in 60s")
            subprocess.Popen(["shutdown", "-r", "+1", "Update OS terminé — reboot programmé"])
        return _local_result("done" if r2.returncode == 0 else "error",
                             vm_job_id="local_update_os", started_at=started, raw_output=out,
                             error=r2.stderr if r2.returncode else None)
    except Exception as e:
        log(f"  update_os err : {e}")
        return _local_result("error", vm_job_id="local_update_os", started_at=started, error=str(e))


def handle_update_tools(params: dict) -> dict:
    """Update pentest tools : nuclei templates + apt packages list."""
    started = datetime.now(timezone.utc).isoformat()
    output = []
    try:
        # Nuclei templates
        nu = subprocess.run(["nuclei", "-update-templates"], capture_output=True, text=True, timeout=600)
        output.append(f"nuclei -update-templates: rc={nu.returncode}\n{nu.stdout[-2000:]}")
        # Default tools list
        tools = params.get("tools") or ["nmap", "masscan", "smbclient", "ldap-utils", "nikto", "sqlmap", "hydra"]
        env = os.environ.copy()
        env["DEBIAN_FRONTEND"] = "noninteractive"
        for tool in tools:
            r = subprocess.run(["apt-get", "install", "-y", "--only-upgrade", tool], capture_output=True, text=True, timeout=300, env=env)
            output.append(f"{tool}: rc={r.returncode}")
        return _local_result("done", vm_job_id="local_update_tools", started_at=started,
                             raw_output="\n".join(output)[-8000:])
    except Exception as e:
        log(f"  update_tools err : {e}")
        return _local_result("error", vm_job_id="local_update_tools", started_at=started, error=str(e))


def _verify_hmac_signature(content: bytes, signature_b64: str) -> bool:
    """Vérifie la signature HMAC d'un script/binaire reçu du SOC."""
    import base64, hmac, hashlib
    key = os.environ.get("UNISOC_AGENT_SIGNING", "unisoc_agent_signing_v1_change_me_in_prod").encode()
    try:
        expected = hmac.new(key, content, hashlib.sha256).digest()
        given = base64.urlsafe_b64decode(signature_b64 + "==")
        return hmac.compare_digest(expected, given)
    except Exception:
        return False


def handle_update_scripts(params: dict) -> dict:
    """Self-update : pull /agent-bundle/ depuis SOC + verify HMAC + replace + restart services."""
    started = datetime.now(timezone.utc).isoformat()
    bundle_url = f"{PT_API_BASE}/api/pentest/vm/agent-bundle"
    try:
        # Pull bundle JSON {scripts: {name: content_b64}, signature_b64}
        cfg = load_config()
        headers = {"X-PT-License": cfg.get("api_key", "")}
        st, bundle = http_get(bundle_url, headers, timeout=60)
        if st != 200 or not bundle:
            return _local_result("error", vm_job_id="local_update_scripts", started_at=started,
                                 error=f"bundle fetch failed status={st}")
        # Verify signature globale
        import base64
        bundle_canonical = json.dumps(bundle.get("scripts", {}), sort_keys=True).encode()
        if not _verify_hmac_signature(bundle_canonical, bundle.get("signature_b64", "")):
            return _local_result("error", vm_job_id="local_update_scripts", started_at=started,
                                 error="signature HMAC invalide — refus")
        updated = []
        for name, content_b64 in (bundle.get("scripts") or {}).items():
            target = Path(f"/opt/sylink-pentest/bin/{name}")
            if not target.exists():
                continue  # pas un script connu, skip
            target.write_bytes(base64.urlsafe_b64decode(content_b64 + "=="))
            os.chmod(target, 0o755)
            updated.append(name)
        if params.get("restart_services", True):
            for svc in ["sylink-pentest-init", "sylink-pentest-agent", "sylink-pentest-setup-ui", "sylink-pentest-network"]:
                subprocess.run(["systemctl", "restart", f"{svc}.service"], timeout=30)
        return _local_result("done", vm_job_id="local_update_scripts", started_at=started,
                             report_json={"updated_scripts": updated})
    except Exception as e:
        log(f"  update_scripts err : {e}")
        return _local_result("error", vm_job_id="local_update_scripts", started_at=started, error=str(e))


def handle_update_packs(params: dict) -> dict:
    """Sync les packs CVE/KEV/Nuclei/wordlists/ioc-blocklist."""
    started = datetime.now(timezone.utc).isoformat()
    pack_types = params.get("pack_types") or ["cve", "kev", "nuclei-templates", "wordlists", "ioc-blocklist"]
    base = "https://pentest.unisoc.fr/packs"
    target_dir = Path("/opt/pentest-resources")
    target_dir.mkdir(parents=True, exist_ok=True)
    output = []
    for pt in pack_types:
        try:
            url = f"{base}/{pt}/latest.tar.zst"
            local_dir = target_dir / pt
            local_dir.mkdir(parents=True, exist_ok=True)
            tar_path = local_dir / "latest.tar.zst"
            # download
            import urllib.request
            urllib.request.urlretrieve(url, tar_path)
            # extract
            r = subprocess.run(["tar", "-I", "zstd", "-xf", str(tar_path), "-C", str(local_dir)],
                               capture_output=True, text=True, timeout=120)
            output.append(f"{pt}: ok={r.returncode == 0} size={tar_path.stat().st_size}")
        except Exception as e:
            output.append(f"{pt}: ERROR {e}")
    return _local_result("done", vm_job_id="local_update_packs", started_at=started,
                         raw_output="\n".join(output))


def handle_reboot(params: dict) -> dict:
    """Reboot après délai optionnel (pour permettre l'ack de la commande)."""
    started = datetime.now(timezone.utc).isoformat()
    delay = params.get("delay_seconds", 30)
    try:
        # Schedule reboot pour ack atteignable
        subprocess.Popen(["bash", "-c", f"sleep {delay} && systemctl reboot"])
        log(f"  reboot scheduled in {delay}s")
        return _local_result("done", vm_job_id="local_reboot", started_at=started,
                             report_json={"reboot_in_seconds": delay})
    except Exception as e:
        return _local_result("error", vm_job_id="local_reboot", started_at=started, error=str(e))


def handle_logs_collect(params: dict) -> dict:
    """Tar des logs + upload (TODO push vers SOC, MVP renvoie dans raw_output)."""
    started = datetime.now(timezone.utc).isoformat()
    since_hours = params.get("since_hours", 24)
    try:
        out = subprocess.run(
            ["journalctl", f"--since={since_hours} hours ago",
             "-u", "sylink-pentest-agent.service", "-u", "sylink-pentest-init.service",
             "-u", "sylink-pentest-network.service", "-u", "sylink-pentest-setup-ui.service",
             "-u", "pentest-api.service",
             "--no-pager"],
            capture_output=True, text=True, timeout=60,
        )
        # Cap raw_output à 50 KB
        return _local_result("done", vm_job_id="local_logs_collect", started_at=started,
                             raw_output=out.stdout[-50000:])
    except Exception as e:
        return _local_result("error", vm_job_id="local_logs_collect", started_at=started, error=str(e))


def handle_apply_license(params: dict) -> dict:
    """Handler local : écrit license.json et retourne success."""
    started_at = datetime.now(timezone.utc).isoformat()
    try:
        license_path = CONFIG_DIR / "license.json"
        license_data = {
            "api_key": params["api_key"],
            "tenant_id": params["tenant_id"],
            "enabled_services": params.get("enabled_services") or [],
            "expires_at": params.get("expires_at"),
            "applied_at": datetime.now(timezone.utc).isoformat(),
        }
        license_path.parent.mkdir(parents=True, exist_ok=True)
        license_path.write_text(json.dumps(license_data, indent=2))
        os.chmod(license_path, 0o600)
        log(f"  apply_license OK : tenant={params['tenant_id']}")
        # Régénère le banner /etc/issue
        try:
            import subprocess as _sp
            _sp.run(["/usr/bin/python3", "/opt/sylink-pentest/bin/init.py"], timeout=15)
        except Exception as e:
            log(f"  init re-run failed: {e}")
        return {
            "vm_job_id": "local_apply_license",
            "status": "done",
            "findings": [],
            "started_at": started_at,
            "finished_at": datetime.now(timezone.utc).isoformat(),
            "report_json": {"action": "apply_license", "tenant_id": params["tenant_id"]},
        }
    except Exception as e:
        log(f"  apply_license error: {e}")
        return {
            "vm_job_id": "local_apply_license",
            "status": "error",
            "error": str(e),
            "findings": [],
            "started_at": started_at,
            "finished_at": datetime.now(timezone.utc).isoformat(),
        }


def exec_command_local(cmd: dict) -> dict:
    """Exécute une commande de scan via l'API locale `:8888` ou un handler interne."""
    vm_endpoint = cmd.get("vm_endpoint", "")
    params = cmd.get("params", {})
    started_at = datetime.now(timezone.utc).isoformat()

    # Handlers internes (préfixe __local__:)
    if vm_endpoint.startswith("__local__:"):
        handler_name = vm_endpoint.split(":", 1)[1]
        handlers = {
            "apply_license":   handle_apply_license,
            "update_os":       handle_update_os,
            "update_tools":    handle_update_tools,
            "update_scripts":  handle_update_scripts,
            "update_packs":    handle_update_packs,
            "reboot":          handle_reboot,
            "logs_collect":    handle_logs_collect,
        }
        if handler_name in handlers:
            return handlers[handler_name](params)
        return _local_result("error", vm_job_id="local_unknown",
                             started_at=started_at,
                             error=f"Handler local inconnu : {handler_name}")

    # 1. POST initial qui retourne {job_id, ...}
    url = f"{LOCAL_API}{vm_endpoint}"
    log(f"  exec_local POST {url} params={params}")
    status, resp = http_post(url, params, headers={}, timeout=30)
    if status not in (200, 202):
        return {
            "status": "error", "error": f"local API returned {status}: {resp}",
            "started_at": started_at, "finished_at": datetime.now(timezone.utc).isoformat(),
            "findings": [],
        }
    job_id = resp.get("id") or resp.get("job_id")
    if not job_id:
        return {"status": "error", "error": "no job_id in local response",
                "started_at": started_at, "finished_at": datetime.now(timezone.utc).isoformat(),
                "findings": []}

    # 2. Poll job local jusqu'à done
    deadline = time.monotonic() + JOB_POLL_TIMEOUT
    job = {}
    while time.monotonic() < deadline:
        time.sleep(10)
        st, job = http_get(f"{LOCAL_API}/job/{job_id}", headers={}, timeout=15)
        if st != 200:
            log(f"    poll job err {st}")
            continue
        if job.get("status") in ("done", "error"):
            break
    finished_at = datetime.now(timezone.utc).isoformat()

    # 3. Récup rapport JSON
    report_json = None
    try:
        st, report_json = http_get(f"{LOCAL_API}/job/{job_id}/report/json", headers={}, timeout=15)
        if st != 200:
            report_json = None
    except Exception:
        pass

    return {
        "vm_job_id": job_id,
        "status": job.get("status", "timeout"),
        "findings": (job.get("findings") if isinstance(job, dict) else []) or [],
        "started_at": started_at,
        "finished_at": finished_at,
        "report_json": report_json,
        "raw_output": (job.get("steps") or [{}])[-1].get("output", "")[:10000] if isinstance(job, dict) else None,
    }


from concurrent.futures import ThreadPoolExecutor

# Pool d'exécution parallèle des commandes (4 scans simultanés max)
_EXEC_POOL = ThreadPoolExecutor(max_workers=4, thread_name_prefix="pt-exec")
# Suivi des cmd_id en cours pour ne pas les re-soumettre si la prochaine pull les retourne
_INFLIGHT_CMDS: set[str] = set()
_INFLIGHT_LOCK = __import__("threading").Lock()


def _execute_and_push(cmd: dict, cfg: dict):
    """Exécute une commande dans son propre thread + pousse le résultat au SOC."""
    cmd_id = cmd["cmd_id"]
    headers = {"X-PT-License": cfg.get("api_key", "")}
    log(f"  [pool] start cmd_id={cmd_id} action={cmd.get('action')}")
    try:
        result = exec_command_local(cmd)
    except Exception as e:
        result = {"status": "error", "error": str(e), "findings": [],
                  "started_at": datetime.now(timezone.utc).isoformat(),
                  "finished_at": datetime.now(timezone.utc).isoformat()}
    result["cmd_id"] = cmd_id
    if "vm_job_id" not in result:
        result["vm_job_id"] = "unknown"
    try:
        st, ack = http_post(
            f"{PT_API_BASE}/api/pentest/vm/commands/{cmd_id}/result",
            result, headers, timeout=30,
        )
        log(f"  [pool] done cmd_id={cmd_id} push={st}")
    except Exception as e:
        log(f"  [pool] push fail cmd_id={cmd_id}: {e}")
    finally:
        with _INFLIGHT_LOCK:
            _INFLIGHT_CMDS.discard(cmd_id)


def poll_and_execute(cfg: dict):
    headers = {"X-PT-License": cfg.get("api_key", "")}
    device_id = cfg.get("device_id", "unknown")
    status, resp = http_get(
        f"{PT_API_BASE}/api/pentest/vm/commands/pending/poll?device_id={device_id}",
        headers, timeout=15,
    )
    if status != 200:
        return
    commands = resp.get("commands") or []
    if not commands:
        return
    new_count = 0
    with _INFLIGHT_LOCK:
        for cmd in commands:
            cmd_id = cmd.get("cmd_id")
            if not cmd_id or cmd_id in _INFLIGHT_CMDS:
                continue  # déjà en cours, skip
            _INFLIGHT_CMDS.add(cmd_id)
            _EXEC_POOL.submit(_execute_and_push, cmd, cfg)
            new_count += 1
    if new_count:
        log(f"received {len(commands)} cmd(s), dispatched {new_count} to pool ({len(_INFLIGHT_CMDS)} inflight)")


# ─── Loop principale ─────────────────────────────────────────────────────────


def main() -> int:
    log("== sylink-pentest-agent start ==")
    last_hb = 0
    last_poll = 0
    while True:
        cfg = load_config()
        if not cfg.get("api_key"):
            log("pas de licence active (config/license.json absent) — sleep 30s")
            time.sleep(30)
            continue
        now = time.monotonic()
        try:
            if now - last_hb >= HEARTBEAT_INTERVAL:
                heartbeat(cfg)
                last_hb = now
            if now - last_poll >= POLL_INTERVAL:
                poll_and_execute(cfg)
                last_poll = now
        except Exception as e:
            log(f"loop err : {e}")
        time.sleep(5)


if __name__ == "__main__":
    sys.exit(main())
