From 93106b9efe95aa20c0912c196ded40fcad4c5231 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20tom=20W=C3=B6rden?= <h.tomwoerden@indiscale.com> Date: Wed, 16 Dec 2020 11:37:00 +0000 Subject: [PATCH] Suggestions for merge --- src/caosadvancedtools/cache.py | 69 +++++++++++++++++----------------- unittests/test_cache.py | 2 - 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py index 5dc4f847..eb9678ac 100644 --- a/src/caosadvancedtools/cache.py +++ b/src/caosadvancedtools/cache.py @@ -27,12 +27,11 @@ # server side? import os import sqlite3 -from hashlib import sha256 - -from lxml import etree from abc import ABC, abstractmethod +from hashlib import sha256 import caosdb as db +from lxml import etree def put_in_container(stuff): @@ -81,16 +80,17 @@ class AbstractCache(ABC): def __init__(self, db_file=None, force_creation=False): """ db_file: The path of the database file. - + if force_creation is set to True, the file will be created regardless of a file at the same path already exists. """ + if db_file is None: self.db_file = self.get_default_file_name() else: self.db_file = db_file - if force_creation: + if not os.path.exists(self.db_file) or force_creation: self.create_cache() else: self.check_cache() @@ -104,19 +104,15 @@ class AbstractCache(ABC): If it exists, but the schema is outdated, an exception will be raised. """ - print(self.db_file) - if not os.path.exists(self.db_file): - self.create_cache() - else: - try: - current_schema = self.get_cache_version() - except sqlite3.OperationalError: - current_schema = 1 - - if current_schema > self.get_cache_schema_version(): - raise RuntimeError("Cache is corrupt or was created with a future version of this program.") - elif current_schema < self.get_cache_schema_version(): - raise RuntimeError("Cache version too old.") + try: + current_schema = self.get_cache_version() + except sqlite3.OperationalError: + current_schema = 1 + + if current_schema > self.get_cache_schema_version(): + raise RuntimeError("Cache is corrupt or was created with a future version of this program.") + elif current_schema < self.get_cache_schema_version(): + raise RuntimeError("Cache version too old.") def get_cache_version(self): """ @@ -128,8 +124,10 @@ class AbstractCache(ABC): c = conn.cursor() c.execute("SELECT schema FROM version") version_row = c.fetchall() + if len(version_row) != 1: raise RuntimeError("Cache version table broken.") + return version_row[0][0] finally: conn.close() @@ -137,22 +135,25 @@ class AbstractCache(ABC): def run_sql_commands(self, commands, fetchall=False): """ Run a list of SQL commands on self.db_file. - + commands: list of sql commands (tuples) to execute fetchall: When True, run fetchall as last command and return the results. Otherwise nothing is returned. """ conn = sqlite3.connect(self.db_file) c = conn.cursor() + for sql in commands: c.execute(*sql) + if fetchall: results = c.fetchall() conn.commit() conn.close() + if fetchall: return results - + # TODO: A better name would be IdentifiablesCache class Cache(AbstractCache): @@ -250,9 +251,11 @@ class Cache(AbstractCache): """ # Check whether all entities have IDs and versions: + for ent in entities: if ent.id is None: raise RuntimeError("Entity has no ID.") + if ent.version is None or ent.version.id is None: raise RuntimeError("Entity has no version ID.") @@ -264,8 +267,7 @@ class Cache(AbstractCache): """ Runs through all entities stored in the cache and checks whether the version still matches the most recent version. - Non-matching entities will be removed - from the cache. + Non-matching entities will be removed from the cache. entities: When set to a db.Container or a list of Entities the IDs from the cache will not be retrieved from the CaosDB database, @@ -281,27 +283,24 @@ class Cache(AbstractCache): "SELECT caosdb_id, caosdb_version FROM identifiables", ())], True) if entities is None: - c = db.Container() - else: - c = entities - v = dict() - for c_id, c_version in res: - if entities is None: - c.append(db.Entity(id=c_id)) - v[c_id] = c_version - if entities is None: - c.retrieve() - + # TODO this might become a problem. If many entities are cached, + # then all of them are retrieved here... + entities = db.Container() + entities.extend([db.Entity(id=c_id) for c_id, _ in res]) + entities.retrieve() + + v = {c_id: c_version for c_id, c_version in res} invalidate_list = [] - for ent in c: + + for ent in entities: if ent.version.id != v[ent.id]: invalidate_list.append(ent.id) self.run_sql_commands([( "DELETE FROM identifiables WHERE caosdb_id IN ({})".format( ", ".join([str(caosdb_id) for caosdb_id in invalidate_list])), ())]) - + return invalidate_list diff --git a/unittests/test_cache.py b/unittests/test_cache.py index 0e38201f..53fc1174 100644 --- a/unittests/test_cache.py +++ b/unittests/test_cache.py @@ -190,8 +190,6 @@ class InvalidationTest(unittest.TestCase): def test_invalid(self): - assert len(self.cache.validate_cache()) == 0 - ent = db.Record() ent2 = db.Record() ent2.add_parent(name="Experiment") -- GitLab