#!/usr/bin/env python3
# encoding: utf-8
#
# This file is a part of the CaosDB Project.
#
# Copyright (C) 2023 Indiscale GmbH <info@indiscale.com>
# Copyright (C) 2023 Timm Fitschen <t.fitschen@indiscale.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

import json
import logging
import os
import sys
import urllib

import linkahead as db
import pandas as pd
from caosadvancedtools.serverside import helper
from caosadvancedtools.table_export import BaseTableExporter
from caoscrawler.config import get_config_setting
from caoscrawler.logging import configure_server_side_logging
from linkahead.cached import cached_get_entity_by, cached_query as cquery
from linkahead.common.datatype import get_id_of_datatype
from linkahead.exceptions import (EmptyUniqueQueryError, QueryNotUniqueError,
                                  TransactionError)

from bis_utils import (create_email_with_link_text,
                       get_description_row, get_email_from_username,
                       get_options_row, send_mail_with_defaults)
from export_container_csv import (generate_label_text,
                                  extract_storage_chain as container_storage_chain)
from sample_helpers.sample_registration_get_person_identifier import get_person_identifier_from_rec
from sample_helpers.sample_upload_column_definitions import (
    DATATYPE_DEFINITIONS, SPECIAL_TREATMENT_SAMPLE as
    SPECIAL_TREATMENT, use_custom_names)
from sample_helpers.utils import (CONSTANTS, get_column_header_name,
                                  get_entity_name)


# suppress warning of diff function
apilogger = logging.getLogger("linkahead.apiutils")
apilogger.setLevel(logging.ERROR)

logger = logging.getLogger("caosadvancedtools")

ERROR_PREFIX = CONSTANTS["error_prefix"]
ERROR_SUFFIX = CONSTANTS["error_suffix"]


def cached_record(i):

    return cached_get_entity_by(eid=i)


def cached_query(query, unique=False):

    if unique:
        return cached_get_entity_by(query=query)

    return cquery(query)


def reverse_semicolon_separated_list(value):
    if isinstance(value, list):
        ret_str = ";".join([str(val) for val in value if val is not None])
        if ',' in ret_str:
            ret_str = f"\"{ret_str}\""
        return ret_str
    else:
        return value


def collection_value(vals):
    return reverse_semicolon_separated_list(vals)


def person_value(vals):
    return reverse_semicolon_separated_list(vals)


def retrieve_values(ids, property_name):
    properties = [cached_query(
        f"SELECT '{property_name}' FROM ENTITY WITH id = '{i}'",
        unique=True).get_property(property_name) for i in ids if i is not None]
    return [p.value for p in properties if p is not None]


def get_enum_value(values):
    values = values if isinstance(values, list) else [values]
    referenced = [cached_record(i) for i in values if i is not None]
    results = []
    for e in referenced:
        if e.get_property("enumValue") is not None and e.get_property("enumValue").value is not None and len(e.get_property("enumValue").value) > 0:
            results.append(e.get_property("enumValue").value)
        elif e.name is not None and len(e.name) > 0:
            results.append(e.name)
        else:
            results.append(e.id)
    return results


def default_find(r, e):
    p = r.get_property(e)
    if p is not None and p.value is not None and p.is_reference():
        return get_enum_value(p.value)
    v = p.value if p is not None else None
    return v


def extract_value_as_list(record, key):
    p = record.get_property(key)
    values = p.value if p is not None else []
    if not isinstance(values, list):
        if values is None:
            return []
        values = [values]
    return values


def extract_storage_id(record, key):
    return extract_value_as_list(record, get_entity_name("container_rt"))


def extract_pdf_id(record, key):
    prop = record.get_property(key)
    return prop.value if prop is not None else None


def extract_person(record, key):
    ids = extract_value_as_list(record, key)
    person_recs = [cached_record(i) for i in ids]
    return [get_person_identifier_from_rec(r) for r in person_recs]


def extract_event_responsible(record_key):
    evt = retrieve_event(record)
    if len(evt) == 0:
        return None
    elif len(evt) > 1:
        logger.debug(f"Sample {record.id} references more than one event.")
        return None
    return extract_person(evt[0], get_entity_name("responsible_person_event"))


def extract_parent_sample(record, key):
    p = record.get_property("Parent sample")
    if p is not None:
        return p.value


def extract_reference_name(record, key):
    ids = extract_value_as_list(record, key)
    return [cached_query(f"SELECT 'name' FROM ENTITY WITH id = '{i}'", unique=True).name
            for i in ids if i is not None]


def retrieve_event(record):
    ids = extract_value_as_list(record, get_entity_name("event_rt"))
    if record.get_property(get_entity_name("event_rt")) is None:
        # there are cases where this property is named "Event"
        ids = extract_value_as_list(record, get_entity_name("event_rt"))
    return [cached_record(i) for i in ids]


def retrieve_positions(source_ev):
    pos_ids = extract_value_as_list(source_ev, "Position")
    return [cached_record(i) for i in pos_ids]


def has_parent(r, par):
    pars = [p.name for p in r.get_parents()]
    return par in pars


def extract_position(record, position, component):
    source_evs = retrieve_event(record)
    result = []
    for ev in source_evs:
        _pos = [pos for pos in retrieve_positions(ev)]

        old_pos = len([pos for pos in _pos if has_parent(pos, "Position")]) > 0
        if old_pos:
            if position == "StartPosition":
                result.append(_pos[0])
            elif len(_pos) > 1:
                result.append(_pos[-1])
        else:
            result.extend([pos for pos in retrieve_positions(ev) if
                           has_parent(pos, position)])
    return [pos.get_property(component).value for pos in result if pos.get_property(component) is not None]


def extract_ele_start(record, key):
    return extract_position(record, get_entity_name("StartPosition"), get_entity_name("elevation"))


def extract_ele_stop(record, key):
    return extract_position(record, get_entity_name("StopPosition"), get_entity_name("elevation"))


def extract_lat_start(record, key):
    return extract_position(record, get_entity_name("StartPosition"), get_entity_name("latitude"))


def extract_lat_stop(record, key):
    return extract_position(record, get_entity_name("StopPosition"), get_entity_name("latitude"))


def extract_lng_start(record, key):
    return extract_position(record, get_entity_name("StartPosition"), get_entity_name("longitude"))


def extract_lng_stop(record, key):
    return extract_position(record, get_entity_name("StopPosition"), get_entity_name("longitude"))


def extract_linkahead_url(record, key):
    # base_uri = db.get_config().get("Connection", "url")
    base_uri = get_config_setting("public_host_url")
    return urllib.parse.urljoin(base_uri, f"Entity/{record.id}")


def extract_doi(record, key):
    source_evs = retrieve_event(record)
    if len(source_evs) > 1:
        logger.error(
            f"Sample {record.id} references more than one event so no unique DOI can be exported.")
        return None
    elif len(source_evs) == 0:
        return None
    ev = source_evs[0]
    return ev.get_property(get_entity_name("igsn_doi_prop")).value if ev.get_property("igsn_doi_prop") is not None else None


def _extract_event_prop(record, key, ref=False):

    evt = retrieve_event(record)
    if len(evt) == 0:
        return None
    elif len(evt) > 1:
        logger.debug(f"Sample {record.id} references more than one event.")
        return None

    if ref:
        return extract_reference_name(evt[0], key)

    return extract_value_as_list(evt[0], key)


def extract_biome(record, key):

    return _extract_event_prop(record, get_entity_name("Biome"), ref=True)


def extract_campaign(record, key):

    return _extract_event_prop(record, get_entity_name("Campaign"), ref=True)


def extract_device(record, key):

    return _extract_event_prop(record, get_entity_name("Device"), ref=True)


def extract_end_date(record, key):

    return _extract_event_prop(record, get_entity_name("end_date_prop"))


def extract_start_date(record, key):

    return _extract_event_prop(record, get_entity_name("start_date_prop"))


def extract_level(record, key):

    return _extract_event_prop(record, get_entity_name("level"))


def extract_sphere(record, key):

    return _extract_event_prop(record, get_entity_name("Sphere"), ref=True)


def extract_locality_descr(record, key):

    return _extract_event_prop(record, get_entity_name("locality_description_prop"))


def extract_locality_name(record, key):

    return _extract_event_prop(record, get_entity_name("locality_name_prop"))


def extract_storage_chain(record, key):

    if record.get_property(get_entity_name("container_rt")) is not None and record.get_property(get_entity_name("container_rt")).value:

        cont_id = record.get_property(get_entity_name("container_rt")).value
        if isinstance(cont_id, list):
            if len(cont_id) > 1:
                logger.debug(f"Sample {record.id} has multiple containers.")
                return None
            if len(cont_id) == 0:
                return None
            cont_id = cont_id[0]
        container = cached_get_entity_by(eid=cont_id)
        container_chain = container_storage_chain(container, key)
        return f"{container_chain} → {generate_label_text(record)}"

    return None


def extract_event_url(record, key):

    events = retrieve_event(record)
    if not events:
        return None
    if len(events) == 1:
        return urllib.parse.urljoin(get_config_setting("public_host_url"), f"Entity/{events[0].id}")
    logger.debug(f"Sample {record.id} has multiple events.")
    return None


def extract_event_name(record, key):

    events = retrieve_event(record)
    if not events:
        return None
    if len(events) == 1:
        return events[0].name
    logger.debug(f"Sample {record.id} has multiple events.")
    return None


def extract_sample_name(record, key):

    return record.name


# must include all keys from SPECIAL_TREATMENT
EXTRACTORS = use_custom_names({
    "entity_id": lambda record, key: record.id,
    "Main User": extract_person,
    "Biome": extract_biome,
    "Campaign": extract_campaign,
    "Collector": extract_person,
    "Curator": extract_person,
    "Device": extract_device,
    "Elevation start": extract_ele_start,
    "Elevation stop": extract_ele_stop,
    "Embargo": default_find,
    "End date": extract_end_date,
    "event_name": extract_event_name,
    "Latitude start": extract_lat_start,
    "Latitude stop": extract_lat_stop,
    "Level": extract_level,
    "LinkAhead URL": extract_linkahead_url,
    "Longitude start": extract_lng_start,
    "Longitude stop": extract_lng_stop,
    "PDFReport": extract_pdf_id,
    "PI": extract_person,
    "sample_name": extract_sample_name,
    "Sampling method": default_find,
    "Sphere": extract_sphere,
    "Start date": extract_start_date,
    "Storage ID": extract_storage_id,
    "Storage chain": extract_storage_chain,
    "URL Event": extract_event_url,
    "igsn_doi_prop": extract_doi,
    "locality_description_prop": extract_locality_descr,
    "locality_name_prop": extract_locality_name,
    "parent_sample_prop": extract_parent_sample,
    "responsible_person_event": extract_event_responsible
})

REVERSE_COLUMN_CONVERTER = use_custom_names({
})

# List of sample properties to be ignored because they are treated
# otherwise. Similar, but not identical to SPECIAL TREATMENT.
IGNORE_KEYS = use_custom_names([
    "parent_sample_prop",
    "container_rt",
    "event_rt",
])

# Additional list of keys to be ignored when extracting parent sample information
IGNORE_KEYS_PARENT = IGNORE_KEYS + use_custom_names([
    "entity_id",
])

# List of columns to be exported although they are not known to or ignored by
# the import.
ADDITIONAL_EXPORTS = use_custom_names([
    "LinkAhead URL",
    "parent_sample_prop",
    "Storage chain",
    "sample_name",
    "event_name"
])


def extract_value(r, e):
    e = _extract_key_from_parent_key(e)
    if e in EXTRACTORS:
        v = EXTRACTORS[e](r, e)
    else:
        v = default_find(r, e)
    if isinstance(v, str) and (',' in v or '\n' in v):
        # Quote text fields with commas in them
        v = f"\"{v}\""
    return v if v is not None else ""


class TableExporter(BaseTableExporter):
    pass


def _extract_key_from_parent_key(parent_key, parent_suffix="_parent"):

    while parent_key.endswith(parent_suffix):
        parent_key = parent_key[:-len(parent_suffix)]

    return parent_key


def gather_parent_information(parent_id, export_dict, level=1, parent_suffix="_parent"):

    # TODO: recursively go through parent samples, export their
    parent_dict = {}
    for key, val in export_dict.items():
        if key.lower() not in [ign.lower() for ign in IGNORE_KEYS_PARENT]:
            parent_dict[key+parent_suffix*level] = val
    parent_rec = cached_get_entity_by(eid=parent_id)
    table_exporter = TableExporter(parent_dict, record=parent_rec)
    table_exporter.keys = [e for e in parent_dict]
    table_exporter.collect_information()
    for e, d in table_exporter.export_dict.items():
        if _extract_key_from_parent_key(e, parent_suffix) in REVERSE_COLUMN_CONVERTER:
            table_exporter.info[e] = REVERSE_COLUMN_CONVERTER[_extract_key_from_parent_key(
                e, parent_suffix)](table_exporter.info[e])
        else:
            table_exporter.info[e] = reverse_semicolon_separated_list(table_exporter.info[e])

    parent_info = table_exporter.prepare_csv_export(print_header=False)
    parent_keys = list(parent_dict.keys())
    if parent_rec.get_property("Parent sample") is not None and parent_rec.get_property("Parent sample").value is not None:
        if isinstance(parent_rec.get_property("Parent sample").value, list):
            logger.warning(
                f"Sample {parent_rec.id} has multiple parent samples. Export not supported, skipping.")
        else:
            next_parent_info, next_parent_keys = gather_parent_information(
                parent_rec.get_property("Parent sample").value, export_dict, level=level+1)
            parent_info += next_parent_info
            parent_keys += next_parent_keys

    if len(parent_info) > 0:
        return ',' + parent_info, parent_keys

    return '', []


def to_csv(samples):

    export_dict = {}
    for c in DATATYPE_DEFINITIONS:
        export_dict[c] = {}
    for c in ADDITIONAL_EXPORTS:
        export_dict[c] = {}

    lower_case_keys = [e.lower() for e in export_dict]

    for s in samples:
        # collect other properties
        for p in s.get_properties():
            if (not p.name.lower() in lower_case_keys
                    and not p.name.lower() in [ign.lower() for ign in IGNORE_KEYS]):
                export_dict[p.name] = {}
                lower_case_keys.append(p.name.lower())

    for c in export_dict:
        export_dict[c]["find_func"] = extract_value
        export_dict[c]["optional"] = True

    keys = [e for e in export_dict]
    csv = []
    parent_csv_keys = []
    for s in samples:
        table_exporter = TableExporter(export_dict, record=s)
        table_exporter.all_keys = keys
        table_exporter.collect_information()
        logger.debug('<code>' + str(table_exporter.info) + '</code>')

        # Post-processing to values (e.g. list to string)
        for e, d in table_exporter.export_dict.items():
            if e in table_exporter.info:

                if e in REVERSE_COLUMN_CONVERTER:
                    table_exporter.info[e] = REVERSE_COLUMN_CONVERTER[e](table_exporter.info[e])
                else:
                    table_exporter.info[e] = reverse_semicolon_separated_list(
                        table_exporter.info[e])

        sample_info = table_exporter.prepare_csv_export(print_header=False)
        if s.get_property("Parent sample") is not None and s.get_property("Parent sample").value is not None:
            if isinstance(s.get_property("Parent sample").value, list):
                logger.warning(
                    f"Sample {s.id} has multiple parent samples. Export not supported, skipping.")
            else:
                parent_info, parent_keys = gather_parent_information(
                    s.get_property("Parent sample").value, export_dict, level=1)
                # Save the longest parent keys
                if len(parent_csv_keys) < len(parent_keys):
                    parent_csv_keys = parent_keys
                sample_info += parent_info
        csv.append(sample_info)

    # Extend header rows in case of parents
    csv_keys = keys + parent_csv_keys
    csv_descr = get_description_row([_extract_key_from_parent_key(k) for k in csv_keys])
    csv_options = get_options_row([_extract_key_from_parent_key(k) for k in csv_keys])

    return ",".join(csv_keys) + "\n" + ",".join(csv_descr) + '\n' + ",".join(csv_options) + '\n' + "\n".join(csv)


def retrieve_samples(data):
    container = []
    not_found = []
    for eid in data:
        if isinstance(eid, int):
            try:
                container.append(
                    cached_get_entity_by(query=f"FIND RECORD SAMPLE WITH id='{eid}'"))
            except EmptyUniqueQueryError as e:
                # we want to warn about these
                not_found.append(eid)
        else:
            found_at_least_one_in_range = False
            for next_eid in eid:
                try:
                    container.append(
                        cached_get_entity_by(query=f"FIND RECORD Sample WITH id='{next_eid}'"))
                    found_at_least_one_in_range = True
                except EmptyUniqueQueryError as e:
                    pass
            if not found_at_least_one_in_range:
                not_found.append(f"{eid.start}-{eid.stop-1}")
    return container, not_found


def sanity_check():
    for key in SPECIAL_TREATMENT:
        if not key in EXTRACTORS:
            raise Exception(f"No extraction method defined for key '{key}.")


def write_csv(file_name, csv, no_empty_columns):
    """Write the csv data in ``csv`` to with given ``file_name`` to the shared
    resource. Drop empy columns before writing if ``no_empty_columns`` is
    ``True``.

    """
    display_path, internal_path = helper.get_shared_filename(file_name)
    with open(internal_path, "w") as csv_file:
        csv_file.write(csv)
    if no_empty_columns:
        # Pandas seems to have problems with commas and quotation marks in the
        # description rows when loading the csv without ignoring comment
        # lines. So we need to restore the descriptions manually further down
        # the line.
        tmp = pd.read_csv(internal_path, comment='#', dtype=str)
        # drop all empty columns
        tmp.dropna(axis=1, inplace=True, how="all")
        # generate new description row and insert as the first "data row"
        new_descriptions = get_description_row(
            [_extract_key_from_parent_key(cname) for cname in tmp.columns])
        description_row_dict = {cname: descr for (
            cname, descr) in zip(tmp.columns, new_descriptions)}
        tmp.loc[-1] = description_row_dict
        tmp.index += 1
        tmp.sort_index(inplace=True)
        tmp.to_csv(internal_path, index=False)

    return display_path


def main():
    sanity_check()
    parser = helper.get_argument_parser()
    args = parser.parse_args()
    # Check whether executed locally or as an SSS depending on
    # auth_token argument.
    if hasattr(args, "auth_token") and args.auth_token:
        db.configure_connection(auth_token=args.auth_token)
        debug_file = configure_server_side_logging()
    else:
        rootlogger = logging.getLogger()
        rootlogger.setLevel(logging.INFO)
        logger.setLevel(logging.DEBUG)
        handler = logging.StreamHandler(stream=sys.stdout)
        handler.setLevel(logging.DEBUG)
        rootlogger.addHandler(handler)
        debug_file = "/tmp/upload_sample_debug.log"

    if hasattr(args, "filename") and args.filename:
        # Read the input from the form (form.json)
        with open(args.filename) as form_json:
            form_data = json.load(form_json)

            no_empty_columns = False
            if "noEmpyColumns" in form_data and form_data["noEmpyColumns"] == "on":
                logger.info("Removing empty columns from export")
                no_empty_columns = True

            if "query_string" in form_data and form_data["query_string"]:
                query_string = form_data["query_string"]
                if not query_string.lower().startswith("find ") and not query_string.lower().startswith("select "):
                    logger.error(
                        f"The query '{query_string}' dosn't seem to be a valid select or find query.")
                    return
                if query_string.lower().startswith("find "):
                    # transform to select query for performance
                    query_string = "SELECT id FROM" + query_string[4:]
                try:
                    data = [el.id for el in db.execute_query(query_string)]
                except db.TransactionError as te:
                    logger.error(
                        f"There is a problem with the given query '{query_string}':\n"
                        f"```\n{str(te)}\n```"
                    )
                    return
            else:
                if not form_data["ids"]:
                    logger.error(
                        "Please specify the samples to be exported either by query or by id(s).")
                    return
                tmp = form_data["ids"].split(",")
                data = []
                for d in tmp:
                    if "-" in d:
                        bound = [int(b) for b in d.split("-")]
                        data.append(range(min(bound), max(bound) + 1))
                    else:
                        data.append(int(d.strip()))

        samples, not_found = retrieve_samples(data)

        if len(samples) == 0:
            logger.error("No samples in the given range.")
            return

        for s in samples:
            logger.debug("Found sample " + str(s.id))
        for s in not_found:
            logger.warning("No samples found: " + str(s))

        csv = to_csv(samples)

        max_id = max([s.id for s in samples])
        min_id = min([s.id for s in samples])
        file_name = f"samples_export_(IDs_{min_id}_to_{max_id}).csv"
        display_path = write_csv(file_name, csv, no_empty_columns)
        logger.info("Your CSV-Export has been prepared successfully.\n" +
                    f"Download the file <a href=/Shared/{display_path}>here</a>.")
        try:
            send_mail_with_defaults(
                to=get_email_from_username(),
                subject=f"BIS sample export {file_name}",
                body=create_email_with_link_text("sample export", display_path)
            )
        except KeyError as ke:
            logger.error(
                "There is a problem with the server's email configuration:\n\n"
                f"{ke}\n\nPlease contact your admin."
            )
    else:
        msg = "{}export_sample_csv.py was called without the JSON file in args.{}".format(
            ERROR_PREFIX, ERROR_SUFFIX)
        logger.error(msg)


if __name__ == "__main__":
    main()