On branch DiscordProfile
Initial commit
This commit is contained in:
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
AutoHTTPProtocol: type[asyncio.Protocol]
|
||||
try:
|
||||
import httptools # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
from uvicorn.protocols.http.h11_impl import H11Protocol
|
||||
|
||||
AutoHTTPProtocol = H11Protocol
|
||||
else: # pragma: no cover
|
||||
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
|
||||
|
||||
AutoHTTPProtocol = HttpToolsProtocol
|
||||
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
||||
|
||||
CLOSE_HEADER = (b"connection", b"close")
|
||||
|
||||
HIGH_WATER_LIMIT = 65536
|
||||
|
||||
|
||||
class FlowControl:
|
||||
def __init__(self, transport: asyncio.Transport) -> None:
|
||||
self._transport = transport
|
||||
self.read_paused = False
|
||||
self.write_paused = False
|
||||
self._is_writable_event = asyncio.Event()
|
||||
self._is_writable_event.set()
|
||||
|
||||
async def drain(self) -> None:
|
||||
await self._is_writable_event.wait() # pragma: full coverage
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self._transport.pause_reading()
|
||||
|
||||
def resume_reading(self) -> None:
|
||||
if self.read_paused:
|
||||
self.read_paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if not self.write_paused: # pragma: full coverage
|
||||
self.write_paused = True
|
||||
self._is_writable_event.clear()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self.write_paused: # pragma: full coverage
|
||||
self.write_paused = False
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
async def service_unavailable(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 503,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"19"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": b"Service Unavailable", "more_body": False})
|
||||
@@ -0,0 +1,544 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import http
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import unquote
|
||||
|
||||
import h11
|
||||
from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPResponseBodyEvent,
|
||||
HTTPResponseStartEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
def _get_status_phrase(status_code: int) -> bytes:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
return b""
|
||||
|
||||
|
||||
STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class H11Protocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.conn = h11.Connection(
|
||||
h11.SERVER,
|
||||
config.h11_max_incomplete_event_size
|
||||
if config.h11_max_incomplete_event_size is not None
|
||||
else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Shared server state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int | None] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.conn.our_state != h11.ERROR:
|
||||
event = h11.ConnectionClosed()
|
||||
try:
|
||||
self.conn.send(event)
|
||||
except h11.LocalProtocolError:
|
||||
# Premature client disconnect
|
||||
pass
|
||||
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
msg = "Unsupported upgrade request."
|
||||
self.logger.warning(msg)
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
if upgrade == b"websocket" and self._should_upgrade_to_ws():
|
||||
return True
|
||||
if upgrade is not None:
|
||||
self._unsupported_upgrade_warning()
|
||||
return False
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.conn.receive_data(data)
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self.conn.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
break
|
||||
|
||||
elif event is h11.PAUSED:
|
||||
# This case can occur in HTTP pipelining, so we need to
|
||||
# stop reading any more data, and ensure that at the end
|
||||
# of the active request/response cycle we handle any
|
||||
# events that have been buffered up.
|
||||
self.flow.pause_reading()
|
||||
break
|
||||
|
||||
elif isinstance(event, h11.Request):
|
||||
self.headers = [(key.lower(), value) for key, value in event.headers]
|
||||
raw_path, _, query_string = event.target.partition(b"?")
|
||||
path = unquote(raw_path.decode("ascii"))
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": event.http_version.decode("ascii"),
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"method": event.method.decode("ascii"),
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade(event)
|
||||
return
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
# When starting to process a request, disable the keep-alive
|
||||
# timeout. Normally we disable this when receiving data from
|
||||
# client and set back when finishing processing its request.
|
||||
# However, for pipelined requests processing finishes after
|
||||
# already receiving the next request and thus the timer may
|
||||
# be set here, which we don't want.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
conn=self.conn,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
if self.config.reset_contextvars:
|
||||
# Opt-in workaround for https://github.com/python/cpython/issues/140947:
|
||||
# asyncio can leak context vars between tasks. Hides context set in the
|
||||
# lifespan or by external instrumentation.
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
|
||||
else: # pragma: py-gte-311
|
||||
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
|
||||
else:
|
||||
task = self.loop.create_task(self.cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
elif isinstance(event, h11.Data):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
continue
|
||||
self.cycle.body += event.data
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
if self.conn.our_state is h11.DONE:
|
||||
self.transport.resume_reading()
|
||||
self.conn.start_next_cycle()
|
||||
continue
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
if self.conn.their_state == h11.MUST_CLOSE:
|
||||
break
|
||||
|
||||
def handle_websocket_upgrade(self, event: h11.Request) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL: # pragma: full coverage
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.headers:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
reason = STATUS_PHRASES[400]
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
]
|
||||
event = h11.Response(status_code=400, headers=headers, reason=reason)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.Data(data=msg.encode("ascii")))
|
||||
self.transport.write(output)
|
||||
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
self.transport.close()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
# Set a short Keep-Alive timeout.
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events.
|
||||
if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE:
|
||||
self.conn.start_next_cycle()
|
||||
self.handle_events()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
event = h11.ConnectionClosed()
|
||||
self.conn.send(event)
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
conn: h11.Connection,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
on_response: Callable[..., None],
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.conn = conn
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = True
|
||||
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = bytearray()
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
response_start_event: HTTPResponseStartEvent = {
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
await self.send(response_start_event)
|
||||
response_body_event: HTTPResponseBodyEvent = {
|
||||
"type": "http.response.body",
|
||||
"body": b"Internal Server Error",
|
||||
"more_body": False,
|
||||
}
|
||||
await self.send(response_body_event)
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message["type"] != "http.response.start":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
reason = STATUS_PHRASES[status]
|
||||
response = h11.Response(status_code=status, headers=headers, reason=reason)
|
||||
output = self.conn.send(event=response)
|
||||
self.transport.write(output)
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message["type"] != "http.response.body":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
data = b"" if self.scope["method"] == "HEAD" else body
|
||||
output = self.conn.send(event=h11.Data(data=data))
|
||||
self.transport.write(output)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
output = self.conn.send(event=h11.EndOfMessage())
|
||||
self.transport.write(output)
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
|
||||
|
||||
if self.response_complete:
|
||||
if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
|
||||
self.conn.send(event=h11.ConnectionClosed())
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
headers: list[tuple[str, str]] = []
|
||||
event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
|
||||
output = self.conn.send(event=event)
|
||||
self.transport.write(output)
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
|
||||
self.body = bytearray()
|
||||
return message
|
||||
@@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import http
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import urllib
|
||||
from asyncio.events import TimerHandle
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
import httptools
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
HTTPRequestEvent,
|
||||
HTTPScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
|
||||
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:\\[\\]={} \t\\\\"]')
|
||||
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
|
||||
|
||||
|
||||
def _get_status_line(status_code: int) -> bytes:
|
||||
try:
|
||||
phrase = http.HTTPStatus(status_code).phrase.encode()
|
||||
except ValueError:
|
||||
phrase = b""
|
||||
return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
|
||||
|
||||
|
||||
STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
|
||||
|
||||
|
||||
class HttpToolsProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.access_logger = logging.getLogger("uvicorn.access")
|
||||
self.access_log = self.access_logger.hasHandlers()
|
||||
self.parser = httptools.HttpRequestParser(self)
|
||||
|
||||
try:
|
||||
# Enable dangerous leniencies to allow server to a response on the first request from a pipelined request.
|
||||
self.parser.set_dangerous_leniencies(lenient_data_after_close=True)
|
||||
except AttributeError: # pragma: no cover
|
||||
# httptools < 0.6.3
|
||||
pass
|
||||
|
||||
self.ws_protocol_class = config.ws_protocol_class
|
||||
self.root_path = config.root_path
|
||||
self.limit_concurrency = config.limit_concurrency
|
||||
self.app_state = app_state
|
||||
|
||||
# Timeouts
|
||||
self.timeout_keep_alive_task: TimerHandle | None = None
|
||||
self.timeout_keep_alive = config.timeout_keep_alive
|
||||
|
||||
# Global state
|
||||
self.server_state = server_state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Per-connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.flow: FlowControl = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int | None] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["http", "https"] | None = None
|
||||
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
|
||||
|
||||
# Per-request state
|
||||
self.scope: HTTPScope = None # type: ignore[assignment]
|
||||
self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
|
||||
self.expect_100_continue = False
|
||||
self.cycle: RequestResponseCycle = None # type: ignore[assignment]
|
||||
|
||||
# Protocol interface
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
|
||||
self.transport = transport
|
||||
self.flow = FlowControl(transport)
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "https" if is_ssl(transport) else "http"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.discard(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
|
||||
|
||||
if self.cycle and not self.cycle.response_complete:
|
||||
self.cycle.disconnected = True
|
||||
if self.cycle is not None:
|
||||
self.cycle.message_event.set()
|
||||
if self.flow is not None:
|
||||
self.flow.resume_writing()
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
self.parser = None # type: ignore[assignment]
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def _unset_keepalive_if_required(self) -> None:
|
||||
if self.timeout_keep_alive_task is not None:
|
||||
self.timeout_keep_alive_task.cancel()
|
||||
self.timeout_keep_alive_task = None
|
||||
|
||||
def _get_upgrade(self) -> bytes | None:
|
||||
connection = []
|
||||
upgrade = None
|
||||
for name, value in self.headers:
|
||||
if name == b"connection":
|
||||
connection = [token.lower().strip() for token in value.split(b",")]
|
||||
if name == b"upgrade":
|
||||
upgrade = value.lower()
|
||||
if b"upgrade" in connection:
|
||||
return upgrade
|
||||
return None # pragma: full coverage
|
||||
|
||||
def _should_upgrade_to_ws(self) -> bool:
|
||||
if self.ws_protocol_class is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _unsupported_upgrade_warning(self) -> None:
|
||||
self.logger.warning("Unsupported upgrade request.")
|
||||
if not self._should_upgrade_to_ws():
|
||||
msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
|
||||
self.logger.warning(msg)
|
||||
|
||||
def _should_upgrade(self) -> bool:
|
||||
upgrade = self._get_upgrade()
|
||||
return upgrade == b"websocket" and self._should_upgrade_to_ws()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except httptools.HttpParserError:
|
||||
msg = "Invalid HTTP request received."
|
||||
self.logger.warning(msg)
|
||||
self.send_400_response(msg)
|
||||
return
|
||||
except httptools.HttpParserUpgrade:
|
||||
if self._should_upgrade():
|
||||
self.handle_websocket_upgrade()
|
||||
else:
|
||||
self._unsupported_upgrade_warning()
|
||||
|
||||
def handle_websocket_upgrade(self) -> None:
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
|
||||
|
||||
self.connections.discard(self)
|
||||
method = self.scope["method"].encode()
|
||||
output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
|
||||
for name, value in self.scope["headers"]:
|
||||
output += [name, b": ", value, b"\r\n"]
|
||||
output.append(b"\r\n")
|
||||
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
|
||||
config=self.config,
|
||||
server_state=self.server_state,
|
||||
app_state=self.app_state,
|
||||
)
|
||||
protocol.connection_made(self.transport)
|
||||
protocol.data_received(b"".join(output))
|
||||
self.transport.set_protocol(protocol)
|
||||
|
||||
def send_400_response(self, msg: str) -> None:
|
||||
content = [STATUS_LINE[400]]
|
||||
for name, value in self.server_state.default_headers:
|
||||
content.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage
|
||||
content.extend(
|
||||
[
|
||||
b"content-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg.encode("ascii"),
|
||||
]
|
||||
)
|
||||
self.transport.write(b"".join(content))
|
||||
self.transport.close()
|
||||
|
||||
def on_message_begin(self) -> None:
|
||||
self.url = b""
|
||||
self.expect_100_continue = False
|
||||
self.headers = []
|
||||
self.scope = { # type: ignore[typeddict-item]
|
||||
"type": "http",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
|
||||
"http_version": "1.1",
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"scheme": self.scheme, # type: ignore[typeddict-item]
|
||||
"root_path": self.root_path,
|
||||
"headers": self.headers,
|
||||
"state": self.app_state.copy(),
|
||||
}
|
||||
|
||||
# Parser callbacks
|
||||
def on_url(self, url: bytes) -> None:
|
||||
self.url += url
|
||||
|
||||
def on_header(self, name: bytes, value: bytes) -> None:
|
||||
name = name.lower()
|
||||
if name == b"expect" and value.lower() == b"100-continue":
|
||||
self.expect_100_continue = True
|
||||
self.headers.append((name, value))
|
||||
|
||||
def on_headers_complete(self) -> None:
|
||||
http_version = self.parser.get_http_version()
|
||||
method = self.parser.get_method()
|
||||
self.scope["method"] = method.decode("ascii")
|
||||
if http_version != "1.1":
|
||||
self.scope["http_version"] = http_version
|
||||
if self.parser.should_upgrade() and self._should_upgrade():
|
||||
return
|
||||
parsed_url = httptools.parse_url(self.url)
|
||||
raw_path = parsed_url.path
|
||||
path = raw_path.decode("ascii")
|
||||
if "%" in path:
|
||||
path = urllib.parse.unquote(path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path
|
||||
self.scope["path"] = full_path
|
||||
self.scope["raw_path"] = full_raw_path
|
||||
self.scope["query_string"] = parsed_url.query or b""
|
||||
|
||||
# Handle 503 responses when 'limit_concurrency' is exceeded.
|
||||
if self.limit_concurrency is not None and (
|
||||
len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
|
||||
):
|
||||
app = service_unavailable
|
||||
message = "Exceeded concurrency limit."
|
||||
self.logger.warning(message)
|
||||
else:
|
||||
app = self.app
|
||||
|
||||
existing_cycle = self.cycle
|
||||
self.cycle = RequestResponseCycle(
|
||||
scope=self.scope,
|
||||
transport=self.transport,
|
||||
flow=self.flow,
|
||||
logger=self.logger,
|
||||
access_logger=self.access_logger,
|
||||
access_log=self.access_log,
|
||||
default_headers=self.server_state.default_headers,
|
||||
message_event=asyncio.Event(),
|
||||
expect_100_continue=self.expect_100_continue,
|
||||
keep_alive=http_version != "1.0",
|
||||
on_response=self.on_response_complete,
|
||||
)
|
||||
if existing_cycle is None or existing_cycle.response_complete:
|
||||
# Standard case - start processing the request.
|
||||
self._start_asgi_task(self.cycle, app)
|
||||
else:
|
||||
# Pipelined HTTP requests need to be queued up.
|
||||
self.flow.pause_reading()
|
||||
self.pipeline.appendleft((self.cycle, app))
|
||||
|
||||
def _start_asgi_task(self, cycle: RequestResponseCycle, app: ASGI3Application) -> None:
|
||||
if self.config.reset_contextvars:
|
||||
# Opt-in workaround for https://github.com/python/cpython/issues/140947:
|
||||
# asyncio can leak context vars between tasks. Hides context set in the
|
||||
# lifespan or by external instrumentation.
|
||||
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
||||
task = self.loop.create_task(cycle.run_asgi(app), context=contextvars.Context())
|
||||
else: # pragma: py-gte-311
|
||||
task = contextvars.Context().run(self.loop.create_task, cycle.run_asgi(app))
|
||||
else:
|
||||
task = self.loop.create_task(cycle.run_asgi(app))
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
self.tasks.add(task)
|
||||
|
||||
def on_body(self, body: bytes) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.body += body
|
||||
if len(self.cycle.body) > HIGH_WATER_LIMIT:
|
||||
self.flow.pause_reading()
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_message_complete(self) -> None:
|
||||
if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
|
||||
return
|
||||
self.cycle.more_body = False
|
||||
self.cycle.message_event.set()
|
||||
|
||||
def on_response_complete(self) -> None:
|
||||
# Callback for pipelined HTTP requests to be started.
|
||||
self.server_state.total_requests += 1
|
||||
|
||||
if self.transport.is_closing():
|
||||
return
|
||||
|
||||
self._unset_keepalive_if_required()
|
||||
|
||||
# Unpause data reads if needed.
|
||||
self.flow.resume_reading()
|
||||
|
||||
# Unblock any pipelined events. If there are none, arm the
|
||||
# Keep-Alive timeout instead.
|
||||
if self.pipeline:
|
||||
cycle, app = self.pipeline.pop()
|
||||
self._start_asgi_task(cycle, app)
|
||||
else:
|
||||
self.timeout_keep_alive_task = self.loop.call_later(
|
||||
self.timeout_keep_alive, self.timeout_keep_alive_handler
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called by the server to commence a graceful shutdown.
|
||||
"""
|
||||
if self.cycle is None or self.cycle.response_complete:
|
||||
self.transport.close()
|
||||
else:
|
||||
self.cycle.keep_alive = False
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.flow.pause_writing() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.flow.resume_writing() # pragma: full coverage
|
||||
|
||||
def timeout_keep_alive_handler(self) -> None:
|
||||
"""
|
||||
Called on a keep-alive connection if no new data is received after a short
|
||||
delay.
|
||||
"""
|
||||
if not self.transport.is_closing():
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class RequestResponseCycle:
|
||||
def __init__(
|
||||
self,
|
||||
scope: HTTPScope,
|
||||
transport: asyncio.Transport,
|
||||
flow: FlowControl,
|
||||
logger: logging.Logger,
|
||||
access_logger: logging.Logger,
|
||||
access_log: bool,
|
||||
default_headers: list[tuple[bytes, bytes]],
|
||||
message_event: asyncio.Event,
|
||||
expect_100_continue: bool,
|
||||
keep_alive: bool,
|
||||
on_response: Callable[..., None],
|
||||
):
|
||||
self.scope = scope
|
||||
self.transport = transport
|
||||
self.flow = flow
|
||||
self.logger = logger
|
||||
self.access_logger = access_logger
|
||||
self.access_log = access_log
|
||||
self.default_headers = default_headers
|
||||
self.message_event = message_event
|
||||
self.on_response = on_response
|
||||
|
||||
# Connection state
|
||||
self.disconnected = False
|
||||
self.keep_alive = keep_alive
|
||||
self.waiting_for_100_continue = expect_100_continue
|
||||
|
||||
# Request state
|
||||
self.body = bytearray()
|
||||
self.more_body = True
|
||||
|
||||
# Response state
|
||||
self.response_started = False
|
||||
self.response_complete = False
|
||||
self.chunked_encoding: bool | None = None
|
||||
self.expected_content_length = 0
|
||||
|
||||
# ASGI exception wrapper
|
||||
async def run_asgi(self, app: ASGI3Application) -> None:
|
||||
try:
|
||||
result = await app( # type: ignore[func-returns-value]
|
||||
self.scope, self.receive, self.send
|
||||
)
|
||||
except BaseException as exc:
|
||||
msg = "Exception in ASGI application\n"
|
||||
self.logger.error(msg, exc_info=exc)
|
||||
if not self.response_started:
|
||||
await self.send_500_response()
|
||||
else:
|
||||
self.transport.close()
|
||||
else:
|
||||
if result is not None:
|
||||
msg = "ASGI callable should return None, but returned '%s'."
|
||||
self.logger.error(msg, result)
|
||||
self.transport.close()
|
||||
elif not self.response_started and not self.disconnected:
|
||||
msg = "ASGI callable returned without starting response."
|
||||
self.logger.error(msg)
|
||||
await self.send_500_response()
|
||||
elif not self.response_complete and not self.disconnected:
|
||||
msg = "ASGI callable returned without completing response."
|
||||
self.logger.error(msg)
|
||||
self.transport.close()
|
||||
finally:
|
||||
self.on_response = lambda: None
|
||||
|
||||
async def send_500_response(self) -> None:
|
||||
await self.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"content-length", b"21"),
|
||||
(b"connection", b"close"),
|
||||
],
|
||||
}
|
||||
)
|
||||
await self.send({"type": "http.response.body", "body": b"Internal Server Error", "more_body": False})
|
||||
|
||||
# ASGI interface
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
if self.flow.write_paused and not self.disconnected:
|
||||
await self.flow.drain() # pragma: full coverage
|
||||
|
||||
if self.disconnected:
|
||||
return # pragma: full coverage
|
||||
|
||||
if not self.response_started:
|
||||
# Sending response status line and headers
|
||||
if message["type"] != "http.response.start":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.start', but got '{message['type']}'.")
|
||||
|
||||
self.response_started = True
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
status_code = message["status"]
|
||||
headers = self.default_headers + list(message.get("headers", []))
|
||||
|
||||
if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
|
||||
headers = headers + [CLOSE_HEADER]
|
||||
|
||||
if self.access_log:
|
||||
self.access_logger.info(
|
||||
'%s - "%s %s HTTP/%s" %d',
|
||||
get_client_addr(self.scope),
|
||||
self.scope["method"],
|
||||
get_path_with_query_string(self.scope),
|
||||
self.scope["http_version"],
|
||||
status_code,
|
||||
)
|
||||
|
||||
# Write response status line and headers
|
||||
content = [STATUS_LINE[status_code]]
|
||||
|
||||
for name, value in headers:
|
||||
if HEADER_RE.search(name):
|
||||
raise RuntimeError("Invalid HTTP header name.") # pragma: full coverage
|
||||
if HEADER_VALUE_RE.search(value):
|
||||
raise RuntimeError("Invalid HTTP header value.")
|
||||
|
||||
name = name.lower()
|
||||
if name == b"content-length" and self.chunked_encoding is None:
|
||||
self.expected_content_length = int(value.decode())
|
||||
self.chunked_encoding = False
|
||||
elif name == b"transfer-encoding" and value.lower() == b"chunked":
|
||||
self.expected_content_length = 0
|
||||
self.chunked_encoding = True
|
||||
elif name == b"connection" and value.lower() == b"close":
|
||||
self.keep_alive = False
|
||||
content.extend([name, b": ", value, b"\r\n"])
|
||||
|
||||
if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
|
||||
# Neither content-length nor transfer-encoding specified
|
||||
self.chunked_encoding = True
|
||||
content.append(b"transfer-encoding: chunked\r\n")
|
||||
|
||||
content.append(b"\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
|
||||
elif not self.response_complete:
|
||||
# Sending response body
|
||||
if message["type"] != "http.response.body":
|
||||
raise RuntimeError(f"Expected ASGI message 'http.response.body', but got '{message['type']}'.")
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Write response body
|
||||
if self.scope["method"] == "HEAD":
|
||||
self.expected_content_length = 0
|
||||
elif self.chunked_encoding:
|
||||
if body:
|
||||
content = [b"%x\r\n" % len(body), body, b"\r\n"]
|
||||
else:
|
||||
content = []
|
||||
if not more_body:
|
||||
content.append(b"0\r\n\r\n")
|
||||
self.transport.write(b"".join(content))
|
||||
else:
|
||||
num_bytes = len(body)
|
||||
if num_bytes > self.expected_content_length:
|
||||
raise RuntimeError("Response content longer than Content-Length")
|
||||
else:
|
||||
self.expected_content_length -= num_bytes
|
||||
self.transport.write(body)
|
||||
|
||||
# Handle response completion
|
||||
if not more_body:
|
||||
if self.expected_content_length != 0:
|
||||
raise RuntimeError("Response content shorter than Content-Length")
|
||||
self.response_complete = True
|
||||
self.message_event.set()
|
||||
if not self.keep_alive:
|
||||
self.transport.close()
|
||||
self.on_response()
|
||||
|
||||
else:
|
||||
# Response already sent
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}' sent, after response already completed.")
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
if self.waiting_for_100_continue and not self.transport.is_closing():
|
||||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
self.waiting_for_100_continue = False
|
||||
|
||||
if not self.disconnected and not self.response_complete:
|
||||
self.flow.resume_reading()
|
||||
await self.message_event.wait()
|
||||
self.message_event.clear()
|
||||
|
||||
if self.disconnected or self.response_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
|
||||
self.body = bytearray()
|
||||
return message
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import urllib.parse
|
||||
|
||||
from uvicorn._types import WWWScope
|
||||
|
||||
|
||||
class ClientDisconnected(OSError): ...
|
||||
|
||||
|
||||
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
|
||||
socket_info: socket.socket | None = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
try:
|
||||
info = socket_info.getpeername()
|
||||
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
|
||||
except OSError: # pragma: no cover
|
||||
# This case appears to inconsistently occur with uvloop
|
||||
# bound to a unix domain socket.
|
||||
return None
|
||||
|
||||
info = transport.get_extra_info("peername")
|
||||
if info is not None and isinstance(info, list | tuple) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
return None
|
||||
|
||||
|
||||
def get_local_addr(transport: asyncio.Transport) -> tuple[str, int | None] | None:
|
||||
socket_info: socket.socket | None = transport.get_extra_info("socket")
|
||||
if socket_info is not None:
|
||||
info = socket_info.getsockname()
|
||||
if isinstance(info, tuple):
|
||||
return (str(info[0]), int(info[1]))
|
||||
if isinstance(info, str):
|
||||
return (info, None)
|
||||
return None
|
||||
info = transport.get_extra_info("sockname")
|
||||
if info is not None and isinstance(info, list | tuple) and len(info) == 2:
|
||||
return (str(info[0]), int(info[1]))
|
||||
if isinstance(info, str):
|
||||
return (info, None)
|
||||
return None
|
||||
|
||||
|
||||
def is_ssl(transport: asyncio.Transport) -> bool:
|
||||
return bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
|
||||
def get_client_addr(scope: WWWScope) -> str:
|
||||
client = scope.get("client")
|
||||
if not client:
|
||||
return ""
|
||||
return "%s:%d" % client
|
||||
|
||||
|
||||
def get_path_with_query_string(scope: WWWScope) -> str:
|
||||
path_with_query_string = urllib.parse.quote(scope["path"])
|
||||
if scope["query_string"]:
|
||||
path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
|
||||
return path_with_query_string
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
|
||||
AutoWebSocketsProtocol: Callable[..., asyncio.Protocol] | None
|
||||
try:
|
||||
import websockets # noqa
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import wsproto # noqa
|
||||
except ImportError:
|
||||
AutoWebSocketsProtocol = None
|
||||
else:
|
||||
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WSProtocol
|
||||
else:
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
|
||||
AutoWebSocketsProtocol = WebSocketProtocol
|
||||
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import websockets
|
||||
import websockets.legacy.handshake
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.extensions.base import ServerExtensionFactory
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.legacy.server import HTTPResponse
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGI3Application,
|
||||
ASGISendEvent,
|
||||
WebSocketConnectEvent,
|
||||
WebSocketDisconnectEvent,
|
||||
WebSocketReceiveEvent,
|
||||
WebSocketScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class Server:
|
||||
closing = False
|
||||
|
||||
def register(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def unregister(self, ws: WebSocketServerProtocol) -> None:
|
||||
pass
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
return not self.closing
|
||||
|
||||
|
||||
class WebSocketProtocol(WebSocketServerProtocol):
|
||||
extra_headers: list[tuple[str, str]]
|
||||
logger: logging.Logger | logging.LoggerAdapter[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
):
|
||||
if not config.loaded:
|
||||
config.load()
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int | None] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# Connection events
|
||||
self.scope: WebSocketScope
|
||||
self.handshake_started_event = asyncio.Event()
|
||||
self.handshake_completed_event = asyncio.Event()
|
||||
self.closed_event = asyncio.Event()
|
||||
self.initial_response: HTTPResponse | None = None
|
||||
self.connect_sent = False
|
||||
self.lost_connection_before_handshake = False
|
||||
self.accepted_subprotocol: Subprotocol | None = None
|
||||
|
||||
self.ws_server: Server = Server() # type: ignore[assignment]
|
||||
|
||||
extensions: list[ServerExtensionFactory] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(ServerPerMessageDeflateFactory())
|
||||
|
||||
super().__init__(
|
||||
ws_handler=self.ws_handler,
|
||||
ws_server=self.ws_server, # type: ignore[arg-type]
|
||||
max_size=self.config.ws_max_size,
|
||||
max_queue=self.config.ws_max_queue,
|
||||
ping_interval=self.config.ws_ping_interval,
|
||||
ping_timeout=self.config.ws_ping_timeout,
|
||||
extensions=extensions,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
self.server_header = None
|
||||
self.extra_headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
|
||||
]
|
||||
|
||||
def connection_made( # type: ignore[override]
|
||||
self, transport: asyncio.Transport
|
||||
) -> None:
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
super().connection_made(transport)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
|
||||
self.handshake_completed_event.set()
|
||||
super().connection_lost(exc)
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.ws_server.closing = True
|
||||
if self.handshake_completed_event.is_set():
|
||||
self.fail_connection(1012)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
|
||||
"""
|
||||
This hook is called to determine if the websocket should return
|
||||
an HTTP response and close.
|
||||
|
||||
Our behavior here is to start the ASGI application, and then wait
|
||||
for either `accept` or `close` in order to determine if we should
|
||||
close the connection.
|
||||
"""
|
||||
path_portion, _, query_string = path.partition("?")
|
||||
|
||||
websockets.legacy.handshake.check_request(request_headers)
|
||||
|
||||
subprotocols: list[str] = []
|
||||
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
|
||||
subprotocols.extend([token.strip() for token in header.split(",")])
|
||||
|
||||
asgi_headers = [
|
||||
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for name, value in request_headers.raw_items()
|
||||
]
|
||||
path = unquote(path_portion)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
|
||||
|
||||
self.scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": asgi_headers,
|
||||
"subprotocols": subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
await self.handshake_started_event.wait()
|
||||
return self.initial_response
|
||||
|
||||
def process_subprotocol(
|
||||
self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
We override the standard 'process_subprotocol' behavior here so that
|
||||
we return whatever subprotocol is sent in the 'accept' message.
|
||||
"""
|
||||
return self.accepted_subprotocol
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
msg = b"Internal Server Error"
|
||||
content = [
|
||||
b"HTTP/1.1 500 Internal Server Error\r\ncontent-type: text/plain; charset=utf-8\r\n",
|
||||
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
|
||||
b"connection: close\r\n",
|
||||
b"\r\n",
|
||||
msg,
|
||||
]
|
||||
self.transport.write(b"".join(content))
|
||||
# Allow handler task to terminate cleanly, as websockets doesn't cancel it by
|
||||
# itself (see https://github.com/Kludex/uvicorn/issues/920)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
async def ws_handler(self, protocol: WebSocketServerProtocol, path: str) -> Any: # type: ignore[override]
|
||||
"""
|
||||
This is the main handler function for the 'websockets' implementation
|
||||
to call into. We just wait for close then return, and instead allow
|
||||
'send' and 'receive' events to drive the flow.
|
||||
"""
|
||||
self.handshake_completed_event.set()
|
||||
await self.wait_closed()
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
"""
|
||||
Wrapper around the ASGI callable, handling exceptions and unexpected
|
||||
termination states.
|
||||
"""
|
||||
try:
|
||||
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected: # pragma: full coverage
|
||||
self.closed_event.set()
|
||||
except BaseException:
|
||||
self.closed_event.set()
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.send_500_response()
|
||||
else:
|
||||
await self.handshake_completed_event.wait()
|
||||
else:
|
||||
self.closed_event.set()
|
||||
if not self.handshake_started_event.is_set():
|
||||
self.logger.error("ASGI callable returned without sending handshake.")
|
||||
self.send_500_response()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
await self.handshake_completed_event.wait()
|
||||
self.transport.close()
|
||||
|
||||
async def asgi_send(self, message: ASGISendEvent) -> None:
|
||||
if not self.handshake_started_event.is_set():
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = None
|
||||
self.accepted_subprotocol = cast(Subprotocol | None, message.get("subprotocol"))
|
||||
if "headers" in message:
|
||||
self.extra_headers.extend(
|
||||
# ASGI spec requires bytes
|
||||
# But for compatibility we need to convert it to strings
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in message["headers"]
|
||||
)
|
||||
self.handshake_started_event.set()
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
|
||||
self.handshake_started_event.set()
|
||||
self.closed_event.set()
|
||||
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
# websockets requires the status to be an enum. look it up.
|
||||
status = http.HTTPStatus(message["status"])
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
|
||||
]
|
||||
self.initial_response = (status, headers, b"")
|
||||
self.handshake_started_event.set()
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close', "
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
|
||||
elif not self.closed_event.is_set() and self.initial_response is None:
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
try:
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
await self.send(data) # type: ignore[arg-type]
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
await self.close(code, reason)
|
||||
self.closed_event.set()
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except ConnectionClosed as exc:
|
||||
raise ClientDisconnected from exc
|
||||
|
||||
elif self.initial_response is not None:
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
self.closed_event.set()
|
||||
else:
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close' "
|
||||
"or response already completed."
|
||||
)
|
||||
|
||||
async def asgi_receive(self) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
|
||||
if not self.connect_sent:
|
||||
self.connect_sent = True
|
||||
return {"type": "websocket.connect"}
|
||||
|
||||
await self.handshake_completed_event.wait()
|
||||
|
||||
if self.lost_connection_before_handshake:
|
||||
# If the handshake failed or the app closed before handshake completion,
|
||||
# use 1006 Abnormal Closure.
|
||||
return {"type": "websocket.disconnect", "code": 1006}
|
||||
|
||||
if self.closed_event.is_set():
|
||||
return {"type": "websocket.disconnect", "code": 1005}
|
||||
|
||||
try:
|
||||
data = await self.recv()
|
||||
except ConnectionClosed:
|
||||
self.closed_event.set()
|
||||
if self.ws_server.closing:
|
||||
return {"type": "websocket.disconnect", "code": 1012}
|
||||
return {"type": "websocket.disconnect", "code": self.close_code or 1005, "reason": self.close_reason}
|
||||
|
||||
if isinstance(data, str):
|
||||
return {"type": "websocket.receive", "text": data}
|
||||
return {"type": "websocket.receive", "bytes": data}
|
||||
+477
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
from asyncio import TimerHandle
|
||||
from asyncio.transports import BaseTransport, Transport
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
from websockets.exceptions import InvalidState
|
||||
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.http11 import Request
|
||||
from websockets.server import ServerProtocol
|
||||
|
||||
from uvicorn._types import (
|
||||
ASGIReceiveEvent,
|
||||
ASGISendEvent,
|
||||
WebSocketScope,
|
||||
)
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
if sys.version_info >= (3, 11): # pragma: no cover
|
||||
from typing import assert_never
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import assert_never
|
||||
|
||||
|
||||
class WebSocketsSansIOProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: no cover
|
||||
|
||||
self.config = config
|
||||
self.app = config.loaded_app
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int | None] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
|
||||
self.handshake_initiated = False
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None
|
||||
|
||||
extensions = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions = [
|
||||
ServerPerMessageDeflateFactory(
|
||||
server_max_window_bits=12,
|
||||
client_max_window_bits=12,
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
self.conn = ServerProtocol(
|
||||
extensions=extensions,
|
||||
max_size=self.config.ws_max_size,
|
||||
logger=logging.getLogger("uvicorn.error"),
|
||||
)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Keepalive state
|
||||
self.ping_interval = config.ws_ping_interval
|
||||
self.ping_timeout = config.ws_ping_timeout
|
||||
self.ping_timer: TimerHandle | None = None
|
||||
self.pong_timer: TimerHandle | None = None
|
||||
self.pending_ping_payload: bytes | None = None
|
||||
self.ping_sent_at: float = 0.0
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffers
|
||||
self.bytes = bytearray()
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when a connection is made."""
|
||||
transport = cast(Transport, transport)
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.stop_keepalive()
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.stop_keepalive()
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
self.conn.send_close(1012)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self.conn.receive_data(data)
|
||||
if self.conn.parser_exc is not None: # pragma: no cover
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events_received():
|
||||
if isinstance(event, Request):
|
||||
self.handle_connect(event)
|
||||
if isinstance(event, Frame):
|
||||
if event.opcode == Opcode.CONT:
|
||||
self.handle_cont(event) # pragma: no cover
|
||||
elif event.opcode == Opcode.TEXT:
|
||||
self.handle_text(event)
|
||||
elif event.opcode == Opcode.BINARY:
|
||||
self.handle_bytes(event)
|
||||
elif event.opcode == Opcode.PING:
|
||||
self.handle_ping()
|
||||
elif event.opcode == Opcode.PONG:
|
||||
self.handle_pong(event)
|
||||
elif event.opcode == Opcode.CLOSE:
|
||||
self.handle_close(event)
|
||||
else:
|
||||
assert_never(event.opcode) # pragma: no cover
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: Request) -> None:
|
||||
self.request = event
|
||||
self.response = self.conn.accept(event)
|
||||
self.handshake_initiated = True
|
||||
if self.response.status_code != 101:
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
return
|
||||
|
||||
headers = [
|
||||
(key.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
|
||||
for key, value in event.headers.raw_items()
|
||||
]
|
||||
raw_path, _, query_string = event.path.partition("?")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": self.root_path + unquote(raw_path),
|
||||
"raw_path": self.root_path.encode("ascii") + raw_path.encode("ascii"),
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_cont(self, event: Frame) -> None:
|
||||
self.bytes.extend(event.data)
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_text(self, event: Frame) -> None:
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type: Literal["text", "bytes"] = "text"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def handle_bytes(self, event: Frame) -> None:
|
||||
self.bytes = bytearray(event.data)
|
||||
self.curr_msg_data_type = "bytes"
|
||||
if event.fin:
|
||||
self.send_receive_event_to_app()
|
||||
|
||||
def send_receive_event_to_app(self) -> None:
|
||||
if self.curr_msg_data_type == "text":
|
||||
try:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "text": self.bytes.decode()})
|
||||
except UnicodeDecodeError: # pragma: no cover
|
||||
self.logger.exception("Invalid UTF-8 sequence received from client.")
|
||||
self.conn.send_close(1007)
|
||||
self.handle_parser_exception()
|
||||
return
|
||||
else:
|
||||
self.queue.put_nowait({"type": "websocket.receive", "bytes": bytes(self.bytes)})
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_ping(self) -> None:
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
def handle_pong(self, event: Frame) -> None:
|
||||
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight
|
||||
if self.pending_ping_payload is None or bytes(event.data) != self.pending_ping_payload:
|
||||
return # pragma: no cover
|
||||
|
||||
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
|
||||
self.pending_ping_payload = None
|
||||
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
|
||||
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
|
||||
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
|
||||
if self.pong_timer is not None:
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.schedule_ping()
|
||||
|
||||
def start_keepalive(self) -> None:
|
||||
if self.ping_interval is not None and self.ping_interval > 0:
|
||||
self.schedule_ping()
|
||||
|
||||
def stop_keepalive(self) -> None:
|
||||
if self.ping_timer is not None:
|
||||
self.ping_timer.cancel()
|
||||
self.ping_timer = None
|
||||
if self.pong_timer is not None: # pragma: no cover
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
|
||||
def schedule_ping(self) -> None:
|
||||
assert self.ping_interval is not None
|
||||
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
|
||||
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
|
||||
|
||||
def send_keepalive_ping(self) -> None:
|
||||
self.ping_timer = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
|
||||
# See https://github.com/python-websockets/websockets/blob/4d229bf9f583d593aa103287aee0a77c9fbc3a79/src/websockets/asyncio/connection.py#L624
|
||||
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
|
||||
self.ping_sent_at = self.loop.time()
|
||||
self.conn.send_ping(self.pending_ping_payload)
|
||||
self.transport.write(b"".join(self.conn.data_to_send()))
|
||||
if self.ping_timeout is not None:
|
||||
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
|
||||
else: # pragma: no cover
|
||||
self.schedule_ping()
|
||||
|
||||
def keepalive_timeout(self) -> None:
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
|
||||
self.conn.fail(1011, "keepalive ping timeout")
|
||||
self.transport.write(b"".join(self.conn.data_to_send()))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def handle_close(self, event: Frame) -> None:
|
||||
if not self.close_sent and not self.transport.is_closing():
|
||||
assert self.conn.close_rcvd is not None
|
||||
code = self.conn.close_rcvd.code
|
||||
reason = self.conn.close_rcvd.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
|
||||
def handle_parser_exception(self) -> None: # pragma: no cover
|
||||
assert self.conn.close_sent is not None
|
||||
code = self.conn.close_sent.code
|
||||
reason = self.conn.close_sent.reason
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send)
|
||||
except ClientDisconnected:
|
||||
pass # pragma: full coverage
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.initial_response or self.handshake_complete:
|
||||
return
|
||||
response = self.conn.reject(500, "Internal Server Error")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
if not self.handshake_complete and self.initial_response is None:
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1").lower(), value.decode("latin-1"))
|
||||
for name, value in (self.default_headers + list(message.get("headers", [])))
|
||||
]
|
||||
accepted_subprotocol = message.get("subprotocol")
|
||||
if accepted_subprotocol:
|
||||
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))
|
||||
self.response.headers.update(headers)
|
||||
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
self.conn.send_response(self.response)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.start_keepalive()
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.handshake_complete = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
elif message["type"] == "websocket.http.response.start" and self.initial_response is None:
|
||||
if not (100 <= message["status"] < 600):
|
||||
raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
headers = [
|
||||
(name.decode("latin-1"), value.decode("latin-1"))
|
||||
for name, value in list(message.get("headers", []))
|
||||
]
|
||||
self.initial_response = (message["status"], headers, b"")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
|
||||
elif not self.close_sent and self.initial_response is None:
|
||||
try:
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
if bytes_data is not None:
|
||||
self.conn.send_binary(bytes_data)
|
||||
elif text_data is not None:
|
||||
self.conn.send_text(text_data.encode())
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
if not self.transport.is_closing():
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
self.conn.send_close(code, reason)
|
||||
output = self.conn.data_to_send()
|
||||
self.transport.write(b"".join(output))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except InvalidState:
|
||||
raise ClientDisconnected()
|
||||
elif self.initial_response is not None:
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body = self.initial_response[2] + message["body"]
|
||||
self.initial_response = self.initial_response[:2] + (body,)
|
||||
if not message.get("more_body", False):
|
||||
response = self.conn.reject(self.initial_response[0], body.decode())
|
||||
response.headers.update(self.initial_response[1])
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.conn.send_response(response)
|
||||
output = self.conn.data_to_send()
|
||||
self.close_sent = True
|
||||
self.transport.write(b"".join(output))
|
||||
self.transport.close()
|
||||
else: # pragma: no cover
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
|
||||
|
||||
async def receive(self) -> ASGIReceiveEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
||||
@@ -0,0 +1,458 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import struct
|
||||
from asyncio import TimerHandle
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, Literal, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import wsproto
|
||||
from wsproto import ConnectionType, events
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.extensions import Extension, PerMessageDeflate
|
||||
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
|
||||
|
||||
from uvicorn._types import ASGI3Application, ASGISendEvent, WebSocketEvent, WebSocketReceiveEvent, WebSocketScope
|
||||
from uvicorn.config import Config
|
||||
from uvicorn.logging import TRACE_LOG_LEVEL
|
||||
from uvicorn.protocols.utils import (
|
||||
ClientDisconnected,
|
||||
get_client_addr,
|
||||
get_local_addr,
|
||||
get_path_with_query_string,
|
||||
get_remote_addr,
|
||||
is_ssl,
|
||||
)
|
||||
from uvicorn.server import ServerState
|
||||
|
||||
|
||||
class FrameTooLargeError(Exception):
|
||||
"""Raised when accumulated websocket message bytes exceed `ws_max_size`."""
|
||||
|
||||
|
||||
class WebsocketBuffer:
|
||||
def __init__(self, max_length: int) -> None:
|
||||
self.value: BytesIO | StringIO | None = None
|
||||
self.length = 0
|
||||
self.max_length = max_length
|
||||
|
||||
def extend(self, event: events.TextMessage | events.BytesMessage) -> None:
|
||||
if self.value is None:
|
||||
self.value = StringIO() if isinstance(event, events.TextMessage) else BytesIO()
|
||||
self.value.write(event.data) # type: ignore[arg-type]
|
||||
# `ws_max_size` is a byte budget, so count UTF-8 bytes for text.
|
||||
self.length += len(event.data.encode()) if isinstance(event, events.TextMessage) else len(event.data)
|
||||
if self.length > self.max_length:
|
||||
raise FrameTooLargeError
|
||||
|
||||
def clear(self) -> None:
|
||||
self.value = None
|
||||
self.length = 0
|
||||
|
||||
def to_message(self) -> WebSocketReceiveEvent:
|
||||
if isinstance(self.value, StringIO):
|
||||
return {"type": "websocket.receive", "text": self.value.getvalue()}
|
||||
assert isinstance(self.value, BytesIO)
|
||||
return {"type": "websocket.receive", "bytes": self.value.getvalue()}
|
||||
|
||||
|
||||
class WSProtocol(asyncio.Protocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
server_state: ServerState,
|
||||
app_state: dict[str, Any],
|
||||
_loop: asyncio.AbstractEventLoop | None = None,
|
||||
) -> None:
|
||||
if not config.loaded:
|
||||
config.load() # pragma: full coverage
|
||||
|
||||
self.config = config
|
||||
self.app = cast(ASGI3Application, config.loaded_app)
|
||||
self.loop = _loop or asyncio.get_event_loop()
|
||||
self.logger = logging.getLogger("uvicorn.error")
|
||||
self.root_path = config.root_path
|
||||
self.app_state = app_state
|
||||
|
||||
# Shared server state
|
||||
self.connections = server_state.connections
|
||||
self.tasks = server_state.tasks
|
||||
self.default_headers = server_state.default_headers
|
||||
|
||||
# Connection state
|
||||
self.transport: asyncio.Transport = None # type: ignore[assignment]
|
||||
self.server: tuple[str, int | None] | None = None
|
||||
self.client: tuple[str, int] | None = None
|
||||
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
|
||||
|
||||
# WebSocket state
|
||||
self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
|
||||
self.handshake_complete = False
|
||||
self.close_sent = False
|
||||
|
||||
# Rejection state
|
||||
self.response_started = False
|
||||
|
||||
self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
|
||||
|
||||
self.read_paused = False
|
||||
self.writable = asyncio.Event()
|
||||
self.writable.set()
|
||||
|
||||
# Keepalive state
|
||||
self.ping_interval = config.ws_ping_interval
|
||||
self.ping_timeout = config.ws_ping_timeout
|
||||
self.ping_timer: TimerHandle | None = None
|
||||
self.pong_timer: TimerHandle | None = None
|
||||
self.pending_ping_payload: bytes | None = None
|
||||
self.ping_sent_at: float = 0.0
|
||||
self.last_ping_rtt: float = 0.0
|
||||
|
||||
# Buffer
|
||||
self.buffer = WebsocketBuffer(self.config.ws_max_size)
|
||||
|
||||
# Protocol interface
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore[override]
|
||||
self.connections.add(self)
|
||||
self.transport = transport
|
||||
self.server = get_local_addr(transport)
|
||||
self.client = get_remote_addr(transport)
|
||||
self.scheme = "wss" if is_ssl(transport) else "ws"
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
|
||||
|
||||
def connection_lost(self, exc: Exception | None) -> None:
|
||||
self.stop_keepalive()
|
||||
code = 1005 if self.handshake_complete else 1006
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
|
||||
self.connections.remove(self)
|
||||
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
|
||||
|
||||
self.handshake_complete = True
|
||||
if exc is None:
|
||||
self.transport.close()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
pass
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
try:
|
||||
self.conn.receive_data(data)
|
||||
except RemoteProtocolError as err:
|
||||
# TODO: Remove `type: ignore` when wsproto fixes the type annotation.
|
||||
self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501
|
||||
self.transport.close()
|
||||
else:
|
||||
self.handle_events()
|
||||
|
||||
def handle_events(self) -> None:
|
||||
for event in self.conn.events():
|
||||
if self.close_sent:
|
||||
return
|
||||
if isinstance(event, events.Request):
|
||||
self.handle_connect(event)
|
||||
elif isinstance(event, (events.TextMessage, events.BytesMessage)):
|
||||
self.handle_message(event)
|
||||
elif isinstance(event, events.CloseConnection):
|
||||
self.handle_close(event)
|
||||
elif isinstance(event, events.Ping):
|
||||
self.handle_ping(event)
|
||||
elif isinstance(event, events.Pong):
|
||||
self.handle_pong(event)
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer exceeds the high water mark.
|
||||
"""
|
||||
self.writable.clear() # pragma: full coverage
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
"""
|
||||
Called by the transport when the write buffer drops below the low water mark.
|
||||
"""
|
||||
self.writable.set() # pragma: full coverage
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.stop_keepalive()
|
||||
if self.handshake_complete:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
|
||||
self.transport.write(output)
|
||||
else:
|
||||
self.send_500_response()
|
||||
self.transport.close()
|
||||
|
||||
def on_task_complete(self, task: asyncio.Task[None]) -> None:
|
||||
self.tasks.discard(task)
|
||||
|
||||
# Event handlers
|
||||
|
||||
def handle_connect(self, event: events.Request) -> None:
|
||||
headers = [(b"host", event.host.encode())]
|
||||
headers += [(key.lower(), value) for key, value in event.extra_headers]
|
||||
raw_path, _, query_string = event.target.partition("?")
|
||||
path = unquote(raw_path)
|
||||
full_path = self.root_path + path
|
||||
full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
|
||||
self.scope: WebSocketScope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
|
||||
"http_version": "1.1",
|
||||
"scheme": self.scheme,
|
||||
"server": self.server,
|
||||
"client": self.client,
|
||||
"root_path": self.root_path,
|
||||
"path": full_path,
|
||||
"raw_path": full_raw_path,
|
||||
"query_string": query_string.encode("ascii"),
|
||||
"headers": headers,
|
||||
"subprotocols": event.subprotocols,
|
||||
"state": self.app_state.copy(),
|
||||
"extensions": {"websocket.http.response": {}},
|
||||
}
|
||||
self.queue.put_nowait({"type": "websocket.connect"})
|
||||
task = self.loop.create_task(self.run_asgi())
|
||||
task.add_done_callback(self.on_task_complete)
|
||||
self.tasks.add(task)
|
||||
|
||||
def handle_message(self, event: events.TextMessage | events.BytesMessage) -> None:
|
||||
try:
|
||||
self.buffer.extend(event)
|
||||
except FrameTooLargeError:
|
||||
self.close_sent = True
|
||||
reason = f"Message exceeds the maximum size ({self.config.ws_max_size} bytes)"
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1009, "reason": reason})
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1009, reason=reason)))
|
||||
self.transport.close()
|
||||
return
|
||||
if event.message_finished:
|
||||
self.queue.put_nowait(self.buffer.to_message())
|
||||
self.buffer.clear()
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self.transport.pause_reading()
|
||||
|
||||
def handle_close(self, event: events.CloseConnection) -> None:
|
||||
if self.conn.state == ConnectionState.REMOTE_CLOSING:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code, "reason": event.reason})
|
||||
self.transport.close()
|
||||
|
||||
def handle_ping(self, event: events.Ping) -> None:
|
||||
self.transport.write(self.conn.send(event.response()))
|
||||
|
||||
def handle_pong(self, event: events.Pong) -> None:
|
||||
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight.
|
||||
if self.pending_ping_payload is None or bytes(event.payload) != self.pending_ping_payload:
|
||||
return # pragma: no cover
|
||||
|
||||
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
|
||||
self.pending_ping_payload = None
|
||||
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
|
||||
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
|
||||
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
|
||||
if self.pong_timer is not None:
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.schedule_ping()
|
||||
|
||||
def start_keepalive(self) -> None:
|
||||
if self.ping_interval is not None and self.ping_interval > 0:
|
||||
self.schedule_ping()
|
||||
|
||||
def stop_keepalive(self) -> None:
|
||||
if self.ping_timer is not None:
|
||||
self.ping_timer.cancel()
|
||||
self.ping_timer = None
|
||||
if self.pong_timer is not None: # pragma: no cover
|
||||
self.pong_timer.cancel()
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
|
||||
def schedule_ping(self) -> None:
|
||||
assert self.ping_interval is not None
|
||||
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
|
||||
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
|
||||
|
||||
def send_keepalive_ping(self) -> None:
|
||||
self.ping_timer = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
|
||||
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
|
||||
self.ping_sent_at = self.loop.time()
|
||||
self.transport.write(self.conn.send(wsproto.events.Ping(payload=self.pending_ping_payload)))
|
||||
if self.ping_timeout is not None:
|
||||
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
|
||||
else: # pragma: no cover
|
||||
self.schedule_ping()
|
||||
|
||||
def keepalive_timeout(self) -> None:
|
||||
self.pong_timer = None
|
||||
self.pending_ping_payload = None
|
||||
if self.close_sent or self.transport.is_closing(): # pragma: no cover
|
||||
return
|
||||
if self.logger.level <= TRACE_LOG_LEVEL:
|
||||
prefix = "%s:%d - " % self.client if self.client else ""
|
||||
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
|
||||
reason = "keepalive ping timeout"
|
||||
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1011, reason=reason)))
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
def send_500_response(self) -> None:
|
||||
if self.response_started or self.handshake_complete:
|
||||
return # we cannot send responses anymore
|
||||
headers: list[tuple[bytes, bytes]] = [
|
||||
(b"content-type", b"text/plain; charset=utf-8"),
|
||||
(b"connection", b"close"),
|
||||
(b"content-length", b"21"),
|
||||
]
|
||||
output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
|
||||
output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
|
||||
self.transport.write(output)
|
||||
|
||||
async def run_asgi(self) -> None:
|
||||
try:
|
||||
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
|
||||
except ClientDisconnected:
|
||||
pass # pragma: full coverage
|
||||
except BaseException:
|
||||
self.logger.exception("Exception in ASGI application\n")
|
||||
self.send_500_response()
|
||||
else:
|
||||
if not self.handshake_complete:
|
||||
self.logger.error("ASGI callable returned without completing handshake.")
|
||||
self.send_500_response()
|
||||
elif result is not None:
|
||||
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
|
||||
self.transport.close()
|
||||
|
||||
async def send(self, message: ASGISendEvent) -> None:
|
||||
await self.writable.wait()
|
||||
|
||||
if not self.handshake_complete:
|
||||
if message["type"] == "websocket.accept":
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" [accepted]',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
subprotocol = message.get("subprotocol")
|
||||
extra_headers = self.default_headers + list(message.get("headers", []))
|
||||
extensions: list[Extension] = []
|
||||
if self.config.ws_per_message_deflate:
|
||||
extensions.append(PerMessageDeflate())
|
||||
if not self.transport.is_closing():
|
||||
self.handshake_complete = True
|
||||
output = self.conn.send(
|
||||
wsproto.events.AcceptConnection(
|
||||
subprotocol=subprotocol,
|
||||
extensions=extensions,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
)
|
||||
self.transport.write(output)
|
||||
self.start_keepalive()
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" 403',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
)
|
||||
self.handshake_complete = True
|
||||
self.close_sent = True
|
||||
event = events.RejectConnection(status_code=403, headers=[])
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
elif message["type"] == "websocket.http.response.start":
|
||||
# ensure status code is in the valid range
|
||||
if not (100 <= message["status"] < 600):
|
||||
msg = "Invalid HTTP status code '%d' in response."
|
||||
raise RuntimeError(msg % message["status"])
|
||||
self.logger.info(
|
||||
'%s - "WebSocket %s" %d',
|
||||
get_client_addr(self.scope),
|
||||
get_path_with_query_string(self.scope),
|
||||
message["status"],
|
||||
)
|
||||
self.handshake_complete = True
|
||||
event = events.RejectConnection(
|
||||
status_code=message["status"],
|
||||
headers=list(message["headers"]),
|
||||
has_body=True,
|
||||
)
|
||||
output = self.conn.send(event)
|
||||
self.transport.write(output)
|
||||
self.response_started = True
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Expected ASGI message 'websocket.accept', 'websocket.close' "
|
||||
f"or 'websocket.http.response.start' but got '{message['type']}'."
|
||||
)
|
||||
|
||||
elif not self.close_sent and not self.response_started:
|
||||
try:
|
||||
if message["type"] == "websocket.send":
|
||||
bytes_data = message.get("bytes")
|
||||
text_data = message.get("text")
|
||||
data = text_data if bytes_data is None else bytes_data
|
||||
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
|
||||
elif message["type"] == "websocket.close":
|
||||
self.close_sent = True
|
||||
code = message.get("code", 1000)
|
||||
reason = message.get("reason", "") or ""
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||||
output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
|
||||
if not self.transport.is_closing():
|
||||
self.transport.write(output)
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Expected ASGI message 'websocket.send' or 'websocket.close', but got '{message['type']}'."
|
||||
)
|
||||
except LocalProtocolError as exc:
|
||||
raise ClientDisconnected from exc
|
||||
elif self.response_started:
|
||||
if message["type"] == "websocket.http.response.body":
|
||||
body_finished = not message.get("more_body", False)
|
||||
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
|
||||
output = self.conn.send(reject_data)
|
||||
self.transport.write(output)
|
||||
|
||||
if body_finished:
|
||||
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
|
||||
self.close_sent = True
|
||||
self.transport.close()
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Expected ASGI message 'websocket.http.response.body' but got '{message['type']}'.")
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected ASGI message '{message['type']}', after sending 'websocket.close'.")
|
||||
|
||||
async def receive(self) -> WebSocketEvent:
|
||||
message = await self.queue.get()
|
||||
if self.read_paused and self.queue.empty():
|
||||
self.read_paused = False
|
||||
self.transport.resume_reading()
|
||||
return message
|
||||
Reference in New Issue
Block a user