diff --git a/src/linkahead/common/models.py b/src/linkahead/common/models.py index b087496d1db9800b352f4b4cc916ea084b2c83f6..2dc7f7a0d46f2d5f0b6ee868a92f71858cf9b984 100644 --- a/src/linkahead/common/models.py +++ b/src/linkahead/common/models.py @@ -35,19 +35,18 @@ transactions. from __future__ import annotations # Can be removed with 3.10. from __future__ import print_function, unicode_literals -from enum import Enum import re import sys from builtins import str from copy import deepcopy +from enum import Enum from functools import cmp_to_key from hashlib import sha512 from os import listdir from os.path import isdir from random import randint from tempfile import NamedTemporaryFile - from typing import TYPE_CHECKING if TYPE_CHECKING and sys.version_info > (3, 7): @@ -57,7 +56,6 @@ if TYPE_CHECKING and sys.version_info > (3, 7): from tempfile import _TemporaryFileWrapper from io import BufferedWriter - from warnings import warn from lxml import etree @@ -155,7 +153,7 @@ class Entity: self.value = value self.messages = Messages() self.properties = _Properties() - self.parents = _ParentList() + self.parents = ParentList() self.path: Optional[str] = None self.file: Optional[File] = None self.unit: Optional[str] = None @@ -899,7 +897,7 @@ out: bool def get_parents(self): """Get all parents of this entity. - @return: _ParentList(list) + @return: ParentList(list) """ return self.parents @@ -2515,7 +2513,7 @@ class _Properties(list): raise KeyError(str(prop) + " not found.") -class _ParentList(list): +class ParentList(list): # TODO unclear why this class is private. Isn't it use full for users? def _get_entity_by_cuid(self, cuid): @@ -2532,10 +2530,15 @@ class _ParentList(list): return e raise KeyError("No entity with that cuid in this container.") - def __init__(self): - list.__init__(self) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._element_by_name = dict() self._element_by_id = dict() + for el in self: + if el.name is not None: + self._element_by_name[el.name] = el + if el.id is not None: + self._element_by_name[str(el.id)] = el def extend(self, parents): self.append(parents) @@ -2596,6 +2599,16 @@ class _ParentList(list): return xml2str(xml) + def __contains__(self, parent): + missing = False + if parent.name is not None: + if parent.name not in self._element_by_name: + missing = True + if parent.id is not None: + if str(parent.id) not in self._element_by_id: + missing = True + return not missing + def remove(self, parent: Union[Entity, int, str]): if isinstance(parent, Entity): if parent in self: @@ -2637,6 +2650,13 @@ class _ParentList(list): raise KeyError(str(parent) + " not found.") +class _ParentList(ParentList): + def __init__(self, *args, **kwargs): + warnings.warn(DeprecationWarning("This class is depricated. Please use ParentList " + "(without underscore.")) + super().__init__(*args, **kwargs) + + class Messages(list): """This specialization of list stores error, warning, info, and other messages. The mentioned three messages types play a special role. diff --git a/unittests/test_entity.py b/unittests/test_entity.py index abf82f0a9b557cf9d1d2365e01fedaa4eae0c565..d48c9ad71ad709b296ad48f529e6e11aaef87791 100644 --- a/unittests/test_entity.py +++ b/unittests/test_entity.py @@ -22,14 +22,15 @@ # ** end header # """Tests for the Entity class.""" +import os # pylint: disable=missing-docstring import unittest -from lxml import etree -import os +import linkahead from linkahead import (INTEGER, Entity, Property, Record, RecordType, configure_connection) from linkahead.connection.mockup import MockUpServerConnection +from lxml import etree UNITTESTDIR = os.path.dirname(os.path.abspath(__file__)) @@ -97,3 +98,21 @@ class TestEntity(unittest.TestCase): # test whether the __role property of this object has explicitely been # set. self.assertEqual(getattr(entity, "_Entity__role"), "Record") + + +def test_parent_list(): + p1 = RecordType(name="A") + pl = linkahead.common.models.ParentList([p1]) + assert p1 in pl + assert RecordType(name="A") in pl + assert not RecordType(id=101) in pl + pl.append(RecordType(id=101)) + assert RecordType(name="A") in pl + assert RecordType(id=101) in pl + assert not RecordType(id=102) in pl + pl.append(RecordType(id=103, name='B')) + assert RecordType(name="A") in pl + assert RecordType(name="B") in pl + assert RecordType(id=101) in pl + assert RecordType(id=103) in pl + assert not RecordType(id=105, name="B") in pl