diff --git a/pkg/graceful/serve.go b/pkg/graceful/serve.go index 2120b54..8dec5e5 100644 --- a/pkg/graceful/serve.go +++ b/pkg/graceful/serve.go @@ -16,9 +16,38 @@ import ( func Serve(server *http.Server, shutdownDuration time.Duration) { go listenAndServe(server) log.Info().Msg("Started HTTP server") + serveGracefully(server, shutdownDuration) +} + +func listenAndServe(server *http.Server) { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Err(err).Msg("Failed to serve") + } +} + +// ServeTLS serves an HTTPS handler with graceful shutdown of connections on +// SIGINT and SIGTERM. The server's TLSConfig must have at least one certificate +// configured. +func ServeTLS(server *http.Server, shutdownDuration time.Duration) { + if server.TLSConfig == nil || len(server.TLSConfig.Certificates) == 0 { + panic("ServeTLS requires TLSConfig with at least one certificate") + } + go listenAndServeTLS(server) + log.Info().Msg("Started HTTPS server") + serveGracefully(server, shutdownDuration) +} +func listenAndServeTLS(server *http.Server) { + // cert/key are in server.TLSConfig, so pass empty strings + if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + log.Err(err).Msg("Failed to serve TLS") + } +} + +func serveGracefully(server *http.Server, shutdownDuration time.Duration) { signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(signalChan) <-signalChan log.Info().Msg("Received termination signal") @@ -30,9 +59,3 @@ func Serve(server *http.Server, shutdownDuration time.Duration) { } log.Info().Msg("Server exited") } - -func listenAndServe(server *http.Server) { - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Err(err).Msg("Failed to serve") - } -} diff --git a/pkg/graceful/serve_test.go b/pkg/graceful/serve_test.go index bd4b191..ba40105 100644 --- a/pkg/graceful/serve_test.go +++ b/pkg/graceful/serve_test.go @@ -2,7 +2,15 @@ package graceful import ( "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "io" + "math/big" "net/http" "os" "regexp" @@ -17,6 +25,8 @@ import ( func TestServeLogStream(t *testing.T) { logBuffer := &bytes.Buffer{} + original := log.Logger + t.Cleanup(func() { log.Logger = original }) log.Logger = zerolog.New(logBuffer) server := http.Server{ @@ -44,6 +54,58 @@ func TestServeLogStream(t *testing.T) { assert.Regexp(t, regexp.MustCompile(`Started.*\n.*Received termination signal`), logStream) } +func TestServeTLSLogStream(t *testing.T) { + logBuffer := &bytes.Buffer{} + original := log.Logger + t.Cleanup(func() { log.Logger = original }) + log.Logger = zerolog.New(logBuffer) + + key := must(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + } + certDER := must(x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)) + keyDER := must(x509.MarshalECPrivateKey(key)) + tlsCert := must(tls.X509KeyPair( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), + pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}), + )) + + server := http.Server{ + Handler: http.NewServeMux(), + Addr: "127.0.0.1:8443", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{tlsCert}}, + } + go ServeTLS(&server, 10*time.Millisecond) + time.Sleep(100 * time.Millisecond) // some startup time + + process := must(os.FindProcess(os.Getpid())) + err := process.Signal(syscall.SIGINT) + assert.NoError(t, err) + time.Sleep(20 * time.Millisecond) // some shutdown time > shutdown duration + + logStream := string(must(io.ReadAll(logBuffer))) + expectedLines := []string{ + "Started HTTPS server", + "Received termination signal", + "Closing server", + "Server exited", + } + for _, expectedLine := range expectedLines { + assert.Contains(t, logStream, expectedLine) + } + assert.Regexp(t, regexp.MustCompile(`Started.*\n.*Received termination signal`), logStream) +} + +func TestServeTLSPanicsWithoutTLSConfig(t *testing.T) { + server := http.Server{ + Handler: http.NewServeMux(), + Addr: "127.0.0.1:8443", + } + assert.Panics(t, func() { ServeTLS(&server, 10*time.Millisecond) }) +} + func must[T any](obj T, err error) T { if err != nil { panic(err)