From 24f2e39651f3a8be1fa06a56ffff9acb10347078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20tom=20W=C3=B6rden?= <h.tomwoerden@indiscale.com> Date: Mon, 16 Oct 2023 20:13:27 +0200 Subject: [PATCH] MAINT: refactor get_list_datatype --- src/linkahead/common/datatype.py | 25 ++++++++++++++++++------- unittests/test_datatype.py | 11 +++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/linkahead/common/datatype.py b/src/linkahead/common/datatype.py index 83284456..7bb0fb23 100644 --- a/src/linkahead/common/datatype.py +++ b/src/linkahead/common/datatype.py @@ -43,16 +43,27 @@ def LIST(datatype): return "LIST<" + str(datatype) + ">" -def get_list_datatype(datatype): +def get_list_datatype(datatype: str, strict: bool = False): """ returns the datatype of the elements in the list """ - if not isinstance(datatype, str): - return None - match = re.match("LIST(<|<)(?P<datatype>.*)(>|>)", datatype) + if not isinstance(datatype, str) or not datatype.lower().startswith("list"): + if strict: + raise ValueError(f"Not a list dtype: {datatype}") + else: + return None + pattern = r"^[Ll][Ii][Ss][Tt]((<|<)(?P<dtype1>.*)(>|>)|\((?P<dtype2>.*)\))$" + match = re.match(pattern, datatype) + + if match and "dtype1" in match.groupdict() and match.groupdict()["dtype1"] is not None: + return match.groupdict()["dtype1"] + + elif match and "dtype2" in match.groupdict() and match.groupdict()["dtype2"] is not None: + return match.groupdict()["dtype2"] - if match is not None: - return match.group("datatype") else: - return None + if strict: + raise ValueError(f"Not a list dtype: {datatype}") + else: + return None def is_list_datatype(datatype): diff --git a/unittests/test_datatype.py b/unittests/test_datatype.py index 5a5e82cc..838edc12 100644 --- a/unittests/test_datatype.py +++ b/unittests/test_datatype.py @@ -33,6 +33,17 @@ def test_list_utilites(): """Test for example if get_list_datatype works.""" dtype = db.LIST(db.INTEGER) assert datatype.get_list_datatype(dtype) == db.INTEGER + assert datatype.get_list_datatype("LIST(Person)") == "Person" + assert datatype.get_list_datatype("Person") is None + assert datatype.get_list_datatype("LIST[]") is None + with raises(ValueError): + datatype.get_list_datatype("LIST[]", strict=True) + with raises(ValueError): + datatype.get_list_datatype("Person", strict=True) + with raises(ValueError): + datatype.get_list_datatype(5, strict=True) + with raises(ValueError): + datatype.get_list_datatype("listlol", strict=True) def test_parsing_of_intger_list_values(): -- GitLab