diff --git a/src/caoscrawler/exceptions.py b/src/caoscrawler/exceptions.py index b271cf5bc77b9d709b11ef3e46b95755f71ec41c..8066c7f4d2197ac53280fe47d3adb82def310b62 100644 --- a/src/caoscrawler/exceptions.py +++ b/src/caoscrawler/exceptions.py @@ -31,7 +31,10 @@ class MissingReferencingEntityError(Exception): class ImpossibleMergeError(Exception): - pass + def __init__(self, *args, pname, values, **kwargs): + self.pname = pname + self.values = values + super().__init__(self, *args, **kwargs) class MissingIdentifyingProperty(Exception): diff --git a/src/caoscrawler/identifiable_adapters.py b/src/caoscrawler/identifiable_adapters.py index 2451db648d99ce2c09dd3d9ec1563fbbfa3a664a..43d4eeafde79a513450c7a3385d6c3b04ad4ea48 100644 --- a/src/caoscrawler/identifiable_adapters.py +++ b/src/caoscrawler/identifiable_adapters.py @@ -238,7 +238,7 @@ startswith: bool, optional return query_string[:-4] @abstractmethod - def get_registered_identifiable(self, record: db.Record): + def get_registered_identifiable(self, record: db.Entity): """ Check whether an identifiable is registered for this record and return its definition. If there is no identifiable registered, return None. @@ -274,7 +274,7 @@ startswith: bool, optional refs.append(val) return refs - def get_identifiable(self, se: SyncNode, identifiable_backrefs): + def get_identifiable(self, se: SyncNode, identifiable_backrefs) -> Identifiable: """ Retrieve the registered identifiable and fill the property values to create an identifiable. @@ -312,7 +312,8 @@ startswith: bool, optional raise MissingReferencingEntityError( f"Could not find referencing entities of type(s): {prop.value}\n" f"for registered identifiable:\n{se.registered_identifiable}\n" - f"There were {len(identifiable_backrefs)} referencing entities to choose from.\n" + f"There were {len(identifiable_backrefs) + } referencing entities to choose from.\n" f"This error can also occur in case of merge conflicts in the referencing entities." ) elif len([e.id for e in identifiable_backrefs if el.id is None]) > 0: @@ -479,7 +480,7 @@ class LocalStorageIdentifiableAdapter(IdentifiableAdapter): return False return True - def get_registered_identifiable(self, record: db.Record): + def get_registered_identifiable(self, record: db.Entity): identifiable_candidates = [] for _, definition in self._registered_identifiables.items(): if self.is_identifiable_for_record(definition, record): @@ -599,7 +600,7 @@ class CaosDBIdentifiableAdapter(IdentifiableAdapter): return None return candidates[0] - def get_registered_identifiable(self, record: db.Record): + def get_registered_identifiable(self, record: db.Entity): """ returns the registered identifiable for the given Record diff --git a/src/caoscrawler/sync_graph.py b/src/caoscrawler/sync_graph.py index 45059ec8b63573b29783128d155f050c43240834..89ad3fb08e262d88a5dfe47beaf55e046391fc2c 100644 --- a/src/caoscrawler/sync_graph.py +++ b/src/caoscrawler/sync_graph.py @@ -27,7 +27,7 @@ crawler. from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union, Callable import linkahead as db from linkahead.apiutils import (EntityMergeConflictError, compare_entities, @@ -37,12 +37,14 @@ from linkahead.exceptions import EmptyUniqueQueryError from .exceptions import ImpossibleMergeError, MissingReferencingEntityError from .identifiable_adapters import IdentifiableAdapter -from .sync_node import SyncNode +from .identifiable import Identifiable +from .sync_node import SyncNode, TempID logger = logging.getLogger(__name__) -def _for_each_scalar_value(node: SyncNode, condition: callable, kind: str, value: Any = None): +def _for_each_scalar_value(node: SyncNode, condition: Callable[[Any], bool], kind: str, + value: Any = None): """ helper function that performs an action on each value element of each property of a node The action (remove or set) is performed on each property value of each property: in case on @@ -68,12 +70,12 @@ def _for_each_scalar_value(node: SyncNode, condition: callable, kind: str, value p.value = value(p.value) -def _remove_each_scalar_value(node: SyncNode, condition: callable): +def _remove_each_scalar_value(node: SyncNode, condition: Callable[[Any], bool]): """ "remove" version of _for_each_scalar_value """ _for_each_scalar_value(node, condition, "remove") -def _set_each_scalar_value(node: SyncNode, condition: callable, value: Any): +def _set_each_scalar_value(node: SyncNode, condition: Callable[[Any], bool], value: Any): """ "set" version of _for_each_scalar_value """ _for_each_scalar_value(node, condition, "set", value=value) @@ -124,18 +126,18 @@ class SyncGraph(): export_record_lists. """ - def __init__(self, entities: List[db.Entity], identifiableAdapter: IdentifiableAdapter): + def __init__(self, entities: list[db.Entity], identifiableAdapter: IdentifiableAdapter): self.identifiableAdapter = identifiableAdapter - self._id_look_up: Dict[int, SyncNode] = {} - self._path_look_up: Dict[str, SyncNode] = {} - self._identifiable_look_up: Dict[str, SyncNode] = {} - self._missing: Dict[int, SyncNode] = {} - self._existing: Dict[int, SyncNode] = {} - self._nonidentifiable: Dict[int, SyncNode] = {} + self._id_look_up: dict[Union[int, TempID, str], SyncNode] = {} + self._path_look_up: dict[str, SyncNode] = {} + self._identifiable_look_up: dict[str, SyncNode] = {} + self._missing: dict[int, SyncNode] = {} + self._existing: dict[int, SyncNode] = {} + self._nonidentifiable: dict[int, SyncNode] = {} # entities that are missing get negative IDs to allow identifiable creation self._remote_missing_counter = -1 - self.nodes: List[SyncNode] = [] + self.nodes: list[SyncNode] = [] self._initialize_nodes(entities) # list of all SemanticEntities # list all SemanticEntities that have not yet been checked self.unchecked = list(self.nodes) @@ -170,13 +172,13 @@ class SyncGraph(): raise RuntimeError('Cannot update ID.\n' f'It already is {node.id} and shall be set to {node_id}.') if node_id is None: - node_id = self._get_new_id() + node_id = TempID(self._get_new_id()) node.id = node_id if node_id in self._id_look_up: self._merge_into(node, self._id_look_up[node.id]) else: self._id_look_up[node.id] = node - if node.id < 0: + if isinstance(node.id, TempID): self._mark_missing(node) else: self._mark_existing(node) @@ -270,6 +272,7 @@ class SyncGraph(): candidate = self._identifiable_look_up[entity.identifiable.get_representation()] if candidate is not entity: return candidate + return None def _get_new_id(self): self._remote_missing_counter -= 1 @@ -291,7 +294,7 @@ class SyncGraph(): self._identifiable_look_up[node.identifiable.get_representation()] = node @staticmethod - def _sanity_check(entities: List[db.Entity]): + def _sanity_check(entities: list[db.Entity]): for ent in entities: if ent.role == "Record" and len(ent.parents) == 0: raise RuntimeError(f"Records must have a parent.\n{ent}") @@ -301,7 +304,7 @@ class SyncGraph(): self.forward_id_referenced_by[id(node)])) @staticmethod - def _create_flat_list(ent_list: List[db.Entity], flat: Optional[List[db.Entity]] = None): + def _create_flat_list(ent_list: list[db.Entity], flat: Optional[list[db.Entity]] = None): """ Recursively adds entities and all their properties contained in ent_list to the output list flat. @@ -331,7 +334,7 @@ class SyncGraph(): return flat @staticmethod - def _create_reference_mapping(flat: List[SyncNode]): + def _create_reference_mapping(flat: list[SyncNode]): """ TODO update docstring Create a dictionary of dictionaries of the form: @@ -346,12 +349,12 @@ class SyncGraph(): to them. """ # TODO we need to treat children of RecordTypes somehow. - forward_references: Dict[str, set[SyncNode]] = {} - backward_references: Dict[str, set[SyncNode]] = {} - forward_id_references: Dict[str, set[SyncNode]] = {} - backward_id_references: Dict[str, set[SyncNode]] = {} - forward_id_referenced_by: Dict[str, set[SyncNode]] = {} - backward_id_referenced_by: Dict[str, set[SyncNode]] = {} + forward_references: dict[int, set[SyncNode]] = {} + backward_references: dict[int, set[SyncNode]] = {} + forward_id_references: dict[int, set[SyncNode]] = {} + backward_id_references: dict[int, set[SyncNode]] = {} + forward_id_referenced_by: dict[int, set[SyncNode]] = {} + backward_id_referenced_by: dict[int, set[SyncNode]] = {} # initialize with empty lists/dict for node in flat: @@ -514,7 +517,7 @@ class SyncGraph(): """ create initial set of SyncNodes from provided Entity list""" entities = self._create_flat_list(entities) self._sanity_check(entities) - se_lookup: Dict[str, SyncNode] = {} # lookup: python id -> SyncNode + se_lookup: dict[int, SyncNode] = {} # lookup: python id -> SyncNode for el in entities: self.nodes.append(SyncNode( el, @@ -545,8 +548,8 @@ class SyncGraph(): self.set_id_of_node(other_node) def _mark_existing(self, node: SyncNode): - if node.id <= 0: - raise ValueError("ID must be positive for existing entities") + if isinstance(node.id, TempID): + raise ValueError("ID must valid existing entities, not TempID") self._existing[id(node)] = node self.unchecked.remove(node) # This is one of three cases that affect other nodes: diff --git a/src/caoscrawler/sync_node.py b/src/caoscrawler/sync_node.py index 7b80d5d5aaad6b6c8a708c19f4a8432f57c81be9..2669d47d0c7ec8b947db9fcc81c28293635dafae 100644 --- a/src/caoscrawler/sync_node.py +++ b/src/caoscrawler/sync_node.py @@ -27,13 +27,18 @@ from typing import Any, Dict, List, Optional, Union import linkahead as db import yaml -from linkahead.common.models import _ParentList, _Properties +from linkahead.common.models import _ParentList, _Properties, Parent from .exceptions import ImpossibleMergeError +from .identifiable import Identifiable logger = logging.getLogger(__name__) +class TempID(int): + pass + + class SyncNode(): """ represents the information of an Entity as it shall be created in LinkAhead @@ -57,7 +62,7 @@ class SyncNode(): def __init__(self, entity: db.Entity, registered_identifiable: Optional[db.RecordType] = None) -> None: # db.Entity properties - self.id = entity.id + self.id: Union[int, TempID, str] = entity.id self.role = entity.role self.path = entity.path self.file = entity.file @@ -66,9 +71,8 @@ class SyncNode(): self.parents = _ParentList().extend(entity.parents) self.properties = _Properties().extend(entity.properties) # other members - self.identifiable = None + self.identifiable: Optional[Identifiable] = None self.registered_identifiable = registered_identifiable - self.other = [] def update(self, other: SyncNode) -> None: """update this node with information of given ``other`` SyncNode. @@ -118,11 +122,14 @@ class SyncNode(): for p in self.parents: ent.add_parent(p) for p in self.properties: - if ent.get_property(p) is None: + if p is None: + continue + entval: Any = ent.get_property(p) + if entval is None: ent.add_property(id=p.id, name=p.name, value=p.value) else: + entval = entval.value unequal = False - entval = ent.get_property(p).value pval = p.value if isinstance(entval, list) != isinstance(pval, list): unequal = True @@ -145,13 +152,12 @@ class SyncNode(): logger.error("The Crawler is trying to create an entity," " but there are have conflicting property values." f"Problematic Property: {p.name}\n" - f"First value:\n{ent.get_property(p).value}\n" + f"First value:\n{entval.value}\n" f"Second value:\n{p.value}\n" f"{self}" ) - ime = ImpossibleMergeError("Cannot merge Entities") - ime.pname = p.name - ime.values = (ent.get_property(p).value, p.value) + ime = ImpossibleMergeError("Cannot merge Entities", pname=p.name, + values=(entval.value, p.value)) raise ime return ent @@ -165,7 +171,7 @@ class SyncNode(): "parents": [el.name for el in self.parents]}, allow_unicode=True) res += "---------------------------------------------------\n" res += "properties:\n" - d = {} + d: dict[str, Any] = {} for p in self.properties: v = p.value d[p.name] = [] @@ -183,10 +189,11 @@ class SyncNode(): def is_unidentifiable(self) -> bool: """returns whether this is an unidentifiable Node""" - return self.registered_identifiable.get_property("no-ident") is not None + return (self.registered_identifiable is not None and + self.registered_identifiable.get_property("no-ident") is not None) -def parent_in_list(parent: db.Parent, plist: _ParentList) -> bool: +def parent_in_list(parent: Parent, plist: _ParentList) -> bool: """helper function that checks whether a parent with the same name or ID is in the plist""" missing = False if parent.name is not None: diff --git a/unittests/test_crawler.py b/unittests/test_crawler.py index 29e6dac02bf58a167c8aa1adba2d87eaf857e02e..cf6129c39a62778c26b6ae78710bb79461c6e30f 100644 --- a/unittests/test_crawler.py +++ b/unittests/test_crawler.py @@ -410,8 +410,8 @@ def test_split_into_inserts_and_updates_with_copy_attr(crawler_mocked_identifiab crawler.identifiableAdapter.retrieve_identified_record_for_identifiable.assert_called() -@ patch("caoscrawler.crawl.cached_get_entity_by", - new=Mock(side_effect=mock_get_entity_by)) +@patch("caoscrawler.crawl.cached_get_entity_by", + new=Mock(side_effect=mock_get_entity_by)) @patch("caoscrawler.identifiable_adapters.cached_query", new=Mock(side_effect=mock_cached_only_rt)) def test_split_iiau_with_unmergeable_list_items():