Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@ func (c *Config) InitDefaults() {
}
}

// parseDSN splits a "scheme://address" DSN and returns the two parts.
// Returns an error when the DSN does not contain "://".
func parseDSN(dsn string) (scheme, addr string, err error) {
scheme, addr, ok := strings.Cut(dsn, "://")
if !ok {
return "", "", errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
}
return scheme, addr, nil
}
Comment thread
rustatian marked this conversation as resolved.

// Valid returns nil if config is valid.
func (c *Config) Valid() error {
if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 {
return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
if _, _, err := parseDSN(c.Listen); err != nil {
return err
}
if c.RequestTimeout < 0 {
return errors.New("rpc request_timeout must be non-negative")
Expand All @@ -63,10 +73,10 @@ func (c *Config) Listener() (net.Listener, error) {

// Dialer creates rpc socket Dialer.
func (c *Config) Dialer() (net.Conn, error) {
dsn := strings.Split(c.Listen, "://")
if len(dsn) != 2 {
return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
scheme, addr, err := parseDSN(c.Listen)
if err != nil {
return nil, err
}
var d net.Dialer
return d.DialContext(context.Background(), dsn[0], dsn[1])
return d.DialContext(context.Background(), scheme, addr)
}
11 changes: 4 additions & 7 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ func (s *Plugin) Init(cfg Configurer, log Logger) error {
return errors.E(op, err)
}

var WholeCfg any
err = cfg.Unmarshal(&WholeCfg)
var wholeCfg any
err = cfg.Unmarshal(&wholeCfg)
if err != nil {
return errors.E(op, err)
}

s.wcfg, err = json.Marshal(WholeCfg)
s.wcfg, err = json.Marshal(wholeCfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -125,10 +125,7 @@ func (s *Plugin) Serve() chan error {
mux.Handle(path, handler)
// derive the gRPC service name from the mount path
// (`/<service>/<Method>` or `/<service>/`)
svc := strings.TrimPrefix(path, "/")
if i := strings.Index(svc, "/"); i >= 0 {
svc = svc[:i]
}
svc, _, _ := strings.Cut(strings.TrimPrefix(path, "/"), "/")
services = append(services, svc)
}

Expand Down
Loading