perf: use fused gdn gating for qwen3.5 prefill.#1301
perf: use fused gdn gating for qwen3.5 prefill.#1301yingxudeng wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces manual PyTorch operations with a fused GDN gating kernel in the prefill path of the Qwen3 Gated Delta Net and adds a new test case for large batch sizes. Review feedback suggests optimizing tensor reshaping by using .reshape() instead of the .contiguous().view() pattern and removing redundant .squeeze(0) and .contiguous() calls on kernel outputs to improve performance in the prefill hot path.
| gdn_params.a = a.contiguous().view({-1, a.size(-1)}); | ||
| gdn_params.b = b.contiguous().view({-1, b.size(-1)}); |
There was a problem hiding this comment.
Prefer using .reshape() over the .contiguous().view() pattern. reshape() is more idiomatic in PyTorch; it returns a view if the tensor is already contiguous and only performs a copy if necessary. This avoids redundant operations and potential memory allocations if the input tensors a and b are already contiguous.
| gdn_params.a = a.contiguous().view({-1, a.size(-1)}); | |
| gdn_params.b = b.contiguous().view({-1, b.size(-1)}); | |
| gdn_params.a = a.reshape({-1, a.size(-1)}); | |
| gdn_params.b = b.reshape({-1, b.size(-1)}); |
| g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)}); | ||
| beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)}); |
There was a problem hiding this comment.
The calls to .squeeze(0) and .contiguous() are redundant here. Tensors returned by custom kernels are typically contiguous, and view() can handle the reshaping directly from the kernel's output shape (whether it is [total_tokens, hidden] or [1, total_tokens, hidden]) to the target [batch, seq, hidden] shape. Removing these unnecessary calls improves performance in the prefill hot path.
| g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)}); | |
| beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)}); | |
| g = g.view({batch_size, seq_len, a.size(-1)}); | |
| beta = beta.view({batch_size, seq_len, b.size(-1)}); |
e7994c6 to
05a5e86
Compare



No description provided.