"""Helpers for WebSocket protocol versions 13 and 8.""" import functools import re from struct import Struct from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple from ..helpers import NO_EXTENSIONS from .models import WSHandshakeError UNPACK_LEN3 = Struct("!Q").unpack_from UNPACK_CLOSE_CODE = Struct("!H").unpack PACK_LEN1 = Struct("!BB").pack PACK_LEN2 = Struct("!BBH").pack PACK_LEN3 = Struct("!BBQ").pack PACK_CLOSE_CODE = Struct("!H").pack PACK_RANDBITS = Struct("!L").pack MSG_SIZE: Final[int] = 2**14 MASK_LEN: Final[int] = 4 WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" # Used by _websocket_mask_python @functools.lru_cache def _xor_table() -> List[bytes]: return [bytes(a ^ b for a in range(256)) for b in range(256)] def _websocket_mask_python(mask: bytes, data: bytearray) -> None: """Websocket masking function. `mask` is a `bytes` object of length 4; `data` is a `bytearray` object of any length. The contents of `data` are masked with `mask`, as specified in section 5.3 of RFC 6455. Note that this function mutates the `data` argument. This pure-python implementation may be replaced by an optimized version when available. """ assert isinstance(data, bytearray), data assert len(mask) == 4, mask if data: _XOR_TABLE = _xor_table() a, b, c, d = (_XOR_TABLE[n] for n in mask) data[::4] = data[::4].translate(a) data[1::4] = data[1::4].translate(b) data[2::4] = data[2::4].translate(c) data[3::4] = data[3::4].translate(d) if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover websocket_mask = _websocket_mask_python else: try: from .mask import _websocket_mask_cython # type: ignore[import-not-found] websocket_mask = _websocket_mask_cython except ImportError: # pragma: no cover websocket_mask = _websocket_mask_python _WS_EXT_RE: Final[Pattern[str]] = re.compile( r"^(?:;\s*(?:" r"(server_no_context_takeover)|" r"(client_no_context_takeover)|" r"(server_max_window_bits(?:=(\d+))?)|" r"(client_max_window_bits(?:=(\d+))?)))*$" ) _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: if not extstr: return 0, False compress = 0 notakeover = False for ext in _WS_EXT_RE_SPLIT.finditer(extstr): defext = ext.group(1) # Return compress = 15 when get `permessage-deflate` if not defext: compress = 15 break match = _WS_EXT_RE.match(defext) if match: compress = 15 if isserver: # Server never fail to detect compress handshake. # Server does not need to send max wbit to client if match.group(4): compress = int(match.group(4)) # Group3 must match if group4 matches # Compress wbit 8 does not support in zlib # If compress level not support, # CONTINUE to next extension if compress > 15 or compress < 9: compress = 0 continue if match.group(1): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break else: if match.group(6): compress = int(match.group(6)) # Group5 must match if group6 matches # Compress wbit 8 does not support in zlib # If compress level not support, # FAIL the parse progress if compress > 15 or compress < 9: raise WSHandshakeError("Invalid window size") if match.group(2): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break # Return Fail if client side and not match elif not isserver: raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) return compress, notakeover def ws_ext_gen( compress: int = 15, isserver: bool = False, server_notakeover: bool = False ) -> str: # client_notakeover=False not used for server # compress wbit 8 does not support in zlib if compress < 9 or compress > 15: raise ValueError( "Compress wbits must between 9 and 15, zlib does not support wbits=8" ) enabledext = ["permessage-deflate"] if not isserver: enabledext.append("client_max_window_bits") if compress < 15: enabledext.append("server_max_window_bits=" + str(compress)) if server_notakeover: enabledext.append("server_no_context_takeover") # if client_notakeover: # enabledext.append('client_no_context_takeover') return "; ".join(enabledext)