Skip to content

perf: use fused gdn gating for qwen3.5 prefill.#1301

Closed
yingxudeng wants to merge 1 commit intojd-opensource:mainfrom
yingxudeng:perf/qwen3-gdn-prefill-fusion-2
Closed

perf: use fused gdn gating for qwen3.5 prefill.#1301
yingxudeng wants to merge 1 commit intojd-opensource:mainfrom
yingxudeng:perf/qwen3-gdn-prefill-fusion-2

Conversation

@yingxudeng
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces manual PyTorch operations with a fused GDN gating kernel in the prefill path of the Qwen3 Gated Delta Net and adds a new test case for large batch sizes. Review feedback suggests optimizing tensor reshaping by using .reshape() instead of the .contiguous().view() pattern and removing redundant .squeeze(0) and .contiguous() calls on kernel outputs to improve performance in the prefill hot path.

Comment on lines +398 to +399
gdn_params.a = a.contiguous().view({-1, a.size(-1)});
gdn_params.b = b.contiguous().view({-1, b.size(-1)});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Prefer using .reshape() over the .contiguous().view() pattern. reshape() is more idiomatic in PyTorch; it returns a view if the tensor is already contiguous and only performs a copy if necessary. This avoids redundant operations and potential memory allocations if the input tensors a and b are already contiguous.

Suggested change
gdn_params.a = a.contiguous().view({-1, a.size(-1)});
gdn_params.b = b.contiguous().view({-1, b.size(-1)});
gdn_params.a = a.reshape({-1, a.size(-1)});
gdn_params.b = b.reshape({-1, b.size(-1)});

Comment on lines +404 to +405
g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)});
beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calls to .squeeze(0) and .contiguous() are redundant here. Tensors returned by custom kernels are typically contiguous, and view() can handle the reshaping directly from the kernel's output shape (whether it is [total_tokens, hidden] or [1, total_tokens, hidden]) to the target [batch, seq, hidden] shape. Removing these unnecessary calls improves performance in the prefill hot path.

Suggested change
g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)});
beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)});
g = g.view({batch_size, seq_len, a.size(-1)});
beta = beta.view({batch_size, seq_len, b.size(-1)});

@yingxudeng
Copy link
Copy Markdown
Collaborator Author

yingxudeng commented Apr 16, 2026

用融合算子前:
image
用融合算子后:
image

用融合算子后,性能略微裂化,此pr暂不合并

@yingxudeng
Copy link
Copy Markdown
Collaborator Author

image 修改代码使其编译出来 bs3520 对应kernel,并命中,性能有提升但是仍然不如原版。因此暂不合并。

@yingxudeng yingxudeng force-pushed the perf/qwen3-gdn-prefill-fusion-2 branch from e7994c6 to 05a5e86 Compare April 16, 2026 18:01
@yingxudeng yingxudeng closed this Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant