diff --git a/src/caoscrawler/crawl.py b/src/caoscrawler/crawl.py index a3983515ef503b571bfa344421465331a7d0e394..dd95f32e6d0c7d23eb7bef42ee5c9894579426d7 100644 --- a/src/caoscrawler/crawl.py +++ b/src/caoscrawler/crawl.py @@ -555,14 +555,15 @@ class Crawler(object): return False @staticmethod - def create_flat_list(ent_list: list[db.Entity], flat: list[db.Entity]): + def create_flat_list(ent_list: list[db.Entity], flat: set[db.Entity]): """ - Recursively adds all properties contained in entities from ent_list to - the output list flat. Each element will only be added once to the list. + Recursively adds entities and all their properties contained in ent_list to + the output set flat. TODO: This function will be moved to pylib as it is also needed by the high level API. """ + flat.update(ent_list) for ent in ent_list: for p in ent.properties: # For lists append each element that is of type Entity to flat: @@ -570,11 +571,11 @@ class Crawler(object): for el in p.value: if isinstance(el, db.Entity): if el not in flat: - flat.append(el) + flat.add(el) Crawler.create_flat_list([el], flat) elif isinstance(p.value, db.Entity): if p.value not in flat: - flat.append(p.value) + flat.add(p.value) Crawler.create_flat_list([p.value], flat) def _has_missing_object_in_references(self, ident: Identifiable, referencing_entities: list): @@ -745,8 +746,7 @@ class Crawler(object): def split_into_inserts_and_updates(self, ent_list: list[db.Entity]): to_be_inserted: list[db.Entity] = [] to_be_updated: list[db.Entity] = [] - flat = list(ent_list) - # assure all entities are direct members TODO Can this be removed at some point?Check only? + flat = set() Crawler.create_flat_list(ent_list, flat) # TODO: can the following be removed at some point diff --git a/unittests/test_tool.py b/unittests/test_tool.py index 71180b17e22409bc2491a51d4cdd45ed6f4aa346..ff8998b94d352727b284bd09abdcdcd8c3f80646 100755 --- a/unittests/test_tool.py +++ b/unittests/test_tool.py @@ -710,8 +710,24 @@ def test_create_reference_mapping(): def test_create_flat_list(): a = db.Record() + b = db.Record() a.add_property(name="a", value=a) - Crawler.create_flat_list([a], []) + a.add_property(name="b", value=b) + flat = set() + Crawler.create_flat_list([a], flat) + assert len(flat) == 2 + assert a in flat + assert b in flat + c = db.Record() + c.add_property(name="a", value=a) + # This would caus recursion if it is not dealt with properly. + a.add_property(name="c", value=c) + flat = set() + Crawler.create_flat_list([c], flat) + assert len(flat) == 3 + assert a in flat + assert b in flat + assert c in flat @pytest.fixture