Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ linters:
rules:
- linters:
- errcheck
- goconst
- unparam
path: _test.go
paths:
Expand Down
106 changes: 106 additions & 0 deletions pkg/server/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package server

import (
"fmt"
"io"
"sort"
"strings"
"time"
)

func prometheusEscapeLabelValue(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, `"`, `\"`)
return s
}

func prometheusLabelValueOrUnknown(s string) string {
if s == "" {
return "unknown"
}
return s
}

func prometheusLabels(index uint32, module, upstream string) string {
return fmt.Sprintf(
`index="%d",module="%s",upstream="%s"`,
index,
prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)),
prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)),
)
}

type prometheusConnectionGroup struct {
module string
upstream string
}

func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) {
connections := s.ListConnectionInfo()

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge")
_, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount())

connectionCounts := make(map[prometheusConnectionGroup]int)
for _, conn := range connections {
snapshot := conn.snapshot()
key := prometheusConnectionGroup{
module: prometheusLabelValueOrUnknown(snapshot.Module),
upstream: prometheusLabelValueOrUnknown(snapshot.UpstreamAddr),
}
connectionCounts[key]++
}

keys := make([]prometheusConnectionGroup, 0, len(connectionCounts))
for key := range connectionCounts {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
if keys[i].module != keys[j].module {
return keys[i].module < keys[j].module
}
return keys[i].upstream < keys[j].upstream
})

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge")
for _, key := range keys {
module := prometheusEscapeLabelValue(key.module)
upstream := prometheusEscapeLabelValue(key.upstream)
_, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key])
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge")
for _, conn := range connections {
snapshot := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.SentBytes)
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge")
for _, conn := range connections {
snapshot := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.ReceivedBytes)
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge")
for _, conn := range connections {
snapshot := conn.snapshot()
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.ConnectedAt.Unix())
}

_, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.")
_, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge")
for _, conn := range connections {
snapshot := conn.snapshot()
duration := now.Sub(snapshot.ConnectedAt).Seconds()
if duration < 0 {
duration = 0
}
_, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.3f\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), duration)
}
}
62 changes: 45 additions & 17 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
const lineFeed = '\n'

