Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@
VECTOR_BYTES_PER_ITER = 256
SUPPORTED_NUM_HEADS = (4, 6, 8, 12, 16, 24, 32, 48, 64, 128)
MAX_VEC_CORE_NUM = detect_vec_core_num()
BATCH_SIZE_SPECIALIZATIONS = tuple(range(2, 49, 2))


def _build_batch_size_specializations() -> tuple[int, ...]:
small_dense = range(2, 49, 2)
large_sparse = range(64, DEFAULT_MAX_BATCH + 1, 64)
return tuple(sorted(set((*small_dense, *large_sparse))))


BATCH_SIZE_SPECIALIZATIONS = _build_batch_size_specializations()


def select_launch_block_num(*, num_batches: int, vec_core_num: int) -> int:
Expand Down
5 changes: 3 additions & 2 deletions xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ void run_tilelang_fused_gdn_gating_chunk(const torch::Tensor& A_log,

auto specialization = build_runtime_specialization(a);
const auto* entry = find_fused_gdn_gating_kernel_entry(specialization);
// Expected fast path: compiled batch_size variants are dense [2, 4, ..., 48].
// If a value is missing, fall back to the nearest smaller batch_size.
// Small batch-size variants are dense [2, 4, ..., 48]. Larger long-prefill
// variants are compiled sparsely and dispatch falls back to the nearest
// smaller batch_size when an exact value is not available.
if (entry == nullptr) {
int32_t fallback_batch_size =
specialization.batch_size - kBatchSpecializationStep;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ TEST_F(TileLangFusedGdnGatingWrapperTest, MatchesTorchReference) {
.num_heads = 32,
.seed = 106,
},
{
.name = "long_prefill_b3533_h32",
.num_batches = 3533,
.num_heads = 32,
.seed = 108,
},
{
.name = "custom_beta2_threshold0p5_b33_h64",
.num_batches = 33,
Expand Down
19 changes: 10 additions & 9 deletions xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,16 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(

// Compute gated delta net decay and beta terms.
if (attn_metadata.is_prefill) {
beta = torch::sigmoid(b);
torch::Tensor A_log_exp = A_log_.exp();
torch::Tensor a_float = a.to(torch::kFloat32);
torch::Tensor a_plus_dt = a_float + dt_bias_;
torch::Tensor softplus_out = torch::nn::functional::softplus(
a_plus_dt,
torch::nn::functional::SoftplusFuncOptions().beta(1.0).threshold(20.0));
g = -A_log_exp * softplus_out;
g = g.to(a.dtype()).contiguous();
xllm::kernel::FusedGdnGatingParams gdn_params;
gdn_params.A_log = A_log_;
gdn_params.a = a.contiguous().view({-1, a.size(-1)});
gdn_params.b = b.contiguous().view({-1, b.size(-1)});
Comment on lines +398 to +399
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Prefer using .reshape() over the .contiguous().view() pattern. reshape() is more idiomatic in PyTorch; it returns a view if the tensor is already contiguous and only performs a copy if necessary. This avoids redundant operations and potential memory allocations if the input tensors a and b are already contiguous.

Suggested change
gdn_params.a = a.contiguous().view({-1, a.size(-1)});
gdn_params.b = b.contiguous().view({-1, b.size(-1)});
gdn_params.a = a.reshape({-1, a.size(-1)});
gdn_params.b = b.reshape({-1, b.size(-1)});

gdn_params.dt_bias = dt_bias_;
gdn_params.beta = 1.0f;
gdn_params.threshold = 20.0f;
std::tie(g, beta) = xllm::kernel::fused_gdn_gating(gdn_params);
g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)});
beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)});
Comment on lines +404 to +405
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calls to .squeeze(0) and .contiguous() are redundant here. Tensors returned by custom kernels are typically contiguous, and view() can handle the reshaping directly from the kernel's output shape (whether it is [total_tokens, hidden] or [1, total_tokens, hidden]) to the target [batch, seq, hidden] shape. Removing these unnecessary calls improves performance in the prefill hot path.

Suggested change
g = g.squeeze(0).contiguous().view({batch_size, seq_len, a.size(-1)});
beta = beta.squeeze(0).contiguous().view({batch_size, seq_len, b.size(-1)});
g = g.view({batch_size, seq_len, a.size(-1)});
beta = beta.view({batch_size, seq_len, b.size(-1)});

} else {
xllm::kernel::FusedGdnGatingParams gdn_params;
gdn_params.A_log = A_log_;
Expand Down
Loading