From fe768d2cd69ec00f5f04636e1382bdc92e86372b Mon Sep 17 00:00:00 2001 From: Devin Wong Date: Tue, 30 Jun 2026 15:35:16 -0700 Subject: [PATCH] feat(anc): add check-hotfix subcommand to read hotfix pointer from LPS Add a fail-open 'check-hotfix' CLI subcommand that reads the base->hotfix pointer map from the live-patching-service (LPS) over the IMDS-attested SNI path that is reachable pre-kubelet, and stages the resolved {hotfixes:{...}} pointer to the path download-hotfix already reads. download-hotfix keeps its unchanged patch-only, strictly-higher gating; check-hotfix only fetches and writes the pointer. - Raw net/http HTTPS GET (no client-go). TLS ServerName pinned to the LPS SNI host while the TCP dial is forced to the apiserver FQDN (curl --resolve trick); Authorization is the IMDS attested-data signature; the server cert is verified against the cluster CA from the provision-config. - FQDN + cluster CA come from the AKSNodeConfig ANC already parses (the only credential source present pre-provisioning); caSource is logged. - Shares the hotfixConfig parser/data contract with download-hotfix. - Always exits 0; emits CheckHotfix telemetry (lpsRead, noHotfixForBase, noHotfixAvailable, customDataFallback, failed). - A reachable LPS with no hotfix published for this node (HTTP 401, 403, 404) is a benign no-op (noHotfixAvailable): no overlay is staged and it is never classified as a failure. Only transport/5xx failures fall back. - PoC cold-start fallback reads a lenient top-level hotfixes object from the node config when the LPS read fails (TODO: typed contract field). - Injectable App fields (checkHotfixFetcher, fetchAttestedToken, nodeConfigPath) for network-free unit tests. - The LPS route + response schema are a planned-maintenance deliverable that is not finalized; lpsHotfixPath is a clearly-marked placeholder with a TODO. The IMDS/LPS client helpers mirror the connectivity prototype and should be de-duplicated into a shared LPS client when that lands. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- aks-node-controller/app.go | 25 + aks-node-controller/checkhotfix.go | 616 ++++++++++++++++++ aks-node-controller/checkhotfix_test.go | 607 +++++++++++++++++ aks-node-controller/common/httpclient.go | 83 +++ aks-node-controller/common/httpclient_test.go | 126 ++++ 5 files changed, 1457 insertions(+) create mode 100644 aks-node-controller/checkhotfix.go create mode 100644 aks-node-controller/checkhotfix_test.go create mode 100644 aks-node-controller/common/httpclient.go create mode 100644 aks-node-controller/common/httpclient_test.go diff --git a/aks-node-controller/app.go b/aks-node-controller/app.go index 40d42acb73a..1c4216bc0c1 100644 --- a/aks-node-controller/app.go +++ b/aks-node-controller/app.go @@ -43,11 +43,23 @@ type App struct { eventLogger *helpers.EventLogger // hotfixVersionPath overrides the default hotfix version file location for testing. + // It is also the path check-hotfix writes the resolved pointer to. hotfixVersionPath string // aptSourcesDir overrides the default APT sources directory for testing. aptSourcesDir string // nodeCustomDataPath overrides the default nodecustomdata path for testing. nodeCustomDataPath string + // nodeConfigPath overrides the default AKSNodeConfig path for testing. It is the + // source for check-hotfix's LPS endpoint (apiserver FQDN + cluster CA) and the + // cold-start fallback pointer. + nodeConfigPath string + // checkHotfixFetcher overrides the real LPS hotfix-pointer GET for testing, letting + // unit tests inject a canned pointer body or errors without real networking. + checkHotfixFetcher func(ctx context.Context) ([]byte, error) + // fetchAttestedToken overrides retrieval of the IMDS attested-data token used as the + // Authorization header for the check-hotfix LPS fetch. When nil, the real IMDS endpoint + // is queried. + fetchAttestedToken func(ctx context.Context) (string, error) } // provision.json values are emitted as strings by the shell jq invocation. @@ -137,6 +149,19 @@ func (a *App) Run(ctx context.Context, args []string) int { return a.runDownloadHotfixCommand(ctx) }, }, + { + Name: "check-hotfix", + Usage: "Read the hotfix pointer from the live-patching-service and stage it (fail-open)", + Action: func(ctx context.Context, cmd *cli.Command) error { + if extra := cmd.Args().Slice(); len(extra) > 0 { + // Fail-open: check-hotfix must always exit 0 so provisioning is never + // blocked, so unexpected args are logged and ignored rather than turned + // into a non-zero exit code via errToExitCode. + slog.Warn("ignoring unexpected check-hotfix arguments", "args", strings.Join(extra, " ")) + } + return a.runCheckHotfixCommand(ctx) + }, + }, }, } diff --git a/aks-node-controller/checkhotfix.go b/aks-node-controller/checkhotfix.go new file mode 100644 index 00000000000..5e02cd39daf --- /dev/null +++ b/aks-node-controller/checkhotfix.go @@ -0,0 +1,616 @@ +package main + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/Azure/agentbaker/aks-node-controller/common" + "github.com/Azure/agentbaker/aks-node-controller/helpers" + "github.com/Azure/agentbaker/aks-node-controller/pkg/nodeconfigutils" +) + +// check-hotfix reads the hotfix pointer from the live-patching-service (LPS) over the +// IMDS-attested SNI path that is reachable pre-kubelet, then writes it to the same path +// download-hotfix already reads. download-hotfix then re-resolves the pointer against the +// node's baked ANC version and keeps its unchanged patch-only, strictly-higher gating. +// check-hotfix only fetches and stages the pointer; it never installs anything and never +// blocks provisioning (fail-open). +const ( + // lpsSNIHost is the live-patching-service SNI/Host that the kube-api-proxy envoy on the + // apiserver front routes to the LPS backend. The TLS handshake pins this as ServerName + // while the TCP connection is forced to the apiserver FQDN (the curl --resolve trick), + // giving the faithful end-to-end path node -> SNI(lpsSNIHost) -> envoy -> LPS. + lpsSNIHost = "aks-security-patch.data.mcr.microsoft.com" + + // lpsAPIServerPort is the HTTPS port the apiserver front (and thus the LPS path) listens on. + lpsAPIServerPort = "443" + + // lpsHotfixPath is the LPS route serving the base->hotfix pointer map. + // + // TODO(provisioning-hotfix): this route and its response schema are a planned-maintenance + // LPS-endpoint deliverable that is NOT finalized yet. The connectivity prototype only + // proved reachability of the LPS read path (/v1/packages). Replace this placeholder with + // the real route once the LPS endpoint contract is published. The expected response body + // is the {"hotfixes":{"":""}} JSON object that parses + // directly into the shared hotfixConfig type (see parseHotfixConfig). + lpsHotfixPath = "/v1/anc-hotfix" + + // imdsAttestedDocURL returns the IMDS attested-data document, whose signature is used as + // the LPS Authorization token. IMDS is reachable pre-kubelet (the same primitive Secure + // TLS Bootstrap uses), so this works before any kube credential exists. + imdsAttestedDocURL = "http://169.254.169.254/metadata/attested/document?api-version=2025-04-07" +) + +// Timeout tuning for the IMDS and LPS fetches. The generic transport/retry mechanics live in +// the common package (common.NewBaseTransport + common.RetryStringFetch); these constants are +// the domain-specific budgets this command layers on top. This file otherwise keeps only the +// domain logic: the LPS endpoint identity (SNI host, route), the TLS/CA build, the forced dial, +// the attested-token parsing, and the fail-open fetch/parse/stage workflow. +const ( + // LPS timeouts. The LPS GET is on the provisioning critical path, so every phase fails + // fast; lpsFetchTimeout bounds the whole round-trip. A too-tight deadline is safe because + // check-hotfix is fail-open and falls back to the cold-start pointer. lpsFetchTimeout is + // the single knob to loosen if private-cluster fronts prove slower (Cameron flagged that + // possibility); the per-phase connect/handshake budgets can stay tight. + lpsDialTimeout = 2 * time.Second // TCP connect to the apiserver front + lpsTLSHandshakeTimeout = 2 * time.Second // envoy TLS negotiation + lpsResponseHeaderTimeout = 2 * time.Second // server time-to-first-byte + lpsFetchTimeout = 3 * time.Second // overall (ctx + http.Client.Timeout) + + // IMDS timeouts. IMDS is a link-local endpoint (169.254.169.254) that is normally + // near-instant, so these are tighter than the LPS knobs. imdsFetchTimeout bounds a single + // attempt. + imdsDialTimeout = 1 * time.Second + // imdsTLSHandshakeTimeout is inert in practice: IMDS is a plain-HTTP endpoint + // (http://169.254.169.254), so the transport never enters the TLS handshake path and this + // timer is never armed. It is kept nonzero only so the shared transport builder stays + // uniform across the LPS (HTTPS) and IMDS callers, and to remain defensive if IMDS is ever + // pointed at an HTTPS endpoint. + imdsTLSHandshakeTimeout = 1 * time.Second + imdsResponseHeaderTimeout = 1 * time.Second + imdsFetchTimeout = 2 * time.Second + + // imdsMaxAttempts is the total number of IMDS attempts (1 initial + retries). IMDS is + // local and usually reliable, so a single quick retry is enough to smooth a one-off blip + // without materially adding to provisioning latency. + imdsMaxAttempts = 2 +) + +// checkHotfixOutcome is the telemetry taxonomy emitted under TaskName "CheckHotfix". +type checkHotfixOutcome string + +const ( + // outcomeLPSRead: LPS pointer fetched + parsed OK and a hotfix entry matched this node's base. + outcomeLPSRead checkHotfixOutcome = "lpsRead" + // outcomeNoHotfixForBase: LPS read OK but no entry matched this node's YYYYMM.DD base. + outcomeNoHotfixForBase checkHotfixOutcome = "noHotfixForBase" + // outcomeNoHotfixAvailable: the LPS was reachable and the request was well-formed, but it + // returned no hotfix for this node (HTTP 401/403/404). This simply means the LPS has + // nothing published for this node yet (e.g. the planned-maintenance hotfix route is not + // serving content for it). It is the expected steady state, so it is a benign no-op: no + // overlay is staged and it is never treated as a failure. + outcomeNoHotfixAvailable checkHotfixOutcome = "noHotfixAvailable" + // outcomeCustomDataFallback: LPS read failed; the embedded customdata pointer was used. + outcomeCustomDataFallback checkHotfixOutcome = "customDataFallback" + // outcomeFailed: everything failed; nothing was staged. Provisioning still proceeds (exit 0). + outcomeFailed checkHotfixOutcome = "failed" +) + +// lpsUnavailableError marks a benign LPS response that means "no hotfix is available for this +// node yet" rather than a failure. It is returned for HTTP 401, 403, and 404, all of which mean +// the LPS is reachable but has nothing published for this node. check-hotfix must not classify +// it as a hard failure or retry it. +type lpsUnavailableError struct { + statusCode int +} + +func (e *lpsUnavailableError) Error() string { + return fmt.Sprintf("LPS has no hotfix available for this node (status %d)", e.statusCode) +} + +// isLPSUnavailable reports whether err is a benign "nothing for this node yet" LPS response. +func isLPSUnavailable(err error) bool { + var u *lpsUnavailableError + return errors.As(err, &u) +} + +// lpsHTTPError is a non-2xx LPS response that is NOT the benign 401/403/404 set. It carries the +// status code so the caller can distinguish a reachable-but-erroring server from a transport +// failure. A 4xx here means the server was reached and rejected the request (e.g. 400/429), so +// the cold-start pointer (which may be stale) must NOT be staged; a 5xx means the server is +// broken/overloaded, for which cold-start fallback is appropriate. +type lpsHTTPError struct { + statusCode int +} + +func (e *lpsHTTPError) Error() string { + return fmt.Sprintf("LPS returned status %d", e.statusCode) +} + +// shouldColdStartFallback reports whether a failed LPS fetch should fall back to the embedded +// cold-start pointer. Fallback is only appropriate when the LPS could not be reached or talked +// to: transport/pre-network errors (no HTTP status) and server errors (5xx). A reachable LPS +// that returns a non-benign 4xx (e.g. 400/429) is authoritative that the request was bad, not a +// reason to stage a possibly-stale cold-start pointer. +func shouldColdStartFallback(err error) bool { + var httpErr *lpsHTTPError + if errors.As(err, &httpErr) { + return httpErr.statusCode >= 500 + } + return true +} + +// runCheckHotfixCommand is the cli Action for `check-hotfix`. It ALWAYS returns nil so +// provisioning is never blocked: any error (404, 403, timeout, parse failure) is logged, +// emitted as telemetry, and swallowed. Internal helpers return errors for testability only. +func (a *App) runCheckHotfixCommand(ctx context.Context) (err error) { + slog.Info("aks-node-controller check-hotfix started") + startTime := time.Now() + + // Fail-open hardening: a panic anywhere in the check-hotfix workflow must not crash the + // process. The wrapper runs check-hotfix before the customdata (cold-start) route, so a + // crash here could otherwise prevent that route from completing. Recover, emit failed + // telemetry, and return nil so provisioning proceeds. + defer func() { + if r := recover(); r != nil { + slog.Error("check-hotfix panicked (fail-open)", "panic", r) + if a.eventLogger != nil { + a.eventLogger.LogEvent("CheckHotfix", + fmt.Sprintf("check-hotfix outcome=%s panic=%v", outcomeFailed, r), + helpers.EventLevelError, startTime, time.Now()) + } + err = nil + } + }() + + outcome, err := a.checkHotfix(ctx) + + endTime := time.Now() + level := helpersEventLevel(outcome) + message := fmt.Sprintf("check-hotfix outcome=%s", outcome) + if err != nil { + message = fmt.Sprintf("%s error=%s", message, err.Error()) + slog.Warn("check-hotfix completed with error (fail-open)", "outcome", outcome, "error", err) + } else { + slog.Info("check-hotfix completed", "outcome", outcome) + } + if a.eventLogger != nil { + a.eventLogger.LogEvent("CheckHotfix", message, level, startTime, endTime) + } + + // Fail-open: never propagate an error so the cli exit code stays 0. + return nil +} + +// checkHotfix performs the fetch/parse/stage workflow and reports a telemetry outcome. +// It is fail-open by contract: the only caller (runCheckHotfixCommand) swallows the error. +func (a *App) checkHotfix(ctx context.Context) (checkHotfixOutcome, error) { + hotfixPath := a.hotfixVersionPath + if hotfixPath == "" { + hotfixPath = defaultHotfixVersionPath + } + + data, fetchErr := a.fetchHotfix(ctx) + if fetchErr != nil { + return a.handleFetchError(hotfixPath, fetchErr) + } + + cfg, err := parseHotfixConfig(data) + if err != nil { + return outcomeFailed, fmt.Errorf("parsing LPS hotfix pointer: %w", err) + } + + // stagedHotfixConfig is the exact shape writeHotfixConfig persists (map-only; the legacy + // Version field is dropped). Basing both the write and the telemetry decision on this same + // value keeps the reported outcome consistent with what download-hotfix will actually read: + // a legacy-only pointer stages nothing resolvable, so it must report noHotfixForBase, not + // LPSRead. + staged := hotfixConfig{Hotfixes: cfg.Hotfixes} + + if err := writeHotfixConfig(hotfixPath, staged); err != nil { + return outcomeFailed, fmt.Errorf("writing hotfix config: %w", err) + } + + // Report whether this node's base actually has a pointer in the staged config. + // download-hotfix still performs the authoritative patch-only-strictly-higher gating; + // this is telemetry only. + if staged.resolveVersion(Version) == "" { + return outcomeNoHotfixForBase, nil + } + return outcomeLPSRead, nil +} + +// handleFetchError maps a failed LPS fetch to a check-hotfix outcome. It is fail-open: benign +// 401/403/404 is a no-op, a reachable client error (non-benign 4xx) fails without staging a +// possibly-stale pointer, and only an unreachable/5xx LPS falls back to the cold-start pointer. +func (a *App) handleFetchError(hotfixPath string, fetchErr error) (checkHotfixOutcome, error) { + if isLPSUnavailable(fetchErr) { + // The LPS is reachable but has no hotfix published for this node (HTTP 401/403/404). + // This is the expected steady state, not a failure: stage no overlay (download-hotfix + // keeps whatever pointer it had) and report a benign outcome. Fail-open. + slog.Info("LPS reports no hotfix available for this node (fail-open)", "reason", fetchErr) + return outcomeNoHotfixAvailable, nil + } + if !shouldColdStartFallback(fetchErr) { + // The LPS was reachable but rejected the request with a non-benign 4xx (e.g. 400/429). + // The server is authoritative here, so do NOT stage the (possibly stale) cold-start + // pointer; fail-open without a fallback. + slog.Warn("LPS returned a client error; not falling back to cold-start pointer (fail-open)", + "error", fetchErr) + return outcomeFailed, fmt.Errorf("LPS fetch failed with a client error, not falling back: %w", fetchErr) + } + // LPS could not be reached or talked to (transport failure / 5xx): fall back to the pointer + // embedded in the node config (cold-start path). See coldStartHotfixConfig for the contract TODO. + slog.Warn("failed to reach LPS, attempting cold-start fallback", "error", fetchErr) + cfg, ok, coldErr := a.coldStartHotfixConfig() + if coldErr != nil { + return outcomeFailed, fmt.Errorf("LPS fetch failed (%w) and cold-start fallback failed: %w", fetchErr, coldErr) + } + if !ok { + return outcomeFailed, fmt.Errorf("LPS fetch failed and no cold-start pointer present: %w", fetchErr) + } + if err := writeHotfixConfig(hotfixPath, cfg); err != nil { + return outcomeFailed, fmt.Errorf("writing cold-start hotfix config: %w", err) + } + return outcomeCustomDataFallback, nil +} + +// helpersEventLevel maps a check-hotfix outcome to a guest-agent event level. Only the +// terminal "failed" outcome is reported as an error; the rest are informational because +// the command is fail-open and provisioning continues regardless. +func helpersEventLevel(outcome checkHotfixOutcome) helpers.EventLevel { + if outcome == outcomeFailed { + return helpers.EventLevelError + } + return helpers.EventLevelInformational +} + +// fetchHotfix returns the raw LPS response body. Tests inject checkHotfixFetcher to supply +// canned pointer JSON or errors without networking. +func (a *App) fetchHotfix(ctx context.Context) ([]byte, error) { + if a.checkHotfixFetcher != nil { + return a.checkHotfixFetcher(ctx) + } + return a.fetchHotfixFromLPS(ctx) +} + +// fetchHotfixFromLPS performs the real network GET against the LPS over the IMDS-attested +// SNI path. It sources the apiserver FQDN and cluster CA from the AKSNodeConfig that ANC +// already parses, pins the TLS ServerName to lpsSNIHost while forcing the TCP dial to the +// FQDN, attaches the IMDS attested-data token as the Authorization header, and returns the +// raw response body. A 2xx returns the body; 401/403/404 are surfaced as a benign +// lpsUnavailableError ("nothing for this node yet"); any other non-2xx is surfaced as a typed +// lpsHTTPError carrying the status so the caller can distinguish a reachable-but-erroring server +// (non-benign 4xx -> no cold-start fallback) from a server/transport failure (5xx / transport +// error -> cold-start fallback). +func (a *App) fetchHotfixFromLPS(ctx context.Context) ([]byte, error) { + fqdn, caPEM, err := a.lpsTargetFromNodeConfig() + if err != nil { + return nil, fmt.Errorf("resolving LPS endpoint from node config: %w", err) + } + + token, err := a.attestedToken(ctx) + if err != nil { + return nil, fmt.Errorf("imds attested token: %w", err) + } + + client, caSource, err := buildLPSHTTPClient(fqdn, caPEM) + if err != nil { + return nil, fmt.Errorf("building LPS http client: %w", err) + } + // caSource records the TLS trust source (the provision-config CA) for diagnosis. + slog.Info("check-hotfix LPS TLS trust source", "caSource", caSource, "dialHost", fqdn) + + ctx, cancel := context.WithTimeout(ctx, lpsFetchTimeout) + defer cancel() + + url := "https://" + net.JoinHostPort(lpsSNIHost, lpsAPIServerPort) + lpsHotfixPath + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("building request: %w", err) + } + req.Header.Set("Authorization", token) + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("GET %s: %w", url, err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + switch { + case resp.StatusCode >= 200 && resp.StatusCode < 300: + return body, nil + case resp.StatusCode == http.StatusUnauthorized, + resp.StatusCode == http.StatusForbidden, + resp.StatusCode == http.StatusNotFound: + // Benign: reachable LPS with no hotfix published for this node yet. + // Surfaced as a typed error so the caller treats it as a no-op, not a failure. + return nil, &lpsUnavailableError{statusCode: resp.StatusCode} + default: + return nil, &lpsHTTPError{statusCode: resp.StatusCode} + } +} + +// lpsTargetFromNodeConfig reads the apiserver FQDN (the forced dial target) and the cluster +// CA (TLS trust) from the AKSNodeConfig. +// +// check-hotfix runs before the provisioning scripts (cse_config.sh), so the on-node decoded +// CA file (/etc/kubernetes/certs/ca.crt) does not exist yet -- it is written later during +// provisioning. The node config is the only credential source guaranteed to be present at +// this point and it carries the CA as base64-encoded PEM (the same value cse_config.sh later +// decodes into that file). +func (a *App) lpsTargetFromNodeConfig() (string, []byte, error) { + path := a.getNodeConfigPath() + raw, err := os.ReadFile(path) + if err != nil { + return "", nil, fmt.Errorf("reading node config %s: %w", path, err) + } + cfg, perr := nodeconfigutils.UnmarshalConfigurationV1(raw) + if perr != nil { + // Forward-compatible parse: unknown fields are discarded, so a non-nil error here + // means some fields were unusable. Continue with whatever parsed. + slog.Info("node config parsed with errors, continuing with partial config", "error", perr) + } + if cfg == nil { + if perr != nil { + return "", nil, fmt.Errorf("node config %s could not be parsed: %w", path, perr) + } + return "", nil, fmt.Errorf("node config %s could not be parsed", path) + } + + fqdn := cfg.GetApiServerConfig().GetApiServerName() + if fqdn == "" { + if perr != nil { + // A required field is missing and the unmarshal reported an error; surface the + // original parse failure as the root cause rather than masking it. + return "", nil, fmt.Errorf("node config has no api_server_config.api_server_name (parse error: %w)", perr) + } + return "", nil, fmt.Errorf("node config has no api_server_config.api_server_name") + } + + var caPEM []byte + if caB64 := cfg.GetKubernetesCaCert(); caB64 != "" { + decoded, derr := base64.StdEncoding.DecodeString(caB64) + if derr != nil { + return "", nil, fmt.Errorf("decoding node config kubernetes_ca_cert: %w", derr) + } + caPEM = decoded + } + return fqdn, caPEM, nil +} + +// attestedToken returns the IMDS attested-data signature used as the LPS Authorization +// token. The fetch is overridable for testing via App.fetchAttestedToken. +func (a *App) attestedToken(ctx context.Context) (string, error) { + if a.fetchAttestedToken != nil { + return a.fetchAttestedToken(ctx) + } + return fetchIMDSAttestedToken(ctx) +} + +// fetchIMDSAttestedToken queries IMDS for the attested-data document and returns its +// signature, the same primitive Secure TLS Bootstrap and the custom-patching flow use. IMDS +// is local and usually reliable, so it makes up to imdsMaxAttempts attempts (one quick retry) +// to smooth a one-off blip; each attempt is independently bounded by imdsFetchTimeout. +func fetchIMDSAttestedToken(ctx context.Context) (string, error) { + return common.RetryStringFetch(ctx, imdsMaxAttempts, fetchIMDSAttestedTokenOnce) +} + +// fetchIMDSAttestedTokenOnce performs a single IMDS attested-document GET and returns the +// document signature. +func fetchIMDSAttestedTokenOnce(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, imdsFetchTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsAttestedDocURL, nil) + if err != nil { + return "", err + } + req.Header.Set("Metadata", "true") + + // IMDS (169.254.169.254) is a link-local endpoint that must never be routed through an + // HTTP(S) proxy. The shared base transport disables proxying (unlike the default client, + // which honors HTTP(S)_PROXY env vars), matching the shell implementation. + imdsClient := &http.Client{ + Timeout: imdsFetchTimeout, + Transport: common.NewBaseTransport(common.HTTPTransportOptions{ + DialTimeout: imdsDialTimeout, + TLSHandshakeTimeout: imdsTLSHandshakeTimeout, + ResponseHeaderTimeout: imdsResponseHeaderTimeout, + }), + } + resp, err := imdsClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("imds returned status %d", resp.StatusCode) + } + var doc struct { + Signature string `json:"signature"` + } + if err := json.Unmarshal(body, &doc); err != nil { + return "", err + } + if doc.Signature == "" { + return "", fmt.Errorf("imds attested document had empty signature") + } + return doc.Signature, nil +} + +// buildLPSHTTPClient builds the HTTP client for the LPS fetch: TLS ServerName pinned to +// lpsSNIHost, the TCP dial forced to dialHost:443 (the curl --resolve equivalent), and +// RootCAs from the cluster CA. It returns the client and a string describing the TLS trust +// source ("provision-config") for diagnostics, with a short timeout so provisioning is never +// delayed. +// +// The cluster CA from the provision-config is REQUIRED. Without it the LPS server certificate +// cannot be verified, and rather than weaken TLS (InsecureSkipVerify) we return an error so +// the caller fails open (no overlay staged) instead of trusting an unverified channel. +func buildLPSHTTPClient(dialHost string, caPEM []byte) (*http.Client, string, error) { + if len(caPEM) == 0 { + return nil, "", fmt.Errorf("cluster CA unavailable from provision-config; refusing to fetch over unverified TLS") + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPEM) { + return nil, "", fmt.Errorf("failed to parse cluster CA PEM") + } + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12, ServerName: lpsSNIHost, RootCAs: pool} + caSource := "provision-config" + + // api_server_name may arrive with a scheme already stripped but still carry a port + // (e.g. "host:443"); normalize to a bare hostname so JoinHostPort does not produce an + // invalid "[host:443]:443" address. + host := dialHost + if h, _, splitErr := net.SplitHostPort(dialHost); splitErr == nil { + host = h + } + dialAddr := net.JoinHostPort(host, lpsAPIServerPort) + transport := common.NewBaseTransport(common.HTTPTransportOptions{ + DialTimeout: lpsDialTimeout, + TLSHandshakeTimeout: lpsTLSHandshakeTimeout, + ResponseHeaderTimeout: lpsResponseHeaderTimeout, + TLSConfig: tlsConfig, + // Force every dial to the apiserver front regardless of the SNI/Host (lpsSNIHost) in + // the request URL -- the curl --resolve equivalent. + DialAddrOverride: dialAddr, + }) + return &http.Client{Timeout: lpsFetchTimeout, Transport: transport}, caSource, nil +} + +// parseHotfixConfig extracts the hotfix pointer from an LPS response body. The body is the +// {"hotfixes":{...}} JSON object that unmarshals DIRECTLY into the shared 2.1a hotfixConfig, +// so check-hotfix and download-hotfix share ONE identical data contract. +func parseHotfixConfig(data []byte) (hotfixConfig, error) { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return hotfixConfig{}, fmt.Errorf("LPS response body is empty") + } + var cfg hotfixConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return hotfixConfig{}, fmt.Errorf("unmarshaling hotfix pointer JSON: %w", err) + } + return cfg, nil +} + +// coldStartHotfixConfig reads a LENIENT top-level "hotfixes" object from the AKSNodeConfig +// JSON. This is the PoC cold-start fallback used only when the LPS endpoint could not be +// reached or talked to (transport failure / 5xx). A benign 401/403/404 is NOT a cold-start: +// the LPS authoritatively has nothing for this node, so that path stages no overlay. +// +// TODO(provisioning-hotfix): There is no formalized AKSNodeConfig contract field for the +// embedded pointer yet - the control-plane side that would populate a typed field is not +// built. Once that contract exists, replace this lenient top-level read with the typed field +// and drop the permissive JSON shape. Until then we read it best-effort and never fail +// provisioning. +func (a *App) coldStartHotfixConfig() (hotfixConfig, bool, error) { + path := a.getNodeConfigPath() + raw, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return hotfixConfig{}, false, nil + } + return hotfixConfig{}, false, fmt.Errorf("reading node config %s: %w", path, err) + } + + // Lenient parse: the AKSNodeConfig is protojson, but the cold-start pointer is an + // out-of-contract top-level object, so parse it permissively with encoding/json. + var lenient struct { + Hotfixes map[string]string `json:"hotfixes"` + } + if err := json.Unmarshal(raw, &lenient); err != nil { + return hotfixConfig{}, false, fmt.Errorf("parsing cold-start hotfixes from node config: %w", err) + } + if len(lenient.Hotfixes) == 0 { + return hotfixConfig{}, false, nil + } + return hotfixConfig{Hotfixes: lenient.Hotfixes}, true, nil +} + +// writeHotfixConfig writes the resolved config to the path download-hotfix reads, in the +// exact {"hotfixes":{...}} shape so download-hotfix re-resolves and applies its unchanged +// gating. The write is atomic (temp file + rename) so a concurrent reader never sees a +// partial file. +func writeHotfixConfig(path string, cfg hotfixConfig) error { + // Only persist the map shape; the legacy Version field is intentionally omitted so the + // on-disk contract matches what the live-patching-service publishes. Marshal a dedicated + // struct without omitempty (and normalize nil -> empty map) so the staged file always has + // a stable top-level "hotfixes" key, i.e. {"hotfixes":{}} rather than {} for an empty map. + // hotfixConfig.Hotfixes carries omitempty for its own read path, so it must not be reused here. + hotfixes := cfg.Hotfixes + if hotfixes == nil { + hotfixes = map[string]string{} + } + out := struct { + Hotfixes map[string]string `json:"hotfixes"` + }{Hotfixes: hotfixes} + data, err := json.Marshal(out) + if err != nil { + return fmt.Errorf("marshaling hotfix config: %w", err) + } + + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, ".aks-node-controller-hotfix-*") + if err != nil { + return fmt.Errorf("creating temp file in %s: %w", dir, err) + } + tmpPath := tmp.Name() + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpPath) + return fmt.Errorf("writing temp file %s: %w", tmpPath, err) + } + if err := tmp.Close(); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("closing temp file %s: %w", tmpPath, err) + } + // CreateTemp defaults to 0600, but the same hotfix file is generated by cloud-init at + // 0644 (hotfix/anc_hotfix_generate.py), so rewriting it must not silently tighten the + // mode. Preserve the existing file's mode when present, otherwise match the 0644 contract. + fileMode := os.FileMode(0o644) + if info, statErr := os.Stat(path); statErr == nil { + fileMode = info.Mode().Perm() + } + if err := os.Chmod(tmpPath, fileMode); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("setting mode on temp file %s: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("renaming %s to %s: %w", tmpPath, path, err) + } + slog.Info("staged hotfix pointer for download-hotfix", "path", path) + return nil +} + +// getNodeConfigPath returns the injectable node-config path, defaulting to the standard +// AKSNodeConfig location that ANC already reads. +func (a *App) getNodeConfigPath() string { + if a.nodeConfigPath != "" { + return a.nodeConfigPath + } + return nodeconfigutils.AKSNodeConfigFilePath +} diff --git a/aks-node-controller/checkhotfix_test.go b/aks-node-controller/checkhotfix_test.go new file mode 100644 index 00000000000..ece32f35950 --- /dev/null +++ b/aks-node-controller/checkhotfix_test.go @@ -0,0 +1,607 @@ +package main + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testCAPEM is a self-signed CA certificate used to exercise the provision-config TLS +// trust path in buildLPSHTTPClient. +const testCAPEM = `-----BEGIN CERTIFICATE----- +MIIBVDCB+6ADAgECAgEBMAoGCCqGSM49BAMCMBIxEDAOBgNVBAMTB3Rlc3QtY2Ew +HhcNMjYwNjE5MjEwNDM4WhcNMzYwNjE2MjEwNDM4WjASMRAwDgYDVQQDEwd0ZXN0 +LWNhMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEDEsevoDBYiQ68iPrOeDKJLfJ +EhavIoHla/EJ5jy1EeaLp5qnDttz9IQe8PiZGSat6Dc1in8pwwQJkTcCwDMlzaNC +MEAwDgYDVR0PAQH/BAQDAgIEMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFI5z +oesQcLTRf96etb8XDK8w9wFRMAoGCCqGSM49BAMCA0gAMEUCIQCDOJZ8qJDAnEB1 +2LbXQPzOc3n5Pcz3lpwQnczk/UdVJAIgcFqNv0HsWdn7Img3gNsUgSaOT1M9QBAL +52RBAH6U7DI= +-----END CERTIFICATE----- +` + +// lpsPointerBody renders an LPS hotfix-pointer response body in the {"hotfixes":{...}} shape. +func lpsPointerBody(t *testing.T, hotfixes map[string]string) []byte { + t.Helper() + b, err := json.Marshal(map[string]any{"hotfixes": hotfixes}) + require.NoError(t, err) + return b +} + +// readStagedConfig reads back the hotfix config check-hotfix wrote. +func readStagedConfig(t *testing.T, path string) hotfixConfig { + t.Helper() + data, err := os.ReadFile(path) + require.NoError(t, err) + var cfg hotfixConfig + require.NoError(t, json.Unmarshal(data, &cfg)) + return cfg +} + +func TestParseHotfixConfig(t *testing.T) { + t.Run("parses the hotfixes object directly", func(t *testing.T) { + cfg, err := parseHotfixConfig([]byte(`{"hotfixes":{"202604.01":"202604.01.1","202605.01":"202605.01.2"}}`)) + require.NoError(t, err) + assert.Equal(t, map[string]string{"202604.01": "202604.01.1", "202605.01": "202605.01.2"}, cfg.Hotfixes) + }) + + t.Run("tolerates surrounding whitespace", func(t *testing.T) { + cfg, err := parseHotfixConfig([]byte(" \n{\"hotfixes\":{\"202604.01\":\"202604.01.1\"}}\n ")) + require.NoError(t, err) + assert.Equal(t, map[string]string{"202604.01": "202604.01.1"}, cfg.Hotfixes) + }) + + t.Run("empty body is an error", func(t *testing.T) { + _, err := parseHotfixConfig([]byte(" ")) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty") + }) + + t.Run("invalid JSON is an error", func(t *testing.T) { + _, err := parseHotfixConfig([]byte("not json")) + require.Error(t, err) + assert.Contains(t, err.Error(), "unmarshaling hotfix pointer JSON") + }) + + t.Run("shares parser shape with download-hotfix readHotfixConfig", func(t *testing.T) { + // The body served by the LPS must round-trip through the SAME shape that + // download-hotfix's readHotfixConfig consumes. + body := `{"hotfixes":{"202604.01":"202604.01.3"}}` + fromLPS, err := parseHotfixConfig([]byte(body)) + require.NoError(t, err) + + path := filepath.Join(t.TempDir(), "h.json") + require.NoError(t, os.WriteFile(path, []byte(body), 0644)) + fromFile, err := readHotfixConfig(path) + require.NoError(t, err) + assert.Equal(t, fromFile, fromLPS) + }) +} + +func TestCheckHotfix_SuccessReadAndWrite(t *testing.T) { + origVersion := Version + Version = "202604.01.0" + defer func() { Version = origVersion }() + + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return lpsPointerBody(t, map[string]string{"202604.01": "202604.01.1"}), nil + } + + outcome, err := tt.App.checkHotfix(context.Background()) + require.NoError(t, err) + assert.Equal(t, outcomeLPSRead, outcome) + + cfg := readStagedConfig(t, path) + assert.Equal(t, map[string]string{"202604.01": "202604.01.1"}, cfg.Hotfixes) +} + +func TestCheckHotfix_NoHotfixForBase(t *testing.T) { + origVersion := Version + Version = "202607.15.0" // base not present in the LPS pointer + defer func() { Version = origVersion }() + + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return lpsPointerBody(t, map[string]string{"202604.01": "202604.01.1"}), nil + } + + outcome, err := tt.App.checkHotfix(context.Background()) + require.NoError(t, err) + assert.Equal(t, outcomeNoHotfixForBase, outcome) + + // The full pointer is still staged so download-hotfix re-resolves authoritatively. + cfg := readStagedConfig(t, path) + assert.Equal(t, map[string]string{"202604.01": "202604.01.1"}, cfg.Hotfixes) +} + +// TestCheckHotfix_LegacyOnlyPointerReportsNoHotfixForBase guards the telemetry/staging +// consistency contract: writeHotfixConfig persists only the Hotfixes map (the legacy Version +// field is dropped), so a legacy-only pointer stages nothing resolvable. The reported outcome +// must match what download-hotfix will actually read - noHotfixForBase, not LPSRead. +func TestCheckHotfix_LegacyOnlyPointerReportsNoHotfixForBase(t *testing.T) { + origVersion := Version + Version = "202604.01.0" + defer func() { Version = origVersion }() + + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + // Legacy shape: a top-level "version" pointer with no "hotfixes" map. + return []byte(`{"version":"202604.01.9"}`), nil + } + + outcome, err := tt.App.checkHotfix(context.Background()) + require.NoError(t, err) + assert.Equal(t, outcomeNoHotfixForBase, outcome) + + // The staged file must not carry the legacy Version - only the (empty) map shape. + cfg := readStagedConfig(t, path) + assert.Empty(t, cfg.Version) + assert.Empty(t, cfg.Hotfixes) +} + +func TestCheckHotfix_LPSUnavailableIsBenign(t *testing.T) { + // A reachable LPS that has no hotfix published for this node (HTTP 401, 403, 404) is the + // expected steady state. It must be a benign no-op: outcome noHotfixAvailable, no error, + // nothing staged, and NO cold-start overlay even when the node config carries an embedded + // pointer. + statuses := map[string]int{ + "401 unauthorized": http.StatusUnauthorized, + "403 forbidden": http.StatusForbidden, + "404 not found": http.StatusNotFound, + } + for name, code := range statuses { + t.Run(name, func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + + // Even with a cold-start pointer present, a benign LPS answer stages no overlay. + nodeConfig := filepath.Join(t.TempDir(), "aks-node-controller-config.json") + require.NoError(t, os.WriteFile(nodeConfig, []byte( + `{"version":"v1","hotfixes":{"202604.01":"202604.01.9"}}`), 0644)) + tt.App.nodeConfigPath = nodeConfig + + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, &lpsUnavailableError{statusCode: code} + } + + outcome, err := tt.App.checkHotfix(context.Background()) + require.NoError(t, err) + assert.Equal(t, outcomeNoHotfixAvailable, outcome) + + // No overlay staged. + _, statErr := os.Stat(path) + assert.True(t, os.IsNotExist(statErr)) + }) + } +} + +func TestCheckHotfix_FetchErrorFailsOpenWithoutFallback(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + // No node config -> no cold-start fallback available. + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nonexistent-config.json") + + // Transport-level failures (not benign 401/403/404) with no fallback -> failed. + cases := map[string]error{ + "timeout": context.DeadlineExceeded, + "connection err": errors.New("dial tcp: connection refused"), + "server error": errors.New("LPS returned status 500"), + } + for name, fetchErr := range cases { + t.Run(name, func(t *testing.T) { + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, fetchErr + } + outcome, err := tt.App.checkHotfix(context.Background()) + assert.Equal(t, outcomeFailed, outcome) + assert.Error(t, err) + // Nothing should be staged. + _, statErr := os.Stat(path) + assert.True(t, os.IsNotExist(statErr)) + }) + } +} + +func TestCheckHotfix_InvalidPointerFailsOpen(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return []byte("not valid json"), nil + } + + outcome, err := tt.App.checkHotfix(context.Background()) + assert.Equal(t, outcomeFailed, outcome) + assert.Error(t, err) + _, statErr := os.Stat(path) + assert.True(t, os.IsNotExist(statErr)) +} + +func TestCheckHotfix_ColdStartFallback(t *testing.T) { + origVersion := Version + Version = "202604.01.0" + defer func() { Version = origVersion }() + + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + + // Node config carries a lenient top-level hotfixes pointer (PoC cold-start contract). + nodeConfig := filepath.Join(t.TempDir(), "aks-node-controller-config.json") + require.NoError(t, os.WriteFile(nodeConfig, []byte( + `{"version":"v1","hotfixes":{"202604.01":"202604.01.2"}}`), 0644)) + tt.App.nodeConfigPath = nodeConfig + + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, errors.New("dial tcp: connection refused") + } + + outcome, err := tt.App.checkHotfix(context.Background()) + require.NoError(t, err) + assert.Equal(t, outcomeCustomDataFallback, outcome) + + cfg := readStagedConfig(t, path) + assert.Equal(t, map[string]string{"202604.01": "202604.01.2"}, cfg.Hotfixes) +} + +func TestCheckHotfix_ColdStartNoPointerFails(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + + nodeConfig := filepath.Join(t.TempDir(), "aks-node-controller-config.json") + require.NoError(t, os.WriteFile(nodeConfig, []byte(`{"version":"v1"}`), 0644)) + tt.App.nodeConfigPath = nodeConfig + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, errors.New("dial tcp: connection refused") + } + + outcome, err := tt.App.checkHotfix(context.Background()) + assert.Equal(t, outcomeFailed, outcome) + assert.Error(t, err) + _, statErr := os.Stat(path) + assert.True(t, os.IsNotExist(statErr)) +} + +// TestCheckHotfix_FallbackOnlyForUnreachableLPS verifies that the cold-start fallback is used +// only when the LPS could not be reached or is server-broken (transport error / 5xx). A +// reachable LPS returning a non-benign 4xx (e.g. 400/429) is authoritative: check-hotfix must +// NOT stage the (possibly stale) cold-start pointer even though one is present. +func TestCheckHotfix_FallbackOnlyForUnreachableLPS(t *testing.T) { + origVersion := Version + Version = "202604.01.0" + defer func() { Version = origVersion }() + + cases := []struct { + name string + fetchErr error + wantOutcome checkHotfixOutcome + wantStaged bool + }{ + {"5xx falls back to cold-start", &lpsHTTPError{statusCode: 503}, outcomeCustomDataFallback, true}, + {"transport error falls back to cold-start", errors.New("dial tcp: connection refused"), outcomeCustomDataFallback, true}, + {"non-benign 4xx does not fall back", &lpsHTTPError{statusCode: 429}, outcomeFailed, false}, + {"400 does not fall back", &lpsHTTPError{statusCode: 400}, outcomeFailed, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + path := filepath.Join(t.TempDir(), "hotfix.json") + tt.App.hotfixVersionPath = path + + // A cold-start pointer IS present, so the outcome hinges purely on whether the + // fetch error is treated as "unreachable" (fall back) or "reachable client error". + nodeConfig := filepath.Join(t.TempDir(), "aks-node-controller-config.json") + require.NoError(t, os.WriteFile(nodeConfig, []byte( + `{"version":"v1","hotfixes":{"202604.01":"202604.01.2"}}`), 0644)) + tt.App.nodeConfigPath = nodeConfig + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, tc.fetchErr + } + + outcome, _ := tt.App.checkHotfix(context.Background()) + assert.Equal(t, tc.wantOutcome, outcome) + _, statErr := os.Stat(path) + if tc.wantStaged { + assert.NoError(t, statErr, "expected the cold-start pointer to be staged") + } else { + assert.True(t, os.IsNotExist(statErr), "expected nothing to be staged") + } + }) + } +} + +// TestRunCheckHotfixCommand_AlwaysFailOpen verifies the cli Action always returns nil +// (exit 0) and emits telemetry, regardless of the underlying outcome. +func TestRunCheckHotfixCommand_AlwaysFailOpen(t *testing.T) { + t.Run("success path emits informational event and exits 0", func(t *testing.T) { + origVersion := Version + Version = "202604.01.0" + defer func() { Version = origVersion }() + + tt := NewTestApp(t, TestAppConfig{}) + tt.App.hotfixVersionPath = filepath.Join(t.TempDir(), "hotfix.json") + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return lpsPointerBody(t, map[string]string{"202604.01": "202604.01.1"}), nil + } + + err := tt.App.runCheckHotfixCommand(context.Background()) + require.NoError(t, err) + + events := tt.eventLogger.Events() + require.Len(t, events, 1) + assert.Equal(t, "AKS.AKSNodeController.CheckHotfix", events[0].TaskName) + assert.Equal(t, "Informational", events[0].EventLevel) + assert.Contains(t, events[0].Message, string(outcomeLPSRead)) + }) + + t.Run("failure path emits error event but still exits 0", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.hotfixVersionPath = filepath.Join(t.TempDir(), "hotfix.json") + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nonexistent.json") + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, errors.New("LPS returned status 500") + } + + err := tt.App.runCheckHotfixCommand(context.Background()) + require.NoError(t, err) + + events := tt.eventLogger.Events() + require.Len(t, events, 1) + assert.Equal(t, "AKS.AKSNodeController.CheckHotfix", events[0].TaskName) + assert.Equal(t, "Error", events[0].EventLevel) + assert.Contains(t, events[0].Message, string(outcomeFailed)) + }) + + t.Run("cli wiring returns exit code 0 even on fetch failure", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.hotfixVersionPath = filepath.Join(t.TempDir(), "hotfix.json") + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nonexistent.json") + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + return nil, errors.New("boom") + } + exitCode := tt.App.Run(context.Background(), []string{"aks-node-controller", "check-hotfix"}) + assert.Equal(t, 0, exitCode) + }) + + t.Run("a panic in the workflow is recovered and still exits 0", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.hotfixVersionPath = filepath.Join(t.TempDir(), "hotfix.json") + // A fetcher that panics simulates an unexpected crash anywhere in the workflow. + tt.App.checkHotfixFetcher = func(context.Context) ([]byte, error) { + panic("unexpected boom") + } + + err := tt.App.runCheckHotfixCommand(context.Background()) + require.NoError(t, err, "a panic must be swallowed so provisioning proceeds") + + events := tt.eventLogger.Events() + require.Len(t, events, 1) + assert.Equal(t, "AKS.AKSNodeController.CheckHotfix", events[0].TaskName) + assert.Equal(t, "Error", events[0].EventLevel) + assert.Contains(t, events[0].Message, string(outcomeFailed)) + assert.Contains(t, events[0].Message, "panic") + }) +} + +func TestCheckHotfix_DefaultsToLPSFetcherWhenNoInjection(t *testing.T) { + // With no injected fetcher and no readable node config, the real path is exercised: it + // must fail-open. Point the node-config source at a nonexistent path so LPS endpoint + // resolution fails deterministically and the network is never actually dialed. + tt := NewTestApp(t, TestAppConfig{}) + tt.App.hotfixVersionPath = filepath.Join(t.TempDir(), "hotfix.json") + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nonexistent.json") + // checkHotfixFetcher intentionally nil. + + err := tt.App.runCheckHotfixCommand(context.Background()) + require.NoError(t, err) +} + +func TestAttestedToken_InjectionOverridesIMDS(t *testing.T) { + t.Run("injected token is returned without networking", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.fetchAttestedToken = func(context.Context) (string, error) { + return "injected-signature", nil + } + tok, err := tt.App.attestedToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "injected-signature", tok) + }) + + t.Run("injected error propagates", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.fetchAttestedToken = func(context.Context) (string, error) { + return "", errors.New("imds down") + } + _, err := tt.App.attestedToken(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "imds down") + }) +} + +func TestLPSTargetFromNodeConfig(t *testing.T) { + // A minimal AKSNodeConfig in the on-disk shape: MarshalConfigurationV1 sets + // UseProtoNames=true, so production JSON uses proto (snake_case) field names. + caPEM := "-----BEGIN CERTIFICATE-----\nMIIB\n-----END CERTIFICATE-----\n" + caB64 := base64.StdEncoding.EncodeToString([]byte(caPEM)) + + t.Run("reads fqdn and decodes CA", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + p := filepath.Join(t.TempDir(), "config.json") + body := `{"version":"v1","api_server_config":{"api_server_name":"myapi.example.com"},"kubernetes_ca_cert":"` + caB64 + `"}` + require.NoError(t, os.WriteFile(p, []byte(body), 0644)) + tt.App.nodeConfigPath = p + + fqdn, ca, err := tt.App.lpsTargetFromNodeConfig() + require.NoError(t, err) + assert.Equal(t, "myapi.example.com", fqdn) + assert.Equal(t, []byte(caPEM), ca) + }) + + t.Run("missing apiserver name is an error", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + p := filepath.Join(t.TempDir(), "config.json") + require.NoError(t, os.WriteFile(p, []byte(`{"version":"v1"}`), 0644)) + tt.App.nodeConfigPath = p + + _, _, err := tt.App.lpsTargetFromNodeConfig() + require.Error(t, err) + assert.Contains(t, err.Error(), "api_server_name") + }) + + t.Run("missing file is an error", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nope.json") + _, _, err := tt.App.lpsTargetFromNodeConfig() + require.Error(t, err) + }) +} + +func TestBuildLPSHTTPClient(t *testing.T) { + t.Run("invalid CA PEM is an error", func(t *testing.T) { + _, _, err := buildLPSHTTPClient("myapi.example.com", []byte("not a pem")) + require.Error(t, err) + assert.Contains(t, err.Error(), "cluster CA PEM") + }) + + t.Run("valid CA pins ServerName and reports provision-config trust", func(t *testing.T) { + // A real (self-signed) cert PEM so AppendCertsFromPEM succeeds. + client, caSource, err := buildLPSHTTPClient("myapi.example.com", []byte(testCAPEM)) + require.NoError(t, err) + assert.Equal(t, "provision-config", caSource) + assert.Equal(t, lpsFetchTimeout, client.Timeout) + tr, ok := client.Transport.(*http.Transport) + require.True(t, ok) + assert.Equal(t, lpsSNIHost, tr.TLSClientConfig.ServerName) + assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) + // The shared base transport must apply the fail-fast per-phase budgets and disable proxying. + assert.Nil(t, tr.Proxy) + assert.Equal(t, lpsTLSHandshakeTimeout, tr.TLSHandshakeTimeout) + assert.Equal(t, lpsResponseHeaderTimeout, tr.ResponseHeaderTimeout) + }) + + t.Run("no CA is a hard error (no insecure fallback)", func(t *testing.T) { + _, _, err := buildLPSHTTPClient("myapi.example.com", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cluster CA unavailable") + }) + + t.Run("dialHost with an existing port is normalized", func(t *testing.T) { + // api_server_name may carry a port; the client must still build (no insecure + // fallback) and JoinHostPort must not produce an invalid "[host:443]:443" address. + client, _, err := buildLPSHTTPClient("myapi.example.com:443", []byte(testCAPEM)) + require.NoError(t, err) + tr, ok := client.Transport.(*http.Transport) + require.True(t, ok) + assert.Equal(t, lpsSNIHost, tr.TLSClientConfig.ServerName) + }) +} + +func TestColdStartHotfixConfig(t *testing.T) { + t.Run("missing file returns not-present without error", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + tt.App.nodeConfigPath = filepath.Join(t.TempDir(), "nope.json") + cfg, ok, err := tt.App.coldStartHotfixConfig() + require.NoError(t, err) + assert.False(t, ok) + assert.Nil(t, cfg.Hotfixes) + }) + + t.Run("present pointer is parsed", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + p := filepath.Join(t.TempDir(), "config.json") + require.NoError(t, os.WriteFile(p, []byte(`{"version":"v1","hotfixes":{"202604.01":"202604.01.5"}}`), 0644)) + tt.App.nodeConfigPath = p + cfg, ok, err := tt.App.coldStartHotfixConfig() + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, map[string]string{"202604.01": "202604.01.5"}, cfg.Hotfixes) + }) + + t.Run("no hotfixes key returns not-present", func(t *testing.T) { + tt := NewTestApp(t, TestAppConfig{}) + p := filepath.Join(t.TempDir(), "config.json") + require.NoError(t, os.WriteFile(p, []byte(`{"version":"v1"}`), 0644)) + tt.App.nodeConfigPath = p + _, ok, err := tt.App.coldStartHotfixConfig() + require.NoError(t, err) + assert.False(t, ok) + }) +} + +func TestWriteHotfixConfig_ShapeAndAtomicity(t *testing.T) { + path := filepath.Join(t.TempDir(), "hotfix.json") + require.NoError(t, writeHotfixConfig(path, hotfixConfig{Hotfixes: map[string]string{"202604.01": "202604.01.1"}})) + + raw, err := os.ReadFile(path) + require.NoError(t, err) + // Must serialize in the {"hotfixes":{...}} shape with no legacy version field. + assert.JSONEq(t, `{"hotfixes":{"202604.01":"202604.01.1"}}`, string(raw)) + + // Round-trips through download-hotfix's reader. + cfg, err := readHotfixConfig(path) + require.NoError(t, err) + assert.Equal(t, "202604.01.1", cfg.resolveVersion("202604.01.0")) +} + +// TestWriteHotfixConfig_EmptyMapKeepsStableKey guards the on-disk/LPS contract: even when the +// map is empty or nil, the staged file must retain a top-level "hotfixes" key ({"hotfixes":{}}) +// rather than collapsing to {} (which hotfixConfig's omitempty tag would otherwise produce). +func TestWriteHotfixConfig_EmptyMapKeepsStableKey(t *testing.T) { + cases := map[string]hotfixConfig{ + "empty map": {Hotfixes: map[string]string{}}, + "nil map": {Hotfixes: nil}, + } + for name, cfg := range cases { + t.Run(name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "hotfix.json") + require.NoError(t, writeHotfixConfig(path, cfg)) + + raw, err := os.ReadFile(path) + require.NoError(t, err) + assert.JSONEq(t, `{"hotfixes":{}}`, string(raw)) + }) + } +} + +func TestWriteHotfixConfig_FileMode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix file modes are not represented on windows") + } + cfg := hotfixConfig{Hotfixes: map[string]string{"202604.01": "202604.01.1"}} + + t.Run("new file uses the 0644 cloud-init contract (not CreateTemp's 0600)", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "hotfix.json") + require.NoError(t, writeHotfixConfig(path, cfg)) + info, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o644), info.Mode().Perm()) + }) + + t.Run("existing file mode is preserved", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "hotfix.json") + require.NoError(t, os.WriteFile(path, []byte("{}"), 0o600)) + require.NoError(t, writeHotfixConfig(path, cfg)) + info, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + }) +} diff --git a/aks-node-controller/common/httpclient.go b/aks-node-controller/common/httpclient.go new file mode 100644 index 00000000000..b7a43746531 --- /dev/null +++ b/aks-node-controller/common/httpclient.go @@ -0,0 +1,83 @@ +// Package common holds low-level primitives shared across the aks-node-controller commands. +// +// The HTTP helpers here are the reusable transport/retry mechanics for the check-hotfix +// network calls (the IMDS attested-token GET and the live-patching-service hotfix-pointer GET) +// and are intentionally domain-agnostic: callers supply their own timeouts, TLS trust, dial +// override, overall deadline, and retry policy. Endpoint identity and timeout tuning live with +// the caller, not here. +// +// TODO(provisioning-hotfix): when the check-lps connectivity client (lps.go) lands in main, +// have it consume NewBaseTransport too so there is a single canonical HTTP client. +package common + +import ( + "context" + "crypto/tls" + "log/slog" + "net" + "net/http" + "time" +) + +// HTTPTransportOptions configures the base transport built by NewBaseTransport. +type HTTPTransportOptions struct { + // DialTimeout bounds the TCP connect. + DialTimeout time.Duration + // TLSHandshakeTimeout bounds the TLS handshake (ignored for plain-HTTP endpoints). + TLSHandshakeTimeout time.Duration + // ResponseHeaderTimeout bounds the wait for response headers after the request is written. + ResponseHeaderTimeout time.Duration + // TLSConfig is the TLS client config; nil for plain-HTTP endpoints (e.g. IMDS). + TLSConfig *tls.Config + // DialAddrOverride forces every dial to this host:port regardless of the URL host (the + // curl --resolve equivalent used by the LPS SNI-pin path). Empty means dial the URL host. + DialAddrOverride string +} + +// NewBaseTransport builds an *http.Transport with fail-fast connect/handshake/response +// timeouts and proxying disabled. Proxying is disabled unconditionally: the LPS client forces +// its dial to the apiserver front (a proxy CONNECT would defeat the --resolve pin) and IMDS is +// a link-local endpoint that must never be proxied. Retry and overall-deadline policy are +// layered by callers. +func NewBaseTransport(opts HTTPTransportOptions) *http.Transport { + dialer := &net.Dialer{Timeout: opts.DialTimeout} + return &http.Transport{ + Proxy: nil, + TLSClientConfig: opts.TLSConfig, + TLSHandshakeTimeout: opts.TLSHandshakeTimeout, + ResponseHeaderTimeout: opts.ResponseHeaderTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if opts.DialAddrOverride != "" { + addr = opts.DialAddrOverride + } + return dialer.DialContext(ctx, network, addr) + }, + } +} + +// RetryStringFetch calls fn up to maxAttempts times and returns the first success. It stops +// early if the context is done, since a retry cannot then succeed. It does NOT sleep between +// attempts: each attempt is already bounded by its own timeout and the provisioning path wants +// fail-fast behavior. Used for the IMDS attested-token fetch (one quick retry). +func RetryStringFetch(ctx context.Context, maxAttempts int, fn func(context.Context) (string, error)) (string, error) { + // Normalize to at least one attempt so a zero/negative maxAttempts cannot silently return + // ("", nil) -- that would make a misconfiguration look like a successful empty fetch. + if maxAttempts < 1 { + maxAttempts = 1 + } + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + v, err := fn(ctx) + if err == nil { + return v, nil + } + lastErr = err + if ctx.Err() != nil { + break + } + if attempt < maxAttempts { + slog.Warn("fetch attempt failed, retrying", "attempt", attempt, "maxAttempts", maxAttempts, "error", err) + } + } + return "", lastErr +} diff --git a/aks-node-controller/common/httpclient_test.go b/aks-node-controller/common/httpclient_test.go new file mode 100644 index 00000000000..ebace2fee08 --- /dev/null +++ b/aks-node-controller/common/httpclient_test.go @@ -0,0 +1,126 @@ +package common + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBaseTransport(t *testing.T) { + t.Run("disables proxy and sets the per-phase timeouts", func(t *testing.T) { + tr := NewBaseTransport(HTTPTransportOptions{ + DialTimeout: 2 * time.Second, + TLSHandshakeTimeout: 2 * time.Second, + ResponseHeaderTimeout: 2 * time.Second, + }) + assert.Nil(t, tr.Proxy, "proxy must be disabled for the forced-dial/link-local paths") + assert.Equal(t, 2*time.Second, tr.TLSHandshakeTimeout) + assert.Equal(t, 2*time.Second, tr.ResponseHeaderTimeout) + require.NotNil(t, tr.DialContext) + }) + + t.Run("dialAddrOverride redirects every dial to the override addr", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer srv.Close() + + tr := NewBaseTransport(HTTPTransportOptions{ + DialTimeout: 2 * time.Second, + DialAddrOverride: srv.Listener.Addr().String(), + }) + client := &http.Client{Transport: tr, Timeout: 3 * time.Second} + + // Request a host that does not resolve; the override must redirect the dial to the + // test server (the curl --resolve equivalent used by the LPS SNI-pin path). + resp, err := client.Get("http://this-host-does-not-resolve.invalid/") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("without an override the URL host is dialed", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + defer srv.Close() + + tr := NewBaseTransport(HTTPTransportOptions{DialTimeout: 2 * time.Second}) + client := &http.Client{Transport: tr, Timeout: 3 * time.Second} + + resp, err := client.Get(srv.URL) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusTeapot, resp.StatusCode) + }) +} + +func TestRetryStringFetch(t *testing.T) { + t.Run("returns the first success without extra calls", func(t *testing.T) { + calls := 0 + v, err := RetryStringFetch(context.Background(), 2, func(context.Context) (string, error) { + calls++ + return "ok", nil + }) + require.NoError(t, err) + assert.Equal(t, "ok", v) + assert.Equal(t, 1, calls) + }) + + t.Run("retries once then succeeds", func(t *testing.T) { + calls := 0 + v, err := RetryStringFetch(context.Background(), 2, func(context.Context) (string, error) { + calls++ + if calls == 1 { + return "", errors.New("transient") + } + return "ok", nil + }) + require.NoError(t, err) + assert.Equal(t, "ok", v) + assert.Equal(t, 2, calls) + }) + + t.Run("returns the last error after exhausting attempts", func(t *testing.T) { + calls := 0 + _, err := RetryStringFetch(context.Background(), 2, func(context.Context) (string, error) { + calls++ + return "", fmt.Errorf("attempt %d", calls) + }) + require.Error(t, err) + assert.Equal(t, 2, calls) + assert.Contains(t, err.Error(), "attempt 2") + }) + + t.Run("stops early when the context is already done", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + calls := 0 + _, err := RetryStringFetch(ctx, 3, func(context.Context) (string, error) { + calls++ + return "", errors.New("boom") + }) + require.Error(t, err) + assert.Equal(t, 1, calls, "must not retry once the context is done") + }) + + t.Run("normalizes maxAttempts below 1 to a single attempt", func(t *testing.T) { + for _, maxAttempts := range []int{0, -1} { + calls := 0 + _, err := RetryStringFetch(context.Background(), maxAttempts, func(context.Context) (string, error) { + calls++ + return "", errors.New("boom") + }) + require.Error(t, err, "maxAttempts=%d must surface a real error, not (\"\", nil)", maxAttempts) + assert.Equal(t, 1, calls, "maxAttempts=%d must still attempt exactly once", maxAttempts) + } + }) +}