1653 lines
60 KiB
Python
1653 lines
60 KiB
Python
import asyncio
|
|
import functools
|
|
import random
|
|
import socket
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
from collections import OrderedDict, defaultdict, deque
|
|
from contextlib import suppress
|
|
from http import HTTPStatus
|
|
from itertools import chain, cycle, islice
|
|
from time import monotonic
|
|
from types import TracebackType
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
DefaultDict,
|
|
Deque,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import aiohappyeyeballs
|
|
|
|
from . import hdrs, helpers
|
|
from .abc import AbstractResolver, ResolveResult
|
|
from .client_exceptions import (
|
|
ClientConnectionError,
|
|
ClientConnectorCertificateError,
|
|
ClientConnectorDNSError,
|
|
ClientConnectorError,
|
|
ClientConnectorSSLError,
|
|
ClientHttpProxyError,
|
|
ClientProxyConnectionError,
|
|
ServerFingerprintMismatch,
|
|
UnixClientConnectorError,
|
|
cert_errors,
|
|
ssl_errors,
|
|
)
|
|
from .client_proto import ResponseHandler
|
|
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
|
|
from .helpers import (
|
|
ceil_timeout,
|
|
is_ip_address,
|
|
noop,
|
|
sentinel,
|
|
set_exception,
|
|
set_result,
|
|
)
|
|
from .resolver import DefaultResolver
|
|
|
|
if TYPE_CHECKING:
|
|
import ssl
|
|
|
|
SSLContext = ssl.SSLContext
|
|
else:
|
|
try:
|
|
import ssl
|
|
|
|
SSLContext = ssl.SSLContext
|
|
except ImportError: # pragma: no cover
|
|
ssl = None # type: ignore[assignment]
|
|
SSLContext = object # type: ignore[misc,assignment]
|
|
|
|
EMPTY_SCHEMA_SET = frozenset({""})
|
|
HTTP_SCHEMA_SET = frozenset({"http", "https"})
|
|
WS_SCHEMA_SET = frozenset({"ws", "wss"})
|
|
|
|
HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
|
|
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
|
|
|
|
NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < (
|
|
3,
|
|
13,
|
|
1,
|
|
) or sys.version_info < (3, 12, 7)
|
|
# Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960
|
|
# which first appeared in Python 3.12.7 and 3.13.1
|
|
|
|
|
|
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .client import ClientTimeout
|
|
from .client_reqrep import ConnectionKey
|
|
from .tracing import Trace
|
|
|
|
|
|
class _DeprecationWaiter:
|
|
__slots__ = ("_awaitable", "_awaited")
|
|
|
|
def __init__(self, awaitable: Awaitable[Any]) -> None:
|
|
self._awaitable = awaitable
|
|
self._awaited = False
|
|
|
|
def __await__(self) -> Any:
|
|
self._awaited = True
|
|
return self._awaitable.__await__()
|
|
|
|
def __del__(self) -> None:
|
|
if not self._awaited:
|
|
warnings.warn(
|
|
"Connector.close() is a coroutine, "
|
|
"please use await connector.close()",
|
|
DeprecationWarning,
|
|
)
|
|
|
|
|
|
class Connection:
|
|
|
|
_source_traceback = None
|
|
|
|
def __init__(
|
|
self,
|
|
connector: "BaseConnector",
|
|
key: "ConnectionKey",
|
|
protocol: ResponseHandler,
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> None:
|
|
self._key = key
|
|
self._connector = connector
|
|
self._loop = loop
|
|
self._protocol: Optional[ResponseHandler] = protocol
|
|
self._callbacks: List[Callable[[], None]] = []
|
|
|
|
if loop.get_debug():
|
|
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Connection<{self._key}>"
|
|
|
|
def __del__(self, _warnings: Any = warnings) -> None:
|
|
if self._protocol is not None:
|
|
kwargs = {"source": self}
|
|
_warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs)
|
|
if self._loop.is_closed():
|
|
return
|
|
|
|
self._connector._release(self._key, self._protocol, should_close=True)
|
|
|
|
context = {"client_connection": self, "message": "Unclosed connection"}
|
|
if self._source_traceback is not None:
|
|
context["source_traceback"] = self._source_traceback
|
|
self._loop.call_exception_handler(context)
|
|
|
|
def __bool__(self) -> Literal[True]:
|
|
"""Force subclasses to not be falsy, to make checks simpler."""
|
|
return True
|
|
|
|
@property
|
|
def loop(self) -> asyncio.AbstractEventLoop:
|
|
warnings.warn(
|
|
"connector.loop property is deprecated", DeprecationWarning, stacklevel=2
|
|
)
|
|
return self._loop
|
|
|
|
@property
|
|
def transport(self) -> Optional[asyncio.Transport]:
|
|
if self._protocol is None:
|
|
return None
|
|
return self._protocol.transport
|
|
|
|
@property
|
|
def protocol(self) -> Optional[ResponseHandler]:
|
|
return self._protocol
|
|
|
|
def add_callback(self, callback: Callable[[], None]) -> None:
|
|
if callback is not None:
|
|
self._callbacks.append(callback)
|
|
|
|
def _notify_release(self) -> None:
|
|
callbacks, self._callbacks = self._callbacks[:], []
|
|
|
|
for cb in callbacks:
|
|
with suppress(Exception):
|
|
cb()
|
|
|
|
def close(self) -> None:
|
|
self._notify_release()
|
|
|
|
if self._protocol is not None:
|
|
self._connector._release(self._key, self._protocol, should_close=True)
|
|
self._protocol = None
|
|
|
|
def release(self) -> None:
|
|
self._notify_release()
|
|
|
|
if self._protocol is not None:
|
|
self._connector._release(self._key, self._protocol)
|
|
self._protocol = None
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._protocol is None or not self._protocol.is_connected()
|
|
|
|
|
|
class _TransportPlaceholder:
|
|
"""placeholder for BaseConnector.connect function"""
|
|
|
|
__slots__ = ()
|
|
|
|
def close(self) -> None:
|
|
"""Close the placeholder transport."""
|
|
|
|
|
|
class BaseConnector:
|
|
"""Base connector class.
|
|
|
|
keepalive_timeout - (optional) Keep-alive timeout.
|
|
force_close - Set to True to force close and do reconnect
|
|
after each request (and between redirects).
|
|
limit - The total number of simultaneous connections.
|
|
limit_per_host - Number of simultaneous connections to one host.
|
|
enable_cleanup_closed - Enables clean-up closed ssl transports.
|
|
Disabled by default.
|
|
timeout_ceil_threshold - Trigger ceiling of timeout values when
|
|
it's above timeout_ceil_threshold.
|
|
loop - Optional event loop.
|
|
"""
|
|
|
|
_closed = True # prevent AttributeError in __del__ if ctor was failed
|
|
_source_traceback = None
|
|
|
|
# abort transport after 2 seconds (cleanup broken connections)
|
|
_cleanup_closed_period = 2.0
|
|
|
|
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
keepalive_timeout: Union[object, None, float] = sentinel,
|
|
force_close: bool = False,
|
|
limit: int = 100,
|
|
limit_per_host: int = 0,
|
|
enable_cleanup_closed: bool = False,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
timeout_ceil_threshold: float = 5,
|
|
) -> None:
|
|
|
|
if force_close:
|
|
if keepalive_timeout is not None and keepalive_timeout is not sentinel:
|
|
raise ValueError(
|
|
"keepalive_timeout cannot be set if force_close is True"
|
|
)
|
|
else:
|
|
if keepalive_timeout is sentinel:
|
|
keepalive_timeout = 15.0
|
|
|
|
loop = loop or asyncio.get_running_loop()
|
|
self._timeout_ceil_threshold = timeout_ceil_threshold
|
|
|
|
self._closed = False
|
|
if loop.get_debug():
|
|
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
|
|
|
# Connection pool of reusable connections.
|
|
# We use a deque to store connections because it has O(1) popleft()
|
|
# and O(1) append() operations to implement a FIFO queue.
|
|
self._conns: DefaultDict[
|
|
ConnectionKey, Deque[Tuple[ResponseHandler, float]]
|
|
] = defaultdict(deque)
|
|
self._limit = limit
|
|
self._limit_per_host = limit_per_host
|
|
self._acquired: Set[ResponseHandler] = set()
|
|
self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = (
|
|
defaultdict(set)
|
|
)
|
|
self._keepalive_timeout = cast(float, keepalive_timeout)
|
|
self._force_close = force_close
|
|
|
|
# {host_key: FIFO list of waiters}
|
|
# The FIFO is implemented with an OrderedDict with None keys because
|
|
# python does not have an ordered set.
|
|
self._waiters: DefaultDict[
|
|
ConnectionKey, OrderedDict[asyncio.Future[None], None]
|
|
] = defaultdict(OrderedDict)
|
|
|
|
self._loop = loop
|
|
self._factory = functools.partial(ResponseHandler, loop=loop)
|
|
|
|
# start keep-alive connection cleanup task
|
|
self._cleanup_handle: Optional[asyncio.TimerHandle] = None
|
|
|
|
# start cleanup closed transports task
|
|
self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None
|
|
|
|
if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED:
|
|
warnings.warn(
|
|
"enable_cleanup_closed ignored because "
|
|
"https://github.com/python/cpython/pull/118960 is fixed "
|
|
f"in Python version {sys.version_info}",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
enable_cleanup_closed = False
|
|
|
|
self._cleanup_closed_disabled = not enable_cleanup_closed
|
|
self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
|
|
self._cleanup_closed()
|
|
|
|
def __del__(self, _warnings: Any = warnings) -> None:
|
|
if self._closed:
|
|
return
|
|
if not self._conns:
|
|
return
|
|
|
|
conns = [repr(c) for c in self._conns.values()]
|
|
|
|
self._close()
|
|
|
|
kwargs = {"source": self}
|
|
_warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs)
|
|
context = {
|
|
"connector": self,
|
|
"connections": conns,
|
|
"message": "Unclosed connector",
|
|
}
|
|
if self._source_traceback is not None:
|
|
context["source_traceback"] = self._source_traceback
|
|
self._loop.call_exception_handler(context)
|
|
|
|
def __enter__(self) -> "BaseConnector":
|
|
warnings.warn(
|
|
'"with Connector():" is deprecated, '
|
|
'use "async with Connector():" instead',
|
|
DeprecationWarning,
|
|
)
|
|
return self
|
|
|
|
def __exit__(self, *exc: Any) -> None:
|
|
self._close()
|
|
|
|
async def __aenter__(self) -> "BaseConnector":
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[Type[BaseException]] = None,
|
|
exc_value: Optional[BaseException] = None,
|
|
exc_traceback: Optional[TracebackType] = None,
|
|
) -> None:
|
|
await self.close()
|
|
|
|
@property
|
|
def force_close(self) -> bool:
|
|
"""Ultimately close connection on releasing if True."""
|
|
return self._force_close
|
|
|
|
@property
|
|
def limit(self) -> int:
|
|
"""The total number for simultaneous connections.
|
|
|
|
If limit is 0 the connector has no limit.
|
|
The default limit size is 100.
|
|
"""
|
|
return self._limit
|
|
|
|
@property
|
|
def limit_per_host(self) -> int:
|
|
"""The limit for simultaneous connections to the same endpoint.
|
|
|
|
Endpoints are the same if they are have equal
|
|
(host, port, is_ssl) triple.
|
|
"""
|
|
return self._limit_per_host
|
|
|
|
def _cleanup(self) -> None:
|
|
"""Cleanup unused transports."""
|
|
if self._cleanup_handle:
|
|
self._cleanup_handle.cancel()
|
|
# _cleanup_handle should be unset, otherwise _release() will not
|
|
# recreate it ever!
|
|
self._cleanup_handle = None
|
|
|
|
now = monotonic()
|
|
timeout = self._keepalive_timeout
|
|
|
|
if self._conns:
|
|
connections = defaultdict(deque)
|
|
deadline = now - timeout
|
|
for key, conns in self._conns.items():
|
|
alive: Deque[Tuple[ResponseHandler, float]] = deque()
|
|
for proto, use_time in conns:
|
|
if proto.is_connected() and use_time - deadline >= 0:
|
|
alive.append((proto, use_time))
|
|
continue
|
|
transport = proto.transport
|
|
proto.close()
|
|
if not self._cleanup_closed_disabled and key.is_ssl:
|
|
self._cleanup_closed_transports.append(transport)
|
|
|
|
if alive:
|
|
connections[key] = alive
|
|
|
|
self._conns = connections
|
|
|
|
if self._conns:
|
|
self._cleanup_handle = helpers.weakref_handle(
|
|
self,
|
|
"_cleanup",
|
|
timeout,
|
|
self._loop,
|
|
timeout_ceil_threshold=self._timeout_ceil_threshold,
|
|
)
|
|
|
|
def _cleanup_closed(self) -> None:
|
|
"""Double confirmation for transport close.
|
|
|
|
Some broken ssl servers may leave socket open without proper close.
|
|
"""
|
|
if self._cleanup_closed_handle:
|
|
self._cleanup_closed_handle.cancel()
|
|
|
|
for transport in self._cleanup_closed_transports:
|
|
if transport is not None:
|
|
transport.abort()
|
|
|
|
self._cleanup_closed_transports = []
|
|
|
|
if not self._cleanup_closed_disabled:
|
|
self._cleanup_closed_handle = helpers.weakref_handle(
|
|
self,
|
|
"_cleanup_closed",
|
|
self._cleanup_closed_period,
|
|
self._loop,
|
|
timeout_ceil_threshold=self._timeout_ceil_threshold,
|
|
)
|
|
|
|
def close(self) -> Awaitable[None]:
|
|
"""Close all opened transports."""
|
|
self._close()
|
|
return _DeprecationWaiter(noop())
|
|
|
|
def _close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
|
|
self._closed = True
|
|
|
|
try:
|
|
if self._loop.is_closed():
|
|
return
|
|
|
|
# cancel cleanup task
|
|
if self._cleanup_handle:
|
|
self._cleanup_handle.cancel()
|
|
|
|
# cancel cleanup close task
|
|
if self._cleanup_closed_handle:
|
|
self._cleanup_closed_handle.cancel()
|
|
|
|
for data in self._conns.values():
|
|
for proto, t0 in data:
|
|
proto.close()
|
|
|
|
for proto in self._acquired:
|
|
proto.close()
|
|
|
|
for transport in self._cleanup_closed_transports:
|
|
if transport is not None:
|
|
transport.abort()
|
|
|
|
finally:
|
|
self._conns.clear()
|
|
self._acquired.clear()
|
|
for keyed_waiters in self._waiters.values():
|
|
for keyed_waiter in keyed_waiters:
|
|
keyed_waiter.cancel()
|
|
self._waiters.clear()
|
|
self._cleanup_handle = None
|
|
self._cleanup_closed_transports.clear()
|
|
self._cleanup_closed_handle = None
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
"""Is connector closed.
|
|
|
|
A readonly property.
|
|
"""
|
|
return self._closed
|
|
|
|
def _available_connections(self, key: "ConnectionKey") -> int:
|
|
"""
|
|
Return number of available connections.
|
|
|
|
The limit, limit_per_host and the connection key are taken into account.
|
|
|
|
If it returns less than 1 means that there are no connections
|
|
available.
|
|
"""
|
|
# check total available connections
|
|
# If there are no limits, this will always return 1
|
|
total_remain = 1
|
|
|
|
if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0:
|
|
return total_remain
|
|
|
|
# check limit per host
|
|
if host_remain := self._limit_per_host:
|
|
if acquired := self._acquired_per_host.get(key):
|
|
host_remain -= len(acquired)
|
|
if total_remain > host_remain:
|
|
return host_remain
|
|
|
|
return total_remain
|
|
|
|
async def connect(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> Connection:
|
|
"""Get from pool or create new connection."""
|
|
key = req.connection_key
|
|
if (conn := await self._get(key, traces)) is not None:
|
|
# If we do not have to wait and we can get a connection from the pool
|
|
# we can avoid the timeout ceil logic and directly return the connection
|
|
return conn
|
|
|
|
async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
|
|
if self._available_connections(key) <= 0:
|
|
await self._wait_for_available_connection(key, traces)
|
|
if (conn := await self._get(key, traces)) is not None:
|
|
return conn
|
|
|
|
placeholder = cast(ResponseHandler, _TransportPlaceholder())
|
|
self._acquired.add(placeholder)
|
|
if self._limit_per_host:
|
|
self._acquired_per_host[key].add(placeholder)
|
|
|
|
try:
|
|
# Traces are done inside the try block to ensure that the
|
|
# that the placeholder is still cleaned up if an exception
|
|
# is raised.
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_connection_create_start()
|
|
proto = await self._create_connection(req, traces, timeout)
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_connection_create_end()
|
|
except BaseException:
|
|
self._release_acquired(key, placeholder)
|
|
raise
|
|
else:
|
|
if self._closed:
|
|
proto.close()
|
|
raise ClientConnectionError("Connector is closed.")
|
|
|
|
# The connection was successfully created, drop the placeholder
|
|
# and add the real connection to the acquired set. There should
|
|
# be no awaits after the proto is added to the acquired set
|
|
# to ensure that the connection is not left in the acquired set
|
|
# on cancellation.
|
|
self._acquired.remove(placeholder)
|
|
self._acquired.add(proto)
|
|
if self._limit_per_host:
|
|
acquired_per_host = self._acquired_per_host[key]
|
|
acquired_per_host.remove(placeholder)
|
|
acquired_per_host.add(proto)
|
|
return Connection(self, key, proto, self._loop)
|
|
|
|
async def _wait_for_available_connection(
|
|
self, key: "ConnectionKey", traces: List["Trace"]
|
|
) -> None:
|
|
"""Wait for an available connection slot."""
|
|
# We loop here because there is a race between
|
|
# the connection limit check and the connection
|
|
# being acquired. If the connection is acquired
|
|
# between the check and the await statement, we
|
|
# need to loop again to check if the connection
|
|
# slot is still available.
|
|
attempts = 0
|
|
while True:
|
|
fut: asyncio.Future[None] = self._loop.create_future()
|
|
keyed_waiters = self._waiters[key]
|
|
keyed_waiters[fut] = None
|
|
if attempts:
|
|
# If we have waited before, we need to move the waiter
|
|
# to the front of the queue as otherwise we might get
|
|
# starved and hit the timeout.
|
|
keyed_waiters.move_to_end(fut, last=False)
|
|
|
|
try:
|
|
# Traces happen in the try block to ensure that the
|
|
# the waiter is still cleaned up if an exception is raised.
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_connection_queued_start()
|
|
await fut
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_connection_queued_end()
|
|
finally:
|
|
# pop the waiter from the queue if its still
|
|
# there and not already removed by _release_waiter
|
|
keyed_waiters.pop(fut, None)
|
|
if not self._waiters.get(key, True):
|
|
del self._waiters[key]
|
|
|
|
if self._available_connections(key) > 0:
|
|
break
|
|
attempts += 1
|
|
|
|
async def _get(
|
|
self, key: "ConnectionKey", traces: List["Trace"]
|
|
) -> Optional[Connection]:
|
|
"""Get next reusable connection for the key or None.
|
|
|
|
The connection will be marked as acquired.
|
|
"""
|
|
if (conns := self._conns.get(key)) is None:
|
|
return None
|
|
|
|
t1 = monotonic()
|
|
while conns:
|
|
proto, t0 = conns.popleft()
|
|
# We will we reuse the connection if its connected and
|
|
# the keepalive timeout has not been exceeded
|
|
if proto.is_connected() and t1 - t0 <= self._keepalive_timeout:
|
|
if not conns:
|
|
# The very last connection was reclaimed: drop the key
|
|
del self._conns[key]
|
|
self._acquired.add(proto)
|
|
if self._limit_per_host:
|
|
self._acquired_per_host[key].add(proto)
|
|
if traces:
|
|
for trace in traces:
|
|
try:
|
|
await trace.send_connection_reuseconn()
|
|
except BaseException:
|
|
self._release_acquired(key, proto)
|
|
raise
|
|
return Connection(self, key, proto, self._loop)
|
|
|
|
# Connection cannot be reused, close it
|
|
transport = proto.transport
|
|
proto.close()
|
|
# only for SSL transports
|
|
if not self._cleanup_closed_disabled and key.is_ssl:
|
|
self._cleanup_closed_transports.append(transport)
|
|
|
|
# No more connections: drop the key
|
|
del self._conns[key]
|
|
return None
|
|
|
|
def _release_waiter(self) -> None:
|
|
"""
|
|
Iterates over all waiters until one to be released is found.
|
|
|
|
The one to be released is not finished and
|
|
belongs to a host that has available connections.
|
|
"""
|
|
if not self._waiters:
|
|
return
|
|
|
|
# Having the dict keys ordered this avoids to iterate
|
|
# at the same order at each call.
|
|
queues = list(self._waiters)
|
|
random.shuffle(queues)
|
|
|
|
for key in queues:
|
|
if self._available_connections(key) < 1:
|
|
continue
|
|
|
|
waiters = self._waiters[key]
|
|
while waiters:
|
|
waiter, _ = waiters.popitem(last=False)
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
return
|
|
|
|
def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
|
|
"""Release acquired connection."""
|
|
if self._closed:
|
|
# acquired connection is already released on connector closing
|
|
return
|
|
|
|
self._acquired.discard(proto)
|
|
if self._limit_per_host and (conns := self._acquired_per_host.get(key)):
|
|
conns.discard(proto)
|
|
if not conns:
|
|
del self._acquired_per_host[key]
|
|
self._release_waiter()
|
|
|
|
def _release(
|
|
self,
|
|
key: "ConnectionKey",
|
|
protocol: ResponseHandler,
|
|
*,
|
|
should_close: bool = False,
|
|
) -> None:
|
|
if self._closed:
|
|
# acquired connection is already released on connector closing
|
|
return
|
|
|
|
self._release_acquired(key, protocol)
|
|
|
|
if self._force_close or should_close or protocol.should_close:
|
|
transport = protocol.transport
|
|
protocol.close()
|
|
|
|
if key.is_ssl and not self._cleanup_closed_disabled:
|
|
self._cleanup_closed_transports.append(transport)
|
|
return
|
|
|
|
self._conns[key].append((protocol, monotonic()))
|
|
|
|
if self._cleanup_handle is None:
|
|
self._cleanup_handle = helpers.weakref_handle(
|
|
self,
|
|
"_cleanup",
|
|
self._keepalive_timeout,
|
|
self._loop,
|
|
timeout_ceil_threshold=self._timeout_ceil_threshold,
|
|
)
|
|
|
|
async def _create_connection(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> ResponseHandler:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _DNSCacheTable:
|
|
def __init__(self, ttl: Optional[float] = None) -> None:
|
|
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
|
|
self._timestamps: Dict[Tuple[str, int], float] = {}
|
|
self._ttl = ttl
|
|
|
|
def __contains__(self, host: object) -> bool:
|
|
return host in self._addrs_rr
|
|
|
|
def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
|
|
self._addrs_rr[key] = (cycle(addrs), len(addrs))
|
|
|
|
if self._ttl is not None:
|
|
self._timestamps[key] = monotonic()
|
|
|
|
def remove(self, key: Tuple[str, int]) -> None:
|
|
self._addrs_rr.pop(key, None)
|
|
|
|
if self._ttl is not None:
|
|
self._timestamps.pop(key, None)
|
|
|
|
def clear(self) -> None:
|
|
self._addrs_rr.clear()
|
|
self._timestamps.clear()
|
|
|
|
def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
|
|
loop, length = self._addrs_rr[key]
|
|
addrs = list(islice(loop, length))
|
|
# Consume one more element to shift internal state of `cycle`
|
|
next(loop)
|
|
return addrs
|
|
|
|
def expired(self, key: Tuple[str, int]) -> bool:
|
|
if self._ttl is None:
|
|
return False
|
|
|
|
return self._timestamps[key] + self._ttl < monotonic()
|
|
|
|
|
|
def _make_ssl_context(verified: bool) -> SSLContext:
|
|
"""Create SSL context.
|
|
|
|
This method is not async-friendly and should be called from a thread
|
|
because it will load certificates from disk and do other blocking I/O.
|
|
"""
|
|
if ssl is None:
|
|
# No ssl support
|
|
return None
|
|
if verified:
|
|
sslcontext = ssl.create_default_context()
|
|
else:
|
|
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
sslcontext.options |= ssl.OP_NO_SSLv2
|
|
sslcontext.options |= ssl.OP_NO_SSLv3
|
|
sslcontext.check_hostname = False
|
|
sslcontext.verify_mode = ssl.CERT_NONE
|
|
sslcontext.options |= ssl.OP_NO_COMPRESSION
|
|
sslcontext.set_default_verify_paths()
|
|
sslcontext.set_alpn_protocols(("http/1.1",))
|
|
return sslcontext
|
|
|
|
|
|
# The default SSLContext objects are created at import time
|
|
# since they do blocking I/O to load certificates from disk,
|
|
# and imports should always be done before the event loop starts
|
|
# or in a thread.
|
|
_SSL_CONTEXT_VERIFIED = _make_ssl_context(True)
|
|
_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False)
|
|
|
|
|
|
class TCPConnector(BaseConnector):
|
|
"""TCP connector.
|
|
|
|
verify_ssl - Set to True to check ssl certifications.
|
|
fingerprint - Pass the binary sha256
|
|
digest of the expected certificate in DER format to verify
|
|
that the certificate the server presents matches. See also
|
|
https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning
|
|
resolver - Enable DNS lookups and use this
|
|
resolver
|
|
use_dns_cache - Use memory cache for DNS lookups.
|
|
ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
|
|
family - socket address family
|
|
local_addr - local tuple of (host, port) to bind socket to
|
|
|
|
keepalive_timeout - (optional) Keep-alive timeout.
|
|
force_close - Set to True to force close and do reconnect
|
|
after each request (and between redirects).
|
|
limit - The total number of simultaneous connections.
|
|
limit_per_host - Number of simultaneous connections to one host.
|
|
enable_cleanup_closed - Enables clean-up closed ssl transports.
|
|
Disabled by default.
|
|
happy_eyeballs_delay - This is the “Connection Attempt Delay”
|
|
as defined in RFC 8305. To disable
|
|
the happy eyeballs algorithm, set to None.
|
|
interleave - “First Address Family Count” as defined in RFC 8305
|
|
loop - Optional event loop.
|
|
"""
|
|
|
|
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
verify_ssl: bool = True,
|
|
fingerprint: Optional[bytes] = None,
|
|
use_dns_cache: bool = True,
|
|
ttl_dns_cache: Optional[int] = 10,
|
|
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
|
|
ssl_context: Optional[SSLContext] = None,
|
|
ssl: Union[bool, Fingerprint, SSLContext] = True,
|
|
local_addr: Optional[Tuple[str, int]] = None,
|
|
resolver: Optional[AbstractResolver] = None,
|
|
keepalive_timeout: Union[None, float, object] = sentinel,
|
|
force_close: bool = False,
|
|
limit: int = 100,
|
|
limit_per_host: int = 0,
|
|
enable_cleanup_closed: bool = False,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
timeout_ceil_threshold: float = 5,
|
|
happy_eyeballs_delay: Optional[float] = 0.25,
|
|
interleave: Optional[int] = None,
|
|
):
|
|
super().__init__(
|
|
keepalive_timeout=keepalive_timeout,
|
|
force_close=force_close,
|
|
limit=limit,
|
|
limit_per_host=limit_per_host,
|
|
enable_cleanup_closed=enable_cleanup_closed,
|
|
loop=loop,
|
|
timeout_ceil_threshold=timeout_ceil_threshold,
|
|
)
|
|
|
|
self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
|
|
if resolver is None:
|
|
resolver = DefaultResolver(loop=self._loop)
|
|
self._resolver = resolver
|
|
|
|
self._use_dns_cache = use_dns_cache
|
|
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
|
|
self._throttle_dns_futures: Dict[
|
|
Tuple[str, int], Set["asyncio.Future[None]"]
|
|
] = {}
|
|
self._family = family
|
|
self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
|
|
self._happy_eyeballs_delay = happy_eyeballs_delay
|
|
self._interleave = interleave
|
|
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
|
|
|
|
def close(self) -> Awaitable[None]:
|
|
"""Close all ongoing DNS calls."""
|
|
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
|
|
fut.cancel()
|
|
|
|
for t in self._resolve_host_tasks:
|
|
t.cancel()
|
|
|
|
return super().close()
|
|
|
|
@property
|
|
def family(self) -> int:
|
|
"""Socket family like AF_INET."""
|
|
return self._family
|
|
|
|
@property
|
|
def use_dns_cache(self) -> bool:
|
|
"""True if local DNS caching is enabled."""
|
|
return self._use_dns_cache
|
|
|
|
def clear_dns_cache(
|
|
self, host: Optional[str] = None, port: Optional[int] = None
|
|
) -> None:
|
|
"""Remove specified host/port or clear all dns local cache."""
|
|
if host is not None and port is not None:
|
|
self._cached_hosts.remove((host, port))
|
|
elif host is not None or port is not None:
|
|
raise ValueError("either both host and port or none of them are allowed")
|
|
else:
|
|
self._cached_hosts.clear()
|
|
|
|
async def _resolve_host(
|
|
self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None
|
|
) -> List[ResolveResult]:
|
|
"""Resolve host and return list of addresses."""
|
|
if is_ip_address(host):
|
|
return [
|
|
{
|
|
"hostname": host,
|
|
"host": host,
|
|
"port": port,
|
|
"family": self._family,
|
|
"proto": 0,
|
|
"flags": 0,
|
|
}
|
|
]
|
|
|
|
if not self._use_dns_cache:
|
|
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_resolvehost_start(host)
|
|
|
|
res = await self._resolver.resolve(host, port, family=self._family)
|
|
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_resolvehost_end(host)
|
|
|
|
return res
|
|
|
|
key = (host, port)
|
|
if key in self._cached_hosts and not self._cached_hosts.expired(key):
|
|
# get result early, before any await (#4014)
|
|
result = self._cached_hosts.next_addrs(key)
|
|
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_cache_hit(host)
|
|
return result
|
|
|
|
futures: Set["asyncio.Future[None]"]
|
|
#
|
|
# If multiple connectors are resolving the same host, we wait
|
|
# for the first one to resolve and then use the result for all of them.
|
|
# We use a throttle to ensure that we only resolve the host once
|
|
# and then use the result for all the waiters.
|
|
#
|
|
if key in self._throttle_dns_futures:
|
|
# get futures early, before any await (#4014)
|
|
futures = self._throttle_dns_futures[key]
|
|
future: asyncio.Future[None] = self._loop.create_future()
|
|
futures.add(future)
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_cache_hit(host)
|
|
try:
|
|
await future
|
|
finally:
|
|
futures.discard(future)
|
|
return self._cached_hosts.next_addrs(key)
|
|
|
|
# update dict early, before any await (#4014)
|
|
self._throttle_dns_futures[key] = futures = set()
|
|
# In this case we need to create a task to ensure that we can shield
|
|
# the task from cancellation as cancelling this lookup should not cancel
|
|
# the underlying lookup or else the cancel event will get broadcast to
|
|
# all the waiters across all connections.
|
|
#
|
|
coro = self._resolve_host_with_throttle(key, host, port, futures, traces)
|
|
loop = asyncio.get_running_loop()
|
|
if sys.version_info >= (3, 12):
|
|
# Optimization for Python 3.12, try to send immediately
|
|
resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True)
|
|
else:
|
|
resolved_host_task = loop.create_task(coro)
|
|
|
|
if not resolved_host_task.done():
|
|
self._resolve_host_tasks.add(resolved_host_task)
|
|
resolved_host_task.add_done_callback(self._resolve_host_tasks.discard)
|
|
|
|
try:
|
|
return await asyncio.shield(resolved_host_task)
|
|
except asyncio.CancelledError:
|
|
|
|
def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
|
|
with suppress(Exception, asyncio.CancelledError):
|
|
fut.result()
|
|
|
|
resolved_host_task.add_done_callback(drop_exception)
|
|
raise
|
|
|
|
async def _resolve_host_with_throttle(
|
|
self,
|
|
key: Tuple[str, int],
|
|
host: str,
|
|
port: int,
|
|
futures: Set["asyncio.Future[None]"],
|
|
traces: Optional[Sequence["Trace"]],
|
|
) -> List[ResolveResult]:
|
|
"""Resolve host and set result for all waiters.
|
|
|
|
This method must be run in a task and shielded from cancellation
|
|
to avoid cancelling the underlying lookup.
|
|
"""
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_cache_miss(host)
|
|
try:
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_resolvehost_start(host)
|
|
|
|
addrs = await self._resolver.resolve(host, port, family=self._family)
|
|
if traces:
|
|
for trace in traces:
|
|
await trace.send_dns_resolvehost_end(host)
|
|
|
|
self._cached_hosts.add(key, addrs)
|
|
for fut in futures:
|
|
set_result(fut, None)
|
|
except BaseException as e:
|
|
# any DNS exception is set for the waiters to raise the same exception.
|
|
# This coro is always run in task that is shielded from cancellation so
|
|
# we should never be propagating cancellation here.
|
|
for fut in futures:
|
|
set_exception(fut, e)
|
|
raise
|
|
finally:
|
|
self._throttle_dns_futures.pop(key)
|
|
|
|
return self._cached_hosts.next_addrs(key)
|
|
|
|
async def _create_connection(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> ResponseHandler:
|
|
"""Create connection.
|
|
|
|
Has same keyword arguments as BaseEventLoop.create_connection.
|
|
"""
|
|
if req.proxy:
|
|
_, proto = await self._create_proxy_connection(req, traces, timeout)
|
|
else:
|
|
_, proto = await self._create_direct_connection(req, traces, timeout)
|
|
|
|
return proto
|
|
|
|
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
|
|
"""Logic to get the correct SSL context
|
|
|
|
0. if req.ssl is false, return None
|
|
|
|
1. if ssl_context is specified in req, use it
|
|
2. if _ssl_context is specified in self, use it
|
|
3. otherwise:
|
|
1. if verify_ssl is not specified in req, use self.ssl_context
|
|
(will generate a default context according to self.verify_ssl)
|
|
2. if verify_ssl is True in req, generate a default SSL context
|
|
3. if verify_ssl is False in req, generate a SSL context that
|
|
won't verify
|
|
"""
|
|
if not req.is_ssl():
|
|
return None
|
|
|
|
if ssl is None: # pragma: no cover
|
|
raise RuntimeError("SSL is not supported.")
|
|
sslcontext = req.ssl
|
|
if isinstance(sslcontext, ssl.SSLContext):
|
|
return sslcontext
|
|
if sslcontext is not True:
|
|
# not verified or fingerprinted
|
|
return _SSL_CONTEXT_UNVERIFIED
|
|
sslcontext = self._ssl
|
|
if isinstance(sslcontext, ssl.SSLContext):
|
|
return sslcontext
|
|
if sslcontext is not True:
|
|
# not verified or fingerprinted
|
|
return _SSL_CONTEXT_UNVERIFIED
|
|
return _SSL_CONTEXT_VERIFIED
|
|
|
|
def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
|
|
ret = req.ssl
|
|
if isinstance(ret, Fingerprint):
|
|
return ret
|
|
ret = self._ssl
|
|
if isinstance(ret, Fingerprint):
|
|
return ret
|
|
return None
|
|
|
|
async def _wrap_create_connection(
|
|
self,
|
|
*args: Any,
|
|
addr_infos: List[aiohappyeyeballs.AddrInfoType],
|
|
req: ClientRequest,
|
|
timeout: "ClientTimeout",
|
|
client_error: Type[Exception] = ClientConnectorError,
|
|
**kwargs: Any,
|
|
) -> Tuple[asyncio.Transport, ResponseHandler]:
|
|
try:
|
|
async with ceil_timeout(
|
|
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
|
|
):
|
|
sock = await aiohappyeyeballs.start_connection(
|
|
addr_infos=addr_infos,
|
|
local_addr_infos=self._local_addr_infos,
|
|
happy_eyeballs_delay=self._happy_eyeballs_delay,
|
|
interleave=self._interleave,
|
|
loop=self._loop,
|
|
)
|
|
return await self._loop.create_connection(*args, **kwargs, sock=sock)
|
|
except cert_errors as exc:
|
|
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
|
|
except ssl_errors as exc:
|
|
raise ClientConnectorSSLError(req.connection_key, exc) from exc
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
raise client_error(req.connection_key, exc) from exc
|
|
|
|
async def _wrap_existing_connection(
|
|
self,
|
|
*args: Any,
|
|
req: ClientRequest,
|
|
timeout: "ClientTimeout",
|
|
client_error: Type[Exception] = ClientConnectorError,
|
|
**kwargs: Any,
|
|
) -> Tuple[asyncio.Transport, ResponseHandler]:
|
|
try:
|
|
async with ceil_timeout(
|
|
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
|
|
):
|
|
return await self._loop.create_connection(*args, **kwargs)
|
|
except cert_errors as exc:
|
|
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
|
|
except ssl_errors as exc:
|
|
raise ClientConnectorSSLError(req.connection_key, exc) from exc
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
raise client_error(req.connection_key, exc) from exc
|
|
|
|
def _fail_on_no_start_tls(self, req: "ClientRequest") -> None:
|
|
"""Raise a :py:exc:`RuntimeError` on missing ``start_tls()``.
|
|
|
|
It is necessary for TLS-in-TLS so that it is possible to
|
|
send HTTPS queries through HTTPS proxies.
|
|
|
|
This doesn't affect regular HTTP requests, though.
|
|
"""
|
|
if not req.is_ssl():
|
|
return
|
|
|
|
proxy_url = req.proxy
|
|
assert proxy_url is not None
|
|
if proxy_url.scheme != "https":
|
|
return
|
|
|
|
self._check_loop_for_start_tls()
|
|
|
|
def _check_loop_for_start_tls(self) -> None:
|
|
try:
|
|
self._loop.start_tls
|
|
except AttributeError as attr_exc:
|
|
raise RuntimeError(
|
|
"An HTTPS request is being sent through an HTTPS proxy. "
|
|
"This needs support for TLS in TLS but it is not implemented "
|
|
"in your runtime for the stdlib asyncio.\n\n"
|
|
"Please upgrade to Python 3.11 or higher. For more details, "
|
|
"please see:\n"
|
|
"* https://bugs.python.org/issue37179\n"
|
|
"* https://github.com/python/cpython/pull/28073\n"
|
|
"* https://docs.aiohttp.org/en/stable/"
|
|
"client_advanced.html#proxy-support\n"
|
|
"* https://github.com/aio-libs/aiohttp/discussions/6044\n",
|
|
) from attr_exc
|
|
|
|
def _loop_supports_start_tls(self) -> bool:
|
|
try:
|
|
self._check_loop_for_start_tls()
|
|
except RuntimeError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _warn_about_tls_in_tls(
|
|
self,
|
|
underlying_transport: asyncio.Transport,
|
|
req: ClientRequest,
|
|
) -> None:
|
|
"""Issue a warning if the requested URL has HTTPS scheme."""
|
|
if req.request_info.url.scheme != "https":
|
|
return
|
|
|
|
asyncio_supports_tls_in_tls = getattr(
|
|
underlying_transport,
|
|
"_start_tls_compatible",
|
|
False,
|
|
)
|
|
|
|
if asyncio_supports_tls_in_tls:
|
|
return
|
|
|
|
warnings.warn(
|
|
"An HTTPS request is being sent through an HTTPS proxy. "
|
|
"This support for TLS in TLS is known to be disabled "
|
|
"in the stdlib asyncio (Python <3.11). This is why you'll probably see "
|
|
"an error in the log below.\n\n"
|
|
"It is possible to enable it via monkeypatching. "
|
|
"For more details, see:\n"
|
|
"* https://bugs.python.org/issue37179\n"
|
|
"* https://github.com/python/cpython/pull/28073\n\n"
|
|
"You can temporarily patch this as follows:\n"
|
|
"* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n"
|
|
"* https://github.com/aio-libs/aiohttp/discussions/6044\n",
|
|
RuntimeWarning,
|
|
source=self,
|
|
# Why `4`? At least 3 of the calls in the stack originate
|
|
# from the methods in this class.
|
|
stacklevel=3,
|
|
)
|
|
|
|
async def _start_tls_connection(
|
|
self,
|
|
underlying_transport: asyncio.Transport,
|
|
req: ClientRequest,
|
|
timeout: "ClientTimeout",
|
|
client_error: Type[Exception] = ClientConnectorError,
|
|
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
|
|
"""Wrap the raw TCP transport with TLS."""
|
|
tls_proto = self._factory() # Create a brand new proto for TLS
|
|
sslcontext = self._get_ssl_context(req)
|
|
if TYPE_CHECKING:
|
|
# _start_tls_connection is unreachable in the current code path
|
|
# if sslcontext is None.
|
|
assert sslcontext is not None
|
|
|
|
try:
|
|
async with ceil_timeout(
|
|
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
|
|
):
|
|
try:
|
|
tls_transport = await self._loop.start_tls(
|
|
underlying_transport,
|
|
tls_proto,
|
|
sslcontext,
|
|
server_hostname=req.server_hostname or req.host,
|
|
ssl_handshake_timeout=timeout.total,
|
|
)
|
|
except BaseException:
|
|
# We need to close the underlying transport since
|
|
# `start_tls()` probably failed before it had a
|
|
# chance to do this:
|
|
underlying_transport.close()
|
|
raise
|
|
if isinstance(tls_transport, asyncio.Transport):
|
|
fingerprint = self._get_fingerprint(req)
|
|
if fingerprint:
|
|
try:
|
|
fingerprint.check(tls_transport)
|
|
except ServerFingerprintMismatch:
|
|
tls_transport.close()
|
|
if not self._cleanup_closed_disabled:
|
|
self._cleanup_closed_transports.append(tls_transport)
|
|
raise
|
|
except cert_errors as exc:
|
|
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
|
|
except ssl_errors as exc:
|
|
raise ClientConnectorSSLError(req.connection_key, exc) from exc
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
raise client_error(req.connection_key, exc) from exc
|
|
except TypeError as type_err:
|
|
# Example cause looks like this:
|
|
# TypeError: transport <asyncio.sslproto._SSLProtocolTransport
|
|
# object at 0x7f760615e460> is not supported by start_tls()
|
|
|
|
raise ClientConnectionError(
|
|
"Cannot initialize a TLS-in-TLS connection to host "
|
|
f"{req.host!s}:{req.port:d} through an underlying connection "
|
|
f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} "
|
|
f"[{type_err!s}]"
|
|
) from type_err
|
|
else:
|
|
if tls_transport is None:
|
|
msg = "Failed to start TLS (possibly caused by closing transport)"
|
|
raise client_error(req.connection_key, OSError(msg))
|
|
tls_proto.connection_made(
|
|
tls_transport
|
|
) # Kick the state machine of the new TLS protocol
|
|
|
|
return tls_transport, tls_proto
|
|
|
|
def _convert_hosts_to_addr_infos(
|
|
self, hosts: List[ResolveResult]
|
|
) -> List[aiohappyeyeballs.AddrInfoType]:
|
|
"""Converts the list of hosts to a list of addr_infos.
|
|
|
|
The list of hosts is the result of a DNS lookup. The list of
|
|
addr_infos is the result of a call to `socket.getaddrinfo()`.
|
|
"""
|
|
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
|
|
for hinfo in hosts:
|
|
host = hinfo["host"]
|
|
is_ipv6 = ":" in host
|
|
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
|
|
if self._family and self._family != family:
|
|
continue
|
|
addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"])
|
|
addr_infos.append(
|
|
(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)
|
|
)
|
|
return addr_infos
|
|
|
|
async def _create_direct_connection(
|
|
self,
|
|
req: ClientRequest,
|
|
traces: List["Trace"],
|
|
timeout: "ClientTimeout",
|
|
*,
|
|
client_error: Type[Exception] = ClientConnectorError,
|
|
) -> Tuple[asyncio.Transport, ResponseHandler]:
|
|
sslcontext = self._get_ssl_context(req)
|
|
fingerprint = self._get_fingerprint(req)
|
|
|
|
host = req.url.raw_host
|
|
assert host is not None
|
|
# Replace multiple trailing dots with a single one.
|
|
# A trailing dot is only present for fully-qualified domain names.
|
|
# See https://github.com/aio-libs/aiohttp/pull/7364.
|
|
if host.endswith(".."):
|
|
host = host.rstrip(".") + "."
|
|
port = req.port
|
|
assert port is not None
|
|
try:
|
|
# Cancelling this lookup should not cancel the underlying lookup
|
|
# or else the cancel event will get broadcast to all the waiters
|
|
# across all connections.
|
|
hosts = await self._resolve_host(host, port, traces=traces)
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
# in case of proxy it is not ClientProxyConnectionError
|
|
# it is problem of resolving proxy ip itself
|
|
raise ClientConnectorDNSError(req.connection_key, exc) from exc
|
|
|
|
last_exc: Optional[Exception] = None
|
|
addr_infos = self._convert_hosts_to_addr_infos(hosts)
|
|
while addr_infos:
|
|
# Strip trailing dots, certificates contain FQDN without dots.
|
|
# See https://github.com/aio-libs/aiohttp/issues/3636
|
|
server_hostname = (
|
|
(req.server_hostname or host).rstrip(".") if sslcontext else None
|
|
)
|
|
|
|
try:
|
|
transp, proto = await self._wrap_create_connection(
|
|
self._factory,
|
|
timeout=timeout,
|
|
ssl=sslcontext,
|
|
addr_infos=addr_infos,
|
|
server_hostname=server_hostname,
|
|
req=req,
|
|
client_error=client_error,
|
|
)
|
|
except (ClientConnectorError, asyncio.TimeoutError) as exc:
|
|
last_exc = exc
|
|
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave)
|
|
continue
|
|
|
|
if req.is_ssl() and fingerprint:
|
|
try:
|
|
fingerprint.check(transp)
|
|
except ServerFingerprintMismatch as exc:
|
|
transp.close()
|
|
if not self._cleanup_closed_disabled:
|
|
self._cleanup_closed_transports.append(transp)
|
|
last_exc = exc
|
|
# Remove the bad peer from the list of addr_infos
|
|
sock: socket.socket = transp.get_extra_info("socket")
|
|
bad_peer = sock.getpeername()
|
|
aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer)
|
|
continue
|
|
|
|
return transp, proto
|
|
else:
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
async def _create_proxy_connection(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
|
|
self._fail_on_no_start_tls(req)
|
|
runtime_has_start_tls = self._loop_supports_start_tls()
|
|
|
|
headers: Dict[str, str] = {}
|
|
if req.proxy_headers is not None:
|
|
headers = req.proxy_headers # type: ignore[assignment]
|
|
headers[hdrs.HOST] = req.headers[hdrs.HOST]
|
|
|
|
url = req.proxy
|
|
assert url is not None
|
|
proxy_req = ClientRequest(
|
|
hdrs.METH_GET,
|
|
url,
|
|
headers=headers,
|
|
auth=req.proxy_auth,
|
|
loop=self._loop,
|
|
ssl=req.ssl,
|
|
)
|
|
|
|
# create connection to proxy server
|
|
transport, proto = await self._create_direct_connection(
|
|
proxy_req, [], timeout, client_error=ClientProxyConnectionError
|
|
)
|
|
|
|
auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
|
|
if auth is not None:
|
|
if not req.is_ssl():
|
|
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
|
|
else:
|
|
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
|
|
|
|
if req.is_ssl():
|
|
if runtime_has_start_tls:
|
|
self._warn_about_tls_in_tls(transport, req)
|
|
|
|
# For HTTPS requests over HTTP proxy
|
|
# we must notify proxy to tunnel connection
|
|
# so we send CONNECT command:
|
|
# CONNECT www.python.org:443 HTTP/1.1
|
|
# Host: www.python.org
|
|
#
|
|
# next we must do TLS handshake and so on
|
|
# to do this we must wrap raw socket into secure one
|
|
# asyncio handles this perfectly
|
|
proxy_req.method = hdrs.METH_CONNECT
|
|
proxy_req.url = req.url
|
|
key = req.connection_key._replace(
|
|
proxy=None, proxy_auth=None, proxy_headers_hash=None
|
|
)
|
|
conn = Connection(self, key, proto, self._loop)
|
|
proxy_resp = await proxy_req.send(conn)
|
|
try:
|
|
protocol = conn._protocol
|
|
assert protocol is not None
|
|
|
|
# read_until_eof=True will ensure the connection isn't closed
|
|
# once the response is received and processed allowing
|
|
# START_TLS to work on the connection below.
|
|
protocol.set_response_params(
|
|
read_until_eof=runtime_has_start_tls,
|
|
timeout_ceil_threshold=self._timeout_ceil_threshold,
|
|
)
|
|
resp = await proxy_resp.start(conn)
|
|
except BaseException:
|
|
proxy_resp.close()
|
|
conn.close()
|
|
raise
|
|
else:
|
|
conn._protocol = None
|
|
try:
|
|
if resp.status != 200:
|
|
message = resp.reason
|
|
if message is None:
|
|
message = HTTPStatus(resp.status).phrase
|
|
raise ClientHttpProxyError(
|
|
proxy_resp.request_info,
|
|
resp.history,
|
|
status=resp.status,
|
|
message=message,
|
|
headers=resp.headers,
|
|
)
|
|
if not runtime_has_start_tls:
|
|
rawsock = transport.get_extra_info("socket", default=None)
|
|
if rawsock is None:
|
|
raise RuntimeError(
|
|
"Transport does not expose socket instance"
|
|
)
|
|
# Duplicate the socket, so now we can close proxy transport
|
|
rawsock = rawsock.dup()
|
|
except BaseException:
|
|
# It shouldn't be closed in `finally` because it's fed to
|
|
# `loop.start_tls()` and the docs say not to touch it after
|
|
# passing there.
|
|
transport.close()
|
|
raise
|
|
finally:
|
|
if not runtime_has_start_tls:
|
|
transport.close()
|
|
|
|
if not runtime_has_start_tls:
|
|
# HTTP proxy with support for upgrade to HTTPS
|
|
sslcontext = self._get_ssl_context(req)
|
|
return await self._wrap_existing_connection(
|
|
self._factory,
|
|
timeout=timeout,
|
|
ssl=sslcontext,
|
|
sock=rawsock,
|
|
server_hostname=req.host,
|
|
req=req,
|
|
)
|
|
|
|
return await self._start_tls_connection(
|
|
# Access the old transport for the last time before it's
|
|
# closed and forgotten forever:
|
|
transport,
|
|
req=req,
|
|
timeout=timeout,
|
|
)
|
|
finally:
|
|
proxy_resp.close()
|
|
|
|
return transport, proto
|
|
|
|
|
|
class UnixConnector(BaseConnector):
|
|
"""Unix socket connector.
|
|
|
|
path - Unix socket path.
|
|
keepalive_timeout - (optional) Keep-alive timeout.
|
|
force_close - Set to True to force close and do reconnect
|
|
after each request (and between redirects).
|
|
limit - The total number of simultaneous connections.
|
|
limit_per_host - Number of simultaneous connections to one host.
|
|
loop - Optional event loop.
|
|
"""
|
|
|
|
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
force_close: bool = False,
|
|
keepalive_timeout: Union[object, float, None] = sentinel,
|
|
limit: int = 100,
|
|
limit_per_host: int = 0,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
force_close=force_close,
|
|
keepalive_timeout=keepalive_timeout,
|
|
limit=limit,
|
|
limit_per_host=limit_per_host,
|
|
loop=loop,
|
|
)
|
|
self._path = path
|
|
|
|
@property
|
|
def path(self) -> str:
|
|
"""Path to unix socket."""
|
|
return self._path
|
|
|
|
async def _create_connection(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> ResponseHandler:
|
|
try:
|
|
async with ceil_timeout(
|
|
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
|
|
):
|
|
_, proto = await self._loop.create_unix_connection(
|
|
self._factory, self._path
|
|
)
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc
|
|
|
|
return proto
|
|
|
|
|
|
class NamedPipeConnector(BaseConnector):
|
|
"""Named pipe connector.
|
|
|
|
Only supported by the proactor event loop.
|
|
See also: https://docs.python.org/3/library/asyncio-eventloop.html
|
|
|
|
path - Windows named pipe path.
|
|
keepalive_timeout - (optional) Keep-alive timeout.
|
|
force_close - Set to True to force close and do reconnect
|
|
after each request (and between redirects).
|
|
limit - The total number of simultaneous connections.
|
|
limit_per_host - Number of simultaneous connections to one host.
|
|
loop - Optional event loop.
|
|
"""
|
|
|
|
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
force_close: bool = False,
|
|
keepalive_timeout: Union[object, float, None] = sentinel,
|
|
limit: int = 100,
|
|
limit_per_host: int = 0,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
force_close=force_close,
|
|
keepalive_timeout=keepalive_timeout,
|
|
limit=limit,
|
|
limit_per_host=limit_per_host,
|
|
loop=loop,
|
|
)
|
|
if not isinstance(
|
|
self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
|
|
):
|
|
raise RuntimeError(
|
|
"Named Pipes only available in proactor loop under windows"
|
|
)
|
|
self._path = path
|
|
|
|
@property
|
|
def path(self) -> str:
|
|
"""Path to the named pipe."""
|
|
return self._path
|
|
|
|
async def _create_connection(
|
|
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
|
|
) -> ResponseHandler:
|
|
try:
|
|
async with ceil_timeout(
|
|
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
|
|
):
|
|
_, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined]
|
|
self._factory, self._path
|
|
)
|
|
# the drain is required so that the connection_made is called
|
|
# and transport is set otherwise it is not set before the
|
|
# `assert conn.transport is not None`
|
|
# in client.py's _request method
|
|
await asyncio.sleep(0)
|
|
# other option is to manually set transport like
|
|
# `proto.transport = trans`
|
|
except OSError as exc:
|
|
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
|
|
raise
|
|
raise ClientConnectorError(req.connection_key, exc) from exc
|
|
|
|
return cast(ResponseHandler, proto)
|