diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d42e0b26a2b493d210ff0a55105d70cfd6c90a3..ebdfab4bc64e8640207e7e678cedb4bd1698fb98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed ### +- Detection for cyclic references when converting entites using the high level API. + ### Security ### ### Documentation ### diff --git a/src/caosdb/high_level_api.py b/src/caosdb/high_level_api.py index 005a20bbba26fd5bee16eac612bd8ebe81f1294a..82c63feaa0df7e6a57897f3596515387469fc64f 100644 --- a/src/caosdb/high_level_api.py +++ b/src/caosdb/high_level_api.py @@ -265,7 +265,8 @@ class CaosDBPythonEntity(object): self._version = val def _set_property_from_entity(self, ent: db.Entity, importance: str, - references: Optional[db.Container]): + references: Optional[db.Container], + visited: Dict[int, "CaosDBPythonEntity"]): """ Set a new property using an entity from the normal python API. @@ -280,7 +281,7 @@ class CaosDBPythonEntity(object): raise RuntimeError("Multiproperty not implemented yet.") val = self._type_converted_value(ent.value, ent.datatype, - references) + references, visited) self.set_property( ent.name, val, @@ -382,7 +383,8 @@ class CaosDBPythonEntity(object): def _type_converted_list(self, val: List, pr: str, - references: Optional[db.Container]): + references: Optional[db.Container], + visited: Dict[int, "CaosDBPythonEntity"]): """ Convert a list to a python list of the correct type. @@ -396,13 +398,14 @@ class CaosDBPythonEntity(object): raise RuntimeError("Not a list.") return [ - self._type_converted_value(i, get_list_datatype(pr), references - ) for i in val] + self._type_converted_value(i, get_list_datatype(pr), references, + visited) for i in val] def _type_converted_value(self, val: Any, pr: str, - references: Optional[db.Container]): + references: Optional[db.Container], + visited: Dict[int, "CaosDBPythonEntity"]): """ Convert val to the correct type which is indicated by the database type string in pr. @@ -416,9 +419,9 @@ class CaosDBPythonEntity(object): # this needs to be checked as second case as it is the ONLY # case which does not depend on pr # TODO: we might need to pass through the reference container - return convert_to_python_object(val, references) + return convert_to_python_object(val, references, visited) elif isinstance(val, list): - return self._type_converted_list(val, pr, references) + return self._type_converted_list(val, pr, references, visited) elif pr is None: return val elif pr == DOUBLE: @@ -436,7 +439,7 @@ class CaosDBPythonEntity(object): elif pr == DATETIME: return self._parse_datetime(val) elif is_list_datatype(pr): - return self._type_converted_list(val, pr, references) + return self._type_converted_list(val, pr, references, visited) else: # Generic references to entities: return CaosDBPythonUnresolvedReference(val) @@ -561,8 +564,8 @@ class CaosDBPythonEntity(object): return propval def resolve_references(self, deep: bool, references: db.Container, - visited: Dict[Union[str, int], - "CaosDBPythonEntity"] = None): + visited: Optional[Dict[Union[str, int], + "CaosDBPythonEntity"]] = None): """ Resolve this entity's references. This affects unresolved properties as well as unresolved parents. @@ -807,7 +810,9 @@ BASE_ATTRIBUTES = ( def _single_convert_to_python_object(robj: CaosDBPythonEntity, entity: db.Entity, - references: Optional[db.Container] = None): + references: Optional[db.Container] = None, + visited: Optional[Dict[int, + "CaosDBPythonEntity"]] = None): """ Convert a db.Entity from the standard API to a (previously created) CaosDBPythonEntity from the high level API. @@ -822,6 +827,17 @@ def _single_convert_to_python_object(robj: CaosDBPythonEntity, Returns the input object robj. """ + + # This parameter is used in the recursion to keep track of already visited + # entites (in order to detect cycles). + if visited is None: + visited = dict() + + if id(entity) in visited: + return visited[id(entity)] + else: + visited[id(entity)] = robj + for base_attribute in BASE_ATTRIBUTES: val = entity.__getattribute__(base_attribute) if val is not None: @@ -830,7 +846,8 @@ def _single_convert_to_python_object(robj: CaosDBPythonEntity, robj.__setattr__(base_attribute, val) for prop in entity.properties: - robj._set_property_from_entity(prop, entity.get_importance(prop), references) + robj._set_property_from_entity(prop, entity.get_importance(prop), references, + visited) for parent in entity.parents: robj.add_parent(CaosDBPythonUnresolvedParent(id=parent.id, @@ -924,7 +941,9 @@ def convert_to_entity(python_object): def convert_to_python_object(entity: Union[db.Container, db.Entity], - references: Optional[db.Container] = None): + references: Optional[db.Container] = None, + visited: Optional[Dict[int, + "CaosDBPythonEntity"]] = None): """ Convert either a container of CaosDB entities or a single CaosDB entity into the high level representation. @@ -936,15 +955,19 @@ def convert_to_python_object(entity: Union[db.Container, db.Entity], """ if isinstance(entity, db.Container): # Create a list of objects: - return [convert_to_python_object(i, references) for i in entity] + return [convert_to_python_object(i, references, visited) for i in entity] + # TODO: recursion problems? return _single_convert_to_python_object( - high_level_type_for_standard_type(entity)(), entity, references) + high_level_type_for_standard_type(entity)(), + entity, + references, + visited) def new_high_level_entity(entity: db.RecordType, importance_level: str, - name: str = None): + name: Optional[str] = None): """ Create an new record in high level format based on a record type in standard format. @@ -977,7 +1000,7 @@ def new_high_level_entity(entity: db.RecordType, return convert_to_python_object(r) -def create_record(rtname: str, name: str = None, **kwargs): +def create_record(rtname: str, name: Optional[str] = None, **kwargs): """ Create a new record based on the name of a record type. The new record is returned. @@ -1016,7 +1039,9 @@ def create_entity_container(record: CaosDBPythonEntity): return db.Container().extend(lse) -def query(query: str, resolve_references: bool = True, references: db.Container = None): +def query(query: str, + resolve_references: Optional[bool] = True, + references: Optional[db.Container] = None): """ """ diff --git a/unittests/test_high_level_api.py b/unittests/test_high_level_api.py index 51993b78b700236618daef2f07dbe754121384c4..ea5e635eadaa849480de5f3ece10b813a538a1b0 100644 --- a/unittests/test_high_level_api.py +++ b/unittests/test_high_level_api.py @@ -641,3 +641,14 @@ def test_recursion_advanced(get_record_container): r.resolve_references(r, get_record_container) d = r.serialize(True) assert r == r.sources[0] + + +def test_cyclic_references(): + r1 = db.Record() + r2 = db.Record() + r1.add_property(name="ref_to_two", value=r2) + r2.add_property(name="ref_to_one", value=r1) + + # This would have lead to a recursion error before adding the detection for + # cyclic references: + r = convert_to_python_object(r1)