Expose flex_attention kernel options in DFlash and Domino training#586
Expose flex_attention kernel options in DFlash and Domino training#586heiheiha798 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for forwarding custom kernel options to the draft model when using the flex_attention backend. This is achieved by adding a new CLI argument --flex-kernel-options-json to both train_dflash.py and train_domino.py, which is then parsed and passed down to the DFlash and Domino model wrappers. Feedback on the changes highlights two main issues: first, passing kernel_options unconditionally to the draft model will result in a TypeError when other attention backends (like sdpa or eager) are selected, so it should be passed conditionally. Second, if --flex-kernel-options-json is provided without setting the attention backend to flex_attention, the options are silently ignored, so adding validation to raise a ValueError in this scenario is recommended.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| noise_embedding=noise_embedding, | ||
| target_hidden=hidden_states, | ||
| attention_mask=dflash_attn_mask, | ||
| kernel_options=self.flex_kernel_options, | ||
| ) |
There was a problem hiding this comment.
Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.
| noise_embedding=noise_embedding, | |
| target_hidden=hidden_states, | |
| attention_mask=dflash_attn_mask, | |
| kernel_options=self.flex_kernel_options, | |
| ) | |
| noise_embedding=noise_embedding, | |
| target_hidden=hidden_states, | |
| attention_mask=dflash_attn_mask, | |
| **({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}), | |
| ) |
| noise_embedding=noise_embedding, | ||
| target_hidden=hidden_states, | ||
| attention_mask=dflash_attn_mask, | ||
| kernel_options=self.flex_kernel_options, | ||
| ) |
There was a problem hiding this comment.
Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.
| noise_embedding=noise_embedding, | |
| target_hidden=hidden_states, | |
| attention_mask=dflash_attn_mask, | |
| kernel_options=self.flex_kernel_options, | |
| ) | |
| noise_embedding=noise_embedding, | |
| target_hidden=hidden_states, | |
| attention_mask=dflash_attn_mask, | |
| **({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}), | |
| ) |
| if args.flex_kernel_options_json is not None: | ||
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | ||
| if not isinstance(flex_kernel_options, dict): | ||
| raise ValueError( | ||
| "--flex-kernel-options-json must decode to a JSON object." | ||
| ) |
There was a problem hiding this comment.
If the user provides --flex-kernel-options-json but --attention-backend is not set to flex_attention, these options will be silently ignored. It is better to validate this and raise a ValueError to prevent user confusion.
| if args.flex_kernel_options_json is not None: | |
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | |
| if not isinstance(flex_kernel_options, dict): | |
| raise ValueError( | |
| "--flex-kernel-options-json must decode to a JSON object." | |
| ) | |
| if args.flex_kernel_options_json is not None: | |
| if args.attention_backend != 'flex_attention': | |
| raise ValueError( | |
| "--flex-kernel-options-json can only be used when --attention-backend is 'flex_attention'." | |
| ) | |
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | |
| if not isinstance(flex_kernel_options, dict): | |
| raise ValueError( | |
| "--flex-kernel-options-json must decode to a JSON object." | |
| ) |
| if args.flex_kernel_options_json is not None: | ||
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | ||
| if not isinstance(flex_kernel_options, dict): | ||
| raise ValueError( | ||
| "--flex-kernel-options-json must decode to a JSON object." | ||
| ) |
There was a problem hiding this comment.
If the user provides --flex-kernel-options-json but --attention-backend is not set to flex_attention, these options will be silently ignored. It is better to validate this and raise a ValueError to prevent user confusion.
| if args.flex_kernel_options_json is not None: | |
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | |
| if not isinstance(flex_kernel_options, dict): | |
| raise ValueError( | |
| "--flex-kernel-options-json must decode to a JSON object." | |
| ) | |
| if args.flex_kernel_options_json is not None: | |
| if args.attention_backend != 'flex_attention': | |
| raise ValueError( | |
| "--flex-kernel-options-json can only be used when --attention-backend is 'flex_attention'." | |
| ) | |
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | |
| if not isinstance(flex_kernel_options, dict): | |
| raise ValueError( | |
| "--flex-kernel-options-json must decode to a JSON object." | |
| ) |
There was a problem hiding this comment.
Pull request overview
This PR exposes Flex Attention kernel configuration (kernel_options) through the DFlash and Domino training CLIs, parses a JSON dict in the entrypoints, and forwards the resulting options through the online training wrappers into the draft model’s Flex Attention forward path. This is intended to let users tune kernel selection to avoid hardware-specific shared-memory issues (e.g., RTX A6000 sm86).
Changes:
- Add
--flex-kernel-options-jsontotrain_dflash.pyandtrain_domino.py, parse it as JSON, and forward it into the online wrappers. - Add
flex_kernel_optionsplumbing toOnlineDFlashModelandOnlineDominoModel. - Forward the options as
kernel_options=...into the draft model call path.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| specforge/core/domino.py | Adds flex_kernel_options param/storage and forwards options into the draft-model call. |
| specforge/core/dflash.py | Adds flex_kernel_options param/storage and forwards options into the draft-model call. |
| scripts/train_domino.py | Adds CLI flag and JSON parsing; forwards parsed dict into OnlineDominoModel. |
| scripts/train_dflash.py | Adds CLI flag and JSON parsing; forwards parsed dict into OnlineDFlashModel. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| noise_embedding=noise_embedding, | ||
| target_hidden=hidden_states, | ||
| attention_mask=dflash_attn_mask, | ||
| kernel_options=self.flex_kernel_options, |
| noise_embedding=noise_embedding, | ||
| target_hidden=hidden_states, | ||
| attention_mask=dflash_attn_mask, | ||
| kernel_options=self.flex_kernel_options, |
| flex_kernel_options = None | ||
| if args.flex_kernel_options_json is not None: | ||
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | ||
| if not isinstance(flex_kernel_options, dict): | ||
| raise ValueError( | ||
| "--flex-kernel-options-json must decode to a JSON object." | ||
| ) |
| flex_kernel_options = None | ||
| if args.flex_kernel_options_json is not None: | ||
| flex_kernel_options = json.loads(args.flex_kernel_options_json) | ||
| if not isinstance(flex_kernel_options, dict): | ||
| raise ValueError( | ||
| "--flex-kernel-options-json must decode to a JSON object." | ||
| ) |
c98ff66 to
1125482
Compare
Summary
--flex-kernel-options-jsonto DFlash and Domino training CLIskernel_optionsin the flex attention draft forward pathWhy
On RTX A6000 (sm86), we hit a training failure when using
attention_backend=flex_attention.The root cause was not an NCCL error and not a generic CUDA OOM. The failure came from the Triton kernel configuration selected for flex attention during training: the generated configuration could exceed the sm86 shared-memory-per-block opt-in limit.
In practice, this means the default flex attention kernel configuration is not always viable on A6000 for these training paths.
Local symptom on A6000
A representative workaround on RTX A6000 is to explicitly pass:
--flex-kernel-options-json '{"num_stages": 2}'This reduces the selected kernel aggressiveness enough for the training path to proceed on our A6000 setup.
Scope
This PR does not hardcode any hardware-specific default.
It only exposes the existing lower-level flex attention
kernel_optionscapability through the DFlash and Domino training CLIs, so users can tune it from the command line when needed.Example
python scripts/train_domino.py \ --attention-backend flex_attention \ --flex-kernel-options-json '{"num_stages": 2}'The same flag is also available in
scripts/train_dflash.py.Validation
num_stages=2to be passed from the CLI