From 653ae056934597c456f4d94ded03c12ac6e4b02c Mon Sep 17 00:00:00 2001
From: Joscha Schmiedt <joscha@schmiedt.dev>
Date: Thu, 4 Apr 2024 22:55:49 +0200
Subject: [PATCH] WIP: Fixing mypy errors

---
 src/linkahead/common/models.py | 227 +++++++++++++++++++++++----------
 src/linkahead/exceptions.py    |   3 +
 2 files changed, 161 insertions(+), 69 deletions(-)

diff --git a/src/linkahead/common/models.py b/src/linkahead/common/models.py
index 2828b1ad..16243a06 100644
--- a/src/linkahead/common/models.py
+++ b/src/linkahead/common/models.py
@@ -54,6 +54,9 @@ if TYPE_CHECKING and sys.version_info > (3, 7):
     from datetime import datetime
     from typing import Any, Dict, Optional, Type, Union, List, TextIO, Tuple, Literal
     from .datatype import DATATYPE
+    from tempfile import _TemporaryFileWrapper
+    from io import BufferedWriter
+
 
 from warnings import warn
 
@@ -62,15 +65,26 @@ from lxml import etree
 from ..configuration import get_config
 from ..connection.connection import get_connection
 from ..connection.encode import MultipartParam, multipart_encode
-from ..exceptions import (AmbiguousEntityError, AuthorizationError,
-                          ConsistencyError, EmptyUniqueQueryError,
-                          EntityDoesNotExistError, EntityError,
-                          EntityHasNoDatatypeError, HTTPURITooLongError,
-                          LinkAheadConnectionError, LinkAheadException,
-                          MismatchingEntitiesError, PagingConsistencyError,
-                          QueryNotUniqueError, TransactionError,
-                          UniqueNamesError, UnqualifiedParentsError,
-                          UnqualifiedPropertiesError)
+from ..exceptions import (
+    AmbiguousEntityError,
+    AuthorizationError,
+    ConsistencyError,
+    EmptyUniqueQueryError,
+    EntityDoesNotExistError,
+    EntityError,
+    EntityHasNoAclError,
+    EntityHasNoDatatypeError,
+    HTTPURITooLongError,
+    LinkAheadConnectionError,
+    LinkAheadException,
+    MismatchingEntitiesError,
+    PagingConsistencyError,
+    QueryNotUniqueError,
+    TransactionError,
+    UniqueNamesError,
+    UnqualifiedParentsError,
+    UnqualifiedPropertiesError,
+)
 from .datatype import (
     BOOLEAN,
     DATETIME,
@@ -95,7 +109,8 @@ FIX = "FIX"
 ALL = "ALL"
 NONE = "NONE"
 if TYPE_CHECKING:
-    INHERITANCE = Literal["OBLIGATORY", "SUGGESTED", "RECOMMENDED", "FIX", "ALL", "NONE"]
+    INHERITANCE = Literal["OBLIGATORY", "SUGGESTED", "RECOMMENDED", "ALL", "NONE"]
+    IMPORTANCE = Literal["OBLIGATORY", "RECOMMENDED", "SUGGESTED", "FIX", "NONE"]
 
 SPECIAL_ATTRIBUTES = ["name", "role", "datatype", "description",
                       "id", "path", "checksum", "size", "value"]
@@ -387,6 +402,10 @@ class Entity:
             ACL will be revoked.
         """
         # @review Florian Spreckelsen 2022-03-17
+
+        if self.acl is None:
+            raise EntityHasNoAclError("This entity does not have an ACL (yet).")
+
         self.acl.grant(realm=realm, username=username, role=role,
                        permission=permission, priority=priority,
                        revoke_denial=revoke_denial)
@@ -429,6 +448,9 @@ class Entity:
             ACL will be revoked.
         """
         # @review Florian Spreckelsen 2022-03-17
+        if self.acl is None:
+            raise EntityHasNoAclError("This entity does not have an ACL (yet).")
+
         self.acl.deny(realm=realm, username=username, role=role,
                       permission=permission, priority=priority,
                       revoke_grant=revoke_grant)
@@ -452,12 +474,14 @@ class Entity:
             priority=priority)
 
     def is_permitted(self, permission: Permission, role: Optional[str] = None):
-        if role is None:
-            # pylint: disable=unsupported-membership-test
+        if self.acl is None:
+            raise EntityHasNoAclError("This entity does not have an ACL (yet).")
 
+        if role is None and self.permissions is not None:
+            # pylint: disable=unsupported-membership-test
             return permission in self.permissions
-        else:
-            self.acl.is_permitted(permission=permission)
+
+        self.acl.is_permitted(role, permission=permission)
 
     def get_all_messages(self) -> Messages:
         ret = Messages()
@@ -534,20 +558,26 @@ class Entity:
 
         """
 
-        if self.get_property(property_name) is None:
+        prop = self.get_property(property_name)
+        if prop is None:
             return self
-        if self.get_property(property_name).value is None:
+
+        property_value = prop.value
+
+        if property_value is None:
             remove_if_empty_afterwards = False
+
         empty_afterwards = False
-        if isinstance(self.get_property(property_name).value, list):
-            if value in self.get_property(property_name).value:
-                self.get_property(property_name).value.remove(value)
-                if self.get_property(property_name).value == []:
-                    self.get_property(property_name).value = None
+        if isinstance(property_value, list):
+            if value in property_value:
+                property_value.remove(value)
+                if property_value == []:
+                    property_value = None
                     empty_afterwards = True
-        elif self.get_property(property_name).value == value:
-            self.get_property(property_name).value = None
+        elif property_value == value:
+            property_value = None
             empty_afterwards = True
+
         if remove_if_empty_afterwards and empty_afterwards:
             self.remove_property(property_name)
 
@@ -576,10 +606,10 @@ class Entity:
         id: Optional[int] = None,
         name: Optional[str] = None,
         description: Optional[str] = None,
-        datatype: Optional[str] = None,
+        datatype: Optional[DATATYPE] = None,
         unit: Optional[str] = None,
-        importance: Optional[str] = None,
-        inheritance: Union[str, INHERITANCE, None] = None,
+        importance: Optional[IMPORTANCE] = None,
+        inheritance: Optional[INHERITANCE] = None,
     ) -> Entity:  # @ReservedAssignment
         """Add a property to this entity.
 
@@ -755,7 +785,7 @@ class Entity:
         parent: Union[Entity, int, str, None] = None,
         id: Optional[int] = None,
         name: Optional[str] = None,
-        inheritance: Union[INHERITANCE, str, None] = None,
+        inheritance: INHERITANCE = "NONE",
     ):  # @ReservedAssignment
         """Add a parent to this entity.
 
@@ -771,7 +801,7 @@ class Entity:
             Name of the parent entity. Ignored if `parent is not
             none`.
         inheritance : str, INHERITANCE
-            One of ``obligatory``, ``recommended``, ``suggested``, or ``fix``. Specifies the
+            One of ``obligatory``, ``recommended``, ``suggested``, or ``all``. Specifies the
             minimum importance which parent properties need to have to be inherited by this
             entity. If no `inheritance` is given, no properties will be inherited by the child.
             This parameter is case-insensitive.
@@ -1018,7 +1048,7 @@ out: List[Entity]
 
                     return p
         else:
-            raise ValueError("argument should be entity, int , string")
+            raise ValueError("pattern argument should be Entity, int or str")
 
         return None
 
@@ -1035,8 +1065,10 @@ out: List[Entity]
         """
         SPECIAL_SELECTORS = ["unit", "value", "description", "id", "name"]
 
-        if not isinstance(selector, (tuple, list)):
+        if isinstance(selector, str):
             selector = [selector]
+        elif isinstance(selector, tuple):
+            selector = list(selector)
 
         ref = self
 
@@ -1051,7 +1083,7 @@ out: List[Entity]
             special_selector = None
 
         # iterating through the entity tree according to the selector
-
+        prop: Optional[Property] = None
         for subselector in selector:
             # selector does not match the structure, we cannot get a
             # property of non-entity
@@ -1077,8 +1109,7 @@ out: List[Entity]
             else:
                 ref = prop
 
-        # if we saved a special selector before, apply it
-
+        # if we saved a special selector before, apply it                
         if special_selector is None:
             return prop.value
         else:
@@ -1195,7 +1226,7 @@ out: List[Entity]
     def to_xml(
         self,
         xml: Optional[etree._Element] = None,
-        add_properties: Optional[INHERITANCE] = ALL,
+        add_properties: INHERITANCE = "ALL",
         local_serialization: bool = False,
     ) -> etree._Element:
         """Generate an xml representation of this entity. If the parameter xml
@@ -1207,6 +1238,10 @@ out: List[Entity]
         @param xml: an xml element to which all attributes, parents,
             properties, and messages
             are to be added.
+
+        FIXME: Add documentation for the add_properties parameter.
+        FIXME: Add docuemntation for the local_serialization parameter.
+
         @return: xml representation of this entity.
         """
 
@@ -1849,7 +1884,12 @@ class Parent(Entity):
             self.set_flag("inheritance", inheritance)
         self.__affiliation = None
 
-    def to_xml(self, xml: Optional[etree._Element] = None, add_properties=None):
+    def to_xml(
+        self,
+        xml: Optional[etree._Element] = None,
+        add_properties: INHERITANCE = "NONE",
+        local_serialization: bool = False,
+    ):
         if xml is None:
             xml = etree.Element("Parent")
 
@@ -1919,11 +1959,20 @@ class Property(Entity):
                         datatype=datatype, value=value, role="Property")
         self.unit = unit
 
-    def to_xml(self, xml: Optional[etree._Element] = None, add_properties=ALL):
+    def to_xml(
+        self,
+        xml: Optional[etree._Element] = None,
+        add_properties: INHERITANCE = "ALL",
+        local_serialization: bool = False,
+    ):
         if xml is None:
             xml = etree.Element("Property")
 
-        return super(Property, self).to_xml(xml, add_properties)
+        return super(Property, self).to_xml(
+            xml=xml,
+            add_properties=add_properties,
+            local_serialization=local_serialization,
+        )
 
     def is_reference(self, server_retrieval=False):
         """Returns whether this Property is a reference
@@ -2028,7 +2077,7 @@ class RecordType(Entity):
         parent: Union[Entity, int, str, None] = None,
         id: Optional[int] = None,
         name: Optional[str] = None,
-        inheritance: Union[INHERITANCE, str, None] = OBLIGATORY,
+        inheritance: INHERITANCE = "OBLIGATORY",
     ):
         """Add a parent to this RecordType
 
@@ -2048,8 +2097,8 @@ class RecordType(Entity):
         name : str
             Name of the parent entity. Ignored if `parent is not
             none`.
-        inheritance : str, default OBLIGATORY
-            One of ``obligatory``, ``recommended``, ``suggested``, or ``fix``. Specifies the
+        inheritance : INHERITANCE, default OBLIGATORY
+            One of ``obligatory``, ``recommended``, ``suggested``, or ``all``. Specifies the
             minimum importance which parent properties need to have to be inherited by this
             entity. If no `inheritance` is given, no properties will be inherited by the child.
             This parameter is case-insensitive.
@@ -2075,12 +2124,18 @@ class RecordType(Entity):
     def to_xml(
         self,
         xml: Optional[etree._Element] = None,
-        add_properties: Optional[INHERITANCE] = ALL,
+        add_properties: INHERITANCE = "ALL",
+        local_serialization: bool = False,
     ) -> etree._Element:
         if xml is None:
             xml = etree.Element("RecordType")
 
-        return Entity.to_xml(self, xml, add_properties)
+        return Entity.to_xml(
+            self,
+            xml=xml,
+            add_properties=add_properties,
+            local_serialization=local_serialization,
+        )
 
 
 class Record(Entity):
@@ -2104,11 +2159,20 @@ class Record(Entity):
         Entity.__init__(self, name=name, id=id, description=description,
                         role="Record")
 
-    def to_xml(self, xml=None, add_properties=ALL):
+    def to_xml(
+        self,
+        xml: Optional[etree._Element] = None,
+        add_properties: INHERITANCE = "ALL",
+        local_serialization: bool = False,
+    ):
         if xml is None:
             xml = etree.Element("Record")
 
-        return Entity.to_xml(self, xml, add_properties=ALL)
+        return super().to_xml(
+            xml=xml,
+            add_properties=add_properties,
+            local_serialization=local_serialization,
+        )
 
 
 class File(Record):
@@ -2175,7 +2239,7 @@ class File(Record):
     def to_xml(
         self,
         xml: Optional[etree._Element] = None,
-        add_properties: Optional[INHERITANCE] = ALL,
+        add_properties: INHERITANCE = "ALL",
         local_serialization: bool = False,
     ):
         """Convert this file to an xml element.
@@ -2189,7 +2253,7 @@ class File(Record):
         return Entity.to_xml(self, xml=xml, add_properties=add_properties,
                              local_serialization=local_serialization)
 
-    def download(self, target=None):
+    def download(self, target: Optional[str] = None):
         """Download this file-entity's actual file from the file server. It
         will be stored to the target or will be hold as a temporary file.
 
@@ -2199,7 +2263,7 @@ class File(Record):
         self.clear_server_messages()
 
         if target:
-            file_ = open(target, 'wb')
+            file_: Union[BufferedWriter, _TemporaryFileWrapper] = open(target, "wb")
         else:
             file_ = NamedTemporaryFile(mode='wb', delete=False)
         checksum = File.download_from_path(file_, self.path)
@@ -2340,9 +2404,7 @@ class _Properties(list):
 
         return self
 
-    def to_xml(
-        self, add_to_element: etree._Element, add_properties: Union[str, INHERITANCE]
-    ):
+    def to_xml(self, add_to_element: etree._Element, add_properties: INHERITANCE):
         for p in self:
             importance = self._importance.get(p)
 
@@ -4152,11 +4214,17 @@ class ACI():
         if self.role is not None:
             e.set("role", self.role)
         else:
+            if self.username is None:
+                raise LinkAheadException(
+                    "An ACI must have either a role or a username."
+                )
             e.set("username", self.username)
 
             if self.realm is not None:
                 e.set("realm", self.realm)
         p = etree.Element("Permission")
+        if self.permission is None:
+            raise LinkAheadException("An ACI must have a permission.")
         p.set("name", self.permission)
         e.append(p)
 
@@ -4206,7 +4274,7 @@ class ACL():
             role = e.get("role")
             username = e.get("username")
             realm = e.get("realm")
-            priority = e.get("priority")
+            priority = self._get_boolean_priority(e.get("priority"))
 
             for p in e:
                 if p.tag == "Permission":
@@ -4564,12 +4632,18 @@ class Query():
 
         if isinstance(q, etree._Element):
             self.q = q.get("string")
-            self.results = int(q.get("results"))
+            results = q.get("results")
+            if results is None:
+                raise LinkAheadException(
+                    "The query result count is not available in the response."
+                )
+            self.results = int(results)
 
-            if q.get("cached") is None:
+            cached_value = q.get("cached")
+            if cached_value is None:
                 self.cached = False
             else:
-                self.cached = q.get("cached").lower() == "true"
+                self.cached = cached_value.lower() == "true"
             self.etag = q.get("etag")
 
             for m in q:
