diff --git a/cmd/emulator/main.go b/cmd/emulator/main.go index a47eee00..9b0528f3 100644 --- a/cmd/emulator/main.go +++ b/cmd/emulator/main.go @@ -43,7 +43,12 @@ func defaultServiceKey( } func main() { - if err := start.Cmd(defaultServiceKey).Execute(); err != nil { + config := start.StartConfig{ + GetServiceKey: defaultServiceKey, + RestMiddlewares: []start.HttpMiddleware{}, + } + + if err := start.Cmd(config).Execute(); err != nil { start.Exit(1, err.Error()) } } diff --git a/cmd/emulator/start/start.go b/cmd/emulator/start/start.go index 2acf9762..2bf5d1c6 100644 --- a/cmd/emulator/start/start.go +++ b/cmd/emulator/start/start.go @@ -22,6 +22,7 @@ import ( "encoding/hex" "fmt" "log" + "net/http" "os" "strings" "time" @@ -91,7 +92,14 @@ type serviceKeyFunc func( hashAlgo crypto.HashAlgorithm, ) (crypto.PrivateKey, crypto.SignatureAlgorithm, crypto.HashAlgorithm) -func Cmd(getServiceKey serviceKeyFunc) *cobra.Command { +type HttpMiddleware func(http.Handler) http.Handler + +type StartConfig struct { + GetServiceKey serviceKeyFunc + RestMiddlewares []HttpMiddleware +} + +func Cmd(config StartConfig) *cobra.Command { cmd := &cobra.Command{ Use: "start", Short: "Starts the Flow emulator server", @@ -123,7 +131,7 @@ func Cmd(getServiceKey serviceKeyFunc) *cobra.Command { servicePublicKey = servicePrivateKey.PublicKey() } else { // if we don't provide any config values use the serviceKeyFunc to obtain the key - servicePrivateKey, serviceKeySigAlgo, serviceKeyHashAlgo = getServiceKey( + servicePrivateKey, serviceKeySigAlgo, serviceKeyHashAlgo = config.GetServiceKey( conf.Init, serviceKeySigAlgo, serviceKeyHashAlgo, @@ -216,6 +224,9 @@ func Cmd(getServiceKey serviceKeyFunc) *cobra.Command { emu := server.NewEmulatorServer(logger, serverConf) if emu != nil { + for _, middleware := range config.RestMiddlewares { + emu.UseRestMiddleware(middleware) + } emu.Start() } else { Exit(-1, "") diff --git a/server/access/rest.go b/server/access/rest.go index 646b9d7a..7ecc027c 100644 --- a/server/access/rest.go +++ b/server/access/rest.go @@ -82,6 +82,12 @@ func (r *RestServer) Stop() { _ = r.server.Shutdown(context.Background()) } +func (r *RestServer) UseMiddleware(middleware func(http.Handler) http.Handler) { + if r.server != nil { + r.server.Handler = middleware(r.server.Handler) + } +} + func NewRestServer(logger *zerolog.Logger, blockchain *emulator.Blockchain, adapter *adapters.AccessAdapter, chain flow.Chain, host string, port int, debug bool) (*RestServer, error) { debugLogger := zerolog.Logger{} diff --git a/server/server.go b/server/server.go index 0b53a0ba..bddf1990 100644 --- a/server/server.go +++ b/server/server.go @@ -20,6 +20,7 @@ package server import ( "fmt" + "net/http" "os" "sort" "time" @@ -511,3 +512,9 @@ func sanitizeConfig(conf *Config) *Config { return conf } + +func (s *EmulatorServer) UseRestMiddleware(middleware func(http.Handler) http.Handler) { + if s.rest != nil { + s.rest.UseMiddleware(middleware) + } +}