diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py index bde76abf08cb5a8a7835746d1dcd827cc4b2c071..63ee5f6e24b03dc8a9b2ae6edd49ac840ef4fadb 100644 --- a/src/caosadvancedtools/cache.py +++ b/src/caosadvancedtools/cache.py @@ -50,6 +50,8 @@ def get_pretty_xml(cont): return etree.tounicode(cont.to_xml( local_serialization=True), pretty_print=True) +# Increase this, when changes to the cache tables are made: +CACHE_SCHEMA_VERSION = 2 class Cache(object): """ @@ -59,20 +61,67 @@ class Cache(object): without querying. """ - def __init__(self, db_file=None): + def __init__(self, db_file=None, force_creation=False): + """ + db_file: The path of the database file. + + if force_creation is set to True, the file will be created + regardless of a file at the same path already exists. + """ if db_file is None: self.db_file = "cache.db" else: self.db_file = db_file + if force_creation: + self.create_cache() + else: + self.check_cache() + + def check_cache(self): + """ + Check whether the cache in db file self.db_file exists and conforms + to the latest database schema. + + If it does not exist, it will be created using the newest database schema. + + If it exists, but the schema is outdated, an exception will be raised. + """ if not os.path.exists(self.db_file): self.create_cache() + else: + try: + current_schema = self.get_cache_version() + except sqlite3.OperationalError: + current_schema = 1 + # TODO: Write unit tests for too old, too new and non-existent version of cache. + + if current_schema > CACHE_SCHEMA_VERSION: + raise RuntimeError("Cache is corrupt or was created with a future version of this program.") + elif current_schema < CACHE_SCHEMA_VERSION: + raise RuntimeError("Cache version too old.") + + def get_cache_version(self): + try: + conn = sqlite3.connect(self.db_file) + c = conn.cursor() + c.execute("SELECT schema FROM version") + version_row = c.fetchall() + if len(version_row) != 1: + raise RuntimeError("Cache version table broken.") + return version_row[0][0] + finally: + conn.close() def create_cache(self): conn = sqlite3.connect(self.db_file) c = conn.cursor() c.execute( - '''CREATE TABLE identifiables (digest text primary key, caosdb_id integer)''') + '''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''') + c.execute( + '''CREATE TABLE version (schema INTEGER)''') + c.execute("INSERT INTO version VALUES (?)", (CACHE_SCHEMA_VERSION,)) + conn.commit() conn.close() diff --git a/unittests/test_cache.py b/unittests/test_cache.py index 985ac15ca52a06c6e00c13c6d87adcb8d21f1595..662f27a3b57439a46249f7b6c607b4c571709be1 100644 --- a/unittests/test_cache.py +++ b/unittests/test_cache.py @@ -24,16 +24,19 @@ import os import unittest from copy import deepcopy from tempfile import NamedTemporaryFile +import sqlite3 import caosdb as db from caosadvancedtools.cache import Cache +import pytest + class CacheTest(unittest.TestCase): def setUp(self): - self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name) - self.cache.create_cache() + self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name, + force_creation) def test_hash(self): ent = db.Record() @@ -75,3 +78,84 @@ class CacheTest(unittest.TestCase): self.cache.insert_list(hashes, entities) self.assertEqual(self.cache.check_existing(hashes[0]), 1001) self.assertEqual(self.cache.check_existing(hashes[2]), 1003) + +def create_sqlite_file(commands): + """ + A temporary file will be used + commands: list of sql commands (tuples) to execute after creation + Name of the file is returned + """ + db_file = NamedTemporaryFile(delete=False).name + conn = sqlite3.connect(db_file) + c = conn.cursor() + for sql in commands: + c.execute(*sql) + conn.commit() + conn.close() + return db_file + +class CacheTest2(unittest.TestCase): + """ + Test the schema version. + """ + + def setUp(self): + # Correct version: + self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name, + force_creation=True) + + self.db_file_defect = [] + self.db_file_defect.extend([ + # Version without version table (old version): + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER)''',)]), + # Version with version table with wrong version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (1,))]), + # Version with version table with wrong version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (3,))]), + # Version with version table with missing version: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',)]), + # Version with version table with too many versions: + create_sqlite_file( + [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (1,)), + ("INSERT INTO version VALUES (?)", (3,))])]) + + + + def test_schema(self): + # Test whether new cache is created correctly: + assert os.path.exists(self.cache.db_file) + # Test whether it can be opened + test_cache_2 = Cache(db_file=self.cache.db_file) + assert test_cache_2.get_cache_version() == 2 + + with pytest.raises(RuntimeError, match="Cache version too old.") as e_info: + test_cache_2 = Cache(db_file=self.db_file_defect[0]) + + with pytest.raises(RuntimeError, match="Cache version too old.") as e_info: + test_cache_2 = Cache(db_file=self.db_file_defect[1]) + + with pytest.raises(RuntimeError, match=r".*future version.*") as e_info: + test_cache_2 = Cache(db_file=self.db_file_defect[2]) + + with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info: + test_cache_2 = Cache(db_file=self.db_file_defect[3]) + + with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info: + test_cache_2 = Cache(db_file=self.db_file_defect[4]) + + def tearDown(self): + os.remove(self.cache.db_file) + + for db_fn_defect in self.db_file_defect: + os.remove(db_fn_defect)