diff --git a/config.go b/config.go index f5b45cc..e750572 100644 --- a/config.go +++ b/config.go @@ -40,10 +40,20 @@ func (c *Config) InitDefaults() { } } +// parseDSN splits a "scheme://address" DSN and returns the two parts. +// Returns an error unless the DSN contains exactly one "://" separator. +func parseDSN(dsn string) (scheme, addr string, err error) { + scheme, addr, ok := strings.Cut(dsn, "://") + if !ok || strings.Contains(addr, "://") { + return "", "", errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + return scheme, addr, nil +} + // Valid returns nil if config is valid. func (c *Config) Valid() error { - if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 { - return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + if _, _, err := parseDSN(c.Listen); err != nil { + return err } if c.RequestTimeout < 0 { return errors.New("rpc request_timeout must be non-negative") @@ -63,10 +73,10 @@ func (c *Config) Listener() (net.Listener, error) { // Dialer creates rpc socket Dialer. func (c *Config) Dialer() (net.Conn, error) { - dsn := strings.Split(c.Listen, "://") - if len(dsn) != 2 { - return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + scheme, addr, err := parseDSN(c.Listen) + if err != nil { + return nil, err } var d net.Dialer - return d.DialContext(context.Background(), dsn[0], dsn[1]) + return d.DialContext(context.Background(), scheme, addr) } diff --git a/plugin.go b/plugin.go index e590aff..526d1a1 100644 --- a/plugin.go +++ b/plugin.go @@ -85,13 +85,13 @@ func (s *Plugin) Init(cfg Configurer, log Logger) error { return errors.E(op, err) } - var WholeCfg any - err = cfg.Unmarshal(&WholeCfg) + var wholeCfg any + err = cfg.Unmarshal(&wholeCfg) if err != nil { return errors.E(op, err) } - s.wcfg, err = json.Marshal(WholeCfg) + s.wcfg, err = json.Marshal(wholeCfg) if err != nil { return err } @@ -125,10 +125,7 @@ func (s *Plugin) Serve() chan error { mux.Handle(path, handler) // derive the gRPC service name from the mount path // (`//` or `//`) - svc := strings.TrimPrefix(path, "/") - if i := strings.Index(svc, "/"); i >= 0 { - svc = svc[:i] - } + svc, _, _ := strings.Cut(strings.TrimPrefix(path, "/"), "/") services = append(services, svc) } diff --git a/tests/config_test.go b/tests/config_test.go index 6d23f6d..9482e4e 100644 --- a/tests/config_test.go +++ b/tests/config_test.go @@ -156,6 +156,18 @@ func Test_Config_DialerErrorMethod(t *testing.T) { assert.Error(t, err) } +func Test_Config_MultipleSeparators(t *testing.T) { + // A DSN with more than one "://" must be rejected by both Valid and Dialer. + cfg := &rpc.Config{Listen: "tcp://host://6001"} + + assert.Error(t, cfg.Valid()) + + conn, err := cfg.Dialer() + assert.Nil(t, conn) + assert.Error(t, err) + assert.Equal(t, "invalid socket DSN (tcp://:6001, unix://file.sock)", err.Error()) +} + func Test_Config_Defaults(t *testing.T) { c := &rpc.Config{} c.InitDefaults()