#!/usr/bin/env bash
set -euo pipefail

SERVICE_NAME="tzproxmox-console-proxy"
HELPER_VERSION="1.1.0"
PYTHON_FILE="/usr/local/bin/${SERVICE_NAME}.py"
ENV_FILE="/etc/${SERVICE_NAME}.env"
VERSION_FILE="/etc/${SERVICE_NAME}.version"
SYSTEMD_FILE="/etc/systemd/system/${SERVICE_NAME}.service"

if [ "${1:-install}" = "version" ] || [ "${1:-install}" = "--version" ]; then
    echo "${SERVICE_NAME} installer ${HELPER_VERSION}"
    if [ -f "${VERSION_FILE}" ]; then
        echo "installed $(cat "${VERSION_FILE}")"
    fi
    exit 0
fi

if [ "${1:-install}" = "uninstall" ]; then
    systemctl disable --now "${SERVICE_NAME}" >/dev/null 2>&1 || true
    rm -f "${SYSTEMD_FILE}" "${PYTHON_FILE}" "${VERSION_FILE}"
    systemctl daemon-reload
    echo "Removed ${SERVICE_NAME}. Config left at ${ENV_FILE}."
    exit 0
fi

if [ "${EUID}" -ne 0 ]; then
    echo "Run this file as root on the Proxmox host."
    exit 1
fi

if ! command -v python3 >/dev/null 2>&1; then
    echo "python3 is required. Proxmox normally includes it."
    exit 1
fi

if [ ! -f /etc/pve/local/pveproxy-ssl.pem ] || [ ! -f /etc/pve/local/pveproxy-ssl.key ]; then
    echo "Proxmox proxy TLS files were not found under /etc/pve/local/."
    echo "This helper expects to reuse the Proxmox web certificate."
    exit 1
fi

umask 077
if [ ! -f "${ENV_FILE}" ]; then
    if command -v openssl >/dev/null 2>&1; then
        SECRET="$(openssl rand -hex 32)"
    else
        SECRET="$(python3 - <<'PY'
import secrets
print(secrets.token_hex(32))
PY
)"
    fi
    cat > "${ENV_FILE}" <<EOF
TZ_PROXY_LISTEN_HOST=0.0.0.0
TZ_PROXY_LISTEN_PORT=8787
TZ_PROXY_UPSTREAM_HOST=127.0.0.1
TZ_PROXY_UPSTREAM_PORT=8006
TZ_PROXY_TLS_CERT=/etc/pve/local/pveproxy-ssl.pem
TZ_PROXY_TLS_KEY=/etc/pve/local/pveproxy-ssl.key
TZ_PROXY_REQUIRE_SIGNATURE=0
TZ_PROXY_SHARED_SECRET=${SECRET}
EOF
fi

cat > "${PYTHON_FILE}" <<'PY'
#!/usr/bin/env python3
import base64
import hashlib
import hmac
import os
import secrets
import select
import signal
import socket
import ssl
import sys
import threading
import time
from urllib.parse import parse_qs, urlsplit

LISTEN_HOST = os.environ.get("TZ_PROXY_LISTEN_HOST", "0.0.0.0")
VERSION = os.environ.get("TZ_PROXY_HELPER_VERSION", "1.0.1")
LISTEN_PORT = int(os.environ.get("TZ_PROXY_LISTEN_PORT", "8787"))
UPSTREAM_HOST = os.environ.get("TZ_PROXY_UPSTREAM_HOST", "127.0.0.1")
UPSTREAM_PORT = int(os.environ.get("TZ_PROXY_UPSTREAM_PORT", "8006"))
TLS_CERT = os.environ.get("TZ_PROXY_TLS_CERT", "/etc/pve/local/pveproxy-ssl.pem")
TLS_KEY = os.environ.get("TZ_PROXY_TLS_KEY", "/etc/pve/local/pveproxy-ssl.key")
REQUIRE_SIGNATURE = os.environ.get("TZ_PROXY_REQUIRE_SIGNATURE", "0") == "1"
SHARED_SECRET = os.environ.get("TZ_PROXY_SHARED_SECRET", "")
MAX_HEADER = 65536
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

