# SPDX-License-Identifier: AGPL-3.0-or-later # pylint: disable=missing-module-docstring, global-statement import asyncio import threading import concurrent.futures from queue import SimpleQueue from types import MethodType from timeit import default_timer from typing import Iterable, NamedTuple, Tuple, List, Dict, Union from contextlib import contextmanager import httpx import anyio from .network import get_network, initialize, check_network_configuration # pylint:disable=cyclic-import from .client import get_loop from .raise_for_httperror import raise_for_httperror THREADLOCAL = threading.local() """Thread-local data is data for thread specific values.""" def reset_time_for_thread(): THREADLOCAL.total_time = 0 def get_time_for_thread(): """returns thread's total time or None""" return THREADLOCAL.__dict__.get('total_time') def set_timeout_for_thread(timeout, start_time=None): THREADLOCAL.timeout = timeout THREADLOCAL.start_time = start_time def set_context_network_name(network_name): THREADLOCAL.network = get_network(network_name) def get_context_network(): """If set return thread's network. If unset, return value from :py:obj:`get_network`. """ return THREADLOCAL.__dict__.get('network') or get_network() @contextmanager def _record_http_time(): # pylint: disable=too-many-branches time_before_request = default_timer() start_time = getattr(THREADLOCAL, 'start_time', time_before_request) try: yield start_time finally: # update total_time. # See get_time_for_thread() and reset_time_for_thread() if hasattr(THREADLOCAL, 'total_time'): time_after_request = default_timer() THREADLOCAL.total_time += time_after_request - time_before_request def _get_timeout(start_time, kwargs): # pylint: disable=too-many-branches # timeout (httpx) if 'timeout' in kwargs: timeout = kwargs['timeout'] else: timeout = getattr(THREADLOCAL, 'timeout', None) if timeout is not None: kwargs['timeout'] = timeout # 2 minutes timeout for the requests without timeout timeout = timeout or 120 # adjust actual timeout timeout += 0.2 # overhead if start_time: timeout -= default_timer() - start_time return timeout def request(method, url, **kwargs): """same as requests/requests/api.py request(...)""" with _record_http_time() as start_time: network = get_context_network() timeout = _get_timeout(start_time, kwargs) future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop()) try: return future.result(timeout) except concurrent.futures.TimeoutError as e: raise httpx.TimeoutException('Timeout', request=None) from e def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]: """send multiple HTTP requests in parallel. Wait for all requests to finish.""" with _record_http_time() as start_time: # send the requests network = get_context_network() loop = get_loop() future_list = [] for request_desc in request_list: timeout = _get_timeout(start_time, request_desc.kwargs) future = asyncio.run_coroutine_threadsafe( network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop ) future_list.append((future, timeout)) # read the responses responses = [] for future, timeout in future_list: try: responses.append(future.result(timeout)) except concurrent.futures.TimeoutError: responses.append(httpx.TimeoutException('Timeout', request=None)) except Exception as e: # pylint: disable=broad-except responses.append(e) return responses class Request(NamedTuple): """Request description for the multi_requests function""" method: str url: str kwargs: Dict[str, str] = {} @staticmethod def get(url, **kwargs): return Request('GET', url, kwargs) @staticmethod def options(url, **kwargs): return Request('OPTIONS', url, kwargs) @staticmethod def head(url, **kwargs): return Request('HEAD', url, kwargs) @staticmethod def post(url, **kwargs): return Request('POST', url, kwargs) @staticmethod def put(url, **kwargs): return Request('PUT', url, kwargs) @staticmethod def patch(url, **kwargs): return Request('PATCH', url, kwargs) @staticmethod def delete(url, **kwargs): return Request('DELETE', url, kwargs) def get(url, **kwargs): kwargs.setdefault('allow_redirects', True) return request('get', url, **kwargs) def options(url, **kwargs): kwargs.setdefault('allow_redirects', True) return request('options', url, **kwargs) def head(url, **kwargs): kwargs.setdefault('allow_redirects', False) return request('head', url, **kwargs) def post(url, data=None, **kwargs): return request('post', url, data=data, **kwargs) def put(url, data=None, **kwargs): return request('put', url, data=data, **kwargs) def patch(url, data=None, **kwargs): return request('patch', url, data=data, **kwargs) def delete(url, **kwargs): return request('delete', url, **kwargs) async def stream_chunk_to_queue(network, queue, method, url, **kwargs): try: async with await network.stream(method, url, **kwargs) as response: queue.put(response) # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding # https://www.python-httpx.org/quickstart/#streaming-responses async for chunk in response.aiter_raw(65536): if len(chunk) > 0: queue.put(chunk) except (httpx.StreamClosed, anyio.ClosedResourceError): # the response was queued before the exception. # the exception was raised on aiter_raw. # we do nothing here: in the finally block, None will be queued # so stream(method, url, **kwargs) generator can stop pass except Exception as e: # pylint: disable=broad-except # broad except to avoid this scenario: # exception in network.stream(method, url, **kwargs) # -> the exception is not catch here # -> queue None (in finally) # -> the function below steam(method, url, **kwargs) has nothing to return queue.put(e) finally: queue.put(None) def _stream_generator(method, url, **kwargs): queue = SimpleQueue() network = get_context_network() future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop()) # yield chunks obj_or_exception = queue.get() while obj_or_exception is not None: if isinstance(obj_or_exception, Exception): raise obj_or_exception yield obj_or_exception obj_or_exception = queue.get() future.result() def _close_response_method(self): asyncio.run_coroutine_threadsafe(self.aclose(), get_loop()) # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak. # it makes sure that : # * the httpx response is closed (see the stream_chunk_to_queue function) # * to call future.result() in _stream_generator for _ in self._generator: # pylint: disable=protected-access continue def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]: """Replace httpx.stream. Usage: response, stream = poolrequests.stream(...) for chunk in stream: ... httpx.Client.stream requires to write the httpx.HTTPTransport version of the the httpx.AsyncHTTPTransport declared above. """ generator = _stream_generator(method, url, **kwargs) # yield response response = next(generator) # pylint: disable=stop-iteration-return if isinstance(response, Exception): raise response response._generator = generator # pylint: disable=protected-access response.close = MethodType(_close_response_method, response) return response, generator