using Discord.API; using Discord.Rest; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading.Tasks; using System.Threading; namespace Discord.WebSocket { public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient { private readonly DiscordSocketConfig _baseConfig; private readonly SemaphoreSlim _connectionGroupLock; private readonly Dictionary _shardIdsToIndex; private readonly bool _automaticShards; private int[] _shardIds; private DiscordSocketClient[] _shards; private int _totalShards; private bool _isDisposed; /// public override int Latency { get => GetLatency(); protected set { } } /// public override UserStatus Status { get => _shards[0].Status; protected set { } } /// public override IActivity Activity { get => _shards[0].Activity; protected set { } } internal new DiscordSocketApiClient ApiClient => base.ApiClient as DiscordSocketApiClient; /// public override IReadOnlyCollection Guilds => GetGuilds().ToReadOnlyCollection(GetGuildCount); /// public override IReadOnlyCollection PrivateChannels => GetPrivateChannels().ToReadOnlyCollection(GetPrivateChannelCount); public IReadOnlyCollection Shards => _shards; /// public override IReadOnlyCollection VoiceRegions => _shards[0].VoiceRegions; /// /// Provides access to a REST-only client with a shared state from this client. /// public override DiscordSocketRestClient Rest => _shards[0].Rest; /// Creates a new REST/WebSocket Discord client. public DiscordShardedClient() : this(null, new DiscordSocketConfig()) { } /// Creates a new REST/WebSocket Discord client. #pragma warning disable IDISP004 public DiscordShardedClient(DiscordSocketConfig config) : this(null, config, CreateApiClient(config)) { } #pragma warning restore IDISP004 /// Creates a new REST/WebSocket Discord client. public DiscordShardedClient(int[] ids) : this(ids, new DiscordSocketConfig()) { } /// Creates a new REST/WebSocket Discord client. #pragma warning disable IDISP004 public DiscordShardedClient(int[] ids, DiscordSocketConfig config) : this(ids, config, CreateApiClient(config)) { } #pragma warning restore IDISP004 private DiscordShardedClient(int[] ids, DiscordSocketConfig config, API.DiscordSocketApiClient client) : base(config, client) { if (config.ShardId != null) throw new ArgumentException($"{nameof(config.ShardId)} must not be set."); if (ids != null && config.TotalShards == null) throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified."); _shardIdsToIndex = new Dictionary(); config.DisplayInitialLog = false; _baseConfig = config; _connectionGroupLock = new SemaphoreSlim(1, 1); if (config.TotalShards == null) _automaticShards = true; else { _totalShards = config.TotalShards.Value; _shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray(); _shards = new DiscordSocketClient[_shardIds.Length]; for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = config.Clone(); newConfig.ShardId = _shardIds[i]; _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, rateLimitPrecision: config.RateLimitPrecision); internal override async Task OnLoginAsync(TokenType tokenType, string token) { if (_automaticShards) { var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); _shardIds = Enumerable.Range(0, shardCount).ToArray(); _totalShards = _shardIds.Length; _shards = new DiscordSocketClient[_shardIds.Length]; for (int i = 0; i < _shardIds.Length; i++) { _shardIdsToIndex.Add(_shardIds[i], i); var newConfig = _baseConfig.Clone(); newConfig.ShardId = _shardIds[i]; newConfig.TotalShards = _totalShards; _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); RegisterEvents(_shards[i], i == 0); } } //Assume thread safe: already in a connection lock for (int i = 0; i < _shards.Length; i++) await _shards[i].LoginAsync(tokenType, token); } internal override async Task OnLogoutAsync() { //Assume thread safe: already in a connection lock if (_shards != null) { for (int i = 0; i < _shards.Length; i++) await _shards[i].LogoutAsync(); } CurrentUser = null; if (_automaticShards) { _shardIds = new int[0]; _shardIdsToIndex.Clear(); _totalShards = 0; _shards = null; } } /// public override async Task StartAsync() => await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false); /// public override async Task StopAsync() => await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false); public DiscordSocketClient GetShard(int id) { if (_shardIdsToIndex.TryGetValue(id, out id)) return _shards[id]; return null; } private int GetShardIdFor(ulong guildId) => (int)((guildId >> 22) % (uint)_totalShards); public int GetShardIdFor(IGuild guild) => GetShardIdFor(guild?.Id ?? 0); private DiscordSocketClient GetShardFor(ulong guildId) => GetShard(GetShardIdFor(guildId)); public DiscordSocketClient GetShardFor(IGuild guild) => GetShardFor(guild?.Id ?? 0); /// public override async Task GetApplicationInfoAsync(RequestOptions options = null) => await _shards[0].GetApplicationInfoAsync(options).ConfigureAwait(false); /// public override SocketGuild GetGuild(ulong id) => GetShardFor(id).GetGuild(id); /// public override SocketChannel GetChannel(ulong id) { for (int i = 0; i < _shards.Length; i++) { var channel = _shards[i].GetChannel(id); if (channel != null) return channel; } return null; } private IEnumerable GetPrivateChannels() { for (int i = 0; i < _shards.Length; i++) { foreach (var channel in _shards[i].PrivateChannels) yield return channel; } } private int GetPrivateChannelCount() { int result = 0; for (int i = 0; i < _shards.Length; i++) result += _shards[i].PrivateChannels.Count; return result; } private IEnumerable GetGuilds() { for (int i = 0; i < _shards.Length; i++) { foreach (var guild in _shards[i].Guilds) yield return guild; } } private int GetGuildCount() { int result = 0; for (int i = 0; i < _shards.Length; i++) result += _shards[i].Guilds.Count; return result; } /// public override SocketUser GetUser(ulong id) { for (int i = 0; i < _shards.Length; i++) { var user = _shards[i].GetUser(id); if (user != null) return user; } return null; } /// public override SocketUser GetUser(string username, string discriminator) { for (int i = 0; i < _shards.Length; i++) { var user = _shards[i].GetUser(username, discriminator); if (user != null) return user; } return null; } /// public override RestVoiceRegion GetVoiceRegion(string id) => _shards[0].GetVoiceRegion(id); /// /// is public override async Task DownloadUsersAsync(IEnumerable guilds) { if (guilds == null) throw new ArgumentNullException(nameof(guilds)); for (int i = 0; i < _shards.Length; i++) { int id = _shardIds[i]; var arr = guilds.Where(x => GetShardIdFor(x) == id).ToArray(); if (arr.Length > 0) await _shards[i].DownloadUsersAsync(arr).ConfigureAwait(false); } } private int GetLatency() { int total = 0; for (int i = 0; i < _shards.Length; i++) total += _shards[i].Latency; return (int)Math.Round(total / (double)_shards.Length); } /// public override async Task SetStatusAsync(UserStatus status) { for (int i = 0; i < _shards.Length; i++) await _shards[i].SetStatusAsync(status).ConfigureAwait(false); } /// public override async Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing) { IActivity activity = null; if (!string.IsNullOrEmpty(streamUrl)) activity = new StreamingGame(name, streamUrl); else if (!string.IsNullOrEmpty(name)) activity = new Game(name, type); await SetActivityAsync(activity).ConfigureAwait(false); } /// public override async Task SetActivityAsync(IActivity activity) { for (int i = 0; i < _shards.Length; i++) await _shards[i].SetActivityAsync(activity).ConfigureAwait(false); } private void RegisterEvents(DiscordSocketClient client, bool isPrimary) { client.Log += (msg) => _logEvent.InvokeAsync(msg); client.LoggedOut += () => { var state = LoginState; if (state == LoginState.LoggedIn || state == LoginState.LoggingIn) { //Should only happen if token is changed var _ = LogoutAsync(); //Signal the logout, fire and forget } return Task.Delay(0); }; if (isPrimary) { client.Ready += () => { CurrentUser = client.CurrentUser; return Task.Delay(0); }; } client.Connected += () => _shardConnectedEvent.InvokeAsync(client); client.Disconnected += (exception) => _shardDisconnectedEvent.InvokeAsync(exception, client); client.Ready += () => _shardReadyEvent.InvokeAsync(client); client.LatencyUpdated += (oldLatency, newLatency) => _shardLatencyUpdatedEvent.InvokeAsync(oldLatency, newLatency, client); client.ChannelCreated += (channel) => _channelCreatedEvent.InvokeAsync(channel); client.ChannelDestroyed += (channel) => _channelDestroyedEvent.InvokeAsync(channel); client.ChannelUpdated += (oldChannel, newChannel) => _channelUpdatedEvent.InvokeAsync(oldChannel, newChannel); client.MessageReceived += (msg) => _messageReceivedEvent.InvokeAsync(msg); client.MessageDeleted += (cache, channel) => _messageDeletedEvent.InvokeAsync(cache, channel); client.MessagesBulkDeleted += (cache, channel) => _messagesBulkDeletedEvent.InvokeAsync(cache, channel); client.MessageUpdated += (oldMsg, newMsg, channel) => _messageUpdatedEvent.InvokeAsync(oldMsg, newMsg, channel); client.ReactionAdded += (cache, channel, reaction) => _reactionAddedEvent.InvokeAsync(cache, channel, reaction); client.ReactionRemoved += (cache, channel, reaction) => _reactionRemovedEvent.InvokeAsync(cache, channel, reaction); client.ReactionsCleared += (cache, channel) => _reactionsClearedEvent.InvokeAsync(cache, channel); client.ReactionsRemovedForEmote += (cache, channel, emote) => _reactionsRemovedForEmoteEvent.InvokeAsync(cache, channel, emote); client.RoleCreated += (role) => _roleCreatedEvent.InvokeAsync(role); client.RoleDeleted += (role) => _roleDeletedEvent.InvokeAsync(role); client.RoleUpdated += (oldRole, newRole) => _roleUpdatedEvent.InvokeAsync(oldRole, newRole); client.JoinedGuild += (guild) => _joinedGuildEvent.InvokeAsync(guild); client.LeftGuild += (guild) => _leftGuildEvent.InvokeAsync(guild); client.GuildAvailable += (guild) => _guildAvailableEvent.InvokeAsync(guild); client.GuildUnavailable += (guild) => _guildUnavailableEvent.InvokeAsync(guild); client.GuildMembersDownloaded += (guild) => _guildMembersDownloadedEvent.InvokeAsync(guild); client.GuildUpdated += (oldGuild, newGuild) => _guildUpdatedEvent.InvokeAsync(oldGuild, newGuild); client.UserJoined += (user) => _userJoinedEvent.InvokeAsync(user); client.UserLeft += (user) => _userLeftEvent.InvokeAsync(user); client.UserBanned += (user, guild) => _userBannedEvent.InvokeAsync(user, guild); client.UserUnbanned += (user, guild) => _userUnbannedEvent.InvokeAsync(user, guild); client.UserUpdated += (oldUser, newUser) => _userUpdatedEvent.InvokeAsync(oldUser, newUser); client.GuildMemberUpdated += (oldUser, newUser) => _guildMemberUpdatedEvent.InvokeAsync(oldUser, newUser); client.UserVoiceStateUpdated += (user, oldVoiceState, newVoiceState) => _userVoiceStateUpdatedEvent.InvokeAsync(user, oldVoiceState, newVoiceState); client.VoiceServerUpdated += (server) => _voiceServerUpdatedEvent.InvokeAsync(server); client.CurrentUserUpdated += (oldUser, newUser) => _selfUpdatedEvent.InvokeAsync(oldUser, newUser); client.UserIsTyping += (oldUser, newUser) => _userIsTypingEvent.InvokeAsync(oldUser, newUser); client.RecipientAdded += (user) => _recipientAddedEvent.InvokeAsync(user); client.RecipientRemoved += (user) => _recipientRemovedEvent.InvokeAsync(user); } //IDiscordClient /// async Task IDiscordClient.GetApplicationInfoAsync(RequestOptions options) => await GetApplicationInfoAsync().ConfigureAwait(false); /// Task IDiscordClient.GetChannelAsync(ulong id, CacheMode mode, RequestOptions options) => Task.FromResult(GetChannel(id)); /// Task> IDiscordClient.GetPrivateChannelsAsync(CacheMode mode, RequestOptions options) => Task.FromResult>(PrivateChannels); /// async Task> IDiscordClient.GetConnectionsAsync(RequestOptions options) => await GetConnectionsAsync().ConfigureAwait(false); /// async Task IDiscordClient.GetInviteAsync(string inviteId, RequestOptions options) => await GetInviteAsync(inviteId, options).ConfigureAwait(false); /// Task IDiscordClient.GetGuildAsync(ulong id, CacheMode mode, RequestOptions options) => Task.FromResult(GetGuild(id)); /// Task> IDiscordClient.GetGuildsAsync(CacheMode mode, RequestOptions options) => Task.FromResult>(Guilds); /// async Task IDiscordClient.CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon, RequestOptions options) => await CreateGuildAsync(name, region, jpegIcon).ConfigureAwait(false); /// Task IDiscordClient.GetUserAsync(ulong id, CacheMode mode, RequestOptions options) => Task.FromResult(GetUser(id)); /// Task IDiscordClient.GetUserAsync(string username, string discriminator, RequestOptions options) => Task.FromResult(GetUser(username, discriminator)); /// Task> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options) => Task.FromResult>(VoiceRegions); /// Task IDiscordClient.GetVoiceRegionAsync(string id, RequestOptions options) => Task.FromResult(GetVoiceRegion(id)); internal override void Dispose(bool disposing) { if (!_isDisposed) { if (disposing) { if (_shards != null) { foreach (var client in _shards) client?.Dispose(); } _connectionGroupLock?.Dispose(); } _isDisposed = true; } base.Dispose(disposing); } } }