Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions pkg/graceful/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,38 @@ import (
func Serve(server *http.Server, shutdownDuration time.Duration) {
go listenAndServe(server)
log.Info().Msg("Started HTTP server")
serveGracefully(server, shutdownDuration)
}

func listenAndServe(server *http.Server) {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Err(err).Msg("Failed to serve")
}
}

// ServeTLS serves an HTTPS handler with graceful shutdown of connections on
// SIGINT and SIGTERM. The server's TLSConfig must have at least one certificate
// configured.
func ServeTLS(server *http.Server, shutdownDuration time.Duration) {
if server.TLSConfig == nil || len(server.TLSConfig.Certificates) == 0 {
panic("ServeTLS requires TLSConfig with at least one certificate")
}
go listenAndServeTLS(server)
log.Info().Msg("Started HTTPS server")
serveGracefully(server, shutdownDuration)
Comment thread
milanmlft marked this conversation as resolved.
}

func listenAndServeTLS(server *http.Server) {
// cert/key are in server.TLSConfig, so pass empty strings
if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.Err(err).Msg("Failed to serve TLS")
}
}

func serveGracefully(server *http.Server, shutdownDuration time.Duration) {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
Comment thread
milanmlft marked this conversation as resolved.
defer signal.Stop(signalChan)
<-signalChan
log.Info().Msg("Received termination signal")

Expand All @@ -30,9 +59,3 @@ func Serve(server *http.Server, shutdownDuration time.Duration) {
}
log.Info().Msg("Server exited")
}

func listenAndServe(server *http.Server) {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Err(err).Msg("Failed to serve")
}
}
62 changes: 62 additions & 0 deletions pkg/graceful/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@ package graceful

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"net/http"
"os"
"regexp"
Expand All @@ -17,6 +25,8 @@ import (

func TestServeLogStream(t *testing.T) {
logBuffer := &bytes.Buffer{}
original := log.Logger
t.Cleanup(func() { log.Logger = original })
log.Logger = zerolog.New(logBuffer)

server := http.Server{
Expand Down Expand Up @@ -44,6 +54,58 @@ func TestServeLogStream(t *testing.T) {
assert.Regexp(t, regexp.MustCompile(`Started.*\n.*Received termination signal`), logStream)
}

func TestServeTLSLogStream(t *testing.T) {
logBuffer := &bytes.Buffer{}
original := log.Logger
t.Cleanup(func() { log.Logger = original })
log.Logger = zerolog.New(logBuffer)

Comment thread
milanmlft marked this conversation as resolved.
key := must(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
}
certDER := must(x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key))
keyDER := must(x509.MarshalECPrivateKey(key))
tlsCert := must(tls.X509KeyPair(
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}),
pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}),
))

server := http.Server{
Handler: http.NewServeMux(),
Addr: "127.0.0.1:8443",
TLSConfig: &tls.Config{Certificates: []tls.Certificate{tlsCert}},
}
Comment thread
milanmlft marked this conversation as resolved.
go ServeTLS(&server, 10*time.Millisecond)
time.Sleep(100 * time.Millisecond) // some startup time

process := must(os.FindProcess(os.Getpid()))
err := process.Signal(syscall.SIGINT)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond) // some shutdown time > shutdown duration

Comment thread
milanmlft marked this conversation as resolved.
logStream := string(must(io.ReadAll(logBuffer)))
expectedLines := []string{
"Started HTTPS server",
"Received termination signal",
"Closing server",
"Server exited",
}
for _, expectedLine := range expectedLines {
assert.Contains(t, logStream, expectedLine)
}
assert.Regexp(t, regexp.MustCompile(`Started.*\n.*Received termination signal`), logStream)
}

func TestServeTLSPanicsWithoutTLSConfig(t *testing.T) {
server := http.Server{
Handler: http.NewServeMux(),
Addr: "127.0.0.1:8443",
}
assert.Panics(t, func() { ServeTLS(&server, 10*time.Millisecond) })
}

func must[T any](obj T, err error) T {
if err != nil {
panic(err)
Expand Down
Loading