diff --git a/src/Runner.Worker/Dap/DapDebugger.cs b/src/Runner.Worker/Dap/DapDebugger.cs index 99b61e1b1a2..89879e2ff2c 100644 --- a/src/Runner.Worker/Dap/DapDebugger.cs +++ b/src/Runner.Worker/Dap/DapDebugger.cs @@ -66,6 +66,7 @@ public sealed class DapDebugger : RunnerService, IDapDebugger // Dev Tunnel relay host for remote debugging private TunnelRelayTunnelHost _tunnelRelayHost; + private IWebSocketDapBridge _webSocketBridge; // Cancellation source for the connection loop, cancelled in StopAsync // so AcceptTcpClientAsync unblocks cleanly without relying on listener disposal. @@ -74,6 +75,10 @@ public sealed class DapDebugger : RunnerService, IDapDebugger // When true, skip tunnel relay startup (unit tests only) internal bool SkipTunnelRelay { get; set; } + // When true, skip the public websocket bridge and expose the raw DAP + // listener directly on the configured tunnel port (unit tests only). + internal bool SkipWebSocketBridge { get; set; } + // Synchronization for step execution private TaskCompletionSource _commandTcs; private readonly object _stateLock = new object(); @@ -108,6 +113,7 @@ public sealed class DapDebugger : RunnerService, IDapDebugger _state == DapSessionState.Running; internal DapSessionState State => _state; + internal int InternalDapPort => (_listener?.LocalEndpoint as IPEndPoint)?.Port ?? 0; public override void Initialize(IHostContext hostContext) { @@ -133,9 +139,19 @@ public async Task StartAsync(IExecutionContext jobContext) _jobContext = jobContext; _readyTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _listener = new TcpListener(IPAddress.Loopback, debuggerConfig.Tunnel.Port); + var dapPort = SkipWebSocketBridge ? debuggerConfig.Tunnel.Port : 0; + _listener = new TcpListener(IPAddress.Loopback, dapPort); _listener.Start(); - Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}"); + if (SkipWebSocketBridge) + { + Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}"); + } + else + { + Trace.Info($"Internal DAP debugger listening on {_listener.LocalEndpoint}"); + _webSocketBridge = HostContext.CreateService(); + _webSocketBridge.Start(debuggerConfig.Tunnel.Port, InternalDapPort); + } // Start Dev Tunnel relay so remote clients reach the local DAP port. // The relay is torn down explicitly in StopAsync (after the DAP session @@ -274,6 +290,25 @@ public async Task StopAsync() _tunnelRelayHost = null; } + if (_webSocketBridge != null) + { + Trace.Info("Stopping WebSocket DAP bridge"); + var shutdownTask = _webSocketBridge.ShutdownAsync(); + if (await Task.WhenAny(shutdownTask, Task.Delay(5_000)) != shutdownTask) + { + Trace.Warning("WebSocket DAP bridge shutdown timed out after 5s"); + _ = shutdownTask.ContinueWith( + t => Trace.Error($"WebSocket DAP bridge shutdown faulted: {t.Exception?.GetBaseException().Message}"), + TaskContinuationOptions.OnlyOnFaulted); + } + else + { + Trace.Info("WebSocket DAP bridge stopped"); + } + + _webSocketBridge = null; + } + CleanupConnection(); // Cancel the connection loop first so AcceptTcpClientAsync unblocks @@ -315,6 +350,7 @@ public async Task StopAsync() _connectionLoopTask = null; _loopCts?.Dispose(); _loopCts = null; + _webSocketBridge = null; } public async Task OnStepStartingAsync(IStep step) diff --git a/src/Runner.Worker/Dap/IWebSocketDapBridge.cs b/src/Runner.Worker/Dap/IWebSocketDapBridge.cs new file mode 100644 index 00000000000..a468aa54a82 --- /dev/null +++ b/src/Runner.Worker/Dap/IWebSocketDapBridge.cs @@ -0,0 +1,12 @@ +using System.Threading.Tasks; +using GitHub.Runner.Common; + +namespace GitHub.Runner.Worker.Dap +{ + [ServiceLocator(Default = typeof(WebSocketDapBridge))] + public interface IWebSocketDapBridge : IRunnerService + { + void Start(int listenPort, int targetPort); + Task ShutdownAsync(); + } +} diff --git a/src/Runner.Worker/Dap/WebSocketDapBridge.cs b/src/Runner.Worker/Dap/WebSocketDapBridge.cs new file mode 100644 index 00000000000..e2f8dc47370 --- /dev/null +++ b/src/Runner.Worker/Dap/WebSocketDapBridge.cs @@ -0,0 +1,839 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using GitHub.Runner.Common; + +namespace GitHub.Runner.Worker.Dap +{ + internal sealed class WebSocketDapBridge : RunnerService, IWebSocketDapBridge + { + internal enum IncomingStreamPrefixKind + { + Unknown, + HttpWebSocketUpgrade, + PreUpgradedWebSocket, + WebSocketReservedBits, + Http2Preface, + TlsClientHello, + } + + private const int _bufferSize = 32 * 1024; + private const int _maxHeaderLineLength = 8 * 1024; + private const int _defaultMaxInboundMessageSize = 10 * 1024 * 1024; // 10 MB + private static readonly TimeSpan _keepAliveInterval = TimeSpan.FromSeconds(30); + private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5); + private static readonly TimeSpan _handshakeTimeout = TimeSpan.FromSeconds(10); + private const string _webSocketAcceptMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + private const int _maxHeaderCount = 64; + private static readonly byte[] _headerEndMarker = new byte[] { (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; + + private int _listenPort; + private int _targetPort; + + private TcpListener _listener; + private CancellationTokenSource _loopCts; + private Task _acceptLoopTask; + + public int MaxInboundMessageSize { get; set; } = _defaultMaxInboundMessageSize; + + internal int ListenPort => (_listener?.LocalEndpoint as IPEndPoint)?.Port ?? 0; + + public void Start(int listenPort, int targetPort) + { + if (_listener != null) + { + throw new InvalidOperationException("WebSocket DAP bridge already started."); + } + + _listenPort = listenPort; + _targetPort = targetPort; + + _listener = new TcpListener(IPAddress.Loopback, _listenPort); + _listener.Start(); + _loopCts = new CancellationTokenSource(); + _acceptLoopTask = AcceptLoopAsync(_loopCts.Token); + + Trace.Info($"WebSocket DAP bridge listening on {_listener.LocalEndpoint} -> 127.0.0.1:{_targetPort}"); + } + + public async Task ShutdownAsync() + { + _loopCts?.Cancel(); + + try + { + _listener?.Stop(); + } + catch (Exception ex) + { + Trace.Warning($"Error stopping listener during shutdown ({ex.GetType().Name})"); + } + + if (_acceptLoopTask != null) + { + try + { + await _acceptLoopTask; + } + catch (OperationCanceledException) + { + // expected on shutdown + } + } + + _loopCts?.Dispose(); + _loopCts = null; + _listener = null; + _acceptLoopTask = null; + } + + private async Task AcceptLoopAsync(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + TcpClient client = null; + try + { + client = await _listener.AcceptTcpClientAsync(cancellationToken); + client.NoDelay = true; + await HandleClientAsync(client, cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + break; + } + catch (Exception ex) + { + client?.Dispose(); + Trace.Error($"WebSocket DAP bridge connection error"); + Trace.Error(ex); + } + finally + { + client?.Dispose(); + } + } + + Trace.Info("WebSocket DAP bridge accept loop ended"); + } + + private async Task HandleClientAsync(TcpClient incomingClient, CancellationToken cancellationToken) + { + using (var incomingStream = incomingClient.GetStream()) + { + Trace.Info($"WebSocket DAP bridge accepted client {incomingClient.Client.RemoteEndPoint}"); + + WebSocket webSocket; + using (var handshakeCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)) + { + handshakeCts.CancelAfter(_handshakeTimeout); + try + { + webSocket = await AcceptWebSocketAsync(incomingStream, handshakeCts.Token); + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + Trace.Warning("WebSocket handshake timed out"); + return; + } + } + if (webSocket == null) + { + return; + } + + using (webSocket) + using (var dapClient = new TcpClient()) + { + dapClient.NoDelay = true; + await dapClient.ConnectAsync(IPAddress.Loopback, _targetPort, cancellationToken); + + using (var dapStream = dapClient.GetStream()) + using (var sessionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)) + { + var proxyToken = sessionCts.Token; + var wsToTcpTask = PumpWebSocketToTcpAsync(webSocket, dapStream, proxyToken); + var tcpToWsTask = PumpTcpToWebSocketAsync(dapStream, webSocket, proxyToken); + + await Task.WhenAny(wsToTcpTask, tcpToWsTask); + sessionCts.Cancel(); + + await CloseWebSocketAsync(webSocket); + + try + { + await Task.WhenAll(wsToTcpTask, tcpToWsTask); + } + catch (OperationCanceledException) when (proxyToken.IsCancellationRequested) + { + // expected during shutdown + } + catch (Exception ex) + { + Trace.Warning($"DAP protocol error: {ex}"); + } + } + } + } + } + + private async Task AcceptWebSocketAsync(NetworkStream stream, CancellationToken cancellationToken) + { + var initialBytes = await ReadInitialBytesAsync(stream, cancellationToken); + if (initialBytes == null || initialBytes.Length == 0) + { + return null; + } + + var prefixKind = ClassifyIncomingStreamPrefix(initialBytes); + if (prefixKind == IncomingStreamPrefixKind.PreUpgradedWebSocket) + { + Trace.Info($"Treating incoming tunnel stream as an already-upgraded websocket connection ({DescribeInitialBytes(initialBytes)})"); + return WebSocket.CreateFromStream( + new ReplayableStream(stream, initialBytes), + isServer: true, + subProtocol: null, + keepAliveInterval: _keepAliveInterval); + } + + if (prefixKind != IncomingStreamPrefixKind.HttpWebSocketUpgrade) + { + Trace.Warning($"Unsupported debugger tunnel stream prefix ({prefixKind}): {DescribeInitialBytes(initialBytes)}"); + return null; + } + + var handshakeStream = new ReplayableStream(stream, initialBytes); + var requestLine = await ReadLineAsync(handshakeStream, cancellationToken); + if (string.IsNullOrEmpty(requestLine)) + { + return null; + } + + var headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + while (true) + { + if (headers.Count >= _maxHeaderCount) + { + Trace.Warning($"Rejected WebSocket request with too many headers (>{_maxHeaderCount})"); + await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Too many headers.", cancellationToken); + return null; + } + + var line = await ReadLineAsync(handshakeStream, cancellationToken); + if (line == null) + { + return null; + } + + if (line.Length == 0) + { + break; + } + + var separatorIndex = line.IndexOf(':'); + if (separatorIndex <= 0) + { + await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Invalid HTTP header.", cancellationToken); + return null; + } + + var headerName = line.Substring(0, separatorIndex).Trim(); + var headerValue = line.Substring(separatorIndex + 1).Trim(); + + if (headers.TryGetValue(headerName, out var existingValue)) + { + headers[headerName] = $"{existingValue}, {headerValue}"; + } + else + { + headers[headerName] = headerValue; + } + } + + if (!IsValidWebSocketRequest(requestLine, headers)) + { + var method = requestLine.Split(' ')[0]; + Trace.Info($"Rejected non-websocket request (method={method})"); + await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Expected a websocket upgrade request.", cancellationToken); + return null; + } + + if (!headers.TryGetValue("Sec-WebSocket-Version", out var webSocketVersion) || + !string.Equals(webSocketVersion.Trim(), "13", StringComparison.Ordinal)) + { + Trace.Warning("Rejected WebSocket request with unsupported version"); + await WriteHttpErrorAsync(stream, (HttpStatusCode)426, "Unsupported WebSocket version. Expected: 13.", cancellationToken); + return null; + } + + var webSocketKey = headers["Sec-WebSocket-Key"]; + if (!IsValidWebSocketKey(webSocketKey)) + { + Trace.Warning("Rejected WebSocket request with invalid Sec-WebSocket-Key"); + await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Invalid Sec-WebSocket-Key.", cancellationToken); + return null; + } + + var acceptValue = ComputeAcceptValue(webSocketKey); + var responseBytes = Encoding.ASCII.GetBytes( + "HTTP/1.1 101 Switching Protocols\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + $"Sec-WebSocket-Accept: {acceptValue}\r\n" + + "\r\n"); + + await handshakeStream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken); + await handshakeStream.FlushAsync(cancellationToken); + + Trace.Info("WebSocket DAP bridge completed websocket handshake"); + return WebSocket.CreateFromStream(handshakeStream, isServer: true, subProtocol: null, keepAliveInterval: _keepAliveInterval); + } + + private async Task PumpWebSocketToTcpAsync(WebSocket source, NetworkStream destination, CancellationToken cancellationToken) + { + var buffer = new byte[_bufferSize]; + + while (!cancellationToken.IsCancellationRequested) + { + using (var messageStream = new MemoryStream()) + { + WebSocketReceiveResult result; + do + { + result = await source.ReceiveAsync(new ArraySegment(buffer), cancellationToken); + if (result.MessageType == WebSocketMessageType.Close) + { + return; + } + + if (result.MessageType != WebSocketMessageType.Binary && + result.MessageType != WebSocketMessageType.Text) + { + break; + } + + if (result.Count > 0) + { + if (messageStream.Length + result.Count > MaxInboundMessageSize) + { + Trace.Warning($"WebSocket message exceeds maximum allowed size of {MaxInboundMessageSize} bytes, closing connection"); + await source.CloseAsync( + WebSocketCloseStatus.MessageTooBig, + $"Message exceeds {MaxInboundMessageSize} byte limit", + CancellationToken.None); + return; + } + + messageStream.Write(buffer, 0, result.Count); + } + } + while (!result.EndOfMessage); + + if (result.MessageType != WebSocketMessageType.Binary && + result.MessageType != WebSocketMessageType.Text) + { + continue; + } + + var messageBytes = messageStream.ToArray(); + if (messageBytes.Length == 0) + { + continue; + } + + var contentLengthHeader = Encoding.ASCII.GetBytes($"Content-Length: {messageBytes.Length}\r\n\r\n"); + await destination.WriteAsync(contentLengthHeader, 0, contentLengthHeader.Length, cancellationToken); + await destination.WriteAsync(messageBytes, 0, messageBytes.Length, cancellationToken); + await destination.FlushAsync(cancellationToken); + } + } + } + + private static async Task PumpTcpToWebSocketAsync(NetworkStream source, WebSocket destination, CancellationToken cancellationToken) + { + var readBuffer = new byte[_bufferSize]; + var dapBuffer = new List(); + + while (!cancellationToken.IsCancellationRequested) + { + var bytesRead = await source.ReadAsync(readBuffer, 0, readBuffer.Length, cancellationToken); + if (bytesRead == 0) + { + break; + } + + dapBuffer.AddRange(new ArraySegment(readBuffer, 0, bytesRead)); + + while (TryParseDapMessage(dapBuffer, out var messageBody)) + { + await destination.SendAsync( + new ArraySegment(messageBody), + WebSocketMessageType.Text, + endOfMessage: true, + cancellationToken); + } + } + } + + private static bool TryParseDapMessage(List buffer, out byte[] messageBody) + { + messageBody = null; + + var headerEndIndex = FindSequence(buffer, _headerEndMarker); + if (headerEndIndex == -1) + { + return false; + } + + var headerBytes = buffer.GetRange(0, headerEndIndex).ToArray(); + var headerText = Encoding.ASCII.GetString(headerBytes); + + var contentLength = -1; + foreach (var line in headerText.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries)) + { + if (line.StartsWith("Content-Length:", StringComparison.OrdinalIgnoreCase)) + { + var valueStart = line.IndexOf(':') + 1; + if (int.TryParse(line.Substring(valueStart).Trim(), out var parsedLength)) + { + contentLength = parsedLength; + break; + } + } + } + + if (contentLength < 0) + { + throw new InvalidOperationException("DAP message missing or unparseable Content-Length header; tearing down session."); + } + + var messageStart = headerEndIndex + 4; + var messageEnd = messageStart + contentLength; + + if (buffer.Count < messageEnd) + { + return false; + } + + messageBody = buffer.GetRange(messageStart, contentLength).ToArray(); + buffer.RemoveRange(0, messageEnd); + return true; + } + + private static int FindSequence(List buffer, byte[] sequence) + { + if (buffer.Count < sequence.Length) + { + return -1; + } + + for (int i = 0; i <= buffer.Count - sequence.Length; i++) + { + var match = true; + for (int j = 0; j < sequence.Length; j++) + { + if (buffer[i + j] != sequence[j]) + { + match = false; + break; + } + } + + if (match) + { + return i; + } + } + + return -1; + } + + private static bool IsValidWebSocketRequest(string requestLine, IDictionary headers) + { + if (string.IsNullOrWhiteSpace(requestLine)) + { + return false; + } + + var requestLineParts = requestLine.Split(' '); + if (requestLineParts.Length < 3 || !string.Equals(requestLineParts[0], "GET", StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + return HeaderContainsToken(headers, "Connection", "Upgrade") && + HeaderContainsToken(headers, "Upgrade", "websocket") && + headers.ContainsKey("Sec-WebSocket-Key"); + } + + private static bool HeaderContainsToken(IDictionary headers, string headerName, string expectedToken) + { + if (!headers.TryGetValue(headerName, out var headerValue) || string.IsNullOrWhiteSpace(headerValue)) + { + return false; + } + + return headerValue + .Split(',') + .Select(token => token.Trim()) + .Any(token => string.Equals(token, expectedToken, StringComparison.OrdinalIgnoreCase)); + } + + private static string ComputeAcceptValue(string webSocketKey) + { + using (var sha1 = SHA1.Create()) + { + var inputBytes = Encoding.ASCII.GetBytes($"{webSocketKey}{_webSocketAcceptMagic}"); + var hashBytes = sha1.ComputeHash(inputBytes); + return Convert.ToBase64String(hashBytes); + } + } + + private static bool IsValidWebSocketKey(string key) + { + if (string.IsNullOrEmpty(key) || key.IndexOfAny(new[] { '\r', '\n' }) >= 0) + { + return false; + } + + try + { + var decoded = Convert.FromBase64String(key); + return decoded.Length == 16; + } + catch (FormatException) + { + return false; + } + } + + private static async Task ReadLineAsync(Stream stream, CancellationToken cancellationToken) + { + var lineBuilder = new StringBuilder(); + var buffer = new byte[1]; + var previousWasCarriageReturn = false; + + while (true) + { + var bytesRead = await stream.ReadAsync(buffer, 0, 1, cancellationToken); + if (bytesRead == 0) + { + return lineBuilder.Length > 0 ? lineBuilder.ToString() : null; + } + + var currentChar = (char)buffer[0]; + if (currentChar == '\n' && previousWasCarriageReturn) + { + if (lineBuilder.Length > 0 && lineBuilder[lineBuilder.Length - 1] == '\r') + { + lineBuilder.Length--; + } + + return lineBuilder.ToString(); + } + + previousWasCarriageReturn = currentChar == '\r'; + lineBuilder.Append(currentChar); + + if (lineBuilder.Length > _maxHeaderLineLength) + { + throw new InvalidDataException($"HTTP header line exceeds maximum length of {_maxHeaderLineLength}"); + } + } + } + + private static async Task ReadInitialBytesAsync(NetworkStream stream, CancellationToken cancellationToken) + { + var buffer = new byte[4]; + var totalRead = 0; + + while (totalRead < buffer.Length) + { + var bytesRead = await stream.ReadAsync(buffer, totalRead, buffer.Length - totalRead, cancellationToken); + if (bytesRead == 0) + { + break; + } + + totalRead += bytesRead; + } + + if (totalRead == 0) + { + return Array.Empty(); + } + + if (totalRead == buffer.Length) + { + return buffer; + } + + var initialBytes = new byte[totalRead]; + Array.Copy(buffer, initialBytes, totalRead); + return initialBytes; + } + + internal static IncomingStreamPrefixKind ClassifyIncomingStreamPrefix(byte[] initialBytes) + { + if (LooksLikeHttpUpgrade(initialBytes)) + { + return IncomingStreamPrefixKind.HttpWebSocketUpgrade; + } + + if (LooksLikeHttp2Preface(initialBytes)) + { + return IncomingStreamPrefixKind.Http2Preface; + } + + if (LooksLikeTlsClientHello(initialBytes)) + { + return IncomingStreamPrefixKind.TlsClientHello; + } + + if (LooksLikeWebSocketFramePrefix(initialBytes, requireReservedBitsClear: false)) + { + return HasReservedBitsSet(initialBytes[0]) + ? IncomingStreamPrefixKind.WebSocketReservedBits + : IncomingStreamPrefixKind.PreUpgradedWebSocket; + } + + return IncomingStreamPrefixKind.Unknown; + } + + internal static string DescribeInitialBytes(byte[] initialBytes) + { + if (initialBytes == null || initialBytes.Length == 0) + { + return "no bytes read"; + } + + var hex = BitConverter.ToString(initialBytes); + var ascii = new string(initialBytes.Select(value => value >= 32 && value <= 126 ? (char)value : '.').ToArray()); + return $"hex={hex}, ascii=\"{ascii}\""; + } + + private static bool LooksLikeHttpUpgrade(byte[] initialBytes) + { + if (initialBytes == null || initialBytes.Length < 4) + { + return false; + } + + return initialBytes[0] == (byte)'G' && + initialBytes[1] == (byte)'E' && + initialBytes[2] == (byte)'T' && + initialBytes[3] == (byte)' '; + } + + private static bool LooksLikeHttp2Preface(byte[] initialBytes) + { + if (initialBytes == null || initialBytes.Length < 4) + { + return false; + } + + return initialBytes[0] == (byte)'P' && + initialBytes[1] == (byte)'R' && + initialBytes[2] == (byte)'I' && + initialBytes[3] == (byte)' '; + } + + private static bool LooksLikeTlsClientHello(byte[] initialBytes) + { + if (initialBytes == null || initialBytes.Length < 3) + { + return false; + } + + return initialBytes[0] == 0x16 && + initialBytes[1] == 0x03 && + initialBytes[2] >= 0x00 && + initialBytes[2] <= 0x04; + } + + private static bool LooksLikeWebSocketFramePrefix(byte[] initialBytes, bool requireReservedBitsClear) + { + if (initialBytes == null || initialBytes.Length < 2) + { + return false; + } + + var firstByte = initialBytes[0]; + var secondByte = initialBytes[1]; + var opcode = firstByte & 0x0F; + var isMasked = (secondByte & 0x80) != 0; + + if (!isMasked || !IsSupportedWebSocketOpcode(opcode)) + { + return false; + } + + return !requireReservedBitsClear || !HasReservedBitsSet(firstByte); + } + + private static bool HasReservedBitsSet(byte firstByte) + { + return (firstByte & 0x70) != 0; + } + + private static bool IsSupportedWebSocketOpcode(int opcode) + { + switch (opcode) + { + case 0x0: + case 0x1: + case 0x2: + case 0x8: + case 0x9: + case 0xA: + return true; + default: + return false; + } + } + + private static async Task WriteHttpErrorAsync( + NetworkStream stream, + HttpStatusCode statusCode, + string message, + CancellationToken cancellationToken) + { + var bodyBytes = Encoding.UTF8.GetBytes(message); + var responseBytes = Encoding.ASCII.GetBytes( + $"HTTP/1.1 {(int)statusCode} {statusCode}\r\n" + + "Connection: close\r\n" + + "Content-Type: text/plain; charset=utf-8\r\n" + + $"Content-Length: {bodyBytes.Length}\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "\r\n"); + + await stream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken); + await stream.WriteAsync(bodyBytes, 0, bodyBytes.Length, cancellationToken); + await stream.FlushAsync(cancellationToken); + } + + private static async Task CloseWebSocketAsync(WebSocket webSocket) + { + if (webSocket == null) + { + return; + } + + if (webSocket.State != WebSocketState.Open && + webSocket.State != WebSocketState.CloseReceived) + { + return; + } + + try + { + using var cts = new CancellationTokenSource(_closeTimeout); + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cts.Token); + } + catch (OperationCanceledException) + { + // Graceful close timed out, abort the connection. + webSocket.Abort(); + } + catch (WebSocketException) + { + // Peer already disconnected. + } + } + + private sealed class ReplayableStream : Stream + { + private readonly Stream _innerStream; + private readonly byte[] _prefixBytes; + private int _prefixOffset; + + public ReplayableStream(Stream innerStream, byte[] prefixBytes) + { + _innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + _prefixBytes = prefixBytes ?? Array.Empty(); + } + + public override bool CanRead => _innerStream.CanRead; + public override bool CanSeek => false; + public override bool CanWrite => _innerStream.CanWrite; + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() => _innerStream.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _innerStream.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) + { + if (TryReadPrefix(buffer, offset, count, out var bytesRead)) + { + return bytesRead; + } + + return _innerStream.Read(buffer, offset, count); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (TryReadPrefix(buffer, offset, count, out var bytesRead)) + { + return bytesRead; + } + + return await _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (_prefixOffset < _prefixBytes.Length) + { + var bytesToCopy = Math.Min(buffer.Length, _prefixBytes.Length - _prefixOffset); + new ReadOnlySpan(_prefixBytes, _prefixOffset, bytesToCopy).CopyTo(buffer.Span); + _prefixOffset += bytesToCopy; + return bytesToCopy; + } + + return await _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => _innerStream.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => + _innerStream.WriteAsync(buffer, cancellationToken); + + private bool TryReadPrefix(byte[] buffer, int offset, int count, out int bytesRead) + { + if (_prefixOffset >= _prefixBytes.Length) + { + bytesRead = 0; + return false; + } + + bytesRead = Math.Min(count, _prefixBytes.Length - _prefixOffset); + Array.Copy(_prefixBytes, _prefixOffset, buffer, offset, bytesRead); + _prefixOffset += bytesRead; + return true; + } + } + } +} diff --git a/src/Test/L0/Worker/DapDebuggerL0.cs b/src/Test/L0/Worker/DapDebuggerL0.cs index 3b7487dd934..9454b1a3a67 100644 --- a/src/Test/L0/Worker/DapDebuggerL0.cs +++ b/src/Test/L0/Worker/DapDebuggerL0.cs @@ -1,7 +1,8 @@ -using System; +using System; using System.IO; using System.Net; using System.Net.Sockets; +using System.Net.WebSockets; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -19,13 +20,45 @@ public sealed class DapDebuggerL0 private const string TimeoutEnvironmentVariable = "ACTIONS_RUNNER_DAP_CONNECTION_TIMEOUT"; private const string TunnelConnectTimeoutVariable = "ACTIONS_RUNNER_DAP_TUNNEL_CONNECT_TIMEOUT_SECONDS"; private DapDebugger _debugger; + private TestWebSocketDapBridge _testWebSocketBridge; - private TestHostContext CreateTestContext([CallerMemberName] string testName = "") + private sealed class TestWebSocketDapBridge : RunnerService, IWebSocketDapBridge + { + private readonly WebSocketDapBridge _inner = new WebSocketDapBridge(); + + public int ListenPort => _inner.ListenPort; + + public override void Initialize(IHostContext hostContext) + { + base.Initialize(hostContext); + _inner.Initialize(hostContext); + } + + public void Start(int listenPort, int targetPort) + { + _inner.Start(0, targetPort); + } + + public Task ShutdownAsync() + { + return _inner.ShutdownAsync(); + } + } + + private TestHostContext CreateTestContext(bool enableWebSocketBridge = false, [CallerMemberName] string testName = "") { var hc = new TestHostContext(this, testName); _debugger = new DapDebugger(); + _testWebSocketBridge = null; _debugger.Initialize(hc); _debugger.SkipTunnelRelay = true; + _debugger.SkipWebSocketBridge = !enableWebSocketBridge; + if (enableWebSocketBridge) + { + _testWebSocketBridge = new TestWebSocketDapBridge(); + hc.EnqueueInstance(_testWebSocketBridge); + } + return hc; } @@ -71,6 +104,14 @@ private static async Task ConnectClientAsync(int port) return client; } + private static async Task ConnectWebSocketClientAsync(int port) + { + var client = new ClientWebSocket(); + client.Options.Proxy = null; + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + return client; + } + private static async Task SendRequestAsync(NetworkStream stream, Request request) { var json = JsonConvert.SerializeObject(request); @@ -83,6 +124,14 @@ private static async Task SendRequestAsync(NetworkStream stream, Request request await stream.FlushAsync(); } + private static async Task SendRequestAsync(WebSocket client, Request request) + { + var json = JsonConvert.SerializeObject(request); + var body = Encoding.UTF8.GetBytes(json); + + await client.SendAsync(new ArraySegment(body), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None); + } + /// /// Reads a single DAP-framed message from a stream with a timeout. /// Parses the Content-Length header, reads exactly that many bytes, @@ -141,6 +190,52 @@ private static async Task ReadDapMessageAsync(NetworkStream stream, Time return Encoding.UTF8.GetString(body); } + private static async Task ReadWebSocketDataUntilAsync(WebSocket client, TimeSpan timeout, params string[] expectedFragments) + { + using var cts = new CancellationTokenSource(timeout); + var buffer = new byte[4096]; + var allMessages = new StringBuilder(); + + while (true) + { + using var messageStream = new MemoryStream(); + WebSocketReceiveResult result; + do + { + result = await client.ReceiveAsync(new ArraySegment(buffer), cts.Token); + if (result.MessageType == WebSocketMessageType.Close) + { + throw new EndOfStreamException("WebSocket closed before expected DAP messages were received."); + } + + if (result.Count > 0) + { + messageStream.Write(buffer, 0, result.Count); + } + } + while (!result.EndOfMessage); + + var messageText = Encoding.UTF8.GetString(messageStream.ToArray()); + allMessages.Append(messageText); + + var text = allMessages.ToString(); + var containsAllFragments = true; + foreach (var fragment in expectedFragments) + { + if (!text.Contains(fragment, StringComparison.Ordinal)) + { + containsAllFragments = false; + break; + } + } + + if (containsAllFragments) + { + return text; + } + } + } + private static Mock CreateJobContextWithTunnel(CancellationToken cancellationToken, ushort port, string jobName = null) { var tunnel = new GitHub.DistributedTask.Pipelines.DebuggerTunnelInfo @@ -208,6 +303,84 @@ public async Task StartAsyncUsesPortFromTunnelConfig() } } + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task StartAsyncWithWebSocketBridgeAcceptsInitializeOverWebSocket() + { + using (CreateTestContext(enableWebSocketBridge: true)) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + var jobContext = CreateJobContextWithTunnel(cts.Token, GetFreePort()); + await _debugger.StartAsync(jobContext.Object); + + var bridgePort = _testWebSocketBridge.ListenPort; + Assert.NotEqual(0, _debugger.InternalDapPort); + Assert.NotEqual(0, bridgePort); + Assert.NotEqual(bridgePort, _debugger.InternalDapPort); + + using var client = await ConnectWebSocketClientAsync(bridgePort); + await SendRequestAsync(client, new Request + { + Seq = 1, + Type = "request", + Command = "initialize" + }); + + var response = await ReadWebSocketDataUntilAsync( + client, + TimeSpan.FromSeconds(5), + "\"type\":\"response\"", + "\"command\":\"initialize\"", + "\"event\":\"initialized\""); + + Assert.Contains("\"success\":true", response); + await _debugger.StopAsync(); + } + } + + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task StartAsyncWithWebSocketBridgeAcceptsPreUpgradedWebSocketStream() + { + using (CreateTestContext(enableWebSocketBridge: true)) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + var jobContext = CreateJobContextWithTunnel(cts.Token, GetFreePort()); + await _debugger.StartAsync(jobContext.Object); + + var bridgePort = _testWebSocketBridge.ListenPort; + Assert.NotEqual(0, _debugger.InternalDapPort); + Assert.NotEqual(0, bridgePort); + Assert.NotEqual(bridgePort, _debugger.InternalDapPort); + + using var tcpClient = await ConnectClientAsync(bridgePort); + using var webSocket = WebSocket.CreateFromStream( + tcpClient.GetStream(), + isServer: false, + subProtocol: null, + keepAliveInterval: TimeSpan.FromSeconds(30)); + + await SendRequestAsync(webSocket, new Request + { + Seq = 1, + Type = "request", + Command = "initialize" + }); + + var response = await ReadWebSocketDataUntilAsync( + webSocket, + TimeSpan.FromSeconds(5), + "\"type\":\"response\"", + "\"command\":\"initialize\"", + "\"event\":\"initialized\""); + + Assert.Contains("\"success\":true", response); + await _debugger.StopAsync(); + } + } + [Fact] [Trait("Level", "L0")] [Trait("Category", "Worker")] diff --git a/src/Test/L0/Worker/WebSocketDapBridgeL0.cs b/src/Test/L0/Worker/WebSocketDapBridgeL0.cs new file mode 100644 index 00000000000..d4210dfcbc7 --- /dev/null +++ b/src/Test/L0/Worker/WebSocketDapBridgeL0.cs @@ -0,0 +1,266 @@ +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using GitHub.Runner.Common; +using GitHub.Runner.Worker.Dap; +using Xunit; + +namespace GitHub.Runner.Common.Tests.Worker +{ + public sealed class WebSocketDapBridgeL0 + { + private TestHostContext CreateTestContext([CallerMemberName] string testName = "") + { + return new TestHostContext(this, testName); + } + + private static async Task ReadWebSocketMessageAsync(ClientWebSocket client, TimeSpan timeout) + { + using var cts = new CancellationTokenSource(timeout); + using var buffer = new MemoryStream(); + var receiveBuffer = new byte[1024]; + + while (true) + { + var result = await client.ReceiveAsync(new ArraySegment(receiveBuffer), cts.Token); + if (result.MessageType == WebSocketMessageType.Close) + { + throw new EndOfStreamException("WebSocket closed unexpectedly."); + } + + if (result.Count > 0) + { + buffer.Write(receiveBuffer, 0, result.Count); + } + + if (result.EndOfMessage) + { + return buffer.ToArray(); + } + } + } + + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task BridgeForwardsWebSocketFramesToTcpAndBack() + { + using var hc = CreateTestContext(); + using var targetListener = new TcpListener(IPAddress.Loopback, 0); + targetListener.Start(); + + var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port; + + var bridge = new WebSocketDapBridge(); + bridge.Initialize(hc); + bridge.Start(0, targetPort); + var bridgePort = bridge.ListenPort; + + try + { + var echoTask = Task.Run(async () => + { + using var targetClient = await targetListener.AcceptTcpClientAsync(); + using var stream = targetClient.GetStream(); + + var headerBuilder = new StringBuilder(); + var buffer = new byte[1]; + var contentLength = -1; + + while (true) + { + var bytesRead = await stream.ReadAsync(buffer, 0, 1); + if (bytesRead == 0) break; + + headerBuilder.Append((char)buffer[0]); + var headers = headerBuilder.ToString(); + if (headers.EndsWith("\r\n\r\n", StringComparison.Ordinal)) + { + foreach (var line in headers.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries)) + { + if (line.StartsWith("Content-Length: ", StringComparison.OrdinalIgnoreCase)) + { + contentLength = int.Parse(line.Substring("Content-Length: ".Length).Trim()); + } + } + break; + } + } + + var body = new byte[contentLength]; + var totalRead = 0; + while (totalRead < contentLength) + { + var bytesRead = await stream.ReadAsync(body, totalRead, contentLength - totalRead); + if (bytesRead == 0) break; + totalRead += bytesRead; + } + + var header = $"Content-Length: {body.Length}\r\n\r\n"; + var headerBytes = Encoding.ASCII.GetBytes(header); + await stream.WriteAsync(headerBytes, 0, headerBytes.Length); + await stream.WriteAsync(body, 0, body.Length); + await stream.FlushAsync(); + }); + + using var client = new ClientWebSocket(); + client.Options.Proxy = null; + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None); + + var dapMessage = "{\"type\":\"request\",\"seq\":1,\"command\":\"initialize\"}"; + var payload = Encoding.UTF8.GetBytes(dapMessage); + await client.SendAsync(new ArraySegment(payload), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None); + + var echoed = await ReadWebSocketMessageAsync(client, TimeSpan.FromSeconds(5)); + Assert.Equal(payload, echoed); + + await echoTask; + } + finally + { + await bridge.ShutdownAsync(); + } + } + + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task BridgeRejectsNonWebSocketRequests() + { + using var hc = CreateTestContext(); + + var bridge = new WebSocketDapBridge(); + bridge.Initialize(hc); + bridge.Start(0, 0); + var bridgePort = bridge.ListenPort; + + try + { + using var client = new TcpClient(); + await client.ConnectAsync(IPAddress.Loopback, bridgePort); + using var stream = client.GetStream(); + + var request = Encoding.ASCII.GetBytes( + "GET / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "\r\n"); + await stream.WriteAsync(request, 0, request.Length); + await stream.FlushAsync(); + + // Read until the server closes the connection (Connection: close). + // A single ReadAsync may return a partial response on some platforms. + using var ms = new MemoryStream(); + var responseBuffer = new byte[1024]; + int bytesRead; + while ((bytesRead = await stream.ReadAsync(responseBuffer, 0, responseBuffer.Length)) > 0) + { + ms.Write(responseBuffer, 0, bytesRead); + } + + var response = Encoding.ASCII.GetString(ms.ToArray()); + + Assert.Contains("400 BadRequest", response); + Assert.Contains("Expected a websocket upgrade request.", response); + } + finally + { + await bridge.ShutdownAsync(); + } + } + + [Theory] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + [InlineData(new byte[] { (byte)'G', (byte)'E', (byte)'T', (byte)' ' }, 1)] + [InlineData(new byte[] { 0x81, 0x85, 0x00, 0x00 }, 2)] + [InlineData(new byte[] { 0xC1, 0x85, 0x00, 0x00 }, 3)] + [InlineData(new byte[] { (byte)'P', (byte)'R', (byte)'I', (byte)' ' }, 4)] + [InlineData(new byte[] { 0x16, 0x03, 0x03, 0x01 }, 5)] + [InlineData(new byte[] { (byte)'B', (byte)'A', (byte)'D', (byte)'!' }, 0)] + public void ClassifyIncomingStreamPrefixDetectsExpectedProtocols(byte[] initialBytes, int expectedKind) + { + var actualKind = WebSocketDapBridge.ClassifyIncomingStreamPrefix(initialBytes); + Assert.Equal((WebSocketDapBridge.IncomingStreamPrefixKind)expectedKind, actualKind); + } + + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task BridgeRejectsOversizedWebSocketMessage() + { + using var hc = CreateTestContext(); + using var targetListener = new TcpListener(IPAddress.Loopback, 0); + targetListener.Start(); + + var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port; + + var bridge = new WebSocketDapBridge(); + bridge.Initialize(hc); + bridge.MaxInboundMessageSize = 64; // artificially small limit for testing + bridge.Start(0, targetPort); + var bridgePort = bridge.ListenPort; + + try + { + using var client = new ClientWebSocket(); + client.Options.Proxy = null; + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None); + + // Send a message that exceeds the 64-byte limit + var oversizedPayload = new byte[128]; + Array.Fill(oversizedPayload, (byte)'X'); + await client.SendAsync( + new ArraySegment(oversizedPayload), + WebSocketMessageType.Text, + endOfMessage: true, + CancellationToken.None); + + // The bridge should close the connection with MessageTooBig + var receiveBuffer = new byte[256]; + using var receiveCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var result = await client.ReceiveAsync( + new ArraySegment(receiveBuffer), + receiveCts.Token); + + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.MessageTooBig, client.CloseStatus); + } + finally + { + await bridge.ShutdownAsync(); + } + } + + [Fact] + [Trait("Level", "L0")] + [Trait("Category", "Worker")] + public async Task BridgeShutdownCompletesWhenPeerDoesNotCloseGracefully() + { + using var hc = CreateTestContext(); + using var targetListener = new TcpListener(IPAddress.Loopback, 0); + targetListener.Start(); + + var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port; + + var bridge = new WebSocketDapBridge(); + bridge.Initialize(hc); + bridge.Start(0, targetPort); + var bridgePort = bridge.ListenPort; + + // Connect a raw TCP client but never perform WebSocket close handshake + using var rawClient = new TcpClient(); + await rawClient.ConnectAsync(IPAddress.Loopback, bridgePort); + + // Shutdown should complete within a bounded time, not hang + var shutdownTask = bridge.ShutdownAsync(); + var completed = await Task.WhenAny(shutdownTask, Task.Delay(TimeSpan.FromSeconds(15))); + Assert.True(completed == shutdownTask, "Bridge shutdown should complete within the timeout, not hang on a non-cooperative peer"); + } + } +}