Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions pkg/server/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package server

import (
"fmt"
"io"
"sort"
"strings"
"time"
)

func prometheusEscapeLabelValue(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, `"`, `\"`)
return s
}

func prometheusLabelValueOrUnknown(s string) string {
if s == "" {
return "unknown"
}
return s
}

func prometheusLabels(index uint32, module, upstream string) string {
return fmt.Sprintf(
`index="%d",module="%s",upstream="%s"`,
index,
prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)),
prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)),
)
}

type prometheusConnectionGroup struct {
module string
upstream string
}

func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) {
connections := s.ListConnectionInfo()

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge")
_, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount())

connectionCounts := make(map[prometheusConnectionGroup]int)
for _, conn := range connections {
_, module, upstream, _, _, _ := conn.snapshot()
key := prometheusConnectionGroup{
module: prometheusLabelValueOrUnknown(module),
upstream: prometheusLabelValueOrUnknown(upstream),
}
connectionCounts[key]++
}

keys := make([]prometheusConnectionGroup, 0, len(connectionCounts))
for key := range connectionCounts {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].module != keys[j].module {
return keys[i].module < keys[j].module
}
return keys[i].upstream < keys[j].upstream
})

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge")
for _, key := range keys {
module := prometheusEscapeLabelValue(key.module)
upstream := prometheusEscapeLabelValue(key.upstream)
_, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key])
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge")
for _, conn := range connections {
index, module, upstream, _, sentBytes, _ := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(index, module, upstream), sentBytes)
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge")
for _, conn := range connections {
index, module, upstream, _, _, receivedBytes := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(index, module, upstream), receivedBytes)
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge")
for _, conn := range connections {
index, module, upstream, connectedAt, _, _ := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(index, module, upstream), connectedAt.Unix())
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge")
for _, conn := range connections {
index, module, upstream, connectedAt, _, _ := conn.snapshot()
duration := now.Sub(connectedAt).Seconds()
if duration < 0 {
duration = 0
}
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.3f\n", prometheusLabels(index, module, upstream), duration)
}
}
48 changes: 38 additions & 10 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
const lineFeed = '\n'

type ConnInfo struct {
mu sync.RWMutex
Index uint32
LocalAddr string
RemoteAddr string
Expand All @@ -60,7 +61,26 @@ type ConnInfo struct {
ReceivedBytes atomic.Int64
}

func (c *ConnInfo) SetModule(module string) {
c.mu.Lock()
defer c.mu.Unlock()
c.Module = module
}

func (c *ConnInfo) SetUpstreamAddr(upstreamAddr string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UpstreamAddr = upstreamAddr
}

func (c *ConnInfo) snapshot() (index uint32, module, upstreamAddr string, connectedAt time.Time, sentBytes, receivedBytes int64) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Index, c.Module, c.UpstreamAddr, c.ConnectedAt, c.SentBytes.Load(), c.ReceivedBytes.Load()
}

func (c *ConnInfo) MarshalJSON() ([]byte, error) {
index, module, upstreamAddr, connectedAt, sentBytes, receivedBytes := c.snapshot()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can change the return type of c.snapshot() to the struct that follows, and then this function can be simplified to json.Marshal(c.snapshot()).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in e0080e4. snapshot() now returns a JSON-ready snapshot struct and MarshalJSON() is simplified to json.Marshal(c.snapshot()). I also updated metrics/tests to use the snapshot fields directly.

Verification run:

  • golangci-lint run in the official container: 0 issues
  • go test -race ./... in a Go container with rsync/netcat installed: passed

// Handle atomic value (cannot marshal directly)
return json.Marshal(struct {
Index uint32 `json:"index"`
Expand All @@ -72,14 +92,14 @@ func (c *ConnInfo) MarshalJSON() ([]byte, error) {
SentBytes int64 `json:"sentBytes"`
ReceivedBytes int64 `json:"receivedBytes"`
}{
Index: c.Index,
Index: index,
LocalAddr: c.LocalAddr,
RemoteAddr: c.RemoteAddr,
ConnectedAt: c.ConnectedAt,
Module: c.Module,
UpstreamAddr: c.UpstreamAddr,
SentBytes: c.SentBytes.Load(),
ReceivedBytes: c.ReceivedBytes.Load(),
ConnectedAt: connectedAt,
Module: module,
UpstreamAddr: upstreamAddr,
SentBytes: sentBytes,
ReceivedBytes: receivedBytes,
})
}

Expand Down Expand Up @@ -537,8 +557,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err
}

