On branch DiscordProfile
Initial commit
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
discord.voice
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
Voice support for the Discord API.
|
||||
|
||||
:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development
|
||||
:license: MIT, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from ..errors import MissingVoiceDependenciesError
|
||||
from ..utils import get_missing_voice_dependencies
|
||||
|
||||
if missing := get_missing_voice_dependencies():
|
||||
raise MissingVoiceDependenciesError(missing=missing)
|
||||
|
||||
from ._types import *
|
||||
from .client import *
|
||||
from .packets import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from discord import abc
|
||||
from discord.client import Client
|
||||
from discord.raw_models import (
|
||||
RawVoiceServerUpdateEvent,
|
||||
RawVoiceStateUpdateEvent,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
ClientT = TypeVar("ClientT", bound="Client", covariant=True)
|
||||
|
||||
__all__ = ("VoiceProtocol",)
|
||||
|
||||
|
||||
class VoiceProtocol(Generic[ClientT]):
|
||||
"""A class that represents the Discord voice protocol.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you are an end user, you **should not construct this manually** but instead
|
||||
take it from the return type in :meth:`abc.Connectable.connect <VoiceChannel.connect>`.
|
||||
The parameters and methods being documented here is so third party libraries can refer to it
|
||||
when implementing their own VoiceProtocol types.
|
||||
|
||||
This is an abstract class. The library provides a concrete implementation
|
||||
under :class:`VoiceClient`.
|
||||
|
||||
This class allows you to implement a protocol to allow for an external
|
||||
method of sending voice, such as Lavalink_ or a native library implementation.
|
||||
|
||||
These classes are passed to :meth:`abc.Connectable.connect <VoiceChannel.connect>`.
|
||||
|
||||
.. _Lavalink: https://github.com/freyacodes/Lavalink
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client: :class:`Client`
|
||||
The client (or its subclasses) that started the connection request.
|
||||
channel: :class:`abc.Connectable`
|
||||
The voice channel that is being connected to.
|
||||
"""
|
||||
|
||||
def __init__(self, client: ClientT, channel: abc.Connectable) -> None:
|
||||
self.client: ClientT = client
|
||||
self.channel: abc.Connectable = channel
|
||||
|
||||
async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
|
||||
"""|coro|
|
||||
|
||||
A method called when the client's voice state has changed. This corresponds
|
||||
to the ``VOICE_STATE_UPDATE`` event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: :class:`RawVoiceStateUpdateEvent`
|
||||
The voice state payload.
|
||||
|
||||
.. versionchanged:: 2.7
|
||||
This now gets passed a `RawVoiceStateUpdateEvent` object instead of a :class:`dict`, but
|
||||
accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
|
||||
"""|coro|
|
||||
|
||||
A method called when the client's intially connecting to voice. This corresponds
|
||||
to the ``VOICE_SERVER_UPDATE`` event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: :class:`RawVoiceServerUpdateEvent`
|
||||
The voice server payload.
|
||||
|
||||
.. versionchanged:: 2.7
|
||||
This now gets passed a `RawVoiceServerUpdateEvent` object instead of a :class:`dict`, but
|
||||
accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def connect(self, *, timeout: float, reconnect: bool) -> None:
|
||||
"""|coro|
|
||||
|
||||
A method called to initialise the connection.
|
||||
|
||||
The library initialises this class and calls ``__init__``, and then :meth:`connect` when attempting
|
||||
to start a connection to the voice. If an error ocurrs, it calls :meth:`disconnect`, so if you need
|
||||
to implement any cleanup, you should manually call it in :meth:`disconnect` as the library will not
|
||||
do so for you.
|
||||
|
||||
Within this method, to start the voice connection flow, it is recommened to use :meth:`Guild.change_voice_state`
|
||||
to start the flow. After which :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called,
|
||||
although this could vary and cause unexpected behaviour, but that falls under Discord's way of handling the voice
|
||||
connection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: :class:`float`
|
||||
The timeout for the connection.
|
||||
reconnect: :class:`bool`
|
||||
Whether reconnection is expected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self, *, force: bool) -> None:
|
||||
"""|coro|
|
||||
|
||||
A method called to terminate the voice connection.
|
||||
|
||||
This can be either called manually when forcing a disconnection, or when an exception in :meth:`connect` ocurrs.
|
||||
|
||||
It is recommended to call :meth:`cleanup` here.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
force: :class:`bool`
|
||||
Whether the disconnection was forced.
|
||||
"""
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""This method *must* be called to ensure proper clean-up during a disconnect.
|
||||
|
||||
It is advisable to call this from within :meth:`disconnect` when you are completely
|
||||
done with the voice protocol instance.
|
||||
|
||||
This method removes it from the internal state cache that keeps track of the currently
|
||||
alive voice clients. Failure to clean-up will cause subsequent connections to report that
|
||||
it's still connected.
|
||||
|
||||
**The library will NOT automatically call this for you**, unlike :meth:`connect` and :meth:`disconnect`.
|
||||
"""
|
||||
key, _ = self.channel._get_voice_client_key()
|
||||
self.client._connection._remove_voice_client(key)
|
||||
@@ -0,0 +1,810 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import struct
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from discord import opus
|
||||
from discord.enums import SpeakingState, try_enum
|
||||
from discord.errors import ClientException
|
||||
from discord.player import AudioPlayer, AudioSource
|
||||
from discord.sinks.core import Sink
|
||||
from discord.sinks.errors import RecordingException
|
||||
from discord.utils import MISSING
|
||||
|
||||
from ..utils import get_missing_voice_dependencies
|
||||
from ._types import VoiceProtocol
|
||||
from .enums import OpCodes
|
||||
from .receive import AudioReader
|
||||
from .state import VoiceConnectionState
|
||||
from .utils.dependencies import HAS_DAVEY, HAS_NACL
|
||||
|
||||
if HAS_NACL:
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from discord import abc
|
||||
from discord.client import Client
|
||||
from discord.guild import Guild, VocalGuildChannel
|
||||
from discord.member import Member
|
||||
from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Encoder
|
||||
from discord.raw_models import (
|
||||
RawVoiceServerUpdateEvent,
|
||||
RawVoiceStateUpdateEvent,
|
||||
)
|
||||
from discord.state import ConnectionState
|
||||
from discord.types.voice import SupportedModes
|
||||
from discord.user import ClientUser, User
|
||||
|
||||
from .gateway import VoiceWebSocket
|
||||
from .receive.reader import AfterCallback
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ("VoiceClient",)
|
||||
|
||||
|
||||
class VoiceClient(VoiceProtocol):
|
||||
"""Represents a Discord voice connection.
|
||||
|
||||
You do not create these, you typically get them from e.g.
|
||||
:meth:`VoiceChannel.connect`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
session_id: :class:`str`
|
||||
The voice connection session ID.
|
||||
token: :class:`str`
|
||||
The voice connection token.
|
||||
endpoint: :class:`str`
|
||||
The endpoint we are connecting to.
|
||||
channel: Union[:class:`VoiceChannel`, :class:`StageChannel`]
|
||||
The channel we are connected to.
|
||||
|
||||
Warning
|
||||
-------
|
||||
In order to use PCM based AudioSources, you must have the opus library
|
||||
installed on your system and loaded through :func:`opus.load_opus`.
|
||||
Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`)
|
||||
or the library will not be able ot transmit audio.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
channel: abc.Connectable,
|
||||
) -> None:
|
||||
missing = get_missing_voice_dependencies()
|
||||
if missing:
|
||||
deps = ", ".join(missing)
|
||||
raise RuntimeError(
|
||||
f"{deps} {'library is' if len(missing) == 1 else 'libraries are'} needed "
|
||||
"in order to use voice related features, "
|
||||
'you can run "pip install py-cord[voice]" to install all voice-related '
|
||||
"dependencies."
|
||||
)
|
||||
|
||||
super().__init__(client, channel)
|
||||
state = client._connection
|
||||
|
||||
self.server_id: int = MISSING
|
||||
self.socket = MISSING
|
||||
self.loop: asyncio.AbstractEventLoop = state.loop
|
||||
self._state: ConnectionState = state
|
||||
|
||||
self.sequence: int = 0
|
||||
self.timestamp: int = 0
|
||||
self._player: AudioPlayer | None = None
|
||||
self._player_future: asyncio.Future[None] | None = None
|
||||
self.encoder: Encoder = MISSING
|
||||
self._incr_nonce: int = 0
|
||||
|
||||
self._connection: VoiceConnectionState = self.create_connection_state()
|
||||
|
||||
self._ssrc_to_id: dict[int, int] = {}
|
||||
self._id_to_ssrc: dict[int, int] = {}
|
||||
self._event_listeners: dict[str, list] = {}
|
||||
self._reader: AudioReader = MISSING
|
||||
|
||||
@staticmethod
|
||||
def _set_future_result_if_pending(
|
||||
future: asyncio.Future[Any], result: Exception | None
|
||||
) -> None:
|
||||
if not future.done():
|
||||
future.set_result(result)
|
||||
|
||||
supported_modes: tuple[SupportedModes, ...] = (
|
||||
"aead_xchacha20_poly1305_rtpsize",
|
||||
"xsalsa20_poly1305_lite",
|
||||
"xsalsa20_poly1305_suffix",
|
||||
"xsalsa20_poly1305",
|
||||
)
|
||||
|
||||
@property
|
||||
def guild(self) -> Guild:
|
||||
"""Returns the guild the channel we're connecting to is bound to."""
|
||||
channel: VocalGuildChannel = self.channel
|
||||
return channel.guild
|
||||
|
||||
@property
|
||||
def user(self) -> ClientUser:
|
||||
"""The user connected to voice (i.e. ourselves)"""
|
||||
return self._state.user # type: ignore
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
"""The session ID of the current voice call."""
|
||||
return self._connection.session_id
|
||||
|
||||
@property
|
||||
def token(self) -> str | None:
|
||||
"""The token of the voice connection. You should not share this."""
|
||||
return self._connection.token
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str | None:
|
||||
"""The endpoint where the client is connected."""
|
||||
return self._connection.endpoint
|
||||
|
||||
@property
|
||||
def ssrc(self) -> int:
|
||||
"""The SSRC of the current user in the voice call."""
|
||||
return self._connection.ssrc
|
||||
|
||||
@property
|
||||
def mode(self) -> SupportedModes:
|
||||
"""The encryption / decryption mode the voice client is currently using."""
|
||||
return self._connection.mode
|
||||
|
||||
@property
|
||||
def secret_key(self) -> list[int]:
|
||||
"""Returns the secret key of the current connected voice call."""
|
||||
return self._connection.secret_key
|
||||
|
||||
@property
|
||||
def ws(self) -> VoiceWebSocket:
|
||||
return self._connection.ws
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
"""The amount of ms the client waits before killing connect attempts."""
|
||||
return self._connection.timeout
|
||||
|
||||
def is_dave_connection(self) -> bool:
|
||||
"""Whether the voice client is connected to a DAVE call."""
|
||||
session = self._connection.dave_session
|
||||
return session is not None
|
||||
|
||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||
val = getattr(self, attr)
|
||||
if val + value > limit:
|
||||
setattr(self, attr, 0)
|
||||
else:
|
||||
setattr(self, attr, val + value)
|
||||
|
||||
def create_connection_state(self) -> VoiceConnectionState:
|
||||
return VoiceConnectionState(self, hook=self._recv_hook)
|
||||
|
||||
async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None:
|
||||
op = msg["op"]
|
||||
data = msg.get("d", {})
|
||||
|
||||
if op == OpCodes.ready:
|
||||
self._add_ssrc(self.guild.me.id, data["ssrc"])
|
||||
elif op == OpCodes.speaking:
|
||||
uid = int(data["user_id"])
|
||||
ssrc = data["ssrc"]
|
||||
|
||||
self._add_ssrc(uid, ssrc)
|
||||
|
||||
member = self.guild.get_member(uid)
|
||||
state = try_enum(SpeakingState, data["speaking"])
|
||||
self.dispatch("member_speaking_state_update", member, ssrc, state)
|
||||
elif op == OpCodes.clients_connect:
|
||||
uids = list(map(int, data["user_ids"]))
|
||||
|
||||
for uid in uids:
|
||||
member = self.guild.get_member(uid)
|
||||
if not member:
|
||||
_log.warning(
|
||||
"Skipping member referencing ID %d on member_connect", uid
|
||||
)
|
||||
continue
|
||||
self.dispatch("member_connect", member)
|
||||
elif op == OpCodes.client_disconnect:
|
||||
uid = int(data["user_id"])
|
||||
ssrc = self._id_to_ssrc.get(uid)
|
||||
|
||||
if self._reader and ssrc is not None:
|
||||
_log.debug("Destroying decoder for user %d, ssrc=%d", uid, ssrc)
|
||||
self._reader.packet_router.destroy_decoder(ssrc)
|
||||
|
||||
self._remove_ssrc(user_id=uid)
|
||||
member = self.guild.get_member(uid)
|
||||
self.dispatch("member_disconnect", member, ssrc)
|
||||
|
||||
# maybe handle video and such things?
|
||||
|
||||
async def _run_event(
|
||||
self, coro, event_name: str, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
await coro(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
_log.exception("Error calling %s", event_name)
|
||||
|
||||
def _schedule_event(
|
||||
self, coro, event_name: str, *args: Any, **kwargs: Any
|
||||
) -> asyncio.Task:
|
||||
wrapped = self._run_event(coro, event_name, *args, **kwargs)
|
||||
return self.client.loop.create_task(
|
||||
wrapped, name=f"voice-receiver-event-dispatch: {event_name}"
|
||||
)
|
||||
|
||||
def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
|
||||
_log.debug("Dispatching voice_client event %s", event)
|
||||
|
||||
event_name = f"on_{event}"
|
||||
for coro in self._event_listeners.get(event_name, []):
|
||||
task = self._schedule_event(coro, event_name, *args, **kwargs)
|
||||
self._connection._dispatch_task_set.add(task)
|
||||
task.add_done_callback(self._connection._dispatch_task_set.discard)
|
||||
|
||||
self._dispatch_sink(event, *args, **kwargs)
|
||||
self.client.dispatch(event, *args, **kwargs)
|
||||
|
||||
async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
|
||||
old_channel_id = self.channel.id if self.channel else None
|
||||
await self._connection.voice_state_update(data)
|
||||
|
||||
if data.channel_id is None:
|
||||
return
|
||||
|
||||
if self._reader and data.channel_id != old_channel_id:
|
||||
_log.debug("Destroying voice receive decoders in guild %s", self.guild.id)
|
||||
self._reader.packet_router.destroy_all_decoders()
|
||||
|
||||
async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
|
||||
await self._connection.voice_server_update(data)
|
||||
|
||||
def _dispatch_sink(self, event: str, /, *args: Any, **kwargs: Any) -> None:
|
||||
if self._reader:
|
||||
self._reader.event_router.dispatch(event, *args, **kwargs)
|
||||
|
||||
def _add_ssrc(self, user_id: int, ssrc: int) -> None:
|
||||
self._ssrc_to_id[ssrc] = user_id
|
||||
self._id_to_ssrc[user_id] = ssrc
|
||||
|
||||
if self._reader:
|
||||
self._reader.packet_router.set_user_id(ssrc, user_id)
|
||||
|
||||
def _remove_ssrc(self, *, user_id: int) -> None:
|
||||
ssrc = self._id_to_ssrc.pop(user_id, None)
|
||||
|
||||
if ssrc:
|
||||
self._reader.speaking_timer.drop_ssrc(ssrc)
|
||||
self._ssrc_to_id.pop(ssrc, None)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
reconnect: bool,
|
||||
timeout: float,
|
||||
self_deaf: bool = False,
|
||||
self_mute: bool = False,
|
||||
) -> None:
|
||||
await self._connection.connect(
|
||||
reconnect=reconnect,
|
||||
timeout=timeout,
|
||||
self_deaf=self_deaf,
|
||||
self_mute=self_mute,
|
||||
resume=False,
|
||||
)
|
||||
|
||||
def wait_until_connected(self, timeout: float | None = 30.0) -> bool:
|
||||
self._connection.wait_for(timeout=timeout)
|
||||
return self._connection.is_connected()
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
"""Latency between a HEARTBEAT and a HEARBEAT_ACK in seconds.
|
||||
|
||||
This chould be referred to as the Discord Voice WebSocket latency and is
|
||||
and analogue of user's voice latencies as seen in the Discord client.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
return float("inf") if not ws else ws.latency
|
||||
|
||||
@property
|
||||
def average_latency(self) -> float:
|
||||
"""Average of most recent 20 HEARBEAT latencies in seconds.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
return float("inf") if not ws else ws.average_latency
|
||||
|
||||
@property
|
||||
def privacy_code(self) -> str | None:
|
||||
"""Returns the current voice session's privacy code, only available if the call has upgraded to use the
|
||||
DAVE protocol
|
||||
"""
|
||||
session = self._connection.dave_session
|
||||
return session and session.voice_privacy_code
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects this voice client from voice.
|
||||
"""
|
||||
|
||||
self.stop()
|
||||
await self._connection.disconnect(force=force, wait=True)
|
||||
self.cleanup()
|
||||
|
||||
async def move_to(
|
||||
self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
moves you to a different voice channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel: Optional[:class:`abc.Snowflake`]
|
||||
The channel to move to. If this is ``None``, it is an equivalent of calling :meth:`.disconnect`.
|
||||
timeout: Optional[:class:`float`]
|
||||
The maximum time in seconds to wait for the channel move to be completed, defaults to 30.
|
||||
If it is ``None``, then there is no timeout.
|
||||
|
||||
Raises
|
||||
------
|
||||
asyncio.TimeoutError
|
||||
Waiting for channel move timed out.
|
||||
"""
|
||||
await self._connection.move_to(channel, timeout)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Whether the voice client is connected to voice."""
|
||||
return self._connection.is_connected()
|
||||
|
||||
def is_playing(self) -> bool:
|
||||
"""INdicates if we're playing audio."""
|
||||
return self._player is not None and self._player.is_playing()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""Indicates if we're playing audio, but if we're paused."""
|
||||
return self._player is not None and self._player.is_paused()
|
||||
|
||||
# audio related
|
||||
|
||||
def _get_voice_packet(self, data: Any) -> bytes:
|
||||
|
||||
session = self._connection.dave_session
|
||||
packet = session.encrypt_opus(data) if session and session.ready else data
|
||||
|
||||
header = bytearray(12)
|
||||
|
||||
# formulate rtp header
|
||||
header[0] = 0x80
|
||||
header[1] = 0x78
|
||||
struct.pack_into(">H", header, 2, self.sequence)
|
||||
struct.pack_into(">I", header, 4, self.timestamp)
|
||||
struct.pack_into(">I", header, 8, self.ssrc)
|
||||
|
||||
encrypt_packet = getattr(self, f"_encrypt_{self.mode}")
|
||||
return encrypt_packet(header, packet)
|
||||
|
||||
# encryption methods
|
||||
|
||||
def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes:
|
||||
# deprecated
|
||||
box = nacl.secret.SecretBox(bytes(self.secret_key))
|
||||
nonce = bytearray(24)
|
||||
nonce[:12] = header
|
||||
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
|
||||
|
||||
def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes:
|
||||
# deprecated
|
||||
box = nacl.secret.SecretBox(bytes(self.secret_key))
|
||||
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
|
||||
return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
|
||||
|
||||
def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes:
|
||||
# deprecated
|
||||
box = nacl.secret.SecretBox(bytes(self.secret_key))
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = struct.pack(">I", self._incr_nonce)
|
||||
self.checked_add("_incr_nonce", 1, 4294967295)
|
||||
|
||||
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
|
||||
|
||||
def _encrypt_aead_xchacha20_poly1305_rtpsize(
|
||||
self, header: bytes, data: Any
|
||||
) -> bytes:
|
||||
box = nacl.secret.Aead(bytes(self.secret_key))
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = struct.pack(">I", self._incr_nonce)
|
||||
self.checked_add("_incr_nonce", 1, 4294967295)
|
||||
return (
|
||||
header
|
||||
+ box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext
|
||||
+ nonce[:4]
|
||||
)
|
||||
|
||||
@overload
|
||||
def play(
|
||||
self,
|
||||
source: AudioSource,
|
||||
*,
|
||||
after: AfterCallback | None = ...,
|
||||
application: APPLICATION_CTL = ...,
|
||||
bitrate: int = ...,
|
||||
fec: bool = ...,
|
||||
expected_packet_loss: float = ...,
|
||||
bandwidth: BAND_CTL = ...,
|
||||
signal_type: SIGNAL_CTL = ...,
|
||||
wait_finish: Literal[False] = ...,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def play(
|
||||
self,
|
||||
source: AudioSource,
|
||||
*,
|
||||
after: AfterCallback = ...,
|
||||
application: APPLICATION_CTL = ...,
|
||||
bitrate: int = ...,
|
||||
fec: bool = ...,
|
||||
expected_packet_loss: float = ...,
|
||||
bandwidth: BAND_CTL = ...,
|
||||
signal_type: SIGNAL_CTL = ...,
|
||||
wait_finish: Literal[True],
|
||||
) -> asyncio.Future[None]: ...
|
||||
|
||||
def play(
|
||||
self,
|
||||
source: AudioSource,
|
||||
*,
|
||||
after: AfterCallback | None = None,
|
||||
application: APPLICATION_CTL = "audio",
|
||||
bitrate: int = 128,
|
||||
fec: bool = True,
|
||||
expected_packet_loss: float = 0.15,
|
||||
bandwidth: BAND_CTL = "full",
|
||||
signal_type: SIGNAL_CTL = "auto",
|
||||
wait_finish: bool = False,
|
||||
) -> None | asyncio.Future[None]:
|
||||
"""Plays an :class:`AudioSource`.
|
||||
|
||||
The finalizer, ``after`` is called after the source has been exhausted
|
||||
or an error occurred.
|
||||
|
||||
IF an error happens while the audio player is running, the exception is
|
||||
caught and the audio player is then stopped. If no after callback is passed,
|
||||
any caught exception will be displayed as if it were raised.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source: :class:`AudioSource`
|
||||
The audio source we're reading from.
|
||||
after: Callable[[Optional[:class:`Exception`]], Any]
|
||||
The finalizer that is called after the stream is exhausted.
|
||||
This function must have a single parameter, ``error``, that
|
||||
denotes an optional exception that was raised during playing.
|
||||
application: :class:`str`
|
||||
The intended application encoder application type. Must be one of
|
||||
``audio``, ``voip``, or ``lowdelay``. Defaults to ``audio``.
|
||||
bitrate: :class:`int`
|
||||
The encoder's bitrate. Must be between ``16`` and ``512``. Defaults
|
||||
to ``128``.
|
||||
fec: :class:`bool`
|
||||
Configures the encoder's use of inband forward error correction (fec).
|
||||
Defaults to ``True``.
|
||||
expected_packet_loss: :class:`float`
|
||||
How much packet loss percentage is expected from the encoder. This requires ``fec``
|
||||
to be set to ``True``. Defaults to ``0.15``.
|
||||
bandwidth: :class:`str`
|
||||
The encoder's bandpass. Must be one of ``narrow``, ``medium``, ``wide``,
|
||||
``superwide``, or ``full``. Defaults to ``full``.
|
||||
signal_type: :class:`str`
|
||||
The type of signal being encoded. Must be one of ``auto``, ``voice``, ``music``.
|
||||
Defaults to ``auto``.
|
||||
wait_finish: :class:`bool`
|
||||
If ``True``, then an awaitable is returned that waits for the audio source to be
|
||||
exhausted, and will return an optional exception that could have been raised.
|
||||
|
||||
If ``False``, ``None`` is returned and the function does not block.
|
||||
|
||||
.. versionadded:: 2.5
|
||||
|
||||
Raises
|
||||
------
|
||||
ClientException
|
||||
Already playing audio, or not connected to voice.
|
||||
TypeError
|
||||
Source is not a :class:`AudioSource`, or after is not a callable.
|
||||
OpusNotLoaded
|
||||
Source is not opus encoded and opus is not loaded.
|
||||
"""
|
||||
|
||||
if not self.is_connected():
|
||||
raise ClientException("Not connected to voice")
|
||||
if self.is_playing():
|
||||
raise ClientException("Already playing audio")
|
||||
if not isinstance(source, AudioSource):
|
||||
raise TypeError(
|
||||
f"Source must be an AudioSource, not {source.__class__.__name__}",
|
||||
)
|
||||
if not self.encoder and not source.is_opus():
|
||||
self.encoder = opus.Encoder(
|
||||
application=application,
|
||||
bitrate=bitrate,
|
||||
fec=fec,
|
||||
expected_packet_loss=expected_packet_loss,
|
||||
bandwidth=bandwidth,
|
||||
signal_type=signal_type,
|
||||
)
|
||||
|
||||
future = None
|
||||
if wait_finish:
|
||||
self._player_future = future = self.loop.create_future()
|
||||
after_callback = after
|
||||
|
||||
def _after(exc: Exception | None) -> None:
|
||||
if callable(after_callback):
|
||||
after_callback(exc)
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._set_future_result_if_pending, future, exc
|
||||
)
|
||||
|
||||
after = _after
|
||||
|
||||
self._player = AudioPlayer(source, self, after=after)
|
||||
self._player.start()
|
||||
return future
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops playing audio, if applicable."""
|
||||
if self._player:
|
||||
self._player.stop()
|
||||
if self._player_future:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._set_future_result_if_pending, self._player_future, None
|
||||
)
|
||||
if self._reader is not MISSING:
|
||||
self._reader.stop()
|
||||
self._reader = MISSING
|
||||
|
||||
self._player = None
|
||||
self._player_future = None
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pauses the audio playing."""
|
||||
if self._player:
|
||||
self._player.pause()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resumes the audio playing."""
|
||||
if self._player:
|
||||
self._player.resume()
|
||||
|
||||
@property
|
||||
def source(self) -> AudioSource | None:
|
||||
"""The audio source being player, if playing.
|
||||
|
||||
This property can also be used to change the audio source currently being played.
|
||||
"""
|
||||
return self._player and self._player.source
|
||||
|
||||
@source.setter
|
||||
def source(self, value: AudioSource) -> None:
|
||||
if not isinstance(value, AudioSource):
|
||||
raise TypeError(f"expected AudioSource, not {value.__class__.__name__}")
|
||||
|
||||
if self._player is None:
|
||||
raise ValueError("the client is not playing anything")
|
||||
|
||||
self._player.set_source(value)
|
||||
|
||||
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
|
||||
"""Sends an audio packet composed of the ``data``.
|
||||
|
||||
You must be connected to play audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: :class:`bytes`
|
||||
The :term:`py:bytes-like object` denoting PCM or Opus voice data.
|
||||
encode: :class:`bool`
|
||||
Indicates if ``data`` should be encoded into Opus.
|
||||
|
||||
Raises
|
||||
------
|
||||
ClientException
|
||||
You are not connected.
|
||||
opus.OpusError
|
||||
Encoding the data failed.
|
||||
"""
|
||||
|
||||
self.checked_add("sequence", 1, 65535)
|
||||
if encode:
|
||||
encoded = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
|
||||
else:
|
||||
encoded = data
|
||||
|
||||
packet = self._get_voice_packet(encoded)
|
||||
try:
|
||||
self._connection.send_packet(packet)
|
||||
except OSError:
|
||||
_log.debug(
|
||||
"A packet has been dropped (seq: %s, timestamp: %s)",
|
||||
self.sequence,
|
||||
self.timestamp,
|
||||
)
|
||||
|
||||
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
"""Returns the elapsed time of the playing audio."""
|
||||
if self._player:
|
||||
return datetime.timedelta(milliseconds=self._player.played_frames() * 20)
|
||||
return datetime.timedelta()
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
sink: Sink,
|
||||
callback: AfterCallback | None = None,
|
||||
*args: Any,
|
||||
sync_start: bool = MISSING,
|
||||
) -> None:
|
||||
r"""Start recording the audio from the current connected channel to the provided sink.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. warning::
|
||||
|
||||
Recording may not work as expected due to the new DAVE (End-to-End Encryption) for voice calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sink: :class:`~.Sink`
|
||||
A Sink in which all audio packets will be processed in.
|
||||
callback: Callable[[:class:`Exception` | None], Any]
|
||||
A function which is called after the bot has stopped recording. This must take exactly one positonal(-only)
|
||||
parameter, ``exception``, which is the exception that was raised during the recording of the Sink.
|
||||
|
||||
.. versionchanged:: 2.7
|
||||
This parameter is now optional, and must take exactly one parameter, ``exception``.
|
||||
\*args:
|
||||
The arguments to pass to the callback coroutine.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
Passing custom arguments to the callback is now deprecated and ignored.
|
||||
sync_start: :class:`bool`
|
||||
If ``True``, the recordings of subsequent users will start with silence.
|
||||
This is useful for recording audio just as it was heard.
|
||||
|
||||
.. deprecated:: 2.7
|
||||
This parameter is now ignored and deprecated.
|
||||
|
||||
Raises
|
||||
------
|
||||
RecordingException
|
||||
Not connected to a voice channel
|
||||
TypeError
|
||||
You did not provide a Sink object.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
|
||||
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# TODO: remove warning in voice-recv fix PR
|
||||
if not self.is_connected():
|
||||
raise RecordingException("not connected to a voice channel")
|
||||
if not isinstance(sink, Sink):
|
||||
raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}")
|
||||
|
||||
if self.is_recording():
|
||||
raise ClientException("Already recording audio")
|
||||
|
||||
if len(args) > 0:
|
||||
warnings.warn(
|
||||
"'args' parameter is deprecated since 2.7 and will be removed in 3.0"
|
||||
)
|
||||
if sync_start is not MISSING:
|
||||
warnings.warn(
|
||||
"'sync_start' parameter is deprecated since 2.7 and will be removed in 3.0"
|
||||
)
|
||||
|
||||
self._reader = AudioReader(sink, self, after=callback, start=True)
|
||||
|
||||
start_listening = start_recording
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stops the recording of the provided ``sink``, or all recording sinks.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Raises
|
||||
------
|
||||
RecordingException
|
||||
You are not recording.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
|
||||
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._reader is not MISSING:
|
||||
self._reader.stop()
|
||||
self._reader = MISSING
|
||||
else:
|
||||
raise RecordingException("You are not recording")
|
||||
|
||||
stop_listening = stop_recording
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
"""Whether the current client is recording in any sink."""
|
||||
return self._reader and self._reader.is_listening()
|
||||
|
||||
def is_speaking(self, member: Member | User) -> bool | None:
|
||||
"""Whether a user is speaking.
|
||||
|
||||
This is an approximate calculation and may have outdated or wrong data.
|
||||
|
||||
If the member speaking status has not been yet saved, it returns ``None``.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
warnings.warn(
|
||||
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
|
||||
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
ssrc = self._id_to_ssrc.get(member.id)
|
||||
if ssrc is None:
|
||||
return None
|
||||
if self._reader:
|
||||
return self._reader.speaking_timer.get_speaking(ssrc)
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from discord.enums import Enum
|
||||
|
||||
|
||||
class OpCodes(Enum):
|
||||
identify = 0
|
||||
select_protocol = 1
|
||||
ready = 2
|
||||
heartbeat = 3
|
||||
session_description = 4
|
||||
speaking = 5
|
||||
heartbeat_ack = 6
|
||||
resume = 7
|
||||
hello = 8
|
||||
resumed = 9
|
||||
clients_connect = 11
|
||||
client_connect = 12
|
||||
client_disconnect = 13
|
||||
|
||||
# dave protocol stuff
|
||||
dave_prepare_transition = 21
|
||||
dave_execute_transition = 22
|
||||
dave_transition_ready = 23
|
||||
dave_prepare_epoch = 24
|
||||
mls_external_sender_package = 25
|
||||
mls_key_package = 26
|
||||
mls_proposals = 27
|
||||
mls_commit_welcome = 28
|
||||
mls_commit_transition = 29
|
||||
mls_welcome = 30
|
||||
mls_invalid_commit_welcome = 31
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
elif isinstance(other, self.__class__):
|
||||
return self is other
|
||||
return NotImplemented
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
|
||||
class ConnectionFlowState(Enum):
|
||||
disconnected = 0
|
||||
set_guild_voice_state = 1
|
||||
got_voice_state_update = 2
|
||||
got_voice_server_update = 3
|
||||
got_both_voice_updates = 4
|
||||
websocket_connected = 5
|
||||
got_websocket_ready = 6
|
||||
got_ip_discovery = 7
|
||||
connected = 8
|
||||
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .utils.dependencies import HAS_DAVEY
|
||||
|
||||
if HAS_DAVEY:
|
||||
import davey
|
||||
|
||||
from discord import utils
|
||||
from discord.enums import SpeakingState
|
||||
from discord.errors import ConnectionClosed
|
||||
from discord.gateway import DiscordWebSocket
|
||||
from discord.gateway import KeepAliveHandler as KeepAliveHandlerBase
|
||||
|
||||
from .enums import OpCodes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ConvertibleToInt
|
||||
from typing_extensions import Self
|
||||
|
||||
from .state import VoiceConnectionState
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeepAliveHandler(KeepAliveHandlerBase):
|
||||
if TYPE_CHECKING:
|
||||
ws: VoiceWebSocket
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
ws: VoiceWebSocket,
|
||||
interval: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
daemon: bool = kwargs.pop("daemon", True)
|
||||
name: str = kwargs.pop("name", f"voice-keep-alive-handler:{id(self):#x}")
|
||||
super().__init__(
|
||||
*args,
|
||||
**kwargs,
|
||||
name=name,
|
||||
daemon=daemon,
|
||||
ws=ws,
|
||||
interval=interval,
|
||||
)
|
||||
|
||||
self.msg: str = "Keeping shard ID %s voice websocket alive with timestamp %s."
|
||||
self.block_msg: str = (
|
||||
"Shard ID %s voice heartbeat blocked for more than %s seconds."
|
||||
)
|
||||
self.behing_msg: str = (
|
||||
"High socket latency, shard ID %s heartbeat is %.1fs behind."
|
||||
)
|
||||
self.recent_ack_latencies: deque[float] = deque(maxlen=20)
|
||||
|
||||
def get_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"op": int(OpCodes.heartbeat),
|
||||
"d": {
|
||||
"t": int(time.time() * 1000),
|
||||
"seq_ack": self.ws.seq_ack,
|
||||
},
|
||||
}
|
||||
|
||||
def ack(self) -> None:
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self._last_recv = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
self.recent_ack_latencies.append(self.latency)
|
||||
|
||||
|
||||
class VoiceWebSocket(DiscordWebSocket):
|
||||
def __init__(
|
||||
self,
|
||||
socket: aiohttp.ClientWebSocketResponse,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
state: VoiceConnectionState,
|
||||
*,
|
||||
hook: Callable[..., Coroutine[Any, Any, Any]] | None = None,
|
||||
) -> None:
|
||||
self.ws: aiohttp.ClientWebSocketResponse = socket
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
self._keep_alive: KeepAliveHandler | None = None
|
||||
self._close_code: int | None = None
|
||||
self.secret_key: list[int] | None = None
|
||||
self.seq_ack: int = -1
|
||||
self.state: VoiceConnectionState = state
|
||||
self.ssrc_map: dict[str, dict[str, Any]] = {}
|
||||
self.known_users: dict[int, Any] = {}
|
||||
|
||||
if hook:
|
||||
self._hook = hook or state.ws_hook # type: ignore
|
||||
|
||||
@property
|
||||
def token(self) -> str | None:
|
||||
return self.state.token
|
||||
|
||||
@token.setter
|
||||
def token(self, value: str | None) -> None:
|
||||
self.state.token = value
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
return self.state.session_id
|
||||
|
||||
@session_id.setter
|
||||
def session_id(self, value: str | None) -> None:
|
||||
self.state.session_id = value
|
||||
|
||||
@property
|
||||
def self_id(self) -> int:
|
||||
return self._connection.self_id
|
||||
|
||||
async def _hook(self, *args: Any) -> Any:
|
||||
pass
|
||||
|
||||
async def send_as_bytes(self, op: ConvertibleToInt, data: bytes) -> None:
|
||||
packet = bytes([int(op)]) + data
|
||||
_log.debug(
|
||||
"Sending voice websocket binary frame: op: %s size: %d", op, len(data)
|
||||
)
|
||||
await self.ws.send_bytes(packet)
|
||||
|
||||
async def send_as_json(self, data: Any) -> None:
|
||||
_log.debug("Sending voice websocket frame: %s.", data)
|
||||
if data.get("op", None) == OpCodes.identify:
|
||||
_log.info("Identifying ourselves: %s", data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
async def resume(self) -> None:
|
||||
payload = {
|
||||
"op": int(OpCodes.resume),
|
||||
"d": {
|
||||
"token": self.token,
|
||||
"server_id": str(self.state.server_id),
|
||||
"session_id": self.session_id,
|
||||
"seq_ack": self.seq_ack,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg: Any, /):
|
||||
_log.debug("Voice websocket frame received: %s", msg)
|
||||
op = msg["op"]
|
||||
data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways
|
||||
self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated
|
||||
state = self.state
|
||||
|
||||
if op == OpCodes.ready:
|
||||
await self.ready(data)
|
||||
elif op == OpCodes.heartbeat_ack:
|
||||
if not self._keep_alive:
|
||||
_log.error(
|
||||
"Received a heartbeat ACK but no keep alive handler was set.",
|
||||
)
|
||||
return
|
||||
self._keep_alive.ack()
|
||||
elif op == OpCodes.resumed:
|
||||
_log.info(
|
||||
f"Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was "
|
||||
"successfully RESUMED.",
|
||||
)
|
||||
elif op == OpCodes.session_description:
|
||||
state.mode = data["mode"]
|
||||
state.dave_protocol_version = data["dave_protocol_version"]
|
||||
await self.load_secret_key(data)
|
||||
await state.reinit_dave_session()
|
||||
elif op == OpCodes.hello:
|
||||
interval = data["heartbeat_interval"] / 1000.0
|
||||
self._keep_alive = KeepAliveHandler(
|
||||
ws=self,
|
||||
interval=min(interval, 5),
|
||||
)
|
||||
self._keep_alive.start()
|
||||
elif state.dave_session:
|
||||
if op == OpCodes.dave_prepare_transition:
|
||||
_log.info(
|
||||
"Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d",
|
||||
state.channel_id,
|
||||
data["transition_id"],
|
||||
data["protocol_version"],
|
||||
)
|
||||
state.dave_pending_transition = data
|
||||
|
||||
transition_id = data["transition_id"]
|
||||
|
||||
if transition_id == 0:
|
||||
await state.execute_dave_transition(data["transition_id"])
|
||||
else:
|
||||
if data["protocol_version"] == 0 and state.dave_session:
|
||||
state.dave_session.set_passthrough_mode(True, 10)
|
||||
await self.send_dave_transition_ready(transition_id)
|
||||
elif op == OpCodes.dave_execute_transition:
|
||||
_log.info(
|
||||
"Upgrading to DAVE connection for channel %s", state.channel_id
|
||||
)
|
||||
await state.execute_dave_transition(data["transition_id"])
|
||||
elif op == OpCodes.dave_prepare_epoch:
|
||||
epoch = data["epoch"]
|
||||
_log.debug(
|
||||
"Preparing for DAVE epoch in channel %s: %s",
|
||||
state.channel_id,
|
||||
epoch,
|
||||
)
|
||||
# if epoch is 1 then a new MLS group is to be created for the proto version
|
||||
if epoch == 1:
|
||||
state.dave_protocol_version = data["protocol_version"]
|
||||
await state.reinit_dave_session()
|
||||
else:
|
||||
_log.debug("Unhandled op code: %s with data %s", op, data)
|
||||
|
||||
await utils.maybe_coroutine(self._hook, self, msg)
|
||||
|
||||
async def received_binary_message(self, msg: bytes) -> None:
|
||||
self.seq_ack = struct.unpack_from(">H", msg, 0)[0]
|
||||
op = msg[2]
|
||||
_log.debug(
|
||||
"Voice websocket binary frame received: %d bytes, seq: %s, op: %s",
|
||||
len(msg),
|
||||
self.seq_ack,
|
||||
op,
|
||||
)
|
||||
|
||||
state = self.state
|
||||
|
||||
if not state.dave_session:
|
||||
return
|
||||
|
||||
if op == OpCodes.mls_external_sender_package:
|
||||
_log.debug("Received MLS External Sender Package, applying to DAVE session")
|
||||
state.dave_session.set_external_sender(msg[3:])
|
||||
_log.debug(
|
||||
"Applied MLS External Sender Package, user IDs available: %s",
|
||||
state.dave_session.get_user_ids(),
|
||||
)
|
||||
elif op == OpCodes.mls_proposals:
|
||||
op_type = msg[3]
|
||||
result = state.dave_session.process_proposals(
|
||||
(
|
||||
davey.ProposalsOperationType.append
|
||||
if op_type == 0
|
||||
else davey.ProposalsOperationType.revoke
|
||||
),
|
||||
msg[4:],
|
||||
)
|
||||
|
||||
if isinstance(result, davey.CommitWelcome):
|
||||
data = (
|
||||
(result.commit + result.welcome)
|
||||
if result.welcome
|
||||
else result.commit
|
||||
)
|
||||
_log.debug("Sending MLS key package with data: %s", data)
|
||||
await self.send_as_bytes(
|
||||
OpCodes.mls_commit_welcome,
|
||||
data,
|
||||
)
|
||||
_log.debug("Processed MLS proposals for current dave session: %r", result)
|
||||
elif op == OpCodes.mls_commit_transition:
|
||||
transt_id = struct.unpack_from(">H", msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_commit(msg[5:])
|
||||
if transt_id != 0:
|
||||
state.dave_pending_transition = {
|
||||
"transition_id": transt_id,
|
||||
"protocol_version": state.dave_protocol_version,
|
||||
}
|
||||
_log.debug(
|
||||
"Sending DAVE transition ready from MLS commit transition with data: %s",
|
||||
state.dave_pending_transition,
|
||||
)
|
||||
await self.send_dave_transition_ready(transt_id)
|
||||
_log.debug("Processed MLS commit for transition %s", transt_id)
|
||||
except Exception as exc:
|
||||
_log.debug(
|
||||
"An exception ocurred while processing a MLS commit, this should be safe to ignore: %s",
|
||||
exc,
|
||||
)
|
||||
await state.recover_dave_from_invalid_commit(transt_id)
|
||||
elif op == OpCodes.mls_welcome:
|
||||
transt_id = struct.unpack_from(">H", msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_welcome(msg[5:])
|
||||
if transt_id != 0:
|
||||
state.dave_pending_transition = {
|
||||
"transition_id": transt_id,
|
||||
"protocol_version": state.dave_protocol_version,
|
||||
}
|
||||
_log.debug(
|
||||
"Sending DAVE transition ready from MLS welcome with data: %s",
|
||||
state.dave_pending_transition,
|
||||
)
|
||||
await self.send_dave_transition_ready(transt_id)
|
||||
_log.debug("Processed MLS welcome for transition %s", transt_id)
|
||||
except Exception as exc:
|
||||
_log.debug(
|
||||
"An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s",
|
||||
exc,
|
||||
)
|
||||
await state.recover_dave_from_invalid_commit(transt_id)
|
||||
|
||||
async def ready(self, data: dict[str, Any]) -> None:
|
||||
state = self.state
|
||||
|
||||
state.ssrc = data["ssrc"]
|
||||
state.voice_port = data["port"]
|
||||
state.endpoint_ip = data["ip"]
|
||||
|
||||
_log.debug(
|
||||
f"Connecting to {state.endpoint_ip} (port {state.voice_port}).",
|
||||
)
|
||||
|
||||
await self.loop.sock_connect(
|
||||
state.socket,
|
||||
(state.endpoint_ip, state.voice_port),
|
||||
)
|
||||
|
||||
_log.debug(
|
||||
"Connected socket to %s (port %s)",
|
||||
state.endpoint_ip,
|
||||
state.voice_port,
|
||||
)
|
||||
|
||||
state.ip, state.port = await self.get_ip()
|
||||
|
||||
modes = [mode for mode in data["modes"] if mode in self.state.supported_modes]
|
||||
_log.debug("Received available voice connection modes: %s", modes)
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
_log.debug("Selected voice protocol %s for this connection", mode)
|
||||
|
||||
async def select_protocol(self, ip: str, port: int, mode: str) -> None:
|
||||
payload = {
|
||||
"op": int(OpCodes.select_protocol),
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {
|
||||
"address": ip,
|
||||
"port": port,
|
||||
"mode": mode,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def get_ip(self) -> tuple[str, int]:
|
||||
state = self.state
|
||||
packet = bytearray(74)
|
||||
struct.pack_into(">H", packet, 0, 1) # 1 = Send
|
||||
struct.pack_into(">H", packet, 2, 70) # 70 = Length
|
||||
struct.pack_into(">I", packet, 4, state.ssrc)
|
||||
|
||||
_log.debug(
|
||||
f"Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})"
|
||||
)
|
||||
await self.loop.sock_sendall(state.socket, packet)
|
||||
|
||||
fut: asyncio.Future[bytes] = self.loop.create_future()
|
||||
|
||||
def get_ip_packet(data: bytes) -> None:
|
||||
if data[1] == 0x02 and len(data) == 74:
|
||||
self.loop.call_soon_threadsafe(fut.set_result, data)
|
||||
|
||||
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
|
||||
state.add_socket_listener(get_ip_packet)
|
||||
recv = await fut
|
||||
|
||||
_log.debug("Received IP discovery packet with data %s", recv)
|
||||
|
||||
ip_start = 8
|
||||
ip_end = recv.index(0, ip_start)
|
||||
ip = recv[ip_start:ip_end].decode("ascii")
|
||||
port = struct.unpack_from(">H", recv, len(recv) - 2)[0]
|
||||
_log.debug("Detected IP %s with port %s", ip, port)
|
||||
|
||||
return ip, port
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
heartbeat = self._keep_alive
|
||||
return float("inf") if heartbeat is None else heartbeat.latency
|
||||
|
||||
@property
|
||||
def average_latency(self) -> float:
|
||||
heartbeat = self._keep_alive
|
||||
if heartbeat is None or not heartbeat.recent_ack_latencies:
|
||||
return float("inf")
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
async def load_secret_key(self, data: dict[str, Any]) -> None:
|
||||
_log.debug(
|
||||
f"Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})"
|
||||
)
|
||||
self.secret_key = self.state.secret_key = data["secret_key"]
|
||||
await self.speak(SpeakingState.none)
|
||||
|
||||
async def poll_event(self) -> None:
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30)
|
||||
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
_log.debug("Received text payload: %s", msg.data)
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
_log.debug("Received binary payload: size: %d", len(msg.data))
|
||||
await self.received_binary_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug("Received %s", msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSED,
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.CLOSING,
|
||||
):
|
||||
_log.debug("Received %s", msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
if self._keep_alive:
|
||||
self._keep_alive.stop()
|
||||
|
||||
self._close_code = code
|
||||
await self.ws.close(code=self._close_code)
|
||||
|
||||
async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:
|
||||
await self.send_as_json(
|
||||
{
|
||||
"op": int(OpCodes.speaking),
|
||||
"d": {
|
||||
"speaking": int(state),
|
||||
"delay": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_state(
|
||||
cls,
|
||||
state: VoiceConnectionState,
|
||||
*,
|
||||
resume: bool = False,
|
||||
hook: Callable[..., Coroutine[Any, Any, Any]] | None = None,
|
||||
seq_ack: int = -1,
|
||||
) -> Self:
|
||||
gateway = f"wss://{state.endpoint}/?v=8"
|
||||
client = state.client
|
||||
http = client._state.http
|
||||
socket = await http.ws_connect(gateway, compress=15)
|
||||
ws = cls(socket, loop=client.loop, hook=hook, state=state)
|
||||
ws.gateway = gateway
|
||||
ws.seq_ack = seq_ack
|
||||
ws._max_heartbeat_timeout = 60.0
|
||||
ws.thread_id = threading.get_ident()
|
||||
|
||||
if resume:
|
||||
await ws.resume()
|
||||
else:
|
||||
await ws.identify()
|
||||
return ws
|
||||
|
||||
async def identify(self) -> None:
|
||||
state = self.state
|
||||
payload = {
|
||||
"op": int(OpCodes.identify),
|
||||
"d": {
|
||||
"server_id": str(state.server_id),
|
||||
"user_id": str(state.user.id),
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
"max_dave_protocol_version": state.max_dave_proto_version,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def send_dave_transition_ready(self, transition_id: int) -> None:
|
||||
payload = {
|
||||
"op": int(OpCodes.dave_transition_ready),
|
||||
"d": {
|
||||
"transition_id": transition_id,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
discord.voice.packets
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Sink packet handlers.
|
||||
:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development
|
||||
:license: MIT, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .core import Packet
|
||||
from .rtp import (
|
||||
FakePacket,
|
||||
ReceiverReportPacket,
|
||||
RTCPPacket,
|
||||
RTPPacket,
|
||||
SenderReportPacket,
|
||||
SilencePacket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import Member, User
|
||||
|
||||
__all__ = (
|
||||
"Packet",
|
||||
"RTPPacket",
|
||||
"RTCPPacket",
|
||||
"FakePacket",
|
||||
"ReceiverReportPacket",
|
||||
"SenderReportPacket",
|
||||
"SilencePacket",
|
||||
"VoiceData",
|
||||
)
|
||||
|
||||
|
||||
class VoiceData:
|
||||
"""Represents an audio data from a source.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
|
||||
Attributes
|
||||
----------
|
||||
packet: :class:`~discord.sinks.Packet`
|
||||
The packet this source data contains.
|
||||
source: :class:`~discord.User` | :class:`~discord.Member` | None
|
||||
The user that emitted this audio source.
|
||||
pcm: :class:`bytes`
|
||||
The PCM bytes of this source.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None
|
||||
) -> None:
|
||||
self.packet: Packet = packet
|
||||
self.source: User | Member | None = source
|
||||
self.pcm: bytes = pcm if pcm else b""
|
||||
|
||||
@property
|
||||
def opus(self) -> bytes | None:
|
||||
return self.packet.decrypted_data
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Final
|
||||
|
||||
OPUS_SILENCE: Final = b"\xf8\xff\xfe"
|
||||
|
||||
|
||||
class Packet:
|
||||
"""Represents an audio stream bytes packet.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
data: :class:`bytes`
|
||||
The bytes data of this packet. This has not been decoded.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ssrc: int
|
||||
sequence: int
|
||||
timestamp: int
|
||||
type: int
|
||||
decrypted_data: bytes
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
self.data: bytes = data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}> data={len(self.data)} bytes>"
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.data)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
if self.ssrc != other.ssrc:
|
||||
raise TypeError(
|
||||
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
|
||||
)
|
||||
return self.sequence == other.sequence and self.timestamp == other.timestamp
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
if self.ssrc != other.ssrc:
|
||||
raise TypeError(
|
||||
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
|
||||
)
|
||||
return self.sequence > other.sequence and self.timestamp > other.timestamp
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
if self.ssrc != other.ssrc:
|
||||
raise TypeError(
|
||||
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
|
||||
)
|
||||
return self.sequence < other.sequence and self.timestamp < other.timestamp
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
data = getattr(self, "decrypted_data", None)
|
||||
return data == OPUS_SILENCE
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.data)
|
||||
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from .core import OPUS_SILENCE, Packet
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Final
|
||||
|
||||
MAX_UINT_32 = 0xFFFFFFFF
|
||||
MAX_UINT_16 = 0xFFFF
|
||||
|
||||
RTP_PACKET_TYPE_VOICE = 120
|
||||
|
||||
|
||||
def decode(data: bytes) -> Packet:
|
||||
if not data[0] >> 6 == 2:
|
||||
raise ValueError(f"Invalid packet header 0b{data[0]:0>8b}")
|
||||
return _rtcp_map.get(data[1], RTPPacket)(data)
|
||||
|
||||
|
||||
class FakePacket(Packet):
|
||||
data = b""
|
||||
decrypted_data: bytes = b""
|
||||
extension_data: dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssrc: int,
|
||||
sequence: int,
|
||||
timestamp: int,
|
||||
) -> None:
|
||||
self.ssrc = ssrc
|
||||
self.sequence = sequence
|
||||
self.timestamp = timestamp
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
class SilencePacket(Packet):
|
||||
decrypted_data: Final = OPUS_SILENCE
|
||||
extension_data: Final[dict[int, Any]] = {}
|
||||
sequence: int = -1
|
||||
|
||||
def __init__(self, ssrc: int, timestamp: int) -> None:
|
||||
self.ssrc = ssrc
|
||||
self.timestamp = timestamp
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class RTPPacket(Packet):
|
||||
"""Represents an RTP packet.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
|
||||
Attributes
|
||||
----------
|
||||
data: :class:`bytes`
|
||||
The raw data of the packet.
|
||||
"""
|
||||
|
||||
_hstruct = struct.Struct(">xxHII")
|
||||
_ext_header = namedtuple("Extension", "profile length values")
|
||||
_ext_magic = b"\xbe\xde"
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.version: int = data[0] >> 6
|
||||
self.padding: bool = bool(data[0] & 0b00100000)
|
||||
self.extended: bool = bool(data[0] & 0b00010000)
|
||||
self.cc: int = data[0] & 0b00001111
|
||||
|
||||
self.marker: bool = bool(data[1] & 0b10000000)
|
||||
self.payload: int = data[1] & 0b01111111
|
||||
|
||||
sequence, timestamp, ssrc = self._hstruct.unpack_from(data)
|
||||
self.sequence = sequence
|
||||
self.timestamp = timestamp
|
||||
self.ssrc = ssrc
|
||||
|
||||
self.csrcs: tuple[int, ...] = ()
|
||||
self.extension = None
|
||||
self.extension_data: dict[int, bytes] = {}
|
||||
|
||||
self.header = data[:12]
|
||||
self.data = data[12:]
|
||||
self.decrypted_data: bytes | None = None
|
||||
|
||||
self.nonce: bytes = b""
|
||||
self._rtpsize: bool = False
|
||||
|
||||
if self.cc:
|
||||
fmt = ">%sI" % self.cc
|
||||
offset = struct.calcsize(fmt) + 12
|
||||
self.csrcs = struct.unpack(fmt, data[12:offset])
|
||||
self.data = data[offset:]
|
||||
|
||||
def adjust_rtpsize(self) -> None:
|
||||
"""Automatically adjusts this packet header and data based on the rtpsize format."""
|
||||
|
||||
self._rtpsize = True
|
||||
self.nonce = self.data[-4:]
|
||||
|
||||
if not self.extended:
|
||||
self.data = self.data[:-4]
|
||||
return
|
||||
|
||||
self.header += self.data[:4]
|
||||
self.data = self.data[4:-4]
|
||||
|
||||
def update_extended_header(self, data: bytes) -> int:
|
||||
"""Updates the extended header using ``data`` and returns the pd offset."""
|
||||
|
||||
if not self.extended:
|
||||
return 0
|
||||
|
||||
if self._rtpsize:
|
||||
data = self.header[-4:] + data
|
||||
|
||||
if len(data) < 4:
|
||||
return 0
|
||||
|
||||
profile, length = struct.unpack_from(">2sH", data)
|
||||
total_ext_length = length * 4
|
||||
|
||||
if profile == self._ext_magic:
|
||||
self._parse_bede_header(data, length)
|
||||
|
||||
if len(data) >= 4 + total_ext_length:
|
||||
try:
|
||||
values = struct.unpack(">%sI" % length, data[4 : 4 + total_ext_length])
|
||||
self.extension = self._ext_header(profile, length, values)
|
||||
except struct.error:
|
||||
self.extension = self._ext_header(profile, 0, [])
|
||||
|
||||
offset = 4 + total_ext_length
|
||||
|
||||
if self._rtpsize:
|
||||
offset -= 4
|
||||
|
||||
return max(0, min(offset, len(data)))
|
||||
|
||||
def _parse_bede_header(self, data: bytes, length: int) -> None:
|
||||
offset = 4
|
||||
n = 0
|
||||
|
||||
max_bytes = length * 4 + 4
|
||||
|
||||
while n < length:
|
||||
if offset >= len(data) or offset >= max_bytes:
|
||||
break
|
||||
|
||||
next_byte = data[offset : offset + 1]
|
||||
|
||||
if next_byte == b"\x00":
|
||||
offset += 1
|
||||
continue
|
||||
|
||||
header = struct.unpack(">B", next_byte)[0]
|
||||
el_id = header >> 4
|
||||
el_len = 1 + (header & 0b0000_1111)
|
||||
|
||||
self.extension_data[el_id] = data[offset + 1 : offset + 1 + el_len]
|
||||
offset += 1 + el_len
|
||||
n += 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"<RTPPacket "
|
||||
f"ssrc={self.ssrc} "
|
||||
f"sequence={self.sequence} "
|
||||
f"timestamp={self.timestamp} "
|
||||
f"size={len(self.data)} "
|
||||
f"ext={set(self.extension_data)}"
|
||||
">"
|
||||
)
|
||||
|
||||
|
||||
class RTCPPacket(Packet):
|
||||
_header = struct.Struct(">BBH")
|
||||
_ssrc_fmt = struct.Struct(">I")
|
||||
type = None
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
super().__init__(data)
|
||||
self.length: int
|
||||
head, _, self.length = self._header.unpack_from(data)
|
||||
|
||||
self.version: int = head >> 6
|
||||
self.padding: bool = bool(head & 0b00100000)
|
||||
setattr(self, "report_count", head & 0b00011111)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} version={self.version} padding={self.padding} length={self.length}>"
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: bytes) -> Packet:
|
||||
_, ptype, _ = cls._header.unpack_from(data)
|
||||
return _rtcp_map[ptype](data)
|
||||
|
||||
|
||||
def _parse_low(x: int, bitlen: int = 32) -> float:
|
||||
return x / 2.0**bitlen
|
||||
|
||||
|
||||
def _to_low(x: float, bitlen: int = 32) -> int:
|
||||
return int(x * 2.0**bitlen)
|
||||
|
||||
|
||||
class SenderReportPacket(RTCPPacket):
|
||||
_info_fmt = struct.Struct(">5I")
|
||||
_report_fmt = struct.Struct(">IB3x4I")
|
||||
_24bit_int_fmt = struct.Struct(">4xI")
|
||||
_info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count")
|
||||
_report = namedtuple(
|
||||
"RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr"
|
||||
)
|
||||
type = 200
|
||||
|
||||
if TYPE_CHECKING:
|
||||
report_count: int
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.ssrc = self._ssrc_fmt.unpack_from(data, 4)[0]
|
||||
self.info = self._read_sender_info(data, 8)
|
||||
|
||||
_report = self._report
|
||||
reports: list[_report] = []
|
||||
for x in range(self.report_count):
|
||||
offset = 28 + 24 * x
|
||||
reports.append(self._read_report(data, offset))
|
||||
|
||||
self.reports: tuple[_report, ...] = tuple(reports)
|
||||
self.extension = None
|
||||
if len(data) > 28 + 24 * self.report_count:
|
||||
self.extension = data[28 + 24 * self.report_count :]
|
||||
|
||||
def _read_sender_info(self, data: bytes, offset: int) -> _info:
|
||||
nhigh, nlow, rtp_ts, pcount, ocount = self._info_fmt.unpack_from(data, offset)
|
||||
ntotal = nhigh + _parse_low(nlow)
|
||||
return self._info(ntotal, rtp_ts, pcount, ocount)
|
||||
|
||||
def _read_report(self, data: bytes, offset: int) -> _report:
|
||||
ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset)
|
||||
clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF
|
||||
return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr)
|
||||
|
||||
|
||||
class ReceiverReportPacket(RTCPPacket):
|
||||
_report_fmt = struct.Struct(">IB3x4I")
|
||||
_24bit_int_fmt = struct.Struct(">4xI")
|
||||
_report = namedtuple(
|
||||
"RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr"
|
||||
)
|
||||
type = 201
|
||||
|
||||
reports: tuple[_report, ...]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
report_count: int
|
||||
|
||||
def __init__(self, data: bytes) -> None:
|
||||
super().__init__(data)
|
||||
self.ssrc: int = self._ssrc_fmt.unpack_from(data, 4)[0]
|
||||
|
||||
_report = self._report
|
||||
reports: list[_report] = []
|
||||
for x in range(self.report_count):
|
||||
offset = 8 + 24 * x
|
||||
reports.append(self._read_report(data, offset))
|
||||
|
||||
self.reports = tuple(reports)
|
||||
|
||||
self.extension: bytes | None = None
|
||||
if len(data) > 8 + 24 * self.report_count:
|
||||
self.extension = data[8 + 24 * self.report_count :]
|
||||
|
||||
def _read_report(self, data: bytes, offset: int) -> _report:
|
||||
ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset)
|
||||
clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF
|
||||
return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr)
|
||||
|
||||
|
||||
_rtcp_map = {
|
||||
200: SenderReportPacket,
|
||||
201: ReceiverReportPacket,
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from .reader import AudioReader
|
||||
from .router import PacketRouter, SinkEventRouter
|
||||
|
||||
__all__ = (
|
||||
"AudioReader",
|
||||
"PacketRouter",
|
||||
"SinkEventRouter",
|
||||
)
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from ..packets.core import OPUS_SILENCE
|
||||
from ..packets.rtp import ReceiverReportPacket, RTCPPacket, decode
|
||||
from ..utils.dependencies import HAS_DAVEY, HAS_NACL
|
||||
from .router import PacketRouter, SinkEventRouter
|
||||
|
||||
if HAS_DAVEY:
|
||||
import davey
|
||||
|
||||
if HAS_NACL:
|
||||
import nacl.secret
|
||||
from nacl.exceptions import CryptoError
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.member import Member
|
||||
from discord.sinks import Sink
|
||||
from discord.types.voice import SupportedModes
|
||||
|
||||
from ..client import VoiceClient
|
||||
from ..packets import RTPPacket
|
||||
|
||||
AfterCallback = Callable[[Exception | None], Any]
|
||||
DecryptRTP = Callable[[RTPPacket], bytes]
|
||||
DecryptRTCP = Callable[[bytes], bytes]
|
||||
SpeakingEvent = Literal["member_speaking_start", "member_speaking_stop"]
|
||||
EncryptionBox = nacl.secret.SecretBox | nacl.secret.Aead
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ("AudioReader",)
|
||||
|
||||
|
||||
def is_rtcp(data: bytes) -> bool:
|
||||
return 200 <= data[1] <= 204
|
||||
|
||||
|
||||
class AudioReader:
|
||||
def __init__(
|
||||
self,
|
||||
sink: Sink,
|
||||
client: VoiceClient,
|
||||
*,
|
||||
after: AfterCallback | None = None,
|
||||
start: bool = False,
|
||||
) -> None:
|
||||
if after is not None and not callable(after):
|
||||
raise TypeError(
|
||||
f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead"
|
||||
)
|
||||
|
||||
self.sink: Sink = sink
|
||||
self.client: VoiceClient = client
|
||||
self.after: AfterCallback | None = after
|
||||
|
||||
# self.sink._client = client
|
||||
|
||||
self.active: bool = False
|
||||
self.error: Exception | None = None
|
||||
self.packet_router: PacketRouter = PacketRouter(self.sink, self)
|
||||
self.event_router: SinkEventRouter = SinkEventRouter(self.sink, self)
|
||||
self.decryptor: PacketDecryptor = PacketDecryptor(
|
||||
client.mode, bytes(client.secret_key), client
|
||||
)
|
||||
self.speaking_timer: SpeakingTimer = SpeakingTimer(self)
|
||||
self.keep_alive: UDPKeepAlive = UDPKeepAlive(client)
|
||||
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
def is_listening(self) -> bool:
|
||||
return self.active
|
||||
|
||||
def update_secret_key(self, secret_key: bytes) -> None:
|
||||
self.decryptor.update_secret_key(secret_key)
|
||||
|
||||
def start(self) -> None:
|
||||
if self.active:
|
||||
_log.debug("Reader is already running", exc_info=True)
|
||||
return
|
||||
|
||||
self.client._connection.add_socket_listener(self.callback)
|
||||
self.speaking_timer.start()
|
||||
self.event_router.start()
|
||||
self.packet_router.start()
|
||||
self.keep_alive.start()
|
||||
self.active = True
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.active:
|
||||
_log.debug("Reader is not active")
|
||||
return
|
||||
|
||||
self.client._connection.remove_socket_listener(self.callback)
|
||||
self.speaking_timer.notify()
|
||||
self._stop()
|
||||
self.active = False
|
||||
|
||||
def _stop(self) -> None:
|
||||
try:
|
||||
if self.packet_router.is_alive():
|
||||
self.packet_router.stop()
|
||||
except Exception as exc:
|
||||
self.error = exc
|
||||
_log.exception("An error ocurred while stopping packet router.")
|
||||
|
||||
try:
|
||||
self.event_router.stop()
|
||||
except Exception as exc:
|
||||
self.error = exc
|
||||
_log.exception("An error ocurred while stopping event router.")
|
||||
|
||||
self.speaking_timer.stop()
|
||||
self.keep_alive.stop()
|
||||
|
||||
if self.after:
|
||||
try:
|
||||
self.after(self.error)
|
||||
except Exception:
|
||||
_log.exception(
|
||||
"An error ocurred while calling the after callback on audio reader"
|
||||
)
|
||||
|
||||
"""for sink in self.sink.root.walk_children(with_self=True):
|
||||
try:
|
||||
sink.cleanup()
|
||||
except Exception as exc:
|
||||
_log.exception("Error calling cleanup() for %s", sink, exc_info=exc)"""
|
||||
|
||||
def set_sink(self, sink: Sink) -> Sink:
|
||||
old_sink = self.sink
|
||||
# old_sink._client = None
|
||||
# sink._client = self.client
|
||||
self.packet_router.set_sink(sink)
|
||||
self.sink = sink
|
||||
return old_sink
|
||||
|
||||
def _is_ip_discovery_packet(self, data: bytes) -> bool:
|
||||
return len(data) == 74 and data[1] == 0x02
|
||||
|
||||
def callback(self, packet_data: bytes) -> None:
|
||||
|
||||
packet = rtp_packet = rtcp_packet = None
|
||||
|
||||
try:
|
||||
if not is_rtcp(packet_data):
|
||||
packet = rtp_packet = decode(packet_data)
|
||||
packet.decrypted_data = self.decryptor.decrypt_rtp(packet) # type: ignore
|
||||
else:
|
||||
packet = rtcp_packet = decode(packet_data)
|
||||
|
||||
if not isinstance(packet, ReceiverReportPacket):
|
||||
_log.info(
|
||||
"Received unexpected rtcp packet type=%s, %s",
|
||||
packet.type,
|
||||
type(packet),
|
||||
)
|
||||
except CryptoError as exc:
|
||||
_log.error("CryptoError while decoding a voice packet", exc_info=exc)
|
||||
return
|
||||
except Exception as exc:
|
||||
if self._is_ip_discovery_packet(packet_data):
|
||||
_log.debug("Received an IP Discovery Packet, ignoring...")
|
||||
return
|
||||
_log.exception(
|
||||
"An exception ocurred while decoding voice packets", exc_info=exc
|
||||
)
|
||||
finally:
|
||||
if self.error:
|
||||
_log.debug("Callback errored out, stopping: %s", self.error)
|
||||
self.stop()
|
||||
return
|
||||
if not packet:
|
||||
_log.debug("No packet found after callback")
|
||||
return
|
||||
|
||||
if rtcp_packet:
|
||||
self.packet_router.feed_rtcp(rtcp_packet) # type: ignore
|
||||
elif rtp_packet:
|
||||
|
||||
if not rtp_packet.decrypted_data:
|
||||
_log.debug(
|
||||
"No decrypted data for RTP packet, this should be safe to ignore."
|
||||
)
|
||||
return
|
||||
|
||||
ssrc = rtp_packet.ssrc
|
||||
|
||||
if ssrc not in self.client._connection.ssrc_user_map:
|
||||
if rtp_packet.is_silence():
|
||||
return
|
||||
else:
|
||||
_log.info(
|
||||
"Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet
|
||||
)
|
||||
_log.debug(
|
||||
"Current SSRCs: %s", self.client._connection.ssrc_user_map
|
||||
)
|
||||
|
||||
self.speaking_timer.notify(ssrc)
|
||||
|
||||
try:
|
||||
_log.debug("Feeding packet to packet router")
|
||||
self.packet_router.feed_rtp(rtp_packet) # type: ignore
|
||||
except Exception as exc:
|
||||
_log.exception(
|
||||
"An error ocurred while processing RTP packet %s", rtp_packet
|
||||
)
|
||||
self.error = exc
|
||||
self.stop()
|
||||
|
||||
|
||||
class PacketDecryptor:
|
||||
supported_modes: list[SupportedModes] = [
|
||||
"aead_xchacha20_poly1305_rtpsize",
|
||||
"xsalsa20_poly1305",
|
||||
"xsalsa20_poly1305_lite",
|
||||
"xsalsa20_poly1305_suffix",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, mode: SupportedModes, secret_key: bytes, client: VoiceClient
|
||||
) -> None:
|
||||
self.mode: SupportedModes = mode
|
||||
self.client: VoiceClient = client
|
||||
|
||||
try:
|
||||
self._decryptor_rtp: DecryptRTP = getattr(self, "_decrypt_rtp_" + mode)
|
||||
self._decryptor_rtcp: DecryptRTCP = getattr(self, "_decrypt_rtcp_" + mode)
|
||||
except AttributeError as exc:
|
||||
raise NotImplementedError(mode) from exc
|
||||
|
||||
self.box: EncryptionBox = self._make_box(secret_key)
|
||||
|
||||
def _make_box(self, secret_key: bytes) -> EncryptionBox:
|
||||
if self.mode.startswith("aead"):
|
||||
return nacl.secret.Aead(secret_key)
|
||||
else:
|
||||
return nacl.secret.SecretBox(secret_key)
|
||||
|
||||
"""def decrypt_rtp(self, packet: RTPPacket) -> bytes:
|
||||
state = self.client._connection
|
||||
dave = state.dave_session
|
||||
data = self._decryptor_rtp(packet)
|
||||
|
||||
if dave is not None and dave.ready and packet.ssrc in state.ssrc_user_map:
|
||||
data = dave.decrypt(
|
||||
state.ssrc_user_map[packet.ssrc], davey.MediaType.audio, data
|
||||
)
|
||||
|
||||
if packet.extended:
|
||||
offset = packet.update_extended_header(data)
|
||||
data = data[offset:]
|
||||
|
||||
return data"""
|
||||
|
||||
def decrypt_rtp(self, packet: RTPPacket) -> bytes:
|
||||
state = self.client._connection
|
||||
dave = state.dave_session
|
||||
|
||||
raw_payload = self._decryptor_rtp(packet)
|
||||
|
||||
if dave is not None and dave.ready:
|
||||
uid = state.ssrc_user_map.get(packet.ssrc)
|
||||
if uid:
|
||||
try:
|
||||
decrypted_audio = dave.decrypt(
|
||||
uid,
|
||||
davey.MediaType.audio,
|
||||
raw_payload,
|
||||
)
|
||||
|
||||
if packet.extended:
|
||||
offset = packet.update_extended_header(decrypted_audio)
|
||||
packet.decrypted_data = decrypted_audio[offset:]
|
||||
else:
|
||||
packet.decrypted_data = decrypted_audio
|
||||
except Exception as exc:
|
||||
_log.debug(
|
||||
"Ignoring exception while decoding DAVE packet", exc_info=exc
|
||||
)
|
||||
packet.decrypted_data = OPUS_SILENCE
|
||||
|
||||
return packet.decrypted_data
|
||||
|
||||
def decrypt_rtcp(self, packet: bytes) -> bytes:
|
||||
data = self._decryptor_rtcp(packet)
|
||||
|
||||
# parse the rtcp packet to its respective report type
|
||||
offset = 0
|
||||
|
||||
while offset < len(data):
|
||||
# offset will allow us to read the compund packets
|
||||
current_data = data[offset:]
|
||||
if len(current_data) < 8:
|
||||
break
|
||||
|
||||
p_header = RTCPPacket.from_data(current_data)
|
||||
|
||||
# the sender ssrc will always be at offset 4 of the current packet
|
||||
# doesn't matter if it is a sr or a rr
|
||||
ssrc = p_header.ssrc
|
||||
|
||||
state = self.client._connection
|
||||
dave = state.dave_session
|
||||
|
||||
if dave is not None and dave.ready and ssrc in state.ssrc_user_map:
|
||||
return dave.decrypt(
|
||||
state.ssrc_user_map[ssrc],
|
||||
davey.MediaType.audio,
|
||||
current_data,
|
||||
)
|
||||
return data
|
||||
|
||||
def update_secret_key(self, secret_key: bytes) -> None:
|
||||
self.box = self._make_box(secret_key)
|
||||
|
||||
def _decrypt_rtp_xsalsa20_poly1305(self, packet: RTPPacket) -> bytes:
|
||||
nonce = bytearray(24)
|
||||
nonce[:12] = packet.header
|
||||
result = self.box.decrypt(bytes(packet.data), bytes(nonce))
|
||||
|
||||
if packet.extended:
|
||||
offset = packet.update_extended_header(result)
|
||||
result = result[offset:]
|
||||
|
||||
return result
|
||||
|
||||
def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes:
|
||||
nonce = bytearray(24)
|
||||
nonce[:8] = data[:8]
|
||||
result = self.box.decrypt(data[8:], bytes(nonce))
|
||||
|
||||
return data[:8] + result
|
||||
|
||||
def _decrypt_rtp_xsalsa20_poly1305_suffix(self, packet: RTPPacket) -> bytes:
|
||||
nonce = packet.data[-24:]
|
||||
voice_data = packet.data[:-24]
|
||||
result = self.box.decrypt(bytes(voice_data), bytes(nonce))
|
||||
|
||||
if packet.extended:
|
||||
offset = packet.update_extended_header(result)
|
||||
result = result[offset:]
|
||||
|
||||
return result
|
||||
|
||||
def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes:
|
||||
nonce = data[-24:]
|
||||
header = data[:8]
|
||||
result = self.box.decrypt(data[8:-24], nonce)
|
||||
|
||||
return header + result
|
||||
|
||||
def _decrypt_rtp_xsalsa20_poly1305_lite(self, packet: RTPPacket) -> bytes:
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = packet.data[-4:]
|
||||
voice_data = packet.data[:-4]
|
||||
result = self.box.decrypt(bytes(voice_data), bytes(nonce))
|
||||
|
||||
if packet.extended:
|
||||
offset = packet.update_extended_header(result)
|
||||
result = result[offset:]
|
||||
|
||||
return result
|
||||
|
||||
def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes:
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = data[-4:]
|
||||
header = data[:8]
|
||||
result = self.box.decrypt(data[8:-4], bytes(nonce))
|
||||
|
||||
return header + result
|
||||
|
||||
def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> bytes:
|
||||
_log.debug(
|
||||
"Decrypting RTP AEAD XChaCha20 Poly1305 RTPSize, has decrypted data?: %s",
|
||||
packet.decrypted_data is not None,
|
||||
)
|
||||
packet.adjust_rtpsize()
|
||||
nonce = packet.nonce + b"\x00" * 20
|
||||
|
||||
assert isinstance(self.box, nacl.secret.Aead)
|
||||
|
||||
try:
|
||||
result = self.box.decrypt(
|
||||
packet.decrypted_data or packet.data,
|
||||
bytes(packet.header),
|
||||
nonce,
|
||||
)
|
||||
except Exception as exc:
|
||||
_log.error("Critical error at AEAD: %s", exc)
|
||||
raise CryptoError(exc)
|
||||
|
||||
if packet.extended:
|
||||
packet.update_extended_header(result)
|
||||
|
||||
return result[8:]
|
||||
|
||||
def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes:
|
||||
_log.debug("Decrypting RTCP AEAD XChaCha20 Poly1305 RTPSize")
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = data[-4:]
|
||||
header = data[:8]
|
||||
|
||||
assert isinstance(self.box, nacl.secret.Aead)
|
||||
result = self.box.decrypt(data[8:-4], bytes(header), bytes(nonce))
|
||||
|
||||
return header + result
|
||||
|
||||
|
||||
class SpeakingTimer(threading.Thread):
|
||||
def __init__(self, reader: AudioReader) -> None:
|
||||
super().__init__(
|
||||
daemon=True,
|
||||
name=f"voice-receiver-speaking-timer:{id(self):#x}",
|
||||
)
|
||||
|
||||
self.reader: AudioReader = reader
|
||||
self.client: VoiceClient = reader.client
|
||||
self.speaking_timeout_delay: float = 0.2
|
||||
self.last_speaking_state: dict[int, bool] = {}
|
||||
self.speaking_cache: dict[int, float] = {}
|
||||
self.speaking_timer_event: threading.Event = threading.Event()
|
||||
self._end_thread: threading.Event = threading.Event()
|
||||
|
||||
def _lookup_member(self, ssrc: int) -> Member | None:
|
||||
id = self.client._connection.ssrc_user_map.get(ssrc)
|
||||
if not self.client.guild:
|
||||
return None
|
||||
return self.client.guild.get_member(id) if id else None
|
||||
|
||||
def maybe_dispatch_speaking_start(self, ssrc: int) -> None:
|
||||
tlast = self.speaking_cache.get(ssrc)
|
||||
if tlast is None or tlast + self.speaking_timeout_delay < time.perf_counter():
|
||||
self.dispatch("member_speaking_start", ssrc)
|
||||
|
||||
def dispatch(self, event: SpeakingEvent, ssrc: int) -> None:
|
||||
member = self._lookup_member(ssrc)
|
||||
if not member:
|
||||
return None
|
||||
self.client._dispatch_sink(event, member)
|
||||
|
||||
def notify(self, ssrc: int | None = None) -> None:
|
||||
if ssrc is not None:
|
||||
self.last_speaking_state[ssrc] = True
|
||||
self.maybe_dispatch_speaking_start(ssrc)
|
||||
self.speaking_cache[ssrc] = time.perf_counter()
|
||||
|
||||
self.speaking_timer_event.set()
|
||||
self.speaking_timer_event.clear()
|
||||
|
||||
def drop_ssrc(self, ssrc: int) -> None:
|
||||
self.speaking_cache.pop(ssrc, None)
|
||||
state = self.last_speaking_state.pop(ssrc, None)
|
||||
if state:
|
||||
self.dispatch("member_speaking_stop", ssrc)
|
||||
self.notify()
|
||||
|
||||
def get_speaking(self, ssrc: int) -> bool | None:
|
||||
return self.last_speaking_state.get(ssrc)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._end_thread.set()
|
||||
self.notify()
|
||||
|
||||
def run(self) -> None:
|
||||
_i1 = itemgetter(1)
|
||||
|
||||
def get_next_entry():
|
||||
cache = sorted(self.speaking_cache.items(), key=_i1)
|
||||
for ssrc, tlast in cache:
|
||||
if self.last_speaking_state.get(ssrc):
|
||||
return ssrc, tlast
|
||||
return None, None
|
||||
|
||||
self.speaking_timer_event.wait()
|
||||
while not self._end_thread.is_set():
|
||||
if not self.speaking_cache:
|
||||
self.speaking_timer_event.wait()
|
||||
|
||||
tnow = time.perf_counter()
|
||||
ssrc, tlast = get_next_entry()
|
||||
|
||||
if ssrc is None or tlast is None:
|
||||
self.speaking_timer_event.wait()
|
||||
continue
|
||||
|
||||
self.speaking_timer_event.wait(tlast + self.speaking_timeout_delay - tnow)
|
||||
|
||||
if time.perf_counter() < tlast + self.speaking_timeout_delay:
|
||||
continue
|
||||
|
||||
self.dispatch("member_speaking_stop", ssrc)
|
||||
self.last_speaking_state[ssrc] = False
|
||||
|
||||
|
||||
class UDPKeepAlive(threading.Thread):
|
||||
delay: int = 5000
|
||||
|
||||
def __init__(self, client: VoiceClient) -> None:
|
||||
super().__init__(
|
||||
daemon=True,
|
||||
name=f"voice-receiver-udp-keep-alive:{id(self):#x}",
|
||||
)
|
||||
|
||||
self.client: VoiceClient = client
|
||||
self.last_time: float = 0
|
||||
self.counter: int = 0
|
||||
self._end_thread: threading.Event = threading.Event()
|
||||
|
||||
def run(self) -> None:
|
||||
self.client.wait_until_connected()
|
||||
|
||||
while not self._end_thread.is_set():
|
||||
vc = self.client
|
||||
|
||||
try:
|
||||
packet = self.counter.to_bytes(8, "big")
|
||||
except OverflowError:
|
||||
self.counter = 0
|
||||
continue
|
||||
|
||||
try:
|
||||
vc._connection.socket.sendto(
|
||||
packet, (vc._connection.endpoint_ip, vc._connection.voice_port)
|
||||
)
|
||||
except Exception as exc:
|
||||
_log.debug(
|
||||
"Error while sending udp keep alive to socket %s at %s:%s",
|
||||
vc._connection.socket,
|
||||
vc._connection.endpoint_ip,
|
||||
vc._connection.voice_port,
|
||||
exc_info=exc,
|
||||
)
|
||||
vc.wait_until_connected()
|
||||
if vc.is_connected():
|
||||
continue
|
||||
break
|
||||
else:
|
||||
self.counter += 1
|
||||
time.sleep(self.delay)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._end_thread.set()
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from discord.opus import PacketDecoder
|
||||
|
||||
from ..utils.multidataevent import MultiDataEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.sinks import Sink
|
||||
|
||||
from ..packets import RTCPPacket, RTPPacket
|
||||
from .reader import AudioReader
|
||||
|
||||
EventCB = Callable[..., Any]
|
||||
EventData = tuple[str, tuple[Any, ...], dict[str, Any]]
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PacketRouter(threading.Thread):
|
||||
def __init__(self, sink: Sink, reader: AudioReader) -> None:
|
||||
super().__init__(
|
||||
daemon=True,
|
||||
name=f"voice-receiver-packet-router:{id(self):#x}",
|
||||
)
|
||||
|
||||
self.sink: Sink = sink
|
||||
self.decoders: dict[int, PacketDecoder] = {}
|
||||
self.reader: AudioReader = reader
|
||||
self.waiter: MultiDataEvent[PacketDecoder] = MultiDataEvent()
|
||||
|
||||
self._lock: threading.RLock = threading.RLock()
|
||||
self._end_thread: threading.Event = threading.Event()
|
||||
self._dropped_ssrcs: deque[int] = deque(maxlen=16)
|
||||
|
||||
def feed_rtp(self, packet: RTPPacket) -> None:
|
||||
if packet.ssrc in self._dropped_ssrcs:
|
||||
_log.debug("Ignoring packet from dropped ssrc %s", packet.ssrc)
|
||||
|
||||
with self._lock:
|
||||
decoder = self.get_decoder(packet.ssrc)
|
||||
if decoder is not None:
|
||||
decoder.push_packet(packet)
|
||||
|
||||
def feed_rtcp(self, packet: RTCPPacket) -> None:
|
||||
guild = self.sink.client.guild if self.sink.client else None
|
||||
event_router = self.reader.event_router
|
||||
event_router.dispatch("rtcp_packet", packet, guild)
|
||||
|
||||
def get_decoder(self, ssrc: int) -> PacketDecoder | None:
|
||||
with self._lock:
|
||||
decoder = self.decoders.get(ssrc)
|
||||
if decoder is None:
|
||||
decoder = self.decoders[ssrc] = PacketDecoder(self, ssrc)
|
||||
return decoder
|
||||
|
||||
def set_sink(self, sink: Sink) -> None:
|
||||
with self._lock:
|
||||
self.sink = sink
|
||||
|
||||
def set_user_id(self, ssrc: int, user_id: int) -> None:
|
||||
with self._lock:
|
||||
if ssrc in self._dropped_ssrcs:
|
||||
self._dropped_ssrcs.remove(ssrc)
|
||||
|
||||
decoder = self.decoders.get(ssrc)
|
||||
if decoder is not None:
|
||||
decoder.set_user_id(user_id)
|
||||
|
||||
def destroy_decoder(self, ssrc: int) -> None:
|
||||
with self._lock:
|
||||
decoder = self.decoders.pop(ssrc, None)
|
||||
if decoder is not None:
|
||||
self._dropped_ssrcs.append(ssrc)
|
||||
decoder.destroy()
|
||||
|
||||
def destroy_all_decoders(self) -> None:
|
||||
with self._lock:
|
||||
for ssrc in self.decoders.keys():
|
||||
self.destroy_decoder(ssrc)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._end_thread.set()
|
||||
self.waiter.notify()
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self._do_run()
|
||||
except Exception as exc:
|
||||
_log.exception("Error in %s loop", self)
|
||||
self.reader.error = exc
|
||||
finally:
|
||||
self.reader.client.stop_recording()
|
||||
self.waiter.clear()
|
||||
|
||||
def _do_run(self) -> None:
|
||||
while not self._end_thread.is_set():
|
||||
self.waiter.wait()
|
||||
|
||||
with self._lock:
|
||||
for decoder in self.waiter.items:
|
||||
data = decoder.pop_data()
|
||||
if data is not None:
|
||||
self.sink.write(data, data.source)
|
||||
|
||||
|
||||
class SinkEventRouter(threading.Thread):
|
||||
def __init__(self, sink: Sink, reader: AudioReader) -> None:
|
||||
super().__init__(
|
||||
daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}"
|
||||
)
|
||||
|
||||
self.sink: Sink = sink
|
||||
self.reader: AudioReader = reader
|
||||
|
||||
self._event_listeners: dict[str, list[EventCB]] = {}
|
||||
self._buffer: queue.SimpleQueue[EventData] = queue.SimpleQueue()
|
||||
self._lock = threading.RLock()
|
||||
self._end_thread = threading.Event()
|
||||
|
||||
self.register_events()
|
||||
|
||||
def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
|
||||
_log.debug("Dispatch voice event %s", event)
|
||||
self._buffer.put_nowait((event, args, kwargs))
|
||||
|
||||
def set_sink(self, sink: Sink) -> None:
|
||||
with self._lock:
|
||||
self.unregister_events()
|
||||
self.sink = sink
|
||||
self.register_events()
|
||||
|
||||
def register_events(self) -> None:
|
||||
with self._lock:
|
||||
self._register_listeners(self.sink)
|
||||
for child in self.sink.walk_children():
|
||||
self._register_listeners(child)
|
||||
|
||||
def unregister_events(self) -> None:
|
||||
with self._lock:
|
||||
self._unregister_listeners(self.sink)
|
||||
for child in self.sink.walk_children():
|
||||
self._unregister_listeners(child)
|
||||
|
||||
def _register_listeners(self, sink: Sink) -> None:
|
||||
_log.debug("Registering events for %s: %s", sink, sink.__sink_listeners__)
|
||||
|
||||
for name, method_name in sink.__sink_listeners__:
|
||||
func = getattr(sink, method_name)
|
||||
_log.debug("Registering event: %r (callback at %r)", name, method_name)
|
||||
|
||||
if name in self._event_listeners:
|
||||
self._event_listeners[name].append(func)
|
||||
else:
|
||||
self._event_listeners[name] = [func]
|
||||
|
||||
def _unregister_listeners(self, sink: Sink) -> None:
|
||||
for name, method_name in sink.__sink_listeners__:
|
||||
func = getattr(sink, method_name)
|
||||
|
||||
if name in self._event_listeners:
|
||||
try:
|
||||
self._event_listeners[name].remove(func)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _dispatch_to_listeners(self, event: str, *args: Any, **kwargs: Any) -> None:
|
||||
for listener in self._event_listeners.get(f"on_{event}", []):
|
||||
try:
|
||||
listener(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
_log.exception(
|
||||
"Unhandled exception while dispatching event %s (args: %s; kwargs: %s)",
|
||||
event,
|
||||
args,
|
||||
kwargs,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._end_thread.set()
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self._do_run()
|
||||
except Exception as exc:
|
||||
_log.exception("Error in sink event router", exc_info=exc)
|
||||
self.reader.error = exc
|
||||
self.reader.client.stop_recording()
|
||||
|
||||
def _do_run(self) -> None:
|
||||
while not self._end_thread.is_set():
|
||||
try:
|
||||
event, args, kwargs = self._buffer.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
else:
|
||||
with self._lock:
|
||||
with self.reader.packet_router._lock:
|
||||
self._dispatch_to_listeners(event, *args, **kwargs)
|
||||
@@ -0,0 +1,983 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import select
|
||||
import socket
|
||||
import threading
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from discord import utils
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.errors import ConnectionClosed
|
||||
from discord.voice.utils.dependencies import DAVE_PROTOCOL_VERSION, HAS_DAVEY
|
||||
|
||||
from .enums import ConnectionFlowState, OpCodes
|
||||
from .gateway import VoiceWebSocket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import abc
|
||||
from discord.guild import Guild
|
||||
from discord.member import VoiceState
|
||||
from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent
|
||||
from discord.state import ConnectionState
|
||||
from discord.types.voice import SupportedModes
|
||||
from discord.user import ClientUser
|
||||
|
||||
from .client import VoiceClient
|
||||
|
||||
MISSING = utils.MISSING
|
||||
SocketReaderCallback = Callable[[bytes], Any]
|
||||
_log = logging.getLogger(__name__)
|
||||
_recv_log = logging.getLogger("discord.voice.receiver")
|
||||
|
||||
if HAS_DAVEY:
|
||||
import davey
|
||||
|
||||
|
||||
class SocketReader(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
state: VoiceConnectionState,
|
||||
name: str,
|
||||
buffer_size: int,
|
||||
*,
|
||||
start_paused: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
daemon=True,
|
||||
name=name,
|
||||
)
|
||||
|
||||
self.buffer_size: int = buffer_size
|
||||
self.state: VoiceConnectionState = state
|
||||
self.start_paused: bool = start_paused
|
||||
self._callbacks: list[SocketReaderCallback] = []
|
||||
self._running: threading.Event = threading.Event()
|
||||
self._end: threading.Event = threading.Event()
|
||||
self._idle_paused: bool = True
|
||||
self._started: threading.Event = threading.Event()
|
||||
self._warned_wait: bool = False
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return self._started.is_set()
|
||||
|
||||
def register(self, callback: SocketReaderCallback) -> None:
|
||||
self._callbacks.append(callback)
|
||||
if self._idle_paused:
|
||||
self._idle_paused = False
|
||||
self._running.set()
|
||||
|
||||
def unregister(self, callback: SocketReaderCallback) -> None:
|
||||
try:
|
||||
self._callbacks.remove(callback)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if not self._callbacks and self._running.is_set():
|
||||
self._idle_paused = True
|
||||
self._running.clear()
|
||||
|
||||
def pause(self) -> None:
|
||||
self._idle_paused = False
|
||||
self._running.clear()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return self._idle_paused or (
|
||||
not self._running.is_set() and not self._end.is_set()
|
||||
)
|
||||
|
||||
def resume(self, *, force: bool = False) -> None:
|
||||
if self._running.is_set():
|
||||
return
|
||||
|
||||
if not force and not self._callbacks:
|
||||
self._idle_paused = True
|
||||
return
|
||||
|
||||
self._idle_paused = False
|
||||
self._running.set()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._started.clear()
|
||||
self._end.set()
|
||||
self._running.set()
|
||||
|
||||
def run(self) -> None:
|
||||
self._started.set()
|
||||
self._end.clear()
|
||||
self._running.set()
|
||||
|
||||
if self.start_paused:
|
||||
self.pause()
|
||||
|
||||
try:
|
||||
self._do_run()
|
||||
except Exception:
|
||||
_log.exception(
|
||||
"An error ocurred while running the socket reader %s",
|
||||
self.name,
|
||||
)
|
||||
finally:
|
||||
self.stop()
|
||||
self._started.clear()
|
||||
self._running.clear()
|
||||
self._callbacks.clear()
|
||||
|
||||
def _do_run(self) -> None:
|
||||
while not self._end.is_set():
|
||||
if not self._running.is_set():
|
||||
if not self._warned_wait:
|
||||
_log.warning(
|
||||
"Socket reader %s is waiting to be set as running", self.name
|
||||
)
|
||||
self._warned_wait = True
|
||||
self._running.wait()
|
||||
continue
|
||||
|
||||
if self._warned_wait:
|
||||
_log.info("Socket reader %s was set as running", self.name)
|
||||
self._warned_wait = False
|
||||
|
||||
try:
|
||||
readable, _, _ = select.select([self.state.socket], [], [], 30)
|
||||
except (ValueError, TypeError, OSError) as e:
|
||||
_log.debug(
|
||||
"Select error handling socket in reader, this should be safe to ignore: %s: %s",
|
||||
e.__class__.__name__,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
if not readable:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = self.state.socket.recv(self.buffer_size)
|
||||
except OSError:
|
||||
_log.debug(
|
||||
"Error reading from socket in %s, this should be safe to ignore",
|
||||
self,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
task = asyncio.ensure_future(
|
||||
self.state.loop.create_task(
|
||||
utils.maybe_coroutine(cb, data)
|
||||
),
|
||||
loop=self.state.loop,
|
||||
)
|
||||
self.state._dispatch_task_set.add(task)
|
||||
task.add_done_callback(self.state._dispatch_task_set.discard)
|
||||
except Exception:
|
||||
_log.exception(
|
||||
"Error while calling %s in %s",
|
||||
cb,
|
||||
self,
|
||||
)
|
||||
|
||||
|
||||
class SocketEventReader(SocketReader):
|
||||
def __init__(
|
||||
self, state: VoiceConnectionState, *, start_paused: bool = True
|
||||
) -> None:
|
||||
super().__init__(
|
||||
state,
|
||||
f"voice-socket-event-reader:{id(self):#x}",
|
||||
2048,
|
||||
start_paused=start_paused,
|
||||
)
|
||||
|
||||
|
||||
class VoiceConnectionState:
|
||||
def __init__(
|
||||
self,
|
||||
client: VoiceClient,
|
||||
*,
|
||||
hook: (
|
||||
Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None
|
||||
) = None,
|
||||
) -> None:
|
||||
self.client: VoiceClient = client
|
||||
self.hook = hook
|
||||
self.loop: asyncio.AbstractEventLoop = client.loop
|
||||
|
||||
self.timeout: float = 30.0
|
||||
self.reconnect: bool = True
|
||||
self.self_deaf: bool = False
|
||||
self.self_mute: bool = False
|
||||
self.endpoint: str | None = None
|
||||
self.endpoint_ip: str | None = None
|
||||
self.server_id: int | None = None
|
||||
self.ip: str | None = None
|
||||
self.port: int | None = None
|
||||
self.voice_port: int | None = None
|
||||
self.secret_key: list[int] = MISSING
|
||||
self.ssrc: int = MISSING
|
||||
self.mode: SupportedModes = MISSING
|
||||
self.socket: socket.socket = MISSING
|
||||
self.ws: VoiceWebSocket = MISSING
|
||||
self.session_id: str | None = None
|
||||
self.token: str | None = None
|
||||
|
||||
self._connection: ConnectionState = client._state
|
||||
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
|
||||
self._expecting_disconnect: bool = False
|
||||
self._connected = threading.Event()
|
||||
self._state_event = asyncio.Event()
|
||||
self._disconnected = asyncio.Event()
|
||||
self._runner: asyncio.Task[None] | None = None
|
||||
self._connector: asyncio.Task[None] | None = None
|
||||
self._socket_reader = SocketEventReader(self)
|
||||
self._socket_reader.start()
|
||||
self.recording_done_callbacks: list[
|
||||
tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]]
|
||||
] = []
|
||||
self._dispatch_task_set: set[asyncio.Task] = set()
|
||||
|
||||
if not self._connection.self_id:
|
||||
raise RuntimeError("client self ID is not set")
|
||||
if not self.channel_id:
|
||||
raise RuntimeError("client channel being connected to is not set")
|
||||
|
||||
self.dave_session: davey.DaveSession | None = None
|
||||
self.dave_protocol_version: int = 0
|
||||
self.dave_pending_transition: dict[str, int] | None = None
|
||||
self.downgraded_dave = False
|
||||
|
||||
@property
|
||||
def user_ssrc_map(self) -> dict[int, int]:
|
||||
return self.client._id_to_ssrc
|
||||
|
||||
@property
|
||||
def ssrc_user_map(self) -> dict[int, int]:
|
||||
return {v: k for k, v in self.user_ssrc_map.items()}
|
||||
|
||||
@property
|
||||
def max_dave_proto_version(self) -> int:
|
||||
return DAVE_PROTOCOL_VERSION
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionFlowState:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: ConnectionFlowState) -> None:
|
||||
if state is not self._state:
|
||||
_log.debug("State changed from %s to %s", self._state.name, state.name)
|
||||
|
||||
self._state = state
|
||||
self._state_event.set()
|
||||
self._state_event.clear()
|
||||
|
||||
if state is ConnectionFlowState.connected:
|
||||
self._connected.set()
|
||||
else:
|
||||
self._connected.clear()
|
||||
|
||||
@property
|
||||
def guild(self) -> Guild:
|
||||
return self.client.guild
|
||||
|
||||
@property
|
||||
def user(self) -> ClientUser:
|
||||
return self.client.user
|
||||
|
||||
@property
|
||||
def channel_id(self) -> int | None:
|
||||
return self.client.channel is not None and self.client.channel.id
|
||||
|
||||
@property
|
||||
def guild_id(self) -> int:
|
||||
return self.guild.id
|
||||
|
||||
@property
|
||||
def supported_modes(self) -> tuple[SupportedModes, ...]:
|
||||
return self.client.supported_modes
|
||||
|
||||
@property
|
||||
def self_voice_state(self) -> VoiceState | None:
|
||||
return self.guild.me.voice
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.state is ConnectionFlowState.connected
|
||||
|
||||
def _inside_runner(self) -> bool:
|
||||
return self._runner is not None and asyncio.current_task() == self._runner
|
||||
|
||||
async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
|
||||
channel_id = data.channel_id
|
||||
|
||||
if channel_id is None:
|
||||
self._disconnected.set()
|
||||
|
||||
if self._expecting_disconnect:
|
||||
self._expecting_disconnect = False
|
||||
else:
|
||||
_log.debug("We have been disconnected from voice")
|
||||
await self.disconnect()
|
||||
return
|
||||
|
||||
self.session_id = data.session_id
|
||||
|
||||
if self.state in (
|
||||
ConnectionFlowState.set_guild_voice_state,
|
||||
ConnectionFlowState.got_voice_server_update,
|
||||
):
|
||||
if self.state is ConnectionFlowState.set_guild_voice_state:
|
||||
self.state = ConnectionFlowState.got_voice_state_update
|
||||
|
||||
if channel_id != self.client.channel.id:
|
||||
# moved from channel
|
||||
self._update_voice_channel(channel_id)
|
||||
else:
|
||||
self.state = ConnectionFlowState.got_both_voice_updates
|
||||
return
|
||||
|
||||
if self.state is ConnectionFlowState.connected:
|
||||
self._update_voice_channel(channel_id)
|
||||
|
||||
elif self.state is not ConnectionFlowState.disconnected:
|
||||
if channel_id != self.client.channel.id:
|
||||
_log.info("We were moved from the channel while connecting...")
|
||||
|
||||
self._update_voice_channel(channel_id)
|
||||
await self.soft_disconnect(
|
||||
with_state=ConnectionFlowState.got_voice_state_update
|
||||
)
|
||||
await self.connect(
|
||||
reconnect=self.reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
wait=False,
|
||||
)
|
||||
else:
|
||||
_log.debug("Ignoring unexpected VOICE_STATEUPDATE event")
|
||||
|
||||
async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
|
||||
previous_token = self.token
|
||||
previous_server_id = self.server_id
|
||||
previous_endpoint = self.endpoint
|
||||
|
||||
self.token = data.token
|
||||
self.server_id = data.guild_id
|
||||
endpoint = data.endpoint
|
||||
|
||||
if self.token is None or endpoint is None:
|
||||
_log.warning(
|
||||
"Awaiting endpoint... This requires waiting. "
|
||||
"If timeout occurred considering raising the timeout and reconnecting."
|
||||
)
|
||||
return
|
||||
|
||||
# strip the prefix off since we add it later
|
||||
self.endpoint = endpoint.removeprefix("wss://")
|
||||
|
||||
if self.state in (
|
||||
ConnectionFlowState.set_guild_voice_state,
|
||||
ConnectionFlowState.got_voice_state_update,
|
||||
):
|
||||
self.endpoint_ip = MISSING
|
||||
self._create_socket()
|
||||
|
||||
if self.state is ConnectionFlowState.set_guild_voice_state:
|
||||
self.state = ConnectionFlowState.got_voice_server_update
|
||||
else:
|
||||
self.state = ConnectionFlowState.got_both_voice_updates
|
||||
|
||||
elif self.state is ConnectionFlowState.connected:
|
||||
_log.debug("Voice server update, closing old voice websocket")
|
||||
await self.ws.close(4014) # 4014 = main gw dropped
|
||||
self.state = ConnectionFlowState.got_voice_server_update
|
||||
|
||||
elif self.state is not ConnectionFlowState.disconnected:
|
||||
if (
|
||||
previous_token == self.token
|
||||
and previous_server_id == self.server_id
|
||||
and previous_endpoint == self.endpoint
|
||||
):
|
||||
return
|
||||
|
||||
_log.debug("Unexpected VOICE_SERVER_UPDATE event received, handling...")
|
||||
|
||||
await self.soft_disconnect(
|
||||
with_state=ConnectionFlowState.got_voice_server_update
|
||||
)
|
||||
await self.connect(
|
||||
reconnect=self.reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
wait=False,
|
||||
)
|
||||
self._create_socket()
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
reconnect: bool,
|
||||
timeout: float,
|
||||
self_deaf: bool,
|
||||
self_mute: bool,
|
||||
resume: bool,
|
||||
wait: bool = True,
|
||||
) -> None:
|
||||
if self._connector:
|
||||
self._connector.cancel()
|
||||
self._connector = None
|
||||
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
self._runner = None
|
||||
|
||||
self.timeout = timeout
|
||||
self.reconnect = reconnect
|
||||
self._connector = self.client.loop.create_task(
|
||||
self._wrap_connect(
|
||||
reconnect,
|
||||
timeout,
|
||||
self_deaf,
|
||||
self_mute,
|
||||
resume,
|
||||
),
|
||||
name=f"voice-connector:{id(self):#x}",
|
||||
)
|
||||
|
||||
if wait:
|
||||
await self._connector
|
||||
|
||||
async def _wrap_connect(
|
||||
self,
|
||||
reconnect: bool,
|
||||
timeout: float,
|
||||
self_deaf: bool,
|
||||
self_mute: bool,
|
||||
resume: bool,
|
||||
) -> None:
|
||||
try:
|
||||
await self._connect(reconnect, timeout, self_deaf, self_mute, resume)
|
||||
except asyncio.CancelledError:
|
||||
_log.debug("Cancelling voice connection")
|
||||
await self.soft_disconnect()
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
_log.info("Timed out while connecting to voice")
|
||||
await self.disconnect()
|
||||
raise
|
||||
except Exception:
|
||||
_log.exception("Error while connecting to voice... disconnecting")
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def _inner_connect(
|
||||
self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool
|
||||
) -> None:
|
||||
for i in range(5):
|
||||
_log.info("Starting voice handshake (connection attempt %s)", i + 1)
|
||||
|
||||
await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute)
|
||||
if self.state is ConnectionFlowState.disconnected:
|
||||
self.state = ConnectionFlowState.set_guild_voice_state
|
||||
|
||||
await self._wait_for_state(ConnectionFlowState.got_both_voice_updates)
|
||||
|
||||
_log.info("Voice handshake complete. Endpoint found: %s", self.endpoint)
|
||||
|
||||
try:
|
||||
self.ws = await self._connect_websocket(resume)
|
||||
await self._handshake_websocket()
|
||||
break
|
||||
except ConnectionClosed:
|
||||
if reconnect:
|
||||
wait = 1 + i * 2
|
||||
_log.exception(
|
||||
"Failed to connect to voice... Retrying in %s seconds", wait
|
||||
)
|
||||
await self.disconnect(cleanup=False)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
else:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def _connect(
|
||||
self,
|
||||
reconnect: bool,
|
||||
timeout: float,
|
||||
self_deaf: bool,
|
||||
self_mute: bool,
|
||||
resume: bool,
|
||||
) -> None:
|
||||
_log.info(f"Connecting to voice {self.client.channel.id}")
|
||||
|
||||
await asyncio.wait_for(
|
||||
self._inner_connect(
|
||||
reconnect=reconnect,
|
||||
self_deaf=self_deaf,
|
||||
self_mute=self_mute,
|
||||
resume=resume,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
_log.info("Voice connection completed")
|
||||
|
||||
if not self._runner:
|
||||
self._runner = self.client.loop.create_task(
|
||||
self._poll_ws(reconnect),
|
||||
name=f"voice-ws-poller:{id(self):#x}",
|
||||
)
|
||||
|
||||
async def disconnect(
|
||||
self, *, force: bool = True, cleanup: bool = True, wait: bool = False
|
||||
) -> None:
|
||||
if not force and not self.is_connected():
|
||||
return
|
||||
|
||||
_log.debug(
|
||||
"Attempting a voice disconnect for channel %s (guild %s)",
|
||||
self.channel_id,
|
||||
self.guild_id,
|
||||
)
|
||||
try:
|
||||
await self._voice_disconnect()
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
_log.debug(
|
||||
"Ignoring exception while disconnecting from voice", exc_info=True
|
||||
)
|
||||
finally:
|
||||
self.state = ConnectionFlowState.disconnected
|
||||
self._socket_reader.pause()
|
||||
|
||||
if cleanup:
|
||||
self._socket_reader.stop()
|
||||
self.client.stop()
|
||||
|
||||
self._connected.set()
|
||||
self._connected.clear()
|
||||
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
self.ip = MISSING
|
||||
self.port = MISSING
|
||||
|
||||
if wait and not self._inside_runner():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._disconnected.wait(), timeout=self.timeout
|
||||
)
|
||||
except TimeoutError:
|
||||
_log.debug("Timed out waiting for voice disconnect confirmation")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if cleanup:
|
||||
self.client.cleanup()
|
||||
|
||||
async def soft_disconnect(
|
||||
self,
|
||||
*,
|
||||
with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates,
|
||||
) -> None:
|
||||
_log.debug("Soft disconnecting from voice")
|
||||
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
self._runner = None
|
||||
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
_log.debug(
|
||||
"Ignoring exception while soft disconnecting from voice", exc_info=True
|
||||
)
|
||||
finally:
|
||||
self.state = with_state
|
||||
self._socket_reader.pause()
|
||||
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
self.ip = MISSING
|
||||
self.port = MISSING
|
||||
|
||||
async def move_to(
|
||||
self, channel: abc.Snowflake | None, timeout: float | None
|
||||
) -> None:
|
||||
if channel is None:
|
||||
await self.disconnect(wait=True)
|
||||
return
|
||||
|
||||
if self.client.channel and channel.id == self.client.channel.id:
|
||||
return
|
||||
|
||||
previous_state = self.state
|
||||
await self._move_to(channel)
|
||||
|
||||
last_state = self.state
|
||||
|
||||
try:
|
||||
await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
_log.warning(
|
||||
"Timed out trying to move to channel %s in guild %s",
|
||||
channel.id,
|
||||
self.guild.id,
|
||||
)
|
||||
if self.state is last_state:
|
||||
_log.debug(
|
||||
"Reverting state %s to previous state: %s",
|
||||
last_state.name,
|
||||
previous_state.name,
|
||||
)
|
||||
self.state = previous_state
|
||||
|
||||
def wait_for(
|
||||
self,
|
||||
state: ConnectionFlowState = ConnectionFlowState.connected,
|
||||
timeout: float | None = None,
|
||||
) -> Any:
|
||||
if state is ConnectionFlowState.connected:
|
||||
return self._connected.wait(timeout)
|
||||
return self._wait_for_state(state, timeout=timeout)
|
||||
|
||||
def send_packet(self, packet: bytes) -> None:
|
||||
self.socket.sendall(packet)
|
||||
|
||||
def add_socket_listener(self, callback: SocketReaderCallback) -> None:
|
||||
_log.debug("Registering a socket listener callback %s", callback)
|
||||
self._socket_reader.register(callback)
|
||||
|
||||
def remove_socket_listener(self, callback: SocketReaderCallback) -> None:
|
||||
_log.debug("Unregistering a socket listener callback %s", callback)
|
||||
self._socket_reader.unregister(callback)
|
||||
|
||||
async def _wait_for_state(
|
||||
self,
|
||||
*states: ConnectionFlowState,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
if not states:
|
||||
raise ValueError
|
||||
|
||||
while True:
|
||||
if self.state in states:
|
||||
return
|
||||
|
||||
_, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.ensure_future(self._state_event.wait()),
|
||||
],
|
||||
timeout=timeout,
|
||||
)
|
||||
if pending:
|
||||
# if we're here, it means that the state event
|
||||
# has timed out, so just raise the exception
|
||||
raise asyncio.TimeoutError
|
||||
|
||||
async def _voice_connect(
|
||||
self, *, self_deaf: bool = False, self_mute: bool = False
|
||||
) -> None:
|
||||
channel = self.client.channel
|
||||
await channel.guild.change_voice_state(
|
||||
channel=channel, self_deaf=self_deaf, self_mute=self_mute
|
||||
)
|
||||
|
||||
async def _voice_disconnect(self) -> None:
|
||||
_log.info(
|
||||
"Terminating voice handshake for channel %s (guild %s)",
|
||||
self.client.channel.id,
|
||||
self.client.guild.id,
|
||||
)
|
||||
|
||||
self.state = ConnectionFlowState.disconnected
|
||||
await self.client.channel.guild.change_voice_state(
|
||||
channel=None
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
self._expecting_disconnect = True
|
||||
self._disconnected.clear()
|
||||
|
||||
async def _connect_websocket(self, resume: bool) -> VoiceWebSocket:
|
||||
seq_ack = -1
|
||||
if self.ws is not MISSING:
|
||||
seq_ack = self.ws.seq_ack
|
||||
ws = await VoiceWebSocket.from_state(
|
||||
self, resume=resume, hook=self.hook, seq_ack=seq_ack
|
||||
)
|
||||
self.state = ConnectionFlowState.websocket_connected
|
||||
return ws
|
||||
|
||||
async def _handshake_websocket(self) -> None:
|
||||
while not self.ip:
|
||||
await self.ws.poll_event()
|
||||
|
||||
self.state = ConnectionFlowState.got_ip_discovery
|
||||
while self.ws.secret_key is None:
|
||||
await self.ws.poll_event()
|
||||
|
||||
self.state = ConnectionFlowState.connected
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.setblocking(False)
|
||||
self._socket_reader.resume()
|
||||
|
||||
async def _poll_ws(self, reconnect: bool) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except (ConnectionClosed, asyncio.TimeoutError) as exc:
|
||||
if isinstance(exc, ConnectionClosed):
|
||||
# 1000 - normal closure - not resumable
|
||||
# 4014 - externally disconnected - not resumable
|
||||
# 4015 - voice server crashed - resumable
|
||||
# 4021 - ratelimited, not reconnect - not resumable
|
||||
# 4022 - call terminated, similar to 4014 - not resumable
|
||||
|
||||
if exc.code == 1000:
|
||||
if not self._expecting_disconnect:
|
||||
_log.info(
|
||||
"Disconnecting from voice manually, close code %d",
|
||||
exc.code,
|
||||
)
|
||||
await self.disconnect()
|
||||
break
|
||||
elif exc.code in (4014, 4022):
|
||||
if self._disconnected.is_set():
|
||||
_log.info(
|
||||
"Disconnecting from voice by Discord, close code %d",
|
||||
exc.code,
|
||||
)
|
||||
await self.disconnect()
|
||||
break
|
||||
|
||||
_log.info(
|
||||
"Disconnecting from voice by force... potentially reconnecting..."
|
||||
)
|
||||
successful = await self._potential_reconnect()
|
||||
if not successful:
|
||||
_log.info(
|
||||
"Reconnect was unsuccessful, disconnecting from voice normally"
|
||||
)
|
||||
if self.state is not ConnectionFlowState.disconnected:
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
# we have successfully resumed so just keep polling events
|
||||
continue
|
||||
elif exc.code == 4021:
|
||||
_log.warning(
|
||||
"We are being rate limited while attempting to connect to voice. Disconnecting...",
|
||||
)
|
||||
if self.state is not ConnectionFlowState.disconnected:
|
||||
await self.disconnect()
|
||||
break
|
||||
elif exc.code == 4015:
|
||||
_log.info(
|
||||
"Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...",
|
||||
)
|
||||
|
||||
try:
|
||||
await self._connect(
|
||||
reconnect=reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=True,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
_log.info(
|
||||
"Could not resume the voice connection... Disconnecting..."
|
||||
)
|
||||
if self.state is not ConnectionFlowState.disconnected:
|
||||
await self.disconnect()
|
||||
break
|
||||
except Exception:
|
||||
_log.exception(
|
||||
"An exception was raised while attempting a reconnect and resume... Disconnecting...",
|
||||
exc_info=True,
|
||||
)
|
||||
if self.state is not ConnectionFlowState.disconnected:
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
_log.info(
|
||||
"Successfully reconnected and resumed the voice connection"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
_log.debug(
|
||||
"Not handling close code %s (%s)",
|
||||
exc.code,
|
||||
exc.reason or "No reason was provided",
|
||||
)
|
||||
|
||||
if not reconnect:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
_log.exception(
|
||||
"Disconnected from voice... Reconnecting in %.2fs",
|
||||
retry,
|
||||
)
|
||||
await asyncio.sleep(retry)
|
||||
await self.disconnect(cleanup=False)
|
||||
|
||||
try:
|
||||
await self._connect(
|
||||
reconnect=reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
_log.warning("Could not connect to voice... Retrying...")
|
||||
continue
|
||||
|
||||
async def _potential_reconnect(self) -> bool:
|
||||
try:
|
||||
await self._wait_for_state(
|
||||
ConnectionFlowState.got_voice_server_update,
|
||||
ConnectionFlowState.got_both_voice_updates,
|
||||
ConnectionFlowState.disconnected,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
else:
|
||||
if self.state is ConnectionFlowState.disconnected:
|
||||
return False
|
||||
|
||||
previous_ws = self.ws
|
||||
|
||||
try:
|
||||
self.ws = await self._connect_websocket(False)
|
||||
await self._handshake_websocket()
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
finally:
|
||||
await previous_ws.close()
|
||||
|
||||
async def _move_to(self, channel: abc.Snowflake) -> None:
|
||||
await self.client.channel.guild.change_voice_state(
|
||||
channel=channel
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
self.state = ConnectionFlowState.set_guild_voice_state
|
||||
|
||||
def _update_voice_channel(self, channel_id: int | None) -> None:
|
||||
self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore
|
||||
|
||||
async def reinit_dave_session(self) -> None:
|
||||
assert self.channel_id
|
||||
|
||||
if self.dave_protocol_version > 0:
|
||||
if self.dave_session:
|
||||
self.dave_session.reinit(
|
||||
self.dave_protocol_version, self.user.id, self.channel_id
|
||||
)
|
||||
else:
|
||||
self.dave_session = davey.DaveSession(
|
||||
self.dave_protocol_version,
|
||||
self.user.id,
|
||||
self.channel_id,
|
||||
)
|
||||
|
||||
await self.ws.send_as_bytes(
|
||||
OpCodes.mls_key_package,
|
||||
self.dave_session.get_serialized_key_package(),
|
||||
)
|
||||
elif self.dave_session:
|
||||
self.dave_session.reset()
|
||||
self.dave_session.set_passthrough_mode(True, 10)
|
||||
|
||||
async def recover_dave_from_invalid_commit(self, transition: int) -> None:
|
||||
payload = {
|
||||
"op": int(OpCodes.mls_invalid_commit_welcome),
|
||||
"d": {"transition_id": transition},
|
||||
}
|
||||
await self.ws.send_as_json(payload)
|
||||
await self.reinit_dave_session()
|
||||
|
||||
async def execute_dave_transition(self, transition: int) -> None:
|
||||
_log.debug("Executing DAVE transition with id %s", transition)
|
||||
|
||||
if not self.dave_pending_transition:
|
||||
_log.warning(
|
||||
"Attempted to execute a transition without having a pending transition for id %s, "
|
||||
"this is a Discord bug.",
|
||||
transition,
|
||||
)
|
||||
return
|
||||
|
||||
pending_transition = self.dave_pending_transition["transition_id"]
|
||||
pending_proto = self.dave_pending_transition["protocol_version"]
|
||||
|
||||
session = self.dave_session
|
||||
|
||||
if transition == pending_transition:
|
||||
old_version = self.dave_protocol_version
|
||||
self.dave_protocol_version = pending_proto
|
||||
|
||||
if (
|
||||
old_version != self.dave_protocol_version
|
||||
and self.dave_protocol_version == 0
|
||||
):
|
||||
_log.warning(
|
||||
"DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7"
|
||||
)
|
||||
self.downgraded_dave = True
|
||||
elif transition > 0 and self.downgraded_dave:
|
||||
self.downgraded_dave = False
|
||||
if session:
|
||||
session.set_passthrough_mode(True, 10)
|
||||
_log.info("Upgraded voice session to use DAVE")
|
||||
else:
|
||||
_log.debug(
|
||||
"Received an execute transition id %s when expected was %s, ignoring",
|
||||
transition,
|
||||
pending_proto,
|
||||
)
|
||||
|
||||
self.dave_pending_transition = None
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
from ..packets import Packet
|
||||
from .wrapped import add_wrapped, gap_wrapped
|
||||
|
||||
__all__ = (
|
||||
"Buffer",
|
||||
"JitterBuffer",
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
PacketT = TypeVar("PacketT", bound=Packet)
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Buffer(Protocol[T]):
|
||||
def __len__(self) -> int: ...
|
||||
def push(self, item: T) -> None: ...
|
||||
def pop(self) -> T | None: ...
|
||||
def peek(self) -> T | None: ...
|
||||
def flush(self) -> list[T]: ...
|
||||
def reset(self) -> None: ...
|
||||
|
||||
|
||||
class BaseBuff(Buffer[PacketT]):
|
||||
def __init__(self) -> None:
|
||||
self._buffer: list[PacketT] = []
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buffer)
|
||||
|
||||
def push(self, item: PacketT) -> None:
|
||||
self._buffer.append(item)
|
||||
|
||||
def pop(self) -> PacketT | None:
|
||||
return self._buffer.pop()
|
||||
|
||||
def peek(self) -> PacketT | None:
|
||||
return self._buffer[-1] if self._buffer else None
|
||||
|
||||
def flush(self) -> list[PacketT]:
|
||||
buf = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
return buf
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer.clear()
|
||||
|
||||
|
||||
class JitterBuffer(BaseBuff[PacketT]):
|
||||
_threshold: int = 10000
|
||||
|
||||
def __init__(
|
||||
self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1
|
||||
) -> None:
|
||||
if max_size < 1:
|
||||
raise ValueError(f"max_size must be greater than 1, not {max_size}")
|
||||
|
||||
if not 0 <= pref_size <= max_size:
|
||||
raise ValueError(f"pref_size must be between 0 and max_size ({max_size})")
|
||||
|
||||
self.max_size: int = max_size
|
||||
self.pref_size: int = pref_size
|
||||
self.prefill: int = prefill
|
||||
self._prefill: int = prefill
|
||||
self._last_tx_seq: int = -1
|
||||
self._has_item: threading.Event = threading.Event()
|
||||
# self._lock: threading.Lock = threading.Lock()
|
||||
self._buffer: list[Packet] = []
|
||||
|
||||
def _push(self, packet: Packet) -> None:
|
||||
heapq.heappush(self._buffer, packet)
|
||||
|
||||
def _pop(self) -> Packet:
|
||||
return heapq.heappop(self._buffer)
|
||||
|
||||
def _get_packet_if_ready(self) -> Packet | None:
|
||||
return self._buffer[0] if len(self._buffer) > self.pref_size else None
|
||||
|
||||
def _pop_if_ready(self) -> Packet | None:
|
||||
return self._pop() if len(self._buffer) > self.pref_size else None
|
||||
|
||||
def _update_has_item(self) -> None:
|
||||
prefilled = self._prefill == 0
|
||||
packet_ready = len(self._buffer) > self.pref_size
|
||||
|
||||
if not prefilled or not packet_ready:
|
||||
self._has_item.clear()
|
||||
return
|
||||
|
||||
next_packet = self._buffer[0]
|
||||
sequential = add_wrapped(self._last_tx_seq, 1) == next_packet.sequence
|
||||
positive_seq = self._last_tx_seq >= 0
|
||||
|
||||
if (
|
||||
(sequential and positive_seq)
|
||||
or not positive_seq
|
||||
or len(self._buffer) >= self.max_size
|
||||
):
|
||||
self._has_item.set()
|
||||
else:
|
||||
self._has_item.clear()
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
while len(self._buffer) > self.max_size:
|
||||
heapq.heappop(self._buffer)
|
||||
|
||||
def push(self, packet: Packet) -> bool:
|
||||
seq = packet.sequence
|
||||
|
||||
if (
|
||||
gap_wrapped(self._last_tx_seq, seq) > self._threshold
|
||||
and self._last_tx_seq != -1
|
||||
):
|
||||
_log.debug("Dropping old packet %s", packet)
|
||||
return False
|
||||
|
||||
self._push(packet)
|
||||
|
||||
if self._prefill > 0:
|
||||
self._prefill -= 1
|
||||
|
||||
self._cleanup()
|
||||
self._update_has_item()
|
||||
return True
|
||||
|
||||
def pop(self, *, timeout: float | None = 0) -> Packet | None:
|
||||
ok = self._has_item.wait(timeout)
|
||||
if not ok:
|
||||
return None
|
||||
|
||||
if self._prefill > 0:
|
||||
return None
|
||||
|
||||
packet = self._pop_if_ready()
|
||||
|
||||
if packet is not None:
|
||||
self._last_tx_seq = packet.sequence
|
||||
|
||||
self._update_has_item()
|
||||
return packet
|
||||
|
||||
def peek(self, *, all: bool = False) -> Packet | None:
|
||||
if not self._buffer:
|
||||
return None
|
||||
|
||||
if all:
|
||||
return self._buffer[0]
|
||||
else:
|
||||
return self._get_packet_if_ready()
|
||||
|
||||
def peek_next(self) -> Packet | None:
|
||||
packet = self.peek(all=True)
|
||||
|
||||
if packet is None:
|
||||
return None
|
||||
|
||||
if (
|
||||
packet.sequence == add_wrapped(self._last_tx_seq, 1)
|
||||
or self._last_tx_seq < 0
|
||||
):
|
||||
return packet
|
||||
|
||||
def gap(self) -> int:
|
||||
if self._buffer and self._last_tx_seq > 0:
|
||||
return gap_wrapped(self._last_tx_seq, self._buffer[0].sequence)
|
||||
return 0
|
||||
|
||||
def flush(self) -> list[Packet]:
|
||||
packets = sorted(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
if packets:
|
||||
self._last_tx_seq = packets[-1].sequence
|
||||
|
||||
self._prefill = self.prefill
|
||||
self._has_item.clear()
|
||||
return packets
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer.clear()
|
||||
self._has_item.clear()
|
||||
self._prefill = self.prefill
|
||||
self._last_tx_seq = -1
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
try:
|
||||
import davey
|
||||
except ImportError:
|
||||
HAS_DAVEY = False
|
||||
DAVE_PROTOCOL_VERSION = 0
|
||||
else:
|
||||
HAS_DAVEY = True
|
||||
DAVE_PROTOCOL_VERSION = davey.DAVE_PROTOCOL_VERSION
|
||||
|
||||
try:
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
except ImportError:
|
||||
HAS_NACL = False
|
||||
else:
|
||||
HAS_NACL = True
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MultiDataEvent(Generic[T]):
|
||||
"""
|
||||
Something like the inverse of a Condition. A 1-waiting-on-N type of object,
|
||||
with accompanying data object for convenience.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._items: list[T] = []
|
||||
self._ready: threading.Event = threading.Event()
|
||||
|
||||
@property
|
||||
def items(self) -> list[T]:
|
||||
"""A shallow copy of the currently ready objects."""
|
||||
return self._items.copy()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self._ready.is_set()
|
||||
|
||||
def _check_ready(self) -> None:
|
||||
if self._items:
|
||||
self._ready.set()
|
||||
else:
|
||||
self._ready.clear()
|
||||
|
||||
def notify(self) -> None:
|
||||
self._ready.set()
|
||||
self._check_ready()
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
self._check_ready()
|
||||
return self._ready.wait(timeout)
|
||||
|
||||
def register(self, item: T) -> None:
|
||||
self._items.append(item)
|
||||
self._ready.set()
|
||||
|
||||
def unregister(self, item: T) -> None:
|
||||
try:
|
||||
self._items.remove(item)
|
||||
except ValueError:
|
||||
pass
|
||||
self._check_ready()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._items.clear()
|
||||
self._ready.clear()
|
||||
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
def gap_wrapped(a: int, b: int, *, wrap: int = 65536) -> int:
|
||||
return (b - (a + 1) + wrap) % wrap
|
||||
|
||||
|
||||
def add_wrapped(a: int, b: int, *, wrap: int = 65536) -> int:
|
||||
return (a + b) % wrap
|
||||
Reference in New Issue
Block a user