running = True


def log(message):
    sys.stderr.write(time.strftime("[%Y-%m-%d %H:%M:%S] ") + message + "\n")
    sys.stderr.flush()


def http_response(conn, status, body, extra_headers=None):
    body = body.encode("utf-8", "replace")
    headers = [
        "HTTP/1.1 %s" % status,
        "Content-Type: text/plain",
        "Content-Length: %d" % len(body),
        "Connection: close",
        "Access-Control-Allow-Origin: *",
        "Access-Control-Allow-Methods: GET, OPTIONS",
        "Access-Control-Allow-Headers: Content-Type, Cookie",
    ]
    if extra_headers:
        headers.extend(extra_headers)
    
    response = "\r\n".join(headers) + "\r\n\r\n"
    conn.sendall(response.encode("ascii") + body)


def read_headers(conn):
    data = b""
    while b"\r\n\r\n" not in data:
        chunk = conn.recv(4096)
        if not chunk:
            raise RuntimeError("client closed before headers")
        data += chunk
        if len(data) > MAX_HEADER:
            raise RuntimeError("request headers too large")

    head, rest = data.split(b"\r\n\r\n", 1)
    lines = head.decode("iso-8859-1").split("\r\n")
    if not lines[0]:
        raise RuntimeError("empty request")
        
    parts = lines[0].split(" ")
    if len(parts) < 2:
        raise RuntimeError("invalid request line")
        
    method = parts[0]
    raw_target = parts[1]
    
    headers = {}
    for line in lines[1:]:
        if ":" in line:
            key, value = line.split(":", 1)
            headers[key.strip().lower()] = value.strip()
    return method, raw_target, headers, rest


def parse_cookie(headers, name):
    cookie = headers.get("cookie", "")
    for part in cookie.split(";"):
        if "=" not in part:
            continue
        key, value = part.split("=", 1)
        if key.strip() == name:
            return value.strip()
    return ""


def websocket_accept(key):
    digest = hashlib.sha1((key + GUID).encode("ascii")).digest()
    return base64.b64encode(digest).decode("ascii")


def verify_signature(params, target, pveticket):
    if not REQUIRE_SIGNATURE:
        return True, ""
    if not SHARED_SECRET:
        return False, "proxy requires a shared secret but none is configured"

    exp = params.get("exp", [""])[0]
    sig = params.get("sig", [""])[0]
    try:
        if not exp or int(exp) < int(time.time()):
            return False, "signature expired"
    except ValueError:
        return False, "invalid signature expiry"

    payload = "%s|%s|%s" % (exp, target, pveticket)
    expected = hmac.new(SHARED_SECRET.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256).hexdigest()
    if not hmac.compare_digest(expected, sig):
        return False, "invalid signature"
    return True, ""


def build_upstream(raw_target, client_headers):
    parsed = urlsplit(raw_target)
    params = parse_qs(parsed.query, keep_blank_values=True)

    target = params.get("target", [""])[0] or params.get("path", [""])[0]

    pveticket = (
        params.get("pveticket", [""])[0]
        or params.get("pve_ticket", [""])[0]
        or parse_cookie(client_headers, "PVEAuthCookie")
    )

    if not target and parsed.path.startswith("/api2/json/"):
        helper_keys = {"pveticket", "pve_ticket", "sig", "exp"}
        clean_params = []
        for key, values in params.items():
            if key in helper_keys:
                continue
            for value in values:
                clean_params.append("%s=%s" % (key, value))
        target = parsed.path.lstrip("/")
        if clean_params:
            target += "?" + "&".join(clean_params)

    if target.startswith("/"):
        target = target[1:]

    if not target.startswith("api2/json/nodes/") or "/vncwebsocket" not in target:
        raise RuntimeError("invalid or missing Proxmox vncwebsocket target")
    if not pveticket:
        raise RuntimeError("missing PVEAuthCookie/pveticket")

    ok, reason = verify_signature(params, target, pveticket)
    if not ok:
        raise RuntimeError(reason)

    return "/" + target, pveticket


def open_upstream(upstream_path, pveticket):
    ctx = ssl.create_default_context()
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    raw = socket.create_connection((UPSTREAM_HOST, UPSTREAM_PORT), timeout=10)
    sock = ctx.wrap_socket(raw, server_hostname=UPSTREAM_HOST)
    key = base64.b64encode(secrets.token_bytes(16)).decode("ascii")

    request = (
        "GET %s HTTP/1.1\r\n"
        "Host: %s:%d\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Key: %s\r\n"
        "Sec-WebSocket-Version: 13\r\n"
        "Sec-WebSocket-Protocol: binary\r\n"
        "Cookie: PVEAuthCookie=%s\r\n"
        "\r\n"
    ) % (upstream_path, UPSTREAM_HOST, UPSTREAM_PORT, key, pveticket)
    sock.sendall(request.encode("utf-8"))

    data = b""
    while b"\r\n\r\n" not in data:
        chunk = sock.recv(4096)
        if not chunk:
            raise RuntimeError("upstream closed during websocket handshake")
        data += chunk
        if len(data) > MAX_HEADER:
            raise RuntimeError("upstream response headers too large")

    header_blob, rest = data.split(b"\r\n\r\n", 1)
    first = header_blob.split(b"\r\n", 1)[0].decode("iso-8859-1", "replace")
    if " 101 " not in first:
        raise RuntimeError("upstream websocket handshake failed: " + first)
    return sock, rest


def proxy_loop(client, upstream, initial_upstream_data=b"", initial_client_data=b""):
    if initial_upstream_data:
        client.sendall(initial_upstream_data)
    if initial_client_data:
        upstream.sendall(initial_client_data)

    sockets = [client, upstream]
    while running:
        readable, _, _ = select.select(sockets, [], [], 60)
        if not readable:
            continue
        for sock in readable:
            try:
                data = sock.recv(65536)
                if not data:
                    return
                other = upstream if sock is client else client
                other.sendall(data)
            except Exception:
                return


def handle_client(conn, addr):
    upstream = None
    try:
        method, raw_target, headers, rest = read_headers(conn)
        parsed_target = urlsplit(raw_target)
        
        # Handle CORS preflight
        if method.upper() == "OPTIONS":
            http_response(conn, "204 No Content", "")
            return

        if headers.get("upgrade", "").lower() != "websocket":
            if parsed_target.path in ("/", "/health", "/version"):
                http_response(conn, "200 OK", "tzproxmox-console-proxy %s\n" % VERSION)
                return
            
            # Support setting cookie via proxy for cross-domain support
            if parsed_target.path == "/set-cookie":
                params = parse_qs(parsed_target.query)
                ticket = params.get("ticket", [""])[0] or params.get("pveticket", [""])[0]
                if ticket:
                    # SameSite=None and Secure are critical for cross-domain cookies
                    cookie = "Set-Cookie: PVEAuthCookie=%s; Path=/; Secure; SameSite=None" % ticket
                    http_response(conn, "200 OK", "Auth cookie has been set.", [cookie])
                else:
                    http_response(conn, "400 Bad Request", "Missing ticket parameter.")
                return
                
            http_response(conn, "426 Upgrade Required", "WebSocket upgrade required or invalid path.")
            return

        if method.upper() != "GET":
            http_response(conn, "405 Method Not Allowed", "Only websocket GET is supported.")
            return

        client_key = headers.get("sec-websocket-key", "")
        if not client_key:
            http_response(conn, "400 Bad Request", "Missing Sec-WebSocket-Key.")
            return

        upstream_path, pveticket = build_upstream(raw_target, headers)
        upstream, upstream_rest = open_upstream(upstream_path, pveticket)

        protocol = ""
        if "binary" in headers.get("sec-websocket-protocol", ""):
            protocol = "Sec-WebSocket-Protocol: binary\r\n"

        response = (
                "HTTP/1.1 101 Switching Protocols\r\n"
                "Upgrade: websocket\r\n"
                "Connection: Upgrade\r\n"
                "Sec-WebSocket-Accept: %s\r\n"
                "%s"
                "\r\n"
            ) % (websocket_accept(client_key), protocol)
        conn.sendall(response.encode("ascii"))

        log("%s connected to %s" % (addr[0], upstream_path.split("?", 1)[0]))
        proxy_loop(conn, upstream, upstream_rest, rest)
    except Exception as exc:
        try:
            http_response(conn, "502 Bad Gateway", str(exc))
        except Exception:
            pass
        log("%s error: %s" % (addr[0], exc))
    finally:
        for sock in (upstream, conn):
            if sock:
                try:
                    sock.close()
                except Exception:
                    pass


