Skip to content
Snippets Groups Projects
test_table_converter.py 4.06 KiB
Newer Older
#!/usr/bin/env python
# encoding: utf-8
#
# ** header v3.0
# This file is a part of the CaosDB Project.
#
# Copyright (C) 2019 Henrik tom Wörden
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# ** end header
import os
import unittest
from tempfile import NamedTemporaryFile

import caosdb as db
import pandas as pd
from caosdb.apiutils import compare_entities
from numpy import nan
from caosadvancedtools.table_converter import (from_table, from_tsv, to_table,
                                               to_tsv)

TEST_TABLE = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                          "test.tsv")


class TableTest(unittest.TestCase):
    def test_basic(self):
        df = pd.read_csv(TEST_TABLE, sep="\t")
        assert isinstance(from_table(df, "Measurement"), db.Container)

    def test_empty(self):
        c = db.Container()
        df = to_table(c)
        assert df.shape == (0, 0)

    def test_different_props(self):
        r1 = db.Record()
        r1.add_parent("no1")
        r1.add_property("p1")
        r2 = db.Record()
        r2.add_parent("no1")
        r2.add_property("p1")
        r2.add_property("p2")
        c = db.Container()
        c.extend([r1, r2])
        to_table(c)

    def test_parents(self):
        r1 = db.Record()
        r1.add_parent("no1")
        r2 = db.Record()
        r2.add_parent("no2")
        c = db.Container()
        c.extend([r1, r2])
        self.assertRaises(ValueError, to_table, c)

Daniel Hornung's avatar
Daniel Hornung committed
    def test_list(self):
        r1 = db.Record()
        r1.add_parent("no1")
        r1.add_property("p1", value=1)
        r1.add_property("p3", value=23)
        r1.add_property("p4", value=[1])
Daniel Hornung's avatar
Daniel Hornung committed
        r2 = db.Record()
        r2.add_parent("no1")
        r2.add_property("p1")
        r2.add_property("p2", value=[20, 21])
        r2.add_property("p3", value=[30, 31])
        r2.add_property("p4", value=[40.0, 41.0])
        r3 = db.Record()
        r3.add_parent("no1")
        r3.add_property("p5", value=[50, 51])
Daniel Hornung's avatar
Daniel Hornung committed
        c = db.Container()
        c.extend([r1, r2, r3])
        result = to_table(c)
        # NaN is hard to compare, so we replace it by -999
        # autopep8: off
        assert result.replace(to_replace=nan, value=-999).to_dict() == {
            'p1': {0: 1,    1: -999,         2: -999},  # noqa: E202
            'p3': {0: 23,   1: [30, 31],     2: -999},  # noqa: E202
            'p4': {0: [1],  1: [40.0, 41.0], 2: -999},  # noqa: E202
            'p2': {0: -999, 1: [20, 21],     2: -999},  # noqa: E202
            'p5': {0: -999, 1: -999,         2: [50, 51]}
        }
        # autopep8: on
        assert list(result.dtypes) == [float, object, object, object, object]

class FromTsvTest(unittest.TestCase):
    def test_basic(self):
        from_tsv(TEST_TABLE, "Measurement")


class ToTsvTest(unittest.TestCase):
    def test_basic(self):
        r = db.Record()
        r.add_property("ha", 5)
        r.add_parent("hu")
        c = db.Container()
        c.append(r)
        to_tsv(NamedTemporaryFile().name, c)

# TODO reactivate this test
# class IntegrationTest(unittest.TestCase):
#    """ converts  tsv to a container and back and compares origin with
#    result """
#
#    def test_backandforth(self):
#        cont = from_tsv(TEST_TABLE, "Measurement")
#        tempfile = NamedTemporaryFile(delete=False)
#        to_tsv(tempfile.name, cont)
#        cont_new = from_tsv(tempfile.name, "Measurement")
#
#        for ent1, ent2 in zip(cont_new, cont):
#            assert compare_entities(ent1, ent2) == ([], [])