Skip to content

Expose flex_attention kernel options in DFlash and Domino training#586

Open
heiheiha798 wants to merge 1 commit into
sgl-project:mainfrom
heiheiha798:pr/expose-flex-kernel-options-cli
Open

Expose flex_attention kernel options in DFlash and Domino training#586
heiheiha798 wants to merge 1 commit into
sgl-project:mainfrom
heiheiha798:pr/expose-flex-kernel-options-cli

Conversation

@heiheiha798

@heiheiha798 heiheiha798 commented Jun 18, 2026

Copy link
Copy Markdown

Summary

  • add --flex-kernel-options-json to DFlash and Domino training CLIs
  • parse the JSON dict in the training entrypoints and forward it to the online wrappers
  • pass the resulting dict through to kernel_options in the flex attention draft forward path

Why

On RTX A6000 (sm86), we hit a training failure when using attention_backend=flex_attention.

The root cause was not an NCCL error and not a generic CUDA OOM. The failure came from the Triton kernel configuration selected for flex attention during training: the generated configuration could exceed the sm86 shared-memory-per-block opt-in limit.

In practice, this means the default flex attention kernel configuration is not always viable on A6000 for these training paths.

Local symptom on A6000

A representative workaround on RTX A6000 is to explicitly pass:

--flex-kernel-options-json '{"num_stages": 2}'

This reduces the selected kernel aggressiveness enough for the training path to proceed on our A6000 setup.

Scope

This PR does not hardcode any hardware-specific default.

It only exposes the existing lower-level flex attention kernel_options capability through the DFlash and Domino training CLIs, so users can tune it from the command line when needed.

Example

python scripts/train_domino.py \
  --attention-backend flex_attention \
  --flex-kernel-options-json '{"num_stages": 2}'

The same flag is also available in scripts/train_dflash.py.

Validation

  • verified that the CLI accepts a JSON dict and rejects non-dict JSON
  • verified the parsed dict is forwarded to the online training wrapper
  • this was added specifically to unblock A6000 flex attention training by allowing num_stages=2 to be passed from the CLI

Copilot AI review requested due to automatic review settings June 18, 2026 12:17

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for forwarding custom kernel options to the draft model when using the flex_attention backend. This is achieved by adding a new CLI argument --flex-kernel-options-json to both train_dflash.py and train_domino.py, which is then parsed and passed down to the DFlash and Domino model wrappers. Feedback on the changes highlights two main issues: first, passing kernel_options unconditionally to the draft model will result in a TypeError when other attention backends (like sdpa or eager) are selected, so it should be passed conditionally. Second, if --flex-kernel-options-json is provided without setting the attention backend to flex_attention, the options are silently ignored, so adding validation to raise a ValueError in this scenario is recommended.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread specforge/core/dflash.py
Comment on lines 311 to 315
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.

Suggested change
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}),
)

Comment thread specforge/core/domino.py
Comment on lines 339 to 343
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.

Suggested change
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}),
)

Comment thread scripts/train_dflash.py Outdated
Comment on lines +386 to +391
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the user provides --flex-kernel-options-json but --attention-backend is not set to flex_attention, these options will be silently ignored. It is better to validate this and raise a ValueError to prevent user confusion.

Suggested change
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)
if args.flex_kernel_options_json is not None:
if args.attention_backend != 'flex_attention':
raise ValueError(
"--flex-kernel-options-json can only be used when --attention-backend is 'flex_attention'."
)
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)

Comment thread scripts/train_domino.py Outdated
Comment on lines +455 to +460
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the user provides --flex-kernel-options-json but --attention-backend is not set to flex_attention, these options will be silently ignored. It is better to validate this and raise a ValueError to prevent user confusion.

Suggested change
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)
if args.flex_kernel_options_json is not None:
if args.attention_backend != 'flex_attention':
raise ValueError(
"--flex-kernel-options-json can only be used when --attention-backend is 'flex_attention'."
)
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR exposes Flex Attention kernel configuration (kernel_options) through the DFlash and Domino training CLIs, parses a JSON dict in the entrypoints, and forwards the resulting options through the online training wrappers into the draft model’s Flex Attention forward path. This is intended to let users tune kernel selection to avoid hardware-specific shared-memory issues (e.g., RTX A6000 sm86).

Changes:

  • Add --flex-kernel-options-json to train_dflash.py and train_domino.py, parse it as JSON, and forward it into the online wrappers.
  • Add flex_kernel_options plumbing to OnlineDFlashModel and OnlineDominoModel.
  • Forward the options as kernel_options=... into the draft model call path.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
specforge/core/domino.py Adds flex_kernel_options param/storage and forwards options into the draft-model call.
specforge/core/dflash.py Adds flex_kernel_options param/storage and forwards options into the draft-model call.
scripts/train_domino.py Adds CLI flag and JSON parsing; forwards parsed dict into OnlineDominoModel.
scripts/train_dflash.py Adds CLI flag and JSON parsing; forwards parsed dict into OnlineDFlashModel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread specforge/core/dflash.py Outdated
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
Comment thread specforge/core/domino.py Outdated
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
Comment thread scripts/train_domino.py Outdated
Comment on lines +454 to +460
flex_kernel_options = None
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)
Comment thread scripts/train_dflash.py Outdated
Comment on lines +385 to +391
flex_kernel_options = None
if args.flex_kernel_options_json is not None:
flex_kernel_options = json.loads(args.flex_kernel_options_json)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated no new comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants