Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,26 @@ func forwardToPrometheus(metrics []byte, forwardURL string) error {
return nil
}

// maxTelemetryBodySize is the maximum accepted size for a telemetry POST body.
const maxTelemetryBodySize = 10 * 1024 * 1024 // 10 MB

// isValidClusterID accepts only alphanumeric characters, hyphens, underscores,
// and dots (standard Kubernetes naming). This prevents Prometheus text-format
// injection via the X-Cluster-ID header (e.g. a value containing '"' or '\n'
// would corrupt the label output forwarded to VictoriaMetrics).
func isValidClusterID(s string) bool {
if s == "" || len(s) > 253 {
return false
}
for _, c := range s {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '-' || c == '_' || c == '.') {
return false
}
}
return true
}
Comment on lines +113 to +124
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For validating string formats like this, using a compiled regular expression is a common and idiomatic approach in Go. It's more declarative and can be easier to read and maintain than manual character-by-character iteration.

Consider replacing this function with a regex-based implementation for better readability and maintainability:

// At package level, to compile once:
// var isValidClusterIDRegexp = regexp.MustCompile(`^[a-zA-Z0-9._-]{1,253}$`)

func isValidClusterID(s string) bool {
	// Assumes isValidClusterIDRegexp is defined at package level
	// and 'regexp' package is imported.
	return isValidClusterIDRegexp.MatchString(s)
}

This would make the validation logic more concise and align with common Go practices for this type of task. You would need to add an import for the regexp package.


func handleTelemetry(w http.ResponseWriter, r *http.Request, forwardURL string) {
startTime := time.Now()

Expand All @@ -113,12 +133,13 @@ func handleTelemetry(w http.ResponseWriter, r *http.Request, forwardURL string)
}

clusterID := r.Header.Get("X-Cluster-ID")
if clusterID == "" {
log.Printf("Request rejected: missing X-Cluster-ID header")
http.Error(w, "X-Cluster-ID header is required", http.StatusBadRequest)
if !isValidClusterID(clusterID) {
log.Printf("Request rejected: invalid or missing X-Cluster-ID header")
http.Error(w, "X-Cluster-ID header is required and must contain only alphanumeric characters, hyphens, underscores, or dots", http.StatusBadRequest)
return
}

r.Body = http.MaxBytesReader(w, r.Body, maxTelemetryBodySize)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using http.MaxBytesReader is correct for limiting the request body size. However, the error handling for io.ReadAll on the subsequent lines is now incomplete. When the size limit is exceeded, io.ReadAll returns a *http.MaxBytesError. The http.Server automatically sends a 413 response in this case. The current code at lines 144-147 will then attempt to write another error response, which will cause a superfluous response.WriteHeader error log on the server.

You should specifically check for *http.MaxBytesError and return without writing another response.

body, err := io.ReadAll(r.Body)
if err != nil {
    var maxBytesErr *http.MaxBytesError
    if errors.As(err, &maxBytesErr) {
        // http.Server automatically sends a 413 response.
        // We don't need to write another error, just log and return.
        log.Printf("Request rejected: body exceeded %d bytes limit", maxBytesErr.Limit)
        return
    }
    log.Printf("Error reading request body: %v", err)
    http.Error(w, fmt.Sprintf("Error reading request: %v", err), http.StatusBadRequest)
    return
}

You will also need to import the errors package.

body, err := io.ReadAll(r.Body)
if err != nil {
log.Printf("Error reading request body: %v", err)
Expand Down