From 32b3c6cb1f79a56426b7947990a40cf4ce3f70b1 Mon Sep 17 00:00:00 2001
From: Joscha Schmiedt <joscha@schmiedt.dev>
Date: Sat, 20 Apr 2024 21:39:50 +0200
Subject: [PATCH] Add type hints to apiutils.py

---
 src/linkahead/apiutils.py | 61 ++++++++++++++++++++++++---------------
 1 file changed, 38 insertions(+), 23 deletions(-)

diff --git a/src/linkahead/apiutils.py b/src/linkahead/apiutils.py
index e2ed0fac..d05c03d5 100644
--- a/src/linkahead/apiutils.py
+++ b/src/linkahead/apiutils.py
@@ -25,11 +25,11 @@
 """API-Utils: Some simplified functions for generation of records etc.
 
 """
-
+from __future__ import annotations
 import logging
 import warnings
 from collections.abc import Iterable
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Union, Optional, Tuple
 
 from .common.datatype import is_reference
 from .common.models import (SPECIAL_ATTRIBUTES, Container, Entity, File,
@@ -47,12 +47,14 @@ class EntityMergeConflictError(LinkAheadException):
     """
 
 
-def new_record(record_type, name=None, description=None,
-               tempid=None, insert=False, **kwargs):
+def new_record(record_type: Union[str],
+               name: Optional[str] = None,
+               description: Optional[str] = None,
+               tempid: Optional[int] = None,
+               insert: bool = False, **kwargs) -> Record:
     """Function to simplify the creation of Records.
 
-    record_type: The name of the RecordType to use for this record.
-                 (ids should also work.)
+    record_type: The name of the RecordType to use for this record.                 
     name: Name of the new Record.
     kwargs: Key-value-pairs for the properties of this Record.
 
@@ -92,19 +94,19 @@ def new_record(record_type, name=None, description=None,
     return r
 
 
-def id_query(ids):
+def id_query(ids: List[int]) -> Container:
     warnings.warn("Please use 'create_id_query', which only creates"
                   "the string.", DeprecationWarning)
 
-    return execute_query(create_id_query(ids))
+    return execute_query(create_id_query(ids))  # type: ignore
 
 
-def create_id_query(ids):
+def create_id_query(ids: List[int]) -> str:
     return "FIND ENTITY WITH " + " OR ".join(
         ["ID={}".format(id) for id in ids])
 
 
-def get_type_of_entity_with(id_):
+def get_type_of_entity_with(id_: int):
     objs = retrieve_entities_with_ids([id_])
 
     if len(objs) == 0:
@@ -127,11 +129,11 @@ def get_type_of_entity_with(id_):
         return Entity
 
 
-def retrieve_entity_with_id(eid):
+def retrieve_entity_with_id(eid: int):
     return execute_query("FIND ENTITY WITH ID={}".format(eid), unique=True)
 
 
-def retrieve_entities_with_ids(entities):
+def retrieve_entities_with_ids(entities: List) -> Container:
     collection = Container()
     step = 20
 
@@ -175,7 +177,10 @@ def getCommitIn(folder):
     return get_commit_in(folder)
 
 
-def compare_entities(old_entity: Entity, new_entity: Entity, compare_referenced_records: bool = False):
+def compare_entities(old_entity: Entity,
+                     new_entity: Entity,
+                     compare_referenced_records: bool = False
+                     ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
     """Compare two entites.
 
     Return a tuple of dictionaries, the first index belongs to additional information for old
@@ -328,7 +333,7 @@ def compare_entities(old_entity: Entity, new_entity: Entity, compare_referenced_
     return (olddiff, newdiff)
 
 
-def empty_diff(old_entity: Entity, new_entity: Entity, compare_referenced_records: bool = False):
+def empty_diff(old_entity: Entity, new_entity: Entity, compare_referenced_records: bool = False) -> bool:
     """Check whether the `compare_entities` found any differences between
     old_entity and new_entity.
 
@@ -357,8 +362,12 @@ def empty_diff(old_entity: Entity, new_entity: Entity, compare_referenced_record
     return True
 
 
-def merge_entities(entity_a: Entity, entity_b: Entity, merge_references_with_empty_diffs=True,
-                   force=False, merge_id_with_resolved_entity: bool = False):
+def merge_entities(entity_a: Entity,
+                   entity_b: Entity,
+                   merge_references_with_empty_diffs=True,
+                   force=False,
+                   merge_id_with_resolved_entity: bool = False
+                   ) -> Entity:
     """Merge entity_b into entity_a such that they have the same parents and properties.
 
     datatype, unit, value, name and description will only be changed in entity_a
@@ -441,8 +450,12 @@ def merge_entities(entity_a: Entity, entity_b: Entity, merge_references_with_emp
                         if merge_id_with_resolved_entity is True and attribute == "value":
                             # Do a special check for the case of an id value on the
                             # one hand, and a resolved entity on the other side.
-                            this = entity_a.get_property(key).value
-                            that = entity_b.get_property(key).value
+                            prop_a = entity_a.get_property(key)
+                            assert prop_a is not None, f"Property {key} not found in entity_a"
+                            prop_b = entity_b.get_property(key)
+                            assert prop_b is not None, f"Property {key} not found in entity_b"
+                            this = prop_a.value
+                            that = prop_b.value
                             same = False
                             if isinstance(this, list) and isinstance(that, list):
                                 if len(this) == len(that):
@@ -465,11 +478,13 @@ def merge_entities(entity_a: Entity, entity_b: Entity, merge_references_with_emp
         else:
             # TODO: This is a temporary FIX for
             #       https://gitlab.indiscale.com/caosdb/src/caosdb-pylib/-/issues/105
-            entity_a.add_property(id=entity_b.get_property(key).id,
-                                  name=entity_b.get_property(key).name,
-                                  datatype=entity_b.get_property(key).datatype,
-                                  value=entity_b.get_property(key).value,
-                                  unit=entity_b.get_property(key).unit,
+            prop_b = entity_b.get_property(key)
+            assert prop_b is not None, f"Property {key} not found in entity_b"
+            entity_a.add_property(id=prop_b.id,
+                                  name=prop_b.name,
+                                  datatype=prop_b.datatype,
+                                  value=prop_b.value,
+                                  unit=prop_b.unit,
                                   importance=entity_b.get_importance(key))
             # entity_a.add_property(
             #     entity_b.get_property(key),
-- 
GitLab