Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions modules/caddyhttp/reverseproxy/active_health_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package reverseproxy

import (
"testing"
)

func newTestUpstream() *Upstream {
return &Upstream{
Host: new(Host),
activeHealthStats: &ActiveHealthStats{},
}
}

// TestConsecutiveCounterResetOnPass verifies that a health check pass
// resets the consecutive failure counter to zero. Without this, non-
// consecutive failures could accumulate and incorrectly trip the threshold.
func TestConsecutiveCounterResetOnPass(t *testing.T) {
upstream := newTestUpstream()

// Simulate: fail, fail, then pass
upstream.countHealthFail(1)
upstream.countHealthFail(1)
if upstream.activeHealthFails() != 2 {
t.Fatalf("expected 2 fails, got %d", upstream.activeHealthFails())
}

// A pass should reset the fail counter
upstream.countHealthPass(1)
if upstream.activeHealthFails() != 0 {
t.Errorf("expected fail counter to reset to 0 after a pass, got %d", upstream.activeHealthFails())
}
if upstream.activeHealthPasses() != 1 {
t.Errorf("expected 1 pass, got %d", upstream.activeHealthPasses())
}
}

// TestConsecutiveCounterResetOnFail verifies that a health check failure
// resets the consecutive pass counter to zero.
func TestConsecutiveCounterResetOnFail(t *testing.T) {
upstream := newTestUpstream()

// Simulate: pass, pass, then fail
upstream.countHealthPass(1)
upstream.countHealthPass(1)
if upstream.activeHealthPasses() != 2 {
t.Fatalf("expected 2 passes, got %d", upstream.activeHealthPasses())
}

// A fail should reset the pass counter
upstream.countHealthFail(1)
if upstream.activeHealthPasses() != 0 {
t.Errorf("expected pass counter to reset to 0 after a fail, got %d", upstream.activeHealthPasses())
}
if upstream.activeHealthFails() != 1 {
t.Errorf("expected 1 fail, got %d", upstream.activeHealthFails())
}
}

// TestNonConsecutiveFailuresDoNotTripThreshold is a regression test:
// interleaved pass/fail results must NOT accumulate toward the threshold.
// Before the fix, fail-pass-fail-pass-fail would reach Fails=3 even
// though there were zero consecutive failures.
func TestNonConsecutiveFailuresDoNotTripThreshold(t *testing.T) {
upstream := newTestUpstream()

// Interleave: fail, pass, fail, pass, fail
for i := 0; i < 3; i++ {
upstream.countHealthFail(1)
if i < 2 {
upstream.countHealthPass(1)
}
}

// With correct consecutive tracking, we should have only 1 consecutive fail
if upstream.activeHealthFails() != 1 {
t.Errorf("expected 1 consecutive fail, got %d", upstream.activeHealthFails())
}
}

// TestConsecutiveFailuresDoTripThreshold verifies that truly consecutive
// failures correctly accumulate and trip the threshold.
func TestConsecutiveFailuresDoTripThreshold(t *testing.T) {
upstream := newTestUpstream()

const failThreshold = 3

upstream.countHealthFail(1)
upstream.countHealthFail(1)
upstream.countHealthFail(1)

if upstream.activeHealthFails() != 3 {
t.Errorf("expected 3 consecutive fails, got %d", upstream.activeHealthFails())
}
if upstream.activeHealthFails() < failThreshold {
t.Error("3 consecutive failures should trip threshold of 3")
}
// Pass counter should be 0 (reset by the first fail)
if upstream.activeHealthPasses() != 0 {
t.Errorf("expected 0 passes after consecutive fails, got %d", upstream.activeHealthPasses())
}
}

// TestInitiallyUnhealthy verifies that when InitiallyUnhealthy is true
// and there are no prior health check passes, the upstream starts unhealthy.
func TestInitiallyUnhealthy(t *testing.T) {
upstream := &Upstream{
Dial: "10.4.0.1:80",
Host: new(Host),
activeHealthStats: &ActiveHealthStats{},
}

// Simulate what Provision does when InitiallyUnhealthy=true and
// passes=0 (fresh host, no prior health checks)
passes := 1 // default Passes threshold
upstream.setHealthy(upstream.activeHealthPasses() >= passes)

if upstream.healthy() {
t.Error("upstream should be unhealthy when InitiallyUnhealthy=true and no passes recorded")
}
}

// TestInitiallyUnhealthyWithPriorPasses verifies that when InitiallyUnhealthy
// is true but the host already has enough passes (e.g., across a reload),
// it starts healthy.
func TestInitiallyUnhealthyWithPriorPasses(t *testing.T) {
stats := &ActiveHealthStats{}
upstream := &Upstream{
Dial: "10.4.0.2:80",
Host: new(Host),
activeHealthStats: stats,
}
upstream.countHealthPass(1) // simulate a prior health check pass

passes := 1
upstream.setHealthy(upstream.activeHealthPasses() >= passes)

if !upstream.healthy() {
t.Error("upstream should be healthy when it has enough prior passes, even with InitiallyUnhealthy=true")
}
}

// TestInitiallyHealthyDefault verifies the default behavior: upstreams
// start healthy unless they have accumulated enough failures.
func TestInitiallyHealthyDefault(t *testing.T) {
upstream := &Upstream{
Dial: "10.4.0.3:80",
Host: new(Host),
activeHealthStats: &ActiveHealthStats{},
}

// Default behavior: healthy unless fails >= threshold
fails := 1
upstream.setHealthy(upstream.activeHealthFails() < fails)

if !upstream.healthy() {
t.Error("upstream should be healthy by default when no failures recorded")
}
}

// TestInitiallyHealthyDefaultWithPriorFails verifies that an upstream
// with prior failures (e.g., from before a reload) starts unhealthy.
func TestInitiallyHealthyDefaultWithPriorFails(t *testing.T) {
upstream := &Upstream{
Dial: "10.4.0.4:80",
Host: new(Host),
activeHealthStats: &ActiveHealthStats{},
}
upstream.countHealthFail(1) // simulate a prior failure

fails := 1
upstream.setHealthy(upstream.activeHealthFails() < fails)

if upstream.healthy() {
t.Error("upstream should be unhealthy when it has prior failures >= threshold")
}
}
12 changes: 12 additions & 0 deletions modules/caddyhttp/reverseproxy/caddyfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
h.HealthChecks.Active.FollowRedirects = true

case "health_initially_unhealthy":
if d.NextArg() {
return d.ArgErr()
}
if h.HealthChecks == nil {
h.HealthChecks = new(HealthChecks)
}
if h.HealthChecks.Active == nil {
h.HealthChecks.Active = new(ActiveHealthChecks)
}
h.HealthChecks.Active.InitiallyUnhealthy = true

