hackerhouse/.venv/Lib/site-packages/aiohappyeyeballs/impl.py
2025-01-17 09:32:32 -06:00

222 lines
7.9 KiB
Python

"""Base implementation."""
import asyncio
import collections
import functools
import itertools
import socket
import sys
from typing import List, Optional, Sequence, Union
from . import _staggered
from .types import AddrInfoType
if sys.version_info < (3, 8, 2): # noqa: UP036
# asyncio.staggered is broken in Python 3.8.0 and 3.8.1
# so it must be patched:
# https://github.com/aio-libs/aiohttp/issues/8556
# https://bugs.python.org/issue39129
# https://github.com/python/cpython/pull/17693
import asyncio.futures
asyncio.futures.TimeoutError = asyncio.TimeoutError # type: ignore[attr-defined]
async def start_connection(
addr_infos: Sequence[AddrInfoType],
*,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
happy_eyeballs_delay: Optional[float] = None,
interleave: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> socket.socket:
"""
Connect to a TCP server.
Create a socket connection to a specified destination. The
destination is specified as a list of AddrInfoType tuples as
returned from getaddrinfo().
The arguments are, in order:
* ``family``: the address family, e.g. ``socket.AF_INET`` or
``socket.AF_INET6``.
* ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
``socket.SOCK_DGRAM``.
* ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
``socket.IPPROTO_UDP``.
* ``canonname``: the canonical name of the address, e.g.
``"www.python.org"``.
* ``sockaddr``: the socket address
This method is a coroutine which will try to establish the connection
in the background. When successful, the coroutine returns a
socket.
The expected use case is to use this method in conjunction with
loop.create_connection() to establish a connection to a server::
socket = await start_connection(addr_infos)
transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)
"""
if not (current_loop := loop):
current_loop = asyncio.get_running_loop()
single_addr_info = len(addr_infos) == 1
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
if interleave and not single_addr_info:
addr_infos = _interleave_addrinfos(addr_infos, interleave)
sock: Optional[socket.socket] = None
# uvloop can raise RuntimeError instead of OSError
exceptions: List[List[Union[OSError, RuntimeError]]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
try:
sock = await _connect_sock(
current_loop, exceptions, addrinfo, local_addr_infos
)
break
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
)
for addrinfo in addr_infos
),
happy_eyeballs_delay,
)
if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
try:
first_exception = all_exceptions[0]
if len(all_exceptions) == 1:
raise first_exception
else:
# If they all have the same str(), raise one.
model = str(first_exception)
if all(str(exc) == model for exc in all_exceptions):
raise first_exception
# Raise a combined exception so the user can see all
# the various error messages.
msg = "Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
if isinstance(first_exception, OSError):
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
elif isinstance(first_exception, RuntimeError) and all(
isinstance(exc, RuntimeError) for exc in all_exceptions
):
raise RuntimeError(msg)
# We have a mix of OSError and RuntimeError
# so we have to pick which one to raise.
# and we raise OSError for compatibility
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
exceptions = None # type: ignore[assignment]
return sock
async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: List[Union[OSError, RuntimeError]] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
try:
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setblocking(False)
if local_addr_infos is not None:
for lfamily, _, _, _, laddr in local_addr_infos:
# skip local addresses of different family
if lfamily != family:
continue
try:
sock.bind(laddr)
break
except OSError as exc:
msg = (
f"error while attempting to bind on "
f"address {laddr!r}: "
f"{exc.strerror.lower()}"
)
exc = OSError(exc.errno, msg)
my_exceptions.append(exc)
else: # all bind attempts failed
if my_exceptions:
raise my_exceptions.pop()
else:
raise OSError(f"no matching local address with {family=} found")
await loop.sock_connect(sock, address)
return sock
except (RuntimeError, OSError) as exc:
my_exceptions.append(exc)
if sock is not None:
try:
sock.close()
except OSError as e:
my_exceptions.append(e)
raise
raise
except:
if sock is not None:
try:
sock.close()
except OSError as e:
my_exceptions.append(e)
raise
raise
finally:
exceptions = my_exceptions = None # type: ignore[assignment]
def _interleave_addrinfos(
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
) -> List[AddrInfoType]:
"""Interleave list of addrinfo tuples by family."""
# Group addresses by family
addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
collections.OrderedDict()
)
for addr in addrinfos:
family = addr[0]
if family not in addrinfos_by_family:
addrinfos_by_family[family] = []
addrinfos_by_family[family].append(addr)
addrinfos_lists = list(addrinfos_by_family.values())
reordered: List[AddrInfoType] = []
if first_address_family_count > 1:
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
del addrinfos_lists[0][: first_address_family_count - 1]
reordered.extend(
a
for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
if a is not None
)
return reordered