Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_cache.py 8.90 KiB
#!/usr/bin/env python
# encoding: utf-8
#
# ** header v3.0
# This file is a part of the LinkAhead project.
#
# Copyright (C) 2019 Henrik tom Wörden
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# ** end header
import os
import unittest
from copy import deepcopy
from tempfile import NamedTemporaryFile
import sqlite3

import linkahead as db
from caosadvancedtools.cache import IdentifiableCache, cleanXML
from lxml import etree

import pytest


class CacheTest(unittest.TestCase):
    def setUp(self):
        self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name,
                                       force_creation=True)

    def test_hash(self):
        ent = db.Record()
        assert isinstance(IdentifiableCache.hash_entity(ent), str)
        assert (IdentifiableCache.hash_entity(ent) !=
                IdentifiableCache.hash_entity(db.Record().add_parent("lol")))

    def test_insert(self):
        ent = db.Record()
        ent2 = db.Record()
        ent2.add_parent(name="Experiment")
        ent_hash = IdentifiableCache.hash_entity(ent)
        ent2_hash = IdentifiableCache.hash_entity(ent2)
        self.cache.insert(ent2_hash, 1235, "ajkfljadsklf")
        assert self.cache.check_existing(ent2_hash)[0] == 1235
        assert self.cache.check_existing(ent2_hash)[1] == "ajkfljadsklf"
        assert self.cache.check_existing(ent_hash) is None

    def test_hirarchy(self):
        assert isinstance(db.Record(), db.Entity)

    def tearDown(self):
        os.remove(self.cache.db_file)

    def test_update_ids_from_cache(self):
        ent = db.Record()
        ent2 = db.Record()
        ent2.add_parent(name="Experiment")
        ent3 = db.Record()
        ent3.add_parent(name="Analysis")
        test_id = 2353243
        self.cache.insert(IdentifiableCache.hash_entity(ent2), test_id, "ajdsklfjadslf")
        entities = [ent, ent2, ent3]
        hashes = self.cache.update_ids_from_cache(entities)
        self.assertEqual(ent.id, None)
        self.assertEqual(ent2.id, test_id)
        self.assertEqual(ent3.id, None)

        with pytest.raises(RuntimeError, match=r".*no ID.*"):
            self.cache.insert_list(hashes, entities)

        # test
        ent.id = 1001
        ent3.id = 1003
        with pytest.raises(RuntimeError, match=r".*no version ID.*"):
            self.cache.insert_list(hashes, entities)

        ent.version = db.common.versioning.Version("jkadsjfldf")
        ent2.version = db.common.versioning.Version("jkadsjfldf")
        ent3.version = db.common.versioning.Version("jkadsjfldf")

        self.cache.insert_list(hashes, entities)
        self.assertEqual(self.cache.check_existing(hashes[0])[0], 1001)
        self.assertEqual(self.cache.check_existing(hashes[2])[0], 1003)

    def test_clean(self):
        xml = etree.XML(
            """\
            <Entities>
  <TransactionBenchmark>
    </TransactionBenchmark>
  <RecordType id="110" name="Guitar">
    <Version id="eb8c7527980e598b887e84d055db18cfc3806ce6" head="true"/>
    <Parent id="108" name="MusicalInstrument" flag="inheritance:OBLIGATORY,"/>
    <Property id="106" name="electric" datatype="BOOLEAN" importance="RECOMMENDED" flag="inheritance:FIX"/>
  </RecordType>
    </Entities>
""")
        cleanXML(xml)
        assert len(xml.findall('TransactionBenchmark')) == 0


def create_sqlite_file(commands):
    """
    A temporary file will be used
    commands: list of sql commands (tuples) to execute after creation
    Name of the file is returned
    """
    db_file = NamedTemporaryFile(delete=False).name
    conn = sqlite3.connect(db_file)
    c = conn.cursor()
    for sql in commands:
        c.execute(*sql)
    conn.commit()
    conn.close()
    return db_file


class CacheTest2(unittest.TestCase):
    """
    Test the schema version.
    """

    def setUp(self):
        # Correct version:
        self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name,
                                       force_creation=True)

        self.db_file_defect = []
        self.db_file_defect.extend([
            # Version without version table (old version):
            create_sqlite_file(
                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER)''',)]),
            # Version with version table with wrong version:
            create_sqlite_file(
                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
                 ('''CREATE TABLE version (schema INTEGER)''',),
                 ("INSERT INTO version VALUES (?)", (1,))]),
            # Version with version table with wrong version:
            create_sqlite_file(
                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
                 ('''CREATE TABLE version (schema INTEGER)''',),
                 ("INSERT INTO version VALUES (?)", (3,))]),
            # Version with version table with missing version:
            create_sqlite_file(
                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
                 ('''CREATE TABLE version (schema INTEGER)''',)]),
            # Version with version table with too many versions:
            create_sqlite_file(
                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
                 ('''CREATE TABLE version (schema INTEGER)''',),
                 ("INSERT INTO version VALUES (?)", (1,)),
                 ("INSERT INTO version VALUES (?)", (3,))])])

    def test_schema(self):
        # Test whether new cache is created correctly:
        assert os.path.exists(self.cache.db_file)
        # Test whether it can be opened
        test_cache_2 = IdentifiableCache(db_file=self.cache.db_file)
        assert test_cache_2.get_cache_version() == 2

        with pytest.raises(RuntimeError, match="Cache version too old.") as e_info:
            test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[0])

        with pytest.raises(RuntimeError, match="Cache version too old.") as e_info:
            test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[1])

        with pytest.raises(RuntimeError, match=r".*future version.*") as e_info:
            test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[2])

        with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info:
            test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[3])

        with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info:
            test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[4])

    def tearDown(self):
        os.remove(self.cache.db_file)

        for db_fn_defect in self.db_file_defect:
            os.remove(db_fn_defect)


class InvalidationTest(unittest.TestCase):
    """
    Test invalidation of cache entries.
    """

    def setUp(self):
        # Correct version:
        self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name,
                                       force_creation=True)

    def tearDown(self):
        os.remove(self.cache.db_file)

    def test_invalid(self):
        ent = db.Record()
        ent2 = db.Record()
        ent2.add_parent(name="Experiment")
        ent3 = db.Record()
        ent3.add_parent(name="Analysis")
        ent.id = 117
        ent2.id = 328
        ent3.id = 224

        ent.version = db.common.versioning.Version("a")
        ent2.version = db.common.versioning.Version("b")
        ent3.version = db.common.versioning.Version("a")

        el = [ent, ent2, ent3]

        for e in el:
            self.cache.insert(IdentifiableCache.hash_entity(e), e.id, e.version.id)

        for e in el:
            res = self.cache.check_existing(IdentifiableCache.hash_entity(e))
            assert e.id == res[0]
            assert e.version.id == res[1]

        ent2.version.id = "c"
        ent3.version.id = "b"

        for e in el[1:]:
            res = self.cache.check_existing(IdentifiableCache.hash_entity(e))
            assert res is None

        invalidated_entries = self.cache.validate_cache(el)
        assert 328 in invalidated_entries
        assert 224 in invalidated_entries
        assert 117 not in invalidated_entries

        res = self.cache.run_sql_commands([
            ("SELECT * FROM identifiables", ())], fetchall=True)
        assert len(res) == 1