diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 73a9abb0b5a525ddb74ddbf33003b03e35c1cacf..68fb90fc027f1d706430358822b7aa8d5c4e2959 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -129,6 +129,7 @@ unittest: image: $CI_REGISTRY_IMAGE needs: [build-testenv] script: + - python3 -c "import caosdb; print('CaosDB Version:', caosdb.__version__)" - tox # Build the sphinx documentation and make it ready for deployment by Gitlab Pages diff --git a/integrationtests/test.sh b/integrationtests/test.sh index 5bb013db6e70a3a8393e7e3b7c7993a6da6bf9b9..686172800a96bb5ce5f97d8bc3fc9b89f012ab1b 100755 --- a/integrationtests/test.sh +++ b/integrationtests/test.sh @@ -16,7 +16,8 @@ fi OUT=/tmp/crawler.output ls cat pycaosdb.ini -rm -rf cache.db +python3 -c "import caosdb; print('CaosDB Version:', caosdb.__version__)" +rm -rf /tmp/cache.db set -e echo "Clearing database" python3 clear_database.py diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py index a7d1e4526ab0d816f489e991be01ad08a94bc0b0..fcd93e7f7861affc3d81f5c5252711edf5587f7b 100644 --- a/src/caosadvancedtools/cache.py +++ b/src/caosadvancedtools/cache.py @@ -23,11 +23,14 @@ # # ** end header -# TODO this is implementing a cache on client side. Should it be on -# server side? +# Note: This is implementing a cache on client side. It would be great if the server would provide +# something to replace this. import os import sqlite3 +from copy import deepcopy +from abc import ABC, abstractmethod from hashlib import sha256 +import warnings import caosdb as db from lxml import etree @@ -64,61 +67,180 @@ def get_pretty_xml(cont): return etree.tounicode(xml, pretty_print=True) -class Cache(object): - """ - stores identifiables (as a hash of xml) and their respective ID. +class AbstractCache(ABC): + def __init__(self, db_file=None, force_creation=False): + """ + db_file: The path of the database file. - This allows to retrieve the Record corresponding to an indentifiable - without querying. - """ + if force_creation is set to True, the file will be created + regardless of a file at the same path already exists. + """ - def __init__(self, db_file=None, default_name="cache.db"): if db_file is None: tmppath = tempfile.gettempdir() - tmpf = os.path.join(tmppath, default_name) - self.db_file = tmpf + self.db_file = os.path.join(tmppath, self.get_default_file_name()) else: self.db_file = db_file - if not os.path.exists(self.db_file): + if not os.path.exists(self.db_file) or force_creation: self.create_cache() + else: + self.check_cache() + + @abstractmethod + def get_cache_schema_version(self): + """ + A method that has to be overloaded that sets the version of the + SQLITE database schema. The schema is saved in table version column schema. + Increase this variable, when changes to the cache tables are made. + """ + pass + + @abstractmethod def create_cache(self): + """ + Provide an overloaded function here that creates the cache in + the most recent version. + """ + pass + + @abstractmethod + def get_default_file_name(self): + """ + Supply a default file name for the cache here. + """ + pass + + def check_cache(self): + """ + Check whether the cache in db file self.db_file exists and conforms + to the latest database schema. + + If it does not exist, it will be created using the newest database schema. + + If it exists, but the schema is outdated, an exception will be raised. + """ + try: + current_schema = self.get_cache_version() + except sqlite3.OperationalError: + current_schema = 1 + + if current_schema > self.get_cache_schema_version(): + raise RuntimeError( + "Cache is corrupt or was created with a future version of this program.") + elif current_schema < self.get_cache_schema_version(): + raise RuntimeError("Cache version too old.") + + def get_cache_version(self): + """ + Return the version of the cache stored in self.db_file. + The version is stored as the only entry in colum schema of table version. + """ + try: + conn = sqlite3.connect(self.db_file) + c = conn.cursor() + c.execute("SELECT schema FROM version") + version_row = c.fetchall() + + if len(version_row) != 1: + raise RuntimeError("Cache version table broken.") + + return version_row[0][0] + finally: + conn.close() + + def run_sql_commands(self, commands, fetchall=False): + """ + Run a list of SQL commands on self.db_file. + + commands: list of sql commands (tuples) to execute + fetchall: When True, run fetchall as last command and return the results. + Otherwise nothing is returned. + """ conn = sqlite3.connect(self.db_file) c = conn.cursor() - c.execute( - '''CREATE TABLE identifiables (digest text primary key, caosdb_id integer)''') + + for sql in commands: + c.execute(*sql) + + if fetchall: + results = c.fetchall() conn.commit() conn.close() + if fetchall: + return results + + +class IdentifiableCache(AbstractCache): + """ + stores identifiables (as a hash of xml) and their respective ID. + + This allows to retrieve the Record corresponding to an indentifiable + without querying. + """ + + def get_cache_schema_version(self): + return 2 + + def get_default_file_name(self): + return "cache.db" + + def __init__(self, db_file=None, force_creation=False): + super().__init__(db_file, force_creation) + + def create_cache(self): + """ + Create a new SQLITE cache file in self.db_file. + + Two tables will be created: + - identifiables is the actual cache. + - version is a table with version information about the cache. + """ + self.run_sql_commands([ + ('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (self.get_cache_schema_version(),))]) + @staticmethod def hash_entity(ent): - xml = get_pretty_xml(ent) + """ + Format an entity as "pretty" XML and return the SHA256 hash. + """ + xml = get_pretty_xml(deepcopy(ent)) digest = sha256(xml.encode("utf-8")).hexdigest() return digest - def insert(self, ent_hash, ent_id): - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''INSERT INTO identifiables VALUES (?, ?)''', - (ent_hash, ent_id)) - conn.commit() - conn.close() + def insert(self, ent_hash, ent_id, ent_version): + """ + Insert a new cache entry. + + ent_hash: Hash of the entity. Should be generated with Cache.hash_entity + ent_id: ID of the entity + ent_version: Version string of the entity + """ + self.run_sql_commands([ + ('''INSERT INTO identifiables VALUES (?, ?, ?)''', + (ent_hash, ent_id, ent_version))]) def check_existing(self, ent_hash): - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''Select * FROM identifiables WHERE digest=?''', - (ent_hash,)) - res = c.fetchone() - conn.commit() - conn.close() + """ + Check the cache for a hash. + + ent_hash: The hash to search for. - if res is None: - return res + Return the ID and the version ID of the hashed entity. + Return None if no entity with that hash is in the cache. + """ + res = self.run_sql_commands([('''Select * FROM identifiables WHERE digest=?''', + (ent_hash,))], True) + + if len(res) == 0: + return None else: - return res[1] + return res[0][1:] def update_ids_from_cache(self, entities): """ sets ids of those entities that are in cache @@ -133,7 +255,7 @@ class Cache(object): eid = self.check_existing(ehash) if eid is not None: - ent.id = eid + ent.id = eid[0] return hashes @@ -143,12 +265,63 @@ class Cache(object): The hashes must correspond to the entities in the list """ + # Check whether all entities have IDs and versions: + + for ent in entities: + if ent.id is None: + raise RuntimeError("Entity has no ID.") + + if ent.version is None or ent.version.id is None: + raise RuntimeError("Entity has no version ID.") + for ehash, ent in zip(hashes, entities): if self.check_existing(ehash) is None: - self.insert(ehash, ent.id) + self.insert(ehash, ent.id, ent.version.id) + + def validate_cache(self, entities=None): + """ + Runs through all entities stored in the cache and checks + whether the version still matches the most recent version. + Non-matching entities will be removed from the cache. + + entities: When set to a db.Container or a list of Entities + the IDs from the cache will not be retrieved from the CaosDB database, + but the versions from the cache will be checked against the versions + contained in that collection. Only entries in the cache that have + a corresponding version in the collection will be checked, all others + will be ignored. Useful for testing. + + Return a list of invalidated entries or an empty list if no elements have been invalidated. + """ + res = self.run_sql_commands([( + "SELECT caosdb_id, caosdb_version FROM identifiables", ())], True) -class UpdateCache(Cache): + if entities is None: + # TODO this might become a problem. If many entities are cached, + # then all of them are retrieved here... + ids = [c_id for c_id, _ in res] + ids = set(ids) + entities = db.Container() + entities.extend([db.Entity(id=c_id) for c_id in ids]) + entities.retrieve() + + v = {c_id: c_version for c_id, c_version in res} + + invalidate_list = [] + + for ent in entities: + if ent.version.id != v[ent.id]: + invalidate_list.append(ent.id) + + self.run_sql_commands([( + "DELETE FROM identifiables WHERE caosdb_id IN ({})".format( + ", ".join([str(caosdb_id) for caosdb_id in invalidate_list])), ())]) + + return invalidate_list + + +class UpdateCache(AbstractCache): """ stores unauthorized inserts and updates @@ -156,8 +329,11 @@ class UpdateCache(Cache): be stored in this cache such that it can be authorized and performed later. """ - def __init__(self, db_file=None): - super().__init__(db_file=db_file, default_name="crawler_insert_cache.db") + def get_cache_schema_version(self): + return 1 + + def get_default_file_name(self): + return "/tmp/crawler_update_cache.db" @staticmethod def get_previous_version(cont): @@ -199,23 +375,15 @@ class UpdateCache(Cache): else: old_hash = Cache.hash_entity(old_ones) new_hash = Cache.hash_entity(new_ones) - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''INSERT INTO updates VALUES (?, ?, ?, ?, ?)''', - (old_hash, new_hash, str(old_ones), str(new_ones), - str(run_id))) - conn.commit() - conn.close() + self.run_sql_commands([('''INSERT INTO updates VALUES (?, ?, ?, ?, ?)''', + (old_hash, new_hash, str(old_ones), str(new_ones), + str(run_id)))]) def create_cache(self): """ initialize the cache """ - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''CREATE TABLE updates (olddigest text, newdigest text, + self.run_sql_commands([('''CREATE TABLE updates (olddigest text, newdigest text, oldrep text, newrep text, run_id text, - primary key (olddigest, newdigest, run_id))''') - conn.commit() - conn.close() + primary key (olddigest, newdigest, run_id))''', )]) def get(self, run_id, querystring): """ returns the pending updates for a given run id @@ -226,14 +394,7 @@ class UpdateCache(Cache): querystring: the sql query """ - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute(querystring, (str(run_id),)) - res = c.fetchall() - conn.commit() - conn.close() - - return res + return self.run_sql_commands([(querystring, (str(run_id),))], fetchall=True) def get_inserts(self, run_id): """ returns the pending updates for a given run id @@ -254,3 +415,9 @@ class UpdateCache(Cache): """ return self.get(run_id, '''Select * FROM updates WHERE olddigest!='' AND run_id=?''') + + +class Cache(IdentifiableCache): + def __init__(self, *args, **kwargs): + warnings.warn(DeprecationWarning("This class is depricated. Please use IdentifiableCache.")) + super().__init__(*args, **kwargs) diff --git a/src/caosadvancedtools/crawler.py b/src/caosadvancedtools/crawler.py index 9e8f5fb324cccb095f98356b2b5e5aabc98bb383..6a0cdb58f50b718cff3850586720b8e5031e64ea 100644 --- a/src/caosadvancedtools/crawler.py +++ b/src/caosadvancedtools/crawler.py @@ -50,7 +50,7 @@ from sqlite3 import IntegrityError import caosdb as db from caosdb.exceptions import BadQueryError -from .cache import Cache, UpdateCache, get_pretty_xml +from .cache import IdentifiableCache, UpdateCache, get_pretty_xml from .cfood import RowCFood, add_files, get_ids_for_entities_with_names from .datainconsistency import DataInconsistencyError from .datamodel_problems import DataModelProblems @@ -190,7 +190,8 @@ class Crawler(object): self.filterKnown.reset(cat) if self.use_cache: - self.cache = Cache(db_file=cache_file) + self.cache = IdentifiableCache(db_file=cache_file) + self.cache.validate_cache() def iteritems(self): """ generates items to be crawled with an index""" diff --git a/unittests/test_cache.py b/unittests/test_cache.py index 2d7b863fe971dd61a575e24f52853de4f5c4e204..de3430bf2f28a6b05ea36b1047ac11937809ff44 100644 --- a/unittests/test_cache.py +++ b/unittests/test_cache.py @@ -24,31 +24,35 @@ import os import unittest from copy import deepcopy from tempfile import NamedTemporaryFile +import sqlite3 import caosdb as db -from caosadvancedtools.cache import Cache, cleanXML +from caosadvancedtools.cache import IdentifiableCache, cleanXML from lxml import etree +import pytest + class CacheTest(unittest.TestCase): def setUp(self): - self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name) - self.cache.create_cache() + self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name, + force_creation=True) def test_hash(self): ent = db.Record() - assert isinstance(Cache.hash_entity(ent), str) - assert (Cache.hash_entity(ent) != - Cache.hash_entity(db.Record().add_parent("lol"))) + assert isinstance(IdentifiableCache.hash_entity(ent), str) + assert (IdentifiableCache.hash_entity(ent) != + IdentifiableCache.hash_entity(db.Record().add_parent("lol"))) def test_insert(self): ent = db.Record() ent2 = db.Record() ent2.add_parent(name="Experiment") - ent_hash = Cache.hash_entity(ent) - ent2_hash = Cache.hash_entity(ent2) - self.cache.insert(ent2_hash, 1235) - assert isinstance(self.cache.check_existing(ent2_hash), int) + ent_hash = IdentifiableCache.hash_entity(ent) + ent2_hash = IdentifiableCache.hash_entity(ent2) + self.cache.insert(ent2_hash, 1235, "ajkfljadsklf") + assert self.cache.check_existing(ent2_hash)[0] == 1235 + assert self.cache.check_existing(ent2_hash)[1] == "ajkfljadsklf" assert self.cache.check_existing(ent_hash) is None def test_hirarchy(self): @@ -64,17 +68,29 @@ class CacheTest(unittest.TestCase): ent3 = db.Record() ent3.add_parent(name="Analysis") test_id = 2353243 - self.cache.insert(Cache.hash_entity(ent2), test_id) + self.cache.insert(IdentifiableCache.hash_entity(ent2), test_id, "ajdsklfjadslf") entities = [ent, ent2, ent3] hashes = self.cache.update_ids_from_cache(entities) + self.assertEqual(ent.id, None) self.assertEqual(ent2.id, test_id) + self.assertEqual(ent3.id, None) + + with pytest.raises(RuntimeError, match=r".*no ID.*"): + self.cache.insert_list(hashes, entities) # test ent.id = 1001 ent3.id = 1003 + with pytest.raises(RuntimeError, match=r".*no version ID.*"): + self.cache.insert_list(hashes, entities) + + ent.version = db.common.versioning.Version("jkadsjfldf") + ent2.version = db.common.versioning.Version("jkadsjfldf") + ent3.version = db.common.versioning.Version("jkadsjfldf") + self.cache.insert_list(hashes, entities) - self.assertEqual(self.cache.check_existing(hashes[0]), 1001) - self.assertEqual(self.cache.check_existing(hashes[2]), 1003) + self.assertEqual(self.cache.check_existing(hashes[0])[0], 1001) + self.assertEqual(self.cache.check_existing(hashes[2])[0], 1003) def test_clean(self): xml = etree.XML( @@ -91,3 +107,138 @@ class CacheTest(unittest.TestCase): """) cleanXML(xml) assert len(xml.findall('TransactionBenchmark')) == 0 + + +def create_sqlite_file(commands): + """ + A temporary file will be used + commands: list of sql commands (tuples) to execute after creation + Name of the file is returned + """ + db_file = NamedTemporaryFile(delete=False).name + conn = sqlite3.connect(db_file) + c = conn.cursor() + for sql in commands: + c.execute(*sql) + conn.commit() + conn.close() + return db_file + + +class CacheTest2(unittest.TestCase): + """ + Test the schema version. + """ + + def setUp(self): + # Correct version: + self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name, + force_creation=True) + + self.db_file_defect = [] + self.db_file_defect.extend([ + # Version without version table (old version): + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER)''',)]), + # Version with version table with wrong version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (1,))]), + # Version with version table with wrong version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (3,))]), + # Version with version table with missing version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',)]), + # Version with version table with too many versions: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (1,)), + ("INSERT INTO version VALUES (?)", (3,))])]) + + def test_schema(self): + # Test whether new cache is created correctly: + assert os.path.exists(self.cache.db_file) + # Test whether it can be opened + test_cache_2 = IdentifiableCache(db_file=self.cache.db_file) + assert test_cache_2.get_cache_version() == 2 + + with pytest.raises(RuntimeError, match="Cache version too old.") as e_info: + test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[0]) + + with pytest.raises(RuntimeError, match="Cache version too old.") as e_info: + test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[1]) + + with pytest.raises(RuntimeError, match=r".*future version.*") as e_info: + test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[2]) + + with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info: + test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[3]) + + with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info: + test_cache_2 = IdentifiableCache(db_file=self.db_file_defect[4]) + + def tearDown(self): + os.remove(self.cache.db_file) + + for db_fn_defect in self.db_file_defect: + os.remove(db_fn_defect) + + +class InvalidationTest(unittest.TestCase): + """ + Test invalidation of cache entries. + """ + + def setUp(self): + # Correct version: + self.cache = IdentifiableCache(db_file=NamedTemporaryFile(delete=False).name, + force_creation=True) + + def tearDown(self): + os.remove(self.cache.db_file) + + def test_invalid(self): + ent = db.Record() + ent2 = db.Record() + ent2.add_parent(name="Experiment") + ent3 = db.Record() + ent3.add_parent(name="Analysis") + ent.id = 117 + ent2.id = 328 + ent3.id = 224 + + ent.version = db.common.versioning.Version("a") + ent2.version = db.common.versioning.Version("b") + ent3.version = db.common.versioning.Version("a") + + el = [ent, ent2, ent3] + + for e in el: + self.cache.insert(IdentifiableCache.hash_entity(e), e.id, e.version.id) + + for e in el: + res = self.cache.check_existing(IdentifiableCache.hash_entity(e)) + assert e.id == res[0] + assert e.version.id == res[1] + + ent2.version.id = "c" + ent3.version.id = "b" + + for e in el[1:]: + res = self.cache.check_existing(IdentifiableCache.hash_entity(e)) + assert res is None + + invalidated_entries = self.cache.validate_cache(el) + assert 328 in invalidated_entries + assert 224 in invalidated_entries + assert 117 not in invalidated_entries + + res = self.cache.run_sql_commands([ + ("SELECT * FROM identifiables", ())], fetchall=True) + assert len(res) == 1