On branch DiscordProfile

Initial commit
This commit is contained in:
EG
2026-07-01 15:15:07 +03:00
commit d4bf750c9e
3125 changed files with 601334 additions and 0 deletions
@@ -0,0 +1,19 @@
"""
discord.voice
~~~~~~~~~~~~~
Voice support for the Discord API.
:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development
:license: MIT, see LICENSE for more details.
"""
from ..errors import MissingVoiceDependenciesError
from ..utils import get_missing_voice_dependencies
if missing := get_missing_voice_dependencies():
raise MissingVoiceDependenciesError(missing=missing)
from ._types import *
from .client import *
from .packets import *
@@ -0,0 +1,166 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Generic, TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord import abc
from discord.client import Client
from discord.raw_models import (
RawVoiceServerUpdateEvent,
RawVoiceStateUpdateEvent,
)
P = ParamSpec("P")
R = TypeVar("R")
ClientT = TypeVar("ClientT", bound="Client", covariant=True)
__all__ = ("VoiceProtocol",)
class VoiceProtocol(Generic[ClientT]):
"""A class that represents the Discord voice protocol.
.. warning::
If you are an end user, you **should not construct this manually** but instead
take it from the return type in :meth:`abc.Connectable.connect <VoiceChannel.connect>`.
The parameters and methods being documented here is so third party libraries can refer to it
when implementing their own VoiceProtocol types.
This is an abstract class. The library provides a concrete implementation
under :class:`VoiceClient`.
This class allows you to implement a protocol to allow for an external
method of sending voice, such as Lavalink_ or a native library implementation.
These classes are passed to :meth:`abc.Connectable.connect <VoiceChannel.connect>`.
.. _Lavalink: https://github.com/freyacodes/Lavalink
Parameters
----------
client: :class:`Client`
The client (or its subclasses) that started the connection request.
channel: :class:`abc.Connectable`
The voice channel that is being connected to.
"""
def __init__(self, client: ClientT, channel: abc.Connectable) -> None:
self.client: ClientT = client
self.channel: abc.Connectable = channel
async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
"""|coro|
A method called when the client's voice state has changed. This corresponds
to the ``VOICE_STATE_UPDATE`` event.
Parameters
----------
data: :class:`RawVoiceStateUpdateEvent`
The voice state payload.
.. versionchanged:: 2.7
This now gets passed a `RawVoiceStateUpdateEvent` object instead of a :class:`dict`, but
accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated.
"""
raise NotImplementedError
async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
"""|coro|
A method called when the client's intially connecting to voice. This corresponds
to the ``VOICE_SERVER_UPDATE`` event.
Parameters
----------
data: :class:`RawVoiceServerUpdateEvent`
The voice server payload.
.. versionchanged:: 2.7
This now gets passed a `RawVoiceServerUpdateEvent` object instead of a :class:`dict`, but
accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated.
"""
raise NotImplementedError
async def connect(self, *, timeout: float, reconnect: bool) -> None:
"""|coro|
A method called to initialise the connection.
The library initialises this class and calls ``__init__``, and then :meth:`connect` when attempting
to start a connection to the voice. If an error ocurrs, it calls :meth:`disconnect`, so if you need
to implement any cleanup, you should manually call it in :meth:`disconnect` as the library will not
do so for you.
Within this method, to start the voice connection flow, it is recommened to use :meth:`Guild.change_voice_state`
to start the flow. After which :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called,
although this could vary and cause unexpected behaviour, but that falls under Discord's way of handling the voice
connection.
Parameters
----------
timeout: :class:`float`
The timeout for the connection.
reconnect: :class:`bool`
Whether reconnection is expected.
"""
raise NotImplementedError
async def disconnect(self, *, force: bool) -> None:
"""|coro|
A method called to terminate the voice connection.
This can be either called manually when forcing a disconnection, or when an exception in :meth:`connect` ocurrs.
It is recommended to call :meth:`cleanup` here.
Parameters
----------
force: :class:`bool`
Whether the disconnection was forced.
"""
def cleanup(self) -> None:
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are completely
done with the voice protocol instance.
This method removes it from the internal state cache that keeps track of the currently
alive voice clients. Failure to clean-up will cause subsequent connections to report that
it's still connected.
**The library will NOT automatically call this for you**, unlike :meth:`connect` and :meth:`disconnect`.
"""
key, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key)
@@ -0,0 +1,810 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import datetime
import logging
import struct
import warnings
from typing import TYPE_CHECKING, Any, Literal, overload
from discord import opus
from discord.enums import SpeakingState, try_enum
from discord.errors import ClientException
from discord.player import AudioPlayer, AudioSource
from discord.sinks.core import Sink
from discord.sinks.errors import RecordingException
from discord.utils import MISSING
from ..utils import get_missing_voice_dependencies
from ._types import VoiceProtocol
from .enums import OpCodes
from .receive import AudioReader
from .state import VoiceConnectionState
from .utils.dependencies import HAS_DAVEY, HAS_NACL
if HAS_NACL:
import nacl.secret
import nacl.utils
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord import abc
from discord.client import Client
from discord.guild import Guild, VocalGuildChannel
from discord.member import Member
from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Encoder
from discord.raw_models import (
RawVoiceServerUpdateEvent,
RawVoiceStateUpdateEvent,
)
from discord.state import ConnectionState
from discord.types.voice import SupportedModes
from discord.user import ClientUser, User
from .gateway import VoiceWebSocket
from .receive.reader import AfterCallback
P = ParamSpec("P")
_log = logging.getLogger(__name__)
__all__ = ("VoiceClient",)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
You do not create these, you typically get them from e.g.
:meth:`VoiceChannel.connect`.
Attributes
----------
session_id: :class:`str`
The voice connection session ID.
token: :class:`str`
The voice connection token.
endpoint: :class:`str`
The endpoint we are connecting to.
channel: Union[:class:`VoiceChannel`, :class:`StageChannel`]
The channel we are connected to.
Warning
-------
In order to use PCM based AudioSources, you must have the opus library
installed on your system and loaded through :func:`opus.load_opus`.
Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`)
or the library will not be able ot transmit audio.
"""
def __init__(
self,
client: Client,
channel: abc.Connectable,
) -> None:
missing = get_missing_voice_dependencies()
if missing:
deps = ", ".join(missing)
raise RuntimeError(
f"{deps} {'library is' if len(missing) == 1 else 'libraries are'} needed "
"in order to use voice related features, "
'you can run "pip install py-cord[voice]" to install all voice-related '
"dependencies."
)
super().__init__(client, channel)
state = client._connection
self.server_id: int = MISSING
self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state
self.sequence: int = 0
self.timestamp: int = 0
self._player: AudioPlayer | None = None
self._player_future: asyncio.Future[None] | None = None
self.encoder: Encoder = MISSING
self._incr_nonce: int = 0
self._connection: VoiceConnectionState = self.create_connection_state()
self._ssrc_to_id: dict[int, int] = {}
self._id_to_ssrc: dict[int, int] = {}
self._event_listeners: dict[str, list] = {}
self._reader: AudioReader = MISSING
@staticmethod
def _set_future_result_if_pending(
future: asyncio.Future[Any], result: Exception | None
) -> None:
if not future.done():
future.set_result(result)
supported_modes: tuple[SupportedModes, ...] = (
"aead_xchacha20_poly1305_rtpsize",
"xsalsa20_poly1305_lite",
"xsalsa20_poly1305_suffix",
"xsalsa20_poly1305",
)
@property
def guild(self) -> Guild:
"""Returns the guild the channel we're connecting to is bound to."""
channel: VocalGuildChannel = self.channel
return channel.guild
@property
def user(self) -> ClientUser:
"""The user connected to voice (i.e. ourselves)"""
return self._state.user # type: ignore
@property
def session_id(self) -> str | None:
"""The session ID of the current voice call."""
return self._connection.session_id
@property
def token(self) -> str | None:
"""The token of the voice connection. You should not share this."""
return self._connection.token
@property
def endpoint(self) -> str | None:
"""The endpoint where the client is connected."""
return self._connection.endpoint
@property
def ssrc(self) -> int:
"""The SSRC of the current user in the voice call."""
return self._connection.ssrc
@property
def mode(self) -> SupportedModes:
"""The encryption / decryption mode the voice client is currently using."""
return self._connection.mode
@property
def secret_key(self) -> list[int]:
"""Returns the secret key of the current connected voice call."""
return self._connection.secret_key
@property
def ws(self) -> VoiceWebSocket:
return self._connection.ws
@property
def timeout(self) -> float:
"""The amount of ms the client waits before killing connect attempts."""
return self._connection.timeout
def is_dave_connection(self) -> bool:
"""Whether the voice client is connected to a DAVE call."""
session = self._connection.dave_session
return session is not None
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
setattr(self, attr, 0)
else:
setattr(self, attr, val + value)
def create_connection_state(self) -> VoiceConnectionState:
return VoiceConnectionState(self, hook=self._recv_hook)
async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None:
op = msg["op"]
data = msg.get("d", {})
if op == OpCodes.ready:
self._add_ssrc(self.guild.me.id, data["ssrc"])
elif op == OpCodes.speaking:
uid = int(data["user_id"])
ssrc = data["ssrc"]
self._add_ssrc(uid, ssrc)
member = self.guild.get_member(uid)
state = try_enum(SpeakingState, data["speaking"])
self.dispatch("member_speaking_state_update", member, ssrc, state)
elif op == OpCodes.clients_connect:
uids = list(map(int, data["user_ids"]))
for uid in uids:
member = self.guild.get_member(uid)
if not member:
_log.warning(
"Skipping member referencing ID %d on member_connect", uid
)
continue
self.dispatch("member_connect", member)
elif op == OpCodes.client_disconnect:
uid = int(data["user_id"])
ssrc = self._id_to_ssrc.get(uid)
if self._reader and ssrc is not None:
_log.debug("Destroying decoder for user %d, ssrc=%d", uid, ssrc)
self._reader.packet_router.destroy_decoder(ssrc)
self._remove_ssrc(user_id=uid)
member = self.guild.get_member(uid)
self.dispatch("member_disconnect", member, ssrc)
# maybe handle video and such things?
async def _run_event(
self, coro, event_name: str, *args: Any, **kwargs: Any
) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
pass
except Exception:
_log.exception("Error calling %s", event_name)
def _schedule_event(
self, coro, event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
return self.client.loop.create_task(
wrapped, name=f"voice-receiver-event-dispatch: {event_name}"
)
def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
_log.debug("Dispatching voice_client event %s", event)
event_name = f"on_{event}"
for coro in self._event_listeners.get(event_name, []):
task = self._schedule_event(coro, event_name, *args, **kwargs)
self._connection._dispatch_task_set.add(task)
task.add_done_callback(self._connection._dispatch_task_set.discard)
self._dispatch_sink(event, *args, **kwargs)
self.client.dispatch(event, *args, **kwargs)
async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
old_channel_id = self.channel.id if self.channel else None
await self._connection.voice_state_update(data)
if data.channel_id is None:
return
if self._reader and data.channel_id != old_channel_id:
_log.debug("Destroying voice receive decoders in guild %s", self.guild.id)
self._reader.packet_router.destroy_all_decoders()
async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
await self._connection.voice_server_update(data)
def _dispatch_sink(self, event: str, /, *args: Any, **kwargs: Any) -> None:
if self._reader:
self._reader.event_router.dispatch(event, *args, **kwargs)
def _add_ssrc(self, user_id: int, ssrc: int) -> None:
self._ssrc_to_id[ssrc] = user_id
self._id_to_ssrc[user_id] = ssrc
if self._reader:
self._reader.packet_router.set_user_id(ssrc, user_id)
def _remove_ssrc(self, *, user_id: int) -> None:
ssrc = self._id_to_ssrc.pop(user_id, None)
if ssrc:
self._reader.speaking_timer.drop_ssrc(ssrc)
self._ssrc_to_id.pop(ssrc, None)
async def connect(
self,
*,
reconnect: bool,
timeout: float,
self_deaf: bool = False,
self_mute: bool = False,
) -> None:
await self._connection.connect(
reconnect=reconnect,
timeout=timeout,
self_deaf=self_deaf,
self_mute=self_mute,
resume=False,
)
def wait_until_connected(self, timeout: float | None = 30.0) -> bool:
self._connection.wait_for(timeout=timeout)
return self._connection.is_connected()
@property
def latency(self) -> float:
"""Latency between a HEARTBEAT and a HEARBEAT_ACK in seconds.
This chould be referred to as the Discord Voice WebSocket latency and is
and analogue of user's voice latencies as seen in the Discord client.
.. versionadded:: 1.4
"""
ws = self.ws
return float("inf") if not ws else ws.latency
@property
def average_latency(self) -> float:
"""Average of most recent 20 HEARBEAT latencies in seconds.
.. versionadded:: 1.4
"""
ws = self.ws
return float("inf") if not ws else ws.average_latency
@property
def privacy_code(self) -> str | None:
"""Returns the current voice session's privacy code, only available if the call has upgraded to use the
DAVE protocol
"""
session = self._connection.dave_session
return session and session.voice_privacy_code
async def disconnect(self, *, force: bool = False) -> None:
"""|coro|
Disconnects this voice client from voice.
"""
self.stop()
await self._connection.disconnect(force=force, wait=True)
self.cleanup()
async def move_to(
self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0
) -> None:
"""|coro|
moves you to a different voice channel.
Parameters
----------
channel: Optional[:class:`abc.Snowflake`]
The channel to move to. If this is ``None``, it is an equivalent of calling :meth:`.disconnect`.
timeout: Optional[:class:`float`]
The maximum time in seconds to wait for the channel move to be completed, defaults to 30.
If it is ``None``, then there is no timeout.
Raises
------
asyncio.TimeoutError
Waiting for channel move timed out.
"""
await self._connection.move_to(channel, timeout)
def is_connected(self) -> bool:
"""Whether the voice client is connected to voice."""
return self._connection.is_connected()
def is_playing(self) -> bool:
"""INdicates if we're playing audio."""
return self._player is not None and self._player.is_playing()
def is_paused(self) -> bool:
"""Indicates if we're playing audio, but if we're paused."""
return self._player is not None and self._player.is_paused()
# audio related
def _get_voice_packet(self, data: Any) -> bytes:
session = self._connection.dave_session
packet = session.encrypt_opus(data) if session and session.ready else data
header = bytearray(12)
# formulate rtp header
header[0] = 0x80
header[1] = 0x78
struct.pack_into(">H", header, 2, self.sequence)
struct.pack_into(">I", header, 4, self.timestamp)
struct.pack_into(">I", header, 8, self.ssrc)
encrypt_packet = getattr(self, f"_encrypt_{self.mode}")
return encrypt_packet(header, packet)
# encryption methods
def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes:
# deprecated
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:12] = header
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes:
# deprecated
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes:
# deprecated
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:4] = struct.pack(">I", self._incr_nonce)
self.checked_add("_incr_nonce", 1, 4294967295)
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def _encrypt_aead_xchacha20_poly1305_rtpsize(
self, header: bytes, data: Any
) -> bytes:
box = nacl.secret.Aead(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:4] = struct.pack(">I", self._incr_nonce)
self.checked_add("_incr_nonce", 1, 4294967295)
return (
header
+ box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext
+ nonce[:4]
)
@overload
def play(
self,
source: AudioSource,
*,
after: AfterCallback | None = ...,
application: APPLICATION_CTL = ...,
bitrate: int = ...,
fec: bool = ...,
expected_packet_loss: float = ...,
bandwidth: BAND_CTL = ...,
signal_type: SIGNAL_CTL = ...,
wait_finish: Literal[False] = ...,
) -> None: ...
@overload
def play(
self,
source: AudioSource,
*,
after: AfterCallback = ...,
application: APPLICATION_CTL = ...,
bitrate: int = ...,
fec: bool = ...,
expected_packet_loss: float = ...,
bandwidth: BAND_CTL = ...,
signal_type: SIGNAL_CTL = ...,
wait_finish: Literal[True],
) -> asyncio.Future[None]: ...
def play(
self,
source: AudioSource,
*,
after: AfterCallback | None = None,
application: APPLICATION_CTL = "audio",
bitrate: int = 128,
fec: bool = True,
expected_packet_loss: float = 0.15,
bandwidth: BAND_CTL = "full",
signal_type: SIGNAL_CTL = "auto",
wait_finish: bool = False,
) -> None | asyncio.Future[None]:
"""Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted
or an error occurred.
IF an error happens while the audio player is running, the exception is
caught and the audio player is then stopped. If no after callback is passed,
any caught exception will be displayed as if it were raised.
Parameters
----------
source: :class:`AudioSource`
The audio source we're reading from.
after: Callable[[Optional[:class:`Exception`]], Any]
The finalizer that is called after the stream is exhausted.
This function must have a single parameter, ``error``, that
denotes an optional exception that was raised during playing.
application: :class:`str`
The intended application encoder application type. Must be one of
``audio``, ``voip``, or ``lowdelay``. Defaults to ``audio``.
bitrate: :class:`int`
The encoder's bitrate. Must be between ``16`` and ``512``. Defaults
to ``128``.
fec: :class:`bool`
Configures the encoder's use of inband forward error correction (fec).
Defaults to ``True``.
expected_packet_loss: :class:`float`
How much packet loss percentage is expected from the encoder. This requires ``fec``
to be set to ``True``. Defaults to ``0.15``.
bandwidth: :class:`str`
The encoder's bandpass. Must be one of ``narrow``, ``medium``, ``wide``,
``superwide``, or ``full``. Defaults to ``full``.
signal_type: :class:`str`
The type of signal being encoded. Must be one of ``auto``, ``voice``, ``music``.
Defaults to ``auto``.
wait_finish: :class:`bool`
If ``True``, then an awaitable is returned that waits for the audio source to be
exhausted, and will return an optional exception that could have been raised.
If ``False``, ``None`` is returned and the function does not block.
.. versionadded:: 2.5
Raises
------
ClientException
Already playing audio, or not connected to voice.
TypeError
Source is not a :class:`AudioSource`, or after is not a callable.
OpusNotLoaded
Source is not opus encoded and opus is not loaded.
"""
if not self.is_connected():
raise ClientException("Not connected to voice")
if self.is_playing():
raise ClientException("Already playing audio")
if not isinstance(source, AudioSource):
raise TypeError(
f"Source must be an AudioSource, not {source.__class__.__name__}",
)
if not self.encoder and not source.is_opus():
self.encoder = opus.Encoder(
application=application,
bitrate=bitrate,
fec=fec,
expected_packet_loss=expected_packet_loss,
bandwidth=bandwidth,
signal_type=signal_type,
)
future = None
if wait_finish:
self._player_future = future = self.loop.create_future()
after_callback = after
def _after(exc: Exception | None) -> None:
if callable(after_callback):
after_callback(exc)
self.loop.call_soon_threadsafe(
self._set_future_result_if_pending, future, exc
)
after = _after
self._player = AudioPlayer(source, self, after=after)
self._player.start()
return future
def stop(self) -> None:
"""Stops playing audio, if applicable."""
if self._player:
self._player.stop()
if self._player_future:
self.loop.call_soon_threadsafe(
self._set_future_result_if_pending, self._player_future, None
)
if self._reader is not MISSING:
self._reader.stop()
self._reader = MISSING
self._player = None
self._player_future = None
def pause(self) -> None:
"""Pauses the audio playing."""
if self._player:
self._player.pause()
def resume(self) -> None:
"""Resumes the audio playing."""
if self._player:
self._player.resume()
@property
def source(self) -> AudioSource | None:
"""The audio source being player, if playing.
This property can also be used to change the audio source currently being played.
"""
return self._player and self._player.source
@source.setter
def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource):
raise TypeError(f"expected AudioSource, not {value.__class__.__name__}")
if self._player is None:
raise ValueError("the client is not playing anything")
self._player.set_source(value)
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
"""Sends an audio packet composed of the ``data``.
You must be connected to play audio.
Parameters
----------
data: :class:`bytes`
The :term:`py:bytes-like object` denoting PCM or Opus voice data.
encode: :class:`bool`
Indicates if ``data`` should be encoded into Opus.
Raises
------
ClientException
You are not connected.
opus.OpusError
Encoding the data failed.
"""
self.checked_add("sequence", 1, 65535)
if encode:
encoded = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
else:
encoded = data
packet = self._get_voice_packet(encoded)
try:
self._connection.send_packet(packet)
except OSError:
_log.debug(
"A packet has been dropped (seq: %s, timestamp: %s)",
self.sequence,
self.timestamp,
)
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
def elapsed(self) -> datetime.timedelta:
"""Returns the elapsed time of the playing audio."""
if self._player:
return datetime.timedelta(milliseconds=self._player.played_frames() * 20)
return datetime.timedelta()
def start_recording(
self,
sink: Sink,
callback: AfterCallback | None = None,
*args: Any,
sync_start: bool = MISSING,
) -> None:
r"""Start recording the audio from the current connected channel to the provided sink.
.. versionadded:: 2.0
.. warning::
Recording may not work as expected due to the new DAVE (End-to-End Encryption) for voice calls.
Parameters
----------
sink: :class:`~.Sink`
A Sink in which all audio packets will be processed in.
callback: Callable[[:class:`Exception` | None], Any]
A function which is called after the bot has stopped recording. This must take exactly one positonal(-only)
parameter, ``exception``, which is the exception that was raised during the recording of the Sink.
.. versionchanged:: 2.7
This parameter is now optional, and must take exactly one parameter, ``exception``.
\*args:
The arguments to pass to the callback coroutine.
.. deprecated:: 2.7
Passing custom arguments to the callback is now deprecated and ignored.
sync_start: :class:`bool`
If ``True``, the recordings of subsequent users will start with silence.
This is useful for recording audio just as it was heard.
.. deprecated:: 2.7
This parameter is now ignored and deprecated.
Raises
------
RecordingException
Not connected to a voice channel
TypeError
You did not provide a Sink object.
"""
warnings.warn(
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
RuntimeWarning,
stacklevel=2,
)
# TODO: remove warning in voice-recv fix PR
if not self.is_connected():
raise RecordingException("not connected to a voice channel")
if not isinstance(sink, Sink):
raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}")
if self.is_recording():
raise ClientException("Already recording audio")
if len(args) > 0:
warnings.warn(
"'args' parameter is deprecated since 2.7 and will be removed in 3.0"
)
if sync_start is not MISSING:
warnings.warn(
"'sync_start' parameter is deprecated since 2.7 and will be removed in 3.0"
)
self._reader = AudioReader(sink, self, after=callback, start=True)
start_listening = start_recording
def stop_recording(self) -> None:
"""Stops the recording of the provided ``sink``, or all recording sinks.
.. versionadded:: 2.0
Raises
------
RecordingException
You are not recording.
"""
warnings.warn(
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
RuntimeWarning,
stacklevel=2,
)
if self._reader is not MISSING:
self._reader.stop()
self._reader = MISSING
else:
raise RecordingException("You are not recording")
stop_listening = stop_recording
def is_recording(self) -> bool:
"""Whether the current client is recording in any sink."""
return self._reader and self._reader.is_listening()
def is_speaking(self, member: Member | User) -> bool | None:
"""Whether a user is speaking.
This is an approximate calculation and may have outdated or wrong data.
If the member speaking status has not been yet saved, it returns ``None``.
.. versionadded:: 2.7
"""
warnings.warn(
"Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. "
+ "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139",
RuntimeWarning,
stacklevel=2,
)
ssrc = self._id_to_ssrc.get(member.id)
if ssrc is None:
return None
if self._reader:
return self._reader.speaking_timer.get_speaking(ssrc)
@@ -0,0 +1,79 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from discord.enums import Enum
class OpCodes(Enum):
identify = 0
select_protocol = 1
ready = 2
heartbeat = 3
session_description = 4
speaking = 5
heartbeat_ack = 6
resume = 7
hello = 8
resumed = 9
clients_connect = 11
client_connect = 12
client_disconnect = 13
# dave protocol stuff
dave_prepare_transition = 21
dave_execute_transition = 22
dave_transition_ready = 23
dave_prepare_epoch = 24
mls_external_sender_package = 25
mls_key_package = 26
mls_proposals = 27
mls_commit_welcome = 28
mls_commit_transition = 29
mls_welcome = 30
mls_invalid_commit_welcome = 31
def __eq__(self, other: object) -> bool:
if isinstance(other, int):
return self.value == other
elif isinstance(other, self.__class__):
return self is other
return NotImplemented
def __int__(self) -> int:
return self.value
class ConnectionFlowState(Enum):
disconnected = 0
set_guild_voice_state = 1
got_voice_state_update = 2
got_voice_server_update = 3
got_both_voice_updates = 4
websocket_connected = 5
got_websocket_ready = 6
got_ip_discovery = 7
connected = 8
@@ -0,0 +1,521 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import logging
import struct
import threading
import time
from collections import deque
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any
import aiohttp
from .utils.dependencies import HAS_DAVEY
if HAS_DAVEY:
import davey
from discord import utils
from discord.enums import SpeakingState
from discord.errors import ConnectionClosed
from discord.gateway import DiscordWebSocket
from discord.gateway import KeepAliveHandler as KeepAliveHandlerBase
from .enums import OpCodes
if TYPE_CHECKING:
from _typeshed import ConvertibleToInt
from typing_extensions import Self
from .state import VoiceConnectionState
_log = logging.getLogger(__name__)
class KeepAliveHandler(KeepAliveHandlerBase):
if TYPE_CHECKING:
ws: VoiceWebSocket
def __init__(
self,
*args: Any,
ws: VoiceWebSocket,
interval: float | None = None,
**kwargs: Any,
) -> None:
daemon: bool = kwargs.pop("daemon", True)
name: str = kwargs.pop("name", f"voice-keep-alive-handler:{id(self):#x}")
super().__init__(
*args,
**kwargs,
name=name,
daemon=daemon,
ws=ws,
interval=interval,
)
self.msg: str = "Keeping shard ID %s voice websocket alive with timestamp %s."
self.block_msg: str = (
"Shard ID %s voice heartbeat blocked for more than %s seconds."
)
self.behing_msg: str = (
"High socket latency, shard ID %s heartbeat is %.1fs behind."
)
self.recent_ack_latencies: deque[float] = deque(maxlen=20)
def get_payload(self) -> dict[str, Any]:
return {
"op": int(OpCodes.heartbeat),
"d": {
"t": int(time.time() * 1000),
"seq_ack": self.ws.seq_ack,
},
}
def ack(self) -> None:
ack_time = time.perf_counter()
self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
class VoiceWebSocket(DiscordWebSocket):
def __init__(
self,
socket: aiohttp.ClientWebSocketResponse,
loop: asyncio.AbstractEventLoop,
state: VoiceConnectionState,
*,
hook: Callable[..., Coroutine[Any, Any, Any]] | None = None,
) -> None:
self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: KeepAliveHandler | None = None
self._close_code: int | None = None
self.secret_key: list[int] | None = None
self.seq_ack: int = -1
self.state: VoiceConnectionState = state
self.ssrc_map: dict[str, dict[str, Any]] = {}
self.known_users: dict[int, Any] = {}
if hook:
self._hook = hook or state.ws_hook # type: ignore
@property
def token(self) -> str | None:
return self.state.token
@token.setter
def token(self, value: str | None) -> None:
self.state.token = value
@property
def session_id(self) -> str | None:
return self.state.session_id
@session_id.setter
def session_id(self, value: str | None) -> None:
self.state.session_id = value
@property
def self_id(self) -> int:
return self._connection.self_id
async def _hook(self, *args: Any) -> Any:
pass
async def send_as_bytes(self, op: ConvertibleToInt, data: bytes) -> None:
packet = bytes([int(op)]) + data
_log.debug(
"Sending voice websocket binary frame: op: %s size: %d", op, len(data)
)
await self.ws.send_bytes(packet)
async def send_as_json(self, data: Any) -> None:
_log.debug("Sending voice websocket frame: %s.", data)
if data.get("op", None) == OpCodes.identify:
_log.info("Identifying ourselves: %s", data)
await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json
async def resume(self) -> None:
payload = {
"op": int(OpCodes.resume),
"d": {
"token": self.token,
"server_id": str(self.state.server_id),
"session_id": self.session_id,
"seq_ack": self.seq_ack,
},
}
await self.send_as_json(payload)
async def received_message(self, msg: Any, /):
_log.debug("Voice websocket frame received: %s", msg)
op = msg["op"]
data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways
self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated
state = self.state
if op == OpCodes.ready:
await self.ready(data)
elif op == OpCodes.heartbeat_ack:
if not self._keep_alive:
_log.error(
"Received a heartbeat ACK but no keep alive handler was set.",
)
return
self._keep_alive.ack()
elif op == OpCodes.resumed:
_log.info(
f"Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was "
"successfully RESUMED.",
)
elif op == OpCodes.session_description:
state.mode = data["mode"]
state.dave_protocol_version = data["dave_protocol_version"]
await self.load_secret_key(data)
await state.reinit_dave_session()
elif op == OpCodes.hello:
interval = data["heartbeat_interval"] / 1000.0
self._keep_alive = KeepAliveHandler(
ws=self,
interval=min(interval, 5),
)
self._keep_alive.start()
elif state.dave_session:
if op == OpCodes.dave_prepare_transition:
_log.info(
"Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d",
state.channel_id,
data["transition_id"],
data["protocol_version"],
)
state.dave_pending_transition = data
transition_id = data["transition_id"]
if transition_id == 0:
await state.execute_dave_transition(data["transition_id"])
else:
if data["protocol_version"] == 0 and state.dave_session:
state.dave_session.set_passthrough_mode(True, 10)
await self.send_dave_transition_ready(transition_id)
elif op == OpCodes.dave_execute_transition:
_log.info(
"Upgrading to DAVE connection for channel %s", state.channel_id
)
await state.execute_dave_transition(data["transition_id"])
elif op == OpCodes.dave_prepare_epoch:
epoch = data["epoch"]
_log.debug(
"Preparing for DAVE epoch in channel %s: %s",
state.channel_id,
epoch,
)
# if epoch is 1 then a new MLS group is to be created for the proto version
if epoch == 1:
state.dave_protocol_version = data["protocol_version"]
await state.reinit_dave_session()
else:
_log.debug("Unhandled op code: %s with data %s", op, data)
await utils.maybe_coroutine(self._hook, self, msg)
async def received_binary_message(self, msg: bytes) -> None:
self.seq_ack = struct.unpack_from(">H", msg, 0)[0]
op = msg[2]
_log.debug(
"Voice websocket binary frame received: %d bytes, seq: %s, op: %s",
len(msg),
self.seq_ack,
op,
)
state = self.state
if not state.dave_session:
return
if op == OpCodes.mls_external_sender_package:
_log.debug("Received MLS External Sender Package, applying to DAVE session")
state.dave_session.set_external_sender(msg[3:])
_log.debug(
"Applied MLS External Sender Package, user IDs available: %s",
state.dave_session.get_user_ids(),
)
elif op == OpCodes.mls_proposals:
op_type = msg[3]
result = state.dave_session.process_proposals(
(
davey.ProposalsOperationType.append
if op_type == 0
else davey.ProposalsOperationType.revoke
),
msg[4:],
)
if isinstance(result, davey.CommitWelcome):
data = (
(result.commit + result.welcome)
if result.welcome
else result.commit
)
_log.debug("Sending MLS key package with data: %s", data)
await self.send_as_bytes(
OpCodes.mls_commit_welcome,
data,
)
_log.debug("Processed MLS proposals for current dave session: %r", result)
elif op == OpCodes.mls_commit_transition:
transt_id = struct.unpack_from(">H", msg, 3)[0]
try:
state.dave_session.process_commit(msg[5:])
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
_log.debug(
"Sending DAVE transition ready from MLS commit transition with data: %s",
state.dave_pending_transition,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS commit for transition %s", transt_id)
except Exception as exc:
_log.debug(
"An exception ocurred while processing a MLS commit, this should be safe to ignore: %s",
exc,
)
await state.recover_dave_from_invalid_commit(transt_id)
elif op == OpCodes.mls_welcome:
transt_id = struct.unpack_from(">H", msg, 3)[0]
try:
state.dave_session.process_welcome(msg[5:])
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
_log.debug(
"Sending DAVE transition ready from MLS welcome with data: %s",
state.dave_pending_transition,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS welcome for transition %s", transt_id)
except Exception as exc:
_log.debug(
"An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s",
exc,
)
await state.recover_dave_from_invalid_commit(transt_id)
async def ready(self, data: dict[str, Any]) -> None:
state = self.state
state.ssrc = data["ssrc"]
state.voice_port = data["port"]
state.endpoint_ip = data["ip"]
_log.debug(
f"Connecting to {state.endpoint_ip} (port {state.voice_port}).",
)
await self.loop.sock_connect(
state.socket,
(state.endpoint_ip, state.voice_port),
)
_log.debug(
"Connected socket to %s (port %s)",
state.endpoint_ip,
state.voice_port,
)
state.ip, state.port = await self.get_ip()
modes = [mode for mode in data["modes"] if mode in self.state.supported_modes]
_log.debug("Received available voice connection modes: %s", modes)
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug("Selected voice protocol %s for this connection", mode)
async def select_protocol(self, ip: str, port: int, mode: str) -> None:
payload = {
"op": int(OpCodes.select_protocol),
"d": {
"protocol": "udp",
"data": {
"address": ip,
"port": port,
"mode": mode,
},
},
}
await self.send_as_json(payload)
async def get_ip(self) -> tuple[str, int]:
state = self.state
packet = bytearray(74)
struct.pack_into(">H", packet, 0, 1) # 1 = Send
struct.pack_into(">H", packet, 2, 70) # 70 = Length
struct.pack_into(">I", packet, 4, state.ssrc)
_log.debug(
f"Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})"
)
await self.loop.sock_sendall(state.socket, packet)
fut: asyncio.Future[bytes] = self.loop.create_future()
def get_ip_packet(data: bytes) -> None:
if data[1] == 0x02 and len(data) == 74:
self.loop.call_soon_threadsafe(fut.set_result, data)
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
state.add_socket_listener(get_ip_packet)
recv = await fut
_log.debug("Received IP discovery packet with data %s", recv)
ip_start = 8
ip_end = recv.index(0, ip_start)
ip = recv[ip_start:ip_end].decode("ascii")
port = struct.unpack_from(">H", recv, len(recv) - 2)[0]
_log.debug("Detected IP %s with port %s", ip, port)
return ip, port
@property
def latency(self) -> float:
heartbeat = self._keep_alive
return float("inf") if heartbeat is None else heartbeat.latency
@property
def average_latency(self) -> float:
heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies:
return float("inf")
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data: dict[str, Any]) -> None:
_log.debug(
f"Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})"
)
self.secret_key = self.state.secret_key = data["secret_key"]
await self.speak(SpeakingState.none)
async def poll_event(self) -> None:
msg = await asyncio.wait_for(self.ws.receive(), timeout=30)
if msg.type is aiohttp.WSMsgType.TEXT:
_log.debug("Received text payload: %s", msg.data)
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.BINARY:
_log.debug("Received binary payload: size: %d", len(msg.data))
await self.received_binary_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code: int = 1000) -> None:
if self._keep_alive:
self._keep_alive.stop()
self._close_code = code
await self.ws.close(code=self._close_code)
async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:
await self.send_as_json(
{
"op": int(OpCodes.speaking),
"d": {
"speaking": int(state),
"delay": 0,
},
},
)
@classmethod
async def from_state(
cls,
state: VoiceConnectionState,
*,
resume: bool = False,
hook: Callable[..., Coroutine[Any, Any, Any]] | None = None,
seq_ack: int = -1,
) -> Self:
gateway = f"wss://{state.endpoint}/?v=8"
client = state.client
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook, state=state)
ws.gateway = gateway
ws.seq_ack = seq_ack
ws._max_heartbeat_timeout = 60.0
ws.thread_id = threading.get_ident()
if resume:
await ws.resume()
else:
await ws.identify()
return ws
async def identify(self) -> None:
state = self.state
payload = {
"op": int(OpCodes.identify),
"d": {
"server_id": str(state.server_id),
"user_id": str(state.user.id),
"session_id": self.session_id,
"token": self.token,
"max_dave_protocol_version": state.max_dave_proto_version,
},
}
await self.send_as_json(payload)
async def send_dave_transition_ready(self, transition_id: int) -> None:
payload = {
"op": int(OpCodes.dave_transition_ready),
"d": {
"transition_id": transition_id,
},
}
await self.send_as_json(payload)
@@ -0,0 +1,63 @@
"""
discord.voice.packets
~~~~~~~~~~~~~~~~~~~~~
Sink packet handlers.
:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development
:license: MIT, see LICENSE for more details.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .core import Packet
from .rtp import (
FakePacket,
ReceiverReportPacket,
RTCPPacket,
RTPPacket,
SenderReportPacket,
SilencePacket,
)
if TYPE_CHECKING:
from discord import Member, User
__all__ = (
"Packet",
"RTPPacket",
"RTCPPacket",
"FakePacket",
"ReceiverReportPacket",
"SenderReportPacket",
"SilencePacket",
"VoiceData",
)
class VoiceData:
"""Represents an audio data from a source.
.. versionadded:: 2.7
Attributes
----------
packet: :class:`~discord.sinks.Packet`
The packet this source data contains.
source: :class:`~discord.User` | :class:`~discord.Member` | None
The user that emitted this audio source.
pcm: :class:`bytes`
The PCM bytes of this source.
"""
def __init__(
self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None
) -> None:
self.packet: Packet = packet
self.source: User | Member | None = source
self.pcm: bytes = pcm if pcm else b""
@property
def opus(self) -> bytes | None:
return self.packet.decrypted_data
@@ -0,0 +1,93 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Final
OPUS_SILENCE: Final = b"\xf8\xff\xfe"
class Packet:
"""Represents an audio stream bytes packet.
Attributes
----------
data: :class:`bytes`
The bytes data of this packet. This has not been decoded.
"""
if TYPE_CHECKING:
ssrc: int
sequence: int
timestamp: int
type: int
decrypted_data: bytes
def __init__(self, data: bytes) -> None:
self.data: bytes = data
def __repr__(self) -> str:
return f"<{self.__class__.__name__}> data={len(self.data)} bytes>"
def __bool__(self) -> bool:
return bool(self.data)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
if self.ssrc != other.ssrc:
raise TypeError(
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
)
return self.sequence == other.sequence and self.timestamp == other.timestamp
def __gt__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
if self.ssrc != other.ssrc:
raise TypeError(
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
)
return self.sequence > other.sequence and self.timestamp > other.timestamp
def __lt__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
if self.ssrc != other.ssrc:
raise TypeError(
f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})"
)
return self.sequence < other.sequence and self.timestamp < other.timestamp
def is_silence(self) -> bool:
data = getattr(self, "decrypted_data", None)
return data == OPUS_SILENCE
def __hash__(self) -> int:
return hash(self.data)
@@ -0,0 +1,320 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import struct
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Literal
from .core import OPUS_SILENCE, Packet
if TYPE_CHECKING:
from typing_extensions import Final
MAX_UINT_32 = 0xFFFFFFFF
MAX_UINT_16 = 0xFFFF
RTP_PACKET_TYPE_VOICE = 120
def decode(data: bytes) -> Packet:
if not data[0] >> 6 == 2:
raise ValueError(f"Invalid packet header 0b{data[0]:0>8b}")
return _rtcp_map.get(data[1], RTPPacket)(data)
class FakePacket(Packet):
data = b""
decrypted_data: bytes = b""
extension_data: dict = {}
def __init__(
self,
ssrc: int,
sequence: int,
timestamp: int,
) -> None:
self.ssrc = ssrc
self.sequence = sequence
self.timestamp = timestamp
def __bool__(self) -> Literal[False]:
return False
class SilencePacket(Packet):
decrypted_data: Final = OPUS_SILENCE
extension_data: Final[dict[int, Any]] = {}
sequence: int = -1
def __init__(self, ssrc: int, timestamp: int) -> None:
self.ssrc = ssrc
self.timestamp = timestamp
def is_silence(self) -> bool:
return True
class RTPPacket(Packet):
"""Represents an RTP packet.
.. versionadded:: 2.7
Attributes
----------
data: :class:`bytes`
The raw data of the packet.
"""
_hstruct = struct.Struct(">xxHII")
_ext_header = namedtuple("Extension", "profile length values")
_ext_magic = b"\xbe\xde"
def __init__(self, data: bytes) -> None:
super().__init__(data)
self.version: int = data[0] >> 6
self.padding: bool = bool(data[0] & 0b00100000)
self.extended: bool = bool(data[0] & 0b00010000)
self.cc: int = data[0] & 0b00001111
self.marker: bool = bool(data[1] & 0b10000000)
self.payload: int = data[1] & 0b01111111
sequence, timestamp, ssrc = self._hstruct.unpack_from(data)
self.sequence = sequence
self.timestamp = timestamp
self.ssrc = ssrc
self.csrcs: tuple[int, ...] = ()
self.extension = None
self.extension_data: dict[int, bytes] = {}
self.header = data[:12]
self.data = data[12:]
self.decrypted_data: bytes | None = None
self.nonce: bytes = b""
self._rtpsize: bool = False
if self.cc:
fmt = ">%sI" % self.cc
offset = struct.calcsize(fmt) + 12
self.csrcs = struct.unpack(fmt, data[12:offset])
self.data = data[offset:]
def adjust_rtpsize(self) -> None:
"""Automatically adjusts this packet header and data based on the rtpsize format."""
self._rtpsize = True
self.nonce = self.data[-4:]
if not self.extended:
self.data = self.data[:-4]
return
self.header += self.data[:4]
self.data = self.data[4:-4]
def update_extended_header(self, data: bytes) -> int:
"""Updates the extended header using ``data`` and returns the pd offset."""
if not self.extended:
return 0
if self._rtpsize:
data = self.header[-4:] + data
if len(data) < 4:
return 0
profile, length = struct.unpack_from(">2sH", data)
total_ext_length = length * 4
if profile == self._ext_magic:
self._parse_bede_header(data, length)
if len(data) >= 4 + total_ext_length:
try:
values = struct.unpack(">%sI" % length, data[4 : 4 + total_ext_length])
self.extension = self._ext_header(profile, length, values)
except struct.error:
self.extension = self._ext_header(profile, 0, [])
offset = 4 + total_ext_length
if self._rtpsize:
offset -= 4
return max(0, min(offset, len(data)))
def _parse_bede_header(self, data: bytes, length: int) -> None:
offset = 4
n = 0
max_bytes = length * 4 + 4
while n < length:
if offset >= len(data) or offset >= max_bytes:
break
next_byte = data[offset : offset + 1]
if next_byte == b"\x00":
offset += 1
continue
header = struct.unpack(">B", next_byte)[0]
el_id = header >> 4
el_len = 1 + (header & 0b0000_1111)
self.extension_data[el_id] = data[offset + 1 : offset + 1 + el_len]
offset += 1 + el_len
n += 1
def __repr__(self) -> str:
return (
"<RTPPacket "
f"ssrc={self.ssrc} "
f"sequence={self.sequence} "
f"timestamp={self.timestamp} "
f"size={len(self.data)} "
f"ext={set(self.extension_data)}"
">"
)
class RTCPPacket(Packet):
_header = struct.Struct(">BBH")
_ssrc_fmt = struct.Struct(">I")
type = None
def __init__(self, data: bytes) -> None:
super().__init__(data)
self.length: int
head, _, self.length = self._header.unpack_from(data)
self.version: int = head >> 6
self.padding: bool = bool(head & 0b00100000)
setattr(self, "report_count", head & 0b00011111)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} version={self.version} padding={self.padding} length={self.length}>"
@classmethod
def from_data(cls, data: bytes) -> Packet:
_, ptype, _ = cls._header.unpack_from(data)
return _rtcp_map[ptype](data)
def _parse_low(x: int, bitlen: int = 32) -> float:
return x / 2.0**bitlen
def _to_low(x: float, bitlen: int = 32) -> int:
return int(x * 2.0**bitlen)
class SenderReportPacket(RTCPPacket):
_info_fmt = struct.Struct(">5I")
_report_fmt = struct.Struct(">IB3x4I")
_24bit_int_fmt = struct.Struct(">4xI")
_info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count")
_report = namedtuple(
"RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr"
)
type = 200
if TYPE_CHECKING:
report_count: int
def __init__(self, data: bytes) -> None:
super().__init__(data)
self.ssrc = self._ssrc_fmt.unpack_from(data, 4)[0]
self.info = self._read_sender_info(data, 8)
_report = self._report
reports: list[_report] = []
for x in range(self.report_count):
offset = 28 + 24 * x
reports.append(self._read_report(data, offset))
self.reports: tuple[_report, ...] = tuple(reports)
self.extension = None
if len(data) > 28 + 24 * self.report_count:
self.extension = data[28 + 24 * self.report_count :]
def _read_sender_info(self, data: bytes, offset: int) -> _info:
nhigh, nlow, rtp_ts, pcount, ocount = self._info_fmt.unpack_from(data, offset)
ntotal = nhigh + _parse_low(nlow)
return self._info(ntotal, rtp_ts, pcount, ocount)
def _read_report(self, data: bytes, offset: int) -> _report:
ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset)
clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF
return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr)
class ReceiverReportPacket(RTCPPacket):
_report_fmt = struct.Struct(">IB3x4I")
_24bit_int_fmt = struct.Struct(">4xI")
_report = namedtuple(
"RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr"
)
type = 201
reports: tuple[_report, ...]
if TYPE_CHECKING:
report_count: int
def __init__(self, data: bytes) -> None:
super().__init__(data)
self.ssrc: int = self._ssrc_fmt.unpack_from(data, 4)[0]
_report = self._report
reports: list[_report] = []
for x in range(self.report_count):
offset = 8 + 24 * x
reports.append(self._read_report(data, offset))
self.reports = tuple(reports)
self.extension: bytes | None = None
if len(data) > 8 + 24 * self.report_count:
self.extension = data[8 + 24 * self.report_count :]
def _read_report(self, data: bytes, offset: int) -> _report:
ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset)
clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF
return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr)
_rtcp_map = {
200: SenderReportPacket,
201: ReceiverReportPacket,
}
@@ -0,0 +1,33 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .reader import AudioReader
from .router import PacketRouter, SinkEventRouter
__all__ = (
"AudioReader",
"PacketRouter",
"SinkEventRouter",
)
@@ -0,0 +1,577 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import logging
import threading
import time
from collections.abc import Callable
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Literal
from ..packets.core import OPUS_SILENCE
from ..packets.rtp import ReceiverReportPacket, RTCPPacket, decode
from ..utils.dependencies import HAS_DAVEY, HAS_NACL
from .router import PacketRouter, SinkEventRouter
if HAS_DAVEY:
import davey
if HAS_NACL:
import nacl.secret
from nacl.exceptions import CryptoError
if TYPE_CHECKING:
from discord.member import Member
from discord.sinks import Sink
from discord.types.voice import SupportedModes
from ..client import VoiceClient
from ..packets import RTPPacket
AfterCallback = Callable[[Exception | None], Any]
DecryptRTP = Callable[[RTPPacket], bytes]
DecryptRTCP = Callable[[bytes], bytes]
SpeakingEvent = Literal["member_speaking_start", "member_speaking_stop"]
EncryptionBox = nacl.secret.SecretBox | nacl.secret.Aead
_log = logging.getLogger(__name__)
__all__ = ("AudioReader",)
def is_rtcp(data: bytes) -> bool:
return 200 <= data[1] <= 204
class AudioReader:
def __init__(
self,
sink: Sink,
client: VoiceClient,
*,
after: AfterCallback | None = None,
start: bool = False,
) -> None:
if after is not None and not callable(after):
raise TypeError(
f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead"
)
self.sink: Sink = sink
self.client: VoiceClient = client
self.after: AfterCallback | None = after
# self.sink._client = client
self.active: bool = False
self.error: Exception | None = None
self.packet_router: PacketRouter = PacketRouter(self.sink, self)
self.event_router: SinkEventRouter = SinkEventRouter(self.sink, self)
self.decryptor: PacketDecryptor = PacketDecryptor(
client.mode, bytes(client.secret_key), client
)
self.speaking_timer: SpeakingTimer = SpeakingTimer(self)
self.keep_alive: UDPKeepAlive = UDPKeepAlive(client)
if start:
self.start()
def is_listening(self) -> bool:
return self.active
def update_secret_key(self, secret_key: bytes) -> None:
self.decryptor.update_secret_key(secret_key)
def start(self) -> None:
if self.active:
_log.debug("Reader is already running", exc_info=True)
return
self.client._connection.add_socket_listener(self.callback)
self.speaking_timer.start()
self.event_router.start()
self.packet_router.start()
self.keep_alive.start()
self.active = True
def stop(self) -> None:
if not self.active:
_log.debug("Reader is not active")
return
self.client._connection.remove_socket_listener(self.callback)
self.speaking_timer.notify()
self._stop()
self.active = False
def _stop(self) -> None:
try:
if self.packet_router.is_alive():
self.packet_router.stop()
except Exception as exc:
self.error = exc
_log.exception("An error ocurred while stopping packet router.")
try:
self.event_router.stop()
except Exception as exc:
self.error = exc
_log.exception("An error ocurred while stopping event router.")
self.speaking_timer.stop()
self.keep_alive.stop()
if self.after:
try:
self.after(self.error)
except Exception:
_log.exception(
"An error ocurred while calling the after callback on audio reader"
)
"""for sink in self.sink.root.walk_children(with_self=True):
try:
sink.cleanup()
except Exception as exc:
_log.exception("Error calling cleanup() for %s", sink, exc_info=exc)"""
def set_sink(self, sink: Sink) -> Sink:
old_sink = self.sink
# old_sink._client = None
# sink._client = self.client
self.packet_router.set_sink(sink)
self.sink = sink
return old_sink
def _is_ip_discovery_packet(self, data: bytes) -> bool:
return len(data) == 74 and data[1] == 0x02
def callback(self, packet_data: bytes) -> None:
packet = rtp_packet = rtcp_packet = None
try:
if not is_rtcp(packet_data):
packet = rtp_packet = decode(packet_data)
packet.decrypted_data = self.decryptor.decrypt_rtp(packet) # type: ignore
else:
packet = rtcp_packet = decode(packet_data)
if not isinstance(packet, ReceiverReportPacket):
_log.info(
"Received unexpected rtcp packet type=%s, %s",
packet.type,
type(packet),
)
except CryptoError as exc:
_log.error("CryptoError while decoding a voice packet", exc_info=exc)
return
except Exception as exc:
if self._is_ip_discovery_packet(packet_data):
_log.debug("Received an IP Discovery Packet, ignoring...")
return
_log.exception(
"An exception ocurred while decoding voice packets", exc_info=exc
)
finally:
if self.error:
_log.debug("Callback errored out, stopping: %s", self.error)
self.stop()
return
if not packet:
_log.debug("No packet found after callback")
return
if rtcp_packet:
self.packet_router.feed_rtcp(rtcp_packet) # type: ignore
elif rtp_packet:
if not rtp_packet.decrypted_data:
_log.debug(
"No decrypted data for RTP packet, this should be safe to ignore."
)
return
ssrc = rtp_packet.ssrc
if ssrc not in self.client._connection.ssrc_user_map:
if rtp_packet.is_silence():
return
else:
_log.info(
"Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet
)
_log.debug(
"Current SSRCs: %s", self.client._connection.ssrc_user_map
)
self.speaking_timer.notify(ssrc)
try:
_log.debug("Feeding packet to packet router")
self.packet_router.feed_rtp(rtp_packet) # type: ignore
except Exception as exc:
_log.exception(
"An error ocurred while processing RTP packet %s", rtp_packet
)
self.error = exc
self.stop()
class PacketDecryptor:
supported_modes: list[SupportedModes] = [
"aead_xchacha20_poly1305_rtpsize",
"xsalsa20_poly1305",
"xsalsa20_poly1305_lite",
"xsalsa20_poly1305_suffix",
]
def __init__(
self, mode: SupportedModes, secret_key: bytes, client: VoiceClient
) -> None:
self.mode: SupportedModes = mode
self.client: VoiceClient = client
try:
self._decryptor_rtp: DecryptRTP = getattr(self, "_decrypt_rtp_" + mode)
self._decryptor_rtcp: DecryptRTCP = getattr(self, "_decrypt_rtcp_" + mode)
except AttributeError as exc:
raise NotImplementedError(mode) from exc
self.box: EncryptionBox = self._make_box(secret_key)
def _make_box(self, secret_key: bytes) -> EncryptionBox:
if self.mode.startswith("aead"):
return nacl.secret.Aead(secret_key)
else:
return nacl.secret.SecretBox(secret_key)
"""def decrypt_rtp(self, packet: RTPPacket) -> bytes:
state = self.client._connection
dave = state.dave_session
data = self._decryptor_rtp(packet)
if dave is not None and dave.ready and packet.ssrc in state.ssrc_user_map:
data = dave.decrypt(
state.ssrc_user_map[packet.ssrc], davey.MediaType.audio, data
)
if packet.extended:
offset = packet.update_extended_header(data)
data = data[offset:]
return data"""
def decrypt_rtp(self, packet: RTPPacket) -> bytes:
state = self.client._connection
dave = state.dave_session
raw_payload = self._decryptor_rtp(packet)
if dave is not None and dave.ready:
uid = state.ssrc_user_map.get(packet.ssrc)
if uid:
try:
decrypted_audio = dave.decrypt(
uid,
davey.MediaType.audio,
raw_payload,
)
if packet.extended:
offset = packet.update_extended_header(decrypted_audio)
packet.decrypted_data = decrypted_audio[offset:]
else:
packet.decrypted_data = decrypted_audio
except Exception as exc:
_log.debug(
"Ignoring exception while decoding DAVE packet", exc_info=exc
)
packet.decrypted_data = OPUS_SILENCE
return packet.decrypted_data
def decrypt_rtcp(self, packet: bytes) -> bytes:
data = self._decryptor_rtcp(packet)
# parse the rtcp packet to its respective report type
offset = 0
while offset < len(data):
# offset will allow us to read the compund packets
current_data = data[offset:]
if len(current_data) < 8:
break
p_header = RTCPPacket.from_data(current_data)
# the sender ssrc will always be at offset 4 of the current packet
# doesn't matter if it is a sr or a rr
ssrc = p_header.ssrc
state = self.client._connection
dave = state.dave_session
if dave is not None and dave.ready and ssrc in state.ssrc_user_map:
return dave.decrypt(
state.ssrc_user_map[ssrc],
davey.MediaType.audio,
current_data,
)
return data
def update_secret_key(self, secret_key: bytes) -> None:
self.box = self._make_box(secret_key)
def _decrypt_rtp_xsalsa20_poly1305(self, packet: RTPPacket) -> bytes:
nonce = bytearray(24)
nonce[:12] = packet.header
result = self.box.decrypt(bytes(packet.data), bytes(nonce))
if packet.extended:
offset = packet.update_extended_header(result)
result = result[offset:]
return result
def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes:
nonce = bytearray(24)
nonce[:8] = data[:8]
result = self.box.decrypt(data[8:], bytes(nonce))
return data[:8] + result
def _decrypt_rtp_xsalsa20_poly1305_suffix(self, packet: RTPPacket) -> bytes:
nonce = packet.data[-24:]
voice_data = packet.data[:-24]
result = self.box.decrypt(bytes(voice_data), bytes(nonce))
if packet.extended:
offset = packet.update_extended_header(result)
result = result[offset:]
return result
def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes:
nonce = data[-24:]
header = data[:8]
result = self.box.decrypt(data[8:-24], nonce)
return header + result
def _decrypt_rtp_xsalsa20_poly1305_lite(self, packet: RTPPacket) -> bytes:
nonce = bytearray(24)
nonce[:4] = packet.data[-4:]
voice_data = packet.data[:-4]
result = self.box.decrypt(bytes(voice_data), bytes(nonce))
if packet.extended:
offset = packet.update_extended_header(result)
result = result[offset:]
return result
def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes:
nonce = bytearray(24)
nonce[:4] = data[-4:]
header = data[:8]
result = self.box.decrypt(data[8:-4], bytes(nonce))
return header + result
def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> bytes:
_log.debug(
"Decrypting RTP AEAD XChaCha20 Poly1305 RTPSize, has decrypted data?: %s",
packet.decrypted_data is not None,
)
packet.adjust_rtpsize()
nonce = packet.nonce + b"\x00" * 20
assert isinstance(self.box, nacl.secret.Aead)
try:
result = self.box.decrypt(
packet.decrypted_data or packet.data,
bytes(packet.header),
nonce,
)
except Exception as exc:
_log.error("Critical error at AEAD: %s", exc)
raise CryptoError(exc)
if packet.extended:
packet.update_extended_header(result)
return result[8:]
def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes:
_log.debug("Decrypting RTCP AEAD XChaCha20 Poly1305 RTPSize")
nonce = bytearray(24)
nonce[:4] = data[-4:]
header = data[:8]
assert isinstance(self.box, nacl.secret.Aead)
result = self.box.decrypt(data[8:-4], bytes(header), bytes(nonce))
return header + result
class SpeakingTimer(threading.Thread):
def __init__(self, reader: AudioReader) -> None:
super().__init__(
daemon=True,
name=f"voice-receiver-speaking-timer:{id(self):#x}",
)
self.reader: AudioReader = reader
self.client: VoiceClient = reader.client
self.speaking_timeout_delay: float = 0.2
self.last_speaking_state: dict[int, bool] = {}
self.speaking_cache: dict[int, float] = {}
self.speaking_timer_event: threading.Event = threading.Event()
self._end_thread: threading.Event = threading.Event()
def _lookup_member(self, ssrc: int) -> Member | None:
id = self.client._connection.ssrc_user_map.get(ssrc)
if not self.client.guild:
return None
return self.client.guild.get_member(id) if id else None
def maybe_dispatch_speaking_start(self, ssrc: int) -> None:
tlast = self.speaking_cache.get(ssrc)
if tlast is None or tlast + self.speaking_timeout_delay < time.perf_counter():
self.dispatch("member_speaking_start", ssrc)
def dispatch(self, event: SpeakingEvent, ssrc: int) -> None:
member = self._lookup_member(ssrc)
if not member:
return None
self.client._dispatch_sink(event, member)
def notify(self, ssrc: int | None = None) -> None:
if ssrc is not None:
self.last_speaking_state[ssrc] = True
self.maybe_dispatch_speaking_start(ssrc)
self.speaking_cache[ssrc] = time.perf_counter()
self.speaking_timer_event.set()
self.speaking_timer_event.clear()
def drop_ssrc(self, ssrc: int) -> None:
self.speaking_cache.pop(ssrc, None)
state = self.last_speaking_state.pop(ssrc, None)
if state:
self.dispatch("member_speaking_stop", ssrc)
self.notify()
def get_speaking(self, ssrc: int) -> bool | None:
return self.last_speaking_state.get(ssrc)
def stop(self) -> None:
self._end_thread.set()
self.notify()
def run(self) -> None:
_i1 = itemgetter(1)
def get_next_entry():
cache = sorted(self.speaking_cache.items(), key=_i1)
for ssrc, tlast in cache:
if self.last_speaking_state.get(ssrc):
return ssrc, tlast
return None, None
self.speaking_timer_event.wait()
while not self._end_thread.is_set():
if not self.speaking_cache:
self.speaking_timer_event.wait()
tnow = time.perf_counter()
ssrc, tlast = get_next_entry()
if ssrc is None or tlast is None:
self.speaking_timer_event.wait()
continue
self.speaking_timer_event.wait(tlast + self.speaking_timeout_delay - tnow)
if time.perf_counter() < tlast + self.speaking_timeout_delay:
continue
self.dispatch("member_speaking_stop", ssrc)
self.last_speaking_state[ssrc] = False
class UDPKeepAlive(threading.Thread):
delay: int = 5000
def __init__(self, client: VoiceClient) -> None:
super().__init__(
daemon=True,
name=f"voice-receiver-udp-keep-alive:{id(self):#x}",
)
self.client: VoiceClient = client
self.last_time: float = 0
self.counter: int = 0
self._end_thread: threading.Event = threading.Event()
def run(self) -> None:
self.client.wait_until_connected()
while not self._end_thread.is_set():
vc = self.client
try:
packet = self.counter.to_bytes(8, "big")
except OverflowError:
self.counter = 0
continue
try:
vc._connection.socket.sendto(
packet, (vc._connection.endpoint_ip, vc._connection.voice_port)
)
except Exception as exc:
_log.debug(
"Error while sending udp keep alive to socket %s at %s:%s",
vc._connection.socket,
vc._connection.endpoint_ip,
vc._connection.voice_port,
exc_info=exc,
)
vc.wait_until_connected()
if vc.is_connected():
continue
break
else:
self.counter += 1
time.sleep(self.delay)
def stop(self) -> None:
self._end_thread.set()
@@ -0,0 +1,231 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import logging
import queue
import threading
from collections import deque
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from discord.opus import PacketDecoder
from ..utils.multidataevent import MultiDataEvent
if TYPE_CHECKING:
from discord.sinks import Sink
from ..packets import RTCPPacket, RTPPacket
from .reader import AudioReader
EventCB = Callable[..., Any]
EventData = tuple[str, tuple[Any, ...], dict[str, Any]]
_log = logging.getLogger(__name__)
class PacketRouter(threading.Thread):
def __init__(self, sink: Sink, reader: AudioReader) -> None:
super().__init__(
daemon=True,
name=f"voice-receiver-packet-router:{id(self):#x}",
)
self.sink: Sink = sink
self.decoders: dict[int, PacketDecoder] = {}
self.reader: AudioReader = reader
self.waiter: MultiDataEvent[PacketDecoder] = MultiDataEvent()
self._lock: threading.RLock = threading.RLock()
self._end_thread: threading.Event = threading.Event()
self._dropped_ssrcs: deque[int] = deque(maxlen=16)
def feed_rtp(self, packet: RTPPacket) -> None:
if packet.ssrc in self._dropped_ssrcs:
_log.debug("Ignoring packet from dropped ssrc %s", packet.ssrc)
with self._lock:
decoder = self.get_decoder(packet.ssrc)
if decoder is not None:
decoder.push_packet(packet)
def feed_rtcp(self, packet: RTCPPacket) -> None:
guild = self.sink.client.guild if self.sink.client else None
event_router = self.reader.event_router
event_router.dispatch("rtcp_packet", packet, guild)
def get_decoder(self, ssrc: int) -> PacketDecoder | None:
with self._lock:
decoder = self.decoders.get(ssrc)
if decoder is None:
decoder = self.decoders[ssrc] = PacketDecoder(self, ssrc)
return decoder
def set_sink(self, sink: Sink) -> None:
with self._lock:
self.sink = sink
def set_user_id(self, ssrc: int, user_id: int) -> None:
with self._lock:
if ssrc in self._dropped_ssrcs:
self._dropped_ssrcs.remove(ssrc)
decoder = self.decoders.get(ssrc)
if decoder is not None:
decoder.set_user_id(user_id)
def destroy_decoder(self, ssrc: int) -> None:
with self._lock:
decoder = self.decoders.pop(ssrc, None)
if decoder is not None:
self._dropped_ssrcs.append(ssrc)
decoder.destroy()
def destroy_all_decoders(self) -> None:
with self._lock:
for ssrc in self.decoders.keys():
self.destroy_decoder(ssrc)
def stop(self) -> None:
self._end_thread.set()
self.waiter.notify()
def run(self) -> None:
try:
self._do_run()
except Exception as exc:
_log.exception("Error in %s loop", self)
self.reader.error = exc
finally:
self.reader.client.stop_recording()
self.waiter.clear()
def _do_run(self) -> None:
while not self._end_thread.is_set():
self.waiter.wait()
with self._lock:
for decoder in self.waiter.items:
data = decoder.pop_data()
if data is not None:
self.sink.write(data, data.source)
class SinkEventRouter(threading.Thread):
def __init__(self, sink: Sink, reader: AudioReader) -> None:
super().__init__(
daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}"
)
self.sink: Sink = sink
self.reader: AudioReader = reader
self._event_listeners: dict[str, list[EventCB]] = {}
self._buffer: queue.SimpleQueue[EventData] = queue.SimpleQueue()
self._lock = threading.RLock()
self._end_thread = threading.Event()
self.register_events()
def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
_log.debug("Dispatch voice event %s", event)
self._buffer.put_nowait((event, args, kwargs))
def set_sink(self, sink: Sink) -> None:
with self._lock:
self.unregister_events()
self.sink = sink
self.register_events()
def register_events(self) -> None:
with self._lock:
self._register_listeners(self.sink)
for child in self.sink.walk_children():
self._register_listeners(child)
def unregister_events(self) -> None:
with self._lock:
self._unregister_listeners(self.sink)
for child in self.sink.walk_children():
self._unregister_listeners(child)
def _register_listeners(self, sink: Sink) -> None:
_log.debug("Registering events for %s: %s", sink, sink.__sink_listeners__)
for name, method_name in sink.__sink_listeners__:
func = getattr(sink, method_name)
_log.debug("Registering event: %r (callback at %r)", name, method_name)
if name in self._event_listeners:
self._event_listeners[name].append(func)
else:
self._event_listeners[name] = [func]
def _unregister_listeners(self, sink: Sink) -> None:
for name, method_name in sink.__sink_listeners__:
func = getattr(sink, method_name)
if name in self._event_listeners:
try:
self._event_listeners[name].remove(func)
except ValueError:
pass
def _dispatch_to_listeners(self, event: str, *args: Any, **kwargs: Any) -> None:
for listener in self._event_listeners.get(f"on_{event}", []):
try:
listener(*args, **kwargs)
except Exception as exc:
_log.exception(
"Unhandled exception while dispatching event %s (args: %s; kwargs: %s)",
event,
args,
kwargs,
exc_info=exc,
)
def stop(self) -> None:
self._end_thread.set()
def run(self) -> None:
try:
self._do_run()
except Exception as exc:
_log.exception("Error in sink event router", exc_info=exc)
self.reader.error = exc
self.reader.client.stop_recording()
def _do_run(self) -> None:
while not self._end_thread.is_set():
try:
event, args, kwargs = self._buffer.get(timeout=0.5)
except queue.Empty:
continue
else:
with self._lock:
with self.reader.packet_router._lock:
self._dispatch_to_listeners(event, *args, **kwargs)
@@ -0,0 +1,983 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import logging
import select
import socket
import threading
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any
from discord import utils
from discord.backoff import ExponentialBackoff
from discord.errors import ConnectionClosed
from discord.voice.utils.dependencies import DAVE_PROTOCOL_VERSION, HAS_DAVEY
from .enums import ConnectionFlowState, OpCodes
from .gateway import VoiceWebSocket
if TYPE_CHECKING:
from discord import abc
from discord.guild import Guild
from discord.member import VoiceState
from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent
from discord.state import ConnectionState
from discord.types.voice import SupportedModes
from discord.user import ClientUser
from .client import VoiceClient
MISSING = utils.MISSING
SocketReaderCallback = Callable[[bytes], Any]
_log = logging.getLogger(__name__)
_recv_log = logging.getLogger("discord.voice.receiver")
if HAS_DAVEY:
import davey
class SocketReader(threading.Thread):
def __init__(
self,
state: VoiceConnectionState,
name: str,
buffer_size: int,
*,
start_paused: bool = True,
) -> None:
super().__init__(
daemon=True,
name=name,
)
self.buffer_size: int = buffer_size
self.state: VoiceConnectionState = state
self.start_paused: bool = start_paused
self._callbacks: list[SocketReaderCallback] = []
self._running: threading.Event = threading.Event()
self._end: threading.Event = threading.Event()
self._idle_paused: bool = True
self._started: threading.Event = threading.Event()
self._warned_wait: bool = False
def is_running(self) -> bool:
return self._started.is_set()
def register(self, callback: SocketReaderCallback) -> None:
self._callbacks.append(callback)
if self._idle_paused:
self._idle_paused = False
self._running.set()
def unregister(self, callback: SocketReaderCallback) -> None:
try:
self._callbacks.remove(callback)
except ValueError:
pass
else:
if not self._callbacks and self._running.is_set():
self._idle_paused = True
self._running.clear()
def pause(self) -> None:
self._idle_paused = False
self._running.clear()
def is_paused(self) -> bool:
return self._idle_paused or (
not self._running.is_set() and not self._end.is_set()
)
def resume(self, *, force: bool = False) -> None:
if self._running.is_set():
return
if not force and not self._callbacks:
self._idle_paused = True
return
self._idle_paused = False
self._running.set()
def stop(self) -> None:
self._started.clear()
self._end.set()
self._running.set()
def run(self) -> None:
self._started.set()
self._end.clear()
self._running.set()
if self.start_paused:
self.pause()
try:
self._do_run()
except Exception:
_log.exception(
"An error ocurred while running the socket reader %s",
self.name,
)
finally:
self.stop()
self._started.clear()
self._running.clear()
self._callbacks.clear()
def _do_run(self) -> None:
while not self._end.is_set():
if not self._running.is_set():
if not self._warned_wait:
_log.warning(
"Socket reader %s is waiting to be set as running", self.name
)
self._warned_wait = True
self._running.wait()
continue
if self._warned_wait:
_log.info("Socket reader %s was set as running", self.name)
self._warned_wait = False
try:
readable, _, _ = select.select([self.state.socket], [], [], 30)
except (ValueError, TypeError, OSError) as e:
_log.debug(
"Select error handling socket in reader, this should be safe to ignore: %s: %s",
e.__class__.__name__,
e,
)
continue
if not readable:
continue
try:
data = self.state.socket.recv(self.buffer_size)
except OSError:
_log.debug(
"Error reading from socket in %s, this should be safe to ignore",
self,
exc_info=True,
)
else:
for cb in self._callbacks:
try:
task = asyncio.ensure_future(
self.state.loop.create_task(
utils.maybe_coroutine(cb, data)
),
loop=self.state.loop,
)
self.state._dispatch_task_set.add(task)
task.add_done_callback(self.state._dispatch_task_set.discard)
except Exception:
_log.exception(
"Error while calling %s in %s",
cb,
self,
)
class SocketEventReader(SocketReader):
def __init__(
self, state: VoiceConnectionState, *, start_paused: bool = True
) -> None:
super().__init__(
state,
f"voice-socket-event-reader:{id(self):#x}",
2048,
start_paused=start_paused,
)
class VoiceConnectionState:
def __init__(
self,
client: VoiceClient,
*,
hook: (
Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None
) = None,
) -> None:
self.client: VoiceClient = client
self.hook = hook
self.loop: asyncio.AbstractEventLoop = client.loop
self.timeout: float = 30.0
self.reconnect: bool = True
self.self_deaf: bool = False
self.self_mute: bool = False
self.endpoint: str | None = None
self.endpoint_ip: str | None = None
self.server_id: int | None = None
self.ip: str | None = None
self.port: int | None = None
self.voice_port: int | None = None
self.secret_key: list[int] = MISSING
self.ssrc: int = MISSING
self.mode: SupportedModes = MISSING
self.socket: socket.socket = MISSING
self.ws: VoiceWebSocket = MISSING
self.session_id: str | None = None
self.token: str | None = None
self._connection: ConnectionState = client._state
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
self._expecting_disconnect: bool = False
self._connected = threading.Event()
self._state_event = asyncio.Event()
self._disconnected = asyncio.Event()
self._runner: asyncio.Task[None] | None = None
self._connector: asyncio.Task[None] | None = None
self._socket_reader = SocketEventReader(self)
self._socket_reader.start()
self.recording_done_callbacks: list[
tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]]
] = []
self._dispatch_task_set: set[asyncio.Task] = set()
if not self._connection.self_id:
raise RuntimeError("client self ID is not set")
if not self.channel_id:
raise RuntimeError("client channel being connected to is not set")
self.dave_session: davey.DaveSession | None = None
self.dave_protocol_version: int = 0
self.dave_pending_transition: dict[str, int] | None = None
self.downgraded_dave = False
@property
def user_ssrc_map(self) -> dict[int, int]:
return self.client._id_to_ssrc
@property
def ssrc_user_map(self) -> dict[int, int]:
return {v: k for k, v in self.user_ssrc_map.items()}
@property
def max_dave_proto_version(self) -> int:
return DAVE_PROTOCOL_VERSION
@property
def state(self) -> ConnectionFlowState:
return self._state
@state.setter
def state(self, state: ConnectionFlowState) -> None:
if state is not self._state:
_log.debug("State changed from %s to %s", self._state.name, state.name)
self._state = state
self._state_event.set()
self._state_event.clear()
if state is ConnectionFlowState.connected:
self._connected.set()
else:
self._connected.clear()
@property
def guild(self) -> Guild:
return self.client.guild
@property
def user(self) -> ClientUser:
return self.client.user
@property
def channel_id(self) -> int | None:
return self.client.channel is not None and self.client.channel.id
@property
def guild_id(self) -> int:
return self.guild.id
@property
def supported_modes(self) -> tuple[SupportedModes, ...]:
return self.client.supported_modes
@property
def self_voice_state(self) -> VoiceState | None:
return self.guild.me.voice
def is_connected(self) -> bool:
return self.state is ConnectionFlowState.connected
def _inside_runner(self) -> bool:
return self._runner is not None and asyncio.current_task() == self._runner
async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None:
channel_id = data.channel_id
if channel_id is None:
self._disconnected.set()
if self._expecting_disconnect:
self._expecting_disconnect = False
else:
_log.debug("We have been disconnected from voice")
await self.disconnect()
return
self.session_id = data.session_id
if self.state in (
ConnectionFlowState.set_guild_voice_state,
ConnectionFlowState.got_voice_server_update,
):
if self.state is ConnectionFlowState.set_guild_voice_state:
self.state = ConnectionFlowState.got_voice_state_update
if channel_id != self.client.channel.id:
# moved from channel
self._update_voice_channel(channel_id)
else:
self.state = ConnectionFlowState.got_both_voice_updates
return
if self.state is ConnectionFlowState.connected:
self._update_voice_channel(channel_id)
elif self.state is not ConnectionFlowState.disconnected:
if channel_id != self.client.channel.id:
_log.info("We were moved from the channel while connecting...")
self._update_voice_channel(channel_id)
await self.soft_disconnect(
with_state=ConnectionFlowState.got_voice_state_update
)
await self.connect(
reconnect=self.reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
wait=False,
)
else:
_log.debug("Ignoring unexpected VOICE_STATEUPDATE event")
async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None:
previous_token = self.token
previous_server_id = self.server_id
previous_endpoint = self.endpoint
self.token = data.token
self.server_id = data.guild_id
endpoint = data.endpoint
if self.token is None or endpoint is None:
_log.warning(
"Awaiting endpoint... This requires waiting. "
"If timeout occurred considering raising the timeout and reconnecting."
)
return
# strip the prefix off since we add it later
self.endpoint = endpoint.removeprefix("wss://")
if self.state in (
ConnectionFlowState.set_guild_voice_state,
ConnectionFlowState.got_voice_state_update,
):
self.endpoint_ip = MISSING
self._create_socket()
if self.state is ConnectionFlowState.set_guild_voice_state:
self.state = ConnectionFlowState.got_voice_server_update
else:
self.state = ConnectionFlowState.got_both_voice_updates
elif self.state is ConnectionFlowState.connected:
_log.debug("Voice server update, closing old voice websocket")
await self.ws.close(4014) # 4014 = main gw dropped
self.state = ConnectionFlowState.got_voice_server_update
elif self.state is not ConnectionFlowState.disconnected:
if (
previous_token == self.token
and previous_server_id == self.server_id
and previous_endpoint == self.endpoint
):
return
_log.debug("Unexpected VOICE_SERVER_UPDATE event received, handling...")
await self.soft_disconnect(
with_state=ConnectionFlowState.got_voice_server_update
)
await self.connect(
reconnect=self.reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
wait=False,
)
self._create_socket()
async def connect(
self,
*,
reconnect: bool,
timeout: float,
self_deaf: bool,
self_mute: bool,
resume: bool,
wait: bool = True,
) -> None:
if self._connector:
self._connector.cancel()
self._connector = None
if self._runner:
self._runner.cancel()
self._runner = None
self.timeout = timeout
self.reconnect = reconnect
self._connector = self.client.loop.create_task(
self._wrap_connect(
reconnect,
timeout,
self_deaf,
self_mute,
resume,
),
name=f"voice-connector:{id(self):#x}",
)
if wait:
await self._connector
async def _wrap_connect(
self,
reconnect: bool,
timeout: float,
self_deaf: bool,
self_mute: bool,
resume: bool,
) -> None:
try:
await self._connect(reconnect, timeout, self_deaf, self_mute, resume)
except asyncio.CancelledError:
_log.debug("Cancelling voice connection")
await self.soft_disconnect()
raise
except asyncio.TimeoutError:
_log.info("Timed out while connecting to voice")
await self.disconnect()
raise
except Exception:
_log.exception("Error while connecting to voice... disconnecting")
await self.disconnect()
raise
async def _inner_connect(
self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool
) -> None:
for i in range(5):
_log.info("Starting voice handshake (connection attempt %s)", i + 1)
await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute)
if self.state is ConnectionFlowState.disconnected:
self.state = ConnectionFlowState.set_guild_voice_state
await self._wait_for_state(ConnectionFlowState.got_both_voice_updates)
_log.info("Voice handshake complete. Endpoint found: %s", self.endpoint)
try:
self.ws = await self._connect_websocket(resume)
await self._handshake_websocket()
break
except ConnectionClosed:
if reconnect:
wait = 1 + i * 2
_log.exception(
"Failed to connect to voice... Retrying in %s seconds", wait
)
await self.disconnect(cleanup=False)
await asyncio.sleep(wait)
continue
else:
await self.disconnect()
raise
async def _connect(
self,
reconnect: bool,
timeout: float,
self_deaf: bool,
self_mute: bool,
resume: bool,
) -> None:
_log.info(f"Connecting to voice {self.client.channel.id}")
await asyncio.wait_for(
self._inner_connect(
reconnect=reconnect,
self_deaf=self_deaf,
self_mute=self_mute,
resume=resume,
),
timeout=timeout,
)
_log.info("Voice connection completed")
if not self._runner:
self._runner = self.client.loop.create_task(
self._poll_ws(reconnect),
name=f"voice-ws-poller:{id(self):#x}",
)
async def disconnect(
self, *, force: bool = True, cleanup: bool = True, wait: bool = False
) -> None:
if not force and not self.is_connected():
return
_log.debug(
"Attempting a voice disconnect for channel %s (guild %s)",
self.channel_id,
self.guild_id,
)
try:
await self._voice_disconnect()
if self.ws:
await self.ws.close()
except Exception:
_log.debug(
"Ignoring exception while disconnecting from voice", exc_info=True
)
finally:
self.state = ConnectionFlowState.disconnected
self._socket_reader.pause()
if cleanup:
self._socket_reader.stop()
self.client.stop()
self._connected.set()
self._connected.clear()
if self.socket:
self.socket.close()
self.ip = MISSING
self.port = MISSING
if wait and not self._inside_runner():
try:
await asyncio.wait_for(
self._disconnected.wait(), timeout=self.timeout
)
except TimeoutError:
_log.debug("Timed out waiting for voice disconnect confirmation")
except asyncio.CancelledError:
pass
if cleanup:
self.client.cleanup()
async def soft_disconnect(
self,
*,
with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates,
) -> None:
_log.debug("Soft disconnecting from voice")
if self._runner:
self._runner.cancel()
self._runner = None
try:
if self.ws:
await self.ws.close()
except Exception:
_log.debug(
"Ignoring exception while soft disconnecting from voice", exc_info=True
)
finally:
self.state = with_state
self._socket_reader.pause()
if self.socket:
self.socket.close()
self.ip = MISSING
self.port = MISSING
async def move_to(
self, channel: abc.Snowflake | None, timeout: float | None
) -> None:
if channel is None:
await self.disconnect(wait=True)
return
if self.client.channel and channel.id == self.client.channel.id:
return
previous_state = self.state
await self._move_to(channel)
last_state = self.state
try:
await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout)
except asyncio.TimeoutError:
_log.warning(
"Timed out trying to move to channel %s in guild %s",
channel.id,
self.guild.id,
)
if self.state is last_state:
_log.debug(
"Reverting state %s to previous state: %s",
last_state.name,
previous_state.name,
)
self.state = previous_state
def wait_for(
self,
state: ConnectionFlowState = ConnectionFlowState.connected,
timeout: float | None = None,
) -> Any:
if state is ConnectionFlowState.connected:
return self._connected.wait(timeout)
return self._wait_for_state(state, timeout=timeout)
def send_packet(self, packet: bytes) -> None:
self.socket.sendall(packet)
def add_socket_listener(self, callback: SocketReaderCallback) -> None:
_log.debug("Registering a socket listener callback %s", callback)
self._socket_reader.register(callback)
def remove_socket_listener(self, callback: SocketReaderCallback) -> None:
_log.debug("Unregistering a socket listener callback %s", callback)
self._socket_reader.unregister(callback)
async def _wait_for_state(
self,
*states: ConnectionFlowState,
timeout: float | None = None,
) -> None:
if not states:
raise ValueError
while True:
if self.state in states:
return
_, pending = await asyncio.wait(
[
asyncio.ensure_future(self._state_event.wait()),
],
timeout=timeout,
)
if pending:
# if we're here, it means that the state event
# has timed out, so just raise the exception
raise asyncio.TimeoutError
async def _voice_connect(
self, *, self_deaf: bool = False, self_mute: bool = False
) -> None:
channel = self.client.channel
await channel.guild.change_voice_state(
channel=channel, self_deaf=self_deaf, self_mute=self_mute
)
async def _voice_disconnect(self) -> None:
_log.info(
"Terminating voice handshake for channel %s (guild %s)",
self.client.channel.id,
self.client.guild.id,
)
self.state = ConnectionFlowState.disconnected
await self.client.channel.guild.change_voice_state(
channel=None
) # pyright: ignore[reportAttributeAccessIssue]
self._expecting_disconnect = True
self._disconnected.clear()
async def _connect_websocket(self, resume: bool) -> VoiceWebSocket:
seq_ack = -1
if self.ws is not MISSING:
seq_ack = self.ws.seq_ack
ws = await VoiceWebSocket.from_state(
self, resume=resume, hook=self.hook, seq_ack=seq_ack
)
self.state = ConnectionFlowState.websocket_connected
return ws
async def _handshake_websocket(self) -> None:
while not self.ip:
await self.ws.poll_event()
self.state = ConnectionFlowState.got_ip_discovery
while self.ws.secret_key is None:
await self.ws.poll_event()
self.state = ConnectionFlowState.connected
def _create_socket(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
self._socket_reader.resume()
async def _poll_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
await self.ws.poll_event()
except asyncio.CancelledError:
return
except (ConnectionClosed, asyncio.TimeoutError) as exc:
if isinstance(exc, ConnectionClosed):
# 1000 - normal closure - not resumable
# 4014 - externally disconnected - not resumable
# 4015 - voice server crashed - resumable
# 4021 - ratelimited, not reconnect - not resumable
# 4022 - call terminated, similar to 4014 - not resumable
if exc.code == 1000:
if not self._expecting_disconnect:
_log.info(
"Disconnecting from voice manually, close code %d",
exc.code,
)
await self.disconnect()
break
elif exc.code in (4014, 4022):
if self._disconnected.is_set():
_log.info(
"Disconnecting from voice by Discord, close code %d",
exc.code,
)
await self.disconnect()
break
_log.info(
"Disconnecting from voice by force... potentially reconnecting..."
)
successful = await self._potential_reconnect()
if not successful:
_log.info(
"Reconnect was unsuccessful, disconnecting from voice normally"
)
if self.state is not ConnectionFlowState.disconnected:
await self.disconnect()
break
else:
# we have successfully resumed so just keep polling events
continue
elif exc.code == 4021:
_log.warning(
"We are being rate limited while attempting to connect to voice. Disconnecting...",
)
if self.state is not ConnectionFlowState.disconnected:
await self.disconnect()
break
elif exc.code == 4015:
_log.info(
"Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...",
)
try:
await self._connect(
reconnect=reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=True,
)
except asyncio.TimeoutError:
_log.info(
"Could not resume the voice connection... Disconnecting..."
)
if self.state is not ConnectionFlowState.disconnected:
await self.disconnect()
break
except Exception:
_log.exception(
"An exception was raised while attempting a reconnect and resume... Disconnecting...",
exc_info=True,
)
if self.state is not ConnectionFlowState.disconnected:
await self.disconnect()
break
else:
_log.info(
"Successfully reconnected and resumed the voice connection"
)
continue
else:
_log.debug(
"Not handling close code %s (%s)",
exc.code,
exc.reason or "No reason was provided",
)
if not reconnect:
await self.disconnect()
raise
retry = backoff.delay()
_log.exception(
"Disconnected from voice... Reconnecting in %.2fs",
retry,
)
await asyncio.sleep(retry)
await self.disconnect(cleanup=False)
try:
await self._connect(
reconnect=reconnect,
timeout=self.timeout,
self_deaf=(self.self_voice_state or self).self_deaf,
self_mute=(self.self_voice_state or self).self_mute,
resume=False,
)
except asyncio.TimeoutError:
_log.warning("Could not connect to voice... Retrying...")
continue
async def _potential_reconnect(self) -> bool:
try:
await self._wait_for_state(
ConnectionFlowState.got_voice_server_update,
ConnectionFlowState.got_both_voice_updates,
ConnectionFlowState.disconnected,
timeout=self.timeout,
)
except asyncio.TimeoutError:
return False
else:
if self.state is ConnectionFlowState.disconnected:
return False
previous_ws = self.ws
try:
self.ws = await self._connect_websocket(False)
await self._handshake_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
finally:
await previous_ws.close()
async def _move_to(self, channel: abc.Snowflake) -> None:
await self.client.channel.guild.change_voice_state(
channel=channel
) # pyright: ignore[reportAttributeAccessIssue]
self.state = ConnectionFlowState.set_guild_voice_state
def _update_voice_channel(self, channel_id: int | None) -> None:
self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore
async def reinit_dave_session(self) -> None:
assert self.channel_id
if self.dave_protocol_version > 0:
if self.dave_session:
self.dave_session.reinit(
self.dave_protocol_version, self.user.id, self.channel_id
)
else:
self.dave_session = davey.DaveSession(
self.dave_protocol_version,
self.user.id,
self.channel_id,
)
await self.ws.send_as_bytes(
OpCodes.mls_key_package,
self.dave_session.get_serialized_key_package(),
)
elif self.dave_session:
self.dave_session.reset()
self.dave_session.set_passthrough_mode(True, 10)
async def recover_dave_from_invalid_commit(self, transition: int) -> None:
payload = {
"op": int(OpCodes.mls_invalid_commit_welcome),
"d": {"transition_id": transition},
}
await self.ws.send_as_json(payload)
await self.reinit_dave_session()
async def execute_dave_transition(self, transition: int) -> None:
_log.debug("Executing DAVE transition with id %s", transition)
if not self.dave_pending_transition:
_log.warning(
"Attempted to execute a transition without having a pending transition for id %s, "
"this is a Discord bug.",
transition,
)
return
pending_transition = self.dave_pending_transition["transition_id"]
pending_proto = self.dave_pending_transition["protocol_version"]
session = self.dave_session
if transition == pending_transition:
old_version = self.dave_protocol_version
self.dave_protocol_version = pending_proto
if (
old_version != self.dave_protocol_version
and self.dave_protocol_version == 0
):
_log.warning(
"DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7"
)
self.downgraded_dave = True
elif transition > 0 and self.downgraded_dave:
self.downgraded_dave = False
if session:
session.set_passthrough_mode(True, 10)
_log.info("Upgraded voice session to use DAVE")
else:
_log.debug(
"Received an execute transition id %s when expected was %s, ignoring",
transition,
pending_proto,
)
self.dave_pending_transition = None
@@ -0,0 +1,215 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import heapq
import logging
import threading
from typing import Protocol, TypeVar
from ..packets import Packet
from .wrapped import add_wrapped, gap_wrapped
__all__ = (
"Buffer",
"JitterBuffer",
)
T = TypeVar("T")
PacketT = TypeVar("PacketT", bound=Packet)
_log = logging.getLogger(__name__)
class Buffer(Protocol[T]):
def __len__(self) -> int: ...
def push(self, item: T) -> None: ...
def pop(self) -> T | None: ...
def peek(self) -> T | None: ...
def flush(self) -> list[T]: ...
def reset(self) -> None: ...
class BaseBuff(Buffer[PacketT]):
def __init__(self) -> None:
self._buffer: list[PacketT] = []
def __len__(self) -> int:
return len(self._buffer)
def push(self, item: PacketT) -> None:
self._buffer.append(item)
def pop(self) -> PacketT | None:
return self._buffer.pop()
def peek(self) -> PacketT | None:
return self._buffer[-1] if self._buffer else None
def flush(self) -> list[PacketT]:
buf = self._buffer.copy()
self._buffer.clear()
return buf
def reset(self) -> None:
self._buffer.clear()
class JitterBuffer(BaseBuff[PacketT]):
_threshold: int = 10000
def __init__(
self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1
) -> None:
if max_size < 1:
raise ValueError(f"max_size must be greater than 1, not {max_size}")
if not 0 <= pref_size <= max_size:
raise ValueError(f"pref_size must be between 0 and max_size ({max_size})")
self.max_size: int = max_size
self.pref_size: int = pref_size
self.prefill: int = prefill
self._prefill: int = prefill
self._last_tx_seq: int = -1
self._has_item: threading.Event = threading.Event()
# self._lock: threading.Lock = threading.Lock()
self._buffer: list[Packet] = []
def _push(self, packet: Packet) -> None:
heapq.heappush(self._buffer, packet)
def _pop(self) -> Packet:
return heapq.heappop(self._buffer)
def _get_packet_if_ready(self) -> Packet | None:
return self._buffer[0] if len(self._buffer) > self.pref_size else None
def _pop_if_ready(self) -> Packet | None:
return self._pop() if len(self._buffer) > self.pref_size else None
def _update_has_item(self) -> None:
prefilled = self._prefill == 0
packet_ready = len(self._buffer) > self.pref_size
if not prefilled or not packet_ready:
self._has_item.clear()
return
next_packet = self._buffer[0]
sequential = add_wrapped(self._last_tx_seq, 1) == next_packet.sequence
positive_seq = self._last_tx_seq >= 0
if (
(sequential and positive_seq)
or not positive_seq
or len(self._buffer) >= self.max_size
):
self._has_item.set()
else:
self._has_item.clear()
def _cleanup(self) -> None:
while len(self._buffer) > self.max_size:
heapq.heappop(self._buffer)
def push(self, packet: Packet) -> bool:
seq = packet.sequence
if (
gap_wrapped(self._last_tx_seq, seq) > self._threshold
and self._last_tx_seq != -1
):
_log.debug("Dropping old packet %s", packet)
return False
self._push(packet)
if self._prefill > 0:
self._prefill -= 1
self._cleanup()
self._update_has_item()
return True
def pop(self, *, timeout: float | None = 0) -> Packet | None:
ok = self._has_item.wait(timeout)
if not ok:
return None
if self._prefill > 0:
return None
packet = self._pop_if_ready()
if packet is not None:
self._last_tx_seq = packet.sequence
self._update_has_item()
return packet
def peek(self, *, all: bool = False) -> Packet | None:
if not self._buffer:
return None
if all:
return self._buffer[0]
else:
return self._get_packet_if_ready()
def peek_next(self) -> Packet | None:
packet = self.peek(all=True)
if packet is None:
return None
if (
packet.sequence == add_wrapped(self._last_tx_seq, 1)
or self._last_tx_seq < 0
):
return packet
def gap(self) -> int:
if self._buffer and self._last_tx_seq > 0:
return gap_wrapped(self._last_tx_seq, self._buffer[0].sequence)
return 0
def flush(self) -> list[Packet]:
packets = sorted(self._buffer)
self._buffer.clear()
if packets:
self._last_tx_seq = packets[-1].sequence
self._prefill = self.prefill
self._has_item.clear()
return packets
def reset(self) -> None:
self._buffer.clear()
self._has_item.clear()
self._prefill = self.prefill
self._last_tx_seq = -1
@@ -0,0 +1,40 @@
"""
The MIT License (MIT)
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
try:
import davey
except ImportError:
HAS_DAVEY = False
DAVE_PROTOCOL_VERSION = 0
else:
HAS_DAVEY = True
DAVE_PROTOCOL_VERSION = davey.DAVE_PROTOCOL_VERSION
try:
import nacl.secret
import nacl.utils
except ImportError:
HAS_NACL = False
else:
HAS_NACL = True
@@ -0,0 +1,79 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import threading
from typing import Generic, TypeVar
T = TypeVar("T")
class MultiDataEvent(Generic[T]):
"""
Something like the inverse of a Condition. A 1-waiting-on-N type of object,
with accompanying data object for convenience.
"""
def __init__(self):
self._items: list[T] = []
self._ready: threading.Event = threading.Event()
@property
def items(self) -> list[T]:
"""A shallow copy of the currently ready objects."""
return self._items.copy()
def is_ready(self) -> bool:
return self._ready.is_set()
def _check_ready(self) -> None:
if self._items:
self._ready.set()
else:
self._ready.clear()
def notify(self) -> None:
self._ready.set()
self._check_ready()
def wait(self, timeout: float | None = None) -> bool:
self._check_ready()
return self._ready.wait(timeout)
def register(self, item: T) -> None:
self._items.append(item)
self._ready.set()
def unregister(self, item: T) -> None:
try:
self._items.remove(item)
except ValueError:
pass
self._check_ready()
def clear(self) -> None:
self._items.clear()
self._ready.clear()
@@ -0,0 +1,32 @@
"""
The MIT License (MIT)
Copyright (c) 2015-2021 Rapptz
Copyright (c) 2021-present Pycord Development
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
def gap_wrapped(a: int, b: int, *, wrap: int = 65536) -> int:
return (b - (a + 1) + wrap) % wrap
def add_wrapped(a: int, b: int, *, wrap: int = 65536) -> int:
return (a + b) % wrap