Skip to content
Snippets Groups Projects
Commit 6fe98e47 authored by Joscha Schmiedt's avatar Joscha Schmiedt
Browse files

Add type hints to linkahead/connection

parent 43e0b33f
Branches
Tags
2 merge requests!143Release 0.15.0,!135Add and fix more type hints
Pipeline #49666 failed
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# ** end header # ** end header
# #
"""Connection to a LinkAhead server.""" """Connection to a LinkAhead server."""
from __future__ import absolute_import, print_function, unicode_literals from __future__ import absolute_import, print_function, unicode_literals, annotations
import logging import logging
import ssl import ssl
...@@ -32,7 +32,7 @@ import warnings ...@@ -32,7 +32,7 @@ import warnings
from builtins import str # pylint: disable=redefined-builtin from builtins import str # pylint: disable=redefined-builtin
from errno import EPIPE as BrokenPipe from errno import EPIPE as BrokenPipe
from socket import error as SocketError from socket import error as SocketError
from urllib.parse import quote, urlparse from urllib.parse import ParseResult, quote, urlparse
from warnings import warn from warnings import warn
from requests import Session as HTTPSession from requests import Session as HTTPSession
...@@ -58,16 +58,24 @@ from .encode import MultipartYielder, ReadableMultiparts ...@@ -58,16 +58,24 @@ from .encode import MultipartYielder, ReadableMultiparts
from .interface import CaosDBHTTPResponse, CaosDBServerConnection from .interface import CaosDBHTTPResponse, CaosDBServerConnection
from .utils import make_uri_path, parse_url, urlencode from .utils import make_uri_path, parse_url, urlencode
from typing import TYPE_CHECKING
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional, List, Any, Iterator, Dict, Union
from requests.models import Response
from ssl import _SSLMethod
from .authentication.interface import AbstractAuthenticator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class _WrappedHTTPResponse(CaosDBHTTPResponse): class _WrappedHTTPResponse(CaosDBHTTPResponse):
def __init__(self, response): def __init__(self, response: Response):
self.response = response self.response: Response = response
self._generator = None self._generator: Optional[Iterator[Any]] = None
self._buffer = b'' self._buffer: Optional[bytes] = b''
self._stream_consumed = False self._stream_consumed: bool = False
@property @property
def reason(self): def reason(self):
...@@ -77,7 +85,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse): ...@@ -77,7 +85,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
def status(self): def status(self):
return self.response.status_code return self.response.status_code
def read(self, size=None): def read(self, size: Optional[int] = None):
if self._stream_consumed is True: if self._stream_consumed is True:
raise RuntimeError("Stream is consumed") raise RuntimeError("Stream is consumed")
...@@ -91,6 +99,10 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse): ...@@ -91,6 +99,10 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
self._stream_consumed = True self._stream_consumed = True
return self.response.content return self.response.content
if size is None or size == 0:
raise RuntimeError(
"size parameter should not be None if the stream is not consumed yet")
if len(self._buffer) >= size: if len(self._buffer) >= size:
# still enough bytes in the buffer # still enough bytes in the buffer
result = chunk[:size] result = chunk[:size]
...@@ -117,7 +129,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse): ...@@ -117,7 +129,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
self._buffer = None self._buffer = None
return result return result
def getheader(self, name, default=None): def getheader(self, name: str, default=None):
return self.response.headers[name] if name in self.response.headers else default return self.response.headers[name] if name in self.response.headers else default
def getheaders(self): def getheaders(self):
...@@ -130,7 +142,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse): ...@@ -130,7 +142,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
class _SSLAdapter(HTTPAdapter): class _SSLAdapter(HTTPAdapter):
"""Transport adapter that allows us to use different SSL versions.""" """Transport adapter that allows us to use different SSL versions."""
def __init__(self, ssl_version): def __init__(self, ssl_version: _SSLMethod):
self.ssl_version = ssl_version self.ssl_version = ssl_version
super().__init__() super().__init__()
...@@ -156,7 +168,11 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection): ...@@ -156,7 +168,11 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
self._session = None self._session = None
self._timeout = None self._timeout = None
def request(self, method, path, headers=None, body=None): def request(self,
method: str, path: str,
headers: Optional[Dict[str, str]] = None,
body: Union[str, bytes, None] = None,
**kwargs):
"""request. """request.
Send a HTTP request to the server. Send a HTTP request to the server.
...@@ -169,7 +185,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection): ...@@ -169,7 +185,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
An URI path segment (without the 'scheme://host:port/' parts), An URI path segment (without the 'scheme://host:port/' parts),
including query and frament segments. including query and frament segments.
headers : dict of str -> str, optional headers : dict of str -> str, optional
HTTP request headers. (Defautl: None) HTTP request headers. (Default: None)
body : str or bytes or readable, optional body : str or bytes or readable, optional
The body of the HTTP request. Bytes should be a utf-8 encoded The body of the HTTP request. Bytes should be a utf-8 encoded
string. string.
...@@ -232,14 +248,15 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection): ...@@ -232,14 +248,15 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
"No connection url specified. Please " "No connection url specified. Please "
"do so via linkahead.configure_connection(...) or in a config " "do so via linkahead.configure_connection(...) or in a config "
"file.") "file.")
if (not config["url"].lower().startswith("https://") and not config["url"].lower().startswith("http://")): url_string: str = config["url"]
if (not url_string.lower().startswith("https://") and not url_string.lower().startswith("http://")):
raise LinkAheadConnectionError("The connection url is expected " raise LinkAheadConnectionError("The connection url is expected "
"to be a http or https url and " "to be a http or https url and "
"must include the url scheme " "must include the url scheme "
"(i.e. start with https:// or " "(i.e. start with https:// or "
"http://).") "http://).")
url = urlparse(config["url"]) url: ParseResult = urlparse(url=url_string)
path = url.path.strip("/") path = url.path.strip("/")
if len(path) > 0: if len(path) > 0:
path = path + "/" path = path + "/"
...@@ -271,7 +288,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection): ...@@ -271,7 +288,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
if "timeout" in config: if "timeout" in config:
self._timeout = config["timeout"] self._timeout = config["timeout"]
def _setup_ssl(self, config): def _setup_ssl(self, config: Dict[str, Any]):
if "ssl_version" in config and config["cacert"] is not None: if "ssl_version" in config and config["cacert"] is not None:
ssl_version = getattr(ssl, config["ssl_version"]) ssl_version = getattr(ssl, config["ssl_version"])
else: else:
...@@ -325,7 +342,7 @@ _DEFAULT_CONF = { ...@@ -325,7 +342,7 @@ _DEFAULT_CONF = {
} }
def _get_authenticator(**config): def _get_authenticator(**config) -> AbstractAuthenticator:
"""_get_authenticator. """_get_authenticator.
Import and configure the password_method. Import and configure the password_method.
...@@ -337,7 +354,7 @@ def _get_authenticator(**config): ...@@ -337,7 +354,7 @@ def _get_authenticator(**config):
Currently, there are four valid values for this parameter: 'plain', Currently, there are four valid values for this parameter: 'plain',
'pass', 'keyring' and 'auth_token'. 'pass', 'keyring' and 'auth_token'.
**config : **config :
Any other keyword arguments are passed the configre method of the Any other keyword arguments are passed the configure method of the
password_method. password_method.
Returns Returns
...@@ -534,8 +551,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -534,8 +551,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
__instance = None __instance = None
def __init__(self): def __init__(self):
self._delegate_connection = None self._delegate_connection: Optional[CaosDBServerConnection] = None
self._authenticator = None self._authenticator: Optional[AbstractAuthenticator] = None
self.is_configured = False self.is_configured = False
@classmethod @classmethod
...@@ -553,7 +570,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -553,7 +570,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
"Missing CaosDBServerConnection implementation. You did not " "Missing CaosDBServerConnection implementation. You did not "
"specify an `implementation` for the connection.") "specify an `implementation` for the connection.")
try: try:
self._delegate_connection = config["implementation"]() self._delegate_connection: CaosDBServerConnection = config["implementation"](
)
if not isinstance(self._delegate_connection, if not isinstance(self._delegate_connection,
CaosDBServerConnection): CaosDBServerConnection):
...@@ -579,7 +597,10 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -579,7 +597,10 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return self return self
def retrieve(self, entity_uri_segments=None, query_dict=None, **kwargs): def retrieve(self,
entity_uri_segments: Optional[List[str]] = None,
query_dict: Optional[Dict[str, Optional[str]]] = None,
**kwargs):
path = make_uri_path(entity_uri_segments, query_dict) path = make_uri_path(entity_uri_segments, query_dict)
http_response = self._http_request(method="GET", path=path, **kwargs) http_response = self._http_request(method="GET", path=path, **kwargs)
...@@ -641,7 +662,7 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -641,7 +662,7 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return http_response return http_response
def download_file(self, path): def download_file(self, path: str):
"""This function downloads a file via HTTP from the LinkAhead file """This function downloads a file via HTTP from the LinkAhead file
system.""" system."""
try: try:
...@@ -681,7 +702,11 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -681,7 +702,11 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
**kwargs) **kwargs)
raise raise
def _retry_http_request(self, method, path, headers, body, **kwargs): def _retry_http_request(self,
method: str,
path: str,
headers: Optional[Dict["str", Any]],
body: Union[str, bytes], **kwargs) -> CaosDBHTTPResponse:
if hasattr(body, "encode"): if hasattr(body, "encode"):
# python3 # python3
...@@ -689,8 +714,18 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -689,8 +714,18 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
if headers is None: if headers is None:
headers = {} headers = {}
if self._authenticator is None:
raise ValueError(
"No authenticator set. Please call configure_connection() first.")
self._authenticator.on_request(method=method, path=path, self._authenticator.on_request(method=method, path=path,
headers=headers) headers=headers)
if self._delegate_connection is None:
raise ValueError(
"No connection set. Please call configure_connection() first.")
_LOGGER.debug("request: %s %s %s", method, path, str(headers)) _LOGGER.debug("request: %s %s %s", method, path, str(headers))
http_response = self._delegate_connection.request( http_response = self._delegate_connection.request(
method=method, method=method,
...@@ -704,10 +739,16 @@ class _Connection(object): # pylint: disable=useless-object-inheritance ...@@ -704,10 +739,16 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return http_response return http_response
def get_username(self): def get_username(self) -> str:
""" """
Return the username of the current connection. Return the username of the current connection.
Shortcut for: get_connection()._authenticator._credentials_provider.username Shortcut for: get_connection()._authenticator._credentials_provider.username
""" """
if self._authenticator is None:
raise ValueError(
"No authenticator set. Please call configure_connection() first.")
if self._authenticator._credentials_provider is None:
raise ValueError(
"No credentials provider set. Please call configure_connection() first.")
return self._authenticator._credentials_provider.username return self._authenticator._credentials_provider.username
...@@ -48,6 +48,7 @@ as multipart/form-data suitable for a HTTP POST or PUT request. ...@@ -48,6 +48,7 @@ as multipart/form-data suitable for a HTTP POST or PUT request.
multipart/form-data is the standard way to upload files over HTTP multipart/form-data is the standard way to upload files over HTTP
""" """
from __future__ import annotations
__all__ = [ __all__ = [
'gen_boundary', 'encode_and_quote', 'MultipartParam', 'encode_string', 'gen_boundary', 'encode_and_quote', 'MultipartParam', 'encode_string',
...@@ -61,6 +62,10 @@ import re ...@@ -61,6 +62,10 @@ import re
import os import os
import mimetypes import mimetypes
from email.header import Header from email.header import Header
from typing import TYPE_CHECKING
import sys
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional
def gen_boundary(): def gen_boundary():
...@@ -68,7 +73,7 @@ def gen_boundary(): ...@@ -68,7 +73,7 @@ def gen_boundary():
return uuid.uuid4().hex return uuid.uuid4().hex
def encode_and_quote(data): def encode_and_quote(data: Optional[str]) -> Optional[str]:
"""If ``data`` is unicode, return urllib.quote_plus(data.encode("utf-8")) """If ``data`` is unicode, return urllib.quote_plus(data.encode("utf-8"))
otherwise return urllib.quote_plus(data)""" otherwise return urllib.quote_plus(data)"""
if data is None: if data is None:
...@@ -111,7 +116,7 @@ class MultipartParam(object): ...@@ -111,7 +116,7 @@ class MultipartParam(object):
""" """
def __init__(self, def __init__(self,
name, name: str,
value=None, value=None,
filename=None, filename=None,
filetype=None, filetype=None,
......
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
# ** end header # ** end header
# #
"""This module defines the CaosDBServerConnection interface.""" """This module defines the CaosDBServerConnection interface."""
from abc import ABCMeta, abstractmethod, abstractproperty from __future__ import annotations
from abc import ABCMeta, abstractmethod, ABC
from warnings import warn from warnings import warn
# meta class compatible with Python 2 *and* 3:
ABC = ABCMeta('ABC', (object, ), {'__slots__': ()}) from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Dict, Union
class CaosDBHTTPResponse(ABC): class CaosDBHTTPResponse(ABC):
...@@ -34,14 +37,14 @@ class CaosDBHTTPResponse(ABC): ...@@ -34,14 +37,14 @@ class CaosDBHTTPResponse(ABC):
LinkAheadServer.""" LinkAheadServer."""
@abstractmethod @abstractmethod
def read(self, size=-1): def read(self, size: Optional[int] = -1):
"""Read up to *size* bytes from the response body. """Read up to *size* bytes from the response body.
If size is unspecified or -1, all bytes until EOF are returned. If size is unspecified or -1, all bytes until EOF are returned.
""" """
@abstractmethod @abstractmethod
def getheader(self, name, default=None): def getheader(self, name: str, default=None):
"""Return the value of the header *name* or the value of *default* if """Return the value of the header *name* or the value of *default* if
there is no such header. there is no such header.
...@@ -50,12 +53,13 @@ class CaosDBHTTPResponse(ABC): ...@@ -50,12 +53,13 @@ class CaosDBHTTPResponse(ABC):
are returned likewise. are returned likewise.
""" """
@abstractproperty @property
def status(self): @abstractmethod
def status(self) -> int:
"""Status code of the response.""" """Status code of the response."""
@abstractmethod @abstractmethod
def getheaders(self): def getheaders(self) -> Dict[str, str]:
"""Return all headers.""" """Return all headers."""
def __enter__(self): def __enter__(self):
...@@ -78,7 +82,12 @@ class CaosDBServerConnection(ABC): ...@@ -78,7 +82,12 @@ class CaosDBServerConnection(ABC):
LinkAhead server.""" LinkAhead server."""
@abstractmethod @abstractmethod
def request(self, method, path, headers=None, body=None, **kwargs): def request(self,
method: str,
path: str,
headers: Optional[Dict[str, str]] = None,
body: Union[str, bytes, None] = None,
**kwargs) -> CaosDBHTTPResponse:
"""Abstract method. Implement this method for HTTP requests to the """Abstract method. Implement this method for HTTP requests to the
LinkAhead server. LinkAhead server.
......
...@@ -22,14 +22,19 @@ ...@@ -22,14 +22,19 @@
# ** end header # ** end header
# #
"""Utility functions for the connection module.""" """Utility functions for the connection module."""
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function, annotations
from builtins import str as unicode from builtins import str as unicode
from urllib.parse import (urlencode as _urlencode, quote as _quote, from urllib.parse import (urlencode as _urlencode, quote as _quote,
urlparse, urlunparse, unquote as _unquote) urlparse, urlunparse, unquote as _unquote)
import re import re
from typing import TYPE_CHECKING
import sys
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional, Dict, List
def urlencode(query):
def urlencode(query: Dict[str, Optional[str]]) -> str:
"""Convert a dict of into a url-encoded (unicode) string. """Convert a dict of into a url-encoded (unicode) string.
This is basically a python2/python3 compatibility wrapper for the respective This is basically a python2/python3 compatibility wrapper for the respective
...@@ -79,7 +84,8 @@ modules when they are called with only the query parameter. ...@@ -79,7 +84,8 @@ modules when they are called with only the query parameter.
})) }))
def make_uri_path(segments=None, query=None): def make_uri_path(segments: Optional[List[str]] = None,
query: Optional[Dict[str, Optional[str]]] = None) -> str:
"""Url-encode all segments, concat them with slashes and append the query. """Url-encode all segments, concat them with slashes and append the query.
Examples Examples
...@@ -105,7 +111,10 @@ def make_uri_path(segments=None, query=None): ...@@ -105,7 +111,10 @@ def make_uri_path(segments=None, query=None):
""" """
path_no_query = ("/".join([quote(segment) for segment in segments]) path_no_query = ("/".join([quote(segment) for segment in segments])
if segments else "") if segments else "")
return str(path_no_query if query is None else "?".join([ if query is None:
return str(path_no_query)
return str("?".join([
path_no_query, "&".join([ path_no_query, "&".join([
quote(key) + "=" + quote(key) + "=" +
(quote(query[key]) if query[key] is not None else "") (quote(query[key]) if query[key] is not None else "")
...@@ -114,13 +123,13 @@ def make_uri_path(segments=None, query=None): ...@@ -114,13 +123,13 @@ def make_uri_path(segments=None, query=None):
])) ]))
def quote(string): def quote(string: str) -> str:
enc = string.encode('utf-8') enc = string.encode('utf-8')
return _quote(enc).replace('/', '%2F') return _quote(enc).replace('/', '%2F')
def parse_url(url): def parse_url(url: str):
fullurl = urlparse(url) fullurl = urlparse(url=url)
# make sure the path ends with a slash # make sure the path ends with a slash
if not fullurl.path.endswith("/"): if not fullurl.path.endswith("/"):
parse_result = list(fullurl) parse_result = list(fullurl)
...@@ -132,7 +141,7 @@ def parse_url(url): ...@@ -132,7 +141,7 @@ def parse_url(url):
_PATTERN = re.compile(r"^SessionToken=([^;]*);.*$") _PATTERN = re.compile(r"^SessionToken=([^;]*);.*$")
def unquote(string): def unquote(string) -> str:
"""unquote. """unquote.
Decode an urlencoded string into a plain text string. Decode an urlencoded string into a plain text string.
...@@ -144,7 +153,7 @@ def unquote(string): ...@@ -144,7 +153,7 @@ def unquote(string):
return bts return bts
def parse_auth_token(cookie): def parse_auth_token(cookie: Optional[str]) -> Optional[str]:
"""parse_auth_token. """parse_auth_token.
Parse an auth token from a cookie. Parse an auth token from a cookie.
...@@ -165,7 +174,7 @@ def parse_auth_token(cookie): ...@@ -165,7 +174,7 @@ def parse_auth_token(cookie):
return auth_token return auth_token
def auth_token_to_cookie(auth_token): def auth_token_to_cookie(auth_token: str) -> str:
"""auth_token_to_cookie. """auth_token_to_cookie.
Urlencode an auth token string and format it as a cookie. Urlencode an auth token string and format it as a cookie.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment