import abc
import time
import json
from logging import getLogger
from collections import defaultdict
import caosdb
from django.contrib import messages

LOGGER = getLogger(__name__)


class Result:
    def __init__(self):
        self.alive = True

    def close(self):
        """close"""
        pass

    def __next__(self):
        """__next__"""
        # TODO
        raise StopIteration

    def __iter__(self):
        """__iter__"""
        # TODO
        return self

class CountResult(Result):

    def __init__(self, count, column_name):
        super(CountResult, self).__init__()
        self._results = [{column_name: count}]

    def __iter__(self):
        return iter(self._results)

def _sort_key(x, field, names):
    res = x[field] or ""
    if res in names:
        res = names[res]
    return str(res)

class FindResult(Result):

    def __init__(self, rows, columns, sort=None, limit=None, skip=None,
                 record_types=None, connection=None):
        super(FindResult, self).__init__()
        self._sort = sort

        upper = (skip + limit if skip is not None and limit is not None else
                 limit)


        if columns is not None:
            if sort:
                fkey = list(sort)[0]
                names = {}
                if record_types is not None and fkey in record_types:
                    names = connection.list_record_names(fkey)

                reverse = sort[fkey] < 0
                named_columns = [dict(zip(columns, row)) for row in rows]
                named_columns.sort(key=lambda x: _sort_key(x, fkey, names), reverse=reverse)
                self._results = named_columns[skip:upper]
            else:
                self._results = []
                for row in rows[skip:upper]:
                    self._results.append(dict(zip(columns, row)))
        else:
            self._results=rows

    def __iter__(self):
        return iter(self._results)

class UpdateManyResult(Result):

    def __init__(self, container):
        super(UpdateManyResult, self).__init__()
        self.modified_count = len(container) # todo
        self.matched_count = len(container)

class DeleteManyResult(Result):

    def __init__(self, container):
        super(DeleteManyResult, self).__init__()
        self.deleted_count = len(container)

class InsertManyResult(Result):
    def __init__(self, container):
        super(InsertManyResult, self).__init__()
        self.inserted_ids = [e.id for e in container]

    def __iter__(self):
        return iter(self.inserted_ids)

class CachedRecordTypes(set):

    def __init__(self, connection):
        self.connection = connection
        super().__init__()

    def __contains__(self, item):
        ans = super().__contains__(item)
        if ans:
            return ans
        self.update(self.connection.list_record_type_names())
        return super().__contains__(item)

class CachedProperties(set):

    def __init__(self, connection):
        self.connection = connection
        super().__init__()

    def __contains__(self, item):
        ans = super().__contains__(item)
        if ans:
            return ans
        self.update(self.connection.list_property_names())
        return super().__contains__(item)

class DefaultCaosDBClientDelegate:

    _query_cache = {}
    _etag = None

    def __init__(self, **kwargs):
        self.cached_record_types = CachedRecordTypes(self)
        self.cached_properties = CachedProperties(self)
        if "caosdb" in kwargs:
            self._caosdb = kwargs["caosdb"]
        else:
            self._caosdb = caosdb

        self._caosdb.configure_connection(**kwargs)


    def _get_reference_clause(self, type, p, v, negation, sub=None):
        result = (f" REFERENCES {v}"
                  if type == "reference"
                  else f" IS REFERENCED BY {v}")
        # todo use v as well

        if sub is not None:
            subquery = self._get_filter_clause(sub)
            result += " WITH (" + subquery + ")"
        return " (" + result + " )"

    def _get_filter_clause(self, fil):
        LOGGER.debug("enter _get_filter_clause(%s)", fil)
        if not "type" in fil:
            raise NotImplementedError("_get_filter_clause(%s)", fil)

        filter_type = fil["type"]
        if filter_type == "and":
            components = [self._get_filter_clause(comps) for comps in
                          fil["elements"]]
            return " (" + " AND".join(components) + " )"
        if filter_type == "or":
            components = [self._get_filter_clause(comps) for comps in
                          fil["elements"]]
            return " (" + " OR".join(components) + " )"
        if filter_type in ["reference", "back_reference"]:
            return self._get_reference_clause(**fil)
        if filter_type == "in":
            p = fil["p"] if fil["p"] == "id" else f'"{fil["p"]}"'
            values = fil["v"]
            components = [f' {p} = ' + str(val) for val in values]
            return " OR".join(components)
        if filter_type == "pov":
            n = "NOT " if fil["negation"] is True else ""
            p = fil["p"] if fil["p"] == "id" else f'"{fil["p"]}"'
            o = fil["o"]
            v = fil["v"]
            return f' {n}{p}{o}"{v}"'
        raise NotImplementedError("_get_filter_clause(%s)", fil)

    def _cache_locally(self, etag, key, value):
        if etag != DefaultCaosDBClientDelegate._etag:
            DefaultCaosDBClientDelegate._query_cache = {}
            DefaultCaosDBClientDelegate._etag = etag
        if key is not None:
            LOGGER.debug("caching results with etag '%s' for key '%s'",
                         str(etag), str(key))
            DefaultCaosDBClientDelegate._query_cache[key] = value

    def _query(self, query, key=None):
        if key is not None and key in DefaultCaosDBClientDelegate._query_cache:
            LOGGER.debug("using cached results for key '%s'", str(key))
            return DefaultCaosDBClientDelegate._query_cache[key]

        result, etag = self._query_no_local_cache(query)
        self._cache_locally(etag, key, result)

        return result

    def _query_no_local_cache(self, query):
        t1 = time.time()
        query = self._caosdb.Query(q=query)
        result = query.execute()
        etag = query.etag
        t2 = time.time()
        LOGGER.debug("[%s] execute_query(%s) - %s", str(t1), query, str(t2 - t1))
        return result, etag

    def _find(self, record_type, *args, **kwargs):
        filter_clause = ""
        if "filter" in kwargs:
            fil = kwargs["filter"]
            filter_clause = " WITH " + self._get_filter_clause(fil)

        query = f'FIND RECORD "{record_type}"{filter_clause}'
        return self._query(query, key=query)

    def _get_property_values(self, container, projection):
        if projection is None:
            rows = []
            for e in container:
                row = defaultdict(lambda: None)
                row.update({"id": e.id, "name": e.name,
                            "description": e.description})
                for p in e.get_properties():
                    row[p.name] = p.value
                    if isinstance(p.value, list):
                        row[p.name] = json.dumps(p.value)
                rows.append(row)
            return rows
        rows = container.get_property_values(*projection)
        rows = [list(t) for t in rows]
        for row in rows:
            for i in range(len(row)):
                val = row[i]
                if isinstance(val, list):
                    row[i] = json.dumps(val)
        return rows

    def find(self, record_type, limit=None, sort=None, skip=None, *args, **kwargs):
        res = self._find(record_type, *args, **kwargs)
        projection = kwargs["projection"]
        rows = self._get_property_values(res, projection)
        return FindResult(rows, projection, sort, limit, skip,
                          record_types=self.cached_record_types, connection=self)

    def find_auth(self, record_type, *args, **kwargs):
        if record_type == "caosdb_auth_user":
            name = kwargs["filter"]["elements"][-1]["v"]
            res = self._caosdb.administration._retrieve_user(name)
            columns = ["id", "password", "last_login", "is_superuser",
                       "username", "first_name", "last_name", "email",
                       "is_staff", "is_active", "date_joined"]
            rows = [(0, None, None, True, "tf", "t", "f", "tf@example.com",
                     True, True, None,)]
            return FindResult(rows, columns)
        return FindResult([], [])

    def list_record_names(self, record_type):
        res = self._query(f"SELECT name FROM RECORD '{record_type}'")
        return {e.id: e.name for e in res if e.name is not None}

    def list_record_type_names(self):
        res = self._query("SELECT name FROM RECORDTYPE")
        return [e.name for e in res if e.name is not None] + ["caosdb_auth_user"]

    def list_property_names(self):
        res1 = self._query("SELECT name FROM PROPERTY")
        res2 = self._query("SELECT name FROM RECORDTYPE")
        return [e.name for e in res1 + res2 if e.name is not None]

    def create_record_type(self, name : str, properties : list, parents : list):
        c = self._caosdb.Container()
        rt = self._caosdb.RecordType(name)
        c.append(rt)
        for p in parents:
            name = p["name"]
            if name not in self.cached_record_types:
                # no way of dealing with this situation yet.
                # wait and let the server throw an error.
                pass
            rt.add_parent(name)
        for p in properties:
            name = p["name"]
            datatype = p["datatype"]
            if name not in self.cached_properties:
                new_property = self._caosdb.Property(name=name, datatype=datatype)
                c.append(new_property)

            rt.add_property(name=name, datatype=datatype)

        c.insert()

    def insert_auth(self, record_type : str, records: list):
        pass

    def insert_many(self, record_type : str, records : list):
        c = self._caosdb.Container()
        property_names = self.list_property_names()
        for rec in records:
            name, description = None, None
            if "name" in rec:
                name = rec["name"]
            if "description" in rec:
                description = rec["description"]

            new_rec = self._caosdb.Record(name=name, description=description)
            new_rec.add_parent(record_type)
            if "properties" in rec:
                for prop in rec["properties"]:
                    if prop["name"] in property_names:
                        new_rec.add_property(**prop)
                    else:
                        LOGGER.warning("%s has not been stored due to changes to the schema.", prop["name"])
            c.append(new_rec)
        c.insert(unique=False)
        return InsertManyResult(c)

    def update_many_auth(self, record_type: str, filter, update):
        return UpdateManyResult([1])

    def update_many(self, record_type: str, filter, update):
        res = self._find(record_type, filter=filter)
        update_set = update["$set"]
        properties = update_set.keys()
        for e in res:
            for p in properties:
                val = update_set[p]
                if hasattr(val, "isoformat"):
                    val = val.isoformat()
                e.get_property(p).value = val
        if res:
            res.update()
        return UpdateManyResult(res)

    def add_foreign_key(self, record_type, *args, **kwargs):
        c = self._caosdb.Container()
        rt = self._caosdb.RecordType(record_type).retrieve()
        c.append(rt)
        referenced = kwargs["target_record_type"]
        ref_prop_ok = len(kwargs["target_property"]) == 1
        prop_name = kwargs["target_property"][0][0]
        ref_prop_ok = ref_prop_ok and prop_name == "id"
        if not ref_prop_ok:
            raise NotImplementedError("FOREIGN KEY to non-id field: %s",
                                      kwargs["target_property"])
        rt.get_property(kwargs["property"][0][0]).datatype=referenced
        c.update()

    def _process_subselect(self, record_type, subselect, count=False, **kwargs):
        if "filter" in kwargs:
            raise NotImplementedError("subselect with filters")
        if count is not False:
            callback = subselect[0]
            record_type = subselect[1]
            params = subselect[2]
            result = callback(record_type, **params)
            return CountResult(len(result._results), count)
        raise NotImplementedError("subselect with anything but count")

    def aggregate(self, record_type, **kwargs):
        subselect = kwargs.pop("subselect", None)
        if subselect is not None:
            return self._process_subselect(record_type, subselect,
                                           **kwargs)
        else:
            # remove, not necessary with caosdb
            kwargs.pop("distinct", None)

            return self._aggregate(record_type, **kwargs)

    def _aggregate(self, record_type, sort=None, projection=None, limit=None,
                   filter=None, count=False, skip=None):
        query, key = self._generate_query(record_type, sort, projection, filter,
                                     count)

        res = self._query(query, key=key)
        if count:
            return CountResult(res, count)
        rows = self._get_property_values(res, projection)
        return FindResult(rows, projection, sort, limit, skip,
                          record_types=self.cached_record_types,
                          connection=self)

    def _generate_query(self, record_type, sort, projection, filter, count):
        prefix = 'FIND'
        filter_clause = ""
        if filter is not None:
            filter_clause += " WITH" + self._get_filter_clause(filter)
        if count:
            prefix = "COUNT"
        elif sort is not None and projection is not None:
            prefix = 'SELECT ' + ", ".join(list(projection) +
                                           list(sort.keys())) + " FROM"
        query = f'{prefix} RECORD "{record_type}"{filter_clause}'
        key = query if not count else None
        return query, key

    def delete_many(self, record_type, filter):
        container = self._find(record_type, filter=filter)
        if len(container) > 0:
            container.delete()
        return DeleteManyResult(container)

