From f8cd21fc8079bf1ee80d5e63046e32885b3cec67 Mon Sep 17 00:00:00 2001
From: Daniel <d.hornung@indiscale.com>
Date: Thu, 8 Feb 2024 15:56:44 +0100
Subject: [PATCH] ENH: `cached_query()` now also caches uniqueness related
 exceptions.

---
 CHANGELOG.md            |  2 ++
 src/linkahead/cached.py | 41 +++++++++++++++++++++++++++--------------
 2 files changed, 29 insertions(+), 14 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index d09fad76..5fe2d07c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ### Changed ###
 
+* `cached_query()` now also caches uniqueness related exceptions.
+
 ### Deprecated ###
 
 ### Removed ###
diff --git a/src/linkahead/cached.py b/src/linkahead/cached.py
index 2eff5b1b..b27afe04 100644
--- a/src/linkahead/cached.py
+++ b/src/linkahead/cached.py
@@ -36,6 +36,7 @@ from enum import Enum
 from functools import lru_cache
 from typing import Union
 
+from .exceptions import EmptyUniqueQueryError, QueryNotUniqueError
 from .utils import get_entity
 from .common.models import execute_query, Entity, Container
 
@@ -80,16 +81,22 @@ If a query phrase is given, the result must be unique.  If this is not what you
     if count != 1:
         raise ValueError("You must supply exactly one argument.")
 
+    result = (None, )
     if eid is not None:
-        return _cached_access(AccessType.EID, eid, unique=True)
+        result = _cached_access(AccessType.EID, eid, unique=True)
     if name is not None:
-        return _cached_access(AccessType.NAME, name, unique=True)
+        result = _cached_access(AccessType.NAME, name, unique=True)
     if path is not None:
-        return _cached_access(AccessType.PATH, path, unique=True)
+        result = _cached_access(AccessType.PATH, path, unique=True)
     if query is not None:
-        return _cached_access(AccessType.QUERY, query, unique=True)
+        result = _cached_access(AccessType.QUERY, query, unique=True)
 
-    raise ValueError("Not all arguments may be None.")
+    if result != (None, ):
+        if isinstance(result, (QueryNotUniqueError, EmptyUniqueQueryError)):
+            raise result
+        return result
+
+    raise RuntimeError("This line should never be reached.")
 
 
 def cached_query(query_string) -> Container:
@@ -98,7 +105,10 @@ def cached_query(query_string) -> Container:
 All additional arguments are at their default values.
 
     """
-    return _cached_access(AccessType.QUERY, query_string, unique=False)
+    result = _cached_access(AccessType.QUERY, query_string, unique=False)
+    if isinstance(result, (QueryNotUniqueError, EmptyUniqueQueryError)):
+        raise result
+    return result
 
 
 @lru_cache(maxsize=DEFAULT_SIZE)
@@ -111,14 +121,17 @@ def _cached_access(kind: AccessType, value: Union[str, int], unique=True):
     if value in _DUMMY_CACHE:
         return _DUMMY_CACHE[value]
 
-    if kind == AccessType.QUERY:
-        return execute_query(value, unique=unique)
-    if kind == AccessType.NAME:
-        return get_entity.get_entity_by_name(value)
-    if kind == AccessType.EID:
-        return get_entity.get_entity_by_id(value)
-    if kind == AccessType.PATH:
-        return get_entity.get_entity_by_path(value)
+    try:
+        if kind == AccessType.QUERY:
+            return execute_query(value, unique=unique)
+        if kind == AccessType.NAME:
+            return get_entity.get_entity_by_name(value)
+        if kind == AccessType.EID:
+            return get_entity.get_entity_by_id(value)
+        if kind == AccessType.PATH:
+            return get_entity.get_entity_by_path(value)
+    except (QueryNotUniqueError, EmptyUniqueQueryError) as exc:
+        return exc
 
     raise ValueError(f"Unknown AccessType: {kind}")
 
-- 
GitLab