diff --git a/CHANGELOG.md b/CHANGELOG.md index d950e343c9523d3df8a28b13533aa0c50364ada2..abca90fe2f0dc1ebc5a22b7aa83f182b6ff7d280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### +- Unauthorized inserts can now be cached. Note that the Crawler cannot postpone + inserts but the Cache has the functionality now. + ### Changed ### ### Deprecated ### diff --git a/integrationtests/test.sh b/integrationtests/test.sh index 700d88160b08652b0c5257d8ba819e277edb2971..36730cc948d308659f01f6153f86a917ab1909d0 100755 --- a/integrationtests/test.sh +++ b/integrationtests/test.sh @@ -17,7 +17,7 @@ OUT=/tmp/crawler.output ls cat pycaosdb.ini python3 -c "import caosdb; print('CaosDB Version:', caosdb.__version__)" -rm -rf cache.db +rm -rf /tmp/caosdb_identifiable_cache.db set -e echo "Clearing database" python3 clear_database.py diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py index 993868c48e1f88373cca8475ee832aeee9999545..9a3fe0fd975eeb393d78f5f0ece086607ef71afb 100644 --- a/src/caosadvancedtools/cache.py +++ b/src/caosadvancedtools/cache.py @@ -23,8 +23,8 @@ # # ** end header -# TODO this is implementing a cache on client side. Should it be on -# server side? +# Note: This is implementing a cache on client side. It would be great if the server would provide +# something to replace this. import os import sqlite3 from copy import deepcopy @@ -68,6 +68,25 @@ def get_pretty_xml(cont): class AbstractCache(ABC): + def __init__(self, db_file=None, force_creation=False): + """ + db_file: The path of the database file. + + if force_creation is set to True, the file will be created + regardless of a file at the same path already exists. + """ + + if db_file is None: + tmppath = tempfile.gettempdir() + self.db_file = os.path.join(tmppath, self.get_default_file_name()) + else: + self.db_file = db_file + + if not os.path.exists(self.db_file) or force_creation: + self.create_cache() + else: + self.check_cache() + @abstractmethod def get_cache_schema_version(self): """ @@ -93,24 +112,6 @@ class AbstractCache(ABC): """ pass - def __init__(self, db_file=None, force_creation=False): - """ - db_file: The path of the database file. - - if force_creation is set to True, the file will be created - regardless of a file at the same path already exists. - """ - - if db_file is None: - self.db_file = self.get_default_file_name() - else: - self.db_file = db_file - - if not os.path.exists(self.db_file) or force_creation: - self.create_cache() - else: - self.check_cache() - def check_cache(self): """ Check whether the cache in db file self.db_file exists and conforms @@ -172,7 +173,6 @@ class AbstractCache(ABC): return results -# TODO: A better name would be IdentifiablesCache class IdentifiableCache(AbstractCache): """ stores identifiables (as a hash of xml) and their respective ID. @@ -185,7 +185,7 @@ class IdentifiableCache(AbstractCache): return 2 def get_default_file_name(self): - return "cache.db" + return "caosdb_identifiable_cache.db" def __init__(self, db_file=None, force_creation=False): super().__init__(db_file, force_creation) @@ -198,8 +198,6 @@ class IdentifiableCache(AbstractCache): - identifiables is the actual cache. - version is a table with version information about the cache. """ - conn = sqlite3.connect(self.db_file) - c = conn.cursor() self.run_sql_commands([ ('''CREATE TABLE identifiables (digest TEXT PRIMARY KEY, caosdb_id INTEGER, caosdb_version TEXT)''',), ('''CREATE TABLE version (schema INTEGER)''',), @@ -325,25 +323,18 @@ class IdentifiableCache(AbstractCache): class UpdateCache(AbstractCache): """ - stores unauthorized updates + stores unauthorized inserts and updates - If the Guard is set to a mode that does not allow an update, the update can - be stored in this cache such that it can be authorized and done later. + If the Guard is set to a mode that does not allow an insert or update, the insert or update can + be stored in this cache such that it can be authorized and performed later. """ def get_cache_schema_version(self): - return 1 + return 3 def get_default_file_name(self): return "/tmp/crawler_update_cache.db" - def __init__(self, db_file=None, force_creation=False): - if db_file is None: - tmppath = tempfile.gettempdir() - tmpf = os.path.join(tmppath, "crawler_update_cache.db") - db_file = tmpf - super().__init__(db_file=db_file, force_creation=force_creation) - @staticmethod def get_previous_version(cont): """ Retrieve the current, unchanged version of the entities that shall @@ -357,59 +348,75 @@ class UpdateCache(AbstractCache): return old_ones - def insert(self, cont, run_id): - """Insert a pending, unauthorized update + def insert(self, cont, run_id, insert=False): + """Insert a pending, unauthorized insert or update Parameters ---------- - cont: Container with the records to be updated containing the desired + cont: Container with the records to be inserted or updated containing the desired version, i.e. the state after the update. run_id: int The id of the crawler run + insert: bool + Whether the entities in the container shall be inserted or updated. """ cont = put_in_container(cont) - old_ones = UpdateCache.get_previous_version(cont) + + if insert: + old_ones = "" + else: + old_ones = UpdateCache.get_previous_version(cont) new_ones = cont - old_hash = Cache.hash_entity(old_ones) + if insert: + old_hash = "" + else: + old_hash = Cache.hash_entity(old_ones) new_hash = Cache.hash_entity(new_ones) - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''INSERT INTO updates VALUES (?, ?, ?, ?, ?)''', - (old_hash, new_hash, str(old_ones), str(new_ones), - str(run_id))) - conn.commit() - conn.close() + self.run_sql_commands([('''INSERT INTO updates VALUES (?, ?, ?, ?, ?)''', + (old_hash, new_hash, str(old_ones), str(new_ones), + str(run_id)))]) def create_cache(self): """ initialize the cache """ - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''CREATE TABLE updates (olddigest text, newdigest text, - oldrep text, newrep text, run_id text, - primary key (olddigest, newdigest, run_id))''') - conn.commit() - conn.close() + self.run_sql_commands([ + ('''CREATE TABLE updates (olddigest TEXT, newdigest TEXT, oldrep TEXT, + newrep TEXT, run_id TEXT, primary key (olddigest, newdigest, run_id))''',), + ('''CREATE TABLE version (schema INTEGER)''',), + ("INSERT INTO version VALUES (?)", (self.get_cache_schema_version(),))]) - def get_updates(self, run_id): + def get(self, run_id, querystring): """ returns the pending updates for a given run id Parameters: ----------- run_id: the id of the crawler run + querystring: the sql query """ - conn = sqlite3.connect(self.db_file) - c = conn.cursor() - c.execute('''Select * FROM updates WHERE run_id=?''', - (str(run_id),)) - res = c.fetchall() - conn.commit() - conn.close() + return self.run_sql_commands([(querystring, (str(run_id),))], fetchall=True) + + def get_inserts(self, run_id): + """ returns the pending updates for a given run id + + Parameters: + ----------- + run_id: the id of the crawler run + """ + + return self.get(run_id, '''Select * FROM updates WHERE olddigest='' AND run_id=?''') + + def get_updates(self, run_id): + """ returns the pending updates for a given run id + + Parameters: + ----------- + run_id: the id of the crawler run + """ - return res + return self.get(run_id, '''Select * FROM updates WHERE olddigest!='' AND run_id=?''') class Cache(IdentifiableCache): diff --git a/src/caosadvancedtools/crawler.py b/src/caosadvancedtools/crawler.py index 5affebe8b11a580d2f94771ac0de5ee8bea76ea0..085cd8d27f261644b38061d26fb10e37ac5465fd 100644 --- a/src/caosadvancedtools/crawler.py +++ b/src/caosadvancedtools/crawler.py @@ -209,28 +209,70 @@ class Crawler(object): run_id: the id of the crawler run """ cache = UpdateCache() + inserts = cache.get_inserts(run_id) + all_inserts = 0 + all_updates = 0 + for _, _, _, new, _ in inserts: + new_cont = db.Container() + new_cont = new_cont.from_xml(new) + new_cont.insert(unique=False) + logger.info("Successfully inserted {} records!".format(len(new_cont))) + all_inserts += len(new_cont) + logger.info("Finished with authorized updates.") + changes = cache.get_updates(run_id) for _, _, old, new, _ in changes: - current = db.Container() new_cont = db.Container() new_cont = new_cont.from_xml(new) + ids = [] + tmp = [] + update_incomplete = False + # remove duplicate entities + for el in new_cont: + if el.id not in ids: + ids.append(el.id) + tmp.append(el) + else: + update_incomplete = True + new_cont = tmp + if new[0].version: + valids = db.Container() + nonvalids = db.Container() + + for ent in new_cont: + remote_ent = db.Entity(id=ent.id).retrieve() + if ent.version == remote_ent.version: + valids.append(remote_ent) + else: + update_incomplete = True + nonvalids.append(remote_ent) + valids.update(unique=False) + logger.info("Successfully updated {} records!".format( + len(valids))) + logger.info("{} Records were not updated because the version in the server " + "changed!".format(len(nonvalids))) + all_updates += len(valids) + else: + current = db.Container() - for ent in new_cont: - current.append(db.execute_query("FIND {}".format(ent.id), - unique=True)) - current_xml = get_pretty_xml(current) + for ent in new_cont: + current.append(db.Entity(id=ent.id).retrieve()) + current_xml = get_pretty_xml(current) - # check whether previous version equals current version - # if not, the update must not be done + # check whether previous version equals current version + # if not, the update must not be done - if current_xml != old: - continue + if current_xml != old: + continue - new_cont.update(unique=False) - logger.info("Successfully updated {} records!".format( - len(new_cont))) + new_cont.update(unique=False) + logger.info("Successfully updated {} records!".format( + len(new_cont))) + all_updates += len(new_cont) + logger.info("Some updates could not be applied. Crawler has to rerun.") logger.info("Finished with authorized updates.") + return all_inserts, all_updates def collect_cfoods(self): """ diff --git a/unittests/test_update_cache.py b/unittests/test_update_cache.py index 4720f23de0b651b90e3b74ee13e06088462c5e31..8376da482b4828dd09de2ac6f3aca4fb9617c08d 100644 --- a/unittests/test_update_cache.py +++ b/unittests/test_update_cache.py @@ -42,8 +42,8 @@ class CacheTest(unittest.TestCase): return c def setUp(self): - self.cache = UpdateCache(db_file=NamedTemporaryFile(delete=False).name) - self.cache.create_cache() + self.cache = UpdateCache(db_file=NamedTemporaryFile(delete=False).name, + force_creation=True) self.run_id = "235234" def test_insert(self):