Skip to content
Snippets Groups Projects
Commit b05dcfc4 authored by Joscha Schmiedt's avatar Joscha Schmiedt
Browse files

Add type hints and type checks to cached.py

parent 10acb62c
Branches
Tags
2 merge requests!143Release 0.15.0,!135Add and fix more type hints
......@@ -32,9 +32,10 @@ See also
- ``cached_get_entity_by(...)`` : Get an Entity by name, id, ...
"""
from __future__ import annotations
from enum import Enum
from functools import lru_cache
from typing import Union
from typing import Union, Optional, Tuple, Any, Dict
from .exceptions import EmptyUniqueQueryError, QueryNotUniqueError
from .utils import get_entity
......@@ -45,7 +46,7 @@ from .common.models import execute_query, Entity, Container
DEFAULT_SIZE = 33333
# This dict cache is solely for filling the real cache manually (e.g. to reuse older query results)
_DUMMY_CACHE = {}
_DUMMY_CACHE: Dict[Union[str, int], Any] = {}
class AccessType(Enum):
......@@ -59,8 +60,10 @@ class AccessType(Enum):
NAME = 4
def cached_get_entity_by(eid: Union[str, int] = None, name: str = None, path: str = None, query:
str = None) -> Entity:
def cached_get_entity_by(eid: Union[str, int, None] = None,
name: Optional[str] = None,
path: Optional[str] = None,
query: Optional[str] = None) -> Union[Entity, Tuple[None]]:
"""Return a single entity that is identified uniquely by one argument.
You must supply exactly one argument.
......@@ -99,7 +102,7 @@ If a query phrase is given, the result must be unique. If this is not what you
raise RuntimeError("This line should never be reached.")
def cached_query(query_string) -> Container:
def cached_query(query_string: str) -> Container:
"""A cached version of :func:`linkahead.execute_query<linkahead.common.models.execute_query>`.
All additional arguments are at their default values.
......@@ -111,8 +114,8 @@ All additional arguments are at their default values.
return result
@lru_cache(maxsize=DEFAULT_SIZE)
def _cached_access(kind: AccessType, value: Union[str, int], unique=True):
@ lru_cache(maxsize=DEFAULT_SIZE)
def _cached_access(kind: AccessType, value: Union[str, int], unique: bool = True):
# This is the function that is actually cached.
# Due to the arguments, the cache has kind of separate sections for cached_query and
# cached_get_entity_by with the different AccessTypes. However, there is only one cache size.
......@@ -123,12 +126,20 @@ def _cached_access(kind: AccessType, value: Union[str, int], unique=True):
try:
if kind == AccessType.QUERY:
assert isinstance(value, str), f"If kind is QUERY, value must be a string, not {
type(value)}."
return execute_query(value, unique=unique)
if kind == AccessType.NAME:
assert isinstance(value, str), f"If kind is NAME, value must be a string, not {
type(value)}."
return get_entity.get_entity_by_name(value)
if kind == AccessType.EID:
assert isinstance(value, (str, int)), f"If kind is EID, value must be a string or int, not {
type(value)}."
return get_entity.get_entity_by_id(value)
if kind == AccessType.PATH:
assert isinstance(value, str), f"If kind is PATH, value must be a string, not {
type(value)}."
return get_entity.get_entity_by_path(value)
except (QueryNotUniqueError, EmptyUniqueQueryError) as exc:
return exc
......@@ -152,7 +163,7 @@ out: named tuple
return _cached_access.cache_info()
def cache_initialize(maxsize=DEFAULT_SIZE) -> None:
def cache_initialize(maxsize: int = DEFAULT_SIZE) -> None:
"""Create a new cache with the given size for `cached_query` and `cached_get_entity_by`.
This implies a call of :func:`cache_clear`, the old cache is emptied.
......@@ -163,7 +174,9 @@ def cache_initialize(maxsize=DEFAULT_SIZE) -> None:
_cached_access = lru_cache(maxsize=maxsize)(_cached_access.__wrapped__)
def cache_fill(items: dict, kind: AccessType = AccessType.EID, unique: bool = True) -> None:
def cache_fill(items: Dict[Union[str, int], Any],
kind: AccessType = AccessType.EID,
unique: bool = True) -> None:
"""Add entries to the cache manually.
This allows to fill the cache without actually submitting queries. Note that this does not
......@@ -186,6 +199,19 @@ unique: bool, optional
:func:`cached_query`.
"""
match kind:
case AccessType.QUERY:
assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings."
case AccessType.NAME:
assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings."
case AccessType.EID:
assert all(isinstance(key, (str, int))
for key in items.keys()), "Keys must be strings or integers."
case AccessType.PATH:
assert all(isinstance(key, str) for key in items.keys()), "Keys must be strings."
case _:
raise ValueError(f"Unknown AccessType: {kind}")
# 1. add the given items to the corresponding dummy dict cache
_DUMMY_CACHE.update(items)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment