From 49a118d52be72a8c51506ad48f1b4d3d7afe9678 Mon Sep 17 00:00:00 2001
From: Alexander Schlemmer <alexander.schlemmer@ds.mpg.de>
Date: Tue, 24 Nov 2020 12:00:35 +0100
Subject: [PATCH] ENH: added cache validation before loading to crawler

---
 src/caosadvancedtools/crawler.py |  1 +
 unittests/test_cache.py          | 56 ++++++++++++++++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/src/caosadvancedtools/crawler.py b/src/caosadvancedtools/crawler.py
index 6d706bb3..2877cce5 100644
--- a/src/caosadvancedtools/crawler.py
+++ b/src/caosadvancedtools/crawler.py
@@ -114,6 +114,7 @@ class Crawler(object):
 
         if self.use_cache:
             self.cache = Cache(db_file=cache_file)
+            self.cache.validate_cache()
 
     def iteritems(self):
         """ generates items to be crawled with an index"""
diff --git a/unittests/test_cache.py b/unittests/test_cache.py
index aa1655da..0e38201f 100644
--- a/unittests/test_cache.py
+++ b/unittests/test_cache.py
@@ -174,3 +174,59 @@ class CacheTest2(unittest.TestCase):
 
         for db_fn_defect in self.db_file_defect:
             os.remove(db_fn_defect)
+
+class InvalidationTest(unittest.TestCase):
+    """
+    Test invalidation of cache entries.
+    """
+    
+    def setUp(self):
+        # Correct version:
+        self.cache = Cache(db_file=NamedTemporaryFile(delete=False).name,
+                           force_creation=True)
+
+    def tearDown(self):
+        os.remove(self.cache.db_file)
+
+
+    def test_invalid(self):
+        assert len(self.cache.validate_cache()) == 0
+        
+        ent = db.Record()
+        ent2 = db.Record()
+        ent2.add_parent(name="Experiment")
+        ent3 = db.Record()
+        ent3.add_parent(name="Analysis")
+        ent.id = 117
+        ent2.id = 328
+        ent3.id = 224
+
+        ent.version = db.common.versioning.Version("a")
+        ent2.version = db.common.versioning.Version("b")
+        ent3.version = db.common.versioning.Version("a")
+
+        el = [ent, ent2, ent3]
+
+        for e in el:
+            self.cache.insert(Cache.hash_entity(e), e.id, e.version.id)
+
+        for e in el:
+            res = self.cache.check_existing(Cache.hash_entity(e))
+            assert e.id == res[0]
+            assert e.version.id == res[1]
+
+        ent2.version.id = "c"
+        ent3.version.id = "b"
+
+        for e in el[1:]:
+            res = self.cache.check_existing(Cache.hash_entity(e))
+            assert res is None
+
+        invalidated_entries = self.cache.validate_cache(el)
+        assert 328 in invalidated_entries
+        assert 224 in invalidated_entries
+        assert 117 not in invalidated_entries
+
+        res = self.cache.run_sql_commands([
+            ("SELECT * FROM identifiables", ())], fetchall=True)
+        assert len(res) == 1
-- 
GitLab