class CaosDBClient:
    def __init__(self, enforce_schema=True,
                 delegate=DefaultCaosDBClientDelegate, **kwargs):
        self.enforce_schema = enforce_schema
        self._delegate = delegate(**kwargs)

    @property
    def cached_record_types(self):
        return self._delegate.cached_record_types

    @property
    def cached_properties(self):
        return self._delegate.cached_properties

    def list_record_type_names(self):
        """list_record_types

        List all record types names.

        Returns
        -------
        record_types : list of str
        """
        LOGGER.debug("list_record_type_names")
        return self._delegate.list_record_type_names()

    def close(self):
        """close the connection to a CaosDB Server"""
        LOGGER.debug("close")

    def create_record_type(self, name : str, properties : list, parents: list):
        """create_record_type

        Parameters
        ----------

        name : str
            name is the name of the new record type
        properties : list of dict
            properties is a list of {name: str, data_type: str} dicts.

        Returns
        -------
        None
        """
        LOGGER.debug("create_record_type(%s, %s, %s)", name, properties, parents)
        self._delegate.create_record_type(name, properties, parents)

    def create_index(self, *args, **kwargs):
        """create_index

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def add_foreign_key(self, *args, **kwargs):
        """add_foreign_key

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        LOGGER.debug("add_foreign_key(%s, %s)", args, kwargs)
        return self._delegate.add_foreign_key(*args, **kwargs)

    def update_one(self, *args, **kwargs):
        """update_one

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def update_many(self, *args, **kwargs):
        """update_many

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        LOGGER.debug("update_many(%s, %s)", args, kwargs)
        if args[0].startswith("caosdb_auth_"):
            return self._delegate.update_many_auth(*args, **kwargs)
        return self._delegate.update_many(*args, **kwargs)

    def drop_record_type(self, *args, **kwargs):
        """drop_record_type

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def aggregate(self, record_type, *args, **kwargs):
        """aggregate

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        if record_type.startswith("caosdb_auth_"):
            return self._delegate.aggregate_auth(record_type, records)
        LOGGER.debug("aggregate(%s, %s, %s)", record_type, args, kwargs)
        return self._delegate.aggregate(record_type, *args, **kwargs)

    def insert_many(self, record_type : str, records : list):
        """insert_many

        Parameters
        ----------

        record_type : str
            name of the record type
        records : list of dict
            list of properties for the new records

        Returns
        -------
        result : InsertManyResult
        """
        LOGGER.debug("insert_many(%s, %s)", record_type, records)
        if record_type.startswith("caosdb_auth_"):
            return self._delegate.insert_auth(record_type, records)
        return self._delegate.insert_many(record_type, records)

    def update(self, *args, **kwargs):
        """update

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def rename(self, *args, **kwargs):
        """rename

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def delete_many(self, *args, **kwargs):
        """delete_many

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        LOGGER.debug("delete_many(%s, %s)", args, kwargs)
        return self._delegate.delete_many(*args, **kwargs)

    def drop_index(self, *args, **kwargs):
        """drop_index

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        """
        # TODO
        raise Exception("NOT IMPLEMENTED")

    def find(self, record_type, *args, **kwargs):
        """find

        Parameters
        ----------

        *args :
        **kwargs :

        Returns
        -------
        result : FindResult
        """
        LOGGER.debug("find(%s, %s, %s)", record_type, args, kwargs)
        if record_type.startswith("caosdb_auth_"):
            return self._delegate.find_auth(record_type, *args, **kwargs)
        return self._delegate.find(record_type, *args, **kwargs)