From 338c635af7725c41d62eae378177065a4b41fdfc Mon Sep 17 00:00:00 2001
From: Florian Spreckelsen <florian.spreckelsen@gmx.net>
Date: Tue, 12 Jan 2021 15:02:57 +0000
Subject: [PATCH] ENH: Support in-place updates in assure_object_is_in_list

---
 CHANGELOG.md                                  |  2 +
 integrationtests/test.sh                      |  4 +-
 integrationtests/test_assure_functions.py     | 93 +++++++++++++++++++
 src/caosadvancedtools/cfood.py                | 28 ++++--
 src/caosadvancedtools/scifolder/utils.py      |  8 +-
 src/caosadvancedtools/scifolder/withreadme.py |  3 +
 6 files changed, 127 insertions(+), 11 deletions(-)
 create mode 100644 integrationtests/test_assure_functions.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f4aa6af8..f9519355 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -40,6 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 - instead of `get_entity`, type-specific functions are used in
   `cfood.py` when the type of the entity in question is known.
 - Logger is used instead of `print` for errors in `crawler.py`.
+- `caosadvancedtools.cfood.assure_object_is_in_list` conducts in-place
+  updates if no `to_be_updated` object is supplied.
 
 ### Deprecated ###
 
diff --git a/integrationtests/test.sh b/integrationtests/test.sh
index 74ca823f..a56b7584 100755
--- a/integrationtests/test.sh
+++ b/integrationtests/test.sh
@@ -11,6 +11,8 @@ echo "Testing caching"
 python3 -m pytest test_cache.py
 echo "Testing models"
 python3 -m pytest test_data_model.py
+echo "Testing cfood functionality"
+python3 -m pytest test_assure_functions.py
 
 echo "Filling the database"
 ./filldb.sh
@@ -34,7 +36,7 @@ RUN_ID=$(grep "run id:" $OUT | awk '{ print $NF }')
 echo $RUN_ID
 echo "run crawler again"
 echo "./crawl.py -a $RUN_ID /"
-./crawl.py -a $RUN_ID / > $OUT
+./crawl.py -a $RUN_ID / | tee  $OUT
 set +e
 if grep "There where unauthorized changes" $OUT
 then 
diff --git a/integrationtests/test_assure_functions.py b/integrationtests/test_assure_functions.py
new file mode 100644
index 00000000..56f9767a
--- /dev/null
+++ b/integrationtests/test_assure_functions.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python
+# encoding: utf-8
+#
+# ** header v3.0
+# This file is a part of the CaosDB Project.
+#
+# Copyright (C) 2021 University Medical Center Göttingen, Institute for Medical Informatics
+# Copyright (C) 2021 Florian Spreckelsen <florian.spreckelsen@med.uni-goettingen.de>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+# ** end header
+"""Integration tests for the `assure_...` functions from
+`caosadvancedtools.cfood`. They mainly test the in-place updates when
+no `to_be_updated` is specified.
+
+"""
+import caosdb as db
+
+from caosadvancedtools.cfood import (assure_object_is_in_list)
+from caosadvancedtools.guard import (global_guard, RETRIEVE, UPDATE)
+
+
+def setup_module():
+    """Delete all test entities."""
+    db.execute_query("FIND Test*").delete(raise_exception_on_error=False)
+
+
+def setup():
+    """Allow all updates and delete test data"""
+    global_guard.level = UPDATE
+    setup_module()
+
+
+def teardown():
+    """Reset guard level and delete test data."""
+    global_guard.level = RETRIEVE
+    setup_module()
+
+
+def test_assure_list_in_place():
+    """Test an in-place update with `assure_object_is_in_list`."""
+
+    int_list_prop = db.Property(name="TestIntListProperty",
+                                datatype=db.LIST(db.INTEGER)).insert()
+    rt1 = db.RecordType(name="TestType1").add_property(
+        name=int_list_prop.name).insert()
+    rec1 = db.Record(name="TestRecord1").add_parent(rt1)
+    rec1.add_property(name=int_list_prop.name, value=[1]).insert()
+
+    # Nothing should happen:
+    assure_object_is_in_list(1, rec1, int_list_prop.name, to_be_updated=None)
+    assert len(rec1.get_property(int_list_prop.name).value) == 1
+    assert 1 in rec1.get_property(int_list_prop.name).value
+
+    # Insertion should happen in-place
+    assure_object_is_in_list(2, rec1, int_list_prop.name, to_be_updated=None)
+    assert len(rec1.get_property(int_list_prop.name).value) == 2
+    assert 2 in rec1.get_property(int_list_prop.name).value
+
+    # Better safe than sorry -- test for reference properties, too.
+    ref_rt = db.RecordType(name="TestRefType").insert()
+    ref_rec1 = db.Record(name="TestRefRec1").add_parent(ref_rt).insert()
+    ref_rec2 = db.Record(name="TestRefRec2").add_parent(ref_rt).insert()
+    ref_rec3 = db.Record(name="TestRefRec3").add_parent(ref_rt).insert()
+    rt2 = db.RecordType(name="TestType2").add_property(
+        name=ref_rt.name, datatype=db.LIST(ref_rt.name)).insert()
+    rec2 = db.Record(name="TestRecord2").add_parent(rt2)
+    rec2.add_property(name=ref_rt.name, value=[ref_rec1],
+                      datatype=db.LIST(ref_rt.name)).insert()
+
+    # Again, nothing should happen
+    assure_object_is_in_list(ref_rec1, rec2, ref_rt.name, to_be_updated=None)
+    assert len(rec2.get_property(ref_rt.name).value) == 1
+    assert ref_rec1.id in rec2.get_property(ref_rt.name).value
+
+    # In-place update with two additional references
+    assure_object_is_in_list([ref_rec2, ref_rec3],
+                             rec2, ref_rt.name, to_be_updated=None)
+    assert len(rec2.get_property(ref_rt.name).value) == 3
+    assert ref_rec2.id in rec2.get_property(ref_rt.name).value
+    assert ref_rec3.id in rec2.get_property(ref_rt.name).value
diff --git a/src/caosadvancedtools/cfood.py b/src/caosadvancedtools/cfood.py
index 54a6b809..680f592b 100644
--- a/src/caosadvancedtools/cfood.py
+++ b/src/caosadvancedtools/cfood.py
@@ -8,6 +8,8 @@
 # Max-Planck-Institute for Dynamics and Self-Organization Göttingen
 # Copyright (C) 2019,2020 Henrik tom Wörden
 # Copyright (C) 2020 Florian Spreckelsen <f.spreckelsen@indiscale.com>
