diff --git a/src/Discord.Net.WebSocket/API/Gateway/IdentifyParams.cs b/src/Discord.Net.WebSocket/API/Gateway/IdentifyParams.cs index e87c58221..af16f22f5 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/IdentifyParams.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/IdentifyParams.cs @@ -13,8 +13,6 @@ namespace Discord.API.Gateway public IDictionary Properties { get; set; } [JsonProperty("large_threshold")] public int LargeThreshold { get; set; } - [JsonProperty("compress")] - public bool UseCompression { get; set; } [JsonProperty("shard")] public Optional ShardingParams { get; set; } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 7d680eaf2..72781204c 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -29,7 +29,11 @@ namespace Discord.API private CancellationTokenSource _connectCancelToken; private string _gatewayUrl; private bool _isExplicitUrl; - + + //Store our decompression streams for zlib shared state + private MemoryStream _compressed; + private DeflateStream _decompressor; + internal IWebSocketClient WebSocketClient { get; } public ConnectionState ConnectionState { get; private set; } @@ -43,14 +47,29 @@ namespace Discord.API _isExplicitUrl = true; WebSocketClient = webSocketProvider(); //WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+) + WebSocketClient.BinaryMessage += async (data, index, count) => { - using (var compressed = new MemoryStream(data, index + 2, count - 2)) using (var decompressed = new MemoryStream()) { - using (var zlib = new DeflateStream(compressed, CompressionMode.Decompress)) - zlib.CopyTo(decompressed); + if (data[0] == 0x78) + { + //Strip the zlib header + _compressed.Write(data, index + 2, count - 2); + _compressed.SetLength(count - 2); + } + else + { + _compressed.Write(data, index, count); + _compressed.SetLength(count); + } + + //Reset positions so we don't run out of memory + _compressed.Position = 0; + _decompressor.CopyTo(decompressed); + _compressed.Position = 0; decompressed.Position = 0; + using (var reader = new StreamReader(decompressed)) using (var jsonReader = new JsonTextReader(reader)) { @@ -76,6 +95,7 @@ namespace Discord.API await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false); }; } + internal override void Dispose(bool disposing) { if (!_isDisposed) @@ -84,6 +104,8 @@ namespace Discord.API { _connectCancelToken?.Dispose(); (WebSocketClient as IDisposable)?.Dispose(); + _decompressor?.Dispose(); + _compressed?.Dispose(); } _isDisposed = true; } @@ -105,6 +127,12 @@ namespace Discord.API if (WebSocketClient == null) throw new NotSupportedException("This client is not configured with websocket support."); + //Re-create streams to reset the zlib state + _compressed?.Dispose(); + _decompressor?.Dispose(); + _compressed = new MemoryStream(); + _decompressor = new DeflateStream(_compressed, CompressionMode.Decompress); + ConnectionState = ConnectionState.Connecting; try { @@ -115,7 +143,7 @@ namespace Discord.API if (!_isExplicitUrl) { var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); - _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}"; + _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream"; } await WebSocketClient.ConnectAsync(_gatewayUrl).ConfigureAwait(false); @@ -191,7 +219,7 @@ namespace Discord.API options = RequestOptions.CreateOrClone(options); return await SendAsync("GET", () => "gateway/bot", new BucketIds(), options: options).ConfigureAwait(false); } - public async Task SendIdentifyAsync(int largeThreshold = 100, bool useCompression = true, int shardID = 0, int totalShards = 1, RequestOptions options = null) + public async Task SendIdentifyAsync(int largeThreshold = 100, int shardID = 0, int totalShards = 1, RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); var props = new Dictionary @@ -202,8 +230,7 @@ namespace Discord.API { Token = AuthToken, Properties = props, - LargeThreshold = largeThreshold, - UseCompression = useCompression, + LargeThreshold = largeThreshold }; if (totalShards > 1) msg.ShardingParams = new int[] { shardID, totalShards };