diff --git a/src/caosadvancedtools/cache.py b/src/caosadvancedtools/cache.py index 63ee5f6e24b03dc8a9b2ae6edd49ac840ef4fadb..620777bf882fc7a08f3d20975af5976a9ea61ffc 100644 --- a/src/caosadvancedtools/cache.py +++ b/src/caosadvancedtools/cache.py @@ -87,6 +87,7 @@ class Cache(object): If it exists, but the schema is outdated, an exception will be raised. """ + print(self.db_file) if not os.path.exists(self.db_file): self.create_cache() else: @@ -94,14 +95,17 @@ class Cache(object): current_schema = self.get_cache_version() except sqlite3.OperationalError: current_schema = 1 - # TODO: Write unit tests for too old, too new and non-existent version of cache. - + if current_schema > CACHE_SCHEMA_VERSION: raise RuntimeError("Cache is corrupt or was created with a future version of this program.") elif current_schema < CACHE_SCHEMA_VERSION: raise RuntimeError("Cache version too old.") def get_cache_version(self): + """ + Return the version of the cache stored in self.db_file. + The version is stored as the only entry in colum schema of table version. + """ try: conn = sqlite3.connect(self.db_file) c = conn.cursor() @@ -114,6 +118,13 @@ class Cache(object): conn.close() def create_cache(self): + """ + Create a new SQLITE cache file in self.db_file. + + Two tables will be created: + - identifiables is the actual cache. + - version is a table with version information about the cache. + """ conn = sqlite3.connect(self.db_file) c = conn.cursor() c.execute( @@ -127,23 +138,41 @@ class Cache(object): @staticmethod def hash_entity(ent): + """ + Format an entity as "pretty" XML and return the SHA256 hash. + """ xml = get_pretty_xml(ent) digest = sha256(xml.encode("utf-8")).hexdigest() return digest - def insert(self, ent_hash, ent_id): + def insert(self, ent_hash, ent_id, ent_version): + """ + Insert a new cache entry. + + ent_hash: Hash of the entity. Should be generated with Cache.hash_entity + ent_id: ID of the entity + ent_version: Version string of the entity + """ conn = sqlite3.connect(self.db_file) c = conn.cursor() - c.execute('''INSERT INTO identifiables VALUES (?, ?)''', - (ent_hash, ent_id)) + c.execute('''INSERT INTO identifiables VALUES (?, ?, ?)''', + (ent_hash, ent_id, ent_version)) conn.commit() conn.close() def check_existing(self, ent_hash): + """ + Check the cache for a hash. + + ent_hash: The hash to search for. + + Return the ID and the version ID of the hashed entity. + Return None if no entity with that hash is in the cache. + """ conn = sqlite3.connect(self.db_file) c = conn.cursor() - c.execute('''Select * FROM identifiables WHERE digest=?''', + c.execute('''Select * FROM identifiables WHERE digest=?''', (ent_hash,)) res = c.fetchone() conn.commit() @@ -152,7 +181,7 @@ class Cache(object): if res is None: return res else: - return res[1] + return res[1:] def update_ids_from_cache(self, entities): """ sets ids of those entities that are in cache @@ -167,7 +196,7 @@ class Cache(object): eid = self.check_existing(ehash) if eid is not None: - ent.id = eid + ent.id = eid[0] return hashes @@ -177,9 +206,16 @@ class Cache(object): The hashes must correspond to the entities in the list """ + # Check whether all entities have IDs and versions: + for ent in entities: + if ent.id is None: + raise RuntimeError("Entity has no ID.") + if ent.version is None or ent.version.id is None: + raise RuntimeError("Entity has no version ID.") + for ehash, ent in zip(hashes, entities): if self.check_existing(ehash) is None: - self.insert(ehash, ent.id) + self.insert(ehash, ent.id, ent.version.id) class UpdateCache(Cache): @@ -192,6 +228,8 @@ class UpdateCache(Cache): def __init__(self, db_file=None): if db_file is None: + # TODO: check whether a hardcoded temp file is really wanted + # Why not crawler_update_cache.db in current working directory? db_file = "/tmp/crawler_update_cache.db" super().__init__(db_file=db_file) diff --git a/unittests/test_cache.py b/unittests/test_cache.py index 662f27a3b57439a46249f7b6c607b4c571709be1..aa1655dae3bc8ca2eb89a5efe927a3073b6cb878 100644 --- a/unittests/test_cache.py +++ b/unittests/test_cache.py @@ -36,7 +36,7 @@ import pytest class CacheTest(unittest.TestCase): def setUp(self): self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name, - force_creation) + force_creation=True) def test_hash(self): ent = db.Record() @@ -50,8 +50,8 @@ class CacheTest(unittest.TestCase): ent2.add_parent(name="Experiment") ent_hash = Cache.hash_entity(ent) ent2_hash = Cache.hash_entity(ent2) - self.cache.insert(ent2_hash, 1235) - assert isinstance(self.cache.check_existing(ent2_hash), int) + self.cache.insert(ent2_hash, 1235, "ajkfljadsklf") + assert isinstance(self.cache.check_existing(ent2_hash)[0], int) assert self.cache.check_existing(ent_hash) is None def test_hirarchy(self): @@ -67,17 +67,32 @@ class CacheTest(unittest.TestCase): ent3 = db.Record() ent3.add_parent(name="Analysis") test_id = 2353243 - self.cache.insert(Cache.hash_entity(ent2), test_id) + self.cache.insert(Cache.hash_entity(ent2), test_id, "ajdsklfjadslf") entities = [ent, ent2, ent3] hashes = self.cache.update_ids_from_cache(entities) self.assertEqual(ent2.id, test_id) + # TODO: is that wanted? + self.assertEqual(ent.id, -1) + self.assertEqual(ent3.id, -1) + + # TODO: I expected this instead: + # with pytest.raises(RuntimeError, match=r".*no ID.*"): + # self.cache.insert_list(hashes, entities) + # test ent.id = 1001 ent3.id = 1003 + with pytest.raises(RuntimeError, match=r".*no version ID.*"): + self.cache.insert_list(hashes, entities) + + ent.version = db.common.versioning.Version("jkadsjfldf") + ent2.version = db.common.versioning.Version("jkadsjfldf") + ent3.version = db.common.versioning.Version("jkadsjfldf") + self.cache.insert_list(hashes, entities) - self.assertEqual(self.cache.check_existing(hashes[0]), 1001) - self.assertEqual(self.cache.check_existing(hashes[2]), 1003) + self.assertEqual(self.cache.check_existing(hashes[0])[0], 1001) + self.assertEqual(self.cache.check_existing(hashes[2])[0], 1003) def create_sqlite_file(commands): """