Skip to content
Snippets Groups Projects
Commit 6fe98e47 authored by Joscha Schmiedt's avatar Joscha Schmiedt
Browse files

Add type hints to linkahead/connection

parent 43e0b33f
No related branches found
No related tags found
2 merge requests!143Release 0.15.0,!135Add and fix more type hints
Pipeline #49666 failed
......@@ -23,7 +23,7 @@
# ** end header
#
"""Connection to a LinkAhead server."""
from __future__ import absolute_import, print_function, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals, annotations
import logging
import ssl
......@@ -32,7 +32,7 @@ import warnings
from builtins import str # pylint: disable=redefined-builtin
from errno import EPIPE as BrokenPipe
from socket import error as SocketError
from urllib.parse import quote, urlparse
from urllib.parse import ParseResult, quote, urlparse
from warnings import warn
from requests import Session as HTTPSession
......@@ -58,16 +58,24 @@ from .encode import MultipartYielder, ReadableMultiparts
from .interface import CaosDBHTTPResponse, CaosDBServerConnection
from .utils import make_uri_path, parse_url, urlencode
from typing import TYPE_CHECKING
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional, List, Any, Iterator, Dict, Union
from requests.models import Response
from ssl import _SSLMethod
from .authentication.interface import AbstractAuthenticator
_LOGGER = logging.getLogger(__name__)
class _WrappedHTTPResponse(CaosDBHTTPResponse):
def __init__(self, response):
self.response = response
self._generator = None
self._buffer = b''
self._stream_consumed = False
def __init__(self, response: Response):
self.response: Response = response
self._generator: Optional[Iterator[Any]] = None
self._buffer: Optional[bytes] = b''
self._stream_consumed: bool = False
@property
def reason(self):
......@@ -77,7 +85,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
def status(self):
return self.response.status_code
def read(self, size=None):
def read(self, size: Optional[int] = None):
if self._stream_consumed is True:
raise RuntimeError("Stream is consumed")
......@@ -91,6 +99,10 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
self._stream_consumed = True
return self.response.content
if size is None or size == 0:
raise RuntimeError(
"size parameter should not be None if the stream is not consumed yet")
if len(self._buffer) >= size:
# still enough bytes in the buffer
result = chunk[:size]
......@@ -117,7 +129,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
self._buffer = None
return result
def getheader(self, name, default=None):
def getheader(self, name: str, default=None):
return self.response.headers[name] if name in self.response.headers else default
def getheaders(self):
......@@ -130,7 +142,7 @@ class _WrappedHTTPResponse(CaosDBHTTPResponse):
class _SSLAdapter(HTTPAdapter):
"""Transport adapter that allows us to use different SSL versions."""
def __init__(self, ssl_version):
def __init__(self, ssl_version: _SSLMethod):
self.ssl_version = ssl_version
super().__init__()
......@@ -156,7 +168,11 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
self._session = None
self._timeout = None
def request(self, method, path, headers=None, body=None):
def request(self,
method: str, path: str,
headers: Optional[Dict[str, str]] = None,
body: Union[str, bytes, None] = None,
**kwargs):
"""request.
Send a HTTP request to the server.
......@@ -169,7 +185,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
An URI path segment (without the 'scheme://host:port/' parts),
including query and frament segments.
headers : dict of str -> str, optional
HTTP request headers. (Defautl: None)
HTTP request headers. (Default: None)
body : str or bytes or readable, optional
The body of the HTTP request. Bytes should be a utf-8 encoded
string.
......@@ -232,14 +248,15 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
"No connection url specified. Please "
"do so via linkahead.configure_connection(...) or in a config "
"file.")
if (not config["url"].lower().startswith("https://") and not config["url"].lower().startswith("http://")):
url_string: str = config["url"]
if (not url_string.lower().startswith("https://") and not url_string.lower().startswith("http://")):
raise LinkAheadConnectionError("The connection url is expected "
"to be a http or https url and "
"must include the url scheme "
"(i.e. start with https:// or "
"http://).")
url = urlparse(config["url"])
url: ParseResult = urlparse(url=url_string)
path = url.path.strip("/")
if len(path) > 0:
path = path + "/"
......@@ -271,7 +288,7 @@ class _DefaultCaosDBServerConnection(CaosDBServerConnection):
if "timeout" in config:
self._timeout = config["timeout"]
def _setup_ssl(self, config):
def _setup_ssl(self, config: Dict[str, Any]):
if "ssl_version" in config and config["cacert"] is not None:
ssl_version = getattr(ssl, config["ssl_version"])
else:
......@@ -325,7 +342,7 @@ _DEFAULT_CONF = {
}
def _get_authenticator(**config):
def _get_authenticator(**config) -> AbstractAuthenticator:
"""_get_authenticator.
Import and configure the password_method.
......@@ -337,7 +354,7 @@ def _get_authenticator(**config):
Currently, there are four valid values for this parameter: 'plain',
'pass', 'keyring' and 'auth_token'.
**config :
Any other keyword arguments are passed the configre method of the
Any other keyword arguments are passed the configure method of the
password_method.
Returns
......@@ -534,8 +551,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
__instance = None
def __init__(self):
self._delegate_connection = None
self._authenticator = None
self._delegate_connection: Optional[CaosDBServerConnection] = None
self._authenticator: Optional[AbstractAuthenticator] = None
self.is_configured = False
@classmethod
......@@ -553,7 +570,8 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
"Missing CaosDBServerConnection implementation. You did not "
"specify an `implementation` for the connection.")
try:
self._delegate_connection = config["implementation"]()
self._delegate_connection: CaosDBServerConnection = config["implementation"](
)
if not isinstance(self._delegate_connection,
CaosDBServerConnection):
......@@ -579,7 +597,10 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return self
def retrieve(self, entity_uri_segments=None, query_dict=None, **kwargs):
def retrieve(self,
entity_uri_segments: Optional[List[str]] = None,
query_dict: Optional[Dict[str, Optional[str]]] = None,
**kwargs):
path = make_uri_path(entity_uri_segments, query_dict)
http_response = self._http_request(method="GET", path=path, **kwargs)
......@@ -641,7 +662,7 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return http_response
def download_file(self, path):
def download_file(self, path: str):
"""This function downloads a file via HTTP from the LinkAhead file
system."""
try:
......@@ -681,7 +702,11 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
**kwargs)
raise
def _retry_http_request(self, method, path, headers, body, **kwargs):
def _retry_http_request(self,
method: str,
path: str,
headers: Optional[Dict["str", Any]],
body: Union[str, bytes], **kwargs) -> CaosDBHTTPResponse:
if hasattr(body, "encode"):
# python3
......@@ -689,8 +714,18 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
if headers is None:
headers = {}
if self._authenticator is None:
raise ValueError(
"No authenticator set. Please call configure_connection() first.")
self._authenticator.on_request(method=method, path=path,
headers=headers)
if self._delegate_connection is None:
raise ValueError(
"No connection set. Please call configure_connection() first.")
_LOGGER.debug("request: %s %s %s", method, path, str(headers))
http_response = self._delegate_connection.request(
method=method,
......@@ -704,10 +739,16 @@ class _Connection(object): # pylint: disable=useless-object-inheritance
return http_response
def get_username(self):
def get_username(self) -> str:
"""
Return the username of the current connection.
Shortcut for: get_connection()._authenticator._credentials_provider.username
"""
if self._authenticator is None:
raise ValueError(
"No authenticator set. Please call configure_connection() first.")
if self._authenticator._credentials_provider is None:
raise ValueError(
"No credentials provider set. Please call configure_connection() first.")
return self._authenticator._credentials_provider.username
......@@ -48,6 +48,7 @@ as multipart/form-data suitable for a HTTP POST or PUT request.
multipart/form-data is the standard way to upload files over HTTP
"""
from __future__ import annotations
__all__ = [
'gen_boundary', 'encode_and_quote', 'MultipartParam', 'encode_string',
......@@ -61,6 +62,10 @@ import re
import os
import mimetypes
from email.header import Header
from typing import TYPE_CHECKING
import sys
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional
def gen_boundary():
......@@ -68,7 +73,7 @@ def gen_boundary():
return uuid.uuid4().hex
def encode_and_quote(data):
def encode_and_quote(data: Optional[str]) -> Optional[str]:
"""If ``data`` is unicode, return urllib.quote_plus(data.encode("utf-8"))
otherwise return urllib.quote_plus(data)"""
if data is None:
......@@ -111,7 +116,7 @@ class MultipartParam(object):
"""
def __init__(self,
name,
name: str,
value=None,
filename=None,
filetype=None,
......
......@@ -22,11 +22,14 @@
# ** end header
#
"""This module defines the CaosDBServerConnection interface."""
from abc import ABCMeta, abstractmethod, abstractproperty
from __future__ import annotations
from abc import ABCMeta, abstractmethod, ABC
from warnings import warn
# meta class compatible with Python 2 *and* 3:
ABC = ABCMeta('ABC', (object, ), {'__slots__': ()})
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Dict, Union
class CaosDBHTTPResponse(ABC):
......@@ -34,14 +37,14 @@ class CaosDBHTTPResponse(ABC):
LinkAheadServer."""
@abstractmethod
def read(self, size=-1):
def read(self, size: Optional[int] = -1):
"""Read up to *size* bytes from the response body.
If size is unspecified or -1, all bytes until EOF are returned.
"""
@abstractmethod
def getheader(self, name, default=None):
def getheader(self, name: str, default=None):
"""Return the value of the header *name* or the value of *default* if
there is no such header.
......@@ -50,12 +53,13 @@ class CaosDBHTTPResponse(ABC):
are returned likewise.
"""
@abstractproperty
def status(self):
@property
@abstractmethod
def status(self) -> int:
"""Status code of the response."""
@abstractmethod
def getheaders(self):
def getheaders(self) -> Dict[str, str]:
"""Return all headers."""
def __enter__(self):
......@@ -78,7 +82,12 @@ class CaosDBServerConnection(ABC):
LinkAhead server."""
@abstractmethod
def request(self, method, path, headers=None, body=None, **kwargs):
def request(self,
method: str,
path: str,
headers: Optional[Dict[str, str]] = None,
body: Union[str, bytes, None] = None,
**kwargs) -> CaosDBHTTPResponse:
"""Abstract method. Implement this method for HTTP requests to the
LinkAhead server.
......
......@@ -22,14 +22,19 @@
# ** end header
#
"""Utility functions for the connection module."""
from __future__ import unicode_literals, print_function
from __future__ import unicode_literals, print_function, annotations
from builtins import str as unicode
from urllib.parse import (urlencode as _urlencode, quote as _quote,
urlparse, urlunparse, unquote as _unquote)
import re
from typing import TYPE_CHECKING
import sys
if TYPE_CHECKING and sys.version_info > (3, 7):
from typing import Optional, Dict, List
def urlencode(query):
def urlencode(query: Dict[str, Optional[str]]) -> str:
"""Convert a dict of into a url-encoded (unicode) string.
This is basically a python2/python3 compatibility wrapper for the respective
......@@ -79,7 +84,8 @@ modules when they are called with only the query parameter.
}))
def make_uri_path(segments=None, query=None):
def make_uri_path(segments: Optional[List[str]] = None,
query: Optional[Dict[str, Optional[str]]] = None) -> str:
"""Url-encode all segments, concat them with slashes and append the query.
Examples
......@@ -105,7 +111,10 @@ def make_uri_path(segments=None, query=None):
"""
path_no_query = ("/".join([quote(segment) for segment in segments])
if segments else "")
return str(path_no_query if query is None else "?".join([
if query is None:
return str(path_no_query)
return str("?".join([
path_no_query, "&".join([
quote(key) + "=" +
(quote(query[key]) if query[key] is not None else "")
......@@ -114,13 +123,13 @@ def make_uri_path(segments=None, query=None):
]))
def quote(string):
def quote(string: str) -> str:
enc = string.encode('utf-8')
return _quote(enc).replace('/', '%2F')
def parse_url(url):
fullurl = urlparse(url)
def parse_url(url: str):
fullurl = urlparse(url=url)
# make sure the path ends with a slash
if not fullurl.path.endswith("/"):
parse_result = list(fullurl)
......@@ -132,7 +141,7 @@ def parse_url(url):
_PATTERN = re.compile(r"^SessionToken=([^;]*);.*$")
def unquote(string):
def unquote(string) -> str:
"""unquote.
Decode an urlencoded string into a plain text string.
......@@ -144,7 +153,7 @@ def unquote(string):
return bts
def parse_auth_token(cookie):
def parse_auth_token(cookie: Optional[str]) -> Optional[str]:
"""parse_auth_token.
Parse an auth token from a cookie.
......@@ -165,7 +174,7 @@ def parse_auth_token(cookie):
return auth_token
def auth_token_to_cookie(auth_token):
def auth_token_to_cookie(auth_token: str) -> str:
"""auth_token_to_cookie.
Urlencode an auth token string and format it as a cookie.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment