diff --git a/src/linkahead/cached.py b/src/linkahead/cached.py index b27afe0469bcaac733ece4c0be3d8d124f6305c0..121e26149179b772068863bcc4b183c4b151bd29 100644 --- a/src/linkahead/cached.py +++ b/src/linkahead/cached.py @@ -32,9 +32,10 @@ See also - ``cached_get_entity_by(...)`` : Get an Entity by name, id, ... """ +from __future__ import annotations from enum import Enum from functools import lru_cache -from typing import Union +from typing import Union, Optional, Tuple, Any, Dict from .exceptions import EmptyUniqueQueryError, QueryNotUniqueError from .utils import get_entity @@ -45,7 +46,7 @@ from .common.models import execute_query, Entity, Container DEFAULT_SIZE = 33333 # This dict cache is solely for filling the real cache manually (e.g. to reuse older query results) -_DUMMY_CACHE = {} +_DUMMY_CACHE: Dict[Union[str, int], Any] = {} class AccessType(Enum): @@ -59,8 +60,10 @@ class AccessType(Enum): NAME = 4 -def cached_get_entity_by(eid: Union[str, int] = None, name: str = None, path: str = None, query: - str = None) -> Entity: +def cached_get_entity_by(eid: Union[str, int, None] = None, + name: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None) -> Union[Entity, Tuple[None]]: """Return a single entity that is identified uniquely by one argument. You must supply exactly one argument. @@ -99,7 +102,7 @@ If a query phrase is given, the result must be unique. If this is not what you raise RuntimeError("This line should never be reached.") -def cached_query(query_string) -> Container: +def cached_query(query_string: str) -> Container: """A cached version of :func:`linkahead.execute_query<linkahead.common.models.execute_query>`. All additional arguments are at their default values. @@ -111,8 +114,8 @@ All additional arguments are at their default values. return result -@lru_cache(maxsize=DEFAULT_SIZE) -def _cached_access(kind: AccessType, value: Union[str, int], unique=True): +@ lru_cache(maxsize=DEFAULT_SIZE) +def _cached_access(kind: AccessType, value: Union[str, int], unique: bool = True): # This is the function that is actually cached. # Due to the arguments, the cache has kind of separate sections for cached_query and # cached_get_entity_by with the different AccessTypes. However, there is only one cache size. @@ -123,12 +126,20 @@ def _cached_access(kind: AccessType, value: Union[str, int], unique=True): try: if kind == AccessType.QUERY: + assert isinstance(value, str), f"If kind is QUERY, value must be a string, not { + type(value)}." return execute_query(value, unique=unique) if kind == AccessType.NAME: + assert isinstance(value, str), f"If kind is NAME, value must be a string, not { + type(value)}." return get_entity.get_entity_by_name(value) if kind == AccessType.EID: + assert isinstance(value, (str, int)), f"If kind is EID, value must be a string or int, not { + type(value)}." return get_entity.get_entity_by_id(value) if kind == AccessType.PATH: + assert isinstance(value, str), f"If kind is PATH, value must be a string, not { + type(value)}." return get_entity.get_entity_by_path(value) except (QueryNotUniqueError, EmptyUniqueQueryError) as exc: return exc @@ -152,7 +163,7 @@ out: named tuple return _cached_access.cache_info() -def cache_initialize(maxsize=DEFAULT_SIZE) -> None: +def cache_initialize(maxsize: int = DEFAULT_SIZE) -> None: """Create a new cache with the given size for `cached_query` and `cached_get_entity_by`. This implies a call of :func:`cache_clear`, the old cache is emptied. @@ -163,7 +174,9 @@ def cache_initialize(maxsize=DEFAULT_SIZE) -> None: _cached_access = lru_cache(maxsize=maxsize)(_cached_access.__wrapped__) -def cache_fill(items: dict, kind: AccessType = AccessType.EID, unique: bool = True) -> None: +def cache_fill(items: Dict[Union[str, int], Any], + kind: AccessType = AccessType.EID, + unique: bool = True) -> None: """Add entries to the cache manually. This allows to fill the cache without actually submitting queries. Note that this does not @@ -186,6 +199,19 @@ unique: bool, optional :func:`cached_query`. """ + match kind: + case AccessType.QUERY: + assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings." + case AccessType.NAME: + assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings." + case AccessType.EID: + assert all(isinstance(key, (str, int)) + for key in items.keys()), "Keys must be strings or integers." + case AccessType.PATH: + assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings." + case _: + raise ValueError(f"Unknown AccessType: {kind}") + # 1. add the given items to the corresponding dummy dict cache _DUMMY_CACHE.update(items)