Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
export_sample_csv.py 25.05 KiB
#!/usr/bin/env python3
# encoding: utf-8
#
# This file is a part of the CaosDB Project.
#
# Copyright (C) 2023 Indiscale GmbH <info@indiscale.com>
# Copyright (C) 2023 Timm Fitschen <t.fitschen@indiscale.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

import json
import logging
import os
import sys
import urllib
from datetime import date, datetime
from typing import List

import linkahead as db
import pandas as pd
from caosadvancedtools.datainconsistency import DataInconsistencyError
from caosadvancedtools.serverside import helper
from caosadvancedtools.table_export import BaseTableExporter
from caosadvancedtools.table_importer import CSVImporter
from caoscrawler import Crawler, SecurityMode
from caoscrawler.crawl import ForbiddenTransaction
from caoscrawler.identifiable_adapters import CaosDBIdentifiableAdapter
from caoscrawler.logging import configure_server_side_logging
from linkahead.cached import cached_get_entity_by, cached_query as cquery
from linkahead.common.datatype import get_id_of_datatype
from linkahead.exceptions import (EmptyUniqueQueryError, QueryNotUniqueError,
                                  TransactionError)
from dateutil import parser as dateparser
from dateutil.parser import isoparse

from bis_utils import (create_email_with_link_text,
                       get_description_row, get_email_from_username,
                       get_options_row, send_mail_with_defaults,
                       SPECIAL_TREATMENT_SAMPLE as SPECIAL_TREATMENT)
from export_container_csv import (generate_label_text,
                                  extract_storage_chain as container_storage_chain)
from upload_sample_template import DATATYPE_DEFINITIONS

# suppress warning of diff function
apilogger = logging.getLogger("linkahead.apiutils")
apilogger.setLevel(logging.ERROR)

logger = logging.getLogger("caosadvancedtools")

ERROR_PREFIX = 'Something went wrong: '
ERROR_SUFFIX = ' Please conatct <a href="mailto:biosamples@geomar.de">biosamples@geomar.de</a> if you encounter this issue.'


def cached_record(i):

    return cached_get_entity_by(eid=i)


def cached_query(query, unique=False):

    if unique:
        return cached_get_entity_by(query=query)

    return cquery(query)


def reverse_semicolon_separated_list(value):
    if isinstance(value, list):
        return ";".join([str(val) for val in value if val is not None])
    else:
        return value


def collection_value(vals):
    return reverse_semicolon_separated_list(vals)


def person_value(vals):
    return reverse_semicolon_separated_list(vals)


def retrieve_values(ids, property_name):
    properties = [cached_query(
        f"SELECT '{property_name}' FROM ENTITY WITH id = '{i}'",
        unique=True).get_property(property_name) for i in ids if i is not None]
    return [p.value for p in properties if p is not None]


def get_enum_value(values):
    values = values if isinstance(values, list) else [values]
    referenced = [cached_record(i) for i in values if i is not None]
    results = []
    for e in referenced:
        if e.get_property("enumValue") is not None and e.get_property("enumValue").value is not None and len(e.get_property("enumValue").value) > 0:
            results.append(e.get_property("enumValue").value)
        elif e.name is not None and len(e.name) > 0:
            results.append(e.name)
        else:
            results.append(e.id)
    return results


def default_find(r, e):
    p = r.get_property(e)
    if p is not None and p.value is not None and p.is_reference():
        return get_enum_value(p.value)
    v = p.value if p is not None else None
    return v


def extract_value_as_list(record, key):
    p = record.get_property(key)
    values = p.value if p is not None else []
    if not isinstance(values, list):
        if values is None:
            return []
        values = [values]
    return values


def extract_storage_id(record, key):
    return extract_value_as_list(record, "Container")


def extract_pdf_id(record, key):
    prop = record.get_property(key)
    return prop.value if prop is not None else None


def extract_storage_container_label(record, key):
    ids = extract_value_as_list(record, "Container")
    return retrieve_values(ids, 'BIS Label')


def extract_nagoya_case_number(record, key):
    ids = extract_value_as_list(record, "NagoyaCase")
    return retrieve_values(ids, key)


def extract_person(record, key):
    ids = extract_value_as_list(record, key)
    return retrieve_values(ids, 'Abbreviation')


def extract_parent_sample(record, key):
    p = record.get_property("Parent sample")
    if p is not None:
        return p.value


def extract_reference_name(record, key):
    ids = extract_value_as_list(record, key)
    return [cached_query(f"SELECT 'name' FROM ENTITY WITH id = '{i}'", unique=True).name
            for i in ids if i is not None]


def retrieve_source_event(record):
    ids = extract_value_as_list(record, 'SourceEvent')
    if record.get_property("SourceEvent") is None:
        # there are cases where this property is named "Event"
        ids = extract_value_as_list(record, 'Event')
    return [cached_record(i) for i in ids]


def retrieve_gear(record):
    ids = [e.get_property("Gear").value for e in retrieve_source_event(record)
           if e.get_property("Gear") is not None]
    return [cached_query(f"SELECT 'parent', 'Configuration' FROM ENTITY WITH id = '{i}'", unique=True) for i in ids]


def extract_gear(record, key):
    return [e.get_parents()[0].name for e in retrieve_gear(record)]


def extract_gear_configuration(record, key):
    return [e.get_property("Configuration").value for e in
            retrieve_gear(record)
            if e.get_property("Configuration") is not None]


def extract_date_time(record, p):
    if p.lower() == "time start" or p.lower() == "time stop":
        # these are attached to the source event directly
        return [e.get_property(p).value for e in retrieve_source_event(record) if
                e.get_property(p) is not None and e.get_property(p).value is not None]
    else:
        return extract_value_as_list(record, p)


def extract_station_number(record, key):
    source_ev = retrieve_source_event(record)
    return [e.get_property(key).value for e in source_ev if
            e.get_property(key) is not None]


def extract_station_id(record, key):
    source_ev = retrieve_source_event(record)
    return [e.get_property(key).value for e in source_ev if
            e.get_property(key) is not None]


def retrieve_positions(source_ev):
    pos_ids = extract_value_as_list(source_ev, "Position")
    return [cached_record(i) for i in pos_ids]


def has_parent(r, par):
    pars = [p.name for p in r.get_parents()]
    return par in pars


def extract_position(record, position, component):
    source_evs = retrieve_source_event(record)
    result = []
    for ev in source_evs:
        _pos = [pos for pos in retrieve_positions(ev)]

        old_pos = len([pos for pos in _pos if has_parent(pos, "Position")]) > 0
        if old_pos:
            if position == "StartPosition":
                result.append(_pos[0])
            elif len(_pos) > 1:
                result.append(_pos[-1])
        else:
            result.extend([pos for pos in retrieve_positions(ev) if
                           has_parent(pos, position)])
    return [pos.get_property(component).value for pos in result if pos.get_property(component) is not None]


def extract_lat_start(record, key):
    return extract_position(record, "StartPosition", "Latitude")


def extract_lat_stop(record, key):
    return extract_position(record, "StopPosition", "Latitude")


def extract_lng_start(record, key):
    return extract_position(record, "StartPosition", "Longitude")


def extract_lng_stop(record, key):
    return extract_position(record, "StopPosition", "Longitude")


def extract_sampling_depth_start(record, key):
    return extract_position(record, "StartPosition", "Sampling depth")


def extract_sampling_depth_stop(record, key):
    return extract_position(record, "StopPosition", "Sampling depth")


def extract_water_depth_start(record, key):
    return extract_position(record, "StartPosition", "Water depth")


def extract_water_depth_stop(record, key):
    return extract_position(record, "StopPosition", "Water depth")


def extract_source_event_name(record, key):
    return [e.name for e in retrieve_source_event(record)]


def extract_hol(record, key):
    source_ev = retrieve_source_event(record)
    return [e.get_property(key).value for e in source_ev if
            e.get_property(key) is not None]


def extract_bis_url(record, key):
    # base_uri = db.get_config().get("Connection", "url")
    base_uri = "https://biosamples.geomar.de/"
    return urllib.parse.urljoin(base_uri, f"Entity/{record.id}")


def extract_igsn(record, key):
    source_evs = retrieve_source_event(record)
    if len(source_evs) > 1:
        logger.error(
            f"Sample {record.id} references more than one SourceEvent so no unique IGSN can be exported.")
        return None
    elif len(source_evs) == 0:
        return None
    ev = source_evs[0]
    return ev.get_property(key).value if ev.get_property(key) is not None else None


def extract_doi(record, key):
    source_evs = retrieve_source_event(record)
    if len(source_evs) > 1:
        logger.error(
            f"Sample {record.id} references more than one SourceEvent so no unique DOI can be exported.")
        return None
    elif len(source_evs) == 0:
        return None
    ev = source_evs[0]
    return ev.get_property("DOI").value if ev.get_property("DOI") is not None else None


def extract_storage_chain(record, key):

    if record.get_property("Container") is not None and record.get_property("Container").value:

        cont_id = record.get_property("Container").value
        if isinstance(cont_id, list):
            if len(cont_id) > 1:
                logger.debug(f"Sample {record.id} has multiple containers.")
                return None
            if len(cont_id) == 0:
                return None
            cont_id = cont_id[0]
        container = cached_get_entity_by(eid=cont_id)
        container_chain = container_storage_chain(container, key)
        return f"{container_chain}{generate_label_text(record)}"

    return None


def extract_event_url(record, key):

    events = retrieve_source_event(record)
    if not events:
        return None
    if len(events) == 1:
        return urllib.parse.urljoin("https://biosamples.geomar.de", f"Entity/{events[0].id}")
    logger.debug(f"Sample {record.id} has multiple events.")
    return None


# must include all keys from SPECIAL_TREATMENT
EXTRACTORS = {
    "BIS ID": lambda record, key: record.id,
    "Parent BIS ID": extract_parent_sample,
    "AphiaID": default_find,
    "Collection": extract_reference_name,
    "Date collected start": extract_date,
    "Date collected stop": extract_date,
    "Date sampled start": extract_date,
    "Date sampled stop": extract_date,
    "Main User": extract_person,
    "Sampling Person": extract_person,
    "PI": extract_person,
    "Person": extract_person,
    "Gear": extract_gear,
    "Gear configuration": extract_gear_configuration,
    "Latitude start": extract_lat_start,
    "Longitude start": extract_lng_start,
    "Storage ID": extract_storage_id,
    "Nagoya case number": extract_nagoya_case_number,
    "PDFReport": extract_pdf_id,
    "Subevent": extract_source_event_name,
    "Station ID": extract_station_id,
    "Station number": extract_station_number,
    "Sampling depth start": extract_sampling_depth_start,
    "Sampling depth stop": extract_sampling_depth_stop,
    "Water depth start": extract_water_depth_start,
    "Water depth stop": extract_water_depth_stop,
    "Latitude stop": extract_lat_stop,
    "Longitude stop": extract_lng_stop,
    "Storage chain": extract_storage_chain,
    "Storage Container Label": extract_storage_container_label,
    "Hol": extract_hol,
    "Sampling method": default_find,
    # "Publications": TODO never used
    # "NCBI BioProject": TODO never used
    # "NCBI BioSample": TODO never used
    # "NCBI Accession": TODO never used
    "BIS URL": extract_bis_url,
    "IGSN": extract_igsn,
    "IGSN URL": extract_doi,
    "Sphere": default_find,
    "URL SourceEvent": extract_event_url,
}

REVERSE_COLUMN_CONVERTER = {
    "Collection": collection_value,
    "PI": person_value,
    "Person": person_value,
}

# List of sample properties to be ignored because they are treated
# otherwise. Similar, but not identical to SPECIAL TREATMENT.
IGNORE_KEYS = [
    "Parent Sample",
    "Container",
    "Event",
]

# Additional list of keys to be ignored when extracting parent sample information
IGNORE_KEYS_PARENT = IGNORE_KEYS + [
    "BIS ID",
]

# List of columns to be exported although they are not known to or ignored by
# the import.
ADDITIONAL_EXPORTS = [
    "BIS URL",
    "Parent BIS ID",
    "Storage chain",
]


def extract_value(r, e):
    e = _extract_key_from_parent_key(e)
    if e in EXTRACTORS:
        v = EXTRACTORS[e](r, e)
    else:
        v = default_find(r, e)
    if isinstance(v, str) and (',' in v or '\n' in v):
        # Quote text fields with commas in them
        v = f"\"{v}\""
    return v if v is not None else ""


class TableExporter(BaseTableExporter):
    pass


def _extract_key_from_parent_key(parent_key, parent_suffix="_parent"):

    while parent_key.endswith(parent_suffix):
        parent_key = parent_key[:-len(parent_suffix)]

    return parent_key


def gather_parent_information(parent_id, export_dict, level=1, parent_suffix="_parent"):

    # TODO: recursively go through parent samples, export their
    parent_dict = {}
    for key, val in export_dict.items():
        if key.lower() not in [ign.lower() for ign in IGNORE_KEYS_PARENT]:
            parent_dict[key+parent_suffix*level] = val
    parent_rec = cached_get_entity_by(eid=parent_id)
    table_exporter = TableExporter(parent_dict, record=parent_rec)
    table_exporter.keys = [e for e in parent_dict]
    table_exporter.collect_information()
    for e, d in table_exporter.export_dict.items():
        if _extract_key_from_parent_key(e, parent_suffix) in REVERSE_COLUMN_CONVERTER:
            table_exporter.info[e] = REVERSE_COLUMN_CONVERTER[_extract_key_from_parent_key(
                e, parent_suffix)](table_exporter.info[e])
        else:
            table_exporter.info[e] = reverse_semicolon_separated_list(table_exporter.info[e])

    parent_info = table_exporter.prepare_csv_export(print_header=False)
    parent_keys = list(parent_dict.keys())
    if parent_rec.get_property("Parent sample") is not None and parent_rec.get_property("Parent sample").value is not None:
        if isinstance(parent_rec.get_property("Parent sample").value, list):
            logger.warning(
                f"Sample {parent_rec.id} has multiple parent samples. Export not supported, skipping.")
        else:
            next_parent_info, next_parent_keys = gather_parent_information(
                parent_rec.get_property("Parent sample").value, export_dict, level=level+1)
            parent_info += next_parent_info
            parent_keys += next_parent_keys

    if len(parent_info) > 0:
        return ',' + parent_info, parent_keys

    return '', []


