On branch DiscordProfile
Initial commit
This commit is contained in:
@@ -0,0 +1 @@
|
||||
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd
|
||||
@@ -0,0 +1 @@
|
||||
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx
|
||||
@@ -0,0 +1 @@
|
||||
97e3831a92693b1e05c69b02b644722139a646f065468f26bfceea36079065ba /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd
|
||||
@@ -0,0 +1 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,148 @@
|
||||
"""Helpers for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import functools
|
||||
import re
|
||||
from re import Pattern
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
from .models import WSHandshakeError
|
||||
|
||||
UNPACK_LEN3 = Struct("!Q").unpack_from
|
||||
UNPACK_CLOSE_CODE = Struct("!H").unpack
|
||||
PACK_LEN1 = Struct("!BB").pack
|
||||
PACK_LEN2 = Struct("!BBH").pack
|
||||
PACK_LEN3 = Struct("!BBQ").pack
|
||||
PACK_CLOSE_CODE = Struct("!H").pack
|
||||
PACK_RANDBITS = Struct("!L").pack
|
||||
MSG_SIZE: Final[int] = 2**14
|
||||
MASK_LEN: Final[int] = 4
|
||||
|
||||
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
# Used by _websocket_mask_python
|
||||
@functools.lru_cache
|
||||
def _xor_table() -> list[bytes]:
|
||||
return [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||||
|
||||
|
||||
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
|
||||
"""Websocket masking function.
|
||||
|
||||
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
|
||||
object of any length. The contents of `data` are masked with `mask`,
|
||||
as specified in section 5.3 of RFC 6455.
|
||||
|
||||
Note that this function mutates the `data` argument.
|
||||
|
||||
This pure-python implementation may be replaced by an optimized
|
||||
version when available.
|
||||
|
||||
"""
|
||||
assert isinstance(data, bytearray), data
|
||||
assert len(mask) == 4, mask
|
||||
|
||||
if data:
|
||||
_XOR_TABLE = _xor_table()
|
||||
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
||||
data[::4] = data[::4].translate(a)
|
||||
data[1::4] = data[1::4].translate(b)
|
||||
data[2::4] = data[2::4].translate(c)
|
||||
data[3::4] = data[3::4].translate(d)
|
||||
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
else:
|
||||
try:
|
||||
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
|
||||
|
||||
websocket_mask = _websocket_mask_cython
|
||||
except ImportError: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
|
||||
|
||||
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
|
||||
r"^(?:;\s*(?:"
|
||||
r"(server_no_context_takeover)|"
|
||||
r"(client_no_context_takeover)|"
|
||||
r"(server_max_window_bits(?:=(\d+))?)|"
|
||||
r"(client_max_window_bits(?:=(\d+))?)))*$"
|
||||
)
|
||||
|
||||
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
|
||||
|
||||
|
||||
def ws_ext_parse(extstr: str | None, isserver: bool = False) -> tuple[int, bool]:
|
||||
if not extstr:
|
||||
return 0, False
|
||||
|
||||
compress = 0
|
||||
notakeover = False
|
||||
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
|
||||
defext = ext.group(1)
|
||||
# Return compress = 15 when get `permessage-deflate`
|
||||
if not defext:
|
||||
compress = 15
|
||||
break
|
||||
match = _WS_EXT_RE.match(defext)
|
||||
if match:
|
||||
compress = 15
|
||||
if isserver:
|
||||
# Server never fail to detect compress handshake.
|
||||
# Server does not need to send max wbit to client
|
||||
if match.group(4):
|
||||
compress = int(match.group(4))
|
||||
# Group3 must match if group4 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# CONTINUE to next extension
|
||||
if compress > 15 or compress < 9:
|
||||
compress = 0
|
||||
continue
|
||||
if match.group(1):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
else:
|
||||
if match.group(6):
|
||||
compress = int(match.group(6))
|
||||
# Group5 must match if group6 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# FAIL the parse progress
|
||||
if compress > 15 or compress < 9:
|
||||
raise WSHandshakeError("Invalid window size")
|
||||
if match.group(2):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
# Return Fail if client side and not match
|
||||
elif not isserver:
|
||||
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
|
||||
|
||||
return compress, notakeover
|
||||
|
||||
|
||||
def ws_ext_gen(
|
||||
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
|
||||
) -> str:
|
||||
# client_notakeover=False not used for server
|
||||
# compress wbit 8 does not support in zlib
|
||||
if compress < 9 or compress > 15:
|
||||
raise ValueError(
|
||||
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
|
||||
)
|
||||
enabledext = ["permessage-deflate"]
|
||||
if not isserver:
|
||||
enabledext.append("client_max_window_bits")
|
||||
|
||||
if compress < 15:
|
||||
enabledext.append("server_max_window_bits=" + str(compress))
|
||||
if server_notakeover:
|
||||
enabledext.append("server_no_context_takeover")
|
||||
# if client_notakeover:
|
||||
# enabledext.append('client_no_context_takeover')
|
||||
return "; ".join(enabledext)
|
||||
Executable
BIN
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""Cython declarations for websocket masking."""
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data)
|
||||
@@ -0,0 +1,48 @@
|
||||
from cpython cimport PyBytes_AsString
|
||||
|
||||
|
||||
#from cpython cimport PyByteArray_AsString # cython still not exports that
|
||||
cdef extern from "Python.h":
|
||||
char* PyByteArray_AsString(bytearray ba) except NULL
|
||||
|
||||
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
|
||||
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
|
||||
"""Note, this function mutates its `data` argument
|
||||
"""
|
||||
cdef:
|
||||
Py_ssize_t data_len, i
|
||||
# bit operations on signed integers are implementation-specific
|
||||
unsigned char * in_buf
|
||||
const unsigned char * mask_buf
|
||||
uint32_t uint32_msk
|
||||
uint64_t uint64_msk
|
||||
|
||||
assert len(mask) == 4
|
||||
|
||||
data_len = len(data)
|
||||
in_buf = <unsigned char*>PyByteArray_AsString(data)
|
||||
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
|
||||
uint32_msk = (<uint32_t*>mask_buf)[0]
|
||||
|
||||
# TODO: align in_data ptr to achieve even faster speeds
|
||||
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
|
||||
|
||||
if sizeof(size_t) >= 8:
|
||||
uint64_msk = uint32_msk
|
||||
uint64_msk = (uint64_msk << 32) | uint32_msk
|
||||
|
||||
while data_len >= 8:
|
||||
(<uint64_t*>in_buf)[0] ^= uint64_msk
|
||||
in_buf += 8
|
||||
data_len -= 8
|
||||
|
||||
|
||||
while data_len >= 4:
|
||||
(<uint32_t*>in_buf)[0] ^= uint32_msk
|
||||
in_buf += 4
|
||||
data_len -= 4
|
||||
|
||||
for i in range(0, data_len):
|
||||
in_buf[i] ^= mask_buf[i]
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Models for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from enum import IntEnum
|
||||
from typing import Any, Final, NamedTuple, cast
|
||||
|
||||
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
|
||||
|
||||
|
||||
class WSCloseCode(IntEnum):
|
||||
OK = 1000
|
||||
GOING_AWAY = 1001
|
||||
PROTOCOL_ERROR = 1002
|
||||
UNSUPPORTED_DATA = 1003
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
INVALID_TEXT = 1007
|
||||
POLICY_VIOLATION = 1008
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
MANDATORY_EXTENSION = 1010
|
||||
INTERNAL_ERROR = 1011
|
||||
SERVICE_RESTART = 1012
|
||||
TRY_AGAIN_LATER = 1013
|
||||
BAD_GATEWAY = 1014
|
||||
|
||||
|
||||
class WSMsgType(IntEnum):
|
||||
# websocket spec types
|
||||
CONTINUATION = 0x0
|
||||
TEXT = 0x1
|
||||
BINARY = 0x2
|
||||
PING = 0x9
|
||||
PONG = 0xA
|
||||
CLOSE = 0x8
|
||||
|
||||
# aiohttp specific types
|
||||
CLOSING = 0x100
|
||||
CLOSED = 0x101
|
||||
ERROR = 0x102
|
||||
|
||||
text = TEXT
|
||||
binary = BINARY
|
||||
ping = PING
|
||||
pong = PONG
|
||||
close = CLOSE
|
||||
closing = CLOSING
|
||||
closed = CLOSED
|
||||
error = ERROR
|
||||
|
||||
|
||||
class WSMessage(NamedTuple):
|
||||
type: WSMsgType
|
||||
# To type correctly, this would need some kind of tagged union for each type.
|
||||
data: Any
|
||||
extra: str | None
|
||||
|
||||
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
||||
"""Return parsed JSON data.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
"""
|
||||
return loads(self.data)
|
||||
|
||||
|
||||
class WSMessageTextBytes(NamedTuple):
|
||||
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""
|
||||
|
||||
type: WSMsgType
|
||||
# To type correctly, this would need some kind of tagged union for each type.
|
||||
# In 4.0, we use a union of message types to properly type data, but in 3.x
|
||||
# we keep it as Any to avoid a breaking change.
|
||||
data: Any
|
||||
extra: str | None
|
||||
|
||||
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
||||
"""Return parsed JSON data."""
|
||||
return loads(self.data)
|
||||
|
||||
|
||||
# Type aliases for message types based on decode_text setting
|
||||
# When decode_text=True, TEXT messages have str data (WSMessage)
|
||||
# When decode_text=False, TEXT messages have bytes data (WSMessageTextBytes)
|
||||
WSMessageDecodeText = WSMessage
|
||||
WSMessageNoDecodeText = WSMessage | WSMessageTextBytes
|
||||
|
||||
|
||||
# Constructing the tuple directly to avoid the overhead of
|
||||
# the lambda and arg processing since NamedTuples are constructed
|
||||
# with a run time built lambda
|
||||
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
|
||||
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
|
||||
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""WebSocket protocol parser error."""
|
||||
|
||||
def __init__(self, code: int, message: str) -> None:
|
||||
self.code = code
|
||||
super().__init__(code, message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return cast(str, self.args[1])
|
||||
|
||||
|
||||
class WSHandshakeError(Exception):
|
||||
"""WebSocket protocol handshake error."""
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
else:
|
||||
try:
|
||||
from .reader_c import ( # type: ignore[import-not-found]
|
||||
WebSocketDataQueue as WebSocketDataQueueCython,
|
||||
WebSocketReader as WebSocketReaderCython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderCython
|
||||
WebSocketDataQueue = WebSocketDataQueueCython
|
||||
except ImportError: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
Executable
BIN
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
import cython
|
||||
|
||||
from .mask cimport _websocket_mask_cython as websocket_mask
|
||||
|
||||
|
||||
cdef unsigned int READ_HEADER
|
||||
cdef unsigned int READ_PAYLOAD_LENGTH
|
||||
cdef unsigned int READ_PAYLOAD_MASK
|
||||
cdef unsigned int READ_PAYLOAD
|
||||
|
||||
cdef int OP_CODE_NOT_SET
|
||||
cdef int OP_CODE_CONTINUATION
|
||||
cdef int OP_CODE_TEXT
|
||||
cdef int OP_CODE_BINARY
|
||||
cdef int OP_CODE_CLOSE
|
||||
cdef int OP_CODE_PING
|
||||
cdef int OP_CODE_PONG
|
||||
|
||||
cdef int COMPRESSED_NOT_SET
|
||||
cdef int COMPRESSED_FALSE
|
||||
cdef int COMPRESSED_TRUE
|
||||
|
||||
cdef object UNPACK_LEN3
|
||||
cdef object UNPACK_CLOSE_CODE
|
||||
cdef object TUPLE_NEW
|
||||
|
||||
cdef object WSMsgType
|
||||
cdef object WSMessage
|
||||
cdef object WSMessageTextBytes
|
||||
|
||||
cdef object WS_MSG_TYPE_TEXT
|
||||
cdef object WS_MSG_TYPE_BINARY
|
||||
|
||||
cdef set ALLOWED_CLOSE_CODES
|
||||
cdef set MESSAGE_TYPES_WITH_CONTENT
|
||||
|
||||
cdef tuple EMPTY_FRAME
|
||||
cdef tuple EMPTY_FRAME_ERROR
|
||||
|
||||
cdef class WebSocketDataQueue:
|
||||
|
||||
cdef unsigned int _size
|
||||
cdef public object _protocol
|
||||
cdef unsigned int _limit
|
||||
cdef object _loop
|
||||
cdef bint _eof
|
||||
cdef object _waiter
|
||||
cdef object _exception
|
||||
cdef public object _buffer
|
||||
cdef object _get_buffer
|
||||
cdef object _put_buffer
|
||||
|
||||
cdef void _release_waiter(self)
|
||||
|
||||
cpdef void feed_data(self, object data, unsigned int size)
|
||||
|
||||
@cython.locals(size="unsigned int")
|
||||
cdef _read_from_buffer(self)
|
||||
|
||||
cdef class WebSocketReader:
|
||||
|
||||
cdef WebSocketDataQueue queue
|
||||
cdef unsigned int _max_msg_size
|
||||
cdef bint _decode_text
|
||||
|
||||
cdef Exception _exc
|
||||
cdef bytearray _partial
|
||||
cdef unsigned int _state
|
||||
|
||||
cdef int _opcode
|
||||
cdef bint _frame_fin
|
||||
cdef int _frame_opcode
|
||||
cdef list _payload_fragments
|
||||
cdef Py_ssize_t _frame_payload_len
|
||||
|
||||
cdef bytes _tail
|
||||
cdef bint _has_mask
|
||||
cdef bytes _frame_mask
|
||||
cdef Py_ssize_t _payload_bytes_to_read
|
||||
cdef unsigned int _payload_len_flag
|
||||
cdef int _compressed
|
||||
cdef object _decompressobj
|
||||
cdef bint _compress
|
||||
|
||||
cpdef tuple feed_data(self, object data)
|
||||
|
||||
@cython.locals(
|
||||
is_continuation=bint,
|
||||
fin=bint,
|
||||
has_partial=bint,
|
||||
payload_merged=bytes,
|
||||
)
|
||||
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
|
||||
|
||||
@cython.locals(
|
||||
start_pos=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
length=Py_ssize_t,
|
||||
chunk_size=Py_ssize_t,
|
||||
chunk_len=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
data_cstr="const unsigned char *",
|
||||
first_byte="unsigned char",
|
||||
second_byte="unsigned char",
|
||||
f_start_pos=Py_ssize_t,
|
||||
f_end_pos=Py_ssize_t,
|
||||
has_mask=bint,
|
||||
fin=bint,
|
||||
had_fragments=Py_ssize_t,
|
||||
payload_bytearray=bytearray,
|
||||
)
|
||||
cpdef void _feed_data(self, bytes data) except *
|
||||
@@ -0,0 +1,509 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Final
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMessageTextBytes,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: asyncio.Future[None] | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._buffer: deque[tuple[WSMessage | WSMessageTextBytes, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> BaseException | None:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(
|
||||
self, data: "WSMessage | WSMessageTextBytes", size: "cython_int"
|
||||
) -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage | WSMessageTextBytes:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage | WSMessageTextBytes:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self,
|
||||
queue: WebSocketDataQueue,
|
||||
max_msg_size: int,
|
||||
compress: bool = True,
|
||||
decode_text: bool = True,
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
self._decode_text = decode_text
|
||||
|
||||
self._exc: Exception | None = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: bytes | None = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: ZLibDecompressor | None = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: int | cython_int, # Union intended: Cython pxd uses C int
|
||||
payload: bytes | bytearray,
|
||||
compressed: int | cython_int, # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# Validate continuation frames before processing
|
||||
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: bytes | bytearray
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
if self._decode_text:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
# Return raw bytes for TEXT messages when decode_text=False
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(
|
||||
WSMessageTextBytes, (WS_MSG_TYPE_TEXT, payload_merged, "")
|
||||
),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode not in {
|
||||
OP_CODE_CONTINUATION,
|
||||
OP_CODE_TEXT,
|
||||
OP_CODE_BINARY,
|
||||
OP_CODE_CLOSE,
|
||||
OP_CODE_PING,
|
||||
OP_CODE_PONG,
|
||||
}:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Unexpected opcode={opcode!r}",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
# Reject oversized data frames before buffering any payload
|
||||
# bytes. Control frames are capped at 125 bytes (checked in
|
||||
# READ_HEADER) so only text/binary/continuation need this.
|
||||
if self._max_msg_size and self._frame_opcode in {
|
||||
OP_CODE_TEXT,
|
||||
OP_CODE_BINARY,
|
||||
OP_CODE_CONTINUATION,
|
||||
}:
|
||||
projected_size = self._payload_bytes_to_read + len(self._partial)
|
||||
if projected_size >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {projected_size} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: bytes | bytearray
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,509 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Final
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMessageTextBytes,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: asyncio.Future[None] | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._buffer: deque[tuple[WSMessage | WSMessageTextBytes, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> BaseException | None:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(
|
||||
self, data: "WSMessage | WSMessageTextBytes", size: "cython_int"
|
||||
) -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage | WSMessageTextBytes:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage | WSMessageTextBytes:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self,
|
||||
queue: WebSocketDataQueue,
|
||||
max_msg_size: int,
|
||||
compress: bool = True,
|
||||
decode_text: bool = True,
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
self._decode_text = decode_text
|
||||
|
||||
self._exc: Exception | None = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: bytes | None = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: ZLibDecompressor | None = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: int | cython_int, # Union intended: Cython pxd uses C int
|
||||
payload: bytes | bytearray,
|
||||
compressed: int | cython_int, # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# Validate continuation frames before processing
|
||||
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: bytes | bytearray
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
if self._decode_text:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
# Return raw bytes for TEXT messages when decode_text=False
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(
|
||||
WSMessageTextBytes, (WS_MSG_TYPE_TEXT, payload_merged, "")
|
||||
),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode not in {
|
||||
OP_CODE_CONTINUATION,
|
||||
OP_CODE_TEXT,
|
||||
OP_CODE_BINARY,
|
||||
OP_CODE_CLOSE,
|
||||
OP_CODE_PING,
|
||||
OP_CODE_PONG,
|
||||
}:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Unexpected opcode={opcode!r}",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
# Reject oversized data frames before buffering any payload
|
||||
# bytes. Control frames are capped at 125 bytes (checked in
|
||||
# READ_HEADER) so only text/binary/continuation need this.
|
||||
if self._max_msg_size and self._frame_opcode in {
|
||||
OP_CODE_TEXT,
|
||||
OP_CODE_BINARY,
|
||||
OP_CODE_CONTINUATION,
|
||||
}:
|
||||
projected_size = self._payload_bytes_to_read + len(self._partial)
|
||||
if projected_size >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {projected_size} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: bytes | bytearray
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,261 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Final, Optional, Set
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..client_exceptions import ClientConnectionResetError
|
||||
from ..compression_utils import ZLibBackend, ZLibCompressor
|
||||
from ..helpers import DEFAULT_CHUNK_SIZE
|
||||
from .helpers import (
|
||||
MASK_LEN,
|
||||
MSG_SIZE,
|
||||
PACK_CLOSE_CODE,
|
||||
PACK_LEN1,
|
||||
PACK_LEN2,
|
||||
PACK_LEN3,
|
||||
PACK_RANDBITS,
|
||||
websocket_mask,
|
||||
)
|
||||
from .models import WS_DEFLATE_TRAILING, WSMsgType
|
||||
|
||||
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
|
||||
# Control frames (ping, pong, close) are never compressed
|
||||
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
|
||||
|
||||
# For websockets, keeping latency low is extremely important as implementations
|
||||
# generally expect to be able to send and receive messages quickly. We use a
|
||||
# larger chunk size to reduce the number of executor calls and avoid task
|
||||
# creation overhead, since both are significant sources of latency when chunks
|
||||
# are small. A size of 16KiB was chosen as a balance between avoiding task
|
||||
# overhead and not blocking the event loop too long with synchronous compression.
|
||||
|
||||
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
|
||||
|
||||
|
||||
class WebSocketWriter:
|
||||
"""WebSocket writer.
|
||||
|
||||
The writer is responsible for sending messages to the client. It is
|
||||
created by the protocol when a connection is established. The writer
|
||||
should avoid implementing any application logic and should only be
|
||||
concerned with the low-level details of the WebSocket protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: BaseProtocol,
|
||||
transport: asyncio.Transport,
|
||||
*,
|
||||
use_mask: bool = False,
|
||||
limit: int = DEFAULT_CHUNK_SIZE,
|
||||
random: random.Random = random.Random(),
|
||||
compress: int = 0,
|
||||
notakeover: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a WebSocket writer."""
|
||||
self.protocol = protocol
|
||||
self.transport = transport
|
||||
self.use_mask = use_mask
|
||||
self.get_random_bits = partial(random.getrandbits, 32)
|
||||
self.compress = compress
|
||||
self.notakeover = notakeover
|
||||
self._closing = False
|
||||
self._limit = limit
|
||||
self._output_size = 0
|
||||
self._compressobj: Optional[ZLibCompressor] = None
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._background_tasks: Set[asyncio.Task[None]] = set()
|
||||
|
||||
async def send_frame(
|
||||
self, message: bytes, opcode: int, compress: int | None = None
|
||||
) -> None:
|
||||
"""Send a frame over the websocket with message as its payload."""
|
||||
if self._closing and not (opcode & WSMsgType.CLOSE):
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
|
||||
# Non-compressed frames don't need lock or shield
|
||||
self._write_websocket_frame(message, opcode, 0)
|
||||
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
|
||||
# Small compressed payloads - compress synchronously in event loop
|
||||
# We need the lock even though sync compression has no await points.
|
||||
# This prevents small frames from interleaving with large frames that
|
||||
# compress in the executor, avoiding compressor state corruption.
|
||||
async with self._send_lock:
|
||||
self._send_compressed_frame_sync(message, opcode, compress)
|
||||
else:
|
||||
# Large compressed frames need shield to prevent corruption
|
||||
# For large compressed frames, the entire compress+send
|
||||
# operation must be atomic. If cancelled after compression but
|
||||
# before send, the compressor state would be advanced but data
|
||||
# not sent, corrupting subsequent frames.
|
||||
# Create a task to shield from cancellation
|
||||
# The lock is acquired inside the shielded task so the entire
|
||||
# operation (lock + compress + send) completes atomically.
|
||||
# Use eager_start on Python 3.12+ to avoid scheduling overhead
|
||||
loop = asyncio.get_running_loop()
|
||||
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
|
||||
if sys.version_info >= (3, 12):
|
||||
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
|
||||
else:
|
||||
send_task = loop.create_task(coro)
|
||||
# Keep a strong reference to prevent garbage collection
|
||||
self._background_tasks.add(send_task)
|
||||
send_task.add_done_callback(self._background_tasks.discard)
|
||||
await asyncio.shield(send_task)
|
||||
|
||||
# It is safe to return control to the event loop when using compression
|
||||
# after this point as we have already sent or buffered all the data.
|
||||
# Once we have written output_size up to the limit, we call the
|
||||
# drain helper which waits for the transport to be ready to accept
|
||||
# more data. This is a flow control mechanism to prevent the buffer
|
||||
# from growing too large. The drain helper will return right away
|
||||
# if the writer is not paused.
|
||||
if self._output_size > self._limit:
|
||||
self._output_size = 0
|
||||
if self.protocol._paused:
|
||||
await self.protocol._drain_helper()
|
||||
|
||||
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
|
||||
"""
|
||||
Write a websocket frame to the transport.
|
||||
|
||||
This method handles frame header construction, masking, and writing to transport.
|
||||
It does not handle compression or flow control - those are the responsibility
|
||||
of the caller.
|
||||
"""
|
||||
msg_length = len(message)
|
||||
|
||||
use_mask = self.use_mask
|
||||
mask_bit = 0x80 if use_mask else 0
|
||||
|
||||
# Depending on the message length, the header is assembled differently.
|
||||
# The first byte is reserved for the opcode and the RSV bits.
|
||||
first_byte = 0x80 | rsv | opcode
|
||||
if msg_length < 126:
|
||||
header = PACK_LEN1(first_byte, msg_length | mask_bit)
|
||||
header_len = 2
|
||||
elif msg_length < 65536:
|
||||
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
|
||||
header_len = 4
|
||||
else:
|
||||
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
|
||||
header_len = 10
|
||||
|
||||
if self.transport.is_closing():
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
|
||||
# If we are using a mask, we need to generate it randomly
|
||||
# and apply it to the message before sending it. A mask is
|
||||
# a 32-bit value that is applied to the message using a
|
||||
# bitwise XOR operation. It is used to prevent certain types
|
||||
# of attacks on the websocket protocol. The mask is only used
|
||||
# when aiohttp is acting as a client. Servers do not use a mask.
|
||||
if use_mask:
|
||||
mask = PACK_RANDBITS(self.get_random_bits())
|
||||
message_arr = bytearray(message)
|
||||
websocket_mask(mask, message_arr)
|
||||
self.transport.write(header + mask + message_arr)
|
||||
self._output_size += MASK_LEN
|
||||
elif msg_length > MSG_SIZE:
|
||||
self.transport.write(header)
|
||||
self.transport.write(message)
|
||||
else:
|
||||
self.transport.write(header + message)
|
||||
|
||||
self._output_size += header_len + msg_length
|
||||
|
||||
def _get_compressor(self, compress: int | None) -> ZLibCompressor:
|
||||
"""Get or create a compressor object for the given compression level."""
|
||||
if compress:
|
||||
# Do not set self._compress if compressing is for this frame
|
||||
return ZLibCompressor(
|
||||
level=ZLibBackend.Z_BEST_SPEED,
|
||||
wbits=-compress,
|
||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||||
)
|
||||
if not self._compressobj:
|
||||
self._compressobj = ZLibCompressor(
|
||||
level=ZLibBackend.Z_BEST_SPEED,
|
||||
wbits=-self.compress,
|
||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||||
)
|
||||
return self._compressobj
|
||||
|
||||
def _send_compressed_frame_sync(
|
||||
self, message: bytes, opcode: int, compress: int | None
|
||||
) -> None:
|
||||
"""
|
||||
Synchronous send for small compressed frames.
|
||||
|
||||
This is used for small compressed payloads that compress synchronously in the event loop.
|
||||
Since there are no await points, this is inherently cancellation-safe.
|
||||
"""
|
||||
# RSV are the reserved bits in the frame header. They are used to
|
||||
# indicate that the frame is using an extension.
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
compressobj = self._get_compressor(compress)
|
||||
# (0x40) RSV1 is set for compressed frames
|
||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
||||
self._write_websocket_frame(
|
||||
(
|
||||
compressobj.compress_sync(message)
|
||||
+ compressobj.flush(
|
||||
ZLibBackend.Z_FULL_FLUSH
|
||||
if self.notakeover
|
||||
else ZLibBackend.Z_SYNC_FLUSH
|
||||
)
|
||||
).removesuffix(WS_DEFLATE_TRAILING),
|
||||
opcode,
|
||||
0x40,
|
||||
)
|
||||
|
||||
async def _send_compressed_frame_async_locked(
|
||||
self, message: bytes, opcode: int, compress: int | None
|
||||
) -> None:
|
||||
"""
|
||||
Async send for large compressed frames with lock.
|
||||
|
||||
Acquires the lock and compresses large payloads asynchronously in
|
||||
the executor. The lock is held for the entire operation to ensure
|
||||
the compressor state is not corrupted by concurrent sends.
|
||||
|
||||
MUST be run shielded from cancellation. If cancelled after
|
||||
compression but before sending, the compressor state would be
|
||||
advanced but data not sent, corrupting subsequent frames.
|
||||
"""
|
||||
async with self._send_lock:
|
||||
# RSV are the reserved bits in the frame header. They are used to
|
||||
# indicate that the frame is using an extension.
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
compressobj = self._get_compressor(compress)
|
||||
# (0x40) RSV1 is set for compressed frames
|
||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
||||
self._write_websocket_frame(
|
||||
(
|
||||
await compressobj.compress(message)
|
||||
+ compressobj.flush(
|
||||
ZLibBackend.Z_FULL_FLUSH
|
||||
if self.notakeover
|
||||
else ZLibBackend.Z_SYNC_FLUSH
|
||||
)
|
||||
).removesuffix(WS_DEFLATE_TRAILING),
|
||||
opcode,
|
||||
0x40,
|
||||
)
|
||||
|
||||
async def close(self, code: int = 1000, message: bytes | str = b"") -> None:
|
||||
"""Close the websocket, sending the specified code and message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
try:
|
||||
await self.send_frame(
|
||||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
|
||||
)
|
||||
finally:
|
||||
self._closing = True
|
||||
Reference in New Issue
Block a user