diff --git a/fasthttpadaptor/request.go b/fasthttpadaptor/request.go index 6213f289aa..aca0fa50e0 100644 --- a/fasthttpadaptor/request.go +++ b/fasthttpadaptor/request.go @@ -3,6 +3,8 @@ package fasthttpadaptor import ( "bytes" "io" + "math" + "net" "net/http" "net/url" "strings" @@ -68,3 +70,72 @@ func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) e return nil } + +// ConvertNetHTTPRequestToFastHTTPRequest converts an http.Request to a fasthttp.RequestCtx. +// +// The caller is responsible for the lifecycle of the fasthttp.RequestCtx and the +// underlying fasthttp.Request. The ctx (and its Request) must only be used for +// the duration that fasthttp considers it valid (typically within a handler), +// and must not be accessed after the handler has returned. +// +// The request body is not copied. If r.Body is non-nil, it is passed directly to +// ctx.Request via SetBodyStream. This means: +// - r.Body must remain readable for as long as ctx may need to read it. +// - r.Body should not be read from, written to, or closed by the caller until +// fasthttp is done with ctx. +// - The same r.Body must not be reused concurrently in other goroutines while +// it is attached to ctx.Request. +// +// After calling this function, you should treat r.Body as effectively owned by +// ctx.Request for the lifetime of that context. +func ConvertNetHTTPRequestToFastHTTPRequest(r *http.Request, ctx *fasthttp.RequestCtx) { + ctx.Request.Header.SetMethod(r.Method) + + if r.RequestURI != "" { + ctx.Request.SetRequestURI(r.RequestURI) + } else if r.URL != nil { + ctx.Request.SetRequestURI(r.URL.RequestURI()) + } + + ctx.Request.Header.SetProtocol(r.Proto) + ctx.Request.SetHost(r.Host) + + for k, values := range r.Header { + for i, v := range values { + if i == 0 { + ctx.Request.Header.Set(k, v) + } else { + ctx.Request.Header.Add(k, v) + } + } + } + + if r.Body != nil { + contentLength := int(r.ContentLength) + if r.ContentLength >= int64(math.MaxInt) { + contentLength = -1 + } + + ctx.Request.SetBodyStream(r.Body, contentLength) + } + + if r.RemoteAddr != "" { + addr := parseRemoteAddr(r.RemoteAddr) + ctx.SetRemoteAddr(addr) + } +} + +func parseRemoteAddr(addr string) net.Addr { + if tcpAddr, err := net.ResolveTCPAddr("tcp", addr); err == nil { + return tcpAddr + } + + if _, _, err := net.SplitHostPort(addr); err != nil { + if tcpAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(addr, "0")); err == nil { + return tcpAddr + } + } + + host := strings.Trim(addr, "[]") + return &net.TCPAddr{IP: net.ParseIP(host)} +} diff --git a/fasthttpadaptor/request_test.go b/fasthttpadaptor/request_test.go index 1f214c2d06..14a35294f5 100644 --- a/fasthttpadaptor/request_test.go +++ b/fasthttpadaptor/request_test.go @@ -1,7 +1,11 @@ package fasthttpadaptor import ( + "bytes" + "errors" + "io" "net/http" + "net/url" "testing" "github.com/valyala/fasthttp" @@ -27,3 +31,332 @@ func BenchmarkConvertRequest(b *testing.B) { _ = ConvertRequest(ctx, &httpReq, true) } } + +func BenchmarkConvertNetHTTPRequestToFastHTTPRequest(b *testing.B) { + httpReq := http.Request{ + Method: "GET", + RequestURI: "/test", + Host: "test", + Header: http.Header{ + "X": []string{"test"}, + "Y": []string{"test"}, + }, + } + + ctx := &fasthttp.RequestCtx{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ConvertNetHTTPRequestToFastHTTPRequest(&httpReq, ctx) + } +} + +// errReader is a reader that always returns an error. +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, errors.New("read error") +} + +func TestConvertNetHTTPRequestToFastHTTPRequest(t *testing.T) { + t.Parallel() + + t.Run("basic conversion", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "POST", + RequestURI: "/test/path?query=1", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + if string(ctx.Method()) != "POST" { + t.Errorf("expected method POST, got %s", ctx.Method()) + } + if string(ctx.RequestURI()) != "/test/path?query=1" { + t.Errorf("expected URI /test/path?query=1, got %s", ctx.RequestURI()) + } + if string(ctx.Request.Header.Protocol()) != "HTTP/1.1" { + t.Errorf("expected protocol HTTP/1.1, got %s", ctx.Request.Header.Protocol()) + } + if string(ctx.Host()) != "example.com" { + t.Errorf("expected host example.com, got %s", ctx.Host()) + } + }) + + t.Run("URL fallback when RequestURI is empty", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "", + URL: &url.URL{ + Path: "/fallback/path", + RawQuery: "foo=bar", + }, + Proto: "HTTP/1.1", + Host: "fallback.com", + Header: http.Header{}, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + if string(ctx.RequestURI()) != "/fallback/path?foo=bar" { + t.Errorf("expected URI /fallback/path?foo=bar, got %s", ctx.RequestURI()) + } + }) + + t.Run("single header", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{ + "X-Custom-Header": []string{"custom-value"}, + }, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + if string(ctx.Request.Header.Peek("X-Custom-Header")) != "custom-value" { + t.Errorf("expected header value custom-value, got %s", ctx.Request.Header.Peek("X-Custom-Header")) + } + }) + + t.Run("multiple header values", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{ + "Accept": []string{"text/html", "application/json", "text/plain"}, + }, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + // Check all header values are present + var values []string + ctx.Request.Header.All()(func(key, value []byte) bool { + if string(key) == "Accept" { + values = append(values, string(value)) + } + return true + }) + + if len(values) != 3 { + t.Errorf("expected 3 Accept header values, got %d", len(values)) + } + }) + + t.Run("request body", func(t *testing.T) { + t.Parallel() + bodyContent := []byte("test body content") + httpReq := &http.Request{ + Method: "POST", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(bodyContent)), + ContentLength: int64(len(bodyContent)), + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + if !bytes.Equal(ctx.Request.Body(), bodyContent) { + t.Errorf("expected body %q, got %q", bodyContent, ctx.Request.Body()) + } + }) + + t.Run("nil body", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + Body: nil, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + if len(ctx.Request.Body()) != 0 { + t.Errorf("expected empty body, got %q", ctx.Request.Body()) + } + }) + + t.Run("remote address with port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "192.168.1.100:8080", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "192.168.1.100:8080" { + t.Errorf("expected remote addr 192.168.1.100:8080, got %s", remoteAddr) + } + }) + + t.Run("remote address without port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "192.168.1.100", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "192.168.1.100:0" { + t.Errorf("expected remote addr 192.168.1.100:0, got %s", remoteAddr) + } + }) + + t.Run("IPv6 remote address with port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "[2001:db8::1]:8080", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "[2001:db8::1]:8080" { + t.Errorf("expected remote addr [2001:db8::1]:8080, got %s", remoteAddr) + } + }) + + t.Run("IPv6 remote address without port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "2001:db8::1", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "[2001:db8::1]:0" { + t.Errorf("expected remote addr [2001:db8::1]:0, got %s", remoteAddr) + } + }) + + t.Run("IPv6 remote address with zone and port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "[fe80::1%eth0]:9090", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "[fe80::1%eth0]:9090" { + t.Errorf("expected remote addr [fe80::1%%eth0]:9090, got %s", remoteAddr) + } + }) + + t.Run("IPv6 remote address with zone without port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "fe80::1%eth0", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "[fe80::1%eth0]:0" { + t.Errorf("expected remote addr [fe80::1%%eth0]:0, got %s", remoteAddr) + } + }) + + t.Run("IPv6 loopback with port", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "GET", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + RemoteAddr: "[::1]:3000", + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + remoteAddr := ctx.RemoteAddr().String() + if remoteAddr != "[::1]:3000" { + t.Errorf("expected remote addr [::1]:3000, got %s", remoteAddr) + } + }) + + t.Run("body read error", func(t *testing.T) { + t.Parallel() + httpReq := &http.Request{ + Method: "POST", + RequestURI: "/", + Proto: "HTTP/1.1", + Host: "example.com", + Header: http.Header{}, + Body: io.NopCloser(errReader{}), + ContentLength: 10, + } + + ctx := &fasthttp.RequestCtx{} + ConvertNetHTTPRequestToFastHTTPRequest(httpReq, ctx) + + _, err := io.ReadAll(ctx.RequestBodyStream()) + if err == nil { + t.Fatal("expected error when reading body stream, got nil") + } + }) +}