Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions pkg/ratelimit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,10 @@ func (m *rateLimitMiddleware) Close() error {
return nil
}

// CreateMiddleware is the factory function for rate limit middleware.
func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error {
var params MiddlewareParams
if err := json.Unmarshal(config.Parameters, &params); err != nil {
return fmt.Errorf("failed to unmarshal rate limit middleware parameters: %w", err)
}

// NewMiddleware creates a Redis-backed rate limit middleware from typed params.
func NewMiddleware(params MiddlewareParams) (types.Middleware, error) {
if params.RedisAddr == "" {
return fmt.Errorf("rate limit middleware requires a Redis address")
return nil, fmt.Errorf("rate limit middleware requires a Redis address")
}

// TODO: share a Redis client builder with session storage to get TLS,
Expand All @@ -89,18 +84,31 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
defer pingCancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", params.RedisAddr, err)
return nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", params.RedisAddr, err)
}

limiter, err := NewLimiter(client, params.Namespace, params.ServerName, params.Config)
if err != nil {
_ = client.Close()
return fmt.Errorf("failed to create rate limiter: %w", err)
return nil, fmt.Errorf("failed to create rate limiter: %w", err)
}

mw := &rateLimitMiddleware{
return &rateLimitMiddleware{
handler: rateLimitHandler(limiter),
client: client,
}, nil
}

// CreateMiddleware is the factory function for rate limit middleware.
func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error {
var params MiddlewareParams
if err := json.Unmarshal(config.Parameters, &params); err != nil {
return fmt.Errorf("failed to unmarshal rate limit middleware parameters: %w", err)
}

mw, err := NewMiddleware(params)
if err != nil {
return err
}
runner.AddMiddleware(MiddlewareType, mw)
return nil
Expand Down
69 changes: 69 additions & 0 deletions pkg/ratelimit/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ import (
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/mcp"
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks"
)

// dummyLimiter is a test double for the Limiter interface.
Expand Down Expand Up @@ -208,3 +214,66 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
assert.Equal(t, "echo", recorder.toolName)
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
}

func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) {
t.Parallel()

expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}})
mw := &rateLimitMiddleware{handler: expected}

assert.NotNil(t, mw.Handler())
}

func TestNewMiddlewareReturnsUsableMiddleware(t *testing.T) {
t.Parallel()

mr := miniredis.RunT(t)
middleware, err := NewMiddleware(MiddlewareParams{
Namespace: "default",
ServerName: "server",
RedisAddr: mr.Addr(),
Config: &v1beta1.RateLimitConfig{
Shared: &v1beta1.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
})

require.NoError(t, err)
require.NotNil(t, middleware)
require.NotNil(t, middleware.Handler())
require.NoError(t, middleware.Close())
}

func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) {
t.Parallel()

mr := miniredis.RunT(t)
cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{
Namespace: "default",
ServerName: "server",
RedisAddr: mr.Addr(),
Config: &v1beta1.RateLimitConfig{
Shared: &v1beta1.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
})
require.NoError(t, err)

ctrl := gomock.NewController(t)
runner := transportmocks.NewMockMiddlewareRunner(ctrl)
var registered transporttypes.Middleware
runner.EXPECT().
AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})).
Do(func(_ string, middleware transporttypes.Middleware) {
registered = middleware
})

require.NoError(t, CreateMiddleware(cfg, runner))
require.NotNil(t, registered)
require.NotNil(t, registered.Handler())
require.NoError(t, registered.Close())
}
34 changes: 31 additions & 3 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
"github.com/stacklok/toolhive/pkg/vmcp/auth/factory"
authfactory "github.com/stacklok/toolhive/pkg/vmcp/auth/factory"
vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client"
"github.com/stacklok/toolhive/pkg/vmcp/config"
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
"github.com/stacklok/toolhive/pkg/vmcp/health"
"github.com/stacklok/toolhive/pkg/vmcp/k8s"
"github.com/stacklok/toolhive/pkg/vmcp/optimizer"
ratelimitfactory "github.com/stacklok/toolhive/pkg/vmcp/ratelimit/factory"
vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router"
vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server"
vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
Expand Down Expand Up @@ -372,13 +373,31 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
}

authMiddleware, authzMiddleware, authInfoHandler, err :=
factory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, vmcpCfg.Name, passThroughTools, upstreamReader, keyProvider)
authfactory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, vmcpCfg.Name, passThroughTools, upstreamReader, keyProvider)
if err != nil {
return fmt.Errorf("failed to create authentication middleware: %w", err)
}

slog.Info(fmt.Sprintf("Incoming authentication configured: %s", vmcpCfg.IncomingAuth.Type))

namespace := vmcpNamespace()
rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{
Namespace: namespace,
ServerName: vmcpCfg.Name,
RateLimiting: vmcpCfg.RateLimiting,
SessionStorage: vmcpCfg.SessionStorage,
})
if err != nil {
return fmt.Errorf("failed to create rate limit middleware: %w", err)
}
if rateLimitCleanup != nil {
defer func() {
if closeErr := rateLimitCleanup(context.Background()); closeErr != nil {
slog.Error(fmt.Sprintf("failed to close rate limit middleware: %v", closeErr))
}
}()
}

serverCfg := &vmcpserver.Config{
Name: vmcpCfg.Name,
Version: versions.Version,
Expand All @@ -389,6 +408,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
AuthMiddleware: authMiddleware,
AuthzMiddleware: authzMiddleware,
AuthInfoHandler: authInfoHandler,
RateLimitMiddleware: rateLimitMiddleware,
AuthServer: embeddedAuthServer,
TelemetryProvider: telemetryProvider,
AuditConfig: vmcpCfg.Audit,
Expand Down Expand Up @@ -534,6 +554,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) {
return cfg, nil
}

func vmcpNamespace() string {
namespace := os.Getenv("VMCP_NAMESPACE")
if namespace == "" {
return "local"
}
return namespace
}

// loadAuthServerConfig loads the auth server RunConfig from a sibling file
// alongside the main config. The operator serializes authserver.RunConfig as a
// separate ConfigMap key (authserver-config.yaml).
Expand Down Expand Up @@ -565,7 +593,7 @@ func discoverBackends(
) ([]vmcp.Backend, vmcp.BackendClient, vmcpauth.OutgoingAuthRegistry, error) {
slog.Info("initializing outgoing authentication")
envReader := &env.OSReader{}
outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, envReader)
outgoingRegistry, err := authfactory.NewOutgoingAuthRegistry(ctx, envReader)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create outgoing authentication registry: %w", err)
}
Expand Down
14 changes: 14 additions & 0 deletions pkg/vmcp/cli/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,20 @@ func TestValidateQuickModeHost(t *testing.T) {
}
}

func TestVMCPNamespace(t *testing.T) {
t.Run("defaults to local", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "")

assert.Equal(t, "local", vmcpNamespace())
})

t.Run("uses environment value", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "toolhive-system")

assert.Equal(t, "toolhive-system", vmcpNamespace())
})
}

// TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the
// discoverer succeeds but returns no backends. The function must return a
// non-error, an empty (non-nil) backend slice, and pass through the client and
Expand Down
55 changes: 55 additions & 0 deletions pkg/vmcp/ratelimit/factory/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

// Package factory builds vMCP-specific rate-limit middleware.
package factory

import (
"context"
"fmt"
"net/http"

"github.com/stacklok/toolhive/pkg/ratelimit"
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
)

// Config contains the vMCP rate-limit middleware inputs.
type Config struct {
Namespace string
ServerName string
RateLimiting *ratelimittypes.RateLimitConfig
SessionStorage *vmcpconfig.SessionStorageConfig
}

// NewMiddleware creates Redis-backed rate-limit middleware for vMCP.
func NewMiddleware(
_ context.Context,
cfg Config,
) (func(http.Handler) http.Handler, func(context.Context) error, error) {
if cfg.RateLimiting == nil {
return nil, nil, nil
}
if cfg.SessionStorage == nil || cfg.SessionStorage.Provider != "redis" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage")
}
if cfg.SessionStorage.Address == "" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address")
}

middleware, err := ratelimit.NewMiddleware(ratelimit.MiddlewareParams{
Namespace: cfg.Namespace,
ServerName: cfg.ServerName,
Config: cfg.RateLimiting,
RedisAddr: cfg.SessionStorage.Address,
RedisDB: cfg.SessionStorage.DB,
})
if err != nil {
return nil, nil, err
}

cleanup := func(context.Context) error {
return middleware.Close()
}
return middleware.Handler(), cleanup, nil
}
Loading
Loading