diff --git a/ee/server/service/apple_psso.go b/ee/server/service/apple_psso.go index bac596c49a8..08085525cff 100644 --- a/ee/server/service/apple_psso.go +++ b/ee/server/service/apple_psso.go @@ -141,8 +141,9 @@ func parsePSSOSigningKeyPEM(pemBytes []byte) (*ecdsa.PrivateKey, string, error) } // computeKID returns base64url-nopad SHA-256 of the SubjectPublicKeyInfo DER -// encoding of pub. This matches the kid format the extension sends with its -// JWTs (SHA-256 of the public key bytes, base64'd). +// encoding of pub. Used only for Fleet's own signing key (JWKS/JWT kid). +// Device key kids are different: the extension computes them as SHA-256 of +// the raw X9.63 point bytes and submits them at registration. func computeKID(pub *ecdsa.PublicKey) (string, error) { der, err := x509.MarshalPKIXPublicKey(pub) if err != nil { @@ -252,8 +253,8 @@ func (svc *Service) PSSORegisterBegin(ctx context.Context) (string, error) { } // PSSORegisterComplete consumes the device-key enrollment POST from the Mac -// extension: it resolves the enrolled host from the hardware device UUID, -// mints a KeyExchangeKey, and persists the device record + KeyID rows. +// extension: it resolves the enrolled host from the hardware device UUID and +// persists the device record plus its public key rows. // // Password-mode registration carries no OAuth code/state — the extension // simply submits the public halves of its Secure Enclave signing and @@ -269,9 +270,19 @@ func (svc *Service) PSSORegisterComplete(ctx context.Context, req fleet.PSSORegi return &fleet.BadRequestError{Message: "missing required psso register fields"} } - // Resolve host_id from device UUID. PSSO requires a matching enrolled host - // since the device record is keyed by host_id. - host, err := svc.ds.HostLiteByIdentifier(ctx, req.DeviceUUID) + // Reject unparseable key material up front: a bad PEM stored here would + // otherwise only surface as opaque verification failures at every + // subsequent login. + if _, err := parseECPublicKeyPEM([]byte(req.DeviceSigningKey)); err != nil { + return &fleet.BadRequestError{Message: "psso register: signing key is not a valid P-256 public key"} + } + if _, err := parseECPublicKeyPEM([]byte(req.DeviceEncryptionKey)); err != nil { + return &fleet.BadRequestError{Message: "psso register: encryption key is not a valid P-256 public key"} + } + + // PSSO requires a matching enrolled host; the registration is keyed by the + // host's UUID. + host, err := svc.ds.HostByUUID(ctx, req.DeviceUUID) if err != nil { if fleet.IsNotFound(err) { return &fleet.BadRequestError{Message: fmt.Sprintf("psso register: no enrolled host matches device UUID %q", req.DeviceUUID)} @@ -279,37 +290,22 @@ func (svc *Service) PSSORegisterComplete(ctx context.Context, req fleet.PSSORegi return ctxerr.Wrap(ctx, err, "look up host by device uuid") } - // Mint a 32-byte KeyExchangeKey. This is the v2 secret returned to the - // device on its first key_request and reused for symmetric session keys - // thereafter. - var kek [32]byte - if _, err := rand.Read(kek[:]); err != nil { - return ctxerr.Wrap(ctx, err, "generate key exchange key") - } - - device := fleet.PSSODevice{ - HostID: host.ID, - DeviceUUID: req.DeviceUUID, - SigningKeyPEM: req.DeviceSigningKey, - EncryptionKeyPEM: req.DeviceEncryptionKey, - KeyExchangeKey: kek[:], - } // Store kids in canonical form so the token endpoint's lookup (which // canonicalizes the JWT's kid) matches regardless of base64 padding or // alphabet differences between the extension and Apple's framework. - signKID := fleet.PSSOKeyID{ - KID: canonicalizeKID(req.SignKeyID), - HostID: host.ID, - KeyType: fleet.PSSOKeyTypeSigning, - PEM: req.DeviceSigningKey, - } - encKID := fleet.PSSOKeyID{ - KID: canonicalizeKID(req.EncKeyID), - HostID: host.ID, - KeyType: fleet.PSSOKeyTypeEncryption, - PEM: req.DeviceEncryptionKey, - } - if err := svc.ds.SetOrUpdatePSSODevice(ctx, device, signKID, encKID); err != nil { + keys := []fleet.PSSOKey{ + { + KID: canonicalizeKID(req.SignKeyID), + KeyType: fleet.PSSOKeyTypeSigning, + PEM: req.DeviceSigningKey, + }, + { + KID: canonicalizeKID(req.EncKeyID), + KeyType: fleet.PSSOKeyTypeEncryption, + PEM: req.DeviceEncryptionKey, + }, + } + if err := svc.ds.SetOrUpdatePSSODevice(ctx, host.UUID, keys); err != nil { return ctxerr.Wrap(ctx, err, "persist psso device registration") } return nil @@ -328,7 +324,7 @@ func (svc *Service) PSSOToken(ctx context.Context, jwtBytes []byte) ([]byte, err return nil, &fleet.BadRequestError{Message: "psso token: empty request body"} } - claims, device, err := svc.parsePSSOInboundJWT(ctx, jwtBytes) + claims, signKey, err := svc.parsePSSOInboundJWT(ctx, jwtBytes) if err != nil { return nil, err } @@ -336,18 +332,14 @@ func (svc *Service) PSSOToken(ctx context.Context, jwtBytes []byte) ([]byte, err // PSSO v2 Password login: a single grant_type=password round trip carrying // a plaintext password and a jwe_crypto response recipe. if claims.GrantType == pssoGrantTypePassword { - return svc.handlePSSOPasswordLogin(ctx, device, claims) + return svc.handlePSSOPasswordLogin(ctx, signKey.HostUUID, claims) } - // Legacy request_type handshake model — retained but not exercised by the - // Password flow. switch claims.RequestType { case pssoRequestKey: - return svc.handlePSSOKeyRequest(ctx, device, claims) + return svc.handlePSSOKeyRequest(ctx, signKey.HostUUID, claims) case pssoRequestExchange: - return svc.handlePSSOKeyExchange(ctx, device, claims) - case pssoRequestPassword: - return svc.handlePSSOPasswordRequest(ctx, device, claims) + return svc.handlePSSOKeyExchange(ctx, signKey.HostUUID, claims) default: return nil, &fleet.BadRequestError{Message: "psso token: unsupported grant_type/request_type"} } @@ -391,7 +383,7 @@ func (svc *Service) pssoIDTokenIssuer(ctx context.Context) (string, error) { // Fleet validates the password against the upstream IdP, then returns the // resulting OIDC claims as a server-signed JWT wrapped in a JWE encrypted per // that recipe. -func (svc *Service) handlePSSOPasswordLogin(ctx context.Context, device *fleet.PSSODevice, claims *pssoTokenClaims) ([]byte, error) { +func (svc *Service) handlePSSOPasswordLogin(ctx context.Context, hostUUID string, claims *pssoTokenClaims) ([]byte, error) { if svc.pssoIdPClient == nil { return nil, ctxerr.New(ctx, "psso idp client not configured") } @@ -429,9 +421,9 @@ func (svc *Service) handlePSSOPasswordLogin(ctx context.Context, device *fleet.P return nil, ctxerr.Wrap(ctx, err, "psso password validation") } - recipientPub, err := parseECPublicKeyPEM([]byte(device.EncryptionKeyPEM)) + recipientPub, err := svc.resolvePSSOEncryptionKey(ctx, hostUUID, claims.JWECrypto.APV) if err != nil { - return nil, ctxerr.Wrap(ctx, err, "parse device encryption pubkey") + return nil, ctxerr.Wrap(ctx, err, "resolve device encryption pubkey") } // Per Apple's JWE login-response doc, the response id_token is verified by @@ -503,13 +495,13 @@ func (svc *Service) handlePSSOPasswordLogin(ctx context.Context, device *fleet.P // key_context} in a JWE (typ=platformsso-key-response+jwt) encrypted to the // device. key_context carries the provisioned PRIVATE key, sealed under a // server key, so the later key exchange can recover it statelessly. -func (svc *Service) handlePSSOKeyRequest(ctx context.Context, device *fleet.PSSODevice, claims *pssoTokenClaims) ([]byte, error) { +func (svc *Service) handlePSSOKeyRequest(ctx context.Context, hostUUID string, claims *pssoTokenClaims) ([]byte, error) { if claims.JWECrypto == nil || claims.JWECrypto.APV == "" { return nil, &fleet.BadRequestError{Message: "psso key request: missing jwe_crypto recipe"} } - encPub, err := parseECPublicKeyPEM([]byte(device.EncryptionKeyPEM)) + encPub, err := svc.resolvePSSOEncryptionKey(ctx, hostUUID, claims.JWECrypto.APV) if err != nil { - return nil, ctxerr.Wrap(ctx, err, "parse device encryption pubkey") + return nil, ctxerr.Wrap(ctx, err, "resolve device encryption pubkey") } provisioned, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -605,7 +597,7 @@ func (svc *Service) issuePSSOProvisionedCertificate(ctx context.Context, provisi // provisioned private key from key_context, computes the raw ECDH shared // secret against other_publickey (this is the unlock key), and returns // {iat, exp, key, key_context} in the same JWE envelope. -func (svc *Service) handlePSSOKeyExchange(ctx context.Context, device *fleet.PSSODevice, claims *pssoTokenClaims) ([]byte, error) { +func (svc *Service) handlePSSOKeyExchange(ctx context.Context, hostUUID string, claims *pssoTokenClaims) ([]byte, error) { if claims.JWECrypto == nil || claims.JWECrypto.APV == "" { return nil, &fleet.BadRequestError{Message: "psso key exchange: missing jwe_crypto recipe"} } @@ -635,9 +627,9 @@ func (svc *Service) handlePSSOKeyExchange(ctx context.Context, device *fleet.PSS return nil, ctxerr.Wrap(ctx, err, "compute key exchange shared secret") } - encPub, err := parseECPublicKeyPEM([]byte(device.EncryptionKeyPEM)) + encPub, err := svc.resolvePSSOEncryptionKey(ctx, hostUUID, claims.JWECrypto.APV) if err != nil { - return nil, ctxerr.Wrap(ctx, err, "parse device encryption pubkey") + return nil, ctxerr.Wrap(ctx, err, "resolve device encryption pubkey") } now := time.Now() @@ -658,46 +650,6 @@ func (svc *Service) handlePSSOKeyExchange(ctx context.Context, device *fleet.PSS return jwe, nil } -// handlePSSOPasswordRequest decrypts the password the device sent under -// the previously-established session key, validates it against the -// upstream IdP via the wired PSSOIdPClient, and returns the resulting -// claims as a JWT-inside-JWE. -func (svc *Service) handlePSSOPasswordRequest(ctx context.Context, device *fleet.PSSODevice, claims *pssoTokenClaims) ([]byte, error) { - if svc.pssoIdPClient == nil { - return nil, ctxerr.New(ctx, "psso idp client not configured") - } - if claims.Username == "" || claims.EncryptedPwd == "" { - return nil, &fleet.BadRequestError{Message: "psso password_request missing username or encrypted_password"} - } - - sessionKey, err := deriveSessionKey(device.KeyExchangeKey, []byte(claims.RequestNonce)) - if err != nil { - return nil, fmt.Errorf("derive session key: %w", err) - } - pwdPlain, err := decryptSymmetricBlob([]byte(claims.EncryptedPwd), sessionKey) - if err != nil { - return nil, fmt.Errorf("decrypt password blob: %w", err) - } - - idpClaims, err := svc.pssoIdPClient.ValidatePasswordAndGetClaims(ctx, claims.Username, string(pwdPlain)) - if err != nil { - return nil, err - } - - // Wrap the OIDC-shaped claims in a server-signed JWT, then JWE-wrap the - // JWT under the session key. - innerToken, err := svc.signServerJWT(ctx, jwt.MapClaims{ - "sub": idpClaims.Subject, - "email": idpClaims.Email, - "name": idpClaims.Name, - "preferred_username": idpClaims.PreferredUsername, - }) - if err != nil { - return nil, err - } - return buildSymmetricJWE(innerToken, sessionKey) -} - // PSSOJWKS returns the JWKS JSON with Fleet's PSSO signing public key. func (svc *Service) PSSOJWKS(ctx context.Context) ([]byte, error) { // skipauth: This is an unauthenticated public endpoint serving only the diff --git a/ee/server/service/apple_psso_crypto.go b/ee/server/service/apple_psso_crypto.go index 893905938c3..65512f02532 100644 --- a/ee/server/service/apple_psso_crypto.go +++ b/ee/server/service/apple_psso_crypto.go @@ -6,18 +6,11 @@ package service // // Cryptographic choices for the POC: // - Inbound JWTs from the Mac extension are ES256 (P-256). The kid in the -// header points to a PEM stored in mdm_apple_psso_key_ids. -// - "Asymmetric" JWE responses (key_request) use ECDH-ES with A256GCM, -// wrapped to the device's encryption pubkey. -// - "Symmetric" JWE responses (key_exchange, password_request) use -// A256GCM with the content-encryption key derived from KeyExchangeKey -// via HKDF-SHA256. -// -// TODO(apple-psso-spec): The exact claim names and the precise HKDF salt / -// info bindings in Apple's published spec should be confirmed before this -// POC ships to a real Mac. The names below ("key_exchange_key", "claims", -// etc.) are clean-room placeholders; if Apple's framework rejects them, this -// is the first place to look. +// header points to a PEM stored in mdm_apple_psso_keys. +// - JWE responses use ECDH-ES with A256GCM, wrapped to the device's +// registered encryption pubkey (resolved from the request's apv). +// - key_context blobs are sealed with A256GCM under a key derived from +// Fleet's PSSO signing key via HKDF-SHA256 — no per-device server state. import ( "context" @@ -36,6 +29,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" @@ -51,15 +45,13 @@ type pssoRequestType string const ( pssoRequestKey pssoRequestType = "key_request" pssoRequestExchange pssoRequestType = "key_exchange" - pssoRequestPassword pssoRequestType = "password_request" ) // pssoTokenClaims models the union of claims an inbound token JWT can -// carry. The real PSSO v2 Password login request identifies itself with +// carry. The PSSO v2 Password login request identifies itself with // GrantType=="password" and carries a plaintext Password plus a JWECrypto -// recipe describing how the response must be encrypted. The RequestType / -// Encrypted* fields belong to an earlier handshake model and are retained -// only so the legacy dispatch path still compiles. +// recipe describing how the response must be encrypted; key requests and +// key exchanges identify themselves via RequestType instead. type pssoTokenClaims struct { jwt.RegisteredClaims @@ -76,11 +68,29 @@ type pssoTokenClaims struct { RequestType pssoRequestType `json:"request_type,omitempty"` OtherPublicKey string `json:"other_publickey,omitempty"` // device DH public key (key_exchange) KeyContext string `json:"key_context,omitempty"` // server-sealed provisioned key, echoed back +} - // Legacy symmetric password_request handshake (unused by the Password - // grant flow). - EncryptedPwd string `json:"encrypted_password,omitempty"` - EncryptedNonce string `json:"encrypted_nonce,omitempty"` +// pssoJWTLeeway is the clock-skew tolerance applied to inbound JWT time +// claims. The default RegisteredClaims validation allows zero skew, so a Mac +// whose clock runs even a second ahead of the server gets "token used before +// issued" on every login. +const pssoJWTLeeway = time.Minute + +// Valid overrides the embedded RegisteredClaims validation to apply +// pssoJWTLeeway to exp, iat, and nbf. jwt/v4 has no parser-level leeway +// option (that arrived in v5), so the claims type does it. +func (c *pssoTokenClaims) Valid() error { + now := time.Now() + if !c.VerifyExpiresAt(now.Add(-pssoJWTLeeway), false) { + return jwt.ErrTokenExpired + } + if !c.VerifyIssuedAt(now.Add(pssoJWTLeeway), false) { + return jwt.ErrTokenUsedBeforeIssued + } + if !c.VerifyNotBefore(now.Add(pssoJWTLeeway), false) { + return jwt.ErrTokenNotValidYet + } + return nil } // pssoJWECrypto is the jwe_crypto claim the extension sends to tell Fleet how @@ -96,8 +106,8 @@ type pssoJWECrypto struct { // parsePSSOInboundJWT verifies the inbound compact JWS using the device's // signing pubkey (resolved by kid) and returns the parsed claims plus the -// associated device record. -func (svc *Service) parsePSSOInboundJWT(ctx context.Context, jwtBytes []byte) (*pssoTokenClaims, *fleet.PSSODevice, error) { +// signing key row that matched (its HostUUID identifies the device). +func (svc *Service) parsePSSOInboundJWT(ctx context.Context, jwtBytes []byte) (*pssoTokenClaims, *fleet.PSSOKey, error) { // First parse without verification to extract kid. unverified, _, err := jwt.NewParser(jwt.WithoutClaimsValidation()).ParseUnverified(string(jwtBytes), &pssoTokenClaims{}) if err != nil { @@ -109,15 +119,15 @@ func (svc *Service) parsePSSOInboundJWT(ctx context.Context, jwtBytes []byte) (* } kid = canonicalizeKID(kid) - device, keyID, err := svc.ds.GetPSSODeviceByKeyID(ctx, kid) + signKey, err := svc.ds.GetPSSOKey(ctx, kid) if err != nil { - return nil, nil, ctxerr.Wrap(ctx, err, "look up psso device by kid") + return nil, nil, ctxerr.Wrap(ctx, err, "look up psso key by kid") } - if keyID.KeyType != fleet.PSSOKeyTypeSigning { + if signKey.KeyType != fleet.PSSOKeyTypeSigning { return nil, nil, &fleet.BadRequestError{Message: "psso jwt kid does not reference a signing key"} } - pub, err := parseECPublicKeyPEM([]byte(keyID.PEM)) + pub, err := parseECPublicKeyPEM([]byte(signKey.PEM)) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "parse device signing pubkey") } @@ -132,7 +142,67 @@ func (svc *Service) parsePSSOInboundJWT(ctx context.Context, jwtBytes []byte) (* if !ok || !tok.Valid { return nil, nil, &fleet.BadRequestError{Message: "psso jwt claims invalid"} } - return claims, device, nil + return claims, signKey, nil +} + +// resolvePSSOEncryptionKey returns the registered encryption public key the +// response JWE must be wrapped to. The device names its encryption key inside +// the request's apv party-info blob ("Apple" || deviceEncKey || nonce), and +// the extension registered that key under kid = base64url(SHA-256(raw key +// bytes)) — so the kid is recomputed from apv and looked up. As a fallback +// against any re-encoding of the key by Apple's framework, the raw point is +// compared against each of the host's registered encryption keys. A key that +// resolves but belongs to a different host, or doesn't resolve at all, is +// rejected: responses are only ever encrypted to keys the host registered. +func (svc *Service) resolvePSSOEncryptionKey(ctx context.Context, hostUUID, apvB64 string) (*ecdsa.PublicKey, error) { + apvRaw, err := decodeJOSEB64(apvB64) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "decode apv") + } + fields, err := parseApplePartyInfo(apvRaw) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "parse apv party-info") + } + if len(fields) < 2 || string(fields[0]) != apvPartyLabel { + return nil, &fleet.BadRequestError{Message: "psso: apv is not an Apple party-info blob"} + } + encKeyRaw := fields[1] + + sum := sha256.Sum256(encKeyRaw) + kid := canonicalizeKID(base64.RawURLEncoding.EncodeToString(sum[:])) + key, err := svc.ds.GetPSSOKey(ctx, kid) + switch { + case err == nil: + if key.KeyType != fleet.PSSOKeyTypeEncryption || key.HostUUID != hostUUID { + return nil, &fleet.BadRequestError{Message: "psso: apv key is not a registered encryption key for this device"} + } + return parseECPublicKeyPEM([]byte(key.PEM)) + case !fleet.IsNotFound(err): + return nil, ctxerr.Wrap(ctx, err, "look up encryption key by apv kid") + } + + apvPub, err := parseRawECPoint(encKeyRaw) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "parse apv encryption key") + } + hostKeys, err := svc.ds.ListPSSOKeys(ctx, hostUUID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "list psso keys for apv fallback") + } + for _, k := range hostKeys { + if k.KeyType != fleet.PSSOKeyTypeEncryption { + continue + } + pub, err := parseECPublicKeyPEM([]byte(k.PEM)) + if err != nil { + svc.logger.WarnContext(ctx, "psso: skipping unparseable registered encryption key", "kid", k.KID, "err", err) + continue + } + if pub.Equal(apvPub) { + return pub, nil + } + } + return nil, &fleet.BadRequestError{Message: "psso: apv key is not a registered encryption key for this device"} } // canonicalizeKID normalizes a key ID to a stable comparison form. Apple's @@ -198,8 +268,8 @@ func parseRawECPoint(raw []byte) (*ecdsa.PublicKey, error) { } // buildAsymmetricJWE encrypts payload to deviceEncPub using JWE -// ECDH-ES + A256GCM. Used for the key_request response that delivers the -// initial KeyExchangeKey to the device. +// ECDH-ES + A256GCM via go-jose's stock encrypter (empty apu/apv — see +// buildPSSOResponseJWE for the Apple-party-info variant the handlers use). func buildAsymmetricJWE(payload []byte, deviceEncPub *ecdsa.PublicKey, kid string) ([]byte, error) { enc, err := jose.NewEncrypter( jose.A256GCM, @@ -422,16 +492,15 @@ func decodeBase64Flexible(s string) ([]byte, error) { return nil, errors.New("psso: value is not valid base64") } -// pssoSessionInfo is the HKDF info string distinguishing PSSO session keys -// from any other purpose KeyExchangeKey could be used for. +// pssoSessionInfo is the HKDF info string distinguishing PSSO-derived keys +// from any other derivation the same input keying material could feed. var pssoSessionInfo = []byte("fleetdm-psso-session-key-v1") -// deriveSessionKey returns a 32-byte AES-256 key derived from the device's -// KeyExchangeKey via HKDF-SHA256. The salt parameter binds the derivation -// to a specific request (typically the request_nonce) so each sign-in uses -// a distinct content-encryption key. -func deriveSessionKey(kek []byte, salt []byte) ([]byte, error) { - r := hkdf.New(sha256.New, kek, salt, pssoSessionInfo) +// deriveSessionKey returns a 32-byte AES-256 key derived from ikm via +// HKDF-SHA256. The salt parameter binds the derivation to a purpose (e.g. +// the key_context info string in deriveKeyContextKey). +func deriveSessionKey(ikm []byte, salt []byte) ([]byte, error) { + r := hkdf.New(sha256.New, ikm, salt, pssoSessionInfo) out := make([]byte, 32) if _, err := r.Read(out); err != nil { return nil, fmt.Errorf("hkdf read: %w", err) @@ -440,8 +509,8 @@ func deriveSessionKey(kek []byte, salt []byte) ([]byte, error) { } // buildSymmetricJWE returns an A256GCM JWE of payload, keyed by sessionKey. -// Used for key_exchange and password_request responses where the device -// has already established a shared secret via the KeyExchangeKey handshake. +// Used to seal key_context blobs so the provisioned private key can +// round-trip statelessly between key_request and key_exchange. func buildSymmetricJWE(payload []byte, sessionKey []byte) ([]byte, error) { if len(sessionKey) != 32 { return nil, fmt.Errorf("psso: session key must be 32 bytes, got %d", len(sessionKey)) @@ -476,9 +545,8 @@ func buildSymmetricJWE(payload []byte, sessionKey []byte) ([]byte, error) { return json.Marshal(envelope) } -// decryptSymmetricBlob is the inverse of buildSymmetricJWE — used in -// password_request to decrypt the password the device sent under the -// previously-established session key. +// decryptSymmetricBlob is the inverse of buildSymmetricJWE — used to open +// the key_context blob a device echoes back in a key-exchange request. func decryptSymmetricBlob(blob []byte, sessionKey []byte) ([]byte, error) { if len(sessionKey) != 32 { return nil, fmt.Errorf("psso: session key must be 32 bytes, got %d", len(sessionKey)) @@ -517,7 +585,7 @@ func (svc *Service) signServerJWT(ctx context.Context, claims jwt.Claims) ([]byt tok.Header["kid"] = kid signed, err := tok.SignedString(key) if err != nil { - return nil, fmt.Errorf("sign server jwt: %w", err) + return nil, ctxerr.Wrap(ctx, err, "sign server jwt") } return []byte(signed), nil } diff --git a/ee/server/service/apple_psso_crypto_test.go b/ee/server/service/apple_psso_crypto_test.go index 163b87ba39e..de55c052868 100644 --- a/ee/server/service/apple_psso_crypto_test.go +++ b/ee/server/service/apple_psso_crypto_test.go @@ -1,22 +1,30 @@ package service import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" + "io" + "log/slog" "testing" + "time" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" jose "github.com/go-jose/go-jose/v3" + jwt "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// TestPSSO_SymmetricRoundTrip exercises the AES-256-GCM envelope used for -// key_exchange and password_request responses. Encrypting and then -// decrypting under the same session key must yield the original plaintext. +// TestPSSO_SymmetricRoundTrip exercises the AES-256-GCM envelope used to +// seal key_context blobs. Encrypting and then decrypting under the same +// session key must yield the original plaintext. func TestPSSO_SymmetricRoundTrip(t *testing.T) { key := make([]byte, 32) _, err := rand.Read(key) @@ -81,13 +89,13 @@ func TestPSSO_HKDFDifferentSaltDifferentKey(t *testing.T) { } // TestPSSO_AsymmetricEncryptRoundTrip confirms that a payload encrypted to -// a device's encryption pubkey via JWE ECDH-ES + A256GCM can be decrypted -// with the corresponding private key. This is the key_request flow. +// a device's encryption pubkey via JWE ECDH-ES + A256GCM produces a valid +// compact JWE. func TestPSSO_AsymmetricEncryptRoundTrip(t *testing.T) { deviceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - payload := []byte(`{"key_exchange_key":"AAECAwQF"}`) + payload := []byte(`{"claims":"AAECAwQF"}`) jweCompact, err := buildAsymmetricJWE(payload, &deviceKey.PublicKey, "") require.NoError(t, err) require.NotEmpty(t, jweCompact) @@ -227,6 +235,38 @@ func TestPSSO_KeyExchangeSharedSecretMatches(t *testing.T) { assert.Equal(t, deviceShared, serverShared) } +// TestPSSO_TokenClaimsLeeway confirms inbound JWT time claims tolerate small +// clock skew between the Mac and the server: an iat slightly in the future +// (Mac clock ahead) or an exp slightly in the past must not fail validation, +// while skew beyond the leeway still does. +func TestPSSO_TokenClaimsLeeway(t *testing.T) { + now := time.Now() + claimsAt := func(iat, exp time.Time) *pssoTokenClaims { + return &pssoTokenClaims{RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(iat), + ExpiresAt: jwt.NewNumericDate(exp), + }} + } + + // In sync: valid. + require.NoError(t, claimsAt(now, now.Add(5*time.Minute)).Valid()) + + // Mac clock slightly ahead: iat in the (server's) future, within leeway. + require.NoError(t, claimsAt(now.Add(30*time.Second), now.Add(5*time.Minute)).Valid()) + + // exp just passed, within leeway. + require.NoError(t, claimsAt(now.Add(-5*time.Minute), now.Add(-30*time.Second)).Valid()) + + // Beyond leeway both ways. + err := claimsAt(now.Add(pssoJWTLeeway+time.Minute), now.Add(10*time.Minute)).Valid() + require.ErrorIs(t, err, jwt.ErrTokenUsedBeforeIssued) + err = claimsAt(now.Add(-10*time.Minute), now.Add(-pssoJWTLeeway-time.Minute)).Valid() + require.ErrorIs(t, err, jwt.ErrTokenExpired) + + // Absent time claims are not required (registration-era JWTs). + require.NoError(t, (&pssoTokenClaims{}).Valid()) +} + // TestPSSO_CanonicalizeKID confirms the padded base64 kid Apple's framework // sends in the JWT header and the unpadded base64url kid the extension // registers collapse to the same value, so device lookup by kid succeeds. @@ -275,6 +315,108 @@ func TestPSSO_ParseECPublicKey(t *testing.T) { require.Error(t, err) } +// TestPSSO_ResolveEncryptionKey covers resolving the response-encryption key +// from a request's apv blob: the kid is recomputed as SHA-256 of the raw key +// bytes the device placed in apv (matching how the extension registers its +// kids), looked up, and validated as an encryption key belonging to the +// requesting host. When the kid lookup misses, the host's registered +// encryption keys are compared point-by-point as a fallback. +func TestPSSO_ResolveEncryptionKey(t *testing.T) { + const hostUUID = "ABCDEFGH-0000-0000-0000-111111111111" + + encPriv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + encECDH, err := encPriv.PublicKey.ECDH() + require.NoError(t, err) + rawPoint := encECDH.Bytes() + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: rawPoint}) + + sum := sha256.Sum256(rawPoint) + kid := canonicalizeKID(base64.RawURLEncoding.EncodeToString(sum[:])) + + apv := base64.RawURLEncoding.EncodeToString( + encodeApplePartyInfo([]byte(apvPartyLabel), rawPoint, []byte("nonce"))) + + newSvc := func() (*Service, *mock.DataStore) { + ds := new(mock.DataStore) + svc := &Service{ds: ds, logger: slog.New(slog.NewTextHandler(io.Discard, nil))} + return svc, ds + } + registeredKey := &fleet.PSSOKey{ + KID: kid, + HostUUID: hostUUID, + KeyType: fleet.PSSOKeyTypeEncryption, + PEM: string(pemBytes), + } + + t.Run("resolves by kid computed from apv", func(t *testing.T) { + svc, ds := newSvc() + ds.GetPSSOKeyFunc = func(ctx context.Context, gotKID string) (*fleet.PSSOKey, error) { + require.Equal(t, kid, gotKID) + return registeredKey, nil + } + pub, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, apv) + require.NoError(t, err) + assert.True(t, pub.Equal(&encPriv.PublicKey)) + }) + + t.Run("rejects a key registered to a different host", func(t *testing.T) { + svc, ds := newSvc() + ds.GetPSSOKeyFunc = func(ctx context.Context, _ string) (*fleet.PSSOKey, error) { + other := *registeredKey + other.HostUUID = "some-other-host" + return &other, nil + } + _, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, apv) + require.Error(t, err) + }) + + t.Run("rejects a signing key", func(t *testing.T) { + svc, ds := newSvc() + ds.GetPSSOKeyFunc = func(ctx context.Context, _ string) (*fleet.PSSOKey, error) { + other := *registeredKey + other.KeyType = fleet.PSSOKeyTypeSigning + return &other, nil + } + _, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, apv) + require.Error(t, err) + }) + + t.Run("falls back to comparing the host's registered keys", func(t *testing.T) { + svc, ds := newSvc() + ds.GetPSSOKeyFunc = func(ctx context.Context, _ string) (*fleet.PSSOKey, error) { + return nil, &testNotFoundError{} + } + ds.ListPSSOKeysFunc = func(ctx context.Context, gotUUID string) ([]*fleet.PSSOKey, error) { + require.Equal(t, hostUUID, gotUUID) + return []*fleet.PSSOKey{registeredKey}, nil + } + pub, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, apv) + require.NoError(t, err) + assert.True(t, pub.Equal(&encPriv.PublicKey)) + assert.True(t, ds.ListPSSOKeysFuncInvoked) + }) + + t.Run("rejects when no registered key matches", func(t *testing.T) { + svc, ds := newSvc() + ds.GetPSSOKeyFunc = func(ctx context.Context, _ string) (*fleet.PSSOKey, error) { + return nil, &testNotFoundError{} + } + ds.ListPSSOKeysFunc = func(ctx context.Context, _ string) ([]*fleet.PSSOKey, error) { + return nil, nil + } + _, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, apv) + require.Error(t, err) + }) + + t.Run("rejects a malformed apv", func(t *testing.T) { + svc, _ := newSvc() + _, err := svc.resolvePSSOEncryptionKey(t.Context(), hostUUID, + base64.RawURLEncoding.EncodeToString([]byte("not party info"))) + require.Error(t, err) + }) +} + // TestPSSO_ParseRawECPointPEM covers the form the macOS extension actually // sends: a raw ANSI X9.63 uncompressed point (0x04 || X || Y) PEM-wrapped // under a "PUBLIC KEY" label rather than DER SubjectPublicKeyInfo. diff --git a/ee/server/service/service.go b/ee/server/service/service.go index 3aa89bddeb4..88be39177ee 100644 --- a/ee/server/service/service.go +++ b/ee/server/service/service.go @@ -27,7 +27,7 @@ type Service struct { // PSSO POC. Required for PSSO nonce/register/token flows. pssoNonceStore fleet.PSSONonceStore - // pssoIdPClient validates passwords for the PSSO password_request flow. + // pssoIdPClient validates passwords for the PSSO password login flow. // Wired via SetPSSOIdPClient. pssoIdPClient fleet.PSSOIdPClient diff --git a/server/datastore/mysql/apple_mdm.go b/server/datastore/mysql/apple_mdm.go index 3d7f79594b6..ef3e83cfb32 100644 --- a/server/datastore/mysql/apple_mdm.go +++ b/server/datastore/mysql/apple_mdm.go @@ -7482,6 +7482,12 @@ func (ds *Datastore) MDMAppleResetOnReenrollment(ctx context.Context, hostUUID s } } + // Clear the PSSO registration (keys cascade) so an ADE re-enrollment + // starts from fresh device keys. + if _, err := tx.ExecContext(ctx, "DELETE FROM mdm_apple_psso_devices WHERE host_uuid = ?", hostUUID); err != nil { + return ctxerr.Wrap(ctx, err, "clear psso registration for mdm reset", "host_uuid", hostUUID) + } + if !preserveHostActivities { if err := ds.clearHostActivitiesForAppleMDMReset(ctx, tx, hostUUID, hostID); err != nil { return ctxerr.Wrap(ctx, err, "clear host activities for mdm reset") diff --git a/server/datastore/mysql/apple_mdm_test.go b/server/datastore/mysql/apple_mdm_test.go index d9503bd6403..1bc4fba70d4 100644 --- a/server/datastore/mysql/apple_mdm_test.go +++ b/server/datastore/mysql/apple_mdm_test.go @@ -5597,11 +5597,18 @@ func testMDMAppleResetOnReenrollment(t *testing.T, ds *Datastore) { _, err = ds.writer(ctx).ExecContext(ctx, `INSERT INTO script_upcoming_activities (upcoming_activity_id) VALUES (?)`, uaID) require.NoError(t, err) + + // PSSO registration (host_uuid ref - device row plus a cascading key) + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, h.UUID, []fleet.PSSOKey{ + {KID: "kid-" + h.UUID, KeyType: fleet.PSSOKeyTypeSigning, PEM: "pem-" + h.UUID}, + })) } type counts struct { - label int - upcoming int + label int + upcoming int + pssoDevice int + pssoKey int } countRows := func(t *testing.T, h *fleet.Host) counts { var c counts @@ -5609,9 +5616,13 @@ func testMDMAppleResetOnReenrollment(t *testing.T, ds *Datastore) { `SELECT COUNT(*) FROM label_membership WHERE host_id = ?`, h.ID)) require.NoError(t, sqlx.GetContext(ctx, ds.writer(ctx), &c.upcoming, `SELECT COUNT(*) FROM upcoming_activities WHERE host_id = ?`, h.ID)) + require.NoError(t, sqlx.GetContext(ctx, ds.writer(ctx), &c.pssoDevice, + `SELECT COUNT(*) FROM mdm_apple_psso_devices WHERE host_uuid = ?`, h.UUID)) + require.NoError(t, sqlx.GetContext(ctx, ds.writer(ctx), &c.pssoKey, + `SELECT COUNT(*) FROM mdm_apple_psso_keys WHERE host_uuid = ?`, h.UUID)) return c } - seeded := counts{label: 1, upcoming: 1} + seeded := counts{label: 1, upcoming: 1, pssoDevice: 1, pssoKey: 1} t.Run("clears expected tables and leaves other hosts untouched", func(t *testing.T) { hostA := newHost("clear-A") @@ -5626,7 +5637,7 @@ func testMDMAppleResetOnReenrollment(t *testing.T, ds *Datastore) { require.NoError(t, ds.MDMAppleResetOnReenrollment(ctx, hostA.UUID, true)) // host A: everything cleared - assert.Equal(t, counts{label: 0, upcoming: 0}, countRows(t, hostA)) + assert.Equal(t, counts{}, countRows(t, hostA)) // host B: untouched (control - proves the reset is host-scoped) assert.Equal(t, seeded, countRows(t, hostB)) diff --git a/server/datastore/mysql/apple_psso.go b/server/datastore/mysql/apple_psso.go index 6a368b93ccc..ad50dcbc828 100644 --- a/server/datastore/mysql/apple_psso.go +++ b/server/datastore/mysql/apple_psso.go @@ -10,114 +10,89 @@ import ( "github.com/jmoiron/sqlx" ) -// SetOrUpdatePSSODevice replaces (or creates) a host's PSSO registration in a -// single transaction: upserts the device row, deletes any stale KeyID rows for -// the host, then inserts the two new KeyID rows (signing + encryption). -func (ds *Datastore) SetOrUpdatePSSODevice( - ctx context.Context, - device fleet.PSSODevice, - signKeyID fleet.PSSOKeyID, - encKeyID fleet.PSSOKeyID, -) error { +// SetOrUpdatePSSODevice upserts a host's PSSO registration: the device row +// plus the given key rows in a single transaction. Keys are upserted by kid; +// keys from earlier registrations are left in place so they keep working. +func (ds *Datastore) SetOrUpdatePSSODevice(ctx context.Context, hostUUID string, keys []fleet.PSSOKey) error { return ds.withTx(ctx, func(tx sqlx.ExtContext) error { const upsertDevice = ` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - device_uuid = VALUES(device_uuid), - signing_key_pem = VALUES(signing_key_pem), - encryption_key_pem = VALUES(encryption_key_pem), - key_exchange_key = VALUES(key_exchange_key) + INSERT INTO mdm_apple_psso_devices (host_uuid) + VALUES (?) + ON DUPLICATE KEY UPDATE updated_at = CURRENT_TIMESTAMP(6) ` - if _, err := tx.ExecContext(ctx, upsertDevice, - device.HostID, - device.DeviceUUID, - device.SigningKeyPEM, - device.EncryptionKeyPEM, - device.KeyExchangeKey, - ); err != nil { + if _, err := tx.ExecContext(ctx, upsertDevice, hostUUID); err != nil { return ctxerr.Wrap(ctx, err, "upsert psso device") } - if _, err := tx.ExecContext(ctx, - `DELETE FROM mdm_apple_psso_key_ids WHERE host_id = ?`, - device.HostID, - ); err != nil { - return ctxerr.Wrap(ctx, err, "clear existing psso key_ids") - } - - const insertKeyID = ` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) + const upsertKey = ` + INSERT INTO mdm_apple_psso_keys (kid, host_uuid, key_type, pem) VALUES (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + host_uuid = VALUES(host_uuid), + key_type = VALUES(key_type), + pem = VALUES(pem) ` - for _, k := range []fleet.PSSOKeyID{signKeyID, encKeyID} { - if _, err := tx.ExecContext(ctx, insertKeyID, k.KID, k.HostID, k.KeyType, k.PEM); err != nil { - return ctxerr.Wrap(ctx, err, "insert psso key_id") + for _, k := range keys { + if _, err := tx.ExecContext(ctx, upsertKey, k.KID, hostUUID, k.KeyType, k.PEM); err != nil { + return ctxerr.Wrap(ctx, err, "upsert psso key") } } return nil }) } -// GetPSSODeviceByKeyID resolves a kid back to its owning device and the -// specific KeyID row that matched (so callers know whether they're holding the -// signing or encryption side of the device's keypair). -func (ds *Datastore) GetPSSODeviceByKeyID(ctx context.Context, kid string) (*fleet.PSSODevice, *fleet.PSSOKeyID, error) { - type joined struct { - // device columns - HostID uint `db:"host_id"` - DeviceUUID string `db:"device_uuid"` - SigningKeyPEM string `db:"signing_key_pem"` - EncryptionKeyPEM string `db:"encryption_key_pem"` - KeyExchangeKey []byte `db:"key_exchange_key"` - DeviceCreatedAt []byte `db:"device_created_at"` - DeviceUpdatedAt []byte `db:"device_updated_at"` - // key_id columns - KID string `db:"kid"` - KeyType fleet.PSSOKeyType `db:"key_type"` - PEM string `db:"pem"` - KIDCreated []byte `db:"kid_created_at"` +func (ds *Datastore) GetPSSODevice(ctx context.Context, hostUUID string) (*fleet.PSSODevice, error) { + const stmt = ` + SELECT host_uuid, created_at, updated_at + FROM mdm_apple_psso_devices + WHERE host_uuid = ? + ` + var device fleet.PSSODevice + if err := sqlx.GetContext(ctx, ds.reader(ctx), &device, stmt, hostUUID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ctxerr.Wrap(ctx, notFound("PSSODevice").WithName(hostUUID)) + } + return nil, ctxerr.Wrap(ctx, err, "get psso device") } + return &device, nil +} +func (ds *Datastore) GetPSSOKey(ctx context.Context, kid string) (*fleet.PSSOKey, error) { const stmt = ` - SELECT - d.host_id AS host_id, - d.device_uuid AS device_uuid, - d.signing_key_pem AS signing_key_pem, - d.encryption_key_pem AS encryption_key_pem, - d.key_exchange_key AS key_exchange_key, - d.created_at AS device_created_at, - d.updated_at AS device_updated_at, - k.kid AS kid, - k.key_type AS key_type, - k.pem AS pem, - k.created_at AS kid_created_at - FROM mdm_apple_psso_key_ids k - JOIN mdm_apple_psso_devices d ON d.host_id = k.host_id - WHERE k.kid = ? + SELECT kid, host_uuid, key_type, pem, created_at, updated_at + FROM mdm_apple_psso_keys + WHERE kid = ? ` - - var row joined - if err := sqlx.GetContext(ctx, ds.reader(ctx), &row, stmt, kid); err != nil { + var key fleet.PSSOKey + if err := sqlx.GetContext(ctx, ds.reader(ctx), &key, stmt, kid); err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, nil, ctxerr.Wrap(ctx, notFound("PSSOKeyID").WithName(kid)) + return nil, ctxerr.Wrap(ctx, notFound("PSSOKey").WithName(kid)) } - return nil, nil, ctxerr.Wrap(ctx, err, "get psso device by kid") + return nil, ctxerr.Wrap(ctx, err, "get psso key") } + return &key, nil +} - device := &fleet.PSSODevice{ - HostID: row.HostID, - DeviceUUID: row.DeviceUUID, - SigningKeyPEM: row.SigningKeyPEM, - EncryptionKeyPEM: row.EncryptionKeyPEM, - KeyExchangeKey: row.KeyExchangeKey, +func (ds *Datastore) ListPSSOKeys(ctx context.Context, hostUUID string) ([]*fleet.PSSOKey, error) { + const stmt = ` + SELECT kid, host_uuid, key_type, pem, created_at, updated_at + FROM mdm_apple_psso_keys + WHERE host_uuid = ? + ORDER BY created_at DESC, kid + ` + var keys []*fleet.PSSOKey + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &keys, stmt, hostUUID); err != nil { + return nil, ctxerr.Wrap(ctx, err, "list psso keys") } - keyID := &fleet.PSSOKeyID{ - KID: row.KID, - HostID: row.HostID, - KeyType: row.KeyType, - PEM: row.PEM, + return keys, nil +} + +// DeletePSSODevice clears a host's PSSO registration; the keys cascade. +func (ds *Datastore) DeletePSSODevice(ctx context.Context, hostUUID string) error { + if _, err := ds.writer(ctx).ExecContext(ctx, + `DELETE FROM mdm_apple_psso_devices WHERE host_uuid = ?`, hostUUID, + ); err != nil { + return ctxerr.Wrap(ctx, err, "delete psso device") } - return device, keyID, nil + return nil } diff --git a/server/datastore/mysql/apple_psso_test.go b/server/datastore/mysql/apple_psso_test.go new file mode 100644 index 00000000000..deddbb63b00 --- /dev/null +++ b/server/datastore/mysql/apple_psso_test.go @@ -0,0 +1,147 @@ +package mysql + +import ( + "testing" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplePSSO(t *testing.T) { + ds := CreateMySQLDS(t) + + cases := []struct { + name string + fn func(t *testing.T, ds *Datastore) + }{ + {"SetOrUpdateAndGet", testPSSOSetOrUpdateAndGet}, + {"ReRegistrationKeepsOldKeys", testPSSOReRegistrationKeepsOldKeys}, + {"DeleteDevice", testPSSODeleteDevice}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + defer TruncateTables(t, ds) + c.fn(t, ds) + }) + } +} + +func testPSSOSetOrUpdateAndGet(t *testing.T, ds *Datastore) { + ctx := t.Context() + const hostUUID = "ABCDEFGH-0000-0000-0000-111111111111" + + keys := []fleet.PSSOKey{ + {KID: "kid-sign-1", KeyType: fleet.PSSOKeyTypeSigning, PEM: "sign-pem-1"}, + {KID: "kid-enc-1", KeyType: fleet.PSSOKeyTypeEncryption, PEM: "enc-pem-1"}, + } + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID, keys)) + + device, err := ds.GetPSSODevice(ctx, hostUUID) + require.NoError(t, err) + assert.Equal(t, hostUUID, device.HostUUID) + assert.False(t, device.CreatedAt.IsZero()) + assert.False(t, device.UpdatedAt.IsZero()) + + _, err = ds.GetPSSODevice(ctx, "unregistered-uuid") + require.Error(t, err) + assert.True(t, fleet.IsNotFound(err)) + + signKey, err := ds.GetPSSOKey(ctx, "kid-sign-1") + require.NoError(t, err) + assert.Equal(t, hostUUID, signKey.HostUUID) + assert.Equal(t, fleet.PSSOKeyTypeSigning, signKey.KeyType) + assert.Equal(t, "sign-pem-1", signKey.PEM) + + encKey, err := ds.GetPSSOKey(ctx, "kid-enc-1") + require.NoError(t, err) + assert.Equal(t, fleet.PSSOKeyTypeEncryption, encKey.KeyType) + assert.Equal(t, "enc-pem-1", encKey.PEM) + + _, err = ds.GetPSSOKey(ctx, "no-such-kid") + require.Error(t, err) + assert.True(t, fleet.IsNotFound(err)) + + listed, err := ds.ListPSSOKeys(ctx, hostUUID) + require.NoError(t, err) + assert.Len(t, listed, 2) + + listed, err = ds.ListPSSOKeys(ctx, "unregistered-uuid") + require.NoError(t, err) + assert.Empty(t, listed) + + // Upserting the same kid updates the row in place. + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID, []fleet.PSSOKey{ + {KID: "kid-sign-1", KeyType: fleet.PSSOKeyTypeSigning, PEM: "sign-pem-1-rotated"}, + })) + signKey, err = ds.GetPSSOKey(ctx, "kid-sign-1") + require.NoError(t, err) + assert.Equal(t, "sign-pem-1-rotated", signKey.PEM) + + listed, err = ds.ListPSSOKeys(ctx, hostUUID) + require.NoError(t, err) + assert.Len(t, listed, 2) +} + +func testPSSOReRegistrationKeepsOldKeys(t *testing.T, ds *Datastore) { + ctx := t.Context() + const hostUUID = "ABCDEFGH-0000-0000-0000-222222222222" + + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID, []fleet.PSSOKey{ + {KID: "kid-sign-old", KeyType: fleet.PSSOKeyTypeSigning, PEM: "sign-pem-old"}, + {KID: "kid-enc-old", KeyType: fleet.PSSOKeyTypeEncryption, PEM: "enc-pem-old"}, + })) + + // Re-register with fresh keys: old keys must remain resolvable. + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID, []fleet.PSSOKey{ + {KID: "kid-sign-new", KeyType: fleet.PSSOKeyTypeSigning, PEM: "sign-pem-new"}, + {KID: "kid-enc-new", KeyType: fleet.PSSOKeyTypeEncryption, PEM: "enc-pem-new"}, + })) + + for _, kid := range []string{"kid-sign-old", "kid-enc-old", "kid-sign-new", "kid-enc-new"} { + key, err := ds.GetPSSOKey(ctx, kid) + require.NoError(t, err, "kid %s", kid) + assert.Equal(t, hostUUID, key.HostUUID) + } + + listed, err := ds.ListPSSOKeys(ctx, hostUUID) + require.NoError(t, err) + assert.Len(t, listed, 4) +} + +func testPSSODeleteDevice(t *testing.T, ds *Datastore) { + ctx := t.Context() + const ( + hostUUID1 = "ABCDEFGH-0000-0000-0000-333333333333" + hostUUID2 = "ABCDEFGH-0000-0000-0000-444444444444" + ) + + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID1, []fleet.PSSOKey{ + {KID: "kid-sign-h1", KeyType: fleet.PSSOKeyTypeSigning, PEM: "p"}, + {KID: "kid-enc-h1", KeyType: fleet.PSSOKeyTypeEncryption, PEM: "p"}, + })) + require.NoError(t, ds.SetOrUpdatePSSODevice(ctx, hostUUID2, []fleet.PSSOKey{ + {KID: "kid-sign-h2", KeyType: fleet.PSSOKeyTypeSigning, PEM: "p"}, + })) + + require.NoError(t, ds.DeletePSSODevice(ctx, hostUUID1)) + + _, err := ds.GetPSSODevice(ctx, hostUUID1) + assert.True(t, fleet.IsNotFound(err)) + + // Keys cascade with the device row. + _, err = ds.GetPSSOKey(ctx, "kid-sign-h1") + assert.True(t, fleet.IsNotFound(err)) + listed, err := ds.ListPSSOKeys(ctx, hostUUID1) + require.NoError(t, err) + assert.Empty(t, listed) + + // Other hosts are untouched. + _, err = ds.GetPSSODevice(ctx, hostUUID2) + require.NoError(t, err) + _, err = ds.GetPSSOKey(ctx, "kid-sign-h2") + require.NoError(t, err) + + // Deleting an unregistered host is a no-op. + require.NoError(t, ds.DeletePSSODevice(ctx, "never-registered")) +} diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 52ff44057fd..ea7937357c5 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -626,6 +626,11 @@ var hostRefs = []string{ // Orbit re-enrollment recreates the host row and the existing password row remains // reachable for view/rotate. Apple-MDM unenroll/re-enroll is handled separately by // MDMResetEnrollment, which soft-deletes the row. +// - mdm_apple_psso_devices / mdm_apple_psso_keys: keyed by host_uuid, intentionally +// preserved across host deletion for the same reason — the Mac may still be +// MDM-enrolled with Platform SSO active, and its registered keys must keep +// authenticating token requests. ADE re-enrollment clears them via +// MDMAppleResetOnReenrollment. // additionalHostRefsByUUID are host refs cannot be deleted using the host.id like the hostRefs // above. They use the host.uuid instead. Additionally, the column name that refers to diff --git a/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables.go b/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables.go index 80a887e3f67..44887b05167 100644 --- a/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables.go +++ b/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables.go @@ -12,34 +12,28 @@ func init() { func Up_20260611140639(tx *sql.Tx) error { if _, err := tx.Exec(` CREATE TABLE mdm_apple_psso_devices ( - host_id INT UNSIGNED NOT NULL, - device_uuid VARCHAR(255) COLLATE utf8mb4_unicode_ci NOT NULL, - signing_key_pem TEXT COLLATE utf8mb4_unicode_ci NOT NULL, - encryption_key_pem TEXT COLLATE utf8mb4_unicode_ci NOT NULL, - key_exchange_key VARBINARY(64) NOT NULL, - created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - PRIMARY KEY (host_id), - UNIQUE KEY idx_mdm_apple_psso_devices_device_uuid (device_uuid), - CONSTRAINT fk_mdm_apple_psso_devices_host_id FOREIGN KEY (host_id) REFERENCES hosts (id) ON DELETE CASCADE + host_uuid VARCHAR(255) COLLATE utf8mb4_unicode_ci NOT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (host_uuid) ) `); err != nil { return fmt.Errorf("creating mdm_apple_psso_devices table: %w", err) } if _, err := tx.Exec(` - CREATE TABLE mdm_apple_psso_key_ids ( - kid VARCHAR(255) COLLATE utf8mb4_unicode_ci NOT NULL, - host_id INT UNSIGNED NOT NULL, - key_type ENUM('signing','encryption') COLLATE utf8mb4_unicode_ci NOT NULL, - pem TEXT COLLATE utf8mb4_unicode_ci NOT NULL, - created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + CREATE TABLE mdm_apple_psso_keys ( + kid VARCHAR(255) COLLATE utf8mb4_unicode_ci NOT NULL, + host_uuid VARCHAR(255) COLLATE utf8mb4_unicode_ci NOT NULL, + key_type ENUM('signing','encryption') COLLATE utf8mb4_unicode_ci NOT NULL, + pem TEXT COLLATE utf8mb4_unicode_ci NOT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), PRIMARY KEY (kid), - UNIQUE KEY idx_mdm_apple_psso_key_ids_host_type (host_id, key_type), - CONSTRAINT fk_mdm_apple_psso_key_ids_host_id FOREIGN KEY (host_id) REFERENCES hosts (id) ON DELETE CASCADE + CONSTRAINT fk_mdm_apple_psso_keys_host_uuid FOREIGN KEY (host_uuid) REFERENCES mdm_apple_psso_devices (host_uuid) ON DELETE CASCADE ) `); err != nil { - return fmt.Errorf("creating mdm_apple_psso_key_ids table: %w", err) + return fmt.Errorf("creating mdm_apple_psso_keys table: %w", err) } return nil diff --git a/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables_test.go b/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables_test.go index fba885f33c1..101034d7ea5 100644 --- a/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables_test.go +++ b/server/datastore/mysql/migrations/tables/20260611140639_CreateApplePSSOTables_test.go @@ -10,130 +10,75 @@ import ( func TestUp_20260611140639(t *testing.T) { db := applyUpToPrev(t) - - hostInsert := `INSERT INTO hosts (hardware_serial, osquery_host_id, node_key, uuid, platform) VALUES (?, ?, ?, ?, ?)` - hostID1 := execNoErrLastID(t, db, hostInsert, "serial-1", "osq-1", "node-key-1", "uuid-1", "darwin") - hostID2 := execNoErrLastID(t, db, hostInsert, "serial-2", "osq-2", "node-key-2", "uuid-2", "darwin") - applyNext(t, db) - // Insert a device row with explicit values. - _, err := db.Exec(` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - `, hostID1, "ABCDEFGH-0000-0000-0000-111111111111", "signing-pem-1", "encryption-pem-1", []byte("0123456789abcdef0123456789abcdef")) - require.NoError(t, err) + const ( + hostUUID1 = "ABCDEFGH-0000-0000-0000-111111111111" + hostUUID2 = "ABCDEFGH-0000-0000-0000-222222222222" + ) + + // Register two devices. + execNoErr(t, db, `INSERT INTO mdm_apple_psso_devices (host_uuid) VALUES (?)`, hostUUID1) + execNoErr(t, db, `INSERT INTO mdm_apple_psso_devices (host_uuid) VALUES (?)`, hostUUID2) var ( - gotUUID string - gotSigningPEM string - gotEncryptPEM string - gotKEK []byte - gotCreatedAt time.Time - gotUpdatedAt time.Time + gotCreatedAt time.Time + gotUpdatedAt time.Time ) - err = db.QueryRow(` - SELECT device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key, created_at, updated_at - FROM mdm_apple_psso_devices WHERE host_id = ? - `, hostID1).Scan(&gotUUID, &gotSigningPEM, &gotEncryptPEM, &gotKEK, &gotCreatedAt, &gotUpdatedAt) + err := db.QueryRow(` + SELECT created_at, updated_at FROM mdm_apple_psso_devices WHERE host_uuid = ? + `, hostUUID1).Scan(&gotCreatedAt, &gotUpdatedAt) require.NoError(t, err) - assert.Equal(t, "ABCDEFGH-0000-0000-0000-111111111111", gotUUID) - assert.Equal(t, "signing-pem-1", gotSigningPEM) - assert.Equal(t, "encryption-pem-1", gotEncryptPEM) - assert.Equal(t, []byte("0123456789abcdef0123456789abcdef"), gotKEK) assert.False(t, gotCreatedAt.IsZero()) assert.False(t, gotUpdatedAt.IsZero()) - // Duplicate host_id is rejected by the PK. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - `, hostID1, "different-uuid", "x", "y", []byte("kek")) - require.Error(t, err) - - // Duplicate device_uuid across hosts is rejected by the unique index. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - `, hostID2, "ABCDEFGH-0000-0000-0000-111111111111", "x", "y", []byte("kek")) - require.Error(t, err) - - // FK to hosts is enforced. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - `, 999999, "ghost-uuid", "x", "y", []byte("kek")) + // Duplicate host_uuid is rejected by the PK. + _, err = db.Exec(`INSERT INTO mdm_apple_psso_devices (host_uuid) VALUES (?)`, hostUUID1) require.Error(t, err) - // Second host can register independently. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_devices - (host_id, device_uuid, signing_key_pem, encryption_key_pem, key_exchange_key) - VALUES (?, ?, ?, ?, ?) - `, hostID2, "ABCDEFGH-0000-0000-0000-222222222222", "signing-pem-2", "encryption-pem-2", []byte("ffeeddccbbaa99887766554433221100")) - require.NoError(t, err) + keyInsert := `INSERT INTO mdm_apple_psso_keys (kid, host_uuid, key_type, pem) VALUES (?, ?, ?, ?)` - // Insert key_id rows for host1: one signing, one encryption. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-sign-host1", hostID1, "signing", "signing-pem-1") - require.NoError(t, err) - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-enc-host1", hostID1, "encryption", "encryption-pem-1") - require.NoError(t, err) + // One signing and one encryption key for host1. + execNoErr(t, db, keyInsert, "kid-sign-host1", hostUUID1, "signing", "signing-pem-1") + execNoErr(t, db, keyInsert, "kid-enc-host1", hostUUID1, "encryption", "encryption-pem-1") - // Duplicate kid is rejected by PK. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-sign-host1", hostID2, "signing", "x") - require.Error(t, err) + // Multiple keys of the same type per host are allowed (re-registration + // keeps old keys working). + execNoErr(t, db, keyInsert, "kid-sign-host1-v2", hostUUID1, "signing", "signing-pem-1-v2") + execNoErr(t, db, keyInsert, "kid-enc-host1-v2", hostUUID1, "encryption", "encryption-pem-1-v2") - // Duplicate (host_id, key_type) is rejected by unique index — a host has at most one signing and one encryption key. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-sign-host1-v2", hostID1, "signing", "x") + // Duplicate kid is rejected by the PK. + _, err = db.Exec(keyInsert, "kid-sign-host1", hostUUID2, "signing", "x") require.Error(t, err) - // Invalid key_type is rejected by ENUM. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-bogus", hostID1, "bogus", "x") + // Invalid key_type is rejected by the ENUM. + _, err = db.Exec(keyInsert, "kid-bogus", hostUUID1, "bogus", "x") require.Error(t, err) - // FK to hosts is enforced. - _, err = db.Exec(` - INSERT INTO mdm_apple_psso_key_ids (kid, host_id, key_type, pem) - VALUES (?, ?, ?, ?) - `, "kid-ghost", 999999, "signing", "x") + // Keys must reference a registered device. + _, err = db.Exec(keyInsert, "kid-ghost", "no-such-device-uuid", "signing", "x") require.Error(t, err) - // ON DELETE CASCADE: deleting host2 wipes its psso rows. - _, err = db.Exec(`DELETE FROM hosts WHERE id = ?`, hostID2) + // Key timestamps are populated. + err = db.QueryRow(` + SELECT created_at, updated_at FROM mdm_apple_psso_keys WHERE kid = ? + `, "kid-sign-host1").Scan(&gotCreatedAt, &gotUpdatedAt) require.NoError(t, err) + assert.False(t, gotCreatedAt.IsZero()) + assert.False(t, gotUpdatedAt.IsZero()) - var devicesRemaining int - err = db.QueryRow(`SELECT COUNT(*) FROM mdm_apple_psso_devices WHERE host_id = ?`, hostID2).Scan(&devicesRemaining) - require.NoError(t, err) - assert.Equal(t, 0, devicesRemaining) + // ON DELETE CASCADE: deleting a device wipes its keys. + execNoErr(t, db, keyInsert, "kid-sign-host2", hostUUID2, "signing", "signing-pem-2") + execNoErr(t, db, `DELETE FROM mdm_apple_psso_devices WHERE host_uuid = ?`, hostUUID2) - // host1's rows survive. - var host1Devices int - err = db.QueryRow(`SELECT COUNT(*) FROM mdm_apple_psso_devices WHERE host_id = ?`, hostID1).Scan(&host1Devices) + var keysRemaining int + err = db.QueryRow(`SELECT COUNT(*) FROM mdm_apple_psso_keys WHERE host_uuid = ?`, hostUUID2).Scan(&keysRemaining) require.NoError(t, err) - assert.Equal(t, 1, host1Devices) + assert.Equal(t, 0, keysRemaining) + // host1's rows survive. var host1Keys int - err = db.QueryRow(`SELECT COUNT(*) FROM mdm_apple_psso_key_ids WHERE host_id = ?`, hostID1).Scan(&host1Keys) + err = db.QueryRow(`SELECT COUNT(*) FROM mdm_apple_psso_keys WHERE host_uuid = ?`, hostUUID1).Scan(&host1Keys) require.NoError(t, err) - assert.Equal(t, 2, host1Keys) + assert.Equal(t, 4, host1Keys) } diff --git a/server/datastore/mysql/schema.sql b/server/datastore/mysql/schema.sql index 9b379f593d6..01332af9c05 100644 --- a/server/datastore/mysql/schema.sql +++ b/server/datastore/mysql/schema.sql @@ -1826,29 +1826,24 @@ CREATE TABLE `mdm_apple_installers` ( /*!40101 SET @saved_cs_client = @@character_set_client */; /*!50503 SET character_set_client = utf8mb4 */; CREATE TABLE `mdm_apple_psso_devices` ( - `host_id` int unsigned NOT NULL, - `device_uuid` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, - `signing_key_pem` text CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, - `encryption_key_pem` text CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, - `key_exchange_key` varbinary(64) NOT NULL, + `host_uuid` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, `created_at` timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `updated_at` timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - PRIMARY KEY (`host_id`), - UNIQUE KEY `idx_mdm_apple_psso_devices_device_uuid` (`device_uuid`), - CONSTRAINT `fk_mdm_apple_psso_devices_host_id` FOREIGN KEY (`host_id`) REFERENCES `hosts` (`id`) ON DELETE CASCADE + PRIMARY KEY (`host_uuid`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; /*!40101 SET character_set_client = @saved_cs_client */; /*!40101 SET @saved_cs_client = @@character_set_client */; /*!50503 SET character_set_client = utf8mb4 */; -CREATE TABLE `mdm_apple_psso_key_ids` ( +CREATE TABLE `mdm_apple_psso_keys` ( `kid` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, - `host_id` int unsigned NOT NULL, + `host_uuid` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, `key_type` enum('signing','encryption') CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, `pem` text CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, `created_at` timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + `updated_at` timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), PRIMARY KEY (`kid`), - UNIQUE KEY `idx_mdm_apple_psso_key_ids_host_type` (`host_id`,`key_type`), - CONSTRAINT `fk_mdm_apple_psso_key_ids_host_id` FOREIGN KEY (`host_id`) REFERENCES `hosts` (`id`) ON DELETE CASCADE + KEY `fk_mdm_apple_psso_keys_host_uuid` (`host_uuid`), + CONSTRAINT `fk_mdm_apple_psso_keys_host_uuid` FOREIGN KEY (`host_uuid`) REFERENCES `mdm_apple_psso_devices` (`host_uuid`) ON DELETE CASCADE ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; /*!40101 SET character_set_client = @saved_cs_client */; /*!40101 SET @saved_cs_client = @@character_set_client */; diff --git a/server/fleet/apple_psso.go b/server/fleet/apple_psso.go index aa04aecdd13..e87f2c2f684 100644 --- a/server/fleet/apple_psso.go +++ b/server/fleet/apple_psso.go @@ -5,15 +5,12 @@ import ( "time" ) -// PSSODevice is a Mac host's Apple Platform SSO registration record. +// PSSODevice marks a Mac host as Apple Platform SSO-registered. It carries no +// key material itself. The device's public keys live in PSSOKey rows type PSSODevice struct { - HostID uint `db:"host_id"` - DeviceUUID string `db:"device_uuid"` - SigningKeyPEM string `db:"signing_key_pem"` - EncryptionKeyPEM string `db:"encryption_key_pem"` - KeyExchangeKey []byte `db:"key_exchange_key"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` + HostUUID string `db:"host_uuid"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` } // PSSOKeyType discriminates a device's signing key from its encryption key. @@ -24,15 +21,18 @@ const ( PSSOKeyTypeEncryption PSSOKeyType = "encryption" ) -// PSSOKeyID indexes a device key by its kid (base64 SHA-256 of the key) so the -// server can look up the owning device when an extension presents a JWT with -// that kid in its header. -type PSSOKeyID struct { +// PSSOKey is one of a registered device's public keys, indexed by kid (base64 +// SHA-256 of the key bytes) so the server can resolve the owning device when +// an extension presents a JWT with that kid in its header. A host may hold +// several keys of the same type: re-registration adds new keys without +// invalidating old ones. +type PSSOKey struct { KID string `db:"kid"` - HostID uint `db:"host_id"` + HostUUID string `db:"host_uuid"` KeyType PSSOKeyType `db:"key_type"` PEM string `db:"pem"` CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` } // PSSOClaims is the OIDC-shaped claim set the upstream IdP returns after a diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 84c0fe70c3e..875a7c7e55f 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -3438,15 +3438,25 @@ type Datastore interface { // Apple Platform SSO (PSSO) // SetOrUpdatePSSODevice persists a Mac's PSSO registration: the device row - // plus both KeyID rows (signing + encryption) in a single transaction. - // Replaces any existing registration for the same host. - SetOrUpdatePSSODevice(ctx context.Context, device PSSODevice, signKeyID PSSOKeyID, encKeyID PSSOKeyID) error - - // GetPSSODeviceByKeyID looks up the device that owns the given kid and - // returns both the device row and the specific KeyID row (so the caller - // knows which key — signing or encryption — was referenced). - GetPSSODeviceByKeyID(ctx context.Context, kid string) (*PSSODevice, *PSSOKeyID, error) - + // plus the given key rows in a single transaction. Keys are upserted by + // kid; existing keys for the host are left in place so they keep working + // after a re-registration. + SetOrUpdatePSSODevice(ctx context.Context, hostUUID string, keys []PSSOKey) error + + // GetPSSODevice returns the PSSO registration record for the given host + // UUID, or a notFound error if the host isn't registered. + GetPSSODevice(ctx context.Context, hostUUID string) (*PSSODevice, error) + + // GetPSSOKey looks up a registered device key by its kid. + GetPSSOKey(ctx context.Context, kid string) (*PSSOKey, error) + + // ListPSSOKeys returns all keys registered for the given host UUID. + ListPSSOKeys(ctx context.Context, hostUUID string) ([]*PSSOKey, error) + + // DeletePSSODevice removes a host's PSSO registration and, via cascade, + // all of its registered keys. Deleting an unregistered host is a no-op. + DeletePSSODevice(ctx context.Context, hostUUID string) error + // HasAppleUpdateConfigProfileConfigured checks if a declaration profile for the team already exists in the update_settings table. HasAppleUpdateConfigProfileConfigured(ctx context.Context, teamID uint) (bool, error) diff --git a/server/fleet/service.go b/server/fleet/service.go index a53cbdbeba1..b3562c21c4a 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -1556,8 +1556,8 @@ type Service interface { // validates the device-key payload, and persists the registration. PSSORegisterComplete(ctx context.Context, req PSSORegisterRequest) error // PSSOToken handles the per-sign-in protocol message: parses the inbound - // signed JWT, dispatches on RequestType (key_request / key_exchange / - // password_request), and returns the JWE response body. + // signed JWT, dispatches on grant_type (password login) or request_type + // (key_request / key_exchange), and returns the JWE response body. PSSOToken(ctx context.Context, jwtBytes []byte) ([]byte, error) // PSSOJWKS returns the JSON web key set that publishes Fleet's PSSO // signing public key. diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 88b1dfbb7a9..986e2db74d6 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -2090,9 +2090,15 @@ type MDMAppleResetOnReenrollmentFunc func(ctx context.Context, hostUUID string, type VerifyAppleConfigProfileScopesDoNotConflictFunc func(ctx context.Context, cps []*fleet.MDMAppleConfigProfile) error -type SetOrUpdatePSSODeviceFunc func(ctx context.Context, device fleet.PSSODevice, signKeyID fleet.PSSOKeyID, encKeyID fleet.PSSOKeyID) error +type SetOrUpdatePSSODeviceFunc func(ctx context.Context, hostUUID string, keys []fleet.PSSOKey) error -type GetPSSODeviceByKeyIDFunc func(ctx context.Context, kid string) (*fleet.PSSODevice, *fleet.PSSOKeyID, error) +type GetPSSODeviceFunc func(ctx context.Context, hostUUID string) (*fleet.PSSODevice, error) + +type GetPSSOKeyFunc func(ctx context.Context, kid string) (*fleet.PSSOKey, error) + +type ListPSSOKeysFunc func(ctx context.Context, hostUUID string) ([]*fleet.PSSOKey, error) + +type DeletePSSODeviceFunc func(ctx context.Context, hostUUID string) error type HasAppleUpdateConfigProfileConfiguredFunc func(ctx context.Context, teamID uint) (bool, error) @@ -5205,8 +5211,17 @@ type DataStore struct { SetOrUpdatePSSODeviceFunc SetOrUpdatePSSODeviceFunc SetOrUpdatePSSODeviceFuncInvoked bool - GetPSSODeviceByKeyIDFunc GetPSSODeviceByKeyIDFunc - GetPSSODeviceByKeyIDFuncInvoked bool + GetPSSODeviceFunc GetPSSODeviceFunc + GetPSSODeviceFuncInvoked bool + + GetPSSOKeyFunc GetPSSOKeyFunc + GetPSSOKeyFuncInvoked bool + + ListPSSOKeysFunc ListPSSOKeysFunc + ListPSSOKeysFuncInvoked bool + + DeletePSSODeviceFunc DeletePSSODeviceFunc + DeletePSSODeviceFuncInvoked bool HasAppleUpdateConfigProfileConfiguredFunc HasAppleUpdateConfigProfileConfiguredFunc HasAppleUpdateConfigProfileConfiguredFuncInvoked bool @@ -12454,18 +12469,39 @@ func (s *DataStore) VerifyAppleConfigProfileScopesDoNotConflict(ctx context.Cont return s.VerifyAppleConfigProfileScopesDoNotConflictFunc(ctx, cps) } -func (s *DataStore) SetOrUpdatePSSODevice(ctx context.Context, device fleet.PSSODevice, signKeyID fleet.PSSOKeyID, encKeyID fleet.PSSOKeyID) error { +func (s *DataStore) SetOrUpdatePSSODevice(ctx context.Context, hostUUID string, keys []fleet.PSSOKey) error { s.mu.Lock() s.SetOrUpdatePSSODeviceFuncInvoked = true s.mu.Unlock() - return s.SetOrUpdatePSSODeviceFunc(ctx, device, signKeyID, encKeyID) + return s.SetOrUpdatePSSODeviceFunc(ctx, hostUUID, keys) +} + +func (s *DataStore) GetPSSODevice(ctx context.Context, hostUUID string) (*fleet.PSSODevice, error) { + s.mu.Lock() + s.GetPSSODeviceFuncInvoked = true + s.mu.Unlock() + return s.GetPSSODeviceFunc(ctx, hostUUID) +} + +func (s *DataStore) GetPSSOKey(ctx context.Context, kid string) (*fleet.PSSOKey, error) { + s.mu.Lock() + s.GetPSSOKeyFuncInvoked = true + s.mu.Unlock() + return s.GetPSSOKeyFunc(ctx, kid) +} + +func (s *DataStore) ListPSSOKeys(ctx context.Context, hostUUID string) ([]*fleet.PSSOKey, error) { + s.mu.Lock() + s.ListPSSOKeysFuncInvoked = true + s.mu.Unlock() + return s.ListPSSOKeysFunc(ctx, hostUUID) } -func (s *DataStore) GetPSSODeviceByKeyID(ctx context.Context, kid string) (*fleet.PSSODevice, *fleet.PSSOKeyID, error) { +func (s *DataStore) DeletePSSODevice(ctx context.Context, hostUUID string) error { s.mu.Lock() - s.GetPSSODeviceByKeyIDFuncInvoked = true + s.DeletePSSODeviceFuncInvoked = true s.mu.Unlock() - return s.GetPSSODeviceByKeyIDFunc(ctx, kid) + return s.DeletePSSODeviceFunc(ctx, hostUUID) } func (s *DataStore) HasAppleUpdateConfigProfileConfigured(ctx context.Context, teamID uint) (bool, error) {