diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c5052723c2..89135cbc2cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -132,6 +132,8 @@ jobs: - name: Run tests # id: step_test # continue-on-error: true + env: + GODEBUG: http2xconnect=1 run: | # (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out go test -v -coverprofile="cover-profile.out" -short -race ./... @@ -191,7 +193,7 @@ jobs: retries=3 exit_code=0 while ((retries > 0)); do - CGO_ENABLED=0 go test -p 1 -v ./... + GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./... exit_code=$? if ((exit_code == 0)); then break diff --git a/caddytest/integration/reverseproxy_extended_connect_test.go b/caddytest/integration/reverseproxy_extended_connect_test.go new file mode 100644 index 00000000000..8822988be09 --- /dev/null +++ b/caddytest/integration/reverseproxy_extended_connect_test.go @@ -0,0 +1,328 @@ +package integration + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support") + +func TestReverseProxyExtendedConnectOverH2(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newWebsocketUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust + servers :9443 { + protocols h2 + } +} + +https://localhost:9443 { + reverse_proxy %s +} +`, backend.addr), "caddyfile") + + const payload = "extended-connect-echo\n" + if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil { + if errors.Is(err, errExtendedConnectUnsupportedByPeer) { + t.Skipf("skipping extended CONNECT integration test: %v", err) + } + t.Fatalf("extended connect h2 echo failed: %v", err) + } +} + +func assertExtendedConnectH2Echo(addr, payload string) error { + conn, err := tlsDialH2(addr) + if err != nil { + return fmt.Errorf("dialing h2 tls: %w", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + return fmt.Errorf("setting deadline: %w", err) + } + + fr := http2.NewFramer(conn, conn) + + if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil { + return fmt.Errorf("writing client preface: %w", err) + } + if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil { + return fmt.Errorf("writing client settings: %w", err) + } + + supported, err := waitForServerSettings(fr) + if err != nil { + return err + } + if !supported { + return errExtendedConnectUnsupportedByPeer + } + if err := waitForSettingsAck(fr); err != nil { + return err + } + + if err := writeExtendedConnectHeaders(fr, addr); err != nil { + return err + } + + status, err := readResponseStatus(fr, 1) + if err != nil { + return err + } + if status != "200" { + return fmt.Errorf("unexpected extended connect status: got=%s want=200", status) + } + + if err := fr.WriteData(1, false, []byte(payload)); err != nil { + return fmt.Errorf("writing stream data: %w", err) + } + + echo, err := readStreamData(fr, 1, len(payload)) + if err != nil { + return err + } + if echo != payload { + return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload) + } + + _ = fr.WriteRSTStream(1, http2.ErrCodeNo) + return nil +} + +func tlsDialH2(addr string) (net.Conn, error) { + var lastErr error + for i := 0; i < 30; i++ { + dialer := &net.Dialer{Timeout: 2 * time.Second} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err == nil { + return conn, nil + } + lastErr = err + time.Sleep(100 * time.Millisecond) + } + return nil, lastErr +} + +func waitForServerSettings(fr *http2.Framer) (bool, error) { + for { + frame, err := fr.ReadFrame() + if err != nil { + return false, fmt.Errorf("reading frame before connect: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if !ok { + continue + } + if settings.IsAck() { + continue + } + + supported := false + if err := settings.ForeachSetting(func(s http2.Setting) error { + if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 { + supported = true + } + return nil + }); err != nil { + return false, fmt.Errorf("reading server settings: %w", err) + } + + if err := fr.WriteSettingsAck(); err != nil { + return false, fmt.Errorf("writing settings ack: %w", err) + } + return supported, nil + } +} + +func waitForSettingsAck(fr *http2.Framer) error { + for { + frame, err := fr.ReadFrame() + if err != nil { + return fmt.Errorf("reading settings ack: %w", err) + } + settings, ok := frame.(*http2.SettingsFrame) + if ok && settings.IsAck() { + return nil + } + } +} + +func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error { + var hb bytes.Buffer + enc := hpack.NewEncoder(&hb) + for _, hf := range []hpack.HeaderField{ + {Name: ":method", Value: "CONNECT"}, + {Name: ":scheme", Value: "https"}, + {Name: ":authority", Value: addr}, + {Name: ":path", Value: "/upgrade"}, + {Name: ":protocol", Value: "websocket"}, + } { + if err := enc.WriteField(hf); err != nil { + return fmt.Errorf("encoding request headers: %w", err) + } + } + + if err := fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: hb.Bytes(), + EndHeaders: true, + EndStream: false, + }); err != nil { + return fmt.Errorf("writing extended connect headers: %w", err) + } + return nil +} + +func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) { + var block bytes.Buffer + + for { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading response headers: %w", err) + } + if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID { + return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode) + } + + h, ok := frame.(*http2.HeadersFrame) + if !ok || h.StreamID != streamID { + continue + } + + if _, err := block.Write(h.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering response header fragment: %w", err) + } + for !h.HeadersEnded() { + next, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading continuation frame: %w", err) + } + c, ok := next.(*http2.ContinuationFrame) + if !ok || c.StreamID != streamID { + continue + } + if _, err := block.Write(c.HeaderBlockFragment()); err != nil { + return "", fmt.Errorf("buffering continuation fragment: %w", err) + } + if c.HeadersEnded() { + break + } + } + break + } + + var status string + dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) { + if f.Name == ":status" { + status = f.Value + } + }) + if _, err := dec.Write(block.Bytes()); err != nil { + return "", fmt.Errorf("decoding response header block: %w", err) + } + if status == "" { + return "", fmt.Errorf("missing :status in response headers") + } + return status, nil +} + +func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) { + buf := make([]byte, 0, n) + for len(buf) < n { + frame, err := fr.ReadFrame() + if err != nil { + return "", fmt.Errorf("reading stream data: %w", err) + } + d, ok := frame.(*http2.DataFrame) + if !ok || d.StreamID != streamID { + continue + } + buf = append(buf, d.Data()...) + } + return string(buf[:n]), nil +} + +type websocketUpgradeEchoBackend struct { + addr string + ln net.Listener + server *http.Server +} + +func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend { + t.Helper() + + backend := &websocketUpgradeEchoBackend{} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for websocket backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + _ = rw.Flush() + + go func() { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *websocketUpgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() +} diff --git a/caddytest/integration/reverseproxy_upgrade_handlers_test.go b/caddytest/integration/reverseproxy_upgrade_handlers_test.go new file mode 100644 index 00000000000..dda93db0ea3 --- /dev/null +++ b/caddytest/integration/reverseproxy_upgrade_handlers_test.go @@ -0,0 +1,130 @@ +package integration + +import ( + "bufio" + "fmt" + "io" + "net" + "net/textproto" + "strings" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +func TestReverseProxyUpgradeWithEncode(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + encode gzip + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, map[string]string{ + "Accept-Encoding": "gzip", + }) + defer client.Close() + + if err := client.echo("encode-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through encode failed: %v", err) + } +} + +func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) { + tester := caddytest.NewTester(t) + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + tester.InitServer(fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + route { + intercept { + @upgrade status 101 + handle_response @upgrade { + respond "should-not-run" + } + } + reverse_proxy %s + } +} +`, backend.addr), "caddyfile") + + client := newUpgradedStreamClientWithHeaders(t, nil) + defer client.Close() + + if err := client.echo("intercept-upgrade\n"); err != nil { + t.Fatalf("upgraded stream echo through intercept failed: %v", err) + } +} + +func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + requestLines := []string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + } + for k, v := range extraHeaders { + requestLines = append(requestLines, k+": "+v) + } + requestLines = append(requestLines, "", "") + + if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} diff --git a/caddytest/integration/stream_reload_stress_test.go b/caddytest/integration/stream_reload_stress_test.go new file mode 100644 index 00000000000..6ae6e9fa087 --- /dev/null +++ b/caddytest/integration/stream_reload_stress_test.go @@ -0,0 +1,504 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "os" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +const ( + defaultStressStreamCount = 1 + defaultStressReloadCount = 1 + defaultStressCloseDelay = 500 * time.Millisecond +) + +func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) { + tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{ + LoadRequestTimeout: 30 * time.Second, + TestRequestTimeout: 30 * time.Second, + }) + + backend := newUpgradeEchoBackend(t) + defer backend.Close() + + // Three scenarios, each sequential so they don't share Caddy state: + // + // legacy – no delay, close on reload immediately (old default) + // close_delay – stream_close_delay, the old "keep-alive workaround" + // detached – stream_detached, the new explicit detached flag + // + // Reloads are spread across time and interleaved with echo-checks so + // stream health is exercised at each reload boundary, not only at the end. + legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0) + closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t)) + detached := runReloadStress(t, tester, backend.addr, "detached", true, 0) + + if legacy.aliveAfterReloads != 0 { + t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads) + } + if closeDelay.aliveBeforeDelayExpiry == 0 { + t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)") + } + if closeDelay.aliveAfterReloads != 0 { + t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads) + } + if detached.aliveAfterReloads != detached.streamCount { + t.Fatalf("detached mode kept %d/%d upgraded streams alive after reloads", detached.aliveAfterReloads, detached.streamCount) + } + + t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(legacy.beforeReload.HeapInuse), + formatBytes(legacy.midReload.HeapInuse), + formatBytes(legacy.afterReload.HeapInuse), + formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse), + legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects, + legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames, + ) + t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(closeDelay.beforeReload.HeapInuse), + formatBytes(closeDelay.midReload.HeapInuse), + formatBytes(closeDelay.afterReload.HeapInuse), + formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse), + closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects, + closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames, + ) + t.Logf("detached heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)", + formatBytes(detached.beforeReload.HeapInuse), + formatBytes(detached.midReload.HeapInuse), + formatBytes(detached.afterReload.HeapInuse), + formatBytesDiff(detached.beforeReload.HeapInuse, detached.afterReload.HeapInuse), + detached.beforeReload.HeapObjects, detached.afterReload.HeapObjects, + detached.beforeReload.handlerFrames, detached.afterReload.handlerFrames, + ) +} + +type stressRunResult struct { + streamCount int + aliveAfterReloads int + aliveBeforeDelayExpiry int // only meaningful for close_delay mode + beforeReload heapSnapshot + midReload heapSnapshot // after all reloads, before delay expiry clean-up + afterReload heapSnapshot // after all streams have been fully cleaned up +} + +type heapSnapshot struct { + HeapInuse uint64 + HeapObjects uint64 + handlerFrames int + profileBytes int +} + +// runReloadStress opens streamCount upgraded streams, then performs reloadCount +// config reloads spread over time. An echo check is performed every 6 reloads so +// stream health is exercised at each reload boundary rather than only at the end. +// closeDelay mirrors the stream_close_delay config option; pass 0 to disable. +func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, detach bool, closeDelay time.Duration) stressRunResult { + t.Helper() + + const echoEvery = 6 // perform an echo check every N reloads + + streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount) + reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount) + + tester.InitServer(reloadStressConfig(backendAddr, detach, closeDelay, 0), "caddyfile") + + clients := make([]*upgradedStreamClient, 0, streamCount) + for i := 0; i < streamCount; i++ { + client := newUpgradedStreamClient(t) + clients = append(clients, client) + if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil { + closeClients(clients) + t.Fatalf("warmup echo failed in %s mode: %v", mode, err) + } + } + defer closeClients(clients) + + before := captureHeapSnapshot(t) + + // Reloads are spread across time; between batches of echoEvery reloads we + // pause briefly and measure stream health so the snapshot reflects real-world + // reload cadence rather than a tight loop. + for i := 1; i <= reloadCount; i++ { + loadCaddyfileConfig(t, reloadStressConfig(backendAddr, detach, closeDelay, i)) + + // Small pause after each reload to let connection teardown propagate. + time.Sleep(50 * time.Millisecond) + + if i%echoEvery == 0 { + alive := countAliveStreams(clients) + t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i) + + // In detached mode, every stream must survive every reload (upstream unchanged). + if detach { + for j, client := range clients { + if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil { + t.Fatalf("detached mode stream %d died at reload %d: %v", j, i, err) + } + } + } + } + } + + // mid snapshot: after all reloads but before any close_delay timer has fired + // (the delay is long enough to still be running at this point). + mid := captureHeapSnapshot(t) + + // For legacy mode: the reloads close streams immediately; wait for that to complete. + // For close_delay mode: streams are still alive here; wait for the delay to fire. + // For detached mode: streams survive indefinitely; no wait needed. + var aliveBeforeDelayExpiry int + aliveAfterReloads := countAliveStreams(clients) + switch { + case detach: + // nothing to wait for + case closeDelay > 0: + // streams should still be alive at this point (delay hasn't expired) + aliveBeforeDelayExpiry = aliveAfterReloads + t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup", + mode, aliveBeforeDelayExpiry, streamCount, closeDelay) + time.Sleep(closeDelay + 200*time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + default: + deadline := time.Now().Add(2 * time.Second) + for aliveAfterReloads > 0 && time.Now().Before(deadline) { + time.Sleep(50 * time.Millisecond) + aliveAfterReloads = countAliveStreams(clients) + } + } + + after := captureHeapSnapshot(t) + t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)", + mode, + before.profileBytes, mid.profileBytes, after.profileBytes, + before.HeapObjects, mid.HeapObjects, after.HeapObjects, + ) + + return stressRunResult{ + streamCount: streamCount, + aliveAfterReloads: aliveAfterReloads, + aliveBeforeDelayExpiry: aliveBeforeDelayExpiry, + beforeReload: before, + midReload: mid, + afterReload: after, + } +} + +func envIntOrDefault(t *testing.T, key string, def int) int { + t.Helper() + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return def + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive integer", key, raw) + } + return v +} + +func stressCloseDelay(t *testing.T) time.Duration { + t.Helper() + + const key = "CADDY_STRESS_CLOSE_DELAY" + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return defaultStressCloseDelay + } + v, err := time.ParseDuration(raw) + if err != nil || v <= 0 { + t.Fatalf("invalid %s=%q: must be a positive duration", key, raw) + } + return v +} + +func loadCaddyfileConfig(t *testing.T, rawConfig string) { + t.Helper() + + client := &http.Client{Timeout: 30 * time.Second} + req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig)) + if err != nil { + t.Fatalf("creating load request: %v", err) + } + req.Header.Set("Content-Type", "text/caddyfile") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("loading config: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading load response: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body) + } +} + +func reloadStressConfig(backendAddr string, detach bool, closeDelay time.Duration, revision int) string { + var directives string + if detach { + directives += "\n\t\tstream_detached" + } + if closeDelay > 0 { + directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay) + } + + return fmt.Sprintf(` +{ + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + skip_install_trust +} + +localhost:9080 { + reverse_proxy %s { + header_up X-Reload-Revision %d%s + } +} +`, backendAddr, revision, directives) +} + +func captureHeapSnapshot(t *testing.T) heapSnapshot { + t.Helper() + + runtime.GC() + debug.FreeOSMemory() + + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + var buf bytes.Buffer + if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil { + t.Fatalf("capturing heap profile: %v", err) + } + profile := buf.String() + + return heapSnapshot{ + HeapInuse: mem.HeapInuse, + HeapObjects: mem.HeapObjects, + handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"), + profileBytes: buf.Len(), + } +} + +func countAliveStreams(clients []*upgradedStreamClient) int { + alive := 0 + for index, client := range clients { + if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil { + alive++ + } + } + return alive +} + +func closeClients(clients []*upgradedStreamClient) { + for _, client := range clients { + if client != nil { + _ = client.Close() + } + } +} + +func formatBytes(value uint64) string { + const unit = 1024 + if value < unit { + return fmt.Sprintf("%d B", value) + } + div, exp := uint64(unit), 0 + for n := value / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp]) +} + +func formatBytesDiff(before, after uint64) string { + if after >= before { + return "+" + formatBytes(after-before) + } + return "-" + formatBytes(before-after) +} + +type upgradedStreamClient struct { + conn net.Conn + reader *bufio.Reader + mu sync.Mutex +} + +func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient { + t.Helper() + + conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second) + if err != nil { + t.Fatalf("dialing caddy: %v", err) + } + + request := strings.Join([]string{ + "GET /upgrade HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: stress-stream", + "", + "", + }, "\r\n") + if _, err := io.WriteString(conn, request); err != nil { + _ = conn.Close() + t.Fatalf("writing upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + tproto := textproto.NewReader(reader) + statusLine, err := tproto.ReadLine() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + _ = conn.Close() + t.Fatalf("unexpected upgrade status: %s", statusLine) + } + + headers, err := tproto.ReadMIMEHeader() + if err != nil { + _ = conn.Close() + t.Fatalf("reading upgrade headers: %v", err) + } + if !strings.EqualFold(headers.Get("Connection"), "Upgrade") { + _ = conn.Close() + t.Fatalf("unexpected upgrade response headers: %v", headers) + } + + return &upgradedStreamClient{conn: conn, reader: reader} +} + +func (c *upgradedStreamClient) echo(payload string) error { + c.mu.Lock() + defer c.mu.Unlock() + + deadline := time.Now().Add(1 * time.Second) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return err + } + if _, err := io.WriteString(c.conn, payload); err != nil { + return err + } + if err := c.conn.SetReadDeadline(deadline); err != nil { + return err + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(c.reader, buf); err != nil { + return err + } + if string(buf) != payload { + return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload) + } + return nil +} + +func (c *upgradedStreamClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.Close() +} + +type upgradeEchoBackend struct { + addr string + ln net.Listener + mu sync.Mutex + conns map[net.Conn]struct{} + server *http.Server +} + +func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend { + t.Helper() + + backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})} + backend.server = &http.Server{ + Handler: http.HandlerFunc(backend.serveHTTP), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening for backend: %v", err) + } + backend.ln = ln + backend.addr = ln.Addr().String() + + go func() { + _ = backend.server.Serve(ln) + }() + + return backend +} + +func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") { + http.Error(w, "upgrade required", http.StatusUpgradeRequired) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + conn, rw, err := hijacker.Hijack() + if err != nil { + return + } + + b.trackConn(conn) + _, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n") + _ = rw.Flush() + + go func() { + defer b.untrackConn(conn) + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() +} + +func (b *upgradeEchoBackend) trackConn(conn net.Conn) { + b.mu.Lock() + b.conns[conn] = struct{}{} + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) untrackConn(conn net.Conn) { + b.mu.Lock() + delete(b.conns, conn) + b.mu.Unlock() +} + +func (b *upgradeEchoBackend) Close() { + _ = b.server.Close() + _ = b.ln.Close() + + b.mu.Lock() + defer b.mu.Unlock() + for conn := range b.conns { + _ = conn.Close() + } + clear(b.conns) +} diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 904c30c0352..d710160bd54 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -21,6 +21,8 @@ import ( "io" "net" "net/http" + + "github.com/caddyserver/caddy/v2" ) // ResponseWriterWrapper wraps an underlying ResponseWriter and @@ -70,6 +72,8 @@ type responseRecorder struct { size int wroteHeader bool stream bool + hijacked bool + detached bool readSize *int } @@ -144,7 +148,8 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer // WriteHeader writes the headers with statusCode to the wrapped // ResponseWriter unless the response is to be buffered instead. -// 1xx responses are never buffered. +// 1xx responses are never buffered, except 101 which is treated +// as a final upgrade response. func (rr *responseRecorder) WriteHeader(statusCode int) { if rr.wroteHeader { return @@ -161,12 +166,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header()) } - // 1xx responses aren't final; just informational - if statusCode < 100 || statusCode > 199 { + // 1xx responses except 101 aren't final; just informational + if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols { rr.wroteHeader = true } - // if informational or not buffered, immediately write header + // if 1xx or not buffered, immediately write header if rr.stream || (100 <= statusCode && statusCode <= 199) { rr.ResponseWriterWrapper.WriteHeader(statusCode) } @@ -222,7 +227,18 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } +func (rr *responseRecorder) DetachAfterHijack(detached bool) bool { + if rr.hijacked { + return false + } + rr.detached = detached + return true +} + func (rr *responseRecorder) WriteResponse() error { + if rr.hijacked { + return nil + } if rr.statusCode == 0 { // could happen if no handlers actually wrote anything, // and this prevents a panic; status must be > 0 @@ -253,11 +269,25 @@ func (rr *responseRecorder) setReadSize(size *int) { } func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !rr.wroteHeader { + // hijacking without writing status code first works as long as + // subsequent writes follows http1.1 wire format, but it will + // show up with a status code of 0 in the access log and bytes + // written will include response headers. Response headers won't + // be present in the log if not set on the response writer. + caddy.Log().Warn("hijacking without writing status code first") + } //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() if err != nil { return nil, nil, err } + rr.hijacked = true + rr.stream = true + rr.wroteHeader = true + if rr.detached { + return conn, brw, nil + } // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not conn = &hijackedConn{conn, rr} brw.Writer.Reset(conn) @@ -311,6 +341,29 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) { return n, err } +// DetachResponseWriterAfterHijack detaches w or one of its wrapped +// response writers when it's hijacked. Returns true if not already +// hijacked. When detached, bytes read or written stats will not be +// recorded for the hijacked connection, and it's safe to use the +// connection after http middleware returns. +func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool { + for w != nil { + if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok { + return detacher.DetachAfterHijack(detached) + } + unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter }) + if !ok { + return false + } + next := unwrapper.Unwrap() + if next == w { + return false + } + w = next + } + return false +} + // ResponseRecorder is a http.ResponseWriter that records // responses instead of writing them to the client. See // docs for NewResponseRecorder for proper usage. @@ -319,6 +372,7 @@ type ResponseRecorder interface { Status() int Buffer() *bytes.Buffer Buffered() bool + DetachAfterHijack(bool) bool Size() int WriteResponse() error } diff --git a/modules/caddyhttp/responsewriter_test.go b/modules/caddyhttp/responsewriter_test.go index c08ad26a472..4111164815e 100644 --- a/modules/caddyhttp/responsewriter_test.go +++ b/modules/caddyhttp/responsewriter_test.go @@ -1,11 +1,14 @@ package caddyhttp import ( + "bufio" "bytes" "io" + "net" "net/http" "strings" "testing" + "time" ) type responseWriterSpy interface { @@ -44,6 +47,50 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } +type hijackRespWriter struct { + baseRespWriter + header http.Header + status int + conn net.Conn +} + +func newHijackRespWriter() *hijackRespWriter { + return &hijackRespWriter{ + header: make(http.Header), + conn: stubConn{}, + } +} + +func (hrw *hijackRespWriter) Header() http.Header { + return hrw.header +} + +func (hrw *hijackRespWriter) WriteHeader(statusCode int) { + hrw.status = statusCode +} + +func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + br := bufio.NewReader(hrw.conn) + bw := bufio.NewWriter(hrw.conn) + return hrw.conn, bufio.NewReadWriter(br, bw), nil +} + +type stubConn struct{} + +func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (stubConn) Write(p []byte) (int, error) { return len(p), nil } +func (stubConn) Close() error { return nil } +func (stubConn) LocalAddr() net.Addr { return stubAddr("local") } +func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") } +func (stubConn) SetDeadline(time.Time) error { return nil } +func (stubConn) SetReadDeadline(time.Time) error { return nil } +func (stubConn) SetWriteDeadline(time.Time) error { return nil } + +type stubAddr string + +func (a stubAddr) Network() string { return "tcp" } +func (a stubAddr) String() string { return string(a) } + func TestResponseWriterWrapperReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy @@ -169,3 +216,49 @@ func TestResponseRecorderReadFrom(t *testing.T) { }) } } + +func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) { + w := newHijackRespWriter() + var buf bytes.Buffer + + rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool { + return true + }) + rr.WriteHeader(http.StatusSwitchingProtocols) + + if rr.Status() != http.StatusSwitchingProtocols { + t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols) + } + if w.status != http.StatusSwitchingProtocols { + t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols) + } + + hj, ok := rr.(http.Hijacker) + if !ok { + t.Fatal("response recorder does not implement http.Hijacker") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatalf("Hijack() error = %v", err) + } + defer conn.Close() + + if rr.Buffered() { + t.Fatal("hijacked response should not remain buffered") + } + if rr.DetachAfterHijack(true) { + t.Fatal("response recorder should report hijacked state by returning false") + } + if DetachResponseWriterAfterHijack(rr, true) { + t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack") + } + if err := rr.WriteResponse(); err != nil { + t.Fatalf("WriteResponse() after hijack returned error: %v", err) + } + if rr.Size() != 0 { + t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size()) + } + if got := w.Written(); got != "" { + t.Fatalf("unexpected buffered body write after hijack: %q", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index 8716babe336..56eb3fd112d 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -99,6 +99,12 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) // stream_buffer_size // stream_timeout // stream_close_delay +// stream_detached +// stream_logs { +// level +// logger_name +// skip_handshake +// } // verbose_logs // // # request manipulation @@ -703,6 +709,49 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { h.StreamCloseDelay = caddy.Duration(dur) } + case "stream_detached": + if d.NextArg() { + return d.ArgErr() + } + h.StreamDetached = true + + case "stream_logs": + if d.NextArg() { + return d.ArgErr() + } + if h.StreamLogs == nil { + h.StreamLogs = new(StreamLogs) + } + + nesting := d.Nesting() + for d.NextBlock(nesting) { + switch d.Val() { + case "level": + if !d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.Level = d.Val() + if d.NextArg() { + return d.ArgErr() + } + case "logger_name": + if !d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.LoggerName = d.Val() + if d.NextArg() { + return d.ArgErr() + } + case "skip_handshake": + if d.NextArg() { + return d.ArgErr() + } + h.StreamLogs.SkipHandshake = true + default: + return d.Errf("unrecognized stream_logs option: %s", d.Val()) + } + } + case "trusted_proxies": for d.NextArg() { if d.Val() == "private_ranges" { diff --git a/modules/caddyhttp/reverseproxy/copyresponse.go b/modules/caddyhttp/reverseproxy/copyresponse.go index c1c9de92ba8..ec1720d31b4 100644 --- a/modules/caddyhttp/reverseproxy/copyresponse.go +++ b/modules/caddyhttp/reverseproxy/copyresponse.go @@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request hrc.isFinalized = true // write the response - return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger) + return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr) } // CopyResponseHeadersHandler is a special HTTP handler which may diff --git a/modules/caddyhttp/reverseproxy/extended_connect_test.go b/modules/caddyhttp/reverseproxy/extended_connect_test.go new file mode 100644 index 00000000000..5cb27d807e3 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/extended_connect_test.go @@ -0,0 +1,146 @@ +package reverseproxy + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "go.uber.org/zap" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" +) + +type extendedConnectCapture struct { + method string + headers http.Header + body []byte + extendedBodyPresent bool + extendedConnectBody []byte +} + +type extendedConnectCaptureTransport struct { + mu sync.Mutex + capture extendedConnectCapture +} + +func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + c := extendedConnectCapture{ + method: req.Method, + headers: req.Header.Clone(), + body: body, + } + if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + c.extendedBodyPresent = true + c.extendedConnectBody, err = io.ReadAll(rc) + if err != nil { + return nil, err + } + _ = rc.Close() + } + + tr.mu.Lock() + tr.capture = c + tr.mu.Unlock() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: req, + }, nil +} + +func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture { + tr.mu.Lock() + defer tr.mu.Unlock() + return tr.capture +} + +func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) { + tests := []struct { + name string + protoMajor int + proto string + headers map[string]string + }{ + { + name: "h2 extended connect", + protoMajor: 2, + proto: "HTTP/2.0", + headers: map[string]string{ + ":protocol": "websocket", + }, + }, + { + name: "h3 extended connect", + protoMajor: 3, + proto: "websocket", + headers: map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + const payload = "extended-connect-body" + + transport := new(extendedConnectCaptureTransport) + h := &Handler{ + logger: zap.NewNop(), + Transport: transport, + Upstreams: UpstreamPool{ + &Upstream{Host: new(Host), Dial: "127.0.0.1:8443"}, + }, + LoadBalancing: &LoadBalancing{ + SelectionPolicy: &RoundRobinSelection{}, + }, + } + + req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload)) + req.ProtoMajor = tc.protoMajor + req.Proto = tc.proto + for key, value := range tc.headers { + req.Header.Set(key, value) + } + req = prepareTestRequest(req) + + rr := httptest.NewRecorder() + err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + return nil + })) + if err != nil { + t.Fatalf("ServeHTTP() error = %v", err) + } + + captured := transport.Snapshot() + if captured.method != http.MethodGet { + t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet) + } + if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") { + t.Fatalf("Upgrade header = %q, want websocket", got) + } + if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") { + t.Fatalf("Connection header = %q, want Upgrade", got) + } + if got := captured.headers.Get(":protocol"); got != "" { + t.Fatalf(":protocol header should be removed, got %q", got) + } + if len(captured.body) != 0 { + t.Fatalf("upstream request body length = %d, want 0", len(captured.body)) + } + if !captured.extendedBodyPresent { + t.Fatal("extended_connect_websocket_body variable missing from request context") + } + if string(captured.extendedConnectBody) != payload { + t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload) + } + }) + } +} diff --git a/modules/caddyhttp/reverseproxy/metrics.go b/modules/caddyhttp/reverseproxy/metrics.go index 2488427304e..4b26d86419c 100644 --- a/modules/caddyhttp/reverseproxy/metrics.go +++ b/modules/caddyhttp/reverseproxy/metrics.go @@ -16,6 +16,10 @@ import ( var reverseProxyMetrics = struct { once sync.Once upstreamsHealthy *prometheus.GaugeVec + streamsActive *prometheus.GaugeVec + streamsTotal *prometheus.CounterVec + streamDuration *prometheus.HistogramVec + streamBytes *prometheus.CounterVec logger *zap.Logger }{} @@ -23,6 +27,8 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { const ns, sub = "caddy", "reverse_proxy" upstreamsLabels := []string{"upstream"} + streamResultLabels := []string{"upstream", "result"} + streamBytesLabels := []string{"upstream", "direction"} reverseProxyMetrics.once.Do(func() { reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: ns, @@ -30,6 +36,31 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { Name: "upstreams_healthy", Help: "Health status of reverse proxy upstreams.", }, upstreamsLabels) + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_active", + Help: "Number of currently active upgraded reverse proxy streams.", + }, upstreamsLabels) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "streams_total", + Help: "Total number of upgraded reverse proxy streams by close result.", + }, streamResultLabels) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_duration_seconds", + Help: "Duration of upgraded reverse proxy streams by close result.", + Buckets: prometheus.DefBuckets, + }, streamResultLabels) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: ns, + Subsystem: sub, + Name: "stream_bytes_total", + Help: "Total bytes proxied across upgraded reverse proxy streams.", + }, streamBytesLabels) }) // duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because @@ -42,10 +73,58 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) { }) { panic(err) } + if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsActive, + NewCollector: reverseProxyMetrics.streamsActive, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamsTotal, + NewCollector: reverseProxyMetrics.streamsTotal, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamDuration, + NewCollector: reverseProxyMetrics.streamDuration, + }) { + panic(err) + } + if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil && + !errors.Is(err, prometheus.AlreadyRegisteredError{ + ExistingCollector: reverseProxyMetrics.streamBytes, + NewCollector: reverseProxyMetrics.streamBytes, + }) { + panic(err) + } reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics") } +func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) { + labels := prometheus.Labels{"upstream": upstream} + reverseProxyMetrics.streamsActive.With(labels).Inc() + + var once sync.Once + return func(result string, duration time.Duration, toBackend, fromBackend int64) { + once.Do(func() { + reverseProxyMetrics.streamsActive.With(labels).Dec() + reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc() + reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds()) + if toBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend)) + } + if fromBackend > 0 { + reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend)) + } + }) + } +} + type metricsUpstreamsHealthyUpdater struct { handler *Handler } diff --git a/modules/caddyhttp/reverseproxy/metrics_test.go b/modules/caddyhttp/reverseproxy/metrics_test.go new file mode 100644 index 00000000000..edbe9ca8d76 --- /dev/null +++ b/modules/caddyhttp/reverseproxy/metrics_test.go @@ -0,0 +1,67 @@ +package reverseproxy + +import ( + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) { + const upstream = "127.0.0.1:7443" + + // Use fresh metric vectors for deterministic assertions in this unit test. + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + finish := trackActiveStream(upstream) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 { + t.Fatalf("active streams = %v, want 1", got) + } + + finish("closed", 150*time.Millisecond, 1234, 4321) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 { + t.Fatalf("active streams = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 { + t.Fatalf("streams_total closed = %v, want 1", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 { + t.Fatalf("bytes to_upstream = %v, want 1234", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 { + t.Fatalf("bytes from_upstream = %v, want 4321", got) + } + + // A second finish call should be ignored by the once guard. + finish("error", 1*time.Second, 111, 222) + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 { + t.Fatalf("streams_total error = %v, want 0", got) + } +} + +func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) { + const upstream = "127.0.0.1:9000" + + reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"}) + reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"}) + reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"}) + + trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0) + + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 { + t.Fatalf("bytes to_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 { + t.Fatalf("bytes from_upstream = %v, want 0", got) + } + if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 { + t.Fatalf("streams_total timeout = %v, want 1", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 52d2b1ab30f..61f31b7657e 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -186,6 +186,22 @@ type Handler struct { // by the previous config closing. Default: no delay. StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"` + // If true, upgraded connections such as WebSockets are detached from + // the handler and retained across config reloads when their upstream + // still exists in the new config. Connections using upstreams that are + // removed are closed during cleanup. By default this is false, preserving + // legacy behavior where upgraded connections are closed on reload + // (optionally delayed by stream_close_delay). + // Only http1.1 websocket connections are affected, websockets for h2/h3 + // are not affected. If true, bytes transferred for http1.1 in the access + // logs will be zero but those stats can be found in the stream logs for + // http1/2/3 regardless if this is enabled. + StreamDetached bool `json:"stream_detached,omitempty"` + + // Controls logging behavior for upgraded stream lifecycle events. + // If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream"). + StreamLogs *StreamLogs `json:"stream_logs,omitempty"` + // If configured, rewrites the copy of the upstream request. // Allows changing the request method and URI (path and query). // Since the rewrite is applied to the copy, it does not persist @@ -240,14 +256,16 @@ type Handler struct { // Holds the handle_response Caddyfile tokens while adapting handleResponseSegments []*caddyfile.Dispenser - // Stores upgraded requests (hijacked connections) for proper cleanup - connections map[io.ReadWriteCloser]openConnection - connectionsCloseTimer *time.Timer - connectionsMu *sync.Mutex + // Tracks hijacked/upgraded connections (WebSocket etc.) so they can be + // closed when their upstream is removed from the config. + tunnelTracker *tunnelTracker ctx caddy.Context logger *zap.Logger events *caddyevents.App + + streamLogLevel zapcore.Level + streamLogLoggerName string } // CaddyModule returns the Caddy module information. @@ -267,8 +285,25 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.events = eventAppIface.(*caddyevents.App) h.ctx = ctx h.logger = ctx.Logger() - h.connections = make(map[io.ReadWriteCloser]openConnection) - h.connectionsMu = new(sync.Mutex) + h.tunnelTracker = newTunnelTracker(h.logger, time.Duration(h.StreamCloseDelay)) + h.streamLogLevel = defaultStreamLogLevel + h.streamLogLoggerName = defaultStreamLoggerName + if h.StreamLogs != nil { + if h.StreamLogs.Level != "" { + lvl, err := zapcore.ParseLevel(strings.ToLower(strings.TrimSpace(h.StreamLogs.Level))) + if err != nil { + return fmt.Errorf("invalid stream_logs.level %q: %w", h.StreamLogs.Level, err) + } + h.streamLogLevel = lvl + } + if name := strings.TrimSpace(h.StreamLogs.LoggerName); name != "" { + h.streamLogLoggerName = name + } + } + + if h.StreamDetached { + registerDetachedTunnelTrackers(h.tunnelTracker) + } // warn about unsafe buffering config if h.RequestBuffers == -1 || h.ResponseBuffers == -1 { @@ -437,15 +472,85 @@ func (h *Handler) Provision(ctx caddy.Context) error { return nil } +func (h Handler) streamLogsSkipHandshake() bool { + return h.StreamLogs != nil && h.StreamLogs.SkipHandshake +} + +func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger { + name := strings.TrimSpace(h.streamLogLoggerName) + if name == "" { + name = defaultStreamLoggerName + } + + if name == streamLoggerNameUseAccess { + logger := caddy.Log().Named(defaultAccessLoggerBase) + names := caddyhttp.GetVar(req.Context(), caddyhttp.AccessLoggerNameVarKey) + namesSlice, ok := names.([]any) + if !ok { + return logger + } + for _, v := range namesSlice { + name, ok := v.(string) + if !ok { + continue + } + if name == "" { + return logger + } + return logger.Named(name) + } + return logger + } + + return caddy.Log().Named(name) +} + +var ( + detachedTunnelTrackers = make(map[*tunnelTracker]struct{}) + detachedTunnelTrackersMu sync.Mutex +) + +func registerDetachedTunnelTrackers(ts *tunnelTracker) { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() + detachedTunnelTrackers[ts] = struct{}{} +} + +func notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream string, self *tunnelTracker) error { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() + + var err error + for tunnel := range detachedTunnelTrackers { + if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil { + err = closeErr + } + } + return err +} + +func unregisterDetachedTunnelTrackers(ts *tunnelTracker) { + detachedTunnelTrackersMu.Lock() + defer detachedTunnelTrackersMu.Unlock() + delete(detachedTunnelTrackers, ts) +} + // Cleanup cleans up the resources made by h. func (h *Handler) Cleanup() error { - err := h.cleanupConnections() - - // remove hosts from our config from the pool + // even if StreamDetached is true, extended connect websockets may still be running + err := h.tunnelTracker.cleanupAttachedConnections() for _, upstream := range h.Upstreams { - _, _ = hosts.Delete(upstream.String()) + // hosts.Delete returns deleted=true when the ref count reaches zero, + // meaning no other active config references this upstream. In that + // case close any tunnels proxying to it; otherwise let them survive + // to their natural end since the upstream is still in use. + deleted, _ := hosts.Delete(upstream.String()) + if deleted { + if closeErr := notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream.String(), h.tunnelTracker); closeErr != nil && err == nil { + err = closeErr + } + } } - return err } @@ -1127,10 +1232,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe // we use the original request here, so that any routes from 'next' // see the original request rather than the proxy cloned request. hrc := &handleResponseContext{ - handler: h, - response: res, - start: start, - logger: logger, + handler: h, + response: res, + start: start, + logger: logger, + upstreamAddr: di.Upstream.String(), } ctx := origReq.Context() ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc) @@ -1160,7 +1266,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe } // copy the response body and headers back to the upstream client - return h.finalizeResponse(rw, req, res, repl, start, logger) + return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String()) } // finalizeResponse prepares and copies the response. @@ -1171,12 +1277,11 @@ func (h *Handler) finalizeResponse( repl *caddy.Replacer, start time.Time, logger *zap.Logger, + upstreamAddr string, ) error { // deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { - var wg sync.WaitGroup - h.handleUpgradeResponse(logger, &wg, rw, req, res) - wg.Wait() + h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr) return nil } @@ -1765,6 +1870,22 @@ func (brc bodyReadCloser) Close() error { return nil } +// StreamLogs controls logging for upgraded stream lifecycle events. +type StreamLogs struct { + // The minimum level at which stream lifecycle events are logged. + // Supported values are debug, info, warn, and error. Default: debug. + Level string `json:"level,omitempty"` + + // Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream". + // Special value "access" uses the access logger namespace and, if set, + // respects the first value in access_logger_names/log_name for the request. + LoggerName string `json:"logger_name,omitempty"` + + // If true, suppresses the access log entry normally emitted when an + // upgraded stream handshake completes and the request unwinds. + SkipHandshake bool `json:"skip_handshake,omitempty"` +} + // bufPool is used for buffering requests and responses. var bufPool = sync.Pool{ New: func() any { @@ -1797,6 +1918,9 @@ type handleResponseContext struct { // i.e. copied and closed, to make sure that it doesn't // happen twice. isFinalized bool + + // upstreamAddr is the selected upstream address for this request. + upstreamAddr string } // proxyHandleResponseContextCtxKey is the context key for the active proxy handler @@ -1807,6 +1931,13 @@ const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_resp // errNoUpstream occurs when there are no upstream available. var errNoUpstream = fmt.Errorf("no upstreams available") +const ( + defaultStreamLogLevel = zapcore.DebugLevel + defaultStreamLoggerName = "http.handlers.reverse_proxy.stream" + streamLoggerNameUseAccess = "access" + defaultAccessLoggerBase = "http.log.access" +) + // Interface guards var ( _ caddy.Provisioner = (*Handler)(nil) diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index e454ee65547..a50e615e423 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -26,6 +26,7 @@ import ( "io" weakrand "math/rand/v2" "mime" + "net" "net/http" "sync" "time" @@ -35,15 +36,16 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/net/http/httpguts" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) -type h2ReadWriteCloser struct { +type extendedConnectReadWriteCloser struct { io.ReadCloser http.ResponseWriter } -func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { +func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) { n, err = rwc.ResponseWriter.Write(p) if err != nil { return 0, err @@ -57,7 +59,7 @@ func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { return n, nil } -func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) { +func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -90,13 +92,37 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, copyHeader(rw.Header(), res.Header) normalizeWebsocketHeaders(rw.Header()) + // Capture all h fields needed by the tunnel now, so that the Handler (h) + // is not referenced after this function returns (for HTTP/1.1 hijacked + // connections the tunnel runs in a detached goroutine). + tunnel := h.tunnelTracker + bufferSize := h.StreamBufferSize + streamTimeout := time.Duration(h.StreamTimeout) + + if h.StreamDetached { + // the return value should be true as it's not hijacked yet, + // but some middleware may wrap response writers incorrectly + if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) { + if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil { + c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked")) + } + } + } + var ( - conn io.ReadWriteCloser - brw *bufio.ReadWriter + conn io.ReadWriteCloser + brw *bufio.ReadWriter + detached = h.StreamDetached ) - // websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade - // TODO: once we can reliably detect backend support this, it can be removed for those backends + // websocket over http2 or http3 if extended connect is enabled, + // assuming backend doesn't support this, the request will be + // modified to http1.1 upgrade + // TODO: once we can reliably detect backend support this, it can + // be removed for those backends if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok { + // websocket over extended connect can't be detached. rw and req.Body + // are only valid while the handler goroutine is running + detached = false req.Body = body rw.Header().Del("Upgrade") rw.Header().Del("Connection") @@ -104,18 +130,18 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw.WriteHeader(http.StatusOK) if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil { - c.Write(zap.Int("http_version", 2)) + c.Write(zap.Int("http_version", req.ProtoMajor)) } //nolint:bodyclose flushErr := http.NewResponseController(rw).Flush() if flushErr != nil { - if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil { + if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil { c.Write(zap.Error(flushErr)) } return } - conn = h2ReadWriteCloser{req.Body, rw} + conn = extendedConnectReadWriteCloser{req.Body, rw} // bufio is not needed, use minimal buffer brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) } else { @@ -143,27 +169,6 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5 - backConnCloseCh := make(chan struct{}) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: - } - backConn.Close() - }() - defer close(backConnCloseCh) - - start := time.Now() - defer func() { - conn.Close() - if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil { - c.Write(zap.Duration("duration", time.Since(start))) - } - }() - if err := brw.Flush(); err != nil { if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil { c.Write(zap.Error(err)) @@ -184,13 +189,12 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } } - // Ensure the hijacked client connection, and the new connection established - // with the backend, are both closed in the event of a server shutdown. This - // is done by registering them. We also try to gracefully close connections - // we recognize as websockets. - // We need to make sure the client connection messages (i.e. to upstream) - // are masked, so we need to know whether the connection is considered the - // server or the client side of the proxy. + // Register both connections with the tunnel tracker. We also try to + // gracefully close connections we recognize as websockets. We need to make + // sure the client connection messages (i.e. to upstream) are masked, so we + // need to know whether the connection is considered the server or the + // client side of the proxy. Note that gracefulClose must not capture h, + // since the tunnel may outlive the handler instance. gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error { if isWebsocket(req) { return func() error { @@ -199,43 +203,147 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, } return nil } - deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false)) - deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true)) - defer deleteFrontConn() + deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr) + deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr) + if h.streamLogsSkipHandshake() { + caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true) + } + repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + repl.Set("http.reverse_proxy.upgraded", true) + streamUUID, _ := repl.GetString("http.request.uuid") + streamFields := makeStreamLogFields(streamUUID) + streamLogger := h.streamLoggerForRequest(req) + streamLevel := h.streamLogLevel + finishMetrics := trackActiveStream(upstreamAddr) + + start := time.Now() + + if !detached { + handleUpgradeTunnel( + streamLogger, + streamLevel, + conn, + backConn, + deleteFrontConn, + deleteBackConn, + bufferSize, + streamTimeout, + start, + finishMetrics, + streamFields, + ) + } else { + // start a new goroutine + go handleUpgradeTunnel( + streamLogger, + streamLevel, + conn, + backConn, + deleteFrontConn, + deleteBackConn, + bufferSize, + streamTimeout, + start, + finishMetrics, + streamFields, + ) + } +} + +// handleUpgradeTunnel returns when transfer is done. +func handleUpgradeTunnel( + streamLogger *zap.Logger, + streamLevel zapcore.Level, + conn io.ReadWriteCloser, + backConn io.ReadWriteCloser, + deleteFrontConn func(), + deleteBackConn func(), + bufferSize int, + streamTimeout time.Duration, + start time.Time, + finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64), + streamFields []zap.Field, +) { defer deleteBackConn() + defer deleteFrontConn() + var ( + wg sync.WaitGroup + toBackend int64 + fromBackend int64 + result string + ) + // when a stream timeout is encountered, no error will be read from errc + // a buffer size of 2 will allow both the read and write goroutines to + // send the error and exit + // see: https://github.com/caddyserver/caddy/issues/7418 + errc := make(chan error, 2) spc := switchProtocolCopier{ user: conn, backend: backConn, - wg: wg, - bufferSize: h.StreamBufferSize, + wg: &wg, + bufferSize: bufferSize, + sent: &toBackend, + received: &fromBackend, } + wg.Add(2) - // setup the timeout if requested var timeoutc <-chan time.Time - if h.StreamTimeout > 0 { - timer := time.NewTimer(time.Duration(h.StreamTimeout)) + if streamTimeout > 0 { + timer := time.NewTimer(streamTimeout) defer timer.Stop() timeoutc = timer.C } - // when a stream timeout is encountered, no error will be read from errc - // a buffer size of 2 will allow both the read and write goroutines to send the error and exit - // see: https://github.com/caddyserver/caddy/issues/7418 - errc := make(chan error, 2) - wg.Add(2) go spc.copyToBackend(errc) go spc.copyFromBackend(errc) select { case err := <-errc: - if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil { + result = classifyStreamResult(err) + if c := streamLogger.Check(streamLevel, "streaming error"); c != nil { c.Write(zap.Error(err)) } - case time := <-timeoutc: - if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil { - c.Write(zap.Time("timeout", time)) + case t := <-timeoutc: + result = "timeout" + if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil { + c.Write(zap.Time("timeout", t)) } } + + // Close both ends to unblock the still-running copy goroutine, + // then wait for it so byte counts are final before metrics/logging. + conn.Close() + backConn.Close() + wg.Wait() + + finishMetrics(result, time.Since(start), toBackend, fromBackend) + if c := streamLogger.Check(streamLevel, "connection closed"); c != nil { + fields := append([]zap.Field{}, streamFields...) + fields = append(fields, + zap.Duration("duration", time.Since(start)), + zap.Int64("bytes_to_backend", toBackend), + zap.Int64("bytes_from_backend", fromBackend), + ) + c.Write(fields...) + } +} + +func classifyStreamResult(err error) string { + if err == nil || + errors.Is(err, io.EOF) || + errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) { + return "closed" + } + return "error" +} + +func makeStreamLogFields(streamUUID string) []zap.Field { + fields := make([]zap.Field, 0, 1) + if streamUUID != "" { + fields = append(fields, zap.String("uuid", streamUUID)) + } + return fields } // flushInterval returns the p.FlushInterval value, conditionally @@ -375,75 +483,101 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za } } -// registerConnection holds onto conn so it can be closed in the event -// of a server shutdown. This is useful because hijacked connections or -// connections dialed to backends don't close when server is shut down. -// The caller should call the returned delete() function when the -// connection is done to remove it from memory. -func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) { - h.connectionsMu.Lock() - h.connections[conn] = openConnection{conn, gracefulClose} - h.connectionsMu.Unlock() +// openConnection maps an open connection to an optional function for graceful +// close and records which upstream address the connection is proxying to. +// Also tracks whether the connection is detached, which means it should only be +// closed when the upstream is removed from the config, not on every reload. +type openConnection struct { + conn io.ReadWriteCloser + gracefulClose func() error + detached bool + upstream string +} + +// tunnelTracker tracks hijacked/upgraded connections for selective cleanup. +// This exists to detach the lifecycle of streaming connections from the proxy +// Handler and config, since we typically want them to survive past config reloads. +// It also allows for selective connection cleanup based on their attachment status. +type tunnelTracker struct { + connections map[io.ReadWriteCloser]openConnection + closeTimer *time.Timer + closeDelay time.Duration + stopped bool + mu sync.Mutex + logger *zap.Logger +} + +func newTunnelTracker(logger *zap.Logger, closeDelay time.Duration) *tunnelTracker { + return &tunnelTracker{ + connections: make(map[io.ReadWriteCloser]openConnection), + closeDelay: closeDelay, + logger: logger, + } +} + +// registerConnection stores conn in the tracking map. The caller must invoke +// the returned del func when the connection is done. +func (ts *tunnelTracker) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) { + ts.mu.Lock() + ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream} + ts.mu.Unlock() return func() { - h.connectionsMu.Lock() - delete(h.connections, conn) - // if there is no connection left before the connections close timer fires - if len(h.connections) == 0 && h.connectionsCloseTimer != nil { - // we release the timer that holds the reference to Handler - if (*h.connectionsCloseTimer).Stop() { - h.logger.Debug("stopped streaming connections close timer - all connections are already closed") + ts.mu.Lock() + delete(ts.connections, conn) + if len(ts.connections) == 0 && ts.stopped { + unregisterDetachedTunnelTrackers(ts) + if ts.closeTimer != nil { + if ts.closeTimer.Stop() { + ts.logger.Debug("stopped streaming connections close timer - all connections are already closed") + } + ts.closeTimer = nil } - h.connectionsCloseTimer = nil } - h.connectionsMu.Unlock() + ts.mu.Unlock() } } -// closeConnections immediately closes all hijacked connections (both to client and backend). -func (h *Handler) closeConnections() error { +// closeAttachedConnections closes all tracked attached connections. +func (ts *tunnelTracker) closeAttachedConnections() error { var err error - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() - - for _, oc := range h.connections { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.stopped = true + for _, oc := range ts.connections { + // detached connections are only closed when the upstream is gone from the config + if oc.detached { + continue + } if oc.gracefulClose != nil { - // this is potentially blocking while we have the lock on the connections - // map, but that should be OK since the server has in theory shut down - // and we are no longer using the connections map - gracefulErr := oc.gracefulClose() - if gracefulErr != nil && err == nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { err = gracefulErr } } - closeErr := oc.conn.Close() - if closeErr != nil && err == nil { + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { err = closeErr } } return err } -// cleanupConnections closes hijacked connections. -// Depending on the value of StreamCloseDelay it does that either immediately -// or sets up a timer that will do that later. -func (h *Handler) cleanupConnections() error { - if h.StreamCloseDelay == 0 { - return h.closeConnections() - } - - h.connectionsMu.Lock() - defer h.connectionsMu.Unlock() - // the handler is shut down, no new connection can appear, - // so we can skip setting up the timer when there are no connections - if len(h.connections) > 0 { - delay := time.Duration(h.StreamCloseDelay) - h.connectionsCloseTimer = time.AfterFunc(delay, func() { - if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { +// cleanupAttachedConnections closes upgraded attached connections. +// Depending on closeDelay it does that either immediately or after a timer. +func (ts *tunnelTracker) cleanupAttachedConnections() error { + if ts.closeDelay == 0 { + return ts.closeAttachedConnections() + } + + ts.mu.Lock() + defer ts.mu.Unlock() + if len(ts.connections) > 0 { + delay := ts.closeDelay + ts.closeTimer = time.AfterFunc(delay, func() { + if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil { c.Write(zap.Duration("delay", delay)) } - err := h.closeConnections() + err := ts.closeAttachedConnections() if err != nil { - if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil { + if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil { c.Write( zap.Error(err), zap.Duration("delay", delay), @@ -567,11 +701,29 @@ func isWebsocket(r *http.Request) bool { httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket") } -// openConnection maps an open connection to -// an optional function for graceful close. -type openConnection struct { - conn io.ReadWriteCloser - gracefulClose func() error +// closeConnectionsForUpstream closes all tracked connections that were +// established to the given upstream address. +func (ts *tunnelTracker) closeConnectionsForUpstream(addr string) error { + var err error + ts.mu.Lock() + defer ts.mu.Unlock() + if !ts.stopped { + return nil + } + for _, oc := range ts.connections { + if oc.upstream != addr { + continue + } + if oc.gracefulClose != nil { + if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil { + err = gracefulErr + } + } + if closeErr := oc.conn.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + return err } type maxLatencyWriter struct { @@ -642,16 +794,23 @@ type switchProtocolCopier struct { user, backend io.ReadWriteCloser wg *sync.WaitGroup bufferSize int + // sent and received accumulate byte counts for each direction. + // They are written before wg.Done() and read after wg.Wait(), so no + // additional synchronization is needed beyond the WaitGroup barrier. + sent *int64 // bytes copied to backend; must be non-nil + received *int64 // bytes copied from backend; must be non-nil } func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + n, err := io.CopyBuffer(c.user, c.backend, c.buffer()) + *c.received = n errc <- err c.wg.Done() } func (c switchProtocolCopier) copyToBackend(errc chan<- error) { - _, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + n, err := io.CopyBuffer(c.backend, c.user, c.buffer()) + *c.sent = n errc <- err c.wg.Done() } diff --git a/modules/caddyhttp/reverseproxy/streaming_test.go b/modules/caddyhttp/reverseproxy/streaming_test.go index ce0db65a06c..7dc5e476cf3 100644 --- a/modules/caddyhttp/reverseproxy/streaming_test.go +++ b/modules/caddyhttp/reverseproxy/streaming_test.go @@ -7,8 +7,10 @@ import ( "strings" "sync" "testing" + "time" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ) func TestHandlerCopyResponse(t *testing.T) { @@ -41,12 +43,15 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) { var wg sync.WaitGroup var errc = make(chan error, 1) var dst bytes.Buffer + var sent, received int64 copier := switchProtocolCopier{ user: nopReadWriteCloser{Reader: strings.NewReader("hello")}, backend: nopReadWriteCloser{Writer: &dst}, wg: &wg, bufferSize: 7, + sent: &sent, + received: &received, } buf := copier.buffer() @@ -80,3 +85,146 @@ type nopReadWriteCloser struct { } func (nopReadWriteCloser) Close() error { return nil } + +type trackingReadWriteCloser struct { + closed chan struct{} + one sync.Once +} + +func newTrackingReadWriteCloser() *trackingReadWriteCloser { + return &trackingReadWriteCloser{closed: make(chan struct{})} +} + +func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil } +func (c *trackingReadWriteCloser) Close() error { + c.one.Do(func() { + close(c.closed) + }) + return nil +} + +func (c *trackingReadWriteCloser) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) { + ts := newTunnelTracker(caddy.Log(), 0) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, false, "a") + ts.registerConnection(connB, nil, false, "b") + + h := &Handler{ + tunnelTracker: ts, + StreamDetached: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if !connA.isClosed() || !connB.isClosed() { + t.Fatalf("legacy cleanup should close all upgraded connections") + } +} + +func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) { + ts := newTunnelTracker(caddy.Log(), 40*time.Millisecond) + conn := newTrackingReadWriteCloser() + ts.registerConnection(conn, nil, false, "a") + + h := &Handler{ + tunnelTracker: ts, + StreamDetached: false, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + if conn.isClosed() { + t.Fatal("connection should not close immediately when stream_close_delay is set") + } + + select { + case <-conn.closed: + case <-time.After(500 * time.Millisecond): + t.Fatal("connection did not close after stream_close_delay elapsed") + } +} + +func TestHandlerCleanupDetachedModeClosesOnlyRemovedUpstreams(t *testing.T) { + const upstreamA = "upstream-a" + const upstreamB = "upstream-b" + + // Simulate old+new configs both referencing upstreamA (refcount 2), + // while upstreamB is only referenced by the old config (refcount 1). + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamA, struct{}{}) + hosts.LoadOrStore(upstreamB, struct{}{}) + t.Cleanup(func() { + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamA) + _, _ = hosts.Delete(upstreamB) + }) + + ts := newTunnelTracker(caddy.Log(), 0) + registerDetachedTunnelTrackers(ts) + connA := newTrackingReadWriteCloser() + connB := newTrackingReadWriteCloser() + ts.registerConnection(connA, nil, true, upstreamA) + ts.registerConnection(connB, nil, true, upstreamB) + + h := &Handler{ + tunnelTracker: ts, + StreamDetached: true, + Upstreams: UpstreamPool{ + &Upstream{Dial: upstreamA}, + &Upstream{Dial: upstreamB}, + }, + } + + if err := h.Cleanup(); err != nil { + t.Fatalf("cleanup failed: %v", err) + } + + if connA.isClosed() { + t.Fatal("connection for detached upstream should remain open") + } + if !connB.isClosed() { + t.Fatal("connection for removed upstream should be closed") + } +} + +func TestHandlerUnmarshalCaddyfileStreamLogsBlock(t *testing.T) { + d := caddyfile.NewTestDispenser(` + reverse_proxy localhost:9000 { + stream_logs { + level info + logger_name access + skip_handshake + } + } + `) + + var h Handler + if err := h.UnmarshalCaddyfile(d); err != nil { + t.Fatalf("UnmarshalCaddyfile() error = %v", err) + } + if h.StreamLogs == nil { + t.Fatal("expected stream_logs to be configured") + } + if h.StreamLogs.Level != "info" { + t.Fatalf("expected stream_logs.level=info, got %q", h.StreamLogs.Level) + } + if h.StreamLogs.LoggerName != "access" { + t.Fatalf("expected stream_logs.logger_name=access, got %q", h.StreamLogs.LoggerName) + } + if !h.StreamLogs.SkipHandshake { + t.Fatal("expected stream_logs.skip_handshake=true") + } +}