type ConnInfo struct {
mu sync.RWMutex
Index uint32
LocalAddr string
RemoteAddr string
Expand All @@ -60,18 +61,33 @@ type ConnInfo struct {
ReceivedBytes atomic.Int64
}

func (c *ConnInfo) MarshalJSON() ([]byte, error) {
// Handle atomic value (cannot marshal directly)
return json.Marshal(struct {
Index uint32 `json:"index"`
LocalAddr string `json:"local"`
RemoteAddr string `json:"remote"`
ConnectedAt time.Time `json:"connected"`
Module string `json:"module"`
UpstreamAddr string `json:"upstream"`
SentBytes int64 `json:"sentBytes"`
ReceivedBytes int64 `json:"receivedBytes"`
}{
type connInfoSnapshot struct {
Index uint32 `json:"index"`
LocalAddr string `json:"local"`
RemoteAddr string `json:"remote"`
ConnectedAt time.Time `json:"connected"`
Module string `json:"module"`
UpstreamAddr string `json:"upstream"`
SentBytes int64 `json:"sentBytes"`
ReceivedBytes int64 `json:"receivedBytes"`
}

func (c *ConnInfo) SetModule(module string) {
c.mu.Lock()
defer c.mu.Unlock()
c.Module = module
}

func (c *ConnInfo) SetUpstreamAddr(upstreamAddr string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UpstreamAddr = upstreamAddr
}

func (c *ConnInfo) snapshot() connInfoSnapshot {
c.mu.RLock()
defer c.mu.RUnlock()
return connInfoSnapshot{
Index: c.Index,
LocalAddr: c.LocalAddr,
RemoteAddr: c.RemoteAddr,
Expand All @@ -80,7 +96,11 @@ func (c *ConnInfo) MarshalJSON() ([]byte, error) {
UpstreamAddr: c.UpstreamAddr,
SentBytes: c.SentBytes.Load(),
ReceivedBytes: c.ReceivedBytes.Load(),
})
}
}

func (c *ConnInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(c.snapshot())
}

type Target struct {
Expand Down Expand Up @@ -537,8 +557,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err
}

moduleName := string(buf[:n-1]) // trim trailing \n
info.Module = moduleName
s.connInfo.Store(index, &info)
info.SetModule(moduleName)

targets, ok := s.getTargetsForModule(moduleName)
if !ok {
Expand All @@ -551,8 +570,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err
target := targets[chooseTargetByClientIP(net.ParseIP(ip), len(targets))]
upstreamAddr := target.Addr
useProxyProtocol := target.UseProxyProtocol
info.UpstreamAddr = upstreamAddr
s.connInfo.Store(index, &info)
info.SetUpstreamAddr(upstreamAddr)

upstreamQueue, ok := s.getQueueForUpstream(target.Upstream)
if !ok {
Expand Down Expand Up @@ -804,6 +822,16 @@ func (s *Server) runHTTPServer() error {
_, _ = fmt.Fprintf(w, "rsync-proxy,host=%s count=%d %d\n", hostname, count, timestamp)
})

mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
s.writePrometheusMetrics(w, time.Now())
})

return http.Serve(s.HTTPListener, &mux)
}

Expand Down
143 changes: 142 additions & 1 deletion pkg/server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package server

import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -57,6 +59,10 @@ func startServer(t *testing.T) *Server {
return srv
}

func testHTTPClient() *http.Client {
return &http.Client{Timeout: time.Second}
}

func doClientHandshake(conn *rsync.Conn, version []byte, module string) (svrVersion string, err error) {
_, err = conn.Write(version)
if err != nil {
Expand Down Expand Up @@ -334,12 +340,147 @@ func TestStatusIncludesSelectedUpstream(t *testing.T) {

require.Eventually(t, func() bool {
infos := srv.ListConnectionInfo()
return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr
if len(infos) != 1 {
return false
}
return infos[0].snapshot().UpstreamAddr == upstreamAddr
}, time.Second, 10*time.Millisecond)

wg.Done()
}

func TestMetricsEndpointNoConnections(t *testing.T) {
srv := startServer(t)
defer srv.Close()

resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics")
require.NoError(t, err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(body)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", resp.Header.Get("Content-Type"))
assert.Contains(t, text, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.")
assert.Contains(t, text, "# TYPE rsync_proxy_active_connections gauge")
assert.Contains(t, text, "rsync_proxy_active_connections 0\n")
}

func TestMetricsEndpointRejectsNonGET(t *testing.T) {
srv := startServer(t)
defer srv.Close()

resp, err := testHTTPClient().Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil)
require.NoError(t, err)
defer resp.Body.Close()

Comment on lines +352 to +378
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e617a49. The HTTP tests now use a test client with a timeout instead of package-level http.Get/http.Post.

assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
}

func TestMetricsIncludesActiveConnections(t *testing.T) {
srv := startServer(t)
defer srv.Close()

var wg sync.WaitGroup
wg.Add(1)
fakeRsync := rsync.NewServer(func(conn *rsync.Conn) {
defer conn.Close()
_, _, err := doServerHandshake(conn, RsyncdServerVersion)
require.NoError(t, err)
wg.Wait()
})
fakeRsync.Start()
defer fakeRsync.Close()

upstreamAddr := fakeRsync.Listener.Addr().String()
srv.modules = map[string][]Target{
"fake": {{Upstream: "u1", Addr: upstreamAddr}},
}
srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)}

rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String())
require.NoError(t, err)
conn := rsync.NewConn(rawConn)
defer conn.Close()

_, err = doClientHandshake(conn, RsyncdServerVersion, "fake")
require.NoError(t, err)

require.Eventually(t, func() bool {
infos := srv.ListConnectionInfo()
if len(infos) != 1 {
return false
}
return infos[0].snapshot().UpstreamAddr == upstreamAddr
}, time.Second, 10*time.Millisecond)

resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics")
require.NoError(t, err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
text := string(body)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Contains(t, text, "rsync_proxy_active_connections 1\n")
assert.Contains(t, text, fmt.Sprintf("rsync_proxy_active_connections_by_module{module=\"fake\",upstream=%q} 1\n", upstreamAddr))
assert.Contains(t, text, "rsync_proxy_connection_sent_bytes{index=\"")
assert.Contains(t, text, "module=\"fake\"")
assert.Contains(t, text, fmt.Sprintf("upstream=%q", upstreamAddr))
assert.Contains(t, text, "rsync_proxy_connection_received_bytes{index=\"")
assert.Contains(t, text, "rsync_proxy_connection_connected_timestamp_seconds{index=\"")
assert.Contains(t, text, "rsync_proxy_connection_duration_seconds{index=\"")
assert.NotContains(t, text, rawConn.LocalAddr().String())

wg.Done()
}

func TestPrometheusConnectionGroupingUsesStructuredKey(t *testing.T) {
srv := New()

first := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)}
first.Module = "a\xffb"
first.UpstreamAddr = "c"
srv.connInfo.Store(first.Index, first)

second := &ConnInfo{Index: 2, ConnectedAt: time.Unix(100, 0)}
second.Module = "a"
second.UpstreamAddr = "b\xffc"
srv.connInfo.Store(second.Index, second)

var buf bytes.Buffer
srv.writePrometheusMetrics(&buf, time.Unix(101, 0))
text := buf.String()

assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\xffb\",upstream=\"c\"} 1\n")
assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 1\n")
assert.NotContains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 2\n")
}

func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) {
srv := New()
conn := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)}
conn.Module = "fake"
conn.UpstreamAddr = "127.0.0.1:873"
srv.connInfo.Store(conn.Index, conn)

var buf bytes.Buffer
srv.writePrometheusMetrics(&buf, time.Unix(100, 250_000_000))

assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n")
}

func TestPrometheusLabelValueEscaping(t *testing.T) {
assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain"))
assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`))
assert.Equal(t, `slash\\value`, prometheusEscapeLabelValue(`slash\value`))
assert.Equal(t, `line\nbreak`, prometheusEscapeLabelValue("line\nbreak"))
assert.Equal(t, `unknown`, prometheusLabelValueOrUnknown(""))
}

func TestPerUpstreamQueueIsolation(t *testing.T) {
srv := startServer(t)
defer srv.Close()
Expand Down
Loading