Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/fleetctl/fleetctl/gitops.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ func getLabelUsage(config *spec.GitOps) (map[string][]LabelUsage, error) {
if len(setting.LabelsIncludeAll) > 0 {
labels = setting.LabelsIncludeAll
}
if overlap := fleet.ProfileLabelOverlap(labels, setting.LabelsExcludeAny); overlap != "" {
if overlap := fleet.LabelOverlap(labels, setting.LabelsExcludeAny); overlap != "" {
return nil, fmt.Errorf("configuration profile '%s': label %q cannot appear in both include and exclude lists.", filepath.Base(setting.Path), overlap)
}
labels = append(labels, setting.LabelsExcludeAny...)
Expand Down
10 changes: 8 additions & 2 deletions server/datastore/mysql/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3758,7 +3758,11 @@ func (ds *Datastore) ListPoliciesForHost(ctx context.Context, host *fleet.Host)
-- count of include_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 0 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_include_all_count,
-- 1 if this host is a member of at least one exclude_any label
MAX(CASE WHEN pl.exclude = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude
MAX(CASE WHEN pl.exclude = 1 AND pl.require_all = 0 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude_any,
-- count of exclude_all labels on this policy
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 THEN 1 ELSE 0 END) AS exclude_all_count,
-- count of exclude_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_exclude_all_count
FROM policy_labels pl
LEFT JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = ?
GROUP BY pl.policy_id
Expand All @@ -3770,7 +3774,9 @@ func (ds *Datastore) ListPoliciesForHost(ctx context.Context, host *fleet.Host)
-- Policy has no include_all labels, or host is in all of them
AND (COALESCE(pl_agg.include_all_count, 0) = 0 OR pl_agg.host_include_all_count = pl_agg.include_all_count)
-- Host is not in any exclude_any label
AND COALESCE(pl_agg.host_in_exclude, 0) = 0
AND COALESCE(pl_agg.host_in_exclude_any, 0) = 0
-- Policy has no exclude_all labels, or host is not in all of them
AND (COALESCE(pl_agg.exclude_all_count, 0) = 0 OR pl_agg.host_exclude_all_count < pl_agg.exclude_all_count)
ORDER BY FIELD(response, 'fail', '', 'pass'), p.name`

var policies []*fleet.HostPolicy
Expand Down
138 changes: 87 additions & 51 deletions server/datastore/mysql/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func newGlobalPolicy(ctx context.Context, db sqlx.ExtContext, authorID *uint, ar
LabelsIncludeAny: fleet.LabelNamesToIdents(args.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(args.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(args.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(args.LabelsExcludeAll),
},
}

Expand All @@ -164,68 +165,73 @@ func updatePolicyLabelsTx(ctx context.Context, tx sqlx.ExtContext, policy *fleet
WHERE name IN (?)
`

// Scopes are mutually exclusive
scopesSet := 0
if len(policy.LabelsIncludeAny) > 0 {
scopesSet++
}
if len(policy.LabelsIncludeAll) > 0 {
scopesSet++
}
if len(policy.LabelsExcludeAny) > 0 {
scopesSet++
if err := fleet.VerifyPolicyLabelScopes(
fleet.LabelIdentsToNames(policy.LabelsIncludeAny),
fleet.LabelIdentsToNames(policy.LabelsIncludeAll),
fleet.LabelIdentsToNames(policy.LabelsExcludeAny),
fleet.LabelIdentsToNames(policy.LabelsExcludeAll),
); err != nil {
return ctxerr.Wrap(ctx, err, "validating policy label scopes")
}
if scopesSet > 1 {
return ctxerr.Wrap(ctx, fleet.ErrPolicyConflictingLabels)

if _, err := tx.ExecContext(ctx, deleteLabelsStmt, policy.ID); err != nil {
return ctxerr.Wrap(ctx, err, "deleting old policy labels")
}
Comment thread
nulmete marked this conversation as resolved.

var (
labelNames []string
exclude bool
requireAll bool
)
switch {
case len(policy.LabelsIncludeAll) > 0:
requireAll = true
for _, label := range policy.LabelsIncludeAll {
labelNames = append(labelNames, label.LabelName)
// Each scope maps to a (exclude, require_all) pair in policy_labels:
// include_any -> exclude=0, require_all=0
// include_all -> exclude=0, require_all=1
// exclude_any -> exclude=1, require_all=0
// exclude_all -> exclude=1, require_all=1
insertScope := func(labels []fleet.LabelIdent, exclude, requireAll bool) error {
names := make([]string, 0, len(labels))
for _, label := range labels {
names = append(names, label.LabelName)
}
case len(policy.LabelsExcludeAny) > 0:
exclude = true
for _, label := range policy.LabelsExcludeAny {
labelNames = append(labelNames, label.LabelName)

stmt, args, err := sqlx.In(insertLabelStmt, policy.ID, exclude, requireAll, names)
if err != nil {
return ctxerr.Wrap(ctx, err, "constructing policy label update query")
}
default:
for _, label := range policy.LabelsIncludeAny {
labelNames = append(labelNames, label.LabelName)

res, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "creating policy labels")
}
}

if _, err := tx.ExecContext(ctx, deleteLabelsStmt, policy.ID); err != nil {
return ctxerr.Wrap(ctx, err, "deleting old policy labels")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "listing number of policy labels affected")
}

if len(labelNames) == 0 {
if rowsAffected != int64(len(names)) {
return ctxerr.Errorf(ctx, "invalid label")
}
return nil
}

labelStmt, args, err := sqlx.In(insertLabelStmt, policy.ID, exclude, requireAll, labelNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "constructing policy label update query")
// Only insert the scopes that actually have labels. PolicyPayload.Verify in the service layer
// enforces at most one include scope (any/all) and one exclude scope (any/all), so in practice
// at most two of these inserts run.
if len(policy.LabelsIncludeAny) > 0 {
if err := insertScope(policy.LabelsIncludeAny, false, false); err != nil {
return err
}
}

res, err := tx.ExecContext(ctx, labelStmt, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "creating policy labels")
if len(policy.LabelsIncludeAll) > 0 {
if err := insertScope(policy.LabelsIncludeAll, false, true); err != nil {
return err
}
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "listing number of policy labels affected")
if len(policy.LabelsExcludeAny) > 0 {
if err := insertScope(policy.LabelsExcludeAny, true, false); err != nil {
return err
}
}

if rowsAffected != int64(len(labelNames)) {
return ctxerr.Errorf(ctx, "invalid label")
if len(policy.LabelsExcludeAll) > 0 {
if err := insertScope(policy.LabelsExcludeAll, true, true); err != nil {
return err
}
}

return nil
Expand Down Expand Up @@ -255,6 +261,7 @@ func loadLabelsForPolicies(ctx context.Context, db sqlx.QueryerContext, policies
policy.LabelsIncludeAny = nil
policy.LabelsIncludeAll = nil
policy.LabelsExcludeAny = nil
policy.LabelsExcludeAll = nil
policyIDs = append(policyIDs, policy.ID)
policyMap[policy.ID] = policy
}
Expand All @@ -279,7 +286,12 @@ func loadLabelsForPolicies(ctx context.Context, db sqlx.QueryerContext, policies
for _, row := range rows {
ident := fleet.LabelIdent{LabelName: row.LabelName, LabelID: row.LabelID}
policy := policyMap[row.PolicyID]
if policy == nil {
continue
}
switch {
case row.Exclude && row.RequireAll:
policy.LabelsExcludeAll = append(policy.LabelsExcludeAll, ident)
case row.Exclude:
policy.LabelsExcludeAny = append(policy.LabelsExcludeAny, ident)
case row.RequireAll:
Expand Down Expand Up @@ -1152,6 +1164,7 @@ func deletePolicyDB(ctx context.Context, q sqlx.ExtContext, ids []uint, teamID *
// exclude=0, require_all=0 -> include_any
// exclude=0, require_all=1 -> include_all
// exclude=1, require_all=0 -> exclude_any
// exclude=1, require_all=1 -> exclude_all
//
// Placeholder order: lm.host_id, team_id, platform. policyQueriesForHostInScope appends "AND p.id IN (?)" (and its arg) to
// restrict to specific policies.
Expand All @@ -1169,7 +1182,11 @@ const policyQueriesForHostStmt = `
-- count of include_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 0 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_include_all_count,
-- 1 if this host is a member of at least one exclude_any label
MAX(CASE WHEN pl.exclude = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude
MAX(CASE WHEN pl.exclude = 1 AND pl.require_all = 0 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude_any,
-- count of exclude_all labels on this policy
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 THEN 1 ELSE 0 END) AS exclude_all_count,
-- count of exclude_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_exclude_all_count
FROM policy_labels pl
LEFT JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = ?
GROUP BY pl.policy_id
Expand All @@ -1182,7 +1199,9 @@ const policyQueriesForHostStmt = `
-- Policy has no include_all labels, or host is in all of them
(COALESCE(pl_agg.include_all_count, 0) = 0 OR pl_agg.host_include_all_count = pl_agg.include_all_count) AND
-- Host is not in any exclude_any label
COALESCE(pl_agg.host_in_exclude, 0) = 0`
COALESCE(pl_agg.host_in_exclude_any, 0) = 0 AND
-- Policy has no exclude_all labels, or host is not in all of them
(COALESCE(pl_agg.exclude_all_count, 0) = 0 OR pl_agg.host_exclude_all_count < pl_agg.exclude_all_count)`

// PolicyQueriesForHost returns the policy queries that are to be executed on the given host.
func (ds *Datastore) PolicyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
Expand Down Expand Up @@ -1384,6 +1403,7 @@ func newTeamPolicy(ctx context.Context, db sqlx.ExtContext, teamID uint, authorI
LabelsIncludeAny: fleet.LabelNamesToIdents(args.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(args.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(args.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(args.LabelsExcludeAll),
},
}

Expand Down Expand Up @@ -1808,6 +1828,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
LabelsIncludeAny: fleet.LabelNamesToIdents(spec.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(spec.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(spec.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(spec.LabelsExcludeAll),
},
})
if err != nil {
Expand Down Expand Up @@ -2154,7 +2175,22 @@ func cleanupPolicyMembershipOnPolicyUpdate(
AND NOT EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 0
)
-- If the policy has exclude_all labels, the host must not be in all of them.
AND (
NOT EXISTS (
SELECT 1 FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
)
OR (
SELECT COUNT(*) FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
) > (
SELECT COUNT(*) FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
)
)
)
ORDER BY pm.host_id ASC
Expand Down
Loading
Loading