Fix: Plumb vmem limit and bkv_compute_in to the custom kernel#416
Fix: Plumb vmem limit and bkv_compute_in to the custom kernel#416eltsai wants to merge 1 commit into
Conversation
|
Pretty great find @eltsai, awesome work! Few questions
|
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request correctly plumbs the vmem_limit_bytes configuration option and other block-level tuning parameters (such as block_kv_compute_in and heads_per_tile) all the way down to the custom TPU Pallas splash attention kernel. This allows executing under raised scoped VMEM limits, unlocking optimal layout sweeping and a ~3.7% performance improvement in denoising.
🔍 General Feedback
- Preservation of Configuration Parameters: Introducing the hashable, frozen
CustomFlashBlockSizescarrier is an excellent design decision to prevent JAX from silently dropping these custom properties. - Robust Config Extraction: The extraction of custom properties in
attention_flax.pycorrectly handles both raw dicts and custom objects, which provides excellent resilience against different configuration entry points. - Consistency Check: Please verify that
tpu_custom_attention(often used by TorchAX or standalone tests) uses the same object-safe extraction logic to preventAttributeErrorcrashes, as detailed in the inline comment.
| block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in) | ||
| heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) | ||
| vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) | ||
|
|
There was a problem hiding this comment.
🔴 With the introduction of the frozen dataclass CustomFlashBlockSizes in max_utils.py to preserve custom properties across JAX boundaries, passing flash_block_sizes to tpu_custom_attention (e.g., via the TorchAX SDPA wrapper path in make_custom_splash_sdpa) will result in a runtime crash. Since flash_block_sizes is no longer guaranteed to be a dictionary, calling .get() on a CustomFlashBlockSizes instance will raise an AttributeError.
We should safely extract values by checking if flash_block_sizes is a dictionary or an object, similar to how it is done in attention_flax.py.
The entire block starting from line 597 should be refactored as follows:
if flash_block_sizes is not None:
if isinstance(flash_block_sizes, dict):
block_q = flash_block_sizes.get("block_q", block_q)
block_kv = flash_block_sizes.get("block_kv", block_kv)
block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
else:
block_q = getattr(flash_block_sizes, "block_q", block_q)
block_kv = getattr(flash_block_sizes, "block_kv", block_kv)
block_kv_compute = getattr(flash_block_sizes, "block_kv_compute", block_kv_compute)
block_kv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", block_kv_compute_in)
heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes)|
@Perseus14 (not sure why I can't directly reply to the comments)
Yes, it's read from the flash_block_sizes dict in the config, for example.:
No, they're separate knobs at different layers and neither feeds the other.
before this commit, |
Right now the
vmem_limit_bytesandblock_kv_compute_infrom the argument was never passed into the kernel.This PR plumbs
vmem_limit_bytesfrom theflash_block_sizesconfig down to the MosaicCompilerParams so the custom splash kernel can run under a raised scoped-VMEM
limit.
custom_splash_attention.py:vmem_limit_bytesparam through both forwards ->pltpu.CompilerParams,make_splash_mha,tpu_custom_attention(+ flash_block_sizesread), make_custom_splash_sdpa (+ kwargs read).
attention_flax.py:read vmem_limit_bytesinwrap_ulysses_attention, pass tomake_splash_mha.max_utils.py:CustomFlashBlockSizesfrozen carrier sovmem_limit_bytes(andblock_kv_compute_in/heads_per_tile) survive config -> kernel instead of beingdropped by the JAX BlockSizes dataclass.
By raising the vmem to the 64 MB limit, our grid search finds a better config:
E2E result:
Comparing to previous config (
BQ=4864), the denoising steps is 3.7% faster (122.6 sec vs 127.3 sec).