diff --git a/src/linkahead/common/datatype.py b/src/linkahead/common/datatype.py index 832844567bca31f4c46e205094daa709a8af9e71..7bb0fb23393df3fb4fcba3f795264e71a1d7348b 100644 --- a/src/linkahead/common/datatype.py +++ b/src/linkahead/common/datatype.py @@ -43,16 +43,27 @@ def LIST(datatype): return "LIST<" + str(datatype) + ">" -def get_list_datatype(datatype): +def get_list_datatype(datatype: str, strict: bool = False): """ returns the datatype of the elements in the list """ - if not isinstance(datatype, str): - return None - match = re.match("LIST(<|<)(?P<datatype>.*)(>|>)", datatype) + if not isinstance(datatype, str) or not datatype.lower().startswith("list"): + if strict: + raise ValueError(f"Not a list dtype: {datatype}") + else: + return None + pattern = r"^[Ll][Ii][Ss][Tt]((<|<)(?P<dtype1>.*)(>|>)|\((?P<dtype2>.*)\))$" + match = re.match(pattern, datatype) + + if match and "dtype1" in match.groupdict() and match.groupdict()["dtype1"] is not None: + return match.groupdict()["dtype1"] + + elif match and "dtype2" in match.groupdict() and match.groupdict()["dtype2"] is not None: + return match.groupdict()["dtype2"] - if match is not None: - return match.group("datatype") else: - return None + if strict: + raise ValueError(f"Not a list dtype: {datatype}") + else: + return None def is_list_datatype(datatype): diff --git a/unittests/test_datatype.py b/unittests/test_datatype.py index 5a5e82cc5bfba9ac46a91b4baf4fe45665049c84..838edc120755e564cd6d237193a354c20652d492 100644 --- a/unittests/test_datatype.py +++ b/unittests/test_datatype.py @@ -33,6 +33,17 @@ def test_list_utilites(): """Test for example if get_list_datatype works.""" dtype = db.LIST(db.INTEGER) assert datatype.get_list_datatype(dtype) == db.INTEGER + assert datatype.get_list_datatype("LIST(Person)") == "Person" + assert datatype.get_list_datatype("Person") is None + assert datatype.get_list_datatype("LIST[]") is None + with raises(ValueError): + datatype.get_list_datatype("LIST[]", strict=True) + with raises(ValueError): + datatype.get_list_datatype("Person", strict=True) + with raises(ValueError): + datatype.get_list_datatype(5, strict=True) + with raises(ValueError): + datatype.get_list_datatype("listlol", strict=True) def test_parsing_of_intger_list_values():