Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,29 @@ type Server struct {
// like they are normal requests.
ContinueHandler func(header *RequestHeader) bool

// ExpectHandler is called after receiving the Expect 100 Continue Header.
//
// https://www.rfc-editor.org/rfc/rfc9110.html#field.expect
//
// ExpectHandler provides more control than ContinueHandler by allowing
// the server to respond with any final status code. The handler should return
// StatusContinue (100) to accept the request and proceed to read the body,
// or any other status code to reject it. If StatusExpectationFailed (417) is
// returned, the response is sent and the connection is closed. For any other
// non-100 status code, the response is also sent and the connection is closed,
// since the client may have already started sending the request body.
Comment thread
miretskiy marked this conversation as resolved.
Outdated
//
// The ctx provides access to request headers and connection metadata (e.g.
// RemoteAddr for IP-based filtering). The response must not be modified —
// only the returned status code is used.
Comment thread
miretskiy marked this conversation as resolved.
Outdated
//
// If both ExpectHandler and ContinueHandler are set, ExpectHandler
// takes precedence.
//
// The default behavior (when neither handler is set) is to automatically accept
// the request body.
ExpectHandler func(ctx *RequestCtx) int

// ConnState specifies an optional callback function that is
// called when a client connection changes state. See the
// ConnState type and associated constants for details.
Expand Down Expand Up @@ -2445,10 +2468,20 @@ func (s *Server) serveConn(c net.Conn) error {
}

// 'Expect: 100-continue' request handling.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
// See https://www.rfc-editor.org/rfc/rfc9110.html#field.expect for details.
if ctx.Request.MayContinue() {
// Allow the ability to deny reading the incoming request body
if s.ContinueHandler != nil {
// Allow the ability to deny reading the incoming request body.
if s.ExpectHandler != nil {
if expectStatus := s.ExpectHandler(ctx); expectStatus != StatusContinue {
continueReadingRequest = false
if br != nil {
br.Reset(ctx.c)
}
ctx.SetStatusCode(expectStatus)
// Close connection since client may have already started sending body data.
connectionClose = true
}
} else if s.ContinueHandler != nil {
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
if br != nil {
br.Reset(ctx.c)
Expand Down Expand Up @@ -2498,8 +2531,9 @@ func (s *Server) serveConn(c net.Conn) error {
}
}

// store req.ConnectionClose so even if it was changed inside of handler
connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
// store req.ConnectionClose so even if it was changed inside of handler.
// Preserve connectionClose if already set (e.g., by ExpectHandler).
connectionClose = connectionClose || s.DisableKeepalive || ctx.Request.Header.ConnectionClose()

if serverName != "" {
ctx.Response.Header.SetServer(serverName)
Expand Down
219 changes: 219 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2102,6 +2102,225 @@ func TestServerContinueHandler(t *testing.T) {
}
}

func TestServerExpectHandler(t *testing.T) {
t.Parallel()

acceptContentLength := 5
s := &Server{
ExpectHandler: func(ctx *RequestCtx) int {
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}

ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpected content-type: %q. Expecting %q", ct, "a/b")
}

if ctx.Request.Header.contentLength == acceptContentLength {
return StatusContinue
}
return StatusExpectationFailed
},
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.contentLength != acceptContentLength {
t.Errorf("all requests with content-length other than %d should be denied", acceptContentLength)
}
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}

sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}

br := bufio.NewReader(&rw.w)
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)

data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}

rw := &readWriter{}
for range 25 {
// Regular requests without Expect header
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")

// Expect 100-continue requests that are accepted
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")

// Requests rejected with 417
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
sendRequest(rw, StatusExpectationFailed, "")
}
}

func TestServerExpectHandlerCustomStatusCode(t *testing.T) {
t.Parallel()

s := &Server{
ExpectHandler: func(ctx *RequestCtx) int {
// Reject with 413 Request Entity Too Large for large bodies
if ctx.Request.Header.ContentLength() > 5 {
return StatusRequestEntityTooLarge
}
return StatusContinue
},
Handler: func(ctx *RequestCtx) {
ctx.WriteString("ok") //nolint:errcheck
},
}

// Accepted request
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "ok")

// Rejected request with 413 — connection should be closed
rw = &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br = bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestEntityTooLarge, string(defaultContentType), "")
}

func TestServerExpectHandlerConnectionClose(t *testing.T) {
t.Parallel()

s := &Server{
ExpectHandler: func(ctx *RequestCtx) int {
if ctx.Request.Header.ContentLength() > 5 {
return StatusExpectationFailed
}
return StatusContinue
},
Handler: func(ctx *RequestCtx) {
ctx.WriteString("ok") //nolint:errcheck
},
}

// When rejected, the response should have Connection: close and
// subsequent pipelined requests should NOT be processed.
rw := &readWriter{}
// Send two pipelined requests: one rejected, one that should never be processed.
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")

if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}

br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.StatusCode() != StatusExpectationFailed {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusExpectationFailed)
}
if !resp.Header.ConnectionClose() {
t.Fatal("response should have Connection: close header")
}

// Verify no second response was sent (connection was closed after first rejection).
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q. Connection should have been closed", data)
}
}

func TestServerExpectHandlerPrecedence(t *testing.T) {
t.Parallel()

// When both ExpectHandler and ContinueHandler are set,
// ExpectHandler should take precedence.
continueHandlerCalled := false
s := &Server{
ContinueHandler: func(headers *RequestHeader) bool {
continueHandlerCalled = true
return true
},
ExpectHandler: func(ctx *RequestCtx) int {
return StatusRequestEntityTooLarge
},
Handler: func(ctx *RequestCtx) {
t.Error("handler should not be called when ExpectHandler rejects")
},
}

rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}

br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestEntityTooLarge, string(defaultContentType), "")

if continueHandlerCalled {
t.Fatal("ContinueHandler should not be called when ExpectHandler is set")
}
}

func TestServerExpectHandlerRemoteAddr(t *testing.T) {
t.Parallel()

// ExpectHandler can use ctx.RemoteAddr() for IP-based filtering.
// readWriter returns zeroTCPAddr (0.0.0.0:0) as the remote address.
s := &Server{
ExpectHandler: func(ctx *RequestCtx) int {
if ctx.RemoteAddr().String() == "0.0.0.0:0" {
return StatusForbidden
}
return StatusContinue
},
Handler: func(ctx *RequestCtx) {
ctx.WriteString("ok") //nolint:errcheck
},
}

rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusForbidden, string(defaultContentType), "")
}

func TestCompressHandler(t *testing.T) {
t.Parallel()

Expand Down