From 24f2e39651f3a8be1fa06a56ffff9acb10347078 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Henrik=20tom=20W=C3=B6rden?= <h.tomwoerden@indiscale.com>
Date: Mon, 16 Oct 2023 20:13:27 +0200
Subject: [PATCH] MAINT: refactor get_list_datatype

---
 src/linkahead/common/datatype.py | 25 ++++++++++++++++++-------
 unittests/test_datatype.py       | 11 +++++++++++
 2 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/src/linkahead/common/datatype.py b/src/linkahead/common/datatype.py
index 83284456..7bb0fb23 100644
--- a/src/linkahead/common/datatype.py
+++ b/src/linkahead/common/datatype.py
@@ -43,16 +43,27 @@ def LIST(datatype):
     return "LIST<" + str(datatype) + ">"
 
 
-def get_list_datatype(datatype):
+def get_list_datatype(datatype: str, strict: bool = False):
     """ returns the datatype of the elements in the list """
-    if not isinstance(datatype, str):
-        return None
-    match = re.match("LIST(<|&lt;)(?P<datatype>.*)(>|&gt;)", datatype)
+    if not isinstance(datatype, str) or not datatype.lower().startswith("list"):
+        if strict:
+            raise ValueError(f"Not a list dtype: {datatype}")
+        else:
+            return None
+    pattern = r"^[Ll][Ii][Ss][Tt]((<|&lt;)(?P<dtype1>.*)(>|&gt;)|\((?P<dtype2>.*)\))$"
+    match = re.match(pattern, datatype)
+
+    if match and "dtype1" in match.groupdict() and match.groupdict()["dtype1"] is not None:
+        return match.groupdict()["dtype1"]
+
+    elif match and "dtype2" in match.groupdict() and match.groupdict()["dtype2"] is not None:
+        return match.groupdict()["dtype2"]
 
-    if match is not None:
-        return match.group("datatype")
     else:
-        return None
+        if strict:
+            raise ValueError(f"Not a list dtype: {datatype}")
+        else:
+            return None
 
 
 def is_list_datatype(datatype):
diff --git a/unittests/test_datatype.py b/unittests/test_datatype.py
index 5a5e82cc..838edc12 100644
--- a/unittests/test_datatype.py
+++ b/unittests/test_datatype.py
@@ -33,6 +33,17 @@ def test_list_utilites():
     """Test for example if get_list_datatype works."""
     dtype = db.LIST(db.INTEGER)
     assert datatype.get_list_datatype(dtype) == db.INTEGER
+    assert datatype.get_list_datatype("LIST(Person)") == "Person"
+    assert datatype.get_list_datatype("Person") is None
+    assert datatype.get_list_datatype("LIST[]") is None
+    with raises(ValueError):
+        datatype.get_list_datatype("LIST[]", strict=True)
+    with raises(ValueError):
+        datatype.get_list_datatype("Person", strict=True)
+    with raises(ValueError):
+        datatype.get_list_datatype(5, strict=True)
+    with raises(ValueError):
+        datatype.get_list_datatype("listlol", strict=True)
 
 
 def test_parsing_of_intger_list_values():
-- 
GitLab