On branch DiscordProfile
Initial commit
This commit is contained in:
@@ -0,0 +1,586 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2021 Rapptz
|
||||
Copyright (c) 2021-present Pycord Development
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .backoff import ExponentialBackoff
|
||||
from .client import Client
|
||||
from .enums import Status
|
||||
from .errors import (
|
||||
ClientException,
|
||||
ConnectionClosed,
|
||||
GatewayNotFound,
|
||||
HTTPException,
|
||||
PrivilegedIntentsRequired,
|
||||
)
|
||||
from .gateway import *
|
||||
from .state import AutoShardedConnectionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .activity import BaseActivity
|
||||
from .gateway import DiscordWebSocket
|
||||
|
||||
EI = TypeVar("EI", bound="EventItem")
|
||||
|
||||
__all__ = (
|
||||
"AutoShardedClient",
|
||||
"ShardInfo",
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType:
|
||||
close = 0
|
||||
reconnect = 1
|
||||
resume = 2
|
||||
identify = 3
|
||||
terminate = 4
|
||||
clean_close = 5
|
||||
|
||||
|
||||
class EventItem:
|
||||
__slots__ = ("type", "shard", "error")
|
||||
|
||||
def __init__(
|
||||
self, etype: int, shard: Shard | None, error: Exception | None
|
||||
) -> None:
|
||||
self.type: int = etype
|
||||
self.shard: Shard | None = shard
|
||||
self.error: Exception | None = error
|
||||
|
||||
def __lt__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type < other.type
|
||||
|
||||
def __eq__(self: EI, other: EI) -> bool:
|
||||
if not isinstance(other, EventItem):
|
||||
return NotImplemented
|
||||
return self.type == other.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.type)
|
||||
|
||||
|
||||
class Shard:
|
||||
def __init__(
|
||||
self,
|
||||
ws: DiscordWebSocket,
|
||||
client: AutoShardedClient,
|
||||
queue_put: Callable[[EventItem], None],
|
||||
) -> None:
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._client: Client = client
|
||||
self._dispatch: Callable[..., None] = client.dispatch
|
||||
self._queue_put: Callable[[EventItem], None] = queue_put
|
||||
self.loop: asyncio.AbstractEventLoop = self._client.loop
|
||||
self._disconnect: bool = False
|
||||
self._reconnect = client._reconnect
|
||||
self._backoff: ExponentialBackoff = ExponentialBackoff()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._handled_exceptions: tuple[type[Exception], ...] = (
|
||||
OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
# DiscordWebSocket.shard_id is set in the from_client classmethod
|
||||
return self.ws.shard_id # type: ignore
|
||||
|
||||
def launch(self) -> None:
|
||||
self._task = self.loop.create_task(self.worker())
|
||||
|
||||
def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._cancel_task()
|
||||
await self.ws.close(code=1000)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
await self.close()
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
|
||||
async def _handle_disconnect(self, e: Exception) -> None:
|
||||
self._dispatch("disconnect")
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
if not self._reconnect:
|
||||
self._queue_put(EventItem(EventType.close, self, e))
|
||||
return
|
||||
|
||||
if self._client.is_closed():
|
||||
return
|
||||
|
||||
if isinstance(e, OSError) and e.errno in (54, 10054):
|
||||
# If we get Connection reset by peer then always try to RESUME the connection.
|
||||
exc = ReconnectWebSocket(self.id, resume=True)
|
||||
self._queue_put(EventItem(EventType.resume, self, exc))
|
||||
return
|
||||
|
||||
if isinstance(e, ConnectionClosed):
|
||||
if e.code == 4014:
|
||||
self._queue_put(
|
||||
EventItem(
|
||||
EventType.terminate, self, PrivilegedIntentsRequired(self.id)
|
||||
)
|
||||
)
|
||||
return
|
||||
if e.code != 1000:
|
||||
self._queue_put(EventItem(EventType.close, self, e))
|
||||
return
|
||||
|
||||
retry = self._backoff.delay()
|
||||
_log.error(
|
||||
"Attempting a reconnect for shard ID %s in %.2fs",
|
||||
self.id,
|
||||
retry,
|
||||
exc_info=e,
|
||||
)
|
||||
await asyncio.sleep(retry)
|
||||
self._queue_put(EventItem(EventType.reconnect, self, e))
|
||||
|
||||
async def worker(self) -> None:
|
||||
while not self._client.is_closed():
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
etype = EventType.resume if e.resume else EventType.identify
|
||||
self._queue_put(EventItem(etype, self, e))
|
||||
break
|
||||
except self._handled_exceptions as e:
|
||||
await self._handle_disconnect(e)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._queue_put(EventItem(EventType.terminate, self, e))
|
||||
break
|
||||
|
||||
async def reidentify(self, exc: ReconnectWebSocket) -> None:
|
||||
self._cancel_task()
|
||||
self._dispatch("disconnect")
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
_log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self._client,
|
||||
resume=exc.resume,
|
||||
shard_id=self.id,
|
||||
session=self.ws.session_id,
|
||||
sequence=self.ws.sequence,
|
||||
)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
except self._handled_exceptions as e:
|
||||
await self._handle_disconnect(e)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
self._queue_put(EventItem(EventType.terminate, self, e))
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
self._cancel_task()
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self._client,
|
||||
gateway=self.ws.resume_gateway_url,
|
||||
shard_id=self.id,
|
||||
)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
except self._handled_exceptions as e:
|
||||
await self._handle_disconnect(e)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
self._queue_put(EventItem(EventType.terminate, self, e))
|
||||
else:
|
||||
self.launch()
|
||||
|
||||
|
||||
class ShardInfo:
|
||||
"""A class that gives information and control over a specific shard.
|
||||
|
||||
You can retrieve this object via :meth:`AutoShardedClient.get_shard`
|
||||
or :attr:`AutoShardedClient.shards`.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
Attributes
|
||||
----------
|
||||
id: :class:`int`
|
||||
The shard ID for this shard.
|
||||
shard_count: Optional[:class:`int`]
|
||||
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
|
||||
"""
|
||||
|
||||
__slots__ = ("_parent", "id", "shard_count")
|
||||
|
||||
def __init__(self, parent: Shard, shard_count: int | None) -> None:
|
||||
self._parent: Shard = parent
|
||||
self.id: int = parent.id
|
||||
self.shard_count: int | None = shard_count
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Whether the shard connection is currently closed."""
|
||||
return not self._parent.ws.open
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects a shard. When this is called, the shard connection will no
|
||||
longer be open.
|
||||
|
||||
If the shard is already disconnected this does nothing.
|
||||
"""
|
||||
if self.is_closed():
|
||||
return
|
||||
|
||||
await self._parent.disconnect()
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects and then connects the shard again.
|
||||
"""
|
||||
if not self.is_closed():
|
||||
await self._parent.disconnect()
|
||||
await self._parent.reconnect()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Connects a shard. If the shard is already connected this does nothing.
|
||||
"""
|
||||
if not self.is_closed():
|
||||
return
|
||||
|
||||
await self._parent.reconnect()
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
"""Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard. If no heartbeat
|
||||
has been received yet this returns ``float('inf')``.
|
||||
"""
|
||||
return self._parent.ws.latency
|
||||
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
"""Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
using HTTP or via the gateway.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
return self._parent.ws.is_ratelimited()
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A client similar to :class:`Client` except it handles the complications
|
||||
of sharding for the user into a more manageable and transparent single
|
||||
process bot.
|
||||
|
||||
When using this client, you will be able to use it as-if it was a regular
|
||||
:class:`Client` with a single shard when implementation wise internally it
|
||||
is split up into multiple shards. This allows you to not have to deal with
|
||||
IPC or other complicated infrastructure.
|
||||
|
||||
It is recommended to use this client only if you have surpassed at least
|
||||
1000 guilds.
|
||||
|
||||
If no :attr:`.shard_count` is provided, then the library will use the
|
||||
Bot Gateway endpoint call to figure out how many shards to use.
|
||||
|
||||
If a ``shard_ids`` parameter is given, then those shard IDs will be used
|
||||
to launch the internal shards. Note that :attr:`.shard_count` must be provided
|
||||
if this is used. By default, when omitted, the client will launch shards from
|
||||
0 to ``shard_count - 1``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
shard_ids: Optional[List[:class:`int`]]
|
||||
An optional list of shard_ids to launch the shards with.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_connection: AutoShardedConnectionState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs.pop("shard_id", None)
|
||||
self.shard_ids: list[int] | None = kwargs.pop("shard_ids", None)
|
||||
super().__init__(*args, loop=loop, **kwargs)
|
||||
|
||||
if self.shard_ids is not None:
|
||||
if self.shard_count is None:
|
||||
raise ClientException(
|
||||
"When passing manual shard_ids, you must provide a shard_count."
|
||||
)
|
||||
elif not isinstance(self.shard_ids, (list, tuple)):
|
||||
raise ClientException("shard_ids parameter must be a list or a tuple.")
|
||||
|
||||
# instead of a single websocket, we have multiple.
|
||||
# the key is the shard_id
|
||||
self.__shards = {}
|
||||
self._connection._get_websocket = self._get_websocket
|
||||
self._connection._get_client = lambda: self
|
||||
self.__queue = asyncio.PriorityQueue()
|
||||
|
||||
def _get_websocket(
|
||||
self, guild_id: int | None = None, *, shard_id: int | None = None
|
||||
) -> DiscordWebSocket:
|
||||
if shard_id is None:
|
||||
# guild_id won't be None if shard_id is None and shard_count won't be None here
|
||||
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
|
||||
return self.__shards[shard_id].ws
|
||||
|
||||
def _get_state(self, **options: Any) -> AutoShardedConnectionState:
|
||||
return AutoShardedConnectionState(
|
||||
dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks,
|
||||
http=self.http,
|
||||
loop=self.loop,
|
||||
**options,
|
||||
)
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
"""Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This operates similarly to :meth:`Client.latency` except it uses the average
|
||||
latency of every shard's latency. To get a list of shard latency, check the
|
||||
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
|
||||
"""
|
||||
if not self.__shards:
|
||||
return float("nan")
|
||||
return sum(latency for _, latency in self.latencies) / len(self.__shards)
|
||||
|
||||
@property
|
||||
def latencies(self) -> list[tuple[int, float]]:
|
||||
"""A list of latencies between a
|
||||
HEARTBEAT and a HEARTBEAT_ACK in seconds.
|
||||
|
||||
This returns a list of tuples with elements ``(shard_id, latency)``.
|
||||
"""
|
||||
return [
|
||||
(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()
|
||||
]
|
||||
|
||||
def get_shard(self, shard_id: int) -> ShardInfo | None:
|
||||
"""Gets the shard information at a given shard ID or ``None`` if not found."""
|
||||
try:
|
||||
parent = self.__shards[shard_id]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return ShardInfo(parent, self.shard_count)
|
||||
|
||||
@property
|
||||
def shards(self) -> dict[int, ShardInfo]:
|
||||
"""Returns a mapping of shard IDs to their respective info object."""
|
||||
return {
|
||||
shard_id: ShardInfo(parent, self.shard_count)
|
||||
for shard_id, parent in self.__shards.items()
|
||||
}
|
||||
|
||||
async def launch_shard(
|
||||
self, gateway: str, shard_id: int, *, initial: bool = False
|
||||
) -> None:
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self, initial=initial, gateway=gateway, shard_id=shard_id
|
||||
)
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
except Exception:
|
||||
_log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
return await self.launch_shard(gateway, shard_id)
|
||||
|
||||
# keep reading the shard while others connect
|
||||
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
|
||||
ret.launch()
|
||||
|
||||
async def launch_shards(self) -> None:
|
||||
if self.shard_count is None:
|
||||
self.shard_count, gateway = await self.http.get_bot_gateway()
|
||||
else:
|
||||
gateway = await self.http.get_gateway()
|
||||
|
||||
self._connection.shard_count = self.shard_count
|
||||
|
||||
shard_ids = self.shard_ids or range(self.shard_count)
|
||||
self._connection.shard_ids = shard_ids
|
||||
|
||||
for shard_id in shard_ids:
|
||||
initial = shard_id == shard_ids[0]
|
||||
await self.launch_shard(gateway, shard_id, initial=initial)
|
||||
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
async def connect(self, *, reconnect: bool = True) -> None:
|
||||
self._reconnect = reconnect
|
||||
await self.launch_shards()
|
||||
|
||||
while not self.is_closed():
|
||||
item = await self.__queue.get()
|
||||
if item.type == EventType.close:
|
||||
await self.close()
|
||||
if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
|
||||
raise item.error
|
||||
return
|
||||
elif item.type in (EventType.identify, EventType.resume):
|
||||
await item.shard.reidentify(item.error)
|
||||
elif item.type == EventType.reconnect:
|
||||
await item.shard.reconnect()
|
||||
elif item.type == EventType.terminate:
|
||||
await self.close()
|
||||
raise item.error
|
||||
elif item.type == EventType.clean_close:
|
||||
return
|
||||
|
||||
async def close(self) -> None:
|
||||
"""|coro|
|
||||
|
||||
Closes the connection to Discord.
|
||||
"""
|
||||
if self.is_closed():
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
for vc in self.voice_clients:
|
||||
try:
|
||||
await vc.disconnect(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
to_close = [
|
||||
asyncio.ensure_future(shard.close(), loop=self.loop)
|
||||
for shard in self.__shards.values()
|
||||
]
|
||||
if to_close:
|
||||
await asyncio.wait(to_close)
|
||||
|
||||
await self.http.close()
|
||||
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
|
||||
|
||||
async def change_presence(
|
||||
self,
|
||||
*,
|
||||
activity: BaseActivity | None = None,
|
||||
status: Status | None = None,
|
||||
shard_id: int = None,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Changes the client's presence.
|
||||
|
||||
Example: ::
|
||||
|
||||
game = discord.Game("with the API")
|
||||
await client.change_presence(status=discord.Status.idle, activity=game)
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
Removed the ``afk`` keyword-only parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
activity: Optional[:class:`BaseActivity`]
|
||||
The activity being done. ``None`` if no currently active activity is done.
|
||||
status: Optional[:class:`Status`]
|
||||
Indicates what status to change to. If ``None``, then
|
||||
:attr:`Status.online` is used.
|
||||
shard_id: Optional[:class:`int`]
|
||||
The shard_id to change the presence to. If not specified
|
||||
or ``None``, then it will change the presence of every
|
||||
shard the bot can see.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidArgument
|
||||
If the ``activity`` parameter is not of proper type.
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status_value = "online"
|
||||
status_enum = Status.online
|
||||
elif status is Status.offline:
|
||||
status_value = "invisible"
|
||||
status_enum = Status.offline
|
||||
else:
|
||||
status_enum = status
|
||||
status_value = str(status)
|
||||
|
||||
if shard_id is None:
|
||||
for shard in self.__shards.values():
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
|
||||
guilds = self._connection.guilds
|
||||
else:
|
||||
shard = self.__shards[shard_id]
|
||||
await shard.ws.change_presence(activity=activity, status=status_value)
|
||||
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
|
||||
|
||||
activities = () if activity is None else (activity,)
|
||||
for guild in guilds:
|
||||
me = guild.me
|
||||
if me is None:
|
||||
continue
|
||||
|
||||
# Member.activities is typehinted as Tuple[ActivityType, ...],
|
||||
# we may be setting it as Tuple[BaseActivity, ...]
|
||||
me.activities = activities # type: ignore
|
||||
me.status = status_enum
|
||||
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
"""Whether the websocket is currently rate limited.
|
||||
|
||||
This can be useful to know when deciding whether you should query members
|
||||
using HTTP or via the gateway.
|
||||
|
||||
This implementation checks if any of the shards are rate limited.
|
||||
For more granular control, consider :meth:`ShardInfo.is_ws_ratelimited`.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
return any(shard.ws.is_ratelimited() for shard in self.__shards.values())
|
||||
Reference in New Issue
Block a user