def shutdown(_signum, _frame):
    global running
    running = False


def main():
    signal.signal(signal.SIGTERM, shutdown)
    signal.signal(signal.SIGINT, shutdown)

    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    ctx.load_cert_chain(TLS_CERT, TLS_KEY)

    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((LISTEN_HOST, LISTEN_PORT))
    server.listen(100)
    server.settimeout(1)
    log("listening on %s:%d" % (LISTEN_HOST, LISTEN_PORT))

    while running:
        try:
            conn, addr = server.accept()
        except socket.timeout:
            continue
        except OSError:
            break

        try:
            tls_conn = ctx.wrap_socket(conn, server_side=True)
        except Exception as exc:
            log("TLS handshake failed from %s: %s" % (addr[0], exc))
            conn.close()
            continue

        thread = threading.Thread(target=handle_client, args=(tls_conn, addr), daemon=True)
        thread.start()

    server.close()


if __name__ == "__main__":
    main()
PY

chmod 700 "${PYTHON_FILE}"

cat > "${SYSTEMD_FILE}" <<EOF
[Unit]
Description=TZProxmoxVE noVNC WebSocket proxy helper
After=network-online.target pveproxy.service
Wants=network-online.target

[Service]
Type=simple
EnvironmentFile=${ENV_FILE}
Environment=TZ_PROXY_HELPER_VERSION=${HELPER_VERSION}
ExecStart=/usr/bin/python3 ${PYTHON_FILE}
Restart=always
RestartSec=2
User=root
Group=root
NoNewPrivileges=true
PrivateTmp=true

[Install]
WantedBy=multi-user.target
EOF

systemctl daemon-reload
systemctl enable "${SERVICE_NAME}" >/dev/null
systemctl restart "${SERVICE_NAME}"
echo "${HELPER_VERSION}" > "${VERSION_FILE}"

PORT="$(grep '^TZ_PROXY_LISTEN_PORT=' "${ENV_FILE}" | cut -d= -f2)"
SECRET="$(grep '^TZ_PROXY_SHARED_SECRET=' "${ENV_FILE}" | cut -d= -f2-)"

echo
echo "Installed ${SERVICE_NAME} ${HELPER_VERSION}."
echo "Status:"
systemctl --no-pager --full status "${SERVICE_NAME}" | sed -n '1,12p'
echo
echo "Helper WebSocket endpoint: wss://<this-proxmox-host>:${PORT}/tzproxmox-console"
echo "Helper health check: https://<this-proxmox-host>:${PORT}/health"
echo "Shared secret saved in: ${ENV_FILE}"
echo "Shared secret: ${SECRET}"
echo
echo "Open firewall/TCP ${PORT} from your clients/WHMCS network if needed."
echo "Logs: journalctl -u ${SERVICE_NAME} -f"