moduleName := string(buf[:n-1]) // trim trailing \n
info.Module = moduleName
s.connInfo.Store(index, &info)
info.SetModule(moduleName)

targets, ok := s.getTargetsForModule(moduleName)
if !ok {
Expand All @@ -551,8 +570,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err
target := targets[chooseTargetByClientIP(net.ParseIP(ip), len(targets))]
upstreamAddr := target.Addr
useProxyProtocol := target.UseProxyProtocol
info.UpstreamAddr = upstreamAddr
s.connInfo.Store(index, &info)
info.SetUpstreamAddr(upstreamAddr)

upstreamQueue, ok := s.getQueueForUpstream(target.Upstream)
if !ok {
Expand Down Expand Up @@ -804,6 +822,16 @@ func (s *Server) runHTTPServer() error {
_, _ = fmt.Fprintf(w, "rsync-proxy,host=%s count=%d %d\n", hostname, count, timestamp)
})

mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
s.writePrometheusMetrics(w, time.Now())
})

return http.Serve(s.HTTPListener, &mux)
}

Expand Down
145 changes: 144 additions & 1 deletion pkg/server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package server

import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -57,6 +59,10 @@
return srv
}

func testHTTPClient() *http.Client {
return &http.Client{Timeout: time.Second}
}

func doClientHandshake(conn *rsync.Conn, version []byte, module string) (svrVersion string, err error) {
_, err = conn.Write(version)
if err != nil {
Expand Down Expand Up @@ -109,7 +115,7 @@
defer fakeRsync.Close()

srv.modules = map[string][]Target{
"fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}},

Check failure on line 118 in pkg/server/server_test.go

View workflow job for this annotation

GitHub Actions / Build

string `fake` has 7 occurrences, make it a constant (goconst)
}
srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)}

Expand Down Expand Up @@ -220,7 +226,7 @@

rawConn, err := tls.Dial("tcp", srv.TLSListenAddr, &tls.Config{
RootCAs: pool,
ServerName: "localhost",

Check failure on line 229 in pkg/server/server_test.go

View workflow job for this annotation

GitHub Actions / Build

string `localhost` has 3 occurrences, make it a constant (goconst)
})
r.NoError(err)
conn := rsync.NewConn(rawConn)
Expand Down Expand Up @@ -334,12 +340,149 @@

require.Eventually(t, func() bool {
infos := srv.ListConnectionInfo()
return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr
if len(infos) != 1 {
return false
}
_, _, infoUpstreamAddr, _, _, _ := infos[0].snapshot()
return infoUpstreamAddr == upstreamAddr
}, time.Second, 10*time.Millisecond)

wg.Done()
}

func TestMetricsEndpointNoConnections(t *testing.T) {
srv := startServer(t)
defer srv.Close()

resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics")
require.NoError(t, err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(body)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", resp.Header.Get("Content-Type"))
assert.Contains(t, text, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.")
assert.Contains(t, text, "# TYPE rsync_proxy_active_connections gauge")
assert.Contains(t, text, "rsync_proxy_active_connections 0\n")
}

func TestMetricsEndpointRejectsNonGET(t *testing.T) {
srv := startServer(t)
defer srv.Close()

resp, err := testHTTPClient().Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil)
require.NoError(t, err)
defer resp.Body.Close()

Comment on lines +352 to +378
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e617a49. The HTTP tests now use a test client with a timeout instead of package-level http.Get/http.Post.

assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
}

func TestMetricsIncludesActiveConnections(t *testing.T) {
srv := startServer(t)
defer srv.Close()

var wg sync.WaitGroup
wg.Add(1)
fakeRsync := rsync.NewServer(func(conn *rsync.Conn) {
defer conn.Close()
_, _, err := doServerHandshake(conn, RsyncdServerVersion)
require.NoError(t, err)
wg.Wait()
})
fakeRsync.Start()
defer fakeRsync.Close()

upstreamAddr := fakeRsync.Listener.Addr().String()
srv.modules = map[string][]Target{
"fake": {{Upstream: "u1", Addr: upstreamAddr}},
}
srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)}

rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String())
require.NoError(t, err)
conn := rsync.NewConn(rawConn)
defer conn.Close()

_, err = doClientHandshake(conn, RsyncdServerVersion, "fake")
require.NoError(t, err)

require.Eventually(t, func() bool {
infos := srv.ListConnectionInfo()
if len(infos) != 1 {
return false
}
_, _, infoUpstreamAddr, _, _, _ := infos[0].snapshot()
return infoUpstreamAddr == upstreamAddr
}, time.Second, 10*time.Millisecond)

resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics")
require.NoError(t, err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(body)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Contains(t, text, "rsync_proxy_active_connections 1\n")
assert.Contains(t, text, fmt.Sprintf("rsync_proxy_active_connections_by_module{module=\"fake\",upstream=%q} 1\n", upstreamAddr))
assert.Contains(t, text, "rsync_proxy_connection_sent_bytes{index=\"")
assert.Contains(t, text, "module=\"fake\"")
assert.Contains(t, text, fmt.Sprintf("upstream=%q", upstreamAddr))
assert.Contains(t, text, "rsync_proxy_connection_received_bytes{index=\"")
assert.Contains(t, text, "rsync_proxy_connection_connected_timestamp_seconds{index=\"")
assert.Contains(t, text, "rsync_proxy_connection_duration_seconds{index=\"")
assert.NotContains(t, text, rawConn.LocalAddr().String())

wg.Done()
}

func TestPrometheusConnectionGroupingUsesStructuredKey(t *testing.T) {
srv := New()

first := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)}
first.Module = "a\xffb"
first.UpstreamAddr = "c"
srv.connInfo.Store(first.Index, first)

second := &ConnInfo{Index: 2, ConnectedAt: time.Unix(100, 0)}
second.Module = "a"
second.UpstreamAddr = "b\xffc"
srv.connInfo.Store(second.Index, second)

var buf bytes.Buffer
srv.writePrometheusMetrics(&buf, time.Unix(101, 0))
text := buf.String()

assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\xffb\",upstream=\"c\"} 1\n")
assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 1\n")
assert.NotContains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 2\n")
}

func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) {
srv := New()
conn := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)}
conn.Module = "fake"
conn.UpstreamAddr = "127.0.0.1:873"
srv.connInfo.Store(conn.Index, conn)

var buf bytes.Buffer
srv.writePrometheusMetrics(&buf, time.Unix(100, 250_000_000))

assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n")
}

func TestPrometheusLabelValueEscaping(t *testing.T) {
assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain"))
assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`))
assert.Equal(t, `slash\\value`, prometheusEscapeLabelValue(`slash\value`))
assert.Equal(t, `line\nbreak`, prometheusEscapeLabelValue("line\nbreak"))
assert.Equal(t, `unknown`, prometheusLabelValueOrUnknown(""))
}

func TestPerUpstreamQueueIsolation(t *testing.T) {
srv := startServer(t)
defer srv.Close()
Expand Down Expand Up @@ -517,7 +660,7 @@
dir := t.TempDir()
configPath := filepath.Join(dir, "config.toml")

firstUpstream := rsync.NewModuleListServer([]string{"foo"})

Check failure on line 663 in pkg/server/server_test.go

View workflow job for this annotation

GitHub Actions / Build

string `foo` has 17 occurrences, make it a constant (goconst)
firstUpstream.Start()
defer firstUpstream.Close()

Expand Down Expand Up @@ -564,7 +707,7 @@
srv := New()
srv.reloadLock.Lock()
srv.upstreams = []upstreamConfig{
{Name: "u1", Modules: []string{"foo", "bar"}},

Check failure on line 710 in pkg/server/server_test.go

View workflow job for this annotation

GitHub Actions / Build

string `bar` has 16 occurrences, make it a constant (goconst)
{Name: "u2", Modules: []string{"baz"}},
}
srv.reloadLock.Unlock()
Expand Down
Loading