From e59b78bc7bfbc488ef5fd44df47b54f6afeaa74a Mon Sep 17 00:00:00 2001
From: Alexander Schlemmer <alexander.schlemmer@ds.mpg.de>
Date: Mon, 23 Nov 2020 17:31:59 +0100
Subject: [PATCH] ENH: added schema version to sql file and unit tests for the
 schema

---
 src/caosadvancedtools/cache.py | 53 +++++++++++++++++++-
 unittests/test_cache.py        | 88 +++++++++++++++++++++++++++++++++-
 2 files changed, 137 insertions(+), 4 deletions(-)

diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py
index bde76abf..63ee5f6e 100644
--- a/src/caosadvancedtools/cache.py
+++ b/src/caosadvancedtools/cache.py
@@ -50,6 +50,8 @@ def get_pretty_xml(cont):
     return etree.tounicode(cont.to_xml(
         local_serialization=True), pretty_print=True)
 
+# Increase this, when changes to the cache tables are made:
+CACHE_SCHEMA_VERSION = 2
 
 class Cache(object):
     """
@@ -59,20 +61,67 @@ class Cache(object):
     without querying.
     """
 
-    def __init__(self, db_file=None):
+    def __init__(self, db_file=None, force_creation=False):
+        """
+        db_file: The path of the database file.
+        
+        if force_creation is set to True, the file will be created
+        regardless of a file at the same path already exists.
+        """
         if db_file is None:
             self.db_file = "cache.db"
         else:
             self.db_file = db_file
 
+        if force_creation:
+            self.create_cache()
+        else:
+            self.check_cache()
+
+    def check_cache(self):
+        """
+        Check whether the cache in db file self.db_file exists and conforms
+        to the latest database schema.
+
+        If it does not exist, it will be created using the newest database schema.
+
+        If it exists, but the schema is outdated, an exception will be raised.
+        """
         if not os.path.exists(self.db_file):
             self.create_cache()
+        else:
+            try:
+                current_schema = self.get_cache_version()
+            except sqlite3.OperationalError:
+                current_schema = 1
+            # TODO: Write unit tests for too old, too new and non-existent version of cache.
+
+            if current_schema > CACHE_SCHEMA_VERSION:
+                raise RuntimeError("Cache is corrupt or was created with a future version of this program.")
+            elif current_schema < CACHE_SCHEMA_VERSION:
+                raise RuntimeError("Cache version too old.")
+
+    def get_cache_version(self):
+        try:
+            conn = sqlite3.connect(self.db_file)
+            c = conn.cursor()
+            c.execute("SELECT schema FROM version")
+            version_row = c.fetchall()
+            if len(version_row) != 1:
+                raise RuntimeError("Cache version table broken.")
+            return version_row[0][0]
+        finally:
+            conn.close()
 
     def create_cache(self):
         conn = sqlite3.connect(self.db_file)
         c = conn.cursor()
         c.execute(
-            '''CREATE TABLE identifiables (digest text primary key, caosdb_id integer)''')
+            '''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''')
+        c.execute(
+            '''CREATE TABLE version (schema INTEGER)''')
+        c.execute("INSERT INTO version VALUES (?)", (CACHE_SCHEMA_VERSION,))
+        
         conn.commit()
         conn.close()
 
diff --git a/unittests/test_cache.py b/unittests/test_cache.py
index 985ac15c..662f27a3 100644
--- a/unittests/test_cache.py
+++ b/unittests/test_cache.py
@@ -24,16 +24,19 @@ import os
 import unittest
 from copy import deepcopy
 from tempfile import NamedTemporaryFile
+import sqlite3
 
 import caosdb as db
 
 from caosadvancedtools.cache import Cache
 
+import pytest
+
 
 class CacheTest(unittest.TestCase):
     def setUp(self):
-        self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name)
-        self.cache.create_cache()
+        self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name,
+                           force_creation)
 
     def test_hash(self):
         ent = db.Record()
@@ -75,3 +78,84 @@ class CacheTest(unittest.TestCase):
         self.cache.insert_list(hashes, entities)
         self.assertEqual(self.cache.check_existing(hashes[0]), 1001)
         self.assertEqual(self.cache.check_existing(hashes[2]), 1003)
+
+def create_sqlite_file(commands):
+    """
+    A temporary file will be used
+    commands: list of sql commands (tuples) to execute after creation
+    Name of the file is returned
+    """
+    db_file = NamedTemporaryFile(delete=False).name
+    conn = sqlite3.connect(db_file)
+    c = conn.cursor()
+    for sql in commands:
+        c.execute(*sql)
+    conn.commit()
+    conn.close()
+    return db_file
+        
+class CacheTest2(unittest.TestCase):
+    """
+    Test the schema version.
+    """
+    
+    def setUp(self):
+        # Correct version:
+        self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name,
+                           force_creation=True)
+
+        self.db_file_defect = []
+        self.db_file_defect.extend([
+            # Version without version table (old version):
+            create_sqlite_file(
+                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER)''',)]),
+            # Version with version table with wrong version:
+            create_sqlite_file(
+                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
+                 ('''CREATE TABLE version (schema INTEGER)''',),
+                 ("INSERT INTO version VALUES (?)", (1,))]),
+            # Version with version table with wrong version:
+            create_sqlite_file(
+                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
+                 ('''CREATE TABLE version (schema INTEGER)''',),
+                 ("INSERT INTO version VALUES (?)", (3,))]),
+            # Version with version table with missing version:
+            create_sqlite_file(
+                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
+                 ('''CREATE TABLE version (schema INTEGER)''',)]),
+            # Version with version table with too many versions:
+            create_sqlite_file(
+                [('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',),
+                 ('''CREATE TABLE version (schema INTEGER)''',),
+                 ("INSERT INTO version VALUES (?)", (1,)),
+                 ("INSERT INTO version VALUES (?)", (3,))])])
+        
+
+
+    def test_schema(self):
+        # Test whether new cache is created correctly:
+        assert os.path.exists(self.cache.db_file)
+        # Test whether it can be opened
+        test_cache_2 = Cache(db_file=self.cache.db_file)
+        assert test_cache_2.get_cache_version() == 2
+
+        with pytest.raises(RuntimeError, match="Cache version too old.") as e_info:
+            test_cache_2 = Cache(db_file=self.db_file_defect[0])
+
+        with pytest.raises(RuntimeError, match="Cache version too old.") as e_info:
+            test_cache_2 = Cache(db_file=self.db_file_defect[1])
+
+        with pytest.raises(RuntimeError, match=r".*future version.*") as e_info:
+            test_cache_2 = Cache(db_file=self.db_file_defect[2])
+
+        with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info:
+            test_cache_2 = Cache(db_file=self.db_file_defect[3])
+
+        with pytest.raises(RuntimeError, match=r".*table broken.*") as e_info:
+            test_cache_2 = Cache(db_file=self.db_file_defect[4])
+
+    def tearDown(self):
+        os.remove(self.cache.db_file)
+
+        for db_fn_defect in self.db_file_defect:
+            os.remove(db_fn_defect)
-- 
GitLab