#!/usr/bin/python3

# Copyright 2021 Michal Arbet
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import yaml

# Default paths
PROXYSQL_CONFIG_DIR = "/etc/proxysql"
PROXYSQL_CONFIG = "/etc/proxysql.cnf"

# Logging
log_format = '%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s'
logging.basicConfig(format=log_format,
                    datefmt='%H:%M:%S',
                    level=logging.DEBUG)
LOG = logging.getLogger("proxysql_config_sync")


class ProxySQLConfig:

    def __init__(self, conf_dir, conf_file):
        self.configs = dict()
        self.config = dict()
        self.configs['global'] = "{}/proxysql.yaml".format(conf_dir)
        self.configs['users'] = "{}/users".format(conf_dir)
        self.configs['rules'] = "{}/rules".format(conf_dir)
        self.conf_file = conf_file
        self._load_config()

    def _load_config(self):
        users = dict()
        rules = dict()

        for cfg, path in self.configs.items():
            if not os.path.isdir(path):
                with open(path) as config_file:
                    self.config.update(
                        yaml.safe_load(config_file))
            else:
                users['mysql_users'] = list()
                rules['mysql_query_rules'] = list()
                user_paths = [os.path.join(self.configs['users'], f)
                              for f in os.listdir(self.configs['users'])]
                rule_paths = [os.path.join(self.configs['rules'], f)
                              for f in os.listdir(self.configs['rules'])]

                for user_conf in user_paths:
                    with open(user_conf) as config_file:
                        tmp_users = yaml.safe_load(config_file)
                        for i in tmp_users['mysql_users']:
                            users['mysql_users'].append(i)
                self.config.update(users)
                for rule_conf in rule_paths:
                    with open(rule_conf) as config_file:
                        tmp_rules = yaml.safe_load(config_file)
                        for i in tmp_rules['mysql_query_rules']:
                            rules['mysql_query_rules'].append(i)
                self.config.update(rules)
        self._sanity()

    def _sanity(self):
        self._users_sanity()
        self._rules_sanity()

    def _users_sanity(self):
        users_added = list()
        users = list()
        for user in self.config['mysql_users']:
            if user['username'] not in users_added:
                users_added.append(user['username'])
                users.append(user)
            else:
                LOG.warning("User {} already exist, ignoring."
                            .format(user['username']))
        self.config['mysql_users'] = users

    def _rules_sanity(self):
        rules_added = list()
        rules = list()
        rule_id = 1
        for rule in self.config['mysql_query_rules']:
            if 'schemaname' in rule:
                key = f"schema:{rule['schemaname']}"
            elif 'username' in rule:
                key = f"user:{rule['username']}"
            else:
                LOG.warning(
                    f"Rule without schemaname or username found, skipping: {rule}"
                )
                continue

            if key not in rules_added:
                rules_added.append(key)
                rule['rule_id'] = rule_id
                rules.append(rule)
                rule_id += 1
            else:
                LOG.warning("Duplicate rule for {}, ignoring.".format(key))

        self.config['mysql_query_rules'] = rules

    def _write_dict(self, key, value):
        if not isinstance(value, list):
            value = [value]
        with open(self.conf_file, "a+") as f:
            if key:
                f.write("{} =\n".format(key))
            for i in range(len(value)):
                f.write("  {\n")
                for k, v in value[i].items():
                    if isinstance(v, str):
                        v = '"{}"'.format(v)
                    f.write("     {} = {}\n".format(k, v))
                if i == len(value)-1:
                    f.write("  }\n")
                else:
                    f.write("  },\n")

    def _write_list(self, key, values):
        with open(self.conf_file, "a+") as f:
            f.write("{} =\n".format(key))
            f.write("(\n")
        self._write_dict(key=None, value=values)
        with open(self.conf_file, "a+") as f:
            f.write(")\n")

    def _write(self, key, value):
        with open(self.conf_file, "a+") as f:
            if isinstance(value, str):
                value = '"{}"'.format(value)
            f.write("{} = {}\n".format(key, value))

    def write_config(self):
        LOG.info("Writing config to {}".format(self.conf_file))
        with open(self.conf_file, 'wb') as f:
            pass
        for k, v in self.config.items():
            if isinstance(v, dict):
                self._write_dict(k, v)
            elif isinstance(v, list):
                self._write_list(k, v)
            else:
                self._write(k, v)


if __name__ == "__main__":
    config = ProxySQLConfig(PROXYSQL_CONFIG_DIR, PROXYSQL_CONFIG)
    config.write_config()
