diff --git a/src/caoscrawler/crawl.py b/src/caoscrawler/crawl.py index 3385a1b5a82ed6b577836422411dfb4e0e73efbc..3b8ba58d7fee9a30ed55665ca95334745627ba8b 100644 --- a/src/caoscrawler/crawl.py +++ b/src/caoscrawler/crawl.py @@ -555,7 +555,7 @@ class Crawler(object): return False @staticmethod - def create_flat_list(ent_list: list[db.Entity], flat: set[db.Entity]): + def create_flat_list(ent_list: list[db.Entity], flat: Optional[set[db.Entity]]): """ Recursively adds entities and all their properties contained in ent_list to the output set flat. @@ -563,6 +563,8 @@ class Crawler(object): TODO: This function will be moved to pylib as it is also needed by the high level API. """ + if flat is None: + flat = set() flat.update(ent_list) for ent in ent_list: for p in ent.properties: @@ -577,6 +579,7 @@ class Crawler(object): if p.value not in flat: flat.add(p.value) Crawler.create_flat_list([p.value], flat) + return list(flat) def _has_missing_object_in_references(self, ident: Identifiable, referencing_entities: list): """ @@ -747,8 +750,7 @@ class Crawler(object): to_be_inserted: list[db.Entity] = [] to_be_updated: list[db.Entity] = [] flat = set() - Crawler.create_flat_list(ent_list, flat) - flat = list(flat) + flat = Crawler.create_flat_list(ent_list, flat) # TODO: can the following be removed at some point for ent in flat: diff --git a/unittests/test_tool.py b/unittests/test_tool.py index ff8998b94d352727b284bd09abdcdcd8c3f80646..e73477e56fcacc841aa4f4acbc058cd470b7b042 100755 --- a/unittests/test_tool.py +++ b/unittests/test_tool.py @@ -713,8 +713,7 @@ def test_create_flat_list(): b = db.Record() a.add_property(name="a", value=a) a.add_property(name="b", value=b) - flat = set() - Crawler.create_flat_list([a], flat) + flat = Crawler.create_flat_list([a]) assert len(flat) == 2 assert a in flat assert b in flat @@ -722,8 +721,7 @@ def test_create_flat_list(): c.add_property(name="a", value=a) # This would caus recursion if it is not dealt with properly. a.add_property(name="c", value=c) - flat = set() - Crawler.create_flat_list([c], flat) + flat = Crawler.create_flat_list([c]) assert len(flat) == 3 assert a in flat assert b in flat