From 93106b9efe95aa20c0912c196ded40fcad4c5231 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Henrik=20tom=20W=C3=B6rden?= <h.tomwoerden@indiscale.com>
Date: Wed, 16 Dec 2020 11:37:00 +0000
Subject: [PATCH] Suggestions for merge

---
 src/caosadvancedtools/cache.py | 69 +++++++++++++++++-----------------
 unittests/test_cache.py        |  2 -
 2 files changed, 34 insertions(+), 37 deletions(-)

diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py
index 5dc4f847..eb9678ac 100644
--- a/src/caosadvancedtools/cache.py
+++ b/src/caosadvancedtools/cache.py
@@ -27,12 +27,11 @@
 # server side?
 import os
 import sqlite3
-from hashlib import sha256
-
-from lxml import etree
 from abc import ABC, abstractmethod
+from hashlib import sha256
 
 import caosdb as db
+from lxml import etree
 
 
 def put_in_container(stuff):
@@ -81,16 +80,17 @@ class AbstractCache(ABC):
     def __init__(self, db_file=None, force_creation=False):
         """
         db_file: The path of the database file.
-        
+
         if force_creation is set to True, the file will be created
         regardless of a file at the same path already exists.
         """
+
         if db_file is None:
             self.db_file = self.get_default_file_name()
         else:
             self.db_file = db_file
 
-        if force_creation:
+        if not os.path.exists(self.db_file) or force_creation:
             self.create_cache()
         else:
             self.check_cache()
@@ -104,19 +104,15 @@ class AbstractCache(ABC):
 
         If it exists, but the schema is outdated, an exception will be raised.
         """
-        print(self.db_file)
-        if not os.path.exists(self.db_file):
-            self.create_cache()
-        else:
-            try:
-                current_schema = self.get_cache_version()
-            except sqlite3.OperationalError:
-                current_schema = 1
-                
-            if current_schema > self.get_cache_schema_version():
-                raise RuntimeError("Cache is corrupt or was created with a future version of this program.")
-            elif current_schema < self.get_cache_schema_version():
-                raise RuntimeError("Cache version too old.")
+        try:
+            current_schema = self.get_cache_version()
+        except sqlite3.OperationalError:
+            current_schema = 1
+
+        if current_schema > self.get_cache_schema_version():
+            raise RuntimeError("Cache is corrupt or was created with a future version of this program.")
+        elif current_schema < self.get_cache_schema_version():
+            raise RuntimeError("Cache version too old.")
 
     def get_cache_version(self):
         """
@@ -128,8 +124,10 @@ class AbstractCache(ABC):
             c = conn.cursor()
             c.execute("SELECT schema FROM version")
             version_row = c.fetchall()
+
             if len(version_row) != 1:
                 raise RuntimeError("Cache version table broken.")
+
             return version_row[0][0]
         finally:
             conn.close()
@@ -137,22 +135,25 @@ class AbstractCache(ABC):
     def run_sql_commands(self, commands, fetchall=False):
         """
         Run a list of SQL commands on self.db_file.
-        
+
         commands: list of sql commands (tuples) to execute
         fetchall: When True, run fetchall as last command and return the results.
                   Otherwise nothing is returned.
         """
         conn = sqlite3.connect(self.db_file)
         c = conn.cursor()
+
         for sql in commands:
             c.execute(*sql)
+
         if fetchall:
             results = c.fetchall()
         conn.commit()
         conn.close()
+
         if fetchall:
             return results
-    
+
 
 # TODO: A better name would be IdentifiablesCache
 class Cache(AbstractCache):
@@ -250,9 +251,11 @@ class Cache(AbstractCache):
         """
 
         # Check whether all entities have IDs and versions:
+
         for ent in entities:
             if ent.id is None:
                 raise RuntimeError("Entity has no ID.")
+
             if ent.version is None or ent.version.id is None:
                 raise RuntimeError("Entity has no version ID.")
 
@@ -264,8 +267,7 @@ class Cache(AbstractCache):
         """
         Runs through all entities stored in the cache and checks
         whether the version still matches the most recent version.
-        Non-matching entities will be removed
-        from the cache.
+        Non-matching entities will be removed from the cache.
 
         entities: When set to a db.Container or a list of Entities
                   the IDs from the cache will not be retrieved from the CaosDB database,
@@ -281,27 +283,24 @@ class Cache(AbstractCache):
             "SELECT caosdb_id, caosdb_version FROM identifiables", ())], True)
 
         if entities is None:
-            c = db.Container()
-        else:
-            c = entities
-        v = dict()
-        for c_id, c_version in res:
-            if entities is None:
-                c.append(db.Entity(id=c_id))
-            v[c_id] = c_version
-        if entities is None:
-            c.retrieve()
-        
+            # TODO this might become a problem. If many entities are cached,
+            # then all of them are retrieved here...
+            entities = db.Container()
+            entities.extend([db.Entity(id=c_id) for c_id, _ in res])
+            entities.retrieve()
+
+        v = {c_id: c_version for c_id, c_version in res}
 
         invalidate_list = []
-        for ent in c:
+
+        for ent in entities:
             if ent.version.id != v[ent.id]:
                 invalidate_list.append(ent.id)
 
         self.run_sql_commands([(
             "DELETE FROM identifiables WHERE caosdb_id IN ({})".format(
                 ", ".join([str(caosdb_id) for caosdb_id in invalidate_list])), ())])
-                
+
         return invalidate_list
 
 
diff --git a/unittests/test_cache.py b/unittests/test_cache.py
index 0e38201f..53fc1174 100644
--- a/unittests/test_cache.py
+++ b/unittests/test_cache.py
@@ -190,8 +190,6 @@ class InvalidationTest(unittest.TestCase):
 
 
     def test_invalid(self):
-        assert len(self.cache.validate_cache()) == 0
-        
         ent = db.Record()
         ent2 = db.Record()
         ent2.add_parent(name="Experiment")
-- 
GitLab