Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/fleetctl/fleetctl/gitops.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ func getLabelUsage(config *spec.GitOps) (map[string][]LabelUsage, error) {
if len(setting.LabelsIncludeAll) > 0 {
labels = setting.LabelsIncludeAll
}
if overlap := fleet.ProfileLabelOverlap(labels, setting.LabelsExcludeAny); overlap != "" {
if overlap := fleet.LabelOverlap(labels, setting.LabelsExcludeAny); overlap != "" {
return nil, fmt.Errorf("configuration profile '%s': label %q cannot appear in both include and exclude lists.", filepath.Base(setting.Path), overlap)
}
labels = append(labels, setting.LabelsExcludeAny...)
Expand Down
10 changes: 8 additions & 2 deletions server/datastore/mysql/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3758,7 +3758,11 @@ func (ds *Datastore) ListPoliciesForHost(ctx context.Context, host *fleet.Host)
-- count of include_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 0 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_include_all_count,
-- 1 if this host is a member of at least one exclude_any label
MAX(CASE WHEN pl.exclude = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude
MAX(CASE WHEN pl.exclude = 1 AND pl.require_all = 0 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude_any,
-- count of exclude_all labels on this policy
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 THEN 1 ELSE 0 END) AS exclude_all_count,
-- count of exclude_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_exclude_all_count
FROM policy_labels pl
LEFT JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = ?
GROUP BY pl.policy_id
Expand All @@ -3770,7 +3774,9 @@ func (ds *Datastore) ListPoliciesForHost(ctx context.Context, host *fleet.Host)
-- Policy has no include_all labels, or host is in all of them
AND (COALESCE(pl_agg.include_all_count, 0) = 0 OR pl_agg.host_include_all_count = pl_agg.include_all_count)
-- Host is not in any exclude_any label
AND COALESCE(pl_agg.host_in_exclude, 0) = 0
AND COALESCE(pl_agg.host_in_exclude_any, 0) = 0
-- Policy has no exclude_all labels, or host is not in all of them
AND (COALESCE(pl_agg.exclude_all_count, 0) = 0 OR pl_agg.host_exclude_all_count < pl_agg.exclude_all_count)
ORDER BY FIELD(response, 'fail', '', 'pass'), p.name`

var policies []*fleet.HostPolicy
Expand Down
131 changes: 79 additions & 52 deletions server/datastore/mysql/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func newGlobalPolicy(ctx context.Context, db sqlx.ExtContext, authorID *uint, ar
LabelsIncludeAny: fleet.LabelNamesToIdents(args.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(args.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(args.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(args.LabelsExcludeAll),
},
}

Expand All @@ -164,68 +165,64 @@ func updatePolicyLabelsTx(ctx context.Context, tx sqlx.ExtContext, policy *fleet
WHERE name IN (?)
`

// Scopes are mutually exclusive
scopesSet := 0
if len(policy.LabelsIncludeAny) > 0 {
scopesSet++
}
if len(policy.LabelsIncludeAll) > 0 {
scopesSet++
}
if len(policy.LabelsExcludeAny) > 0 {
scopesSet++
}
if scopesSet > 1 {
return ctxerr.Wrap(ctx, fleet.ErrPolicyConflictingLabels)
if _, err := tx.ExecContext(ctx, deleteLabelsStmt, policy.ID); err != nil {
return ctxerr.Wrap(ctx, err, "deleting old policy labels")
}
Comment thread
nulmete marked this conversation as resolved.

var (
labelNames []string
exclude bool
requireAll bool
)
switch {
case len(policy.LabelsIncludeAll) > 0:
requireAll = true
for _, label := range policy.LabelsIncludeAll {
labelNames = append(labelNames, label.LabelName)
// Each scope maps to a (exclude, require_all) pair in policy_labels:
// include_any -> exclude=0, require_all=0
// include_all -> exclude=0, require_all=1
// exclude_any -> exclude=1, require_all=0
// exclude_all -> exclude=1, require_all=1
insertScope := func(labels []fleet.LabelIdent, exclude, requireAll bool) error {
names := make([]string, 0, len(labels))
for _, label := range labels {
names = append(names, label.LabelName)
}
case len(policy.LabelsExcludeAny) > 0:
exclude = true
for _, label := range policy.LabelsExcludeAny {
labelNames = append(labelNames, label.LabelName)

stmt, args, err := sqlx.In(insertLabelStmt, policy.ID, exclude, requireAll, names)
if err != nil {
return ctxerr.Wrap(ctx, err, "constructing policy label update query")
}
default:
for _, label := range policy.LabelsIncludeAny {
labelNames = append(labelNames, label.LabelName)

res, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "creating policy labels")
}
}

if _, err := tx.ExecContext(ctx, deleteLabelsStmt, policy.ID); err != nil {
return ctxerr.Wrap(ctx, err, "deleting old policy labels")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "listing number of policy labels affected")
}

if len(labelNames) == 0 {
if rowsAffected != int64(len(names)) {
return ctxerr.Errorf(ctx, "invalid label")
}
return nil
}

labelStmt, args, err := sqlx.In(insertLabelStmt, policy.ID, exclude, requireAll, labelNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "constructing policy label update query")
// Only insert the scopes that actually have labels. PolicyPayload.Verify in the service layer
// enforces at most one include scope (any/all) and one exclude scope (any/all), so in practice
// at most two of these inserts run.
if len(policy.LabelsIncludeAny) > 0 {
if err := insertScope(policy.LabelsIncludeAny, false, false); err != nil {
return err
}
}

res, err := tx.ExecContext(ctx, labelStmt, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "creating policy labels")
if len(policy.LabelsIncludeAll) > 0 {
if err := insertScope(policy.LabelsIncludeAll, false, true); err != nil {
return err
}
}

rowsAffected, err := res.RowsAffected()
if err != nil {
return ctxerr.Wrap(ctx, err, "listing number of policy labels affected")
if len(policy.LabelsExcludeAny) > 0 {
if err := insertScope(policy.LabelsExcludeAny, true, false); err != nil {
return err
}
}

if rowsAffected != int64(len(labelNames)) {
return ctxerr.Errorf(ctx, "invalid label")
if len(policy.LabelsExcludeAll) > 0 {
if err := insertScope(policy.LabelsExcludeAll, true, true); err != nil {
return err
}
}

return nil
Expand Down Expand Up @@ -255,6 +252,7 @@ func loadLabelsForPolicies(ctx context.Context, db sqlx.QueryerContext, policies
policy.LabelsIncludeAny = nil
policy.LabelsIncludeAll = nil
policy.LabelsExcludeAny = nil
policy.LabelsExcludeAll = nil
policyIDs = append(policyIDs, policy.ID)
policyMap[policy.ID] = policy
}
Expand All @@ -279,7 +277,12 @@ func loadLabelsForPolicies(ctx context.Context, db sqlx.QueryerContext, policies
for _, row := range rows {
ident := fleet.LabelIdent{LabelName: row.LabelName, LabelID: row.LabelID}
policy := policyMap[row.PolicyID]
if policy == nil {
continue
}
switch {
case row.Exclude && row.RequireAll:
policy.LabelsExcludeAll = append(policy.LabelsExcludeAll, ident)
case row.Exclude:
policy.LabelsExcludeAny = append(policy.LabelsExcludeAny, ident)
case row.RequireAll:
Expand Down Expand Up @@ -1152,6 +1155,7 @@ func deletePolicyDB(ctx context.Context, q sqlx.ExtContext, ids []uint, teamID *
// exclude=0, require_all=0 -> include_any
// exclude=0, require_all=1 -> include_all
// exclude=1, require_all=0 -> exclude_any
// exclude=1, require_all=1 -> exclude_all
//
// Placeholder order: lm.host_id, team_id, platform. policyQueriesForHostInScope appends "AND p.id IN (?)" (and its arg) to
// restrict to specific policies.
Expand All @@ -1169,7 +1173,11 @@ const policyQueriesForHostStmt = `
-- count of include_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 0 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_include_all_count,
-- 1 if this host is a member of at least one exclude_any label
MAX(CASE WHEN pl.exclude = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude
MAX(CASE WHEN pl.exclude = 1 AND pl.require_all = 0 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_in_exclude_any,
-- count of exclude_all labels on this policy
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 THEN 1 ELSE 0 END) AS exclude_all_count,
-- count of exclude_all labels this host is a member of
SUM(CASE WHEN pl.exclude = 1 AND pl.require_all = 1 AND lm.host_id IS NOT NULL THEN 1 ELSE 0 END) AS host_exclude_all_count
FROM policy_labels pl
LEFT JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = ?
GROUP BY pl.policy_id
Expand All @@ -1182,7 +1190,9 @@ const policyQueriesForHostStmt = `
-- Policy has no include_all labels, or host is in all of them
(COALESCE(pl_agg.include_all_count, 0) = 0 OR pl_agg.host_include_all_count = pl_agg.include_all_count) AND
-- Host is not in any exclude_any label
COALESCE(pl_agg.host_in_exclude, 0) = 0`
COALESCE(pl_agg.host_in_exclude_any, 0) = 0 AND
-- Policy has no exclude_all labels, or host is not in all of them
(COALESCE(pl_agg.exclude_all_count, 0) = 0 OR pl_agg.host_exclude_all_count < pl_agg.exclude_all_count)`

// PolicyQueriesForHost returns the policy queries that are to be executed on the given host.
func (ds *Datastore) PolicyQueriesForHost(ctx context.Context, host *fleet.Host) (map[string]string, error) {
Expand Down Expand Up @@ -1384,6 +1394,7 @@ func newTeamPolicy(ctx context.Context, db sqlx.ExtContext, teamID uint, authorI
LabelsIncludeAny: fleet.LabelNamesToIdents(args.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(args.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(args.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(args.LabelsExcludeAll),
},
}

Expand Down Expand Up @@ -1808,6 +1819,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
LabelsIncludeAny: fleet.LabelNamesToIdents(spec.LabelsIncludeAny),
LabelsIncludeAll: fleet.LabelNamesToIdents(spec.LabelsIncludeAll),
LabelsExcludeAny: fleet.LabelNamesToIdents(spec.LabelsExcludeAny),
LabelsExcludeAll: fleet.LabelNamesToIdents(spec.LabelsExcludeAll),
},
})
if err != nil {
Expand Down Expand Up @@ -2154,7 +2166,22 @@ func cleanupPolicyMembershipOnPolicyUpdate(
AND NOT EXISTS (
SELECT 1 FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 0
)
-- If the policy has exclude_all labels, the host must not be in all of them.
AND (
NOT EXISTS (
SELECT 1 FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
)
OR (
SELECT COUNT(*) FROM policy_labels pl
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
) > (
SELECT COUNT(*) FROM policy_labels pl
JOIN label_membership lm ON lm.label_id = pl.label_id AND lm.host_id = pm.host_id
WHERE pl.policy_id = pm.policy_id AND pl.exclude = 1 AND pl.require_all = 1
)
)
)
ORDER BY pm.host_id ASC
Expand Down
Loading
Loading