From 370aafff3800acef98d2e606dfda14d250236351 Mon Sep 17 00:00:00 2001
From: Joscha Schmiedt <joscha@schmiedt.dev>
Date: Wed, 10 Apr 2024 21:23:46 +0200
Subject: [PATCH] Add type hints to state.py

---
 src/linkahead/common/state.py | 86 +++++++++++++++++++++++------------
 1 file changed, 57 insertions(+), 29 deletions(-)

diff --git a/src/linkahead/common/state.py b/src/linkahead/common/state.py
index 82f314e8..42953b1a 100644
--- a/src/linkahead/common/state.py
+++ b/src/linkahead/common/state.py
@@ -19,11 +19,19 @@
 #
 # ** end header
 
+from __future__ import annotations  # Can be removed with 3.10.
 import copy
 from lxml import etree
 
+from typing import TYPE_CHECKING
+import sys
 
-def _translate_to_state_acis(acis):
+if TYPE_CHECKING and sys.version_info > (3, 7):
+    from typing import Optional
+    from linkahead.common.models import ACL, ACI
+
+
+def _translate_to_state_acis(acis: set[ACI]) -> set[ACI]:
     result = set()
     for aci in acis:
         aci = copy.copy(aci)
@@ -50,7 +58,13 @@ class Transition:
         A state name
     """
 
-    def __init__(self, name, from_state, to_state, description=None):
+    def __init__(
+        self,
+        name: Optional[str],
+        from_state: Optional[str],
+        to_state: Optional[str],
+        description: Optional[str] = None,
+    ):
         self._name = name
         self._from_state = from_state
         self._to_state = to_state
@@ -76,25 +90,29 @@ class Transition:
         return f'Transition(name="{self.name}", from_state="{self.from_state}", to_state="{self.to_state}", description="{self.description}")'
 
     def __eq__(self, other):
-        return (isinstance(other, Transition)
-                and other.name == self.name
-                and other.to_state == self.to_state
-                and other.from_state == self.from_state)
+        return (
+            isinstance(other, Transition)
+            and other.name == self.name
+            and other.to_state == self.to_state
+            and other.from_state == self.from_state
+        )
 
     def __hash__(self):
         return 23472 + hash(self.name) + hash(self.from_state) + hash(self.to_state)
 
     @staticmethod
-    def from_xml(xml):
-        to_state = [to.get("name") for to in xml
-                    if to.tag.lower() == "tostate"]
-        from_state = [from_.get("name") for from_ in xml
-                      if from_.tag.lower() == "fromstate"]
-        result = Transition(name=xml.get("name"),
-                            description=xml.get("description"),
-                            from_state=from_state[0] if from_state else None,
-                            to_state=to_state[0] if to_state else None)
-        return result
+    def from_xml(xml: etree._Element) -> "Transition":
+        to_state = [to.get("name")
+                    for to in xml if to.tag.lower() == "tostate"]
+        from_state = [
+            from_.get("name") for from_ in xml if from_.tag.lower() == "fromstate"
+        ]
+        return Transition(
+            name=xml.get("name"),
+            description=xml.get("description"),
+            from_state=from_state[0] if from_state else None,
+            to_state=to_state[0] if to_state else None,
+        )
 
 
 class State:
@@ -119,12 +137,12 @@ class State:
         All transitions which are available from this state (read-only)
     """
 
-    def __init__(self, model, name):
+    def __init__(self, model: str, name: str):
         self.name = name
         self.model = model
-        self._id = None
-        self._description = None
-        self._transitions = None
+        self._id: Optional[str] = None
+        self._description: Optional[str] = None
+        self._transitions: Optional[set[Transition]] = None
 
     @property
     def id(self):
@@ -139,9 +157,11 @@ class State:
         return self._transitions
 
     def __eq__(self, other):
-        return (isinstance(other, State)
-                and self.name == other.name
-                and self.model == other.model)
+        return (
+            isinstance(other, State)
+            and self.name == other.name
+            and self.model == other.model
+        )
 
     def __hash__(self):
         return hash(self.name) + hash(self.model)
@@ -164,7 +184,7 @@ class State:
         return xml
 
     @staticmethod
-    def from_xml(xml):
+    def from_xml(xml: etree._Element):
         """Create a new State instance from an xml Element.
 
         Parameters
@@ -176,23 +196,31 @@ class State:
         state : State
         """
         name = xml.get("name")
+        if name is None:
+            raise ValueError(f"State name is missing from xml:{str(xml)}")
         model = xml.get("model")
+        if model is None:
+            raise ValueError(f"State model is missing from xml:{str(xml)}")
         result = State(name=name, model=model)
         result._id = xml.get("id")
         result._description = xml.get("description")
-        transitions = [Transition.from_xml(t) for t in xml if t.tag.lower() ==
-                       "transition"]
+        transitions = [
+            Transition.from_xml(t) for t in xml if t.tag.lower() == "transition"
+        ]
         if transitions:
             result._transitions = set(transitions)
 
         return result
 
     @staticmethod
-    def create_state_acl(acl):
+    def create_state_acl(acl: ACL):
         from .models import ACL
+
         state_acl = ACL()
         state_acl._grants = _translate_to_state_acis(acl._grants)
         state_acl._denials = _translate_to_state_acis(acl._denials)
-        state_acl._priority_grants = _translate_to_state_acis(acl._priority_grants)
-        state_acl._priority_denials = _translate_to_state_acis(acl._priority_denials)
+        state_acl._priority_grants = _translate_to_state_acis(
+            acl._priority_grants)
+        state_acl._priority_denials = _translate_to_state_acis(
+            acl._priority_denials)
         return state_acl
-- 
GitLab