import os
import logging
import urllib.request
import ssl
import datetime as Date
import threading
import json
import time
import re

from flask.logging import default_handler
from logging import StreamHandler
from logging.handlers import RotatingFileHandler


class UploadHandler(StreamHandler):
  MAX_BATCH_SIZE = 1000
  BATCH_UPLOAD_INTERVAL_IN_S = 5

  def __init__(self):
    StreamHandler.__init__(self)
    self.api_hostname = ""
    self.api_token = ""
    self.request = None
    self.request_context = None
    self.batch = []
    self.upload_thread = threading.Thread(target=self._process_upload_queue)
    self.upload_thread.daemon = True
    self._local_logger = logging.getLogger(__name__)
    self._local_logger.addHandler(default_handler)
    self.upload_thread.start()

  def emit(self, record):
    with self.lock:
      self.batch.append({
        "application": "network_sidecar",
        "levelname": record.levelname,
        "message": "[{kasm_id}] {message}".format(**record.msg),
      })
      self.batch = self.batch[-self.MAX_BATCH_SIZE:]

  def reset(self):
    with self.lock:
      self.api_hostname = ""
      self.api_token = ""

  def _extract_hostname(self):
    if not os.path.exists("/var/run/kasm-sidecar/api_hostname"):
      return

    with open("/var/run/kasm-sidecar/api_hostname", "r") as file:
      hostname = file.read().strip()
      if re.match(r"^[a-zA-Z0-9.-]+:\d+$", hostname):
        self.api_hostname = hostname
        self.request = urllib.request.Request(f"https://{self.api_hostname}/api/component_log", method="POST")
        self.request.add_header("Content-Type", "application/json")
        self.request_context = ssl.create_default_context()
        self.request_context.check_hostname = False
        self.request_context.verify_mode = False

  def _process_upload_queue(self):
    last_hostname_check_at = 0
    last_upload_check_at = 0

    hostname_dir = '/var/run/kasm-sidecar'
    if os.path.exists(f'{hostname_dir}/api_hostname_external'):
      self._extract_hostname()

    while True:
      try:
        # Wait for a configuration.
        # After initialization, refresh the api_hostname periodically. The associated network can be reset
        # and/or a new IP address assigned to the API host without restarting the plugin container.
        now = time.time()
        hostname_check_elapsed = now - last_hostname_check_at
        if (not self.api_hostname and hostname_check_elapsed > 5) or hostname_check_elapsed > 30:
          last_hostname_check_at = now
          self._extract_hostname()

        if not self.api_hostname or not self.api_token:
          time.sleep(1)
          continue

        # wait for a batch
        if time.time() - last_upload_check_at < self.BATCH_UPLOAD_INTERVAL_IN_S:
          time.sleep(1)
          continue

        last_upload_check_at = time.time()
        with self.lock:
          if not self.api_hostname or not self.api_token:
            continue
          items = self.batch.copy()
          self.batch = []

        if len(items) == 0:
          continue

        # upload the batch
        payload = json.dumps({
          'logs': items,
          'token': self.api_token
        }).encode('utf-8')
        self.request.add_header("Content-Length", len(payload))
        self._local_logger.info(f'Sending {len(items)} log records to {self.api_hostname}')
        response = urllib.request.urlopen(self.request, payload, context=self.request_context, timeout=30)

        if response.status > 299:
          self._local_logger.error(f"HTTP Logging failed: Invalid response code {response.status}", flush=True)

        resp_json = json.loads(response.read())
        if resp_json.get('error_message'):
          self._local_logger.error(f"HTTP Logging failed: {resp_json.get('error_message')}", flush=True)

      except Exception as e:
        self._local_logger.error(f"Failed to upload logs to the proxy at hostname({self.api_hostname}): {e}", flush=True)

class JsonFormatter(logging.Formatter):
  def format(self, record):
    return json.dumps({
      "asctime": record.msg["asctime"],
      "levelname": record.msg["levelname"],
      "message": record.msg["message"],
      "kasm_id": record.msg["kasm_id"]
    })

class JsonFileHandler(RotatingFileHandler):
  def __init__(self, filename, *args, **kwargs):
    super().__init__(filename, *args, **kwargs)
    self.setFormatter(JsonFormatter())

  def emit(self, record):
    super().emit(record)

class TextFileHandler(RotatingFileHandler):
  def emit(self, record):
    msg = record.msg
    time = msg["asctime"]
    levelname = msg["levelname"]
    message = msg["message"]
    kasm_id = msg["kasm_id"]
    record.msg = f"{time} | {levelname} | {kasm_id} | {message}"
    super().emit(record)

class ScriptLogger(logging.Logger):
    def __init__(self, name, log_file_path):
      super().__init__(name)

      self.upload_handler = UploadHandler()
      self.addHandler(self.upload_handler)

      self.json_handler = JsonFileHandler(log_file_path.replace(".log", "_json.log"))
      self.addHandler(self.json_handler)

      self.text_handler = TextFileHandler(log_file_path)
      self.addHandler(self.text_handler)

      self._local_logger = logging.getLogger(__name__)
      self._local_logger.addHandler(default_handler)

      self.read_pipe, self.write_pipe = os.pipe()
      self.log_thread = threading.Thread(target=self._read_logs)
      self.log_thread.daemon = True
      self.log_thread.start()

    @property
    def pipe(self):
      return self.write_pipe
    
    def set_api_token(self, token):
      self.upload_handler.api_token = token

    def reset(self):
      self.upload_handler.reset()

    def _read_logs(self):
      with os.fdopen(self.read_pipe) as lines:
        for line in lines:
          try:
            log = json.loads(line)
            payload = {
              "asctime": Date.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
              "levelname": log["levelname"].upper(),
              "message": "\n".join(log["message"]),
              "kasm_id": log["kasm_id"]
            }

            level = log["levelname"].lower()
            if level == "debug":
              self.debug(payload)
            elif level == "info":
              self.info(payload)
            elif level == "warning":
              self.warning(payload)
            elif level == "error":
              self.error(payload)
          except json.JSONDecodeError as e:
            self._local_logger.error(f"Failed to parse script log line on \"{line}\": {e}")