case "health_passes":
if !d.NextArg() {
return d.ArgErr()
Expand Down
62 changes: 35 additions & 27 deletions modules/caddyhttp/reverseproxy/healthchecks.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package reverseproxy

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -127,7 +128,10 @@ type ActiveHealthChecks struct {
// body of a healthy backend.
ExpectBody string `json:"expect_body,omitempty"`

uri *url.URL
// Whether backends are initially considered unhealthy.
InitiallyUnhealthy bool `json:"initially_unhealthy,omitempty"`

uri url.URL
httpClient *http.Client
bodyRegexp *regexp.Regexp
logger *zap.Logger
Expand Down Expand Up @@ -163,15 +167,16 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {

if a.Path != "" {
a.logger.Warn("the 'path' option is deprecated, please use 'uri' instead!")
a.uri.Path = a.Path
}

// parse the URI string (supports path and query)
// parse the URI string (supports path and query) and takes precedence over the deprecated Path field
if a.URI != "" {
parsedURI, err := url.Parse(a.URI)
if err != nil {
return err
}
a.uri = parsedURI
a.uri = *parsedURI
}

a.httpClient = &http.Client{
Expand All @@ -185,7 +190,22 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {
},
}

if a.Passes < 1 {
a.Passes = 1
}

if a.Fails < 1 {
a.Fails = 1
}

for _, upstream := range h.Upstreams {
upstream.provisionActiveHealthStats(a.uri.String())
if a.InitiallyUnhealthy {
upstream.setHealthy(upstream.activeHealthPasses() >= a.Passes)
} else {
upstream.setHealthy(upstream.activeHealthFails() < a.Fails)
}

// if there's an alternative upstream for health-check provided in the config,
// then use it, otherwise use the upstream's dial address. if upstream is used,
// then the port is ignored.
Expand All @@ -210,14 +230,6 @@ func (a *ActiveHealthChecks) Provision(ctx caddy.Context, h *Handler) error {
}
}

if a.Passes < 1 {
a.Passes = 1
}

if a.Fails < 1 {
a.Fails = 1
}

return nil
}

Expand Down Expand Up @@ -391,8 +403,10 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networkAddr string, upstream *Upstream) error {
// create the URL for the request that acts as a health check
u := &url.URL{
Scheme: "http",
Host: hostAddr,
Scheme: "http",
Host: hostAddr,
Path: h.HealthChecks.Active.uri.Path,
RawQuery: h.HealthChecks.Active.uri.RawQuery,
}

// split the host and port if possible, override the port if configured
Expand All @@ -415,15 +429,6 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
hcsot.OverrideHealthCheckScheme(u, port)
}

// if we have a provisioned uri, use that, otherwise use
// the deprecated Path option
if h.HealthChecks.Active.uri != nil {
u.Path = h.HealthChecks.Active.uri.Path
u.RawQuery = h.HealthChecks.Active.uri.RawQuery
} else {
u.Path = h.HealthChecks.Active.Path
}

// replacer used for both body and headers. Only globals (env vars, system info, etc.) are available
repl := caddy.NewReplacer()

Expand Down Expand Up @@ -463,7 +468,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ

markUnhealthy := func() {
// increment failures and then check if it has reached the threshold to mark unhealthy
err := upstream.Host.countHealthFail(1)
err := upstream.countHealthFail(1)
if err != nil {
if c := h.HealthChecks.Active.logger.Check(zapcore.ErrorLevel, "could not count active health failure"); c != nil {
c.Write(
Expand All @@ -473,11 +478,10 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
}
return
}
if upstream.Host.activeHealthFails() >= h.HealthChecks.Active.Fails {
if upstream.activeHealthFails() >= h.HealthChecks.Active.Fails {
// dispatch an event that the host newly became unhealthy
if upstream.setHealthy(false) {
h.events.Emit(h.ctx, "unhealthy", map[string]any{"host": hostAddr})
upstream.Host.resetHealth()
}
}
}
Expand All @@ -494,20 +498,24 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, networ
}
return
}
if upstream.Host.activeHealthPasses() >= h.HealthChecks.Active.Passes {
if upstream.activeHealthPasses() >= h.HealthChecks.Active.Passes {
if upstream.setHealthy(true) {
if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "host is up"); c != nil {
c.Write(zap.String("host", hostAddr))
}
h.events.Emit(h.ctx, "healthy", map[string]any{"host": hostAddr})
upstream.Host.resetHealth()
}
}
}

// do the request, being careful to tame the response body
resp, err := h.HealthChecks.Active.httpClient.Do(req) //nolint:gosec // no SSRF
if err != nil {
if errors.Is(err, context.Canceled) {
// context was canceled, so don't count this as a failure
return nil
}

if c := h.HealthChecks.Active.logger.Check(zapcore.InfoLevel, "HTTP request failed"); c != nil {
c.Write(
zap.String("host", hostAddr),
Expand Down
Loading