@@ -4883,8 +4957,12 @@ class Permissions():
 
         for e in xml:
             if e.tag == "Permission":
-                self._perms.add(Permission(name=e.get("name"),
-                                           description=e.get("description")))
+                name = e.get("name")
+                if name is None:
+                    raise LinkAheadException(
+                        "The permission element has no name attribute."
+                    )
+                self._perms.add(Permission(name=name, description=e.get("description")))
 
     def __contains__(self, p):
         if isinstance(p, Permission):
@@ -4917,15 +4995,18 @@ def parse_xml(xml: Union[str, etree._Element]):
 
 def _parse_single_xml_element(elem: etree._Element):
     classmap = {
-        'record': Record,
-        'recordtype': RecordType,
-        'property': Property,
-        'file': File,
-        'parent': Parent,
-        'entity': Entity}
+        "record": Record,
+        "recordtype": RecordType,
+        "property": Property,
+        "file": File,
+        "parent": Parent,
+        "entity": Entity,
+    }
 
     if elem.tag.lower() in classmap:
         klass = classmap.get(elem.tag.lower())
+        if klass is None:
+            raise LinkAheadException("No class for tag '{}' found.".format(elem.tag))
         entity = klass()
         Entity._from_xml(entity, elem)
 
@@ -4953,8 +5034,8 @@ def _parse_single_xml_element(elem: etree._Element):
         return Message(type='History', description=elem.get("transaction"))
     elif elem.tag.lower() == 'stats':
         counts = elem.find("counts")
-
-        return Message(type="Counts", description=None, body=counts.attrib)
+        attrib = str(counts.attrib) if counts is not None else None
+        return Message(type="Counts", description=None, body=attrib)
     elif elem.tag == "EntityACL":
         return ACL(xml=elem)
     elif elem.tag == "Permissions":
@@ -4962,11 +5043,19 @@ def _parse_single_xml_element(elem: etree._Element):
     elif elem.tag == "UserInfo":
         return UserInfo(xml=elem)
     elif elem.tag == "TimeZone":
-        return TimeZone(zone_id=elem.get("id"), offset=elem.get("offset"),
-                        display_name=elem.text.strip())
+        return TimeZone(
+            zone_id=elem.get("id"),
+            offset=elem.get("offset"),
+            display_name=elem.text.strip() if elem.text is not None else "",
+        )
     else:
-        return Message(type=elem.tag, code=elem.get(
-            "code"), description=elem.get("description"), body=elem.text)
+        code = elem.get("code")
+        return Message(
+            type=elem.tag,
+            code=int(code) if code is not None else None,
+            description=elem.get("description"),
+            body=elem.text,
+        )
 
 
 def _evaluate_and_add_error(parent_error: TransactionError, ent: Union[Entity, Container]):
@@ -5000,7 +5089,7 @@ def _evaluate_and_add_error(parent_error: TransactionError, ent: Union[Entity, C
 
             if err.code is not None:
                 if int(err.code) == 101:  # ent doesn't exist
-                    new_exc = EntityDoesNotExistError(entity=ent,
+                    new_exc: EntityError = EntityDoesNotExistError(entity=ent,
                                                       error=err)
                 elif int(err.code) == 110:  # ent has no data type
                     new_exc = EntityHasNoDatatypeError(entity=ent,
diff --git a/src/linkahead/exceptions.py b/src/linkahead/exceptions.py
index a6abe09e..e702dba4 100644
--- a/src/linkahead/exceptions.py
+++ b/src/linkahead/exceptions.py
@@ -353,6 +353,9 @@ class UnqualifiedPropertiesError(EntityError):
 
     """
 
+class EntityHasNoAclError(EntityError):
+    """This entity has no ACL (yet)."""
+
 
 class EntityDoesNotExistError(EntityError):
     """This entity does not exist."""
-- 
GitLab