import os
import json
import re
import base64
import socket
import logging
from flask import Flask, jsonify, request
from plumbum import BG, local
from logger import ScriptLogger

network_dir = "/root/networks/"
log_dir = os.environ.get("LOG_DIR", "/var/log/kasm-sidecar")

os.makedirs(network_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs("/var/run/kasm-sidecar", exist_ok=True)

script_logger = ScriptLogger("kasm-sidecar-plugin", os.path.join(log_dir, "network_sidecar.log"))

#
def read_api_hostname():
  if not os.path.exists("/var/run/kasm-sidecar/api_hostname"):
    return ""

  with open("/var/run/kasm-sidecar/api_hostname", "r") as file:
    return file.read().strip()

def load_network_config(network_id):
  with open(network_dir + network_id + ".json") as f:
    network_config = json.load(f)
    f.close()
  return network_config

def save_network_config(network_id, network_config):
  with open(network_dir + network_id + ".json", "w") as f:
    json.dump(network_config, f, indent=2)
    f.close()
    
def delete_network_config(network_id):
  os.remove(network_dir + network_id + ".json")

def lookup_hostname_ip(hostname):
  if re.match(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", hostname):
    return hostname
  else:
    hostname, port = hostname.split(":") if ":" in hostname else (hostname, None)
    ip = socket.gethostbyname(hostname)
    return f"{ip}:{port}" if port else ip

log_file = os.path.join(log_dir, "plugin.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = Flask(__name__)

@app.route("/Plugin.Activate", methods=["POST"])
def activate():
  script_logger.reset()
  return jsonify({"Implements": ["NetworkDriver"]})

@app.route("/NetworkDriver.GetCapabilities", methods=["POST"])
def get_capabilities():
  return jsonify({"Scope":"local"})

@app.route("/NetworkDriver.CreateNetwork", methods=["POST"])
def create_network():
  try:
    # ingest options
    data = request.get_json(force=True)
    options = data["Options"]["com.docker.network.generic"]

    # build config file
    network_config = {}
    network_config["bridge-name"] = "br_kasm_sidecar"
    network_config["bridge-gateway"] = data["IPv4Data"][0]["Gateway"].replace(r"/(16|24)", "")
    network_config["pool"] = data["IPv4Data"][0]["Pool"]
    network_config["name"] = options["name"] if "name" in options else ""
    network_config["endpoints-configs"] = options["endpoints-configs"] if "endpoints-configs" in options else {}
    network_config["endpoints-processes"] = {}
    save_network_config(data["NetworkID"], network_config)

    return jsonify({"Err": ""})
  except Exception as e:
    return jsonify({"Err": str(e)})

@app.route("/NetworkDriver.DeleteNetwork", methods=["POST"])
def delete_network():
  return jsonify({"Err": ""})

@app.route("/NetworkDriver.CreateEndpoint", methods=["POST"])
def create_endpoint():
  data = request.get_json(force=True)

  # load network config from storage
  network_id = data["NetworkID"]
  network_config = load_network_config(network_id)

  # add endpoint to network config
  endpoint_id = data["EndpointID"]
  options = data.get("Options", {})
  endpoint_config = json.loads(options.get("com.kasmweb.network.endpoint.config", "{}"))

  if "api_token" in endpoint_config:
    script_logger.set_api_token(endpoint_config["api_token"])

  if "script_data" in endpoint_config:
    script_data = json.loads(base64.b64decode(endpoint_config["script_data"]).decode("utf-8"))

    if endpoint_config["script_name"] == "egress_openvpn_setup":
      # replace remote host
      remote_host = re.search(r"remote (\S+)", script_data["config"]).group(1)
      remote_host_ip = lookup_hostname_ip(remote_host)
      script_data["config"] = re.sub(f"remote {remote_host}", f"remote {remote_host_ip}", script_data["config"])
      script_data["server"] = remote_host_ip
      script_data["egress_provider"] = endpoint_config["egress_provider"] if "egress_provider" in endpoint_config else "n/a"
      script_data["egress_provider_type"] = endpoint_config["egress_provider_type"] if "egress_provider_type" in endpoint_config else "n/a"
      script_data["egress_gateway"] = endpoint_config["egress_gateway"] if "egress_gateway" in endpoint_config else "n/a"
      script_data["egress_country"] = endpoint_config["egress_country"] if "egress_country" in endpoint_config else "n/a"
      script_data["egress_city"] = endpoint_config["egress_city"] if "egress_city" in endpoint_config else "n/a"
      script_data["show_ip_status"] = endpoint_config["show_ip_status"] if "show_ip_status" in endpoint_config else ""
      script_data["show_vpn_status"] = endpoint_config["show_vpn_status"] if "show_vpn_status" in endpoint_config else ""

      # replace dns servers
      script_data["dns"] = []
      for dns_server in re.findall(r"dhcp-option DNS (\S+)", script_data["config"]):
        script_data["dns"].append(dns_server)
    elif endpoint_config["script_name"] == "egress_wireguard_setup":
      # find remote host IP
      remote_host = re.search(r"Endpoint\s*=\s*([\w.\-_]+)", script_data["config"]).group(1)
      remote_host_ip = lookup_hostname_ip(remote_host)
      script_data["config"] = re.sub(r"Endpoint\s*=\s*([\w.\-_]+)", f"Endpoint = {remote_host_ip}", script_data["config"])
      script_data["server"] = remote_host_ip
      script_data["egress_provider"] = endpoint_config["egress_provider"] if "egress_provider" in endpoint_config else "n/a"
      script_data["egress_provider_type"] = endpoint_config["egress_provider_type"] if "egress_provider_type" in endpoint_config else "n/a"
      script_data["egress_gateway"] = endpoint_config["egress_gateway"] if "egress_gateway" in endpoint_config else "n/a"
      script_data["egress_country"] = endpoint_config["egress_country"] if "egress_country" in endpoint_config else "n/a"
      script_data["egress_city"] = endpoint_config["egress_city"] if "egress_city" in endpoint_config else "n/a"
      script_data["show_ip_status"] = endpoint_config["show_ip_status"] if "show_ip_status" in endpoint_config else ""
      script_data["show_vpn_status"] = endpoint_config["show_vpn_status"] if "show_vpn_status" in endpoint_config else ""

      # replace dns servers
      script_data["dns"] = []
      if re.search(r"DNS\s*=\s*(\S+)", script_data["config"]):
        for dns_server in re.search(r"DNS\s*=\s*(\S+)", script_data["config"]).group(1).split(","):
          script_data["dns"].append(dns_server)

      # update the config to use credential's private_key
      if "private_key" in script_data:
        if re.search(r"PrivateKey\s*=\s*(\S+)", script_data["config"]):
          script_data["config"] = re.sub(r"PrivateKey\s*=\s*(\S+)", f"PrivateKey = {script_data['private_key']}", script_data["config"])
        else:
          interface_section = re.search(r"\[Interface\]\s*([\s\S]+)", script_data["config"]).group(1)
          script_data["config"] = re.sub(r"\[Interface\]\s*([\s\S]+)", f"[Interface]\nPrivateKey = {script_data['private_key']}\n{interface_section}", script_data["config"])

    endpoint_config["script_data"] = base64.b64encode(json.dumps(script_data).encode("utf-8")).decode("utf-8")

  network_config["endpoints-configs"][endpoint_id] = endpoint_config
  save_network_config(network_id, network_config)

  return jsonify({"Err": ""})

@app.route("/NetworkDriver.Join", methods=["POST"])
def join():
  try:
    # read network config
    data = request.get_json(force=True)

    # load the endpoint config
    network_config = load_network_config(data["NetworkID"])
    endpoint_id = data["EndpointID"]
    endpoint_config = network_config["endpoints-configs"][endpoint_id]

    if endpoint_config is None:
      return jsonify({"Err": ""})

    ns_file = data["SandboxKey"]
    endpoint_config["ns"] = ns_file

    bridge_name = network_config["bridge-name"]
    bridge_gateway = network_config["bridge-gateway"]
    kasm_id = endpoint_config.get("kasm_id", "")
    script_name = endpoint_config.get("script_name", "")
    script_data = endpoint_config.get("script_data", {})

    nssetup = local["nssetup"].with_env(
      KASM_API_HOST=read_api_hostname(),
      KASM_API_JWT=endpoint_config.get("api_token", ""))
    (nssetup[ns_file, bridge_name, bridge_gateway, kasm_id, script_name, script_data]) & BG(stdout=script_logger.pipe, stderr=script_logger.pipe)

    save_network_config(data["NetworkID"], network_config)

    # note: for any containers running with sysbox runtime we need to actually
    # start the container with default network interfaces
    runtime = endpoint_config.get("runtime", "")
    return jsonify({"DisableGatewayService": runtime != "sysbox-runc"})
  except Exception as e:
    logger.error(e)
    return jsonify({"Err": str(e)}), 500
  
@app.route("/NetworkDriver.ProgramExternalConnectivity", methods=["POST"])
def program_external_connectivity():
  return jsonify({"Err": ""})

@app.route("/NetworkDriver.Leave", methods=["POST"])
def leave():
  return jsonify({"Err": ""})

@app.route("/NetworkDriver.DeleteEndpoint", methods=["POST"])
def delete_endpoint():
  try:
    data = request.get_json(force=True)
    network_id = data["NetworkID"]
    network_config = load_network_config(network_id)

    endpoint_id = data["EndpointID"]
    if endpoint_id not in network_config["endpoints-configs"]:
      return jsonify({"Err": ""})

    endpoint_config = network_config["endpoints-configs"][endpoint_id]
    if "ns" in endpoint_config:
      kasm_id = endpoint_config.get("kasm_id", "kasm_proxy")
      nscleanup = local["nscleanup"].with_env(
        KASM_API_HOST=read_api_hostname(),
        KASM_API_JWT=endpoint_config.get("api_token", ""))
      nscleanup(endpoint_config["ns"], kasm_id, stdout=script_logger.pipe, stderr=script_logger.pipe)

    del network_config["endpoints-configs"][endpoint_id]
    save_network_config(network_id, network_config)
    return jsonify({"Err": ""})
  except Exception as e:
    logger.error(e)
    return jsonify({"Err": str(e)}), 500

@app.route("/NetworkDriver.EndpointOperInfo", methods=["POST"])
def info():
  return jsonify({"Err": ""})

app.run(host="unix:///run/docker/plugins/kasmnetwork.sock",threaded=True)