+# Copyright (C) 2021 University Medical Center Göttingen, Institute for Medical Informatics
+# Copyright (C) 2021 Florian Spreckelsen <florian.spreckelsen@med.uni-goettingen.de>
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as
@@ -349,19 +351,22 @@ class AbstractFileCFood(AbstractCFood):
 
 
 def assure_object_is_in_list(obj, containing_object, property_name,
-                             to_be_updated, datatype=None):
-    """
-    Checks whether `obj` is one of the values in the list property
+                             to_be_updated=None, datatype=None):
+    """Checks whether `obj` is one of the values in the list property
     `property_name` of the supplied entity  containing_object`.
 
-    If this is the case this function returns. Otherwise the entity is added to
-    the property `property_name` and the entity `containing_object` is added to
-    the supplied list to_be_updated in order to indicate, that the entity
-    `containing_object` should be updated.
+    If this is the case this function returns. Otherwise the entity is
+    added to the property `property_name` and the entity
+    `containing_object` is added to the supplied list to_be_updated in
+    order to indicate, that the entity `containing_object` should be
+    updated. If none is submitted the update will be conducted
+    in-place.
 
-    If the property is missing, it is added first and then the entity is added.
+    If the property is missing, it is added first and then the entity
+    is added/updated.
 
     If obj is a list, every element is added
+
     """
 
     if datatype is None:
@@ -409,7 +414,12 @@ def assure_object_is_in_list(obj, containing_object, property_name,
             update = True
 
     if update:
-        to_be_updated.append(containing_object)
+        if to_be_updated is not None:
+            to_be_updated.append(containing_object)
+        else:
+            get_ids_for_entities_with_names([containing_object])
+
+            guard.safe_update(containing_object)
 
 
 def assure_special_is(entity, value, kind, to_be_updated=None, force=False):
diff --git a/src/caosadvancedtools/scifolder/utils.py b/src/caosadvancedtools/scifolder/utils.py
index 3241764f..afa671af 100644
--- a/src/caosadvancedtools/scifolder/utils.py
+++ b/src/caosadvancedtools/scifolder/utils.py
@@ -17,6 +17,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program. If not, see <https://www.gnu.org/licenses/>.
 
+import logging
 import os
 from itertools import chain
 
@@ -28,6 +29,8 @@ from caosadvancedtools.utils import (find_records_that_reference_ids,
                                      return_field_or_property,
                                      string_to_person)
 
+logger = logging.getLogger("caosadvancedtools")
+
 
 def parse_responsibles(header):
     """
@@ -71,6 +74,7 @@ def get_files_referenced_by_field(globs, prefix="", final_glob=None):
             glob = os.path.normpath(glob)
 
         query_string = "FIND file which is stored at {}".format(glob)
+        logger.debug(query_string)
 
         el = db.execute_query(query_string)
 
@@ -126,13 +130,15 @@ def reference_records_corresponding_to_files(record, recordtypes, globs, path,
         files_in_folders = list(chain(*get_files_referenced_by_field(
             globs,
             prefix=os.path.dirname(path),
-            final_glob="**")))
+            final_glob="/**")))
         files = [f for f in directly_named_files + files_in_folders if
                  is_filename_allowed(f.path, recordtype=recordtype)]
+        logger.debug("Referenced files:\n" + str(files))
         entities = find_records_that_reference_ids(
             list(set([
                 fi.id for fi in files])),
             rt=recordtype)
+        logger.debug("Referencing entities:\n" + str(entities))
 
         if len(entities) == 0:
             continue
diff --git a/src/caosadvancedtools/scifolder/withreadme.py b/src/caosadvancedtools/scifolder/withreadme.py
index b3eb1095..c36fb736 100644
--- a/src/caosadvancedtools/scifolder/withreadme.py
+++ b/src/caosadvancedtools/scifolder/withreadme.py
@@ -86,6 +86,9 @@ def get_glob(field):
     if it is a dict, it must have either an include or a file key"""
     globs = []
 
+    if not isinstance(field, list):
+        field = [field]
+
     for value in field:
 
         if isinstance(value, dict) and INCLUDE.key in value:
-- 
GitLab