diff --git a/src/linkahead/common/state.py b/src/linkahead/common/state.py index 82f314e80191163f14a5c4babdd749f977f2901b..42953b1ad871ed6953248d93fd17b24b2485c5b5 100644 --- a/src/linkahead/common/state.py +++ b/src/linkahead/common/state.py @@ -19,11 +19,19 @@ # # ** end header +from __future__ import annotations # Can be removed with 3.10. import copy from lxml import etree +from typing import TYPE_CHECKING +import sys -def _translate_to_state_acis(acis): +if TYPE_CHECKING and sys.version_info > (3, 7): + from typing import Optional + from linkahead.common.models import ACL, ACI + + +def _translate_to_state_acis(acis: set[ACI]) -> set[ACI]: result = set() for aci in acis: aci = copy.copy(aci) @@ -50,7 +58,13 @@ class Transition: A state name """ - def __init__(self, name, from_state, to_state, description=None): + def __init__( + self, + name: Optional[str], + from_state: Optional[str], + to_state: Optional[str], + description: Optional[str] = None, + ): self._name = name self._from_state = from_state self._to_state = to_state @@ -76,25 +90,29 @@ class Transition: return f'Transition(name="{self.name}", from_state="{self.from_state}", to_state="{self.to_state}", description="{self.description}")' def __eq__(self, other): - return (isinstance(other, Transition) - and other.name == self.name - and other.to_state == self.to_state - and other.from_state == self.from_state) + return ( + isinstance(other, Transition) + and other.name == self.name + and other.to_state == self.to_state + and other.from_state == self.from_state + ) def __hash__(self): return 23472 + hash(self.name) + hash(self.from_state) + hash(self.to_state) @staticmethod - def from_xml(xml): - to_state = [to.get("name") for to in xml - if to.tag.lower() == "tostate"] - from_state = [from_.get("name") for from_ in xml - if from_.tag.lower() == "fromstate"] - result = Transition(name=xml.get("name"), - description=xml.get("description"), - from_state=from_state[0] if from_state else None, - to_state=to_state[0] if to_state else None) - return result + def from_xml(xml: etree._Element) -> "Transition": + to_state = [to.get("name") + for to in xml if to.tag.lower() == "tostate"] + from_state = [ + from_.get("name") for from_ in xml if from_.tag.lower() == "fromstate" + ] + return Transition( + name=xml.get("name"), + description=xml.get("description"), + from_state=from_state[0] if from_state else None, + to_state=to_state[0] if to_state else None, + ) class State: @@ -119,12 +137,12 @@ class State: All transitions which are available from this state (read-only) """ - def __init__(self, model, name): + def __init__(self, model: str, name: str): self.name = name self.model = model - self._id = None - self._description = None - self._transitions = None + self._id: Optional[str] = None + self._description: Optional[str] = None + self._transitions: Optional[set[Transition]] = None @property def id(self): @@ -139,9 +157,11 @@ class State: return self._transitions def __eq__(self, other): - return (isinstance(other, State) - and self.name == other.name - and self.model == other.model) + return ( + isinstance(other, State) + and self.name == other.name + and self.model == other.model + ) def __hash__(self): return hash(self.name) + hash(self.model) @@ -164,7 +184,7 @@ class State: return xml @staticmethod - def from_xml(xml): + def from_xml(xml: etree._Element): """Create a new State instance from an xml Element. Parameters @@ -176,23 +196,31 @@ class State: state : State """ name = xml.get("name") + if name is None: + raise ValueError(f"State name is missing from xml:{str(xml)}") model = xml.get("model") + if model is None: + raise ValueError(f"State model is missing from xml:{str(xml)}") result = State(name=name, model=model) result._id = xml.get("id") result._description = xml.get("description") - transitions = [Transition.from_xml(t) for t in xml if t.tag.lower() == - "transition"] + transitions = [ + Transition.from_xml(t) for t in xml if t.tag.lower() == "transition" + ] if transitions: result._transitions = set(transitions) return result @staticmethod - def create_state_acl(acl): + def create_state_acl(acl: ACL): from .models import ACL + state_acl = ACL() state_acl._grants = _translate_to_state_acis(acl._grants) state_acl._denials = _translate_to_state_acis(acl._denials) - state_acl._priority_grants = _translate_to_state_acis(acl._priority_grants) - state_acl._priority_denials = _translate_to_state_acis(acl._priority_denials) + state_acl._priority_grants = _translate_to_state_acis( + acl._priority_grants) + state_acl._priority_denials = _translate_to_state_acis( + acl._priority_denials) return state_acl