def to_csv(samples):

    export_dict = {}
    for c in DATATYPE_DEFINITIONS:
        export_dict[c] = {}
    for c in ADDITIONAL_EXPORTS:
        export_dict[c] = {}

    lower_case_keys = [e.lower() for e in export_dict]

    for s in samples:
        # collect other properties
        for p in s.get_properties():
            if (not p.name.lower() in lower_case_keys
                    and not p.name.lower() in [ign.lower() for ign in IGNORE_KEYS]):
                export_dict[p.name] = {}
                lower_case_keys.append(p.name.lower())

    for c in export_dict:
        export_dict[c]["find_func"] = extract_value
        export_dict[c]["optional"] = True

    keys = [e for e in export_dict]
    csv = []
    parent_csv_keys = []
    for s in samples:
        table_exporter = TableExporter(export_dict, record=s)
        table_exporter.all_keys = keys
        table_exporter.collect_information()
        logger.debug('<code>' + str(table_exporter.info) + '</code>')

        # Post-processing to values (e.g. list to string)
        for e, d in table_exporter.export_dict.items():
            if e in table_exporter.info:

                if e in REVERSE_COLUMN_CONVERTER:
                    table_exporter.info[e] = REVERSE_COLUMN_CONVERTER[e](table_exporter.info[e])
                else:
                    table_exporter.info[e] = reverse_semicolon_separated_list(
                        table_exporter.info[e])

        sample_info = table_exporter.prepare_csv_export(print_header=False)
        if s.get_property("Parent sample") is not None and s.get_property("Parent sample").value is not None:
            if isinstance(s.get_property("Parent sample").value, list):
                logger.warning(
                    f"Sample {s.id} has multiple parent samples. Export not supported, skipping.")
            else:
                parent_info, parent_keys = gather_parent_information(
                    s.get_property("Parent sample").value, export_dict, level=1)
                # Save the longest parent keys
                if len(parent_csv_keys) < len(parent_keys):
                    parent_csv_keys = parent_keys
                sample_info += parent_info
        csv.append(sample_info)

    # Extend header rows in case of parents
    csv_keys = keys + parent_csv_keys
    csv_descr = get_description_row([_extract_key_from_parent_key(k) for k in csv_keys])
    csv_options = get_options_row([_extract_key_from_parent_key(k) for k in csv_keys])

    return ",".join(csv_keys) + "\n" + ",".join(csv_descr) + '\n' + ",".join(csv_options) + '\n' + "\n".join(csv)


def retrieve_samples(data):
    container = []
    not_found = []
    for bis_id in data:
        if isinstance(bis_id, int):
            try:
                container.append(
                    cached_get_entity_by(query=f"FIND RECORD SAMPLE WITH id='{bis_id}'"))
            except EmptyUniqueQueryError as e:
                # we want to warn about these
                not_found.append(bis_id)
        else:
            found_at_least_one_in_range = False
            for next_bis_id in bis_id:
                try:
                    container.append(
                        cached_get_entity_by(query=f"FIND RECORD Sample WITH id='{next_bis_id}'"))
                    found_at_least_one_in_range = True
                except EmptyUniqueQueryError as e:
                    pass
            if not found_at_least_one_in_range:
                not_found.append(f"{bis_id.start}-{bis_id.stop-1}")
    return container, not_found


def sanity_check():
    for key in SPECIAL_TREATMENT:
        if not key in EXTRACTORS:
            raise Exception(f"No extraction method defined for key '{key}.")


def write_csv(file_name, csv, no_empty_columns):
    """Write the csv data in ``csv`` to with given ``file_name`` to the shared
    resource. Drop empy columns before writing if ``no_empty_columns`` is
    ``True``.

    """
    display_path, internal_path = helper.get_shared_filename(file_name)
    with open(internal_path, "w") as csv_file:
        csv_file.write(csv)
    if no_empty_columns:
        # Pandas seems to have problems with commas and quotation marks in the
        # description rows when loading the csv without ignoring comment
        # lines. So we need to restore the descriptions manually further down
        # the line.
        tmp = pd.read_csv(internal_path, comment='#', dtype=str)
        # drop all empty columns
        tmp.dropna(axis=1, inplace=True, how="all")
        # generate new description row and insert as the first "data row"
        new_descriptions = get_description_row(
            [_extract_key_from_parent_key(cname) for cname in tmp.columns])
        description_row_dict = {cname: descr for (
            cname, descr) in zip(tmp.columns, new_descriptions)}
        tmp.loc[-1] = description_row_dict
        tmp.index += 1
        tmp.sort_index(inplace=True)
        tmp.to_csv(internal_path, index=False)

    return display_path


def main():
    sanity_check()
    parser = helper.get_argument_parser()
    args = parser.parse_args()
    # Check whether executed locally or as an SSS depending on
    # auth_token argument.
    if hasattr(args, "auth_token") and args.auth_token:
        db.configure_connection(auth_token=args.auth_token)
        debug_file = configure_server_side_logging()
    else:
        rootlogger = logging.getLogger()
        rootlogger.setLevel(logging.INFO)
        logger.setLevel(logging.DEBUG)
        handler = logging.StreamHandler(stream=sys.stdout)
        handler.setLevel(logging.DEBUG)
        rootlogger.addHandler(handler)
        debug_file = "/tmp/upload_sample_debug.log"

    if hasattr(args, "filename") and args.filename:
        # Read the input from the form (form.json)
        with open(args.filename) as form_json:
            form_data = json.load(form_json)

            no_empty_columns = False
            if "noEmpyColumns" in form_data and form_data["noEmpyColumns"] == "on":
                logger.info("Removing empty columns from export")
                no_empty_columns = True

            if "from_date" in form_data:
                # Inserted after ...
                data = [el.id for el in db.execute_query(
                    "SELECT id FROM sample WHICH REFERENCES A SourceEvent "
                    "WHICH HAS AN IGSN AND "
                    f"(WHICH WAS INSERTED SINCE {form_data['from_date']})")
                ]
                # ... + update after
                data += [el.id for el in db.execute_query(
                    "SELECT id FROM sample WHICH REFERENCES A SourceEvent "
                    "WHICH HAS AN IGSN AND "
                    f"(WHICH WAS UPDATED SINCE {form_data['from_date']})")
                ]
            elif "query_string" in form_data and form_data["query_string"]:
                query_string = form_data["query_string"]
                if not query_string.lower().startswith("find ") and not query_string.lower().startswith("select "):
                    logger.error(
                        f"The query '{query_string}' dosn't seem to be a valid select or find query.")
                    return
                if query_string.lower().startswith("find "):
                    # transform to select query for performance
                    query_string = "SELECT id FROM" + query_string[4:]
                try:
                    data = [el.id for el in db.execute_query(query_string)]
                except db.TransactionError as te:
                    logger.error(
                        f"There is a problem with the given query '{query_string}':\n"
                        f"```\n{str(te)}\n```"
                    )
                    return
            else:
                if not form_data["bis_ids"]:
                    logger.error(
                        "Please specify the samples to be exported either by query or by id(s).")
                    return
                tmp = form_data["bis_ids"].split(",")
                data = []
                for d in tmp:
                    if "-" in d:
                        bound = [int(b) for b in d.split("-")]
                        data.append(range(min(bound), max(bound) + 1))
                    else:
                        data.append(int(d.strip()))

        samples, not_found = retrieve_samples(data)

        if len(samples) == 0:
            logger.error("No samples in the given range.")
            return

        for s in samples:
            logger.debug("Found sample " + str(s.id))
        for s in not_found:
            logger.warning("No samples found: " + str(s))

        csv = to_csv(samples)

        max_id = max([s.id for s in samples])
        min_id = min([s.id for s in samples])
        file_name = f"samples_export_(IDs_{min_id}_to_{max_id}).csv"
        display_path = write_csv(file_name, csv, no_empty_columns)
        logger.info("Your CSV-Export has been prepared successfully.\n" +
                    f"Download the file <a href=/Shared/{display_path}>here</a>.")
        try:
            send_mail_with_defaults(
                to=get_email_from_username(),
                subject=f"BIS sample export {file_name}",
                body=create_email_with_link_text("sample export", display_path)
            )
        except KeyError as ke:
            logger.error(
                "There is a problem with the server's email configuration:\n\n"
                f"{ke}\n\nPlease contact your admin."
            )
    else:
        msg = "{}export_sample_csv.py was called without the JSON file in args.{}".format(
            ERROR_PREFIX, ERROR_SUFFIX)
        logger.error(msg)


if __name__ == "__main__":
    main()