diff --git a/.agents/skills/code-review/SKILL.md b/.agents/skills/code-review/SKILL.md new file mode 100644 index 000000000..c709c2e0d --- /dev/null +++ b/.agents/skills/code-review/SKILL.md @@ -0,0 +1,115 @@ +--- +name: code-review +description: Review code changes for quality, security, performance, and correctness following project-specific standards. Use when reviewing pull requests, examining git diffs, or when the user asks for a code review. This skill should be used proactively — when the user asks for a review without specifying commits, automatically detect the current branch and diff against the main branch. +--- + +# Code Review + +## Workflow + +### Step 1: Determine the diff + +If the user provides explicit SHAs or a PR link, use those. Otherwise, **auto-detect**: + +```bash +# Fetch latest remote state +git fetch origin main --quiet + +# Detect current branch +CURRENT_BRANCH=$(git branch --show-current) + +# Find the merge base with origin/main +MERGE_BASE=$(git merge-base origin/main HEAD) + +# Show what changed +git diff --stat $MERGE_BASE..HEAD +git diff $MERGE_BASE..HEAD +``` + +If `CURRENT_BRANCH` is `main`, warn the user and ask which commits to review. + +### Step 2: Read project standards + +Read [custom-code-style.md](references/custom-code-style.md) for project-specific coding style. + +### Step 3: Review against the checklist + +**Correctness:** +- Logic handles edge cases and boundary conditions +- Error handling is comprehensive (no silent failures) +- Type safety maintained (no unsafe casts, proper use of `std::optional`) +- Resource lifecycle correct (RAII, no leaks, proper cleanup order) + +**Architecture:** +- Clean separation of concerns, no layer violations +- Dependencies flow in the correct direction +- Changes align with existing patterns in the codebase +- No unnecessary coupling introduced + +**Performance & Concurrency:** +- No performance regressions on hot paths +- Thread safety: proper locking, no data races +- CUDA/NPU kernels: memory coalescing, occupancy, sync correctness +- No unnecessary copies of large objects (tensors, vectors) + +**Testing:** +- Tests verify actual logic, not just mock wiring +- Edge cases and error paths covered +- Integration tests for cross-component changes + +**Production Readiness:** +- Backward compatibility maintained (or breaking changes documented) +- Migration strategy for schema/config changes +- No hardcoded values that should be configurable + +### Step 4: Output findings + +Use the format below. + +## Output Format + +### Strengths +[Specific things done well, with file:line references] + +### Issues + +#### Critical (Must Fix) +[Bugs, security holes, data loss risks, broken functionality] + +#### Important (Should Fix) +[Architecture problems, missing error handling, test gaps, performance issues] + +#### Minor (Nice to Have) +[Style, optimization opportunities, documentation improvements] + +**Each issue must include:** +- **File:line** reference +- **What** is wrong +- **Why** it matters +- **How** to fix (if not obvious) + +### Recommendations +[Broader improvements for code quality, architecture, or process] + +### Assessment + +**Ready to merge?** [Yes / No / With fixes] + +**Reasoning:** [1-2 sentence technical assessment] + +## Rules + +**DO:** +- Apply project-specific style from [custom-code-style.md](references/custom-code-style.md) +- Follow DDD (Domain Driven Design) principles, and keep the codebase clean and maintainable +- Categorize by actual severity (not everything is Critical) +- Be specific with file:line references +- Explain WHY issues matter +- Acknowledge strengths +- Give a clear verdict + +**DON'T:** +- Approve without thorough review +- Mark nitpicks as Critical +- Give feedback on code not in the diff +- Be vague (e.g., "improve error handling" without specifics) diff --git a/.agents/skills/code-review/references/custom-code-style.md b/.agents/skills/code-review/references/custom-code-style.md new file mode 100644 index 000000000..7801bd811 --- /dev/null +++ b/.agents/skills/code-review/references/custom-code-style.md @@ -0,0 +1,316 @@ +# Custom Code Style + +Project-specific coding style for xllm. The reviewer **MUST** enforce these style. + +--- + +## 1. Naming Conventions + +### C++ + +| Element | Style | Example | +|------------------|------------------------------------|--------------------------------------| +| Namespace | `snake_case` | `xllm`, `xllm::detail` | +| Class / Struct | `PascalCase` | `LlmModelImplBase`, `KVCache` | +| Function | `snake_case` | `get_input_embeddings`, `forward` | +| Member variable | `snake_case_` (trailing underscore)| `model_type_`, `embed_tokens_` | +| Local variable | `snake_case` | `inputs_embeds`, `kv_caches` | +| Constant | `k` + `PascalCase` | `kContentLength`, `kMaxBatchSize` | +| Enum type | `PascalCase` | `EngineType`, `DeviceType` | +| Enum value | `ALL_CAPS` | `LLM`, `VLM`, `INVALID` | +| Template param | `PascalCase` | `DecoderLayerType` | +| Macro | `ALL_CAPS` | `XLLM_CHECK`, `LOG_EVERY_N` | +| File name | `snake_case` | `llm_model_base.h`, `types.h` | +| Header guard | `#pragma once` | - | + +### Python + +| Element | Style | Example | +|------------------|------------------------|--------------------------------------| +| Module / file | `snake_case` | `model_loader.py` | +| Class | `PascalCase` | `TokenizerConfig` | +| Function | `snake_case` | `load_model` | +| Variable | `snake_case` | `batch_size` | +| Constant | `ALL_CAPS` | `MAX_SEQ_LEN` | +| Private member | `_leading_underscore` | `_internal_state` | + +--- + +## 2. File & Header Rules + +- **Copyright header required** on all new files. Use the correct year matching the file creation date. + +```cpp +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +... +==============================================================================*/ +``` + +- **No relative paths in `#include`**. Always use project-root-relative paths. + +```cpp +// Good +#include "core/common/types.h" + +// Bad +#include "../common/types.h" +#include "./types.h" +``` + +- **Remove redundant and duplicate includes**. Each header should be included exactly once, and unused includes must be cleaned up. + +--- + +## 3. Type System & Declarations + +- **Use fixed-width integers** (`int32_t`, `int64_t`) instead of plain `int`, unless the API you are calling explicitly requires `int`. + +```cpp +// Good +int32_t batch_size = 16; +int64_t total_tokens = 0; + +// Bad +int batch_size = 16; +``` + +- **Use `static_cast`** for all type conversions. Never use C-style casts. + +```cpp +// Good +auto len = static_cast(vec.size()); + +// Bad +auto len = (int32_t)vec.size(); +``` + +- **Do not use `auto` for simple/primitive types**. `auto` is acceptable for complex types (iterators, lambdas, template-deduced types) but not for `int32_t`, `float`, `bool`, `std::string`, etc. + +```cpp +// Good +int32_t count = 0; +auto it = map.find(key); // complex iterator type, auto is fine + +// Bad +auto count = 0; +auto name = std::string("model"); +``` + +- **Use `using` instead of `typedef`** for type aliases. Prefer aliases for complex types to improve readability. + +```cpp +// Good +using TensorVec = std::vector; +using CallbackFn = std::function; + +// Bad +typedef std::vector TensorVec; +``` + +- **Use `enum class`** instead of plain `enum` to provide type safety and prevent implicit conversions. + +```cpp +// Good +enum class DeviceType : int8_t { CPU = 0, CUDA = 1, NPU = 2 }; + +// Bad +enum DeviceType { CPU = 0, CUDA = 1, NPU = 2 }; +``` + +- **Use `nullptr`** instead of `NULL` or `0` for null pointers. + +- **Choose the right container**: use `std::unordered_map` / `std::unordered_set` when key ordering is irrelevant (O(1) average lookup). Use `std::map` / `std::set` only when sorted iteration or key ordering is required. + +--- + +## 4. Class Design + +- **Mark classes `final`** if they are not designed to be inherited from. + +```cpp +// Good +class TokenizerConfig final { ... }; + +// Bad – class has no virtual functions and is not intended as a base +class TokenizerConfig { ... }; +``` + +- **Use `explicit`** on any constructor that can be invoked with a single argument. This includes multi-parameter constructors where all parameters except the first have default values. + +```cpp +// Good +explicit ModelArgs(const std::string& path, int32_t num_layers = 12); + +// Bad – allows implicit conversion from std::string +ModelArgs(const std::string& path, int32_t num_layers = 12); +``` + +- **Use `override`** when overriding virtual functions in derived classes. Never repeat the `virtual` keyword on overrides. + +```cpp +// Good +ModelOutput forward(torch::Tensor tokens, ...) override; + +// Bad +virtual ModelOutput forward(torch::Tensor tokens, ...); +``` + +- **Structs must not have member functions**. If you need methods, use a `class`. Structs are for plain data aggregation only. + +--- + +## 5. Memory & Resource Management + +- **Avoid raw pointers**. Prefer smart pointers (`std::unique_ptr`, `std::shared_ptr`) for ownership semantics. + - Use `std::unique_ptr` by default (sole ownership). + - Use `std::shared_ptr` only when shared ownership is genuinely needed. + - Raw pointers are acceptable only for non-owning references where the lifetime is clearly managed elsewhere. + +--- + +## 6. Scoping & Visibility + +### C++ + +- **File-local functions and variables** (used only within a single `.cpp` file) must be placed in an **anonymous namespace**. + +```cpp +namespace { +int32_t compute_padding(int32_t seq_len, int32_t alignment) { + return (alignment - seq_len % alignment) % alignment; +} +} // namespace +``` + +### Python + +- **File-local helper functions** (not part of the public API) must be prefixed with `_`. +- **Non-public member functions** of a class must be prefixed with `_`. + +```python +def _validate_config(config: dict) -> bool: + ... + +class ModelLoader: + def load(self, path: str) -> Model: + self._check_path(path) + ... + + def _check_path(self, path: str) -> None: + ... +``` + +--- + +## 7. Torch & Framework API Usage + +- **Use `torch::` namespace** instead of `at::` or `c10::` wherever possible. Prefer the highest-level PyTorch C++ API. + +```cpp +// Good +torch::Tensor output = torch::zeros({batch_size, hidden_dim}); + +// Bad +at::Tensor output = at::zeros({batch_size, hidden_dim}); +c10::optional mask = c10::nullopt; // use std::optional +``` + +- **Use `CHECK`** (glog) instead of `TORCH_CHECK` for assertions. + +```cpp +// Good +CHECK(tensor.is_contiguous()) << "Input tensor must be contiguous"; + +// Bad +TORCH_CHECK(tensor.is_contiguous(), "Input tensor must be contiguous"); +``` + +- **Use `LOG(FATAL)`** for unrecoverable errors instead of throwing `std::runtime_error`. + +```cpp +// Good +LOG(FATAL) << "Unsupported model type: " << model_type; + +// Bad +throw std::runtime_error("Unsupported model type: " + model_type); +``` + +--- + +## 8. Code Style & Control Flow + +- **Always use braces `{}`** with `if`, `while`, `for`, even for single-line bodies. + +```cpp +// Good +if (x > 0) { + return x; +} + +// Bad +if (x > 0) return x; +``` + +- **Avoid `if` inside `for` loops** when possible. Prefer filtering the data beforehand or restructuring the logic (e.g., early `continue`, separate loops, `std::copy_if`). + +- **Define variables close to first use**. Do not declare all variables at the top of a function. + +- **Annotate constant arguments** with a comment indicating the parameter name when calling functions or constructors. + +```cpp +// Good +auto layer = DecoderLayer(/*hidden_size=*/4096, /*num_heads=*/32); + +// Bad +auto layer = DecoderLayer(4096, 32); +``` + +--- + +## 9. STL Best Practices + +- **Always `reserve()` before filling a `std::vector`** when the size is known or can be estimated. + +```cpp +// Good +std::vector outputs; +outputs.reserve(num_layers); +for (int32_t i = 0; i < num_layers; ++i) { + outputs.emplace_back(compute_layer(i)); +} + +// Bad – causes multiple reallocations +std::vector outputs; +for (int32_t i = 0; i < num_layers; ++i) { + outputs.push_back(compute_layer(i)); +} +``` + +- **Prefer `emplace_back`** over `push_back` to construct elements in-place and avoid unnecessary copies. + +--- + +## 10. Global Flags + +- **Do not overuse `FLAGS_` global variables**. Prefer passing configuration through constructor parameters or config structs. Only use global flags for top-level, process-wide settings. +- **Register new flags in `help_formatter.h`**. When adding a new global flag, always add a corresponding entry in `help_formatter.h` so it appears in `--help` output. + +--- + +## 11. Python-Specific Rules + +- **Type annotations are required** on all function signatures (parameters and return types). Use `typing` module types where needed. + +```python +# Good +def load_model(path: str, device: str = "cuda") -> nn.Module: + ... + +# Bad +def load_model(path, device="cuda"): + ... +``` + +- **Private helpers**: prefix with `_` (see Section 6). diff --git a/.agents/skills/git-workflow/SKILL.md b/.agents/skills/git-workflow/SKILL.md new file mode 100644 index 000000000..5e48ed715 --- /dev/null +++ b/.agents/skills/git-workflow/SKILL.md @@ -0,0 +1,49 @@ +--- +name: git-workflow +description: Use when the task involves Git operations for the public xLLM repository, including choosing branch or tag names, preparing commits and pull requests, backporting fixes, checking repo-specific review expectations, or drafting commit messages from actual diffs. +--- + +# Git Workflow + +Use xLLM repo reality, not generic Git habits. + +## Reference Map + +Load only the file that matches the user's immediate Git task. + +| File | What it is for | When to load it | +| --- | --- | --- | +| `references/source-of-truth.md` | Repo-specific source priority and canonical files to consult | Load first when repo docs, local state, and user wording may disagree | +| `references/branch-naming.md` | Branch naming patterns and default branch conventions | Load when the user asks how to name a branch or which branch to branch from | +| `references/development-flow.md` | Day-to-day fork, sync, branch, validate, and push flow | Load when the user asks for normal development steps from local change to push | +| `references/pr-review.md` | PR targeting, PR scope, and review expectations | Load when the task is about opening a PR, choosing the target branch, or deciding who should review | +| `references/release-layout.md` | Release branch and tag shapes used by xLLM | Load when the task mentions release branches, release tags, or patch version naming | +| `references/backport-flow.md` | Preferred backport and hotfix flow for released lines | Load when the task mentions cherry-picks, hotfixes, or fixing an already released branch | +| `references/commit-format.md` | Commit title/body conventions and xLLM-style examples | Load when the user asks for a commit message, commit style guidance, or message cleanup | + +## Workflow + +1. Decide which subtask the user actually needs. +2. Read `references/source-of-truth.md` when you need repo-specific confirmation. +3. Then load only the most relevant task file from the table above. +4. For commit message drafting, run `bash scripts/collect_git_context.sh [--staged|--all|--unstaged]` before writing the final message. +5. Draft commit messages from the actual diff, not from filenames alone. +6. If one diff mixes unrelated concerns, recommend splitting the commit instead of forcing one vague summary. +7. Default PR targets to `main` unless the task is clearly a release or backport flow. +8. For released lines, prefer landing on `main` first and then backporting unless the user explicitly wants a direct hotfix flow. + +## Output + +Return the smallest useful answer for the user's Git task: + +- workflow questions: concrete branch, tag, PR, or backport steps +- commit message requests: `: ` plus an optional short bullet body +- repo-convention questions: current xLLM-specific guidance, not generic Git advice + +## Quick Checks + +- branch names match xLLM style such as `feat/`, `bugfix/`, or `release/vX.Y.Z` +- PR target is `main` unless this is a release or backport task +- commit title matches the dominant change in the diff +- release tags use semantic versions such as `v0.9.0` or `v0.9.1` +- owner-review expectations come from `.github/CODEOWNERS` when relevant diff --git a/.agents/skills/git-workflow/references/backport-flow.md b/.agents/skills/git-workflow/references/backport-flow.md new file mode 100644 index 000000000..d8a594e06 --- /dev/null +++ b/.agents/skills/git-workflow/references/backport-flow.md @@ -0,0 +1,30 @@ +# Backport Flow + +When a released line needs a bugfix, prefer this flow: + +1. land the fix on `main` first unless the user explicitly needs a direct hotfix flow +2. cherry-pick or backport the fix to the matching `release/vX.Y.Z` branch +3. update release content as needed on that release branch +4. create the next patch tag for that release line + +Example shape: + +```bash +# land on main first +git checkout main +git pull --rebase upstream main +git checkout -b bugfix/ +git commit -m "bugfix: fix ." + +# then backport +git checkout release/v0.9.0 +git pull --rebase upstream release/v0.9.0 +git cherry-pick +git tag v0.9.1 +``` + +## Quick Checklist + +- backport from a commit already landed on `main` when possible +- cherry-pick onto the matching `release/vX.Y.Z` branch +- use the next semantic patch tag for the release line diff --git a/.agents/skills/git-workflow/references/branch-naming.md b/.agents/skills/git-workflow/references/branch-naming.md new file mode 100644 index 000000000..b1bf00805 --- /dev/null +++ b/.agents/skills/git-workflow/references/branch-naming.md @@ -0,0 +1,54 @@ +# Branch Naming + +Use one of these branch shapes: + +```text +/ +// +preview/ +release/vX.Y.Z +``` + +Recommended lowercase branch types: + +- `feat`: new user-visible capability or feature work +- `bugfix`: incorrect behavior, regressions, or hot fixes +- `refactor`: structural changes without intended behavior changes +- `docs`: documentation-only work when a dedicated branch is useful +- `test`: test-only work when separated from product changes +- `perf`: runtime or memory improvements +- `chore`: repo maintenance that does not fit the other categories +- `build`: dependency, CI, packaging, or release tooling changes + +Topic guidelines: + +- use lowercase letters, numbers, and hyphens by default +- keep the topic short, specific, and review-friendly +- prefer nouns or short noun phrases like `scheduler`, `lm-head-new`, `npu-template` +- avoid spaces, uppercase letters, and vague names like `misc`, `temp`, `test-branch` +- avoid repeating the type in the topic, such as `feat/feature-x` +- use slash-separated namespace prefixes only when the work clearly belongs to a scoped stream + +Scoped branch guidelines: + +- use `//` for team-, model-, or project-scoped work such as `dsv4/feat/rope-dsv4` +- keep the namespace stable and meaningful, not personal or temporary +- use `preview/` only for preview-track work that intentionally aligns with preview branches upstream +- reserve `release/vX.Y.Z` for release preparation or release-only changes +- avoid direct development on `main` and `release/*` unless the user explicitly asks for it + +Examples: + +- `feat/skills` +- `feat/lm_head_new` +- `bugfix/scheduler` +- `refactor/npu_template` +- `preview/glm-5` +- `release/v0.9.0` + +Notes for xLLM: + +- `main` is the default development branch +- `feat/*`, `bugfix/*`, and `refactor/*` appear in current branch usage and are safe defaults +- `preview/*` and `release/*` are long-lived integration branches, not ordinary personal topic branches +- both hyphen and underscore appear in existing history, but prefer hyphens for new branch topics unless matching an established naming family diff --git a/.agents/skills/git-workflow/references/commit-format.md b/.agents/skills/git-workflow/references/commit-format.md new file mode 100644 index 000000000..1465ffb65 --- /dev/null +++ b/.agents/skills/git-workflow/references/commit-format.md @@ -0,0 +1,56 @@ +# Commit Format + +Use this exact first-line format: + +```text +: +``` + +Allowed lowercase types: + +- `feat`: add user-visible behavior or a new capability +- `bugfix`: correct incorrect behavior or a regression +- `docs`: change documentation only +- `test`: add or update tests only +- `refactor`: improve structure without changing intended behavior +- `chore`: repository maintenance that does not fit the other types +- `style`: formatting or style-only changes without logic changes +- `revert`: revert an earlier commit +- `perf`: improve runtime or memory behavior +- `model`: change model definitions, checkpoints, prompts, or inference behavior +- `build`: change build, release, or dependency wiring +- `release`: change release versioning, release notes, or release-only metadata + +Subject guidelines: + +- use lowercase letters by default +- include at least 4 words +- end with a period +- start with a verb like `add`, `fix`, `remove`, `refactor`, `document` +- keep it specific enough that a reviewer understands the main change +- avoid filler like `update`, `misc`, `stuff`, `changes` +- describe the effect or intent, not a mechanical file list + +Body guidelines: + +- add a body only when the title alone is not enough +- use short bullets for secondary details or important context +- mention follow-up work, migration steps, or compatibility impact when relevant +- if confidence is low because the diff is partial or noisy, say that explicitly + +Observed xLLM-style examples: + +- `feat: add rope_in_place tilelang kernel for npu device. (#964)` +- `bugfix: align rec initialization flags with options. (#1142)` +- `docs: update the document to align them with the latest code. (#1113)` +- `refactor: extract multi-modal input processors to processors dir. (#1022)` +- `perf: reserve vector capacity before batch push_back. (#1089)` +- `release: update xllm release version to v0.9.0. (#1124)` +- `feat: support qwen3.5/qwen3.5-moe mtp draft model for speculative decoding[3/N]. (#1119)` + +Notes for xLLM: + +- `bugfix:` appears more often than `fix:` in current history and should be the default bug-repair prefix +- `fix:` exists in a few historical commits but is less consistent than `bugfix:` +- PR-number suffixes like `(#1142)` are common in merged history but are optional unless the user explicitly wants them +- staged series markers like `[1/N]` or `[3/N]` appear when a change is intentionally split across multiple commits diff --git a/.agents/skills/git-workflow/references/development-flow.md b/.agents/skills/git-workflow/references/development-flow.md new file mode 100644 index 000000000..62210afc5 --- /dev/null +++ b/.agents/skills/git-workflow/references/development-flow.md @@ -0,0 +1,33 @@ +# Development Flow + +Follow this sequence unless the user asks for a different workflow: + +1. fork the upstream repository +2. sync local `main` with `upstream/main` +3. create a focused topic branch from `main` +4. implement the change +5. run formatting and the narrowest relevant validation +6. commit in clear English +7. push to the fork +8. open a PR to upstream `main` + +Example commands: + +```bash +git fetch upstream +git checkout main +git pull --rebase upstream main +git checkout -b feat/ + +# after development +git add +git commit -m "feat: add ." +git push origin feat/ +``` + +## Quick Checklist + +- branch from `main` unless this is a release or backport task +- keep the branch focused on one change +- run the narrowest relevant validation before commit +- push to your fork before opening the PR diff --git a/.agents/skills/git-workflow/references/pr-review.md b/.agents/skills/git-workflow/references/pr-review.md new file mode 100644 index 000000000..b7818f854 --- /dev/null +++ b/.agents/skills/git-workflow/references/pr-review.md @@ -0,0 +1,28 @@ +# PR And Review + +## Pull Request Guidance + +The public repo guidance is lightweight: + +- `README.md` and `CONTRIBUTING.md` ask contributors to fork, create a branch, and send a pull request +- keep PRs focused and easy to review, even though the public docs do not publish a hard line-count limit +- write commit messages and PR descriptions in clear English +- avoid unnecessary merge noise in branch history; prefer a clean linear history when practical + +## Target Branch + +- use `main` unless this is explicitly a release or backport task + +## Review Expectations + +Use `.github/CODEOWNERS` as the visible review signal: + +- changes under `/xllm/` have listed code owners +- expect owner review or owner attention for those paths +- if the user asks who should review a change under `/xllm/`, check `CODEOWNERS` + +## Quick Checklist + +- PR target is `main` unless this is a backport or release task +- PR is focused and clearly described +- review expectations are checked through `CODEOWNERS` diff --git a/.agents/skills/git-workflow/references/release-layout.md b/.agents/skills/git-workflow/references/release-layout.md new file mode 100644 index 000000000..c2518ba29 --- /dev/null +++ b/.agents/skills/git-workflow/references/release-layout.md @@ -0,0 +1,15 @@ +# Release Layout + +Observed public release layout: + +- release branches use `release/vX.Y.Z` +- observed release branches include `release/v0.6.0`, `release/v0.7.0`, `release/v0.8.0`, and `release/v0.9.0` +- release notes are tracked in `RELEASE.md` +- public tags use semantic version tags such as `v0.9.0` +- patch tags use normal patch versions such as `v0.7.1` and `v0.7.2`, not `-rcN` + +## Naming Guardrails + +- do not switch to `release_0.1.0` branches unless the user explicitly wants an older internal workflow +- do not assume `v0.1.0-rc0` style tags unless the user explicitly asks for them +- prefer the next semantic patch tag for bugfix releases diff --git a/.agents/skills/git-workflow/references/source-of-truth.md b/.agents/skills/git-workflow/references/source-of-truth.md new file mode 100644 index 000000000..39b426a88 --- /dev/null +++ b/.agents/skills/git-workflow/references/source-of-truth.md @@ -0,0 +1,29 @@ +# Source Of Truth + +Use these repo files first when the user asks for xLLM-specific Git guidance: + +- `README.md` +- `CONTRIBUTING.md` +- `RELEASE.md` +- `.github/workflows/check_format.yml` +- `.pre-commit-config.yaml` +- `.github/CODEOWNERS` + +Guidance priority: + +1. direct user request +2. current repo files and branch reality +3. visible remote repo conventions +4. older internal notes or remembered habits + +When repo docs, local repo state, and user wording disagree: + +- say the conflict explicitly +- prefer the most concrete source available +- avoid inventing rules that are not visible in the public repo + +Common examples of rules you should not assume without evidence: + +- rebase-only or squash-only merge requirements +- mandatory reviewer counts beyond what `CODEOWNERS` implies +- old naming patterns such as `features/*`, `release_0.1.0`, or `v0.1.0-rc0` diff --git a/.agents/skills/git-workflow/scripts/collect_git_context.sh b/.agents/skills/git-workflow/scripts/collect_git_context.sh new file mode 100644 index 000000000..b5a37c36a --- /dev/null +++ b/.agents/skills/git-workflow/scripts/collect_git_context.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash + +set -euo pipefail + +mode="auto" + +if [[ $# -gt 1 ]]; then + echo "usage: $0 [--staged|--all|--unstaged]" >&2 + exit 1 +fi + +if [[ $# -eq 1 ]]; then + case "$1" in + --staged) + mode="staged" + ;; + --all) + mode="all" + ;; + --unstaged) + mode="unstaged" + ;; + *) + echo "unknown option: $1" >&2 + echo "usage: $0 [--staged|--all|--unstaged]" >&2 + exit 1 + ;; + esac +fi + +if ! git rev-parse --show-toplevel >/dev/null 2>&1; then + echo "not inside a git repository" >&2 + exit 1 +fi + +repo_root="$(git rev-parse --show-toplevel)" +cd "$repo_root" + +staged_count="$(git diff --cached --name-only | wc -l | tr -d ' ')" +untracked_files="$(git ls-files --others --exclude-standard)" + +if [[ "$mode" == "auto" ]]; then + if [[ "$staged_count" != "0" ]]; then + mode="staged" + else + mode="all" + fi +fi + +case "$mode" in + staged) + status_cmd=(git diff --cached --name-status) + stat_cmd=(git diff --cached --stat) + patch_cmd=(git diff --cached --unified=1 --no-color) + headline="staged changes" + ;; + unstaged) + status_cmd=(git diff --name-status) + stat_cmd=(git diff --stat) + patch_cmd=(git diff --unified=1 --no-color) + headline="unstaged changes" + ;; + all) + headline="all local changes" + ;; + *) + echo "invalid mode: $mode" >&2 + exit 1 + ;; + esac + +echo "repo: $repo_root" +echo "branch: $(git branch --show-current)" +echo "scope: $headline" +echo +echo "status:" +git status --short +echo + +if [[ -n "$untracked_files" ]]; then + echo "untracked files:" + printf '%s\n' "$untracked_files" + echo +fi + +if [[ "$mode" == "all" ]]; then + echo "changed files:" + git diff HEAD --name-status + echo + echo "diff stat:" + git diff HEAD --stat + echo + echo "patch excerpt:" + git diff HEAD --unified=1 --no-color | sed -n '1,400p' +else + echo "changed files:" + "${status_cmd[@]}" + echo + echo "diff stat:" + "${stat_cmd[@]}" + echo + echo "patch excerpt:" + "${patch_cmd[@]}" | sed -n '1,400p' +fi diff --git a/.agents/skills/tilelang-api-best-practices b/.agents/skills/tilelang-api-best-practices new file mode 120000 index 000000000..1689a11c2 --- /dev/null +++ b/.agents/skills/tilelang-api-best-practices @@ -0,0 +1 @@ +../../third_party/tilelang-ascend/.agents/skills/tilelang-custom-skill/tilelang-api-best-practices \ No newline at end of file diff --git a/.agents/skills/tilelang-ascend-kernel/SKILL.md b/.agents/skills/tilelang-ascend-kernel/SKILL.md new file mode 100644 index 000000000..c153cc905 --- /dev/null +++ b/.agents/skills/tilelang-ascend-kernel/SKILL.md @@ -0,0 +1,139 @@ +--- +name: tilelang-ascend-kernel +description: Use when the user wants to add, modify, debug, or review an xLLM TileLang Ascend kernel or specialization, including Python kernel definitions, generated Ascend-C source, runtime wrapper dispatch, TileLang CMake wiring, and NPU tests. +--- + +# TileLang Ascend Kernel + +## When to use + +Use this skill when the task involves any of the following in the xLLM repo: + +- `xllm/compiler/tilelang/targets/ascend/kernels/*.py` +- `xllm/core/kernels/npu/tilelang/*_wrapper.cpp` +- `xllm/core/kernels/npu/tilelang/CMakeLists.txt` +- generated TileLang artifacts such as `manifest.json`, `registry.inc`, or specialization `.cpp` + +Run build and test commands inside the NPU container. + +Run TileLang commands from the xLLM repo root, not from an installed-package environment. + +## Entry points and TL_ROOT + +- Prefer `python xllm/compiler/tilelang_launcher.py ...` for end-to-end TileLang compile flows. +- From the xLLM repo root, use `export TL_ROOT=$PWD/third_party/tilelang-ascend` for xLLM TileLang tooling and verify `test -f "$TL_ROOT/tilelang/__init__.py"`. +- Before any raw script does `import tilelang`, run `export TL_ROOT=$PWD/third_party/tilelang-ascend && source third_party/tilelang-ascend/set_env.sh`, then execute the script. +- Do not run kernel files directly with `python rope.py`; use module execution because these files rely on relative imports. +- For direct kernel-script debugging, run them as modules and pass required CLI args: + +```bash +cd xllm +python -m compiler.tilelang.targets.ascend.kernels.rope \ + --output ../.tmp/rope.cpp +# Expected: [INFO] RoPE output matches torch reference +``` + +- The same module-style rule applies to other kernel files under `xllm/compiler/tilelang/targets/ascend/kernels/`. + +## Primary Reference And Mode Preference + +Primary reference: + +- `third_party/tilelang-ascend/docs/TileLang-Ascend Programming Guide.md` + +Use `third_party/tilelang-ascend/.agents/skills/tilelang-custom-skill/tilelang-api-best-practices/references/api-tile-ops.md` when the task depends on `T.tile.xxx` semantics such as `compare`, `select`, `cast`, or other vector intrinsics. + +Default to Expert mode for xLLM Ascend kernels: + +- prefer `T.tile.xxx`, explicit UB/shared allocation, and explicit `T.copy` +- prefer explicit `T.serial` control for row/block traversal +- do not introduce Developer mode `T.Parallel` unless the kernel is a clearly tile-local element-wise expression and the change does not reduce control over UB usage, temporary buffers, or exact runtime semantics +- when translating Triton kernels, preserve the Triton runtime semantics first, then choose the smallest Expert-mode lowering that matches them + +## Common Triton To TileLang-Ascend Semantics + +Use this table as the quick semantic mapping when translating Triton kernels: + +| Triton pattern | TileLang-Ascend pattern | Notes | +| --- | --- | --- | +| `x + y`, `x - y`, `x * y`, `x / y` | `T.tile.add/sub/mul/div` | Prefer tile ops in Expert-style vector code instead of hand-written scalar loops. | +| `tl.exp(x)`, `tl.log(x)`, `tl.abs(x)` | `T.tile.exp`, `T.tile.ln`, `T.tile.abs` | TileLang uses `ln`, not `log`. | +| `x.to(tl.float32)` or `tl.cast(...)` | `T.tile.cast(dst, src, "CAST_NONE", count)` | Pick a non-default cast mode only when rounding semantics are required. | +| `x <= y`, `x < y`, `x >= y`, `x == y` | `T.tile.compare(mask, x, y, "LE"/"LT"/"GE"/"EQ")` | `T.tile.compare` produces a bit mask, not a float tensor. | +| `tl.where(cond, a, b)` | `T.tile.select(dst, selMask, a, b, selMode)` | API-level match. If `cond` is a comparison expression such as `x <= y`, materialize `selMask` with `T.tile.compare(...)` first; use the matching `VSEL_*` mode for tensor-tensor or tensor-scalar selection. | +| `tl.full(shape, value, dtype)` | allocate buffer + `T.tile.fill(dst, value)` | Separate allocation from initialization. | +| `tl.arange(0, N)` | `T.tile.createvecindex(dst, 0)` or explicit loop indices | Prefer `createvecindex` only when the kernel truly needs a vector index tensor. | + +Rules for semantic-preserving lowering: + +- Preserve Triton control-flow, masking, and parameter semantics. Do not substitute a numerically similar formula unless the runtime-visible behavior is unchanged for the supported input domain. +- Keep the kernel ABI aligned with the lowering. Every runtime parameter must either participate in the TileLang implementation or be removed from the interface. +- Add targeted tests for branch, mask, and boundary behavior. Do not rely only on random inputs if some paths are hit only under specific values. +- Choose the correct `VSEL_*` mode based on the source operands. `VSEL_CMPMASK_SPR` is the natural match for a mask produced by `T.tile.compare`; `VSEL_TENSOR_SCALAR_MODE` and `VSEL_TENSOR_TENSOR_MODE` are for explicit tensor/scalar or tensor/tensor selection modes. + +## New kernel + +Follow this order: + +1. Implement `build__kernel(...)` +2. Implement `generate_source(...)` +3. Declare `DISPATCH_SCHEMA` and `SPECIALIZATIONS` +4. Run TileLang compilation once and inspect `registry.inc` +5. Add or update `_wrapper.cpp` +6. Register the kernel in `xllm/core/kernels/npu/tilelang/CMakeLists.txt` with: + - `tilelang_register_runtime_kernel(NAME WRAPPER_SRCS )` + +For wrapper work: + +- do kernel precision alignment on the Python side first (`build__kernel(...)` / `generate_source(...)`), not in the C++ wrapper +- handwrite tensor checks, layout transforms, and `build_runtime_specialization(...)` +- use generated `make__specialization(...)` and `find__kernel_entry(...)` +- do not handwrite kernel-specific specialization structs or kernel fn typedefs + +## New specialization + +Use this path when the kernel logic and wrapper ABI stay the same. + +1. Update the existing kernel's `SPECIALIZATIONS` +2. Confirm every runtime dispatch field still matches `DISPATCH_SCHEMA` +3. Re-run TileLang compilation +4. Check that `registry.inc` contains the new entry +5. Check that the wrapper's `build_runtime_specialization(...)` still constructs matching values + +## Debug generated Ascend-C + +When the task is to inspect codegen or compare two kernel implementations, use an isolated output root: + +```bash +python xllm/compiler/tilelang_launcher.py compile-kernels \ + --target ascend \ + --device a3 \ + --output-root .tmp/tilelang_debug \ + --kernels \ + --force +``` + +Then inspect: + +- `.tmp/tilelang_debug/targets/ascend///__kernel.cpp` +- `.tmp/tilelang_debug/targets/ascend//registry.inc` +- `.tmp/tilelang_debug/targets/ascend//manifest.json` + +Use this path before changing wrapper code when you need to understand generated symbols, field order, or ABI. + +## Validate + +Prefer the narrowest command first: + +- `python -m py_compile xllm/compiler/tilelang/targets/ascend/kernels/.py` +- `cd xllm && python -m compiler.tilelang.targets.ascend.kernels. --output ../.tmp/.cpp` +- `python xllm/compiler/tilelang_launcher.py prepare-ascend` +- `python setup.py test --test-name --device npu` + +## References + +Read `docs/en/dev_guide/tilelang_ascend_kernel_dev.md` for mechanism details. +Use `rope` as the concrete template: +- `xllm/compiler/tilelang/targets/ascend/kernels/rope.py` +- `xllm/core/kernels/npu/tilelang/rope_wrapper.cpp` +- `xllm/core/kernels/npu/tilelang/CMakeLists.txt` diff --git a/.agents/skills/tilelang-debug-helper b/.agents/skills/tilelang-debug-helper new file mode 120000 index 000000000..5f54a2d9e --- /dev/null +++ b/.agents/skills/tilelang-debug-helper @@ -0,0 +1 @@ +../../third_party/tilelang-ascend/.agents/skills/tilelang-custom-skill/tilelang-debug-helper \ No newline at end of file diff --git a/.agents/skills/tilelang-expert-to-developer b/.agents/skills/tilelang-expert-to-developer new file mode 120000 index 000000000..208abac42 --- /dev/null +++ b/.agents/skills/tilelang-expert-to-developer @@ -0,0 +1 @@ +../../third_party/tilelang-ascend/.agents/skills/tilelang-custom-skill/tilelang-expert-to-developer \ No newline at end of file diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 000000000..ac173151d --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +../.agents/skills/ \ No newline at end of file diff --git a/.gemini/styleguide.md b/.gemini/styleguide.md new file mode 120000 index 000000000..75157a7f2 --- /dev/null +++ b/.gemini/styleguide.md @@ -0,0 +1 @@ +../.agents/skills/code-review/references/custom-code-style.md \ No newline at end of file diff --git a/CONTRIBUTING.md b/.github/CONTRIBUTING.md similarity index 98% rename from CONTRIBUTING.md rename to .github/CONTRIBUTING.md index 3cc6d8623..9f9657d3d 100644 --- a/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -42,4 +42,4 @@ For developers who want to contribute to our code, here is the guidance: ## 4. Test After the PR is submitted, we will format and test the code. -Our tests are still far from perfect, so you are welcomed to add tests to our project! \ No newline at end of file +Our tests are still far from perfect, so you are welcomed to add tests to our project! diff --git a/CONTRIBUTING_zh.md b/.github/CONTRIBUTING_zh.md similarity index 100% rename from CONTRIBUTING_zh.md rename to .github/CONTRIBUTING_zh.md diff --git a/.github/workflows/build_x86_64_npu.yaml b/.github/workflows/build_x86_64_npu.yaml index 25c054ef8..bc5c6b1b1 100644 --- a/.github/workflows/build_x86_64_npu.yaml +++ b/.github/workflows/build_x86_64_npu.yaml @@ -109,11 +109,24 @@ jobs: needs.check-sensitive.outputs.do_build == 'true' runs-on: [self-hosted] steps: + - name: Prepare submodule checkout + run: | + git config --global url."https://gitcode.com/xLLM-AI/tvm".insteadOf "https://github.com/TileLang/tvm" + git config --global url."https://gitcode.com/xLLM-AI/composable_kernel".insteadOf "https://github.com/ROCm/composable_kernel" + git config --global url."https://gitcode.com/xLLM-AI/cutlass".insteadOf "https://github.com/NVIDIA/cutlass" + if [ -d .git/modules ]; then + find .git/modules -type f -name '*.lock' -print -delete || true + fi + + rm -rf .git/modules/third_party/tilelang-ascend + rm -rf third_party/tilelang-ascend + - name: Checkout Code timeout-minutes: 5 uses: actions/checkout@v4 with: - submodules: true + clean: true + submodules: recursive - name: Build if: ${{ success() }} diff --git a/.gitignore b/.gitignore index 29119b494..dd2b4a71e 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,8 @@ MANIFEST xllm/*.pyd xllm/*.so xllm/version.py -**/__pycache__/* +__pycache__/ +.pkl_memoize_py3/ # compile_commands.json from nvbench compile_commands.json @@ -57,4 +58,4 @@ compile_commands.json # local files /local /logs -/log \ No newline at end of file +/log diff --git a/.gitmodules b/.gitmodules index 3f94d0032..a6863816e 100755 --- a/.gitmodules +++ b/.gitmodules @@ -1,39 +1,57 @@ [submodule "third_party/brpc"] path = third_party/brpc url = https://gitcode.com/xLLM-AI/brpc.git + fetchRecurseSubmodules = false [submodule "third_party/cpprestsdk"] path = third_party/cpprestsdk url = https://gitcode.com/xLLM-AI/cpprestsdk.git + fetchRecurseSubmodules = false [submodule "third_party/minja"] path = third_party/minja url = https://gitcode.com/xLLM-AI/minja.git + fetchRecurseSubmodules = false [submodule "third_party/sentencepiece"] path = third_party/sentencepiece url = https://gitcode.com/xLLM-AI/sentencepiece.git + fetchRecurseSubmodules = false [submodule "third_party/smhasher"] path = third_party/smhasher url = https://gitcode.com/xLLM-AI/smhasher.git + fetchRecurseSubmodules = false [submodule "third_party/xllm_ops"] path = third_party/xllm_ops url = https://gitcode.com/xLLM-AI/xllm_ops.git + fetchRecurseSubmodules = true [submodule "third_party/etcd_cpp_apiv3"] path = third_party/etcd_cpp_apiv3 url = https://gitcode.com/xLLM-AI/etcd-cpp-apiv3.git + fetchRecurseSubmodules = false [submodule "third_party/spdlog"] path = third_party/spdlog url = https://gitcode.com/xLLM-AI/spdlog.git + fetchRecurseSubmodules = false [submodule "third_party/Mooncake"] path = third_party/Mooncake url = https://gitcode.com/xLLM-AI/Mooncake.git + fetchRecurseSubmodules = false [submodule "third_party/torch_npu_ops"] path = third_party/torch_npu_ops url = https://gitcode.com/xLLM-AI/torch_npu_ops.git + fetchRecurseSubmodules = false [submodule "third_party/cutlass"] path = third_party/cutlass url = https://gitcode.com/xLLM-AI/cutlass.git + fetchRecurseSubmodules = false [submodule "third_party/xllm_atb_layers"] path = third_party/xllm_atb_layers url = https://gitcode.com/xLLM-AI/xllm_atb_layers.git + fetchRecurseSubmodules = false [submodule "third_party/xxHash"] path = third_party/xxHash url = https://gitcode.com/xLLM-AI/xxHash.git + fetchRecurseSubmodules = false +[submodule "third_party/tilelang-ascend"] + path = third_party/tilelang-ascend + url = https://gitcode.com/xLLM-AI/tilelang-ascend.git + branch = ascendc_pto + fetchRecurseSubmodules = true diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 0e9640c29..000000000 --- a/.style.yapf +++ /dev/null @@ -1,2 +0,0 @@ -[style] -based_on_style = google diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..686b21e2f --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,51 @@ +# xLLM Coding Agent Instructions + +## Directory Structure + +``` +├── xllm/ +| : main source folder +│ ├── api_service/ # code for api services +│ ├── c_api/ # code for c api +│ ├── cc_api/ # code for cc api +│ ├── core/ +│ │ : xllm core features folder +│ │ ├── common/ +│ │ ├── distributed_runtime/ # code for distributed and pd serving +│ │ ├── framework/ # code for execution orchestration +│ │ ├── kernels/ # adaption for npu kernels adaption +│ │ ├── layers/ # model layers impl +│ │ ├── platform/ # adaption for various platform +│ │ ├── runtime/ # code for worker and executor +│ │ ├── scheduler/ # code for batch and pd scheduler +│ │ └── util/ +│ ├── function_call # code for tool call parser +│ ├── models/ # models impl +│ ├── parser/ # parser reasoning +│ ├── processors/ # code for vlm pre-processing +│ ├── proto/ # communication protocol +│ ├── pybind/ # code for python bind +| └── server/ # xLLM server +├── examples/ # examples of calling xLLM +├── tools/ # code for npu time generations +└── xllm.cpp # entrypoint of xLLM +``` + +## Code Style Guide + +* Before editing, creating, refactoring, or reviewing any file under `xllm/`, you **MUST** read [custom-code-style.md](.agents/skills/code-review/references/custom-code-style.md). +* The file above is a **required instruction file**, not an optional reference. Do not skip reading it. +* Apply the rules in [custom-code-style.md](.agents/skills/code-review/references/custom-code-style.md) to **both code generation and code review**. +* Follow DDD (Domain Driven Design) principles, and keep the codebase clean and maintainable. +* If [custom-code-style.md](.agents/skills/code-review/references/custom-code-style.md) specifies a rule, that rule takes precedence over the Google C++/Python Style Guide. +* Use the Google C++/Python Style Guide only for cases not specified in [custom-code-style.md](.agents/skills/code-review/references/custom-code-style.md). + +## Review Instructions + +* For code review tasks, you **MUST** first read [code-review/SKILL.md](.agents/skills/code-review/SKILL.md). +* Then read [custom-code-style.md](.agents/skills/code-review/references/custom-code-style.md) and apply it during the review. +* Review code changes for quality, security, performance, correctness, and maintainability following the project-specific standards. +* Review code changes for DDD (Domain Driven Design) principles, and keep the codebase clean and maintainable. +* Use the review workflow, checklist, severity rules, and output format defined in [code-review/SKILL.md](.agents/skills/code-review/SKILL.md). +* Apply the Google C++/Python Style Guide only when the project-specific style guide does not define the rule. +* Focus the review on the requested diff or changed files. Do not comment on unrelated code. \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 000000000..47dc3e3d8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 18f022610..d7495bb18 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -261,6 +261,17 @@ find_package(leveldb CONFIG REQUIRED) find_package(OpenSSL REQUIRED) find_package(absl CONFIG REQUIRED) find_package(Protobuf CONFIG REQUIRED) +# Ensure vcpkg protobuf headers are found before PyTorch's bundled older version. +# Use BEFORE SYSTEM so that: +# 1) vcpkg protobuf is first in -isystem list (before torch's -isystem) +# 2) project-local -I paths (e.g. Mooncake fake_include stubs) still win over -isystem +# Protobuf_INCLUDE_DIRS may be empty in CONFIG mode; fall back to the imported target. +if(NOT Protobuf_INCLUDE_DIRS AND TARGET protobuf::libprotobuf) + get_target_property(Protobuf_INCLUDE_DIRS protobuf::libprotobuf INTERFACE_INCLUDE_DIRECTORIES) +endif() +if(Protobuf_INCLUDE_DIRS) + include_directories(BEFORE SYSTEM ${Protobuf_INCLUDE_DIRS}) +endif() find_package(gRPC CONFIG REQUIRED) find_package(folly CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) @@ -453,8 +464,8 @@ if(USE_CUDA) add_definitions(-DUSE_CUDA) add_compile_definitions(TORCH_CUDA=1) set(CMAKE_VERBOSE_MAKEFILE ON) - include_directories( - $ENV{PYTHON_INCLUDE_PATH} + include_directories($ENV{PYTHON_INCLUDE_PATH}) + include_directories(SYSTEM $ENV{PYTORCH_INSTALL_PATH}/include $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include ) @@ -465,15 +476,31 @@ if(USE_CUDA) $ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64 ) - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3) + # To reduce compilation time during development, use fewer architectures: + # export TORCH_CUDA_ARCH_LIST="9.0" + option(CUDA_DEV_MODE "Use -O1 instead of -O3 for faster CUDA compilation during development" OFF) + if(CUDA_DEV_MODE) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O1") + else() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") + endif() + # The following definitions must be undefined since half-precision operation is required. - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} - -U__CUDA_NO_HALF_OPERATORS__ - -U__CUDA_NO_HALF_CONVERSIONS__ - -U__CUDA_NO_HALF2_OPERATORS__ - -U__CUDA_NO_BFLOAT16_CONVERSIONS__) - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all) - message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}") + string(APPEND CMAKE_CUDA_FLAGS + " -U__CUDA_NO_HALF_OPERATORS__" + " -U__CUDA_NO_HALF_CONVERSIONS__" + " -U__CUDA_NO_HALF2_OPERATORS__" + " -U__CUDA_NO_BFLOAT16_CONVERSIONS__" + " --use_fast_math" + " -Xfatbin -compress-all") + + # Parallel nvcc compilation: compile multiple GPU architectures simultaneously within each .cu + # file. --threads 0 = auto-detect CPU core count. Requires CUDA >= 11.2. + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2) + string(APPEND CMAKE_CUDA_FLAGS " --threads 0") + endif() + + message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") # find_package(NCCL REQUIRED) diff --git a/README.md b/README.md index 8f8ff799c..4a9ad982b 100755 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -[English](./README.md) | [中文](./README_zh.md) +[English](./README.md) | [中文](./docs/project/README_zh.md)
xLLM diff --git a/docs/assets/fixed_steps_scheduler_orca.png b/docs/assets/fixed_steps_scheduler_orca.png new file mode 100644 index 000000000..883705d72 Binary files /dev/null and b/docs/assets/fixed_steps_scheduler_orca.png differ diff --git a/docs/assets/generative_recommendation_beam_search.png b/docs/assets/generative_recommendation_beam_search.png new file mode 100644 index 000000000..70a1b1457 Binary files /dev/null and b/docs/assets/generative_recommendation_beam_search.png differ diff --git a/docs/assets/generative_recommendation_integration_architecture.jpg b/docs/assets/generative_recommendation_integration_architecture.jpg new file mode 100644 index 000000000..9f9d3c583 Binary files /dev/null and b/docs/assets/generative_recommendation_integration_architecture.jpg differ diff --git a/docs/assets/generative_recommendation_integration_architecture_en.svg b/docs/assets/generative_recommendation_integration_architecture_en.svg new file mode 100644 index 000000000..ea877091b --- /dev/null +++ b/docs/assets/generative_recommendation_integration_architecture_en.svg @@ -0,0 +1,54 @@ + + + + + + + + + Generative Recommendation Integration Architecture + + + Predictor Side + + Sparse feature processing + + Sample construction / feature assembly + + Online service integration + Recall, ranking, and request orchestration + + + Shared Library Integration (.so) + Model body subgraph handoff + Prompt / tokens / tensors + Context handoff + runtime options + Low-overhead path without extra RPC + + + xLLM Side + + LLM main-body inference + + Operators / KV cache / scheduling + + Candidate expansion and response generation + REC runtime path + + + + + subgraph offload + runtime execution + diff --git a/docs/assets/generative_recommendation_model_onerec.png b/docs/assets/generative_recommendation_model_onerec.png new file mode 100644 index 000000000..cc83d9785 Binary files /dev/null and b/docs/assets/generative_recommendation_model_onerec.png differ diff --git a/docs/assets/generative_recommendation_model_onetrans.png b/docs/assets/generative_recommendation_model_onetrans.png new file mode 100644 index 000000000..07dbb070a Binary files /dev/null and b/docs/assets/generative_recommendation_model_onetrans.png differ diff --git a/docs/assets/generative_recommendation_overview.png b/docs/assets/generative_recommendation_overview.png new file mode 100644 index 000000000..1320e1637 Binary files /dev/null and b/docs/assets/generative_recommendation_overview.png differ diff --git a/docs/assets/paged_attention_comparison.png b/docs/assets/paged_attention_comparison.png new file mode 100644 index 000000000..c68b38519 Binary files /dev/null and b/docs/assets/paged_attention_comparison.png differ diff --git a/docs/assets/xattention_kv_layout.png b/docs/assets/xattention_kv_layout.png new file mode 100644 index 000000000..ba797ab16 Binary files /dev/null and b/docs/assets/xattention_kv_layout.png differ diff --git a/docs/assets/xattention_kv_layout_en.svg b/docs/assets/xattention_kv_layout_en.svg new file mode 100644 index 000000000..3afcb4e5f --- /dev/null +++ b/docs/assets/xattention_kv_layout_en.svg @@ -0,0 +1,82 @@ + + + + + + + + Separated KV Cache: Shared KV and Unshared KV + Shared Cache + + + + + one physical copy for shared prefix + + Unshared Cache Across Decode Rounds + + + + + + Round 0 + + + + + + + Round 1 + + + + + + + Round 2 + + + + decode + decode + + Beam Expansion Tree + + + + + + + + + + a + b + c + d + + + 0 + 1 + 2 + 3 + 4 + + + + + + + Shared KV is kept once; unshared KV grows with decode rounds and beam choices. + diff --git a/docs/assets/xattention_three_stage_pipeline.png b/docs/assets/xattention_three_stage_pipeline.png new file mode 100644 index 000000000..8c6890b69 Binary files /dev/null and b/docs/assets/xattention_three_stage_pipeline.png differ diff --git a/docs/assets/xattention_three_stage_pipeline_en.svg b/docs/assets/xattention_three_stage_pipeline_en.svg new file mode 100644 index 000000000..b49e31357 --- /dev/null +++ b/docs/assets/xattention_three_stage_pipeline_en.svg @@ -0,0 +1,55 @@ + + + xAttention Three-stage Execution Pipeline + + + Shared Stage + + MCU + BatchMatmul + BatchMatmul + BatchMatmul + BatchMatmul + + VCU + Softmax + Softmax + + + Unshared Stage + + MCU + BatchMatmul + BatchMatmul + BatchMatmul + BatchMatmul + + VCU + Softmax + Softmax + + + Merging Stage + + MCU + Post-processing + Post-processing + + VCU + OnlineSoftmax + OnlineSoftmax + OnlineSoftmax + + Shared and unshared attention are computed separately, then merged with OnlineSoftmax. + diff --git a/docs/en/design/generative_recommendation_design.md b/docs/en/design/generative_recommendation_design.md new file mode 100644 index 000000000..f562e0a34 --- /dev/null +++ b/docs/en/design/generative_recommendation_design.md @@ -0,0 +1,1062 @@ +# Generative Recommendation Design Document + +## Overview + +xLLM provides generative recommendation inference through the `backend=rec` path. The goal is not to replace the existing recommendation system, but to reuse the LLM inference engine in the recommendation pipeline while keeping the original predictor-side sparse feature processing and online serving capabilities. In practice, the LLM body is executed by xLLM, while the traditional recommendation system continues to handle feature preparation and online integration. + +This document focuses on the following topics: + +- the goal and constraints of generative recommendation inference +- model structure and integration architecture +- why the REC path prefers fixed scheduling and whole-graph multi-step execution +- how `xAttention` and `beam search` cooperate around memory efficiency and execution efficiency +- where the core REC-related code is located in the current branch + +The design goals of this document are: + +- explain the `backend=rec` path with a unified view +- clarify the relation among fixed scheduling, multi-step execution, and custom kernels +- provide a stable base document for future technical sharing and code walkthroughs + +The non-goals of this document are: + +- full training details of recommendation models +- all online serving differences across business scenarios +- replacing detailed module-level API documents + +## 1. Background and Problem + +In recent years, LLM-based generative recommendation has made substantial progress. xLLM has also introduced support for recommendation inference. The goal of generative recommendation is not simply to attach an LLM to a recommendation system, but to use generative modeling to improve candidate expansion and ranking quality, especially metrics such as `CTR`. + +In the current solution, xLLM serves as the unified inference engine and is integrated into the existing prediction pipeline through a shared library (`.so`) interface: + +- the `predictor` side continues to handle sparse feature processing, sample construction, and online service integration; +- the `xLLM` side is responsible for the LLM-related inference computation. + +This split allows the original recommendation engineering capabilities to remain intact, while reusing xLLM's infrastructure in operators, KV cache management, multi-backend execution, and scheduling. + +However, generative recommendation and general LLM inference optimize for different targets. + +- General LLM inference cares more about token-by-token interaction quality, such as returning the first result quickly, reducing the interval between generated tokens, and allowing requests to be inserted or ended dynamically. +- Generative recommendation cares more about total request latency and obtaining better candidate results within a limited number of decoding rounds. + +The reason is straightforward: recommendation is usually not about generating an open-ended passage, but about expanding candidates, comparing candidates, and outputting the best result within a fixed number of rounds. + +![Generative recommendation overview](../../assets/generative_recommendation_overview.png) + +This path often uses `beam search`. In this context, beam search means that at each decoding round the system keeps multiple high-scoring candidate branches instead of only the best one, then keeps expanding and comparing them in later rounds. In recommendation, the purpose is not to generate longer text, but to cover more high-quality candidates within a small number of decoding rounds and improve the final recommendation quality. + +![Beam search in generative recommendation](../../assets/generative_recommendation_beam_search.png) + +Therefore, generative recommendation naturally has two characteristics: + +- fixed-step decoding; +- synchronized comparison of multiple candidates. + +In other words, the real optimization target is not “make one sequence finish earlier”, but “push multiple candidates forward stably in a small number of rounds and compare them efficiently at each round”. This leads directly to the later design choice: fixed scheduling at the control layer and whole-graph multi-step execution at the execution layer, followed by custom operator optimization on top of that stable execution shape. + +### 1.1 Workload characteristics: why GR is not just “another attention model” + +At a high level, generative recommendation still uses attention-based architectures, so it is easy to assume that its serving path can simply reuse the general LLM serving mindset. But the workload profile is very different. + +Generative recommendation typically has the following characteristics: + +- the prompt is long because it carries user history, context, and recommendation-side signals; +- the output is short because the system only needs a fixed-length item token sequence; +- the number of decode rounds is fixed and usually small; +- each decode round is still expensive because candidate expansion is often combined with a large `beam_width` and `top_k`. + +This forms a sharp contrast with general LLM inference: + +- general LLM inference is usually “short prompt + long output”; +- generative recommendation is much closer to “long prompt + short output”. + +So recommendation does not become cheap simply because the output is short. The decode phase is still expensive, and in many cases the expensive part is no longer only attention itself, but the system cost around candidate expansion and beam comparison. + +### 1.2 Three core challenges of this workload + +From the workload characteristics of generative recommendation itself, three challenge categories stand out clearly when compared with the general LLM inference path. + +#### Challenge 1: long-prompt / short-output does not mean decode is cheap + +Even though the decode length is fixed and small, each decode round may still be expensive. Shared-prefix reuse, repeated KV access across beams, and beam-related block movement all become more visible because the system is not amortizing them over a long free-form generation. + +#### Challenge 2: beam search is not only an algorithmic issue + +In generative recommendation, beam search is not just a decoding technique for “better text quality”. It is part of the recommendation search process itself. Once `beam_width` and `top_k` increase, sorting, filtering, valid-item checking, candidate retention, and data-structure reuse all become system-level concerns. + +#### Challenge 3: the bottleneck is also in host-device cooperation + +The system usually runs under strict online latency constraints and high concurrency. If the host still comes back at every step to decide whether to continue, prepare the next input, and resend control to the device, the host-side control path itself becomes a major part of the latency budget. + +### 1.3 The role of this design document + +This document does not try to reproduce an external paper or replace an experiment section. Its role is more practical: it brings the workload characteristics, the system-level problems, and the concrete implementation paths of the current branch into one place, so that later talks, reviews, and code walkthroughs can all use the same technical base. + +## 2. Inference Architecture + +### 2.1 Model Structure + +Generative recommendation has become one of the most important paradigm shifts in modern recommendation systems. It breaks the traditional recall-ranking-rerank cascade and pushes the task from discriminative matching toward generative prediction. This document focuses on two model families that have shown strong quality and have been deployed at scale: OneRec for recall and OneTrans for ranking. + +![OneRec model structure](../../assets/generative_recommendation_model_onerec.png) + +![OneTrans model structure](../../assets/generative_recommendation_model_onetrans.png) + +A shared pattern across these models is that they keep the traditional recommendation signals such as user sequence features, user static features, and context features, then use an input adaptation layer to map heterogeneous recommendation inputs (discrete IDs, continuous values, sequences, and multimodal content) into embeddings that can be consumed by the LLM decoder. The model body itself is still an Encoder+Decoder or Decoder-only LLM structure, which means different parts of the model should be handled by different inference engines. + +### 2.2 Inference Integration Architecture + +Based on the model structure, the current solution splits the model into two groups: + +- the input adaptation layer still belongs to the traditional CTR-style inference domain and is handled by the `predictor` side; +- the LLM body is handled by xLLM. + +As the core LLM inference engine, xLLM provides two integration modes for generative recommendation: RPC integration and shared library (`.so`) integration. + +#### 2.2.1 RPC Integration + +The current marketing and online recall scenarios mainly use the RPC-based integration mode. Its advantage is a clean service boundary, while the downside is the extra RPC overhead. + +#### 2.2.2 Shared Library Integration + +Another mode is to embed xLLM into the predictor side as an internal inference engine for the LLM subgraph. This avoids RPC round trips and is more suitable for future low-latency scenarios. + +![Generative recommendation integration architecture](../../assets/generative_recommendation_integration_architecture_en.svg) + +## 3. Fixed Scheduling and Graph-style Execution + +### 3.1 Fixed-Step Scheduling + +![Background of Orca continuous batching](../../assets/fixed_steps_scheduler_orca.png) + +The figure above comes from the paper *Orca: A Distributed Serving System for Transformer-Based Generative Models*. It explains why continuous batching is useful for general text generation: it avoids idle compute caused by rigid fixed-batch execution. + +Generative recommendation changes that assumption because the decoding length is fixed and the candidate set needs to move forward synchronously. As a result, the REC path is more suitable for `fixed_steps_scheduler` than for continuous batching. The reason is not simply “fixed rounds imply fixed scheduling”. The deeper reason is that the workload itself is organized as a fixed number of rounds. When requests usually finish in a predefined number of rounds and multiple candidate branches must move forward together, the scheduler should focus on sending one stable candidate group efficiently instead of inserting and removing requests at every step. + +The first benefit of `fixed_steps_scheduler` is that it matches `beam search` better. In the decode stage, `beam width` is often large and multiple beams need to move and be compared in the same round. If continuous scheduling is used, every step may trigger batch rebuilding, sequence compaction, index remapping, and state pruning. Those operations make sense for general LLM inference because requests really do end dynamically. In generative recommendation, however, they are often additional cost instead of real value. Under fixed scheduling, the beam group of one request can move forward together inside one fixed window, without repeated batch rebuilding and repeated pruning decisions. + +The second benefit is execution stability. Once the number of rounds, the beam-group size, and the advancing rhythm all become stable, many later optimizations become possible. Buffers can be allocated early, workspace can be reused more easily, and cache access patterns become more regular. This stability makes profiling and capacity planning easier and allows the execution path to be made much more stable. + +![PagedAttention background and fixed-step comparison](../../assets/paged_attention_comparison.png) + +The third benefit is that it reduces overhead outside the real model computation. In recommendation inference, the major cost should belong to candidate expansion, attention computation, and beam comparison. But if the scheduler keeps participating in sequence reordering, batch rebuilding, metadata refresh, and index movement at every step, then extra overhead is introduced even though it is not part of the actual operator work. In this sense, fixed scheduling trades stronger execution determinism for higher throughput, lower scheduling overhead, and more stable runtime behavior. + +Of course, fixed scheduling also has a clear cost: a new request may wait longer. Under continuous scheduling, a new request may have a chance to be inserted after only one step. Under fixed scheduling, it often has to wait until the current fixed window finishes. This leads to a more visible queueing delay. The mitigation direction is not to go back to continuous scheduling, but to introduce `multi-stream` execution. The idea is to decouple the large request groups already running inside a fixed window from newly arriving small request groups and place them on different streams or execution channels. The purpose is not to eliminate waiting entirely, but to keep the throughput advantage of fixed scheduling while reducing the extra access latency of new requests. + +### 3.2 Graph-style Multi-Step Execution + +On top of fixed scheduling, `multi_step_pipeline` becomes the natural execution-side companion. It solves an execution-efficiency problem. Once we know the workload itself always runs a fixed number of rounds and does not usually end early, then there is no need to involve the host at every step. There is no need to perform a `D2H` synchronization every round just to ask whether the batch is finished, and there is no need to perform another `H2D` transfer every round just to prepare the next round's input. A more efficient design is to prepare the space, indices, and data structures needed by later rounds at the first step, and then let the device continue advancing through the later rounds. + +This brings several direct benefits: + +- fewer `D2H/H2D` round trips and less host participation; +- lower launch and control overhead at every round; +- better device-side data reuse because more intermediate data stays on device; +- a more pipeline-like execution flow instead of a repeated stop-prepare-run cycle. + +For a fixed-round recommendation workload, this is clearly more efficient than returning to the host after every step. + +Another important but often underestimated benefit of `multi_step_pipeline` is that it creates a better execution environment for custom operators. This is the point where `xAttention` and `beam search` custom kernels can be discussed together. `fixed step` solves scheduling stability, while whole-graph multi-step execution plus custom kernels solves execution efficiency. + +## 4. Memory Management and Operator Co-optimization + +### 4.1 Compute and Memory Bottlenecks + +#### 4.1.1 Model Input and Output Characteristics + +In the current generative recommendation setup, an item ID is represented by a fixed-length token sequence. As a result, `decode_step` is a known small constant, for example 3. One request can be summarized as: + +- one prefill stage with a long user-history context; +- `decode_step` rounds of decode, where each round generates one token and the final token sequence is combined into an item ID. + +Even if the number of decode rounds is small, the per-step cost is not small. To improve recall and diversity, generative recommendation often uses a large `beam_width`. In addition, each beam may expand to `top_k` candidates, and the system then selects the new beam set from a global candidate pool of `beam_width × top_k`. For example, when `beam_width=512` and `top_k=512`, the candidate pool size of one decode round reaches 262144 (about 2.6×10^5). So although the number of rounds is limited, the search and KV access cost per step is still considerable. + +#### 4.1.2 Storage Redundancy and Memory Fragmentation + +The main bottlenecks of this inference service can be summarized into two groups, and `xAttention` is designed around both of them. + +The first is redundant bandwidth consumption in attention. Shared prefixes are not explicitly represented as reusable structures. When beam width is large, all beams share the same long prompt, but a generic implementation often organizes KV as if every beam were a full independent sequence. This causes shared KV to be loaded repeatedly along the beam dimension and reduces the effective arithmetic intensity of the attention kernel, eventually making the path bandwidth-bound. + +The second is KV cache copying and fragmentation. Beam search frequently forks and retires branches, which leads to beam reordering. For a block-based KV management scheme such as PagedAttention, “reordering + block alignment” often implies block copying, fragmentation, and extra memory waste. Both memory capacity and memory bandwidth get amplified in the wrong direction. + +### 4.2 `xAttention` Design Principle + +#### 4.2.1 KV Cache Layout Optimization + +![xAttention KV cache layout](../../assets/xattention_kv_layout_en.svg) + +Given the fixed structure of generative recommendation inference, xAttention redesigns both KV cache organization and the attention execution strategy. The shared prefix is stored only once at the physical-memory level, while beam branching and reordering no longer cause high-cost copying. + +The first key idea is to split KV cache into two groups: + +- **Shared KV**: prompt KV generated in prefill and shared by all beams; +- **Unshared KV**: newly generated KV in decode, managed at token granularity for each beam. + +Once the cache is split this way, Unshared KV only stores the decode-generated tokens, which avoids both block copying and unnecessary memory waste. + +#### 4.2.2 Attention Compute Optimization + +![xAttention three-stage execution](../../assets/xattention_three_stage_pipeline_en.svg) + +To avoid concatenating Shared KV and Unshared KV into one long logical sequence, xAttention splits one attention computation into three stages: + +1. **shared stage**: compute local softmax statistics and partial outputs only on Shared KV; +2. **unshared stage**: compute local statistics and partial outputs only on Unshared KV; +3. **merge stage**: use OnlineSoftmax to merge the two parts stably. + +At the parallelization level, shared, unshared, and merge are assigned to different execution units or queues to form a pipeline. The goal is to overlap Shared and Unshared computation as much as possible while minimizing synchronization points. + +### 4.3 Treating beam search as a system problem + +If `xAttention` addresses the question “how can attention reuse shared context and avoid wasteful KV movement”, then this section addresses a different but equally important question: “how can beam search avoid turning recommendation decode into a sorting-heavy control bottleneck”. + +From a systems perspective, beam search in recommendation carries several cost layers at once: + +- each round must select a new beam set from a large candidate pool; +- not every token combination corresponds to a valid item; +- old candidates are discarded while new candidates are constantly introduced; +- if the implementation rebuilds data structures and performs full sorting every round, the overhead grows quickly. + +So this part should be understood less as “one sorting kernel” and more as a system-level optimization strategy around beam search. Its goals include: + +- terminate unnecessary sorting work as early as possible; +- filter invalid item paths as early as possible; +- reuse data structures across rounds instead of reconstructing them repeatedly. + +For a technical talk, the key message here is that beam search is not an accessory cost in recommendation serving. It is part of the main decode cost, and therefore has to be designed together with scheduling, memory layout, and hot operator paths. + +### 4.4 A system-level view of overlap and parallelism + +The third challenge is not that one operator is slow, but that the whole serving pipeline is not naturally shaped for the recommendation workload. In this document, this refers to the system-level effort to maximize overlap across host-side scheduling, engine-side execution, worker-side multi-round progression, and stream-level parallelism. + +At least three ideas belong to this layer: + +1. **clearer host-device role split** + - the host should participate less in every round; + - the device should carry more of the fixed-round progression directly. + +2. **more pipeline-friendly execution** + - while the current round is being executed, next-round inputs should already be under preparation; + - scheduler, batch builder, and worker should avoid introducing unnecessary idle boundaries. + +3. **multi-stream and multi-pipeline concurrency** + - a fixed execution window naturally increases waiting cost for new requests; + - but that cost can be partially offset through multiple streams or multiple execution pipelines. + +So, in the context of this document, this section is best understood as the system layer that combines fixed-step stability, device-side multi-step progression, and multi-stream overlap into one serving strategy. + +### 4.5 Why this design direction is worth the effort + +From the workload characteristics of generative recommendation, treating it as a special serving path is not over-engineering. It has clear system-level value. + +The reason is that this workload combines several properties that do not usually appear together in the general LLM path: + +- the input is long while the output is short; +- the decode rounds are fixed, but one decode round is still expensive; +- both `beam_width` and candidate-pool size are non-trivial; +- host-side control, batch rebuilding, and data movement can easily dominate the latency budget. + +As a result, simply reusing the general path tends to preserve unnecessary system cost. Once fixed-step scheduling, multi-step execution, shared/unshared KV design, beam search handling, and multi-stream overlap are considered together, the gain can appear at multiple levels at once: lower scheduling cost, less host participation, a cleaner memory layout, and a more stable hot path for custom operators. + +This also gives the technical talk a strong closing point for this section: the reason to separate recommendation from the general LLM path is not just conceptual neatness, but the fact that recommendation decode really creates a different serving problem, and the payoff of handling it explicitly can be significant. + +## 5. Code Layout + +The current branch organizes the generative recommendation path as follows. + +External integration: +- `xllm/c_api/rec.h` +- `xllm/c_api/internal/rec.cpp` +- `xllm/c_api/examples/simple_rec_completions.cpp` + +Service entry: +- `xllm/api_service/rec_completion_service_impl.cpp` +- `xllm/api_service/chat_service_impl.cpp` +- `xllm/api_service/api_service.cpp` +- `xllm/api_service/api_service.h` + +Scheduling and engine: +- `xllm/core/distributed_runtime/rec_master.cpp` +- `xllm/core/distributed_runtime/rec_master.h` +- `xllm/core/scheduler/fixed_steps_scheduler.cpp` +- `xllm/core/scheduler/fixed_steps_scheduler.h` +- `xllm/core/distributed_runtime/rec_engine.cpp` +- `xllm/core/distributed_runtime/rec_engine.h` + +Batch / request / proto: +- `xllm/core/framework/batch/rec_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_batch_input_builder.h` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.h` +- `xllm/core/framework/request/rec_type.h` +- `xllm/proto/rec.proto` +- `xllm/proto/completion.proto` +- `xllm/proto/xllm_service.proto` + +Runtime / worker: +- `xllm/core/runtime/rec_worker_impl.cpp` +- `xllm/core/runtime/rec_worker_impl.h` + +Kernel hot path: +- `xllm/core/layers/cuda/xattention.cpp` +- `xllm/core/layers/cuda/flashinfer_attention.cpp` +- `xllm/core/kernels/cuda/xattention/beam_search.cpp` +- `xllm/core/kernels/cuda/xattention/cache_select.cu` + +## 6. Current-branch Execution Flow + +To align the design with the actual implementation, the current branch can be understood as the following execution chain: + +1. **External entry** + - In shared-library mode, requests enter from `xllm/c_api/internal/rec.cpp` through `xllm_rec_text_completions`, `xllm_rec_token_completions`, or `xllm_rec_chat_completions`. + - In service mode, requests enter from `xllm/api_service/rec_completion_service_impl.cpp` or `chat_service_impl.cpp`, and are then forwarded to `RecMaster`. + +2. **Request convergence in `RecMaster`** + - `RecMaster` unifies different request forms such as prompt input, token input, and raw embedding input. + - It also distinguishes `kOneRec` and `kLlmRec`, then selects the corresponding request pipeline. + +3. **Entering fixed scheduling** + - `RecMaster` creates `FixedStepsScheduler` directly during initialization. + - Instead of rebuilding decode batches dynamically at every step, the scheduler is centered on a fixed number of rounds and a stable candidate group. + +4. **Engine execution** + - `RecEngine` then selects an execution path according to `RecPipelineType`. + - For the `LlmRec` multi-round scenario, execution is routed into `RecMultiRoundEnginePipeline`, which pushes more decode control logic down toward the worker side. + +5. **Batch building and input construction** + - `RecBatchInputBuilder` and `RecMultiRoundBatchInputBuilder` organize sequences, step information, decode positions, sampling parameters, and other metadata into `ForwardInput`. + - `step_meta` is especially important here because it provides the per-round information needed for multi-step execution. + +6. **Multi-round execution inside the worker** + - `RecWorkerImpl::LlmRecMultiRoundPipeline::step()` performs a device-side round loop. + - For each round, it coordinates: + - current-round input preparation + - model forward + - sample processing + - beam search + - cache selection + - preparation for the next round + +7. **Hot operator path** + - Attention-related execution goes through `xattention.cpp` and `flashinfer_attention.cpp` + - Beam-related selection goes through `beam_search.cpp` + - Cache selection after beam reordering goes through `cache_select.cu` + +From a technical-sharing perspective, this chain is a useful narrative backbone because it connects fixed scheduling, whole-graph multi-step execution, and custom kernels into one coherent execution story instead of presenting them as isolated optimizations. + +## 7. Trade-offs and Applicability Boundaries + +This design does not mean that fixed scheduling is always better than continuous scheduling, nor does it mean that multi-step pipeline execution fits every generation task. It works well because the current generative recommendation workload has a few specific characteristics: + +- the number of decode rounds is relatively fixed and requests usually do not terminate early like open-ended text generation; +- one request often carries a large `beam_width`, and multiple beams must be compared synchronously; +- total request latency matters more than token-by-token interactivity; +- device-side state such as KV cache, positions, and beam tensors can be prepared and reused in a relatively stable way. + +When those conditions hold, fixed scheduling and multi-step execution provide clear benefits. At the same time, they also have clear boundaries: + +### 7.1 Boundary of fixed scheduling + +- If request lengths vary dramatically and a large number of requests end early, continuous scheduling regains its advantage. +- If the business priority is to insert new requests as soon as possible, fixed scheduling naturally becomes less attractive. +- If candidate expansion no longer depends on synchronized beam progression, the benefit of a fixed execution window decreases. + +### 7.2 Boundary of `multi_step_pipeline` + +- If each step still requires strong host-side decisions, the value of device-side multi-step execution becomes smaller. +- If shape, batch layout, or key inputs change significantly at every round, it becomes much harder to keep a stable whole-graph execution path. +- If backend operators themselves are not ready for stable multi-round device-side execution, whole-graph execution will remain only a conceptual idea. + +### 7.3 Boundary of custom operators + +`xAttention` and custom `beam search` kernels are valuable because the execution shape is already stable enough. Without fixed-step scheduling and multi-step device-side progression, many operator-level optimizations would be diluted by repeated data movement, batch rebuilding, and extra host participation. + +So the more accurate order of reasoning is: + +1. first confirm that the workload truly fits fixed scheduling; +2. then confirm that multi-round execution can be pushed down to the device side; +3. only then optimize the stable hot path with custom operators. + +This is what makes `fixed_steps_scheduler`, `multi_step_pipeline`, `xAttention`, and `beam search` a coherent design stack rather than four unrelated optimization tricks. + +## 8. Code Path Appendix + +The purpose of this appendix is not to repeat the file list in the previous section, but to provide a practical reading order. If someone later wants to expand this document, prepare a technical talk, or walk through the implementation, the following path is a more useful starting point. + +### 8.1 Start from the external entry + +If the goal is to understand how `predictor` or the shared-library integration enters xLLM, start with: + +- `xllm/c_api/rec.h` + - exposes `xllm_rec_create`, `xllm_rec_initialize`, `xllm_rec_text_completions`, `xllm_rec_token_completions`, and `xllm_rec_chat_completions` + - useful for understanding what the REC-facing external contract looks like +- `xllm/c_api/internal/rec.cpp` + - the actual CAPI implementation + - useful for seeing how request parameters are wrapped and forwarded in `.so` mode +- `xllm/c_api/examples/simple_rec_completions.cpp` + - the smallest runnable example + - useful when explaining what a dynamic-library integration looks like in practice + +If the technical-sharing version needs one “minimal usage example”, this is the best layer to cite first. + +### 8.2 Then read the service entry layer + +If the focus is on RPC mode or the unified service path, continue with: + +- `xllm/api_service/api_service.cpp` + - dispatches service implementations according to `FLAGS_backend` + - under `backend == "rec"`, both `rec_completion_service_impl_` and `chat_service_impl_` are attached +- `xllm/api_service/rec_completion_service_impl.cpp` + - forwards REC completion requests into `RecMaster` + - this is where `routing`, `input_tensors`, and `RequestParams` are assembled together +- `xllm/api_service/chat_service_impl.cpp` + - also provides a chat entry for `RecMaster` + - useful to show that REC is not limited to token completion only + +This layer is useful in a technical talk when explaining that `backend=rec` is not a side path outside the service framework, but a first-class backend integrated into the existing service entry. + +### 8.3 Then read scheduling and engine as one chain + +If the main story is “why fixed step is a better fit for REC”, a practical reading order is: + +1. `xllm/core/distributed_runtime/rec_master.h` +2. `xllm/core/distributed_runtime/rec_master.cpp` +3. `xllm/core/scheduler/fixed_steps_scheduler.h` +4. `xllm/core/scheduler/fixed_steps_scheduler.cpp` +5. `xllm/core/distributed_runtime/rec_engine.h` +6. `xllm/core/distributed_runtime/rec_engine.cpp` + +The chain can be understood as: + +```text +Rec request + -> RecMaster + -> FixedStepsScheduler + -> RecEngine + -> RecEnginePipeline + -> Worker / worker_clients +``` + +The most important takeaways in this layer are: + +- `RecMaster` converges multiple request forms and chooses the request pipeline +- `FixedStepsScheduler` organizes requests into a batch shape that matches fixed-round progression +- `RecEngine` bridges the scheduler output to the actual execution path +- `RecMultiRoundEnginePipeline` is the concrete code path that represents “push multi-round decode control further down” + +If the talk needs to prove that “fixed step is not just an idea but the actual code path”, this is the core evidence layer. + +### 8.4 Read batch builders to understand why multi-step execution is possible + +To explain why `multi_step_pipeline` is possible, it is not enough to stop at the scheduler. The batch builder layer is equally important: + +- `xllm/core/framework/batch/rec_batch_input_builder.h` +- `xllm/core/framework/batch/rec_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.h` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp` +- `xllm/core/framework/batch/batch.cpp` + +The key points here are: + +- `RecBatchInputBuilder::create(...)` chooses different builders according to `RecType` and multi-round mode +- `RecMultiRoundBatchInputBuilder` is not just a small variation of the default builder, but a dedicated implementation for multi-round decode input construction +- `step_meta`, `decode_positions`, `sampling params`, and `batch forward type` are assembled here before being sent further into runtime + +So if the document or talk wants to explain “why later rounds can already be prepared at the first step”, this layer is more important than only talking about the engine loop. + +### 8.5 Read worker-side multi-round execution next + +The most valuable implementation details of `multi_step_pipeline` are inside: + +- `xllm/core/runtime/rec_worker_impl.h` +- `xllm/core/runtime/rec_worker_impl.cpp` + +Especially these parts: + +- `RecWorkerImpl::step_async(...)` + - shows how work is dispatched into the worker and onto the target stream +- `RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_inputs(...)` + - shows how multi-round inputs enter runtime +- `RecWorkerImpl::LlmRecMultiRoundPipeline::allocate_kv_caches_related()` + - shows why fixed-step execution is friendly to early KV-related allocation +- `RecWorkerImpl::LlmRecMultiRoundPipeline::step(...)` + - this is the single most important function for multi-round execution + - it makes the relationship among round loop, beam search, cache select, and next-round preparation explicit +- `compute_next_round_input_async(...)` + - this is the key function when explaining why host round-trips can be reduced + +If a technical talk wants to emphasize execution efficiency rather than scheduling policy, this is the best layer to expand. + +### 8.6 Finally read the hot operator path + +If the goal is to explain why `xAttention` and custom `beam search` kernels matter, the recommended reading order is: + +- `xllm/core/layers/cuda/xattention.cpp` +- `xllm/core/layers/cuda/flashinfer_attention.cpp` +- `xllm/core/kernels/cuda/xattention/xattention_ops_api.h` +- `xllm/core/kernels/cuda/xattention/beam_search.cpp` +- `xllm/core/kernels/cuda/xattention/cache_select.cu` + +This layer helps answer three questions: + +1. where the stable attention execution path actually lives +2. why beam search is not only a scheduling topic but also a hot kernel path +3. why cache selection after beam reordering must be discussed together with execution shape + +In other words, if the technical-sharing narrative wants to connect `fixed_steps_scheduler`, `multi_step_pipeline`, `xAttention`, and `beam search` into one coherent story, it eventually has to land here. + +### 8.7 Recommended reading order + +If this document is expanded further later, the recommended code-reading order is: + +```text +entry + -> api_service + -> RecMaster + -> FixedStepsScheduler + -> RecEngine + -> RecBatchInputBuilder / RecMultiRoundBatchInputBuilder + -> RecWorkerImpl::LlmRecMultiRoundPipeline + -> xAttention / beam_search / cache_select +``` + +The advantage of this order is that it matches how people usually understand the system: + +- first, how the request enters the system +- then, why REC prefers fixed-step scheduling +- then, how multi-step execution is actually pushed down to the device side +- finally, why custom operators become valuable only after the execution shape becomes stable + +This makes the logic easier to present in a technical talk and easier to reuse in future documentation work. + +## 9. Key Code Anchor Index + +If this design document is later extended into a technical talk, a PR explanation, or a deeper code walkthrough, the following anchors are the fastest way to connect the document back to the current branch implementation. + +### 9.1 Entry and service-layer anchors + +- `xllm/core/distributed_runtime/rec_master.cpp:575` + - `RecMaster::handle_request(...)` + - prompt / prompt_tokens / input_tensors entry +- `xllm/core/distributed_runtime/rec_master.cpp:603` + - `RecMaster::handle_request(...)` + - chat-message entry +- `xllm/core/distributed_runtime/rec_master.cpp:651` + - `RecMaster::handle_request(const std::vector& prompt_tokens, ...)` + - token / raw-input style entry + +This group is useful when the talk needs to answer a simple question first: how exactly does a REC request get converged into `RecMaster`. + +### 9.2 Scheduling anchors + +- `xllm/core/scheduler/fixed_steps_scheduler.cpp:337` + - `FixedStepsScheduler::step(const absl::Duration& timeout)` + - the actual scheduling step that advances the fixed-step path +- `xllm/core/scheduler/fixed_steps_scheduler.cpp:186` + - `FixedStepsScheduler::prepare_batch()` + - useful for explaining how requests are grouped under fixed scheduling +- `xllm/core/framework/batch/rec_batch_input_builder.cpp:29` + - `RecBatchInputBuilder::create(...)` + - useful for explaining how builders are selected according to `RecType` and multi-round mode + +Taken together, these anchors are strong evidence that fixed-step scheduling is not just a conceptual preference but the actual scheduling choice in the code path. + +### 9.3 Engine and multi-round execution anchors + +- `xllm/core/distributed_runtime/rec_engine.cpp:901` + - `RecEngine::RecMultiRoundEnginePipeline::step(...)` + - shows how engine-side execution is pushed into the multi-round pipeline +- `xllm/core/runtime/rec_worker_impl.cpp:849` + - `RecWorkerImpl::LlmRecMultiRoundPipeline::step(...)` + - the single most important multi-round execution function on the worker side +- `xllm/core/runtime/rec_worker_impl.cpp:1011` + - call site of `xllm::kernel::cuda::beam_search(...)` +- `xllm/core/runtime/rec_worker_impl.cpp:1066` + - call site of `xllm::kernel::cuda::cache_select(...)` + +If the talk needs to make it explicit that beam search and cache select are not only “scheduling concepts” but hot device-side execution paths, this is the right group of anchors to cite. + +### 9.4 How to use this anchor list + +This list does not need to be repeated in the main narrative, but it is especially useful as: + +- a “code evidence” slide in a technical talk +- a “key implementation locations” section in a PR description +- a quick response set when a reviewer asks where a specific statement comes from + +If the document is extended further later, these anchors are also the most practical starting points for adding deeper function-level explanations. + +## 10. Suggested Talk Order and Reading Strategy + +The previous sections, especially the code-path appendix and the key-code anchor list, are useful as a reference base. But if this document is later turned into a real technical talk, it also needs a more presentation-friendly reading order. + +### 10.1 Recommended talk order + +If the audience is not directly involved in implementing `backend=rec`, the talk should not start from `RecWorkerImpl` or `beam_search.cpp`. A more effective order is: + +1. **Start from the business goal** + - explain why generative recommendation and general LLM inference optimize for different targets + - explain why recommendation cares more about end-to-end request latency and fixed-round candidate comparison + +2. **Then explain the scheduling choice** + - why `fixed_steps_scheduler` is a better fit for this workload + - why a large `beam_width` makes fixed scheduling more valuable than continuous scheduling + +3. **Then explain the execution model** + - why `multi_step_pipeline` can push multi-round decode control further down to the device side + - why this reduces host round-trips and control overhead + +4. **Then explain custom operators** + - why `xAttention` and `beam search` optimizations become meaningful only after the execution shape becomes stable + - why those two topics should be presented together rather than as isolated tricks + +5. **Finally return to the code** + - use the code path and key anchors to prove the previous conclusions + +This order works well because the audience first understands the design motivation, then the execution strategy, and only after that the implementation evidence. + +### 10.2 For internal code walkthroughs, the order can be reversed + +If the audience already works on xLLM or recommendation infrastructure, the order can be more implementation-driven: + +1. start from `RecMaster -> FixedStepsScheduler -> RecEngine` +2. then `RecBatchInputBuilder` +3. then `RecWorkerImpl::LlmRecMultiRoundPipeline` +4. finally `xAttention / beam_search / cache_select` + +This path is more efficient for an internal walkthrough because it follows the actual call stack. The downside is that it pushes new readers into implementation detail before they have the design context. + +### 10.3 The three takeaways worth remembering + +If the document is later compressed into a short technical presentation, the whole design can be summarized into three lines: + +- `fixed_steps_scheduler` solves scheduling stability; +- `multi_step_pipeline` solves multi-round execution efficiency; +- `xAttention` and `beam search` custom kernels turn that stable execution shape into real performance gains. + +Those three lines are the most important part to remember. The rest of the document can be seen as evidence supporting them. + +## 11. Comparison with the General LLM Inference Path + +To avoid treating `backend=rec` as “just another small variation of LLM inference”, it is useful to compare it explicitly with the general LLM inference path. + +### 11.1 Different optimization targets + +General LLM inference is usually optimized for token-by-token generation experience: return the first result quickly, keep the token interval small, and allow new requests to enter as early as possible. +`backend=rec`, in contrast, is optimized for candidate expansion and comparison within a fixed number of rounds, so the primary concern becomes end-to-end request latency rather than the earliest completion of one sequence. + +That means both paths perform decode, but the target is already different: + +- general LLM inference is more about dynamic request management +- generative recommendation is more about synchronized progression inside a fixed execution window + +### 11.2 Different scheduling focus + +In general LLM inference, continuous scheduling is valuable because: + +- some sequences can end early and free space immediately +- batches can be rebuilt continuously +- dynamic insertion and dynamic exit are frequent and meaningful + +In `backend=rec`, the picture is different: + +- the number of decode rounds is more fixed +- `beam_width` is larger +- multiple candidates under the same request must be compared in the same round +- frequent batch rebuilding tends to amplify scheduling cost rather than reduce it + +So the key difference is not simply “fixed-step vs continuous”, but “stability-first scheduling” versus “flexibility-first scheduling”. + +### 11.3 Different execution model + +General LLM inference often allows the host to continue participating at each decode step, including step completion checks and next-step preparation. +Generative recommendation, because of its fixed-round and synchronized-candidate nature, benefits more from preparing multi-round structures early and letting the device continue forward. + +That is why `multi_step_pipeline` is especially valuable here: it is not only about removing a few memcpy calls, but about replacing a host-driven step-by-step control model with a device-side multi-round progression model. + +### 11.4 Different way to realize operator-level gains + +In general LLM inference, operator optimization often directly improves one decode step or the generic attention path. +In `backend=rec`, the gain of operator optimization depends much more on whether the execution shape has already been stabilized. + +If fixed scheduling is not established and batches are still rebuilt frequently, or if multi-step execution is not established and the host still repeatedly intervenes, a large part of the operator-level gain will be diluted by extra data movement and control overhead. +That is why the more accurate order in recommendation is: + +1. stabilize scheduling first +2. push multi-round execution to the device side +3. optimize the stable hot path with `xAttention`, `beam search`, and `cache_select` + +### 11.5 Design boundaries and applicability conditions + +The purpose of this comparison is not to claim that generative recommendation and general LLM inference are completely unrelated. The real purpose is to make it explicit which parts can be reused from the general path and which parts must be redesigned around the `backend=rec` workload. + +From the analysis above, `fixed_steps_scheduler`, `multi_step_pipeline`, `xAttention`, and `beam search` need to be discussed together because they depend on the same preconditions: + +- output length is fixed or approximately fixed; +- the number of decode rounds is small, but the cost of one round is still high; +- one request maintains a relatively large `beam_width`; +- candidate expansion and candidate comparison are part of the main serving path rather than optional post-processing; +- device-side state can be prepared early and reused across later rounds. + +Only when these conditions hold at the same time do fixed scheduling and whole-graph multi-step execution deliver stable gains. Once these conditions disappear, for example because requests vary dramatically in length, many sequences terminate early, or new-request latency matters more than throughput, the current `backend=rec` design may no longer be the best choice. + +It is also important to note that these four parts are not parallel ideas. They form a layered dependency: + +The first layer is `fixed_steps_scheduler`. It stabilizes the scheduling entrance. If the scheduler still rebuilds batches and prunes sequences at every step, the later execution path cannot become stable. + +The second layer is `multi_step_pipeline`. It stabilizes the progression of multi-round execution. Only after the scheduling window is stable does it become meaningful to push input preparation, KV organization, and round progression further down to the device side. + +The third layer is `xAttention`. It addresses the memory and bandwidth behavior of attention in the recommendation workload. It only pays off consistently when execution itself has already become stable enough. + +The fourth layer is `beam search` optimization. It reduces search and filtering cost under a large candidate space, but because beam logic directly participates in the main decode path, it cannot be treated as an isolated post-processing step. + +In short, the current branch should be understood not as four independent optimizations, but as one layered design shaped by the recommendation workload itself. + +### 11.6 Scenarios where the benefit is limited + +To avoid treating the current design as a universal answer, its boundaries should also be stated explicitly. + +The first category is where fixed scheduling is not ideal. If request lengths vary dramatically, many sequences terminate early, or the system cares more about the earliest insertion of a new request than about total batch throughput, continuous scheduling may become more attractive again. + +The second category is where whole-graph multi-step execution is hard to sustain. If each round still requires strong host-side decisions, or if shape, batch layout, or key inputs change significantly from round to round, then the benefits of `multi_step_pipeline` become weaker. + +The third category is where operator-level gains are limited. If beam size is small, candidate expansion is cheap, or the shared prefix is short, the engineering payoff of `xAttention` and beam-related optimization may still exist, but it will be much smaller than in a typical recommendation-serving workload. + +Writing down these boundaries is useful for two reasons: it prevents readers from assuming unconditional applicability, and it gives reviewers a clearer way to distinguish between design assumptions and implementation defects. + +## 12. Verification and Acceptance Guidance + +Verification should not stop at “the document renders correctly”. It should also answer whether the implementation in the current branch actually supports the design described in this document. + +### 12.1 Document-level verification + +The most basic verification still matters: + +- both the Chinese and English design documents render correctly; +- all image references resolve to actual files under `docs/assets/`; +- the English document points to English diagrams instead of Chinese-labeled figures; +- the entry pages and related feature pages can navigate to this design document. + +This level of verification ensures that the document itself is consumable. + +### 12.2 Code-path consistency verification + +The second level is to verify that the key paths described in the document really exist in the current branch and that their roles match the implementation: + +- whether `RecMaster` truly converges requests and selects the proper pipeline; +- whether `FixedStepsScheduler` is actually the fixed scheduling entry for `backend=rec`; +- whether `RecEngine` is really the engine-side execution organizer; +- whether `RecWorkerImpl::LlmRecMultiRoundPipeline` is really the device-side multi-round execution path; +- whether the paths of `xAttention`, `beam_search`, and `cache_select` match the design description. + +This level prevents the document from becoming a “target architecture note” that no longer matches the branch. + +### 12.3 Verification items tied to design goals + +If this document is used as a stable base for implementation discussion or technical sharing, the verification items should also map directly back to the design goals. + +For fixed scheduling, useful checks include: + +- whether decode still rebuilds batches frequently; +- whether beam-related sequences still trigger heavy pruning and movement every round; +- whether scheduler overhead is actually reduced in profiling. + +For `multi_step_pipeline`, useful checks include: + +- whether the host still decides termination at every round; +- whether next-round inputs can already be prepared while the current round is executing; +- whether `D2H/H2D` round trips are reduced relative to a host-driven round-by-round path. + +For `xAttention`, useful checks include: + +- whether shared KV is really stored once at the physical-storage level; +- whether unshared KV really carries only decode-generated tokens; +- whether repeated loading of the shared prefix is reduced. + +For `beam search`, useful checks include: + +- whether sorting and filtering still dominate the cost under large `beam_width`; +- whether beam-related data structures are reused across rounds; +- whether item filtering and candidate screening are already integrated with the main decode path. + +### 12.4 Final acceptance criteria + +If this document is treated as the design baseline of the current branch, then at minimum it should satisfy the following: + +- document level: bilingual readability, complete references, and reachable navigation; +- structure level: the relationship among scheduling, execution, and custom operators is clearly described; +- code level: key paths, key functions, and key files can be matched to the branch; +- design level: assumptions, boundaries, and trade-offs are stated clearly; +- sharing level: a reader can extract a talk outline directly from the document without rebuilding the logic from scratch. + +Only when those conditions are met does the document become a stable base shared by design discussion, technical sharing, and code review. + +## 13. FAQ / Common Misconceptions + +### Q1: Does fixed-step scheduling always outperform continuous scheduling? + +No. Fixed-step scheduling is a strong fit for the current `backend=rec` workload because the decode rounds are more fixed, `beam_width` is larger, and candidate comparison is more synchronized. If these assumptions disappear, continuous scheduling may become more attractive again. + +### Q2: Is `multi_step_pipeline` only about reducing memcpy calls? + +No. Fewer `D2H/H2D` round trips are only the most visible benefit. The more important change is the control model itself: instead of letting the host drive every round, the system pushes as much multi-round progression as possible down to the device side. + +### Q3: Why are `xAttention` and `beam search` discussed in the same section? + +Because they depend on the same condition: execution shape must already be stable. If the scheduler still rebuilds batches frequently, or the host still intervenes heavily at each round, then a large part of the operator-level gain will be diluted by system overhead. + +### Q4: Is this design only about OneRec? + +No. The document uses recommendation serving as the main narrative and often uses OneRec as an example, but the more important point is the workload shape behind `backend=rec`, not a single model name. + +### Q5: Why does the document repeatedly emphasize “stabilize scheduling first, optimize operators later”? + +Because order matters in system design. If the scheduling entrance is still unstable, and the execution path still depends heavily on host-side intervention, then operator-level optimization will struggle to produce stable end-to-end gains. + +## 14. Contract and Invariants + +If this document is expected to serve as a base for implementation and review, explanation alone is not enough. A few key contracts and invariants should be made explicit. + +### 14.1 External input shapes + +At a high level, the current `backend=rec` path accepts at least three kinds of inputs: + +- **prompt-based input** + - used when recommendation requests enter through text-like prompts +- **token / raw input** + - used when token sequences, embeddings, or related raw structures are already prepared outside +- **chat-style input** + - used when recommendation requests enter through a conversation-like service path + +The key point is not the number of input forms, but the fact that all of them must eventually converge into one unified request state and one scheduler / engine / worker execution chain. + +### 14.2 Key data that directly affects multi-round execution + +For `backend=rec`, the following data items form the most important contract surface: + +- `beam_width` + - controls how many branches are kept after each round +- `top_k` + - controls how many next-token candidates are expanded per branch +- `decode_step` / `total_round` + - defines the fixed-round boundary of execution +- `decode_positions` + - defines how token positions are organized across rounds +- shared / unshared KV + - defines how KV cache is split and reused in decode + +If these data items are not organized early enough and consistently enough, the gains of fixed-step scheduling and multi-step execution will not remain stable. + +### 14.3 Invariants that the document should keep stable + +If this document continues to grow later, the following invariants are worth preserving: + +- whenever the document mentions fixed-step scheduling, it should also explain the boundary conditions; +- whenever it mentions `multi_step_pipeline`, it should state which host-side work is reduced; +- whenever it mentions `xAttention`, it should explain its dependency on KV organization; +- whenever it mentions `beam search`, it should treat it as part of the main serving cost rather than only an algorithmic trick; +- all code anchors must continue to match the current branch and must not silently drift toward future paths. + +## 15. Implementation Status Matrix + +To avoid mixing design intent and branch reality, it is useful to make the current alignment explicit. This is not meant to be an exhaustive implementation matrix, but rather a stable summary of what the current document already treats as grounded facts. + +### 15.1 Capabilities that already have concrete implementation anchors + +- **fixed scheduling** + - the current branch contains `FixedStepsScheduler` + - it can be cited as a concrete scheduling fact rather than a conceptual placeholder +- **multi-round execution path** + - the current branch contains `RecMultiRoundEnginePipeline` + - the worker side contains `LlmRecMultiRoundPipeline` +- **beam search hot path** + - the current branch already contains beam-search-related call paths and kernels +- **cache select hot path** + - the current branch already contains cache-select-related call paths and kernels +- **xAttention / flashinfer path** + - the current branch already contains the relevant implementation files and call chain + +### 15.2 Parts that still remain design-level abstractions in this document + +- the relationship among fixed scheduling, multi-step execution, and custom operators + - this document explains it as a system design + - it is not represented by one single class or one single file +- the end-to-end gain of multi-stream overlap + - the current branch contains the concurrency and stream infrastructure + - but the gain discussion in the document still belongs to system-level interpretation, not one single implementation object + +### 15.3 Why this section matters + +The point of this section is to prevent readers from confusing “the document presents a coherent design” with “every part of the design is already fully materialized in the current branch”. It helps reviewers and readers distinguish between implementation-grounded facts and system-level abstraction. + +## 16. Failure Modes and Observability + +If the document only explains the normal path, it still feels closer to a talk draft than to a technical design note. An online system also needs a clear view of failure modes and observability. + +### 16.1 Typical failure modes + +- **waiting time grows too much** + - the fixed-step window increases queueing time for new requests +- **beam becomes too large** + - sorting, filtering, and candidate management dominate the cost +- **shared / unshared KV is not organized correctly** + - this can lead to memory waste, repeated loading, or wrong cache selection +- **host participation remains too high** + - this directly weakens the benefit of multi-step execution +- **operator-level gains do not materialize** + - this usually means the execution shape is not stable enough yet + +### 16.2 What should be observed first + +Even without a full monitoring chapter, the document should at least clarify the key observation points: + +- end-to-end request latency +- per-round decode latency +- scheduler overhead ratio +- host-device round trips or synchronization points +- beam-search / cache-select share in total latency +- memory peak and fragmentation tendency + +### 16.3 Shortest debugging path + +When behavior does not match the design expectation, the following order is often the shortest path: + +1. verify that the request actually matches the fixed-step workload assumptions; +2. verify that the decode path really enters the multi-round worker pipeline; +3. verify whether beam search and cache select dominate the cost; +4. finally verify whether xAttention and KV organization actually reduce repeated loading and copying. + +## 17. Benchmark and Acceptance Protocol + +To make this design document more useful as a review and acceptance baseline, it is helpful to include a lightweight benchmark and acceptance template. + +### 17.1 Recommended workload coverage + +- **long input / short output** + - the most representative generative recommendation scenario +- **different beam_width settings** + - for example, one moderate-beam case and one large-beam case +- **different top_k settings** + - useful for observing how search and filtering cost scales + +### 17.2 Recommended metric set + +- throughput +- P50 / P95 / P99 latency +- per-round decode cost +- scheduler overhead ratio +- host participation overhead +- peak memory usage + +### 17.3 Acceptance focus + +- whether fixed-step scheduling really reduces scheduling disturbance +- whether multi-step execution really reduces host-side round control overhead +- whether `xAttention` really improves shared-prefix reuse +- whether beam-related optimization really reduces search and filtering cost + +The value of this section is that it turns “the design sounds reasonable” into “the design can be evaluated with a stable acceptance template”. + +## 18. Key Parameters and Suggested Ranges + +To keep this document from remaining purely conceptual, it is useful to end with a lightweight parameter-oriented section. The purpose is not to declare one universally correct configuration, but to summarize which parameters matter the most and why. + +### 18.1 `beam_width` + +`beam_width` controls how many candidate branches survive after each round. It is one of the most important cost drivers in generative recommendation decode. + +- when `beam_width` is smaller: + - search cost is lower + - candidate coverage is weaker + - lower latency is easier to achieve +- when `beam_width` is larger: + - candidate coverage is stronger + - sorting, filtering, and cache-select cost all increase + - the gain of fixed-step and multi-step execution becomes more visible, but decode-side cost also grows quickly + +So `beam_width` should be treated as a core trade-off parameter among recommendation quality, total latency, and system overhead. + +### 18.2 `top_k` + +`top_k` controls how many candidates are expanded from each beam in one round. +In a large candidate space, increasing `top_k` directly amplifies: + +- sorting work +- invalid-item filtering work +- intermediate candidate-state management cost + +If `top_k` is too large and invalid-item filtering happens too late, then a large amount of computation is spent on candidates that will never survive into the next round. + +### 18.3 Decode-round count + +The more stable the decode-round count is, the easier it is for fixed-step scheduling and multi-step execution to keep producing stable gains. +If the number of rounds stays within a small and predictable range, then: + +- device-side structures can be allocated earlier; +- `step_meta` remains easier to stabilize; +- whole-graph multi-step execution becomes easier to exploit. + +If the round count itself becomes unstable, one of the core assumptions behind the current design becomes weaker. + +### 18.4 Multi-stream concurrency + +More streams do not automatically mean better performance. Too few streams may fail to hide waiting cost, while too many streams may introduce additional contention, scheduling overhead, and resource competition. + +A more reasonable tuning order is usually: + +1. confirm that fixed-step scheduling is already stable; +2. confirm that host-side waiting and queueing are actually significant; +3. only then increase multi-stream or multi-pipeline concurrency. + +### 18.5 Suggested parameter-tuning order + +If this path is tuned further later, the following order is usually more effective: + +1. first confirm that the decode-round count is stable; +2. then tune `beam_width`; +3. then tune `top_k`; +4. only after that tune multi-stream and system overlap. + +The reason is simple: the first two primarily determine workload shape, while the latter ones optimize the system under that shape. + +## 19. References + +This document is implementation-oriented, but some of the background problems, serving ideas, and memory/graph execution context naturally come from public references. If the document is extended later, the following materials are useful anchor points. + +### 19.1 Scheduling and serving systems + +- *Orca: A Distributed Serving System for Transformer-Based Generative Models* + - useful as the background reference for continuous batching + - useful as a comparison point when explaining why REC prefers fixed-step execution + +### 19.2 Memory management and KV cache + +- *Efficient Memory Management for Large Language Model Serving with PagedAttention* + - useful as a background reference for block-based KV management + - useful when explaining why shared/unshared KV separation matters + +### 19.3 Generative recommendation models + +- OneRec + - useful when describing the structure of recall-oriented generative recommendation models +- OneTrans + - useful when describing the structure of ranking-oriented generative recommendation models + +### 19.4 Graph execution and multi-step execution + +- the existing `Graph Mode` design document in this repository + - useful for graph capture / replay, parameterization, and Piecewise Graph +- the current generative recommendation design document + - useful for fixed-step scheduling, multi-round decode, beam search, and xAttention as one coherent serving stack + +### 19.5 Suggested use + +If the document is extended further later: + +- read Orca first when writing scheduling background; +- read PagedAttention first when writing KV / memory layout background; +- return to OneRec / OneTrans when writing model-structure background; +- return to this document and the code-anchor appendix when writing implementation mapping. + +This keeps future expansion grounded and prevents the document from turning into a loose collection of disconnected explanations. diff --git a/docs/en/design/graph_mode_design.md b/docs/en/design/graph_mode_design.md index 843547536..6dc92c302 100644 --- a/docs/en/design/graph_mode_design.md +++ b/docs/en/design/graph_mode_design.md @@ -24,6 +24,10 @@ The non-goals of this document are: - full adaptation details for every operator or every model - replacing feature documentation for flags and usage examples +Related design documents: + +- for a recommendation-oriented case study that focuses on fixed scheduling, multi-step execution, and custom operators, see: [Generative Recommendation Design Document](generative_recommendation_design.md) + ## 1. Graph Mode Fundamentals ### 1.1 Capture / Replay Basics diff --git a/docs/en/dev_guide/tilelang_ascend_kernel_dev.md b/docs/en/dev_guide/tilelang_ascend_kernel_dev.md new file mode 100644 index 000000000..95e6cdc78 --- /dev/null +++ b/docs/en/dev_guide/tilelang_ascend_kernel_dev.md @@ -0,0 +1,396 @@ +# xLLM Ascend TileLang Kernel Development Guide + +This document explains how to add or modify an Ascend TileLang kernel in xLLM. The examples use the current `rope` kernel throughout. + +Relevant directories: + +- Python kernel definitions: `xllm/xllm/compiler/tilelang/targets/ascend/kernels` +- NPU runtime wrappers: `xllm/xllm/core/kernels/npu/tilelang` + +Builds and tests should be run inside the NPU container. + +## 1. First Decide What Kind of Change You Are Making + +- Add a `specialization` + - Add one more compiled parameter combination to an existing kernel + - Reuse the same wrapper, the same runtime dispatch fields, and the same C ABI + - Typical changes are updates to `DISPATCH_SCHEMA` or `SPECIALIZATIONS` +- Add a `kernel` + - Add a new logical operator + - Typical changes are a new Python kernel file, a new wrapper C++ file, and one CMake registration + +For `rope`: + +- Adding one more item like `{"variant_key": "...", "head_dim": ..., "rope_dim": ..., "dtype": ...}` to `SPECIALIZATIONS` means adding a new `specialization` +- Adding a new external interface such as `xxx_wrapper.cpp` means adding a new `kernel` + +## 2. Development Order + +The recommended order is: + +1. Implement the TileLang kernel in a Python file such as `rope.py` +2. Implement `generate_source(...)` to lower the kernel into Ascend-C source +3. Declare `DISPATCH_SCHEMA` and `SPECIALIZATIONS` +4. Generate `registry.inc` once and inspect it +5. Then write or update the runtime specialization construction logic in the wrapper +6. Wire it into CMake and run tests + +The key idea behind this order is: + +- implement the kernel itself first +- then fix the runtime dispatch schema +- then write the wrapper against the generated `registry.inc` + +## 3. Write the Python Kernel + +Using `rope.py` as the example, the Python side can be understood in three layers: + +- `build_rope_kernel(...)`: kernel implementation +- `generate_source(...)`: AOT export +- `RopeKernel`: kernel registration plus dispatch schema and compiled instance declaration + +### 3.1 Implement `build_rope_kernel(...)` + +`build_rope_kernel(...)` is the actual TileLang kernel implementation. This is where you write: + +- `@T.prim_func` +- input and output tensor shapes +- parallel task organization under `with T.Kernel(...)` +- UB allocation and the actual compute logic + +The simplified structure in `rope.py` looks like this: + +```python +def build_rope_kernel( + head_dim: int, + rope_dim: int, + vec_core_num: int, + ub_buffer_bytes: int, +): + task_num = vec_core_num + m_num = vec_core_num // 2 + + @T.prim_func + def rope_in_place_kernel(...): + with T.Kernel(m_num, is_npu=True) as (cid, vid): + task_id = cid * 2 + vid + ... + + return rope_in_place_kernel +``` + +Here, `head_dim` and `rope_dim` are the compile-time parameters that this implementation actually depends on. + +Vector kernels such as `rope` must also follow the fixed-task convention used by the current AOT path. In the current AOT flow, the kernel launch `block_num` is fixed at compile time, which means: + +- runtime input shapes do not change the kernel launch `block_num` +- runtime input shapes only change workload splitting across the fixed tasks + +The convention in the current `rope.py` is: + +```python +task_num = vec_core_num +m_num = vec_core_num // 2 + +with T.Kernel(m_num, is_npu=True) as (cid, vid): + task_id = cid * 2 + vid +``` + +This means: + +- `cid` ranges over `[0, vec_core_num // 2)` +- `vid` ranges over `[0, 2)` +- the total task count is fixed as `task_num = vec_core_num` + +As a result, `rope.py` also derives the compile-time token count for one specialization using the fixed task count: + +```python +max_rows_num_in_ub = _derive_max_rows_num_in_ub(...) +compile_num_tokens = task_num * max_rows_num_in_ub +``` + +### 3.2 Implement `generate_source(...)` + +`generate_source(...)` lowers the TileLang kernel above into the final source code. The export layer takes one specialization and turns it into compilable Ascend-C source. + +For `rope`, the core logic is: + +```python +@staticmethod +def generate_source(head_dim: int, rope_dim: int, dtype: str) -> str: + vec_core_num = detect_vec_core_num() + tilelang_kernel = build_rope_kernel( + head_dim=head_dim, + rope_dim=rope_dim, + vec_core_num=vec_core_num, + ub_buffer_bytes=FIXED_UB_BUFFER_BYTES, + ) + with tilelang.tvm.transform.PassContext(...): + kernel = tilelang.engine.lower(tilelang_kernel) + return kernel.kernel_source +``` + +The rules here are: + +- the inputs to `generate_source(...)` come from the current `SPECIALIZATIONS` entry +- `generate_source(...)` calls `build_rope_kernel(...)` +- the return value is the lowered source string + +### 3.3 Declare `DISPATCH_SCHEMA` and `SPECIALIZATIONS` + +After the kernel implementation and export layer are done, use an `@register_kernel` class to attach the kernel to the framework. + +The current minimal template in `rope.py` is: + +```python +from ....common.spec import DispatchField, TilelangKernel, register_kernel + + +@register_kernel +class RopeKernel(TilelangKernel): + DISPATCH_SCHEMA = [ + DispatchField("head_dim", "int32"), + DispatchField("rope_dim", "int32"), + DispatchField("dtype", "dtype"), + ] + SPECIALIZATIONS = [ + { + "variant_key": "hd128_rd128_bf16", + "head_dim": 128, + "rope_dim": 128, + "dtype": "bf16", + }, + { + "variant_key": "hd576_rd64_bf16", + "head_dim": 576, + "rope_dim": 64, + "dtype": "bf16", + }, + ] + + @staticmethod + def generate_source(head_dim: int, rope_dim: int, dtype: str) -> str: + ... +``` + +There are two concepts to distinguish here: + +- `DISPATCH_SCHEMA` + - defines the field names, order, and types of the runtime specialization + - is the single source of truth for the C++ specialization struct, builder, and lookup interface +- `SPECIALIZATIONS` + - represents the set of instances that will actually be compiled + - each item corresponds to one variant + +The rules are: + +- every field in `DISPATCH_SCHEMA` must appear in every `SPECIALIZATIONS` item +- `SPECIALIZATIONS` may contain extra fields; those fields are passed into `generate_source(...)`, but do not enter the runtime dispatch schema +- `variant_key` is the unique identifier for that specialization +- `DISPATCH_SCHEMA` and `SPECIALIZATIONS` must match the runtime specialization one-to-one + +For `rope`, the runtime dispatch dimensions are: + +- `head_dim` +- `rope_dim` +- `dtype` + +So these three fields must appear in both: + +- `DISPATCH_SCHEMA` +- every `SPECIALIZATIONS` item + +At build time, Ascend build resolves the actual `bisheng_arch` from the `--device a2|a3` value passed by the main build path. + +### 3.4 Inspect the Generated Ascend-C Source + +When debugging the implementation details in `build_rope_kernel(...)`, or comparing how different kernel styles affect the final code generation, use the common `compile-kernels` entry to regenerate artifacts and inspect the Ascend-C source for the specialization you care about. + +For `rope`, you can fix: + +- `head_dim=576` +- `rope_dim=64` +- `dtype=bf16` + +Then regenerate the `rope` artifacts: + +```bash +python xllm/compiler/tilelang_launcher.py compile-kernels \ + --target ascend \ + --device a3 \ + --output-root /tmp/tilelang_debug \ + --kernels rope \ + --force +``` + +It is recommended to keep `--force` so the source and object files are regenerated from the current code instead of reusing an old cache hit. + +This command uses an isolated debug output directory, `/tmp/tilelang_debug`, so only the debug artifacts for `rope` are generated there and they do not get mixed with artifacts from other kernels in the main build directory. + +After that, you can directly inspect the generated source for the specialization, including the entry function, UB allocation, and vector compute logic: + +```bash +sed -n '1,200p' \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp + +rg -n 'extern "C"|__global__|alloc_ub|alloc_shared|g_tilingKey' \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp +``` + +To compare two kernel implementations, keep the specialization fixed, run `compile-kernels --force` before and after the change, then diff the generated `.cpp` file: + +```bash +cp /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp \ + /tmp/rope_before.cpp + +diff -u /tmp/rope_before.cpp \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp +``` + +This helps isolate specialization changes from kernel implementation changes. + +After generation, the main files to inspect are: + +- `/tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp` +- `/tmp/tilelang_debug/targets/ascend/rope/registry.inc` +- `/tmp/tilelang_debug/targets/ascend/rope/manifest.json` + +These correspond to: + +- the final Ascend-C source for one specialization +- the runtime dispatch interface directly included by the wrapper +- the full compiled artifact record for the current kernel + +The recommended debugging sequence is: + +1. run `compile-kernels --force` to regenerate the current kernel artifacts +2. inspect the `.cpp` for the specialization and analyze the code generation result +3. inspect `registry.inc` and `manifest.json` to confirm they match expectations +4. finally run `rope_wrapper_test` to check end-to-end behavior and performance + +## 4. Update the Wrapper + +When adding a new `kernel`, you need a new wrapper. When adding a new `specialization`, the wrapper only needs an update if the runtime specialization semantics change. + +For `rope_wrapper.cpp`, the manually written parts should remain: + +- tensor shape, dtype, and layout validation +- reshaping inputs into `x_rows / sin_rows / cos_rows` +- constructing the runtime specialization from tensors +- assembling launch arguments and calling `entry->fn(...)` + +### 4.1 What `registry.inc` Generates Automatically + +`registry.inc` is generated automatically from the Python-side `DISPATCH_SCHEMA`, `SPECIALIZATIONS`, and the exported Ascend-C ABI. + +For `rope`, the generated content includes: + +- `RopeSpecialization` +- `RopeHeadDim` +- `RopeRopeDim` +- `RopeDType` +- `RopeKernelFn` +- `make_rope_specialization(...)` +- `find_rope_kernel_entry(...)` +- `available_rope_variant_keys()` + +For `rope_wrapper.cpp`, `registry.inc` directly provides dispatch-related definitions such as `RopeSpecialization`, `operator==(...)`, and `RopeKernelFn`. Dtype conversion uses the shared helper `to_tilelang_dtype(...)`. + +### 4.2 What the Wrapper Actually Needs to Write + +The most important handwritten logic in `rope_wrapper.cpp` is constructing the runtime specialization from the tensors. The current code looks like this: + +```cpp +RopeSpecialization build_runtime_specialization(const torch::Tensor& x_rows) { + return make_rope_specialization( + RopeHeadDim{static_cast(x_rows.stride(0))}, + RopeRopeDim{static_cast(x_rows.size(1))}, + RopeDType{to_tilelang_dtype(x_rows.scalar_type())}); +} +``` + +For `rope`: + +- `head_dim` maps to `x_rows.stride(0)`, which is the `x_stride` used by the kernel +- `rope_dim` maps to `x_rows.size(1)` +- `dtype` maps to `x_rows.scalar_type()` + +The runtime path is: + +1. the wrapper reshapes the inputs into `x_rows / sin_rows / cos_rows` +2. `build_runtime_specialization(...)` constructs a specialization from `x_rows` +3. `find_rope_kernel_entry(...)` performs an exact match in the static registry +4. after a match, `entry->fn(...)` calls the actual compiled symbol + +The current lookup strategy is a linear scan with exact matching. If any of `head_dim`, `rope_dim`, or `dtype` differs, the lookup will miss. + +So when you add a new `specialization`, the main things to cross-check are: + +- the field semantics in Python-side `DISPATCH_SCHEMA` +- the field values in Python-side `SPECIALIZATIONS` +- the field values constructed by `build_runtime_specialization(...)` in the wrapper + +All three must match exactly. + +### 4.3 Generate and Inspect `registry.inc` First + +Before writing or modifying wrapper code, generate `registry.inc` once and inspect it. Focus on: + +- whether the generated field order in `RopeSpecialization` matches expectations +- whether the generated wrapped field type names match expectations +- the parameter order of `make_rope_specialization(...)` +- the generated entry symbol names + +`registry.inc` is the direct contract for the wrapper. Inspect it first, then write the wrapper against it. + +## 5. Update CMake + +When adding a new `kernel`, register it in `xllm/xllm/core/kernels/npu/tilelang/CMakeLists.txt`. + +CMake registration is unified through the high-level helper: + +- `tilelang_register_runtime_kernel(NAME WRAPPER_SRCS )` + +Using `rope` as the example, the minimal template is: + +```cmake +tilelang_register_runtime_kernel( + NAME rope + WRAPPER_SRCS rope_wrapper.cpp +) +``` + +This helper will: + +- derive the manifest path as `TILELANG_GENERATED_ROOT/targets/ascend//manifest.json` +- import the manifest +- add the wrapper source and compiled objects into `tilelang_kernels` +- append the `XLLM_TL__REGISTRY_INC=...` compile definition automatically + +So when adding a new runtime kernel, the CMake-side work mainly consists of two things: + +1. make sure the Python side can already generate the manifest for that kernel +2. add one `tilelang_register_runtime_kernel(...)` entry in the TileLang CMakeLists + +For day-to-day kernel additions, add one `tilelang_register_runtime_kernel(...)` line directly in CMake. `tilelang_import_kernel_manifest(...)` stays underneath as the implementation base for that higher-level helper. + +## 6. Validate + +The recommended validation order is: + +1. compile the TileLang kernel and inspect the generated `registry.inc` +2. then run the full wrapper test + +Common commands: + +```bash +python xllm/compiler/tilelang_launcher.py compile-kernels \ + --target ascend \ + --device a3 \ + --output-root build/cmake.linux-aarch64-cpython-311/xllm/compiler/tilelang \ + --kernels rope + +python setup.py test --test-name rope_wrapper_test --device a3 +``` + +The first command generates `manifest.json`, `registry.inc`, and the object files. The second command validates the full integration path. diff --git a/docs/en/features/overview.md b/docs/en/features/overview.md index c2b530d89..d0a810def 100644 --- a/docs/en/features/overview.md +++ b/docs/en/features/overview.md @@ -42,7 +42,7 @@ We have implemented various request scheduling strategies supporting continuous #### Global KV Cache Management Utilizes ETCD as a metadata service middleware at the global level for cluster service registration, load information synchronization, and global cache state management. Each compute instance maintains a local multi-level cache pool. Regarding scheduling strategy, the system adopts a dynamic decision-making mechanism based on KV cache: it first performs prefix matching detection, calculates the KV cache reuse rate of candidate nodes, and finally selects the node with the optimal comprehensive performance for processing, achieving dynamic offloading and migration of KV cache. #### Speculative Inference -xLLM incorporates an optimized speculative inference algorithm that generates multiple tokens at once to boost throughput. xLLM reduces communication costs by下沉 (sinking) the speculative module and optimizes speculative inference computation through methods like overlapping scheduling and computation timelines and reducing operator data movement in speculative scenarios. +xLLM incorporates an optimized speculative inference algorithm that generates multiple tokens at once to boost throughput. xLLM reduces communication costs by sinking the speculative module and optimizes speculative inference computation through methods like overlapping scheduling and computation timelines and reducing operator data movement in speculative scenarios. #### MoE Load Balancing xLLM implements expert weight updates based on historical expert load statistics for MoE models. During inference, it achieves effective dynamic load balancing through efficient expert load statistics and double-buffered, seamless expert weight updates. diff --git a/docs/en/getting_started/quick_start.md b/docs/en/getting_started/quick_start.md index c316d11b5..98ca14d66 100644 --- a/docs/en/getting_started/quick_start.md +++ b/docs/en/getting_started/quick_start.md @@ -94,7 +94,7 @@ cd xllm pip install pre-commit pre-commit install -git submodule update --init +git submodule update --init --recursive ``` The compiled binary file is located at `/path/to/xllm/build/xllm/core/server/xllm`. In a new image, the first compilation of xllm takes a long time because all dependencies in vcpkg need to be compiled, but subsequent compilations will be much faster. diff --git a/docs/en/supported_models.md b/docs/en/supported_models.md index 3f116375a..f8731093e 100644 --- a/docs/en/supported_models.md +++ b/docs/en/supported_models.md @@ -43,6 +43,7 @@ ## Rec | | NPU | MLU | ILU | | --- | :---: | :---: | :---: | -| | | | | -| | | | | -| | | | | \ No newline at end of file +| OneRec | ✅ | ❌ | ❌ | +| Qwen2 | ✅ | ❌ | ❌ | +| Qwen2.5 | ✅ | ❌ | ❌ | +| Qwen3 | ✅ | ❌ | ❌ | diff --git a/README_zh.md b/docs/project/README_zh.md old mode 100755 new mode 100644 similarity index 94% rename from README_zh.md rename to docs/project/README_zh.md index f4534b10b..654fe75c3 --- a/README_zh.md +++ b/docs/project/README_zh.md @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -[English](./README.md) | [中文](./README_zh.md) +[English](../../README.md) | [中文](./README_zh.md)
-xLLM +xLLM [![Document](https://img.shields.io/badge/Document-black?logo=html5&labelColor=grey&color=red)](https://xllm.readthedocs.io/zh-cn/latest/) [![Docker](https://img.shields.io/badge/Docker-black?logo=docker&labelColor=grey&color=%231E90FF)](https://hub.docker.com/r/xllm/xllm-ai) [![License](https://img.shields.io/badge/license-Apache%202.0-brightgreen?labelColor=grey)](https://opensource.org/licenses/Apache-2.0) [![report](https://img.shields.io/badge/Technical%20Report-red?logo=arxiv&logoColor=%23B31B1B&labelColor=%23F0EBEB&color=%23D42626)](https://arxiv.org/abs/2510.14686) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/jd-opensource/xllm) @@ -41,7 +41,7 @@ limitations under the License. --> **xLLM** 是一个高效的开源大模型推理框架,专为**国产芯片**优化设计,提供企业级的服务部署,使得性能更高、成本更低。该框架采用**服务-引擎分离的推理架构**,通过服务层的在离线请求弹性调度、动态PD分离、EPD混合机制及高可用容错设计,结合引擎层的多流并行计算、图融合优化、投机推理、动态负载均衡及全局KV缓存管理,实现推理效率突破性提升。xLLM整体架构和功能如下图所示:
-xllm_arch +xllm_arch
**xLLM** 已支持主流大模型(如 *DeepSeek-V3.1*,*Qwen2/3*等)在国产芯片上的高效部署,助力企业实现高性能、低成本的 AI 大模型应用落地。xLLM已全面落地京东零售核心业务,涵盖智能客服、风控、供应链优化、广告推荐等多种场景。 @@ -85,13 +85,13 @@ xLLM 提供了强大的智能计算能力,通过硬件系统的算力优化与 | ILU | BI150 | | | MUSA | S5000 | | -此外,请在[模型支持列表](docs/zh/supported_models.md)查看不同硬件上的模型支持情况。 +此外,请在[模型支持列表](../zh/supported_models.md)查看不同硬件上的模型支持情况。 --- ## 快速开始 -请参考[快速开始文档](docs/zh/getting_started/quick_start.md)。 +请参考[快速开始文档](../zh/getting_started/quick_start.md)。 --- @@ -120,7 +120,7 @@ xLLM 提供了强大的智能计算能力,通过硬件系统的算力优化与 如果您有企业内部Slack,请直接联系xLLM Core团队。另外,我们建立了官方微信群,可以访问以下二维码加入。欢迎沟通和联系我们:
- qrcode3 + qrcode3
--- diff --git a/RELEASE.md b/docs/project/RELEASE.md similarity index 100% rename from RELEASE.md rename to docs/project/RELEASE.md diff --git a/docs/zh/design/generative_recommendation_design.md b/docs/zh/design/generative_recommendation_design.md new file mode 100644 index 000000000..0bcaa3545 --- /dev/null +++ b/docs/zh/design/generative_recommendation_design.md @@ -0,0 +1,1076 @@ +# 生成式推荐设计文档 + +## 概述 + +xLLM 在 `backend=rec` 场景下提供了生成式推荐推理能力。其目标不是替代现有推荐系统,而是在保留 `predictor` 侧稀疏特征处理和在线服务能力的前提下,把 LLM 主体推理能力复用到推荐场景中,用于候选扩展、候选比较和最终结果生成。 + +本文档重点说明以下内容: + +- 生成式推荐场景的目标与约束 +- 推荐模型结构与推理接入方式 +- 为什么推荐场景更适合固定调度和整图执行 +- `xAttention` 与 `beam search` 如何围绕显存和执行效率协同优化 +- 当前分支中与生成式推荐相关的核心代码分布 + +本文档的设计目标包括: + +- 用统一视角解释 `backend=rec` 的推理链路 +- 说明固定调度、整图执行和定制算子之间的关系 +- 为后续技术分享、代码走读和文档扩展提供稳定底稿 + +本文档的非目标包括: + +- 不展开推荐模型训练细节 +- 不覆盖所有线上业务接入差异 +- 不替代各模块的详细 API 文档 + +## 1. 背景和问题 + +最近几年,基于 LLM 的生成式推荐取得了比较明显的进展。在 xLLM 中,我们也逐步补齐了对生成式推荐推理的支持。生成式推荐的目标,不是简单把大模型能力接进推荐系统,而是希望利用生成式建模能力,在候选扩展和排序阶段提升效果,尤其是提升 `CTR` 这类核心指标。 + +在当前方案中,我们使用自研 xLLM 作为统一推理引擎,通过动态库(`.so`)方式接入现有预测链路: + +- `predictor` 侧继续负责稀疏特征处理、样本组织和在线服务集成; +- `xLLM` 侧负责完成 LLM 相关推理计算。 + +这样做的价值在于,推荐系统原有的工程能力可以保留,而 xLLM 在算子、KV Cache、多后端执行和调度上的基础设施也能够直接复用。 + +但生成式推荐和通用 LLM 推理,优化目标并不相同。 + +- 通用 LLM 推理更关注逐步生成的体验,例如尽快返回第一个结果、尽量缩短每一步生成之间的间隔,并允许请求在执行过程中灵活插入和提前结束; +- 生成式推荐更关注整次请求的总时延,以及在有限几轮内得到更优的候选结果。 + +原因很直接:推荐场景通常不是生成一段开放文本,而是在固定几轮里不断扩展候选、比较候选,最后输出更优结果。 + +![生成式推荐整体背景](../../assets/generative_recommendation_overview.png) + +这里经常会用到 `beam search`。可以把它理解为:在每一轮里,不只保留当前最优的一条路径,而是同时保留多个高分候选,并在后续轮次继续扩展和比较,最后从这些候选里选出更优结果。在推荐场景里,这样做的意义不是“生成更长内容”,而是“在有限几步内覆盖更多高质量候选,提高最终推荐效果”。 + +![Beam Search 在生成式推荐中的作用](../../assets/generative_recommendation_beam_search.png) + +因此,生成式推荐天然有两个特征: + +- 固定步数推进; +- 多个候选同步比较。 + +也就是说,这个场景真正要优化的,不是“某一条序列先跑完”,而是“多个候选在固定几轮里稳定推进,并在每一轮完成低开销比较”。这也决定了后续的设计方向:调度层更适合使用固定调度,执行层更适合做整图执行,并在稳定执行形态上做专门的算子优化。 + +### 1.1 生成式推荐和通用 LLM 推理的 workload 差异 + +如果只从“模型里也有 attention”这一点来看,生成式推荐似乎和通用 LLM 推理很接近;但从服务 workload 的角度看,两者其实差异很大。 + +生成式推荐更常见的输入输出形态是: + +- 输入很长:因为用户历史行为、上下文特征、候选上下文往往都要进入模型; +- 输出很短:最终只需要生成固定长度的 item token 序列; +- decode 轮数固定:例如只生成 2 到 4 个 token 就结束; +- 每轮 decode 的代价不低:因为不是单路 greedy decode,而往往伴随着较大的 `beam_width` 和 `top_k` 候选扩展。 + +这和通用 LLM 推理正好形成一个鲜明对照: + +- 通用 LLM 推理常见的是“短 prompt + 长输出”; +- 生成式推荐常见的是“长 prompt + 短输出”。 + +这意味着生成式推荐并不会因为“输出很短”就自然变得便宜。相反,正因为输出阶段带着 `beam search`、大候选池和固定多轮比较,单步 decode 的控制和搜索成本会被放大。 + +### 1.2 这个场景最难的 3 个问题 + +从生成式推荐服务本身的 workload 特征出发,和通用 LLM 推理相比,最值得单独拎出来的有 3 类问题。 + +#### 问题一:长 prompt、短输出,但 decode 单步并不便宜 + +生成式推荐的 decode 轮数虽然固定且较少,但每一轮往往都要处理大规模候选比较。这和“输出短所以推理简单”的直觉不一样。 +如果继续沿用通用推理系统对 decode 的组织方式,就会把大量优化空间浪费在 shared prefix 的重复加载、beam 之间的冗余 KV 访问,以及 block 粒度下的重排与复制上。 + +#### 问题二:beam search 不只是搜索问题,还是系统问题 + +在生成式推荐里,`beam search` 的作用不是语言生成里的“多样性增强”,而是核心推荐候选扩展逻辑的一部分。 +一旦 `beam_width` 和 `top_k` 变大,排序、过滤、候选保活、结构复用都会成为显著成本。 +也就是说,beam search 在这里不是单独一个算法模块,而是会牵动调度、显存、数据结构和 kernel 组织方式的系统级问题。 + +#### 问题三:系统瓶颈不只在算子,也在 Host 与 Device 的协作方式 + +生成式推荐通常有严格的在线时延约束,同时并发量又高。 +如果 host 仍然按照“每一步都回来做判断、准备下一步、再发回设备”的方式控制整个过程,那么 host 调度和数据搬运本身就会成为显著瓶颈。 +这也是为什么这个场景里不仅要优化算子,还要重构整体 pipeline,让 host 和 device 的职责划分更适合 fixed-step workload。 + +### 1.3 这篇设计文档的定位 + +这篇文档不会去复述外部材料,也不试图替代论文式分析。 +更准确的定位是:它把生成式推荐场景里的工作负载特征、系统问题和当前分支的真实实现路径整理到同一份设计文档里,帮助后续做分享、走读和代码 review。 + +## 2. 推理架构 + +### 2.1 模型结构介绍 + +生成式推荐是近两年推荐系统领域的重要方向。它正在打破传统“召回-排序-重排”的级联边界,把推荐任务从“判别式匹配”推进到“生成式预测”。当前文档里重点关注两类已经在线上大规模使用的模型:用于召回的 OneRec 模型,以及用于精排的 OneTrans 模型。 + +![OneRec 模型结构](../../assets/generative_recommendation_model_onerec.png) + +![OneTrans 模型结构](../../assets/generative_recommendation_model_onetrans.png) + +从这些模型的共同点来看,它们保留了传统 CTR 场景里的序列特征、用户静态特征和上下文特征,并由输入适配层把异构推荐信号(离散 ID、连续值、序列、多模态内容)统一映射为 LLM Decoder 可理解的嵌入表示(embedding),必要时再与 LLM 的词表嵌入空间对齐。模型主体则是 LLM 的 Encoder+Decoder 或 Decoder-only 结构,因此不同部分需要不同的推理引擎承接。 + +### 2.2 推理架构介绍 + +根据模型结构特点,当前方案把模型切成两类子图: + +- 输入适配层仍然归属于传统 CTR 推理范畴,由 `predictor` 承接; +- LLM 主体部分由 xLLM 承接。 + +作为 LLM 推理的核心引擎,xLLM 在生成式推荐场景下提供了两种接入方式:RPC 接入与动态库(`.so`)接入。 + +#### 2.2.1 RPC 接入方式 + +当前营销等在线召回场景的生成式推荐主要采用 RPC 方式接入。它的优点是服务边界清晰、接入方式稳定,但也会引入额外的 RPC 调用开销。 + +#### 2.2.2 动态库接入方式 + +另一种方式是把 xLLM 作为 `predictor` 内部的独立推理引擎,对模型中属于 LLM 主体的子图直接做推理。这样可以省掉 RPC 往返开销,后续更适合承接需要低延迟的相关业务。 + +![xLLM 在生成式推荐中的接入架构](../../assets/generative_recommendation_integration_architecture.jpg) + +## 3. 固定调度与整图执行 + +### 3.1 固定步数调度 + +![Orca 中 continuous batching 的背景](../../assets/fixed_steps_scheduler_orca.png) + +上图来自论文《Orca: A Distributed Serving System for Transformer-Based Generative Models》,它介绍了 `continuous batching` 的背景:通过动态重组 batch,避免固定 batch 调度导致算力空转。 + +但生成式推荐是固定步数的,这一点改变了调度问题本身。从调度角度看,生成式推荐更适合 `fixed_steps_scheduler`,而不是 `continuous batching`。原因不只是“固定步数所以固定调度”,而是因为这个场景本身就是按固定几轮来组织计算的。既然请求通常会在约定好的几步里完成,而且多个候选需要同步向前推进,那么调度器最重要的任务就不是“随时插队、随时清退”,而是“把这一组候选稳定地发出去,并尽量减少额外调度动作”。 + +`fixed_steps_scheduler` 的第一个好处,是更适合 `beam search`。在 `decode` 阶段,`beam width` 往往比较大,我们希望多个 beam 在同一轮里一起推进、一起比较。如果采用连续调度,那么每一步都可能触发 batch 重组、sequence 压缩、索引重排和状态裁剪。这些动作在通用 LLM 推理里是合理的,因为请求确实会动态结束;但在生成式推荐里,它们很多时候并不是收益,而是额外成本。使用固定调度之后,同一个请求下的多个 beam 可以在固定窗口里齐头并进,调度器不需要每一步都重新组织 batch,也不需要反复判断哪些序列该保留、哪些序列该剔除。这样做可以明显减少调度层的控制开销。 + +第二个好处,是执行形态会更稳定。一旦解码轮数固定、beam group 规模固定、推进节奏固定,很多后续优化才真正有了基础。比如 buffer 可以提前分配,workspace 更容易复用,cache 访问模式也更规整。对于性能优化来说,这种稳定性很重要,因为它意味着更容易做 profiling、更容易做容量规划,也更容易把执行链路固化下来。换句话说,`fixed_steps_scheduler` 解决的是调度稳定性问题,它让执行入口从动态、不规则、频繁变化的状态,收敛成了一个稳定的固定窗口。 + +![PagedAttention 的问题背景与固定步场景对比](../../assets/paged_attention_comparison.png) + +第三个好处,是它减少了很多与模型计算无关的损耗。在推荐场景里,主要成本本来应该集中在真正的候选扩展、注意力计算和 beam 比较上;但如果每一步都让调度器参与 sequence 重排、batch 重组、元数据更新和索引搬运,那么会引入不少“不是算子本身、但又必须付出”的额外成本。从这个角度看,固定调度本质上是在用更强的执行确定性,换更高的吞吐、更低的调度成本以及更稳定的运行时行为。 + +当然,固定调度也有代价。最明显的问题就是,新请求的等待时间会变长。因为连续调度的一个优势,是新请求可能等一步就有机会被插入;而固定调度下,新请求通常要等当前这一轮固定窗口结束,才能进入下一轮执行。这会带来更明显的排队等待。这个问题的缓解方向,不是退回到连续调度,而是引入 `multi-stream`。也就是说,把已经在固定窗口里的大批请求和新接入的小批请求尽量解耦,让它们落在不同 stream 或不同执行通道上。这样做的目的,不是完全消除等待,而是在保住固定调度吞吐优势的同时,降低新请求接入的额外时延。 + +### 3.2 整图执行 + +在这个基础上,`multi_step_pipeline` 就成为固定调度的天然配套设计。它解决的是执行效率问题。既然我们已经知道这个场景本身就是固定几步,而且通常不会提前结束,那么就没有必要每一步都让 host 参与一次控制:没有必要每一步都做一次 `D2H` 去判断“这一批是不是结束了”,也没有必要每一步都再做一次 `H2D` 去准备下一轮输入。更高效的做法,是在第一步启动时,就把后续若干步需要用到的空间、索引和数据结构一次性准备好,然后让 device 侧连续向前推进。 + +这样做的收益非常直接: + +- 减少 `D2H/H2D` 往返,降低 host 参与频率; +- 减少每一步的 launch 和控制开销; +- 让更多中间数据停留在 device 侧,提高数据复用效率; +- 让整段执行过程更像一条连续流水,而不是“每一步停一下、准备一下、再继续”。 + +对于生成式推荐这种固定轮数任务来说,这种连续执行方式明显比逐步回到 host 再下发下一轮更高效。 + +`multi_step_pipeline` 还有一个经常被低估的价值,就是它为定制算子创造了更好的运行条件。在执行形态稳定之后,配合定制算子把关键热路径进一步做快。`fixed step` 解决的是调度稳定性,而整图执行加上算子定制,解决的是执行效率。 + +## 4. 显存管理与算子协同优化 + +### 4.1 计算与显存瓶颈 + +#### 4.1.1 模型输入输出特征 + +在当前生成式推荐推理设定中,item id 由固定长度 token 序列表示,因此 `decode_step` 是已知的小常数(例如 3)。一次请求的推理流程可以概括为: + +- 一次 prefill:输入为长序列,即用户历史上下文; +- `decode_step` 次 decode:每步生成 1 个 token,最终组合为 item id。 + +单步 decode 的单位开销并不低。为了召回与多样性,生成式推荐通常需要较大的 `beam_width`;同时每条 beam 还要扩展 `top_k` 个候选,再在全局候选池 `beam_width × top_k` 上选择新的 beam 集合,最终 beam 集合大小仍保持为 `beam_width`。例如当 `beam_width=512`、`top_k=512` 时,单步候选池大小达到 262144(约 2.6×10^5)。因此 decode 的步数虽然不多,但每步的搜索选择与 KV 访问开销仍然不低。 + +#### 4.1.2 存储冗余与显存碎片 + +生成式推荐推理服务的主要瓶颈可以拆成两类,而 `xAttention` 就是围绕这两类问题来设计的。 + +第一类是 Attention 的冗余带宽消耗:shared prefix 没有被显式建模为可复用结构。在较大的 beam 场景下,所有 beam 都共享同一段长 prompt,但通用实现往往以“每条 beam 一条完整序列”的视角组织 KV,导致 Shared KV 在 beam 维度被重复触发加载,attention kernel 的有效算术强度下降,最终受限于 HBM 带宽。 + +第二类是 KV Cache 的复制与碎片:beam 分叉与 block 级管理之间存在结构性冲突。beam search 会频繁 fork 与 retire,并触发 beam 重排。对于基于 block 的 KV 管理(例如 PagedAttention 一类),“重排 + block 对齐”往往意味着 block copy、碎片化以及额外空间浪费,显存和带宽都会被放大。 + +### 4.2 `xAttention` 设计原理 + +#### 4.2.1 KV Cache 存储优化 + +![xAttention KV cache 布局](../../assets/xattention_kv_layout.png) + +围绕当前生成式推荐推理的固定结构,xAttention 把 KV Cache 的组织方式与 attention 计算和并行策略一起重新设计,将 shared prefix 在显存层面只存一份,同时 beam 的分叉与重排不再触发高代价的数据拷贝。 + +首先,KV Cache 被按“是否共享前缀”拆成两类: + +- **Shared KV**:prefill 阶段生成的 prompt KV,所有 beam 共享同一份物理存储; +- **Unshared KV**:decode 阶段每条 beam 新生成 token 的 KV,按 token 粒度管理。 + +拆成两类 KV 之后,Unshared KV 只存储 decode 阶段产生的新 token,从而避免 block copy 与显存浪费。 + +#### 4.2.2 Attention 计算优化 + +![xAttention 三阶段执行](../../assets/xattention_three_stage_pipeline.png) + +为了避免把 Shared 与 Unshared KV 直接拼接成一个逻辑长序列,以及由此带来的访存与拷贝问题,xAttention 把一次 attention 拆成三个阶段: + +1. **shared stage**:仅对 Shared KV 计算局部 softmax 统计量与部分输出; +2. **unshared stage**:仅对 Unshared KV 计算局部统计量与部分输出; +3. **merge stage**:使用 OnlineSoftmax 把两段结果稳定合并。 + +并行化层面,会把 shared、unshared 与 merge 分配到不同执行单元和队列中形成流水线,目标是让 Shared 与 Unshared 的计算尽量重叠执行,同时把同步点压缩到最少。 + +### 4.3 Beam Search 的系统化处理 + +如果说 `xAttention` 解决的是“attention 如何更省显存、更少重复读取 Shared KV”,那么这一节解决的就是“在大候选池下,beam search 怎么避免把时间浪费在无效比较和无谓排序上”。 + +从系统视角看,beam search 在生成式推荐场景里至少包含下面几层成本: + +- 每一步要从大量候选中选出新的 beam 集合; +- 候选里并不是每个 token 组合都代表真实 item; +- 随着 decode 往前推进,旧候选会被淘汰,新候选会不断产生; +- 如果每次都新建数据结构、全量排序、全量过滤,开销会非常高。 + +因此,这一层更适合被理解成“围绕 beam search 的系统优化”,而不是单纯一个排序 kernel。 +它的目标包括: + +- 尽早终止不必要的排序; +- 在 item 空间约束下尽早过滤无效路径; +- 尽量复用已有数据结构,避免每轮反复创建和销毁候选容器。 + +对于技术分享来说,这一节最值得强调的是: +在生成式推荐里,beam search 的代价不是附属成本,而是 decode 主成本的一部分。也正因为如此,beam search 需要和 fixed-step 调度、multi-step 执行、KV cache 组织一起被系统性考虑。 + +### 4.4 多级流水线和多流并发 + +第三类问题不是某个单点算子慢,而是整个 pipeline 的分层协作方式不够适合生成式推荐。 +在当前设计里,这一节关注的是一类系统层优化思想:如何让 host、engine、worker、算子执行之间尽量重叠。 + +它至少包含 3 层含义: + +1. **Host 与 Device 的分工更清晰** + - host 尽量少做每轮控制 + - device 尽量多做连续多轮执行 + +2. **执行链路尽量流水化** + - 当前轮执行时,下一轮输入准备已经开始 + - scheduler、batch builder、worker 之间尽量减少停顿 + +3. **利用多流与多并发 pipeline 缓解等待** + - 固定窗口会带来等待代价 + - 但可以通过 `multi-stream` 或多套执行 pipeline 把这部分代价摊薄 + +因此,这一节讨论的可以理解为:把 fixed-step 场景下的调度稳定性、设备侧连续执行,以及多流并发能力组合起来的系统层设计。 + +### 4.5 为什么这条设计路线值得做 + +从生成式推荐服务的 workload 特征来看,把它单独当成一种特殊服务路径来设计调度、执行和算子,并不是“过度工程化”,而是有明确收益空间的。 + +原因在于,这个场景同时具备几种在通用 LLM 推理中不常同时出现的特征: + +- 输入长,但输出短; +- decode 轮数固定,但单轮成本高; +- `beam_width` 和候选池规模都不小; +- host 参与、batch 重排和数据搬运很容易放大整体时延。 + +所以,只要沿用通用推理路径,很多系统成本就会被保留下来;而一旦把 fixed-step、multi-step、shared/unshared KV、beam search 和多流执行放到一起设计,收益就会从多个层次同时出现:调度成本下降、host 参与变少、显存组织更合理、热路径更适合做定制算子优化。 + +所以在技术分享里,这一节可以作为一个很好的收口: +我们不是因为“fixed step 听起来合理”才这样做,而是因为生成式推荐在系统层面确实和通用 LLM 推理不一样,这样做能换来更高吞吐和更稳定的低时延表现。 + +## 5. 代码结构 + +当前分支里,生成式推荐相关代码可以按下面的结构理解: + +外部接入: +- `xllm/c_api/rec.h` +- `xllm/c_api/internal/rec.cpp` +- `xllm/c_api/examples/simple_rec_completions.cpp` + +服务入口: +- `xllm/api_service/rec_completion_service_impl.cpp` +- `xllm/api_service/chat_service_impl.cpp` +- `xllm/api_service/api_service.cpp` +- `xllm/api_service/api_service.h` + +调度与引擎: +- `xllm/core/distributed_runtime/rec_master.cpp` +- `xllm/core/distributed_runtime/rec_master.h` +- `xllm/core/scheduler/fixed_steps_scheduler.cpp` +- `xllm/core/scheduler/fixed_steps_scheduler.h` +- `xllm/core/distributed_runtime/rec_engine.cpp` +- `xllm/core/distributed_runtime/rec_engine.h` + +batch / request / proto: +- `xllm/core/framework/batch/rec_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_batch_input_builder.h` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.h` +- `xllm/core/framework/request/rec_type.h` +- `xllm/proto/rec.proto` +- `xllm/proto/completion.proto` +- `xllm/proto/xllm_service.proto` + +runtime / worker: +- `xllm/core/runtime/rec_worker_impl.cpp` +- `xllm/core/runtime/rec_worker_impl.h` + +kernel / 算子热路径: +- `xllm/core/layers/cuda/xattention.cpp` +- `xllm/core/layers/cuda/flashinfer_attention.cpp` +- `xllm/core/kernels/cuda/xattention/beam_search.cpp` +- `xllm/core/kernels/cuda/xattention/cache_select.cu` + +## 6. 当前分支的执行主链 + +为了把设计和实现真正对起来,可以把当前分支的主执行链拆成下面几步来理解: + +1. **外部接入** + - 如果走动态库方式,请求会从 `xllm/c_api/internal/rec.cpp` 中的 `xllm_rec_text_completions`、`xllm_rec_token_completions` 或 `xllm_rec_chat_completions` 进入。 + - 如果走服务方式,请求会从 `xllm/api_service/rec_completion_service_impl.cpp` 或 `chat_service_impl.cpp` 进入,再转到 `RecMaster`。 + +2. **请求进入 `RecMaster`** + - `RecMaster` 负责把 prompt、token ids、raw embedding 等不同入口统一收敛到 request 构造逻辑。 + - 在这里会根据模型类型区分 `kOneRec` 和 `kLlmRec`,并选择不同的 request pipeline。 + +3. **进入固定调度** + - `RecMaster` 在初始化时直接创建 `FixedStepsScheduler`。 + - 调度器不再按“每一步都动态重排 batch”的思路工作,而是优先围绕固定轮数和固定候选组去构造 batch。 + +4. **引擎执行** + - `RecEngine` 再根据 `RecPipelineType` 选择执行路径。 + - 对 `LlmRec` multi-round 场景,会下沉到 `RecMultiRoundEnginePipeline`,把多轮 decode 的主要控制逻辑继续往 worker 侧下压。 + +5. **batch 与输入拼装** + - `RecBatchInputBuilder` 和 `RecMultiRoundBatchInputBuilder` 负责把 sequence、step 信息、decode positions、sampling params 等整理成 `ForwardInput`。 + - 这里的 `step_meta` 是 multi-step 执行的关键数据来源,它决定后续每一轮 decode 该如何构造位置、cache 和 beam 相关输入。 + +6. **worker 侧多轮执行** + - `RecWorkerImpl::LlmRecMultiRoundPipeline::step()` 会在设备侧循环多轮。 + - 它会先准备 beam search tensor、full/unshared KV 相关结构,再在每一轮中执行: + - 当前轮输入准备 + - 模型 forward + - sample 输出处理 + - beam search + - cache select + - 下一轮输入预计算 + +7. **算子热路径** + - Attention 相关路径落在 `xattention.cpp` 与 `flashinfer_attention.cpp` + - beam 相关路径落在 `beam_search.cpp` + - beam 重排后的 cache 选择路径落在 `cache_select.cu` + +如果从技术分享视角来讲,这条主链非常适合作为“架构总图”之后的第一条展开线,因为它把“固定调度、整图执行、定制算子”三件事串成了一个具体执行过程,而不是三个割裂的优化点。 + +## 7. 设计取舍与适用边界 + +这套设计并不意味着固定调度一定优于连续调度,也不意味着 multi-step pipeline 适合所有生成任务。它成立的前提,是当前生成式推荐场景具有下面几个特征: + +- decode 轮数较固定,通常不会像开放式文本生成那样提前结束; +- 同一请求下存在较大的 `beam_width`,而且多个 beam 需要同步比较; +- 整次请求的总时延,比逐 token 的交互体验更重要; +- 设备侧状态(KV Cache、positions、beam tensors)可以提前组织并稳定复用。 + +在这些前提成立时,固定调度和整图执行的收益会比较明显。但它也有明确边界: + +### 7.1 固定调度的边界 + +- 如果请求长度差异极大,而且大量请求会提前结束,那么连续调度的灵活性会更有价值; +- 如果业务更关心“新请求能不能立刻插入”,而不是“当前窗口吞吐是否最优”,固定调度会天然吃亏; +- 如果候选扩展不依赖大规模 beam 同步推进,那么固定窗口的收益会下降。 + +### 7.2 multi-step pipeline 的边界 + +- 如果每一步都必须回到 host 做强控制决策,那么 multi-step pipeline 的优势会被削弱; +- 如果 shape、batch 或关键输入在每一轮都大幅波动,那么想要把多轮执行稳定下来会更难; +- 如果后端算子本身还不支持稳定的多轮设备侧推进,那么整图执行只会停留在概念层。 + +### 7.3 定制算子的边界 + +`xAttention` 和 `beam search` 定制算子之所以值得做,是因为当前执行形态已经足够稳定。如果没有 fixed-step 带来的稳定 batch 形态,也没有 multi-step pipeline 带来的稳定多轮推进,那么很多定制优化都会被反复的数据搬运、batch 重组和 host 参与开销抵消掉。 + +因此,更合理的理解顺序不是“先有定制算子,再决定调度”,而是: + +1. 先确认 workload 适合固定调度; +2. 再确认多轮执行可以尽可能下沉到设备侧; +3. 最后再围绕真正稳定下来的热路径做算子定制。 + +这样才能让 `fixed_steps_scheduler`、`multi_step_pipeline`、`xAttention` 和 `beam search` 四者形成一套前后自洽的设计,而不是彼此孤立的优化点。 + +## 8. 代码路径附录 + +这一节的目的,不是重复“代码结构”里的文件清单,而是给读者一个更可执行的阅读顺序:如果后续要继续做技术分享、走读代码,或者排查 `backend=rec` 路径上的行为差异,可以直接按下面的顺序进入。 + +### 8.1 从外部入口开始看 + +如果要理解 `predictor` 或动态库接入是怎样进入 xLLM 的,建议先看下面几处: + +- `xllm/c_api/rec.h` + - 对外暴露 `xllm_rec_create`、`xllm_rec_initialize`、`xllm_rec_text_completions`、`xllm_rec_token_completions`、`xllm_rec_chat_completions` + - 适合先理解“外部系统到底能怎么调用 REC 能力” +- `xllm/c_api/internal/rec.cpp` + - 这是真正的 CAPI 实现 + - 适合看 `.so` 模式下,request 参数是怎样被封装和转发的 +- `xllm/c_api/examples/simple_rec_completions.cpp` + - 是最短的调用示例 + - 如果要给新人解释“动态库接入长什么样”,这里最直观 + +如果技术分享里想给一段“最小调用样例”,这一层最适合出现在开头。 + +### 8.2 从服务入口看统一分发 + +如果你更关心 RPC 或统一服务链路,可以继续看: + +- `xllm/api_service/api_service.cpp` + - 这里会根据 `FLAGS_backend` 决定挂哪类 service impl + - `backend == "rec"` 时,`rec_completion_service_impl_` 和 `chat_service_impl_` 都会被接上 +- `xllm/api_service/rec_completion_service_impl.cpp` + - 负责把 rec completion 请求转给 `RecMaster` + - `routing`、`input_tensors`、`RequestParams` 都是在这里被整理进来 +- `xllm/api_service/chat_service_impl.cpp` + - 对 `RecMaster` 也有 chat 入口 + - 适合说明“REC 不是只有 token completion,一样可以走 chat 形态” + +这一层适合在技术分享中回答一个问题:为什么说 `backend=rec` 不是另起炉灶,而是接进了现有服务框架。 + +### 8.3 从调度和引擎看主链 + +如果分享的重点是“为什么 fixed step 更适合 rec”,阅读顺序建议是: + +1. `xllm/core/distributed_runtime/rec_master.h` +2. `xllm/core/distributed_runtime/rec_master.cpp` +3. `xllm/core/scheduler/fixed_steps_scheduler.h` +4. `xllm/core/scheduler/fixed_steps_scheduler.cpp` +5. `xllm/core/distributed_runtime/rec_engine.h` +6. `xllm/core/distributed_runtime/rec_engine.cpp` + +可以按下面这条链理解: + +```text +Rec request + -> RecMaster + -> FixedStepsScheduler + -> RecEngine + -> RecEnginePipeline + -> Worker / worker_clients +``` + +这里最值得讲的几个点是: + +- `RecMaster` 负责入口收敛和 pipeline 选择 +- `FixedStepsScheduler` 负责把 request 组织成适合固定轮数推进的 batch +- `RecEngine` 负责把调度结果交给实际执行路径 +- `RecMultiRoundEnginePipeline` 代表“多轮 decode 控制进一步下沉”的实现方式 + +如果分享时要强调“fixed step 不是口头概念,而是代码主链的真实选择”,这一层就是核心证据。 + +### 8.4 从 batch builder 看 multi-step 的输入组织 + +如果要解释 `multi_step_pipeline` 为什么能成立,单看 scheduler 还不够,必须继续看 batch builder: + +- `xllm/core/framework/batch/rec_batch_input_builder.h` +- `xllm/core/framework/batch/rec_batch_input_builder.cpp` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.h` +- `xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp` +- `xllm/core/framework/batch/batch.cpp` + +这里最重要的不是“类名”,而是几个关键事实: + +- `RecBatchInputBuilder::create(...)` 会按 `RecType` 和 multi-round 模式选择 builder +- `RecMultiRoundBatchInputBuilder` 不是普通 builder 的轻微变种,而是专门为多轮 decode 组织输入的实现 +- `step_meta`、`decode_positions`、`sampling params`、`batch forward type` 等信息是在这一层被拼好并送往后续 runtime 的 + +所以,如果要说明“为什么第一步就能把后面几步的输入准备好”,这一层比只讲 engine 更关键。 + +### 8.5 从 worker 看 device 侧多轮执行 + +`multi_step_pipeline` 真正最值得展开的代码在: + +- `xllm/core/runtime/rec_worker_impl.h` +- `xllm/core/runtime/rec_worker_impl.cpp` + +尤其是下面这些点: + +- `RecWorkerImpl::step_async(...)` + - 说明请求是如何进到 worker 内部并在特定 stream 上执行的 +- `RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_inputs(...)` + - 说明 multi-round 输入如何进入 runtime +- `RecWorkerImpl::LlmRecMultiRoundPipeline::allocate_kv_caches_related()` + - 说明为什么 fixed-step 场景更适合提前分配 KV 相关结构 +- `RecWorkerImpl::LlmRecMultiRoundPipeline::step(...)` + - 这是最核心的一段 + - 明确展示了多轮循环、beam search、cache select、下一轮输入预计算之间的关系 +- `compute_next_round_input_async(...)` + - 这是解释“为什么可以减少 host 往返”的关键点 + +如果技术分享想从“执行效率”而不是“调度策略”切入,这一层是最值得重点展开的。 + +### 8.6 从 kernel 热路径看为什么定制算子值得做 + +如果要讲 `xAttention` 和 `beam search` 定制算子,推荐按下面的顺序进: + +- `xllm/core/layers/cuda/xattention.cpp` +- `xllm/core/layers/cuda/flashinfer_attention.cpp` +- `xllm/core/kernels/cuda/xattention/xattention_ops_api.h` +- `xllm/core/kernels/cuda/xattention/beam_search.cpp` +- `xllm/core/kernels/cuda/xattention/cache_select.cu` + +这一层适合回答 3 个问题: + +1. Attention 的稳定执行路径到底落在哪 +2. Beam Search 为什么不只是调度问题,而是 kernel 热路径问题 +3. beam 重排之后的 cache select 为什么必须和前面的执行形态一起考虑 + +也就是说,技术分享如果要把 `fixed_steps_scheduler`、`multi_step_pipeline`、`xAttention` 和 `beam search` 串成一条线,最终一定会落到这里。 + +### 8.7 推荐的代码阅读顺序 + +如果后续你或者其他同学还要继续扩写这篇文档,我建议代码阅读顺序固定成下面这样: + +```text +入口 + -> api_service + -> RecMaster + -> FixedStepsScheduler + -> RecEngine + -> RecBatchInputBuilder / RecMultiRoundBatchInputBuilder + -> RecWorkerImpl::LlmRecMultiRoundPipeline + -> xAttention / beam_search / cache_select +``` + +这个顺序的好处是: + +- 先看“请求怎么进来” +- 再看“为什么 fixed step” +- 再看“multi-step 是怎么在设备侧成立的” +- 最后看“定制算子为什么在这里有价值” + +这样逻辑最顺,也最适合技术分享展开。 + +## 9. 关键代码锚点索引 + +如果后续需要继续写技术分享、补代码注释,或者在评审时快速证明文档里的说法来自当前分支实现,可以直接从下面这组锚点入手。 + +### 9.1 入口与服务层锚点 + +- `xllm/core/distributed_runtime/rec_master.cpp:575` + - `RecMaster::handle_request(...)` + - 对应 prompt / prompt_tokens / input_tensors 入口 +- `xllm/core/distributed_runtime/rec_master.cpp:603` + - `RecMaster::handle_request(...)` + - 对应 chat messages 入口 +- `xllm/core/distributed_runtime/rec_master.cpp:651` + - `RecMaster::handle_request(const std::vector& prompt_tokens, ...)` + - 对应 token / raw input 类入口 + +这些锚点适合回答一个问题:REC 请求到底是怎么被收敛到 `RecMaster` 里的。 + +### 9.2 调度层锚点 + +- `xllm/core/scheduler/fixed_steps_scheduler.cpp:337` + - `FixedStepsScheduler::step(const absl::Duration& timeout)` + - 这是 fixed-step 调度真正往前推进的一步 +- `xllm/core/scheduler/fixed_steps_scheduler.cpp:186` + - `FixedStepsScheduler::prepare_batch()` + - 适合解释当前 batch 是怎么在固定步场景下被组织的 +- `xllm/core/framework/batch/rec_batch_input_builder.cpp:29` + - `RecBatchInputBuilder::create(...)` + - 适合解释 builder 是如何按 `RecType` 和 multi-round 模式切换的 + +这几处放在一起,可以直接支持“fixed_steps_scheduler 不是概念,而是代码主链里的真实选择”这一点。 + +### 9.3 引擎与多轮执行锚点 + +- `xllm/core/distributed_runtime/rec_engine.cpp:901` + - `RecEngine::RecMultiRoundEnginePipeline::step(...)` + - 说明 engine 层如何把执行进一步下沉到 multi-round pipeline +- `xllm/core/runtime/rec_worker_impl.cpp:849` + - `RecWorkerImpl::LlmRecMultiRoundPipeline::step(...)` + - 这是最关键的一段,真正体现多轮 decode 循环在设备侧发生 +- `xllm/core/runtime/rec_worker_impl.cpp:1011` + - `xllm::kernel::cuda::beam_search(...)` 调用点 +- `xllm/core/runtime/rec_worker_impl.cpp:1066` + - `xllm::kernel::cuda::cache_select(...)` 调用点 + +如果你需要在分享中明确“beam search 和 cache select 不是调度层概念,而是直接落到 device 热路径上的实现”,这一组锚点最适合引用。 + +### 9.4 这组锚点怎么用 + +这组锚点不需要在正文中全部展开,但非常适合作为: + +- 技术分享讲稿里的“代码证据页” +- PR 说明里的“关键实现位置” +- 后续 reviewer 问“这句话代码在哪”的快速答复 + +如果后续还要继续扩写文档,建议优先围绕这几处函数补更细的输入输出说明,而不是再扩展泛化描述。 + +## 10. 推荐讲解顺序与阅读策略 + +前面的“代码路径附录”和“关键代码锚点索引”更像是资料库,适合在需要时回查;但如果这篇文档要真正变成一场技术分享的底稿,还需要一条更偏“讲述顺序”的线索。下面给出一个更适合对外讲的展开顺序。 + +### 10.1 推荐的讲解顺序 + +如果听众并不直接参与 `backend=rec` 的实现,建议不要一上来就讲 `RecWorkerImpl` 或 `beam_search.cpp`,而是按下面的顺序推进: + +1. **先讲业务目标** + - 为什么生成式推荐和通用 LLM 推理的目标不一样 + - 为什么这里更关注“整次请求的总时延”和“固定几轮的候选比较” + +2. **再讲调度选择** + - 为什么 `fixed_steps_scheduler` 更适合这个 workload + - 为什么大 `beam_width` 会让固定调度比连续调度更划算 + +3. **然后讲执行方式** + - `multi_step_pipeline` 为什么可以把多轮 decode 控制下沉到设备侧 + - 为什么这样可以减少 host 往返和控制开销 + +4. **再讲定制算子** + - 为什么执行形态稳定之后,`xAttention` 和 `beam search` 优化才真正值得做 + - 为什么这两类优化要放在一节里讲,而不是拆成孤立话题 + +5. **最后回到代码** + - 再用代码主链和锚点证明前面的结论,不让整场分享停留在概念层 + +这个顺序的好处是:听众先理解“为什么”,再理解“怎么做”,最后再看“代码在哪里”。这比一开始就从实现文件名切入更容易跟上。 + +### 10.2 如果是内部走读,顺序可以反过来 + +如果听众本身就是 xLLM 或推荐基础设施相关同学,那么也可以采用另一条顺序: + +1. 先看 `RecMaster -> FixedStepsScheduler -> RecEngine` +2. 再看 `RecBatchInputBuilder` +3. 再看 `RecWorkerImpl::LlmRecMultiRoundPipeline` +4. 最后看 `xAttention / beam_search / cache_select` + +这种讲法更适合代码走读,因为它直接沿着调用栈往下走。但它的缺点是,对没有上下文的听众来说,一开始就会掉进实现细节里,不容易先抓住设计选择背后的动机。 + +### 10.3 建议在分享中强调的 3 个结论 + +如果后续要把这篇文档压缩成 10 分钟技术分享,我建议把整篇内容收束成下面 3 句话: + +- `fixed_steps_scheduler` 解决的是调度稳定性问题; +- `multi_step_pipeline` 解决的是多轮执行效率问题; +- `xAttention` 和 `beam search` 定制算子是在执行形态稳定之后,进一步兑现性能收益的关键路径优化。 + +这三句话是整篇文档最值得被记住的部分,后面的代码链路和锚点都可以理解为是在为这三句结论提供证据。 + +## 11. 与通用 LLM 推理路径的对照 + +为了避免把 `backend=rec` 看成“只是把 LLM 推理拿来改一改”,这里把它和通用 LLM 推理的主路径做一个更明确的对照。 + +### 11.1 优化目标不同 + +通用 LLM 推理更强调逐 token 的生成体验,例如尽快返回第一个结果、尽量缩短 token 间隔,并让新的请求尽快插入执行。 +而 `backend=rec` 更强调固定轮数内的候选扩展与比较,因此更关注整次请求的总时延,而不是单条序列的最早结束时间。 + +这意味着两者虽然都在做 decode,但优化目标已经发生了偏移: + +- 通用 LLM 推理更偏“动态请求管理” +- 生成式推荐更偏“固定窗口内的同步推进” + +### 11.2 调度重点不同 + +在通用 LLM 推理里,连续调度的价值来自于: + +- 某些序列提前结束后可以立刻让位; +- batch 可以被持续重组; +- decode 路径中动态插入和动态退出是高频动作。 + +而在 `backend=rec` 里: + +- decode 轮数更固定; +- `beam_width` 更大; +- 同一个 request 的多个候选需要在同一步比较; +- 频繁重排 batch 反而会放大调度成本。 + +所以,通用 LLM 推理的调度器重点是“灵活”,而生成式推荐的调度器重点是“稳定”。 + +### 11.3 执行方式不同 + +通用 LLM 推理常常允许 host 在每一步都继续参与调度、结束判断和下一轮输入准备。 +而生成式推荐因为轮数固定、候选同步推进,所以更适合在第一步时就把后续几步所需的结构准备好,再让 device 侧连续推进。 + +这也是为什么 `multi_step_pipeline` 对推荐场景更有价值:它不是简单减少几个 memcpy,而是把原本“每一步都回来问一下 host”的控制模式,替换成“尽量在设备侧连续完成多轮”的模式。 + +### 11.4 算子收益兑现方式不同 + +通用 LLM 推理里,算子优化往往直接服务于单步 decode 或通用 attention 路径。 +而在 `backend=rec` 中,算子优化的价值高度依赖执行形态是否已经稳定下来。 + +如果固定调度没有成立,batch 还在频繁重组;如果 multi-step pipeline 没有成立,host 还在反复介入,那么很多算子级优化都会被额外的数据搬运和控制开销抵消掉。 +因此,在推荐场景里更合理的顺序是: + +1. 先把调度稳定下来; +2. 再把多轮执行尽量下沉到设备侧; +3. 最后才围绕真正稳定的热路径做 `xAttention`、`beam search`、`cache_select` 这类优化。 + +### 11.5 设计边界与适用条件 + +上面的对照并不是为了强调“生成式推荐和通用 LLM 推理完全不同”,而是为了明确哪些设计可以直接复用通用推理框架,哪些设计必须围绕 `backend=rec` 的 workload 特征单独处理。 + +从当前分析看,`fixed_steps_scheduler`、`multi_step_pipeline`、`xAttention` 和 `beam search` 这四部分之所以需要被放在一起讨论,本质原因是它们依赖同一组前提条件: + +- 输出长度固定或近似固定; +- decode 轮数较少,但单轮计算成本高; +- 同一请求下需要维护较大的 `beam_width`; +- 候选扩展和候选比较是主路径,而不是附属逻辑; +- 设备侧状态可以提前组织,并在后续轮次中稳定复用。 + +只有在这些条件同时成立时,固定调度和整图执行的收益才会比较稳定。如果离开这些条件,例如请求长度差异极大、序列可能随时提前结束、或者新请求接入时延比吞吐更重要,那么 `backend=rec` 当前这套设计就不一定仍然是最优选择。 + +进一步看,这四部分之间并不是并列关系,而是逐层依赖关系。 + +第一层是 `fixed_steps_scheduler`。它解决的是调度入口的稳定性问题。如果调度层仍然每一步都在重组 batch、裁剪 sequence、回收空位,那么后面执行层的稳定性基础就不存在。 + +第二层是 `multi_step_pipeline`。它解决的是执行过程的连续性问题。只有在调度窗口已经稳定之后,才有意义进一步把多轮 decode 的输入准备、KV 组织和轮次推进尽量下沉到 device 侧。 + +第三层是 `xAttention`。它解决的是 attention 在生成式推荐场景下的显存与访存问题。只有在执行过程已经相对稳定之后,shared KV 与 unshared KV 的拆分、分段 attention 和流水线计算才会稳定兑现收益。 + +第四层是 `beam search` 优化。它解决的是大候选池下的搜索与筛选成本问题。由于 beam 相关逻辑会直接影响 decode 阶段的主路径,因此它必须和前面三层设计一起考虑,而不能独立看成一个后处理模块。 + +因此,这里真正想说明的是:当前分支中的这套设计不是四个离散优化点,而是一套围绕生成式推荐 workload 逐层收敛出来的组合方案。调度先稳定,执行再连续,算子最后吃收益。如果顺序反过来,很多优化很可能在系统层被额外开销抵消。 + +### 11.6 不适用或收益有限的场景 + +为了避免把当前方案理解成“生成式推荐的统一答案”,还需要明确它的边界。 + +第一类是不适合固定调度的场景。比如请求长度波动极大、序列经常提前结束、或者系统更关心新请求的最短接入延迟而不是整批吞吐。在这种情况下,连续调度的灵活性可能更重要,固定窗口反而会放大等待成本。 + +第二类是不适合整图执行的场景。比如每一轮都必须由 host 做强控制决策,或者每一轮的 shape、batch 结构、关键输入变化都非常剧烈。这类场景下,`multi_step_pipeline` 很难把多轮控制稳定地下沉到 device 侧,整图执行也就不容易获得稳定收益。 + +第三类是算子侧收益不明显的场景。如果 beam 不大、候选扩展成本不高、或者 shared prefix 不长,那么 `xAttention` 和 `beam search` 相关优化虽然依然成立,但其工程收益可能不会像典型生成式推荐场景那样突出。 + +从技术文档角度,把这些边界写清楚有两个好处:一是避免读者误以为这套方案无条件普适;二是让后续 reviewer 在看实现时能更明确地区分“设计前提”与“实现缺陷”。 + +## 12. 验证与验收建议 + +这一节的验证不应只停留在“文档能不能打开”,而应该和前文所描述的设计目标直接对应。换句话说,验证要回答的问题是:当前分支中的实现,是否真的支持文档里所描述的这套设计。 + +### 12.1 文档层验证 + +最基础的验证仍然需要保留: + +- 中文与英文设计文档能够正常渲染; +- 所有图片引用都能正确解析到 `docs/assets/`; +- 英文文档引用的是英文版示意图,而不是中文标注图片; +- 入口页和相关 feature 文档能够导航到这篇设计文档。 + +这一层验证的意义,是保证文档本身可以被正常消费。 + +### 12.2 代码路径一致性验证 + +第二层验证,是确认文档中提到的关键路径在当前分支中确实存在,并且角色描述与实际实现一致。这里建议至少对以下路径做静态核对: + +- `RecMaster` 是否确实承担 request 收敛与 pipeline 选择; +- `FixedStepsScheduler` 是否确实是 `backend=rec` 的固定调度入口; +- `RecEngine` 是否承担 engine 层执行组织; +- `RecWorkerImpl::LlmRecMultiRoundPipeline` 是否确实承担多轮 device 侧执行; +- `xAttention`、`beam_search`、`cache_select` 的 kernel 路径是否与文档描述一致。 + +这一层验证的目的,是避免文档讲的是“理想结构”,而代码里其实已经偏离。 + +### 12.3 设计目标对应的验证项 + +如果要进一步把这篇文档作为技术分享和实现说明的统一底稿,建议把验证项和设计目标一一对应起来。 + +针对固定调度,建议关注: + +- decode 过程中 batch 是否仍频繁重组; +- beam 相关 sequence 是否仍在每轮发生大量裁剪和搬移; +- 调度开销是否在 profile 中被压低。 + +针对 `multi_step_pipeline`,建议关注: + +- host 是否仍在每一轮都做终止判断; +- 下一轮输入是否已经可以在当前轮计算期间准备; +- `D2H/H2D` 往返是否相比逐轮控制路径有所下降。 + +针对 `xAttention`,建议关注: + +- shared KV 是否真的只保留一份物理存储; +- unshared KV 是否只承载 decode 产生的新 token; +- 是否减少了 shared prefix 的重复加载。 + +针对 `beam search`,建议关注: + +- 大 `beam_width` 场景下排序和过滤是否仍然占主成本; +- 相关数据结构是否可以跨轮复用; +- item 过滤和候选筛选是否已经与 decode 主链有效衔接。 + +### 12.4 最终验收标准 + +如果把这篇文档作为当前分支生成式推荐设计的验收基线,那么最终至少应满足以下标准: + +- 文档层面:中英文可读、引用完整、导航可达; +- 结构层面:调度、执行、算子三层关系讲清楚; +- 代码层面:关键路径、关键函数、关键文件能对上; +- 设计层面:适用条件、边界条件和主要取舍写清楚; +- 分享层面:读者能够从文档中直接抽出讲稿主线,而不需要再次重构逻辑。 + +只有同时满足这几类条件,这篇文档才不只是“说明写过了”,而是真正能作为设计、分享、review 三者共用的稳定底稿。 + +## 13. FAQ / 常见误解 + +### Q1: 固定步调度是不是一定比连续调度更好? + +不是。固定步调度之所以在 `backend=rec` 场景更合适,是因为这里的 decode 轮数更固定、`beam_width` 更大、候选比较更同步。如果这些前提不成立,连续调度的灵活性可能更有价值。 + +### Q2: `multi_step_pipeline` 是不是只是少做几次 memcpy? + +不是。减少 `D2H/H2D` 往返只是表面收益。更关键的是,它改变了控制方式:从“每一步都依赖 host 回来做决策”,变成“尽量在 device 侧连续推进多轮执行”。 + +### Q3: `xAttention` 和 `beam search` 为什么要放在同一节? + +因为它们都依赖同一件事:执行形态已经稳定。如果调度还在频繁重组 batch、host 还在每一步强介入,那么不管是 attention 优化还是 beam search 优化,都会被系统层额外开销稀释。 + +### Q4: 这套设计是不是只适用于 OneRec? + +不是。当前文档以生成式推荐场景为主线,重点落在 `backend=rec` 的固定步和多轮执行特征上。OneRec 是重要例子,但重点不是某个具体模型名,而是这类 workload 的服务特征。 + +### Q5: 为什么文档里反复强调“先稳定调度,再优化算子”? + +因为在系统设计里,顺序很重要。如果调度入口不稳定,batch 形态不稳定,执行过程也不能稳定地下沉到设备侧,那么算子级优化很难持续兑现收益。只有调度和执行先稳定,算子优化才不会被反复的数据搬运和控制开销抵消。 + +## 14. 接口与数据契约 + +如果这篇文档要继续作为实现和评审的基础材料,仅靠“原理解释”还不够,还需要把一些关键接口和数据契约明确写出来。这里不追求覆盖所有字段,而是把最影响实现正确性的部分固定下来。 + +### 14.1 外部入口的输入形态 + +当前 `backend=rec` 至少存在三类外部输入形态: + +- **文本 prompt 入口** + - 适用于 prompt 驱动的推荐请求 + - 最终会被收敛到 `RecMaster::handle_request(...)` +- **token / raw 输入入口** + - 适用于已经完成 tokenizer 处理或外部直接传 token 序列的场景 + - 也可能带有 embedding 或索引信息 +- **chat 入口** + - 适用于对话式 recommendation 场景 + - 在服务层最终仍然会落回 `RecMaster` + +从设计角度看,这里最重要的并不是“有多少入口”,而是:所有入口最终都要收敛到统一 request state 和统一 scheduler / engine / worker 主链中。 + +### 14.2 与 multi-round 直接相关的关键数据 + +对 `backend=rec` 来说,下面几类数据是关键契约: + +- `beam_width` + - 决定每轮保留多少候选分支 +- `top_k` + - 决定每轮每个分支扩展多少候选 +- `decode_step` / `total_round` + - 决定固定轮数和整图执行边界 +- `decode_positions` + - 决定每一轮 token 的位置组织方式 +- shared / unshared KV + - 决定 decode 阶段 KV cache 如何拆分和复用 + +如果这些数据没有被提前组织好,那么 fixed-step 和 multi-step 的收益就无法稳定落地。 + +### 14.3 文档层应该长期保持的不变量 + +后续如果继续扩写这篇文档,我建议至少保持下面几条不变量: + +- 文档里一旦提到 fixed-step,就必须同时交代 fixed-step 的边界条件; +- 文档里一旦提到 multi-step,就必须说明 host 侧减少了哪些职责; +- 文档里一旦提到 `xAttention`,就必须解释它和 KV 组织之间的关系; +- 文档里一旦提到 `beam search`,就不能只讲算法,还要讲它的系统成本; +- 文档里的代码锚点必须能在当前分支对上,不能引用未来规划路径。 + +## 15. 实现对齐表 + +为了避免读者把“设计意图”和“当前分支已经落地的能力”混在一起,这里给出一个简化的实现对齐表。它不是精确到每个 helper 的状态表,而是当前文档里最重要的设计项对齐结果。 + +### 15.1 已有实现 + +- **固定调度** + - 当前分支已有 `FixedStepsScheduler` + - 可以作为固定窗口调度的代码事实依据 +- **multi-round 执行主链** + - 当前分支已有 `RecMultiRoundEnginePipeline` + - worker 侧已有 `LlmRecMultiRoundPipeline` +- **beam search 热路径** + - 当前分支已有 beam search 相关调用与 kernel 路径 +- **cache select 热路径** + - 当前分支已有 cache select 调用和对应 CUDA 路径 +- **xAttention / flashinfer 路径** + - 当前分支已有对应实现文件和调用链 + +### 15.2 当前文档里仍然属于“设计抽象”的部分 + +- `xAttention`、beam search、fixed-step、multi-step 之间的收益关系 + - 当前文档做了归纳 + - 但它们不是某个单一类或单一模块的名字 +- 多流缓解等待的整体收益 + - 当前代码里有并发 pipeline 与 stream 的基础设施 + - 但文档里的收益分析仍属于系统层抽象,不是单个类能直接表达的事实 + +### 15.3 为什么实现对齐表值得单独写 + +这一节的作用,是避免 reviewer 或听众把“文档里讲得通”误以为“当前分支里一定已经逐项完全实现”。 +实现对齐表能让后续讨论更聚焦:哪些内容是已经可以拿代码证明的,哪些内容是当前设计的系统层解释。 + +## 16. 失败模式与可观测性 + +如果这篇文档只讲“正常路径”,那它仍然更像分享稿,而不是技术文档。在线系统一定要考虑失败模式与观测手段,因此这里补一个最小版本。 + +### 16.1 典型失败模式 + +- **等待时间变长** + - fixed-step 窗口导致新请求排队时间增加 +- **beam 过大导致 decode 成本过高** + - 候选扩展、排序与过滤成本放大 +- **shared/unshared KV 组织不正确** + - 导致显存浪费、重复读取或错误的 cache 选择 +- **host 参与过多** + - 使 multi-step 的收益被冲掉 +- **算子热路径收益不明显** + - 说明调度和执行形态还不够稳定 + +### 16.2 最值得观察的指标 + +即便不展开完整监控系统,这篇文档也至少应该把观察点写清楚: + +- 请求总时延 +- 不同 decode round 的耗时分布 +- scheduler 相关开销占比 +- host-device 往返次数或同步点 +- beam search / cache select 在总耗时中的占比 +- 显存占用和碎片化趋势 + +### 16.3 排障时最短路径 + +如果线上效果和预期不一致,建议优先按下面顺序排: + +1. 先确认请求是不是确实符合 fixed-step 场景假设; +2. 再确认 decode 主链是否真的走到了 multi-round pipeline; +3. 再看 beam search 和 cache select 是否成为主成本; +4. 最后再看 xAttention / KV 组织是否真的减少了重复加载和拷贝。 + +## 17. 基准与验收协议 + +为了让这篇文档后续更容易被用作评审和验收依据,这里补一个轻量的基准与验收协议模板。 + +### 17.1 建议至少覆盖的 workload + +- **长输入 / 短输出** + - 典型生成式推荐主场景 +- **不同 beam_width** + - 例如中等 beam 和大 beam 两组 +- **不同 top_k** + - 用于观察搜索和过滤成本的放大趋势 + +### 17.2 建议至少对比的指标 + +- 吞吐 +- P50 / P95 / P99 时延 +- 不同 decode round 的耗时 +- scheduler 开销占比 +- host 参与开销 +- 显存峰值 + +### 17.3 验收关注点 + +- fixed-step 是否真的降低了调度扰动 +- multi-step 是否真的降低了 host 回合控制开销 +- `xAttention` 是否真的改善了 shared prefix 的重复访问 +- beam 相关优化是否真的降低了搜索与过滤开销 + +这部分内容的价值在于:它能把“这套设计听起来合理”进一步收敛为“这套设计应该如何被验证”。 + +## 18. 关键参数与建议范围 + +为了避免这篇文档停留在抽象讨论层面,最后补一个轻量但实用的参数章节。这里不追求给出唯一正确的数值,而是整理出后续继续调优时最值得关注的参数范围和影响方向。 + +### 18.1 `beam_width` + +`beam_width` 决定每一轮保留的候选分支数,是生成式推荐 decode 成本最敏感的参数之一。 + +- `beam_width` 偏小: + - 搜索成本低 + - 候选覆盖较弱 + - 更容易得到较低时延 +- `beam_width` 偏大: + - 候选覆盖更强 + - 排序、过滤、cache select 成本会明显上升 + - fixed-step 和 multi-step 的收益会更突出,但同时更容易放大 decode 侧主成本 + +因此,`beam_width` 不是简单的“越大越好”或“越小越快”,而是推荐质量、时延和系统开销之间的核心平衡点。 + +### 18.2 `top_k` + +`top_k` 决定每个 beam 在一轮中扩展多少候选。 +在大候选池场景下,`top_k` 增长会直接放大: + +- 排序工作量 +- 无效 item 路径过滤工作量 +- 候选集合的中间态管理成本 + +如果 `top_k` 设计得过大,而 item 过滤又不够早,那么会导致很多计算都浪费在最终不会进入下一轮的候选上。 + +### 18.3 decode 轮数 + +decode 轮数越固定,fixed-step 调度和 multi-step 执行就越容易稳定发挥作用。 +如果 decode 轮数稳定在较小范围内,那么: + +- 设备侧更容易提前分配相关结构; +- `step_meta` 更容易保持稳定; +- 整图执行的收益更容易兑现。 + +反过来,如果 decode 轮数本身开始变得不稳定,那么 fixed-step 和 multi-step 的设计前提就会被削弱。 + +### 18.4 多流并发数量 + +多流不是越多越好。流数量过少,不能有效缓解等待;流数量过多,又可能带来新的调度竞争、资源抢占和额外管理开销。 + +因此,比较合理的调优方式通常是: + +1. 先确认 fixed-step 是否已经稳定; +2. 再评估 host 侧等待和排队问题是否明显; +3. 最后再增加多流或多执行 pipeline 数量。 + +### 18.5 参数调优的顺序建议 + +如果后续要继续做性能调优,建议按下面顺序看参数: + +1. 先看 decode 轮数是否稳定; +2. 再看 `beam_width` 是否合理; +3. 再看 `top_k` 是否过大; +4. 最后再看多流并发和系统层 overlap 是否足够。 + +这个顺序的原因是:前两个参数决定的是 workload 形状,后两个参数更多是在既定 workload 上做系统优化。 + +## 19. 参考资料 + +这篇文档尽量立足当前分支实现,但一些背景问题、调度思路和 memory / graph 执行语境,本身就来自公开资料和已有模型实践。后续如果要继续扩写,可以优先参考下面这些公开资料。 + +### 19.1 调度与服务系统 + +- Orca: A Distributed Serving System for Transformer-Based Generative Models + - 用于理解 continuous batching 的问题背景 + - 适合作为 fixed-step 对照材料 + +### 19.2 显存与 KV Cache + +- Efficient Memory Management for Large Language Model Serving with PagedAttention + - 用于理解 block-based KV 管理的常见问题 + - 适合作为 shared/unshared KV 设计的对照背景 + +### 19.3 生成式推荐模型 + +- OneRec + - 用于理解生成式召回模型的结构特点 +- OneTrans + - 用于理解生成式精排模型的结构特点 + +### 19.4 图执行与多步执行 + +- 仓库内已有 `Graph Mode` 设计文档 + - 适合理解 graph capture / replay、参数化和 Piecewise Graph +- 当前生成式推荐设计文档 + - 更适合理解 fixed-step、多轮 decode、beam search 和 xAttention 在推荐场景里的组合关系 + +### 19.5 使用建议 + +如果后续还要继续扩写这篇文档,我建议: + +- 需要写调度背景时,先看 Orca; +- 需要写 KV / memory 组织问题时,先看 PagedAttention; +- 需要写推荐模型结构时,先回到 OneRec / OneTrans; +- 需要写当前分支落点时,再回到本文档和代码锚点。 + +这样能保证文档继续变厚时,内容来源清楚,而且不会变成无边界的泛化解释。 diff --git a/docs/zh/design/graph_mode_design.md b/docs/zh/design/graph_mode_design.md index d72b31683..f5d81f58a 100644 --- a/docs/zh/design/graph_mode_design.md +++ b/docs/zh/design/graph_mode_design.md @@ -24,6 +24,10 @@ xLLM 的 Graph Mode 覆盖多种图执行后端。其目标是在推理服务场 - 不覆盖所有算子或所有模型的适配细节 - 不替代功能文档中的参数说明与使用示例 +相关设计文档: + +- 若希望看一个更偏业务推理场景、并且聚焦固定调度、多步执行和定制算子的案例,可参考:[生成式推荐设计文档](generative_recommendation_design.md) + ## 1. Graph Mode 原理和在 xLLM 中的落地 ### 1.1 Graph Capture / Replay 的基本原理 diff --git a/docs/zh/dev_guide/tilelang_ascend_kernel_dev.md b/docs/zh/dev_guide/tilelang_ascend_kernel_dev.md new file mode 100644 index 000000000..b82809462 --- /dev/null +++ b/docs/zh/dev_guide/tilelang_ascend_kernel_dev.md @@ -0,0 +1,396 @@ +# xLLM Ascend TileLang Kernel 开发指南 + +本文说明在 xLLM 中新增或修改 Ascend TileLang kernel 的开发方式。示例全程使用当前的 `rope` kernel。 + +相关目录: + +- Python kernel 定义:`xllm/xllm/compiler/tilelang/targets/ascend/kernels` +- NPU runtime wrapper:`xllm/xllm/core/kernels/npu/tilelang` + +构建和测试应在 NPU 容器中执行。 + +## 1. 先判断修改类型 + +- 新增 `specialization` + - 给现有 kernel 增加一组新的编译参数组合 + - 仍复用同一个 wrapper、同一套 runtime dispatch 字段和同一套 C ABI + - 典型动作是修改 `DISPATCH_SCHEMA` 或 `SPECIALIZATIONS` +- 新增 `kernel` + - 新增一个新的逻辑算子 + - 典型动作是新增 Python kernel 文件、wrapper C++ 文件和一条 CMake 接线 + +对 `rope` 来说: + +- 给 `SPECIALIZATIONS` 增加一项 `{"variant_key": "...", "head_dim": ..., "rope_dim": ..., "dtype": ...}`,这是新增 `specialization` +- 新增一个新的 `xxx_wrapper.cpp` 对外接口,这是新增 `kernel` + +## 2. 开发顺序 + +推荐按下面顺序开发: + +1. 在 `rope.py` 这类 Python 文件里先写 TileLang kernel 实现 +2. 实现 `generate_source(...)`,把 kernel lower 成 Ascend-C 源码 +3. 声明 `DISPATCH_SCHEMA` 和 `SPECIALIZATIONS` +4. 先生成一次 `registry.inc` 并查看内容 +5. 再写或修改 wrapper 里的 runtime specialization 构造逻辑 +6. 接入 CMake 并运行测试 + +这个顺序的重点是: + +- 先把 kernel 本身实现出来 +- 再把 runtime dispatch schema 固定下来 +- 最后根据生成出来的 `registry.inc` 写 wrapper + +## 3. 编写 Python Kernel + +以 `rope.py` 为例,Python 侧可以按三层理解: + +- `build_rope_kernel(...)`:kernel 实现 +- `generate_source(...)`:AOT 导出 +- `RopeKernel`:注册 kernel,并声明 dispatch schema 与编译实例 + +### 3.1 实现 `build_rope_kernel(...)` + +`build_rope_kernel(...)` 才是 TileLang kernel 的实现主体。这里负责写: + +- `@T.prim_func` +- 输入输出张量 shape +- `with T.Kernel(...)` 下的并行任务组织 +- UB 分配和实际计算逻辑 + +`rope.py` 的精简骨架如下: + +```python +def build_rope_kernel( + head_dim: int, + rope_dim: int, + vec_core_num: int, + ub_buffer_bytes: int, +): + task_num = vec_core_num + m_num = vec_core_num // 2 + + @T.prim_func + def rope_in_place_kernel(...): + with T.Kernel(m_num, is_npu=True) as (cid, vid): + task_id = cid * 2 + vid + ... + + return rope_in_place_kernel +``` + +这里的 `head_dim`、`rope_dim` 是这一组实现真正依赖的编译参数。 + +`rope` 这类 vector kernel 还要遵守当前 AOT 使用方式下的固定任务约定。当前路径是 AOT 编译,kernel launch 的 `block_num` 会在编译时固定下来,因此: + +- 运行时输入 shape 不影响 kernel launch 的 `block_num` +- 运行时输入 shape 只影响固定任务之间的 workload 切分 + +当前 `rope.py` 的约定是: + +```python +task_num = vec_core_num +m_num = vec_core_num // 2 + +with T.Kernel(m_num, is_npu=True) as (cid, vid): + task_id = cid * 2 + vid +``` + +这表示: + +- `cid` 范围是 `[0, vec_core_num // 2)` +- `vid` 范围是 `[0, 2)` +- 总任务数固定为 `task_num = vec_core_num` + +因此,`rope.py` 在推导单个 specialization 的编译 token 数时,也按固定任务数计算: + +```python +max_rows_num_in_ub = _derive_max_rows_num_in_ub(...) +compile_num_tokens = task_num * max_rows_num_in_ub +``` + +### 3.2 实现 `generate_source(...)` + +`generate_source(...)` 负责把上面的 TileLang kernel lower 成最终源码。导出层的职责,是把一组 specialization 参数转换成可编译的 Ascend-C 源码。 + +对 `rope` 来说,核心逻辑如下: + +```python +@staticmethod +def generate_source(head_dim: int, rope_dim: int, dtype: str) -> str: + vec_core_num = detect_vec_core_num() + tilelang_kernel = build_rope_kernel( + head_dim=head_dim, + rope_dim=rope_dim, + vec_core_num=vec_core_num, + ub_buffer_bytes=FIXED_UB_BUFFER_BYTES, + ) + with tilelang.tvm.transform.PassContext(...): + kernel = tilelang.engine.lower(tilelang_kernel) + return kernel.kernel_source +``` + +这里的规则是: + +- `generate_source(...)` 的输入来自当前这组 `SPECIALIZATIONS` +- `generate_source(...)` 内部调用 `build_rope_kernel(...)` +- 返回值是 lower 后的源码字符串 + +### 3.3 声明 `DISPATCH_SCHEMA` 与 `SPECIALIZATIONS` + +当 kernel 实现和导出层写完后,再通过 `@register_kernel` 类把它接入框架。 + +`rope.py` 当前的最小模板如下: + +```python +from ....common.spec import DispatchField, TilelangKernel, register_kernel + + +@register_kernel +class RopeKernel(TilelangKernel): + DISPATCH_SCHEMA = [ + DispatchField("head_dim", "int32"), + DispatchField("rope_dim", "int32"), + DispatchField("dtype", "dtype"), + ] + SPECIALIZATIONS = [ + { + "variant_key": "hd128_rd128_bf16", + "head_dim": 128, + "rope_dim": 128, + "dtype": "bf16", + }, + { + "variant_key": "hd576_rd64_bf16", + "head_dim": 576, + "rope_dim": 64, + "dtype": "bf16", + }, + ] + + @staticmethod + def generate_source(head_dim: int, rope_dim: int, dtype: str) -> str: + ... +``` + +这里要分清楚两个概念: + +- `DISPATCH_SCHEMA` + - 定义 runtime specialization 的字段名、顺序和类型 + - 是 C++ 侧 specialization struct、builder 和查表接口的单一真相源 +- `SPECIALIZATIONS` + - 表示要实际编译出的实例集合 + - 每一项都对应一个 variant + +规则如下: + +- `DISPATCH_SCHEMA` 中每个字段都必须出现在每一项 `SPECIALIZATIONS` 里 +- `SPECIALIZATIONS` 中可以有额外字段,这些字段会传给 `generate_source(...)`,但不会进入 runtime dispatch schema +- `variant_key` 是这一组 specialization 的唯一标识 +- `DISPATCH_SCHEMA` 和 `SPECIALIZATIONS` 必须与 runtime specialization 一一对应 + +对 `rope` 来说,runtime dispatch 维度是: + +- `head_dim` +- `rope_dim` +- `dtype` + +所以这三个字段必须同时出现在: + +- `DISPATCH_SCHEMA` +- 每一项 `SPECIALIZATIONS` + +构建时,Ascend build 会根据主构建路径传入的 `--device a2|a3` 解析实际使用的 `bisheng_arch`。 + +### 3.4 查看生成的 Ascend-C 源码 + +在调试 `build_rope_kernel(...)` 的实现细节,或者比较不同 kernel 写法对最终代码生成的影响时,建议通过公共入口 `compile-kernels` 重新生成产物,再查看对应 specialization 的 Ascend-C 源码。 + +对 `rope` 来说,可以先固定: + +- `head_dim=576` +- `rope_dim=64` +- `dtype=bf16` + +然后重新生成 `rope` 的编译产物: + +```bash +python xllm/compiler/tilelang_launcher.py compile-kernels \ + --target ascend \ + --device a3 \ + --output-root /tmp/tilelang_debug \ + --kernels rope \ + --force +``` + +这里建议带上 `--force`,保证源码和 object 会按当前修改重新生成,不直接命中旧 cache。 + +这里把 `--output-root` 指到独立的调试目录 `/tmp/tilelang_debug`,这样当前命令只会在这个目录下生成 `rope` 的调试产物,不会和主构建目录里的其他 kernel 产物混在一起。 + +执行后,可以直接查看对应 specialization 的源码中的入口函数、UB 分配和向量计算逻辑: + +```bash +sed -n '1,200p' \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp + +rg -n 'extern "C"|__global__|alloc_ub|alloc_shared|g_tilingKey' \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp +``` + +如果要比较两种 kernel 写法的差异,做法是保持 specialization 不变,在修改前后各执行一次 `compile-kernels --force`,再对生成的 `.cpp` 做 diff: + +```bash +cp /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp \ + /tmp/rope_before.cpp + +diff -u /tmp/rope_before.cpp \ + /tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp +``` + +这样可以把“specialization 变化”和“kernel 实现变化”分开看。 + +执行后,可以重点查看这些文件: + +- `/tmp/tilelang_debug/targets/ascend/rope/hd576_rd64_bf16/rope_hd576_rd64_bf16_kernel.cpp` +- `/tmp/tilelang_debug/targets/ascend/rope/registry.inc` +- `/tmp/tilelang_debug/targets/ascend/rope/manifest.json` + +这三类文件分别对应: + +- 某个 specialization 的最终 Ascend-C 源码 +- wrapper 会直接包含的 runtime dispatch 接口 +- 当前 kernel 的全部编译产物记录 + +调试顺序建议是: + +1. 先执行 `compile-kernels --force` 重新生成当前 kernel 的产物 +2. 查看对应 specialization 的 `.cpp`,分析代码生成结果 +3. 再看 `registry.inc` 和 `manifest.json` 是否符合预期 +4. 最后通过 `rope_wrapper_test` 看完整接入后的结果和性能 + +## 4. 修改 Wrapper + +新增 `kernel` 时,需要新增 wrapper。新增 `specialization` 时,只有 runtime specialization 语义变化,才需要同步修改 wrapper。 + +对 `rope_wrapper.cpp` 来说,人工需要保留的内容是: + +- tensor shape、dtype、layout 校验 +- 把输入整理成 `x_rows / sin_rows / cos_rows` +- 从 tensor 构造 runtime specialization +- 组装 launch 参数并调用 `entry->fn(...)` + +### 4.1 `registry.inc` 会自动生成什么 + +`registry.inc` 由 Python 侧的 `DISPATCH_SCHEMA`、`SPECIALIZATIONS` 和导出出来的 Ascend-C ABI 自动生成。 + +对 `rope` 来说,生成内容包括: + +- `RopeSpecialization` +- `RopeHeadDim` +- `RopeRopeDim` +- `RopeDType` +- `RopeKernelFn` +- `make_rope_specialization(...)` +- `find_rope_kernel_entry(...)` +- `available_rope_variant_keys()` + +对 `rope_wrapper.cpp` 来说,`registry.inc` 会直接提供 `RopeSpecialization`、`operator==(...)`、`RopeKernelFn` 等 dispatch 相关定义。`dtype` 转换统一使用公共函数 `to_tilelang_dtype(...)`。 + +### 4.2 wrapper 里真正要写的东西 + +`rope_wrapper.cpp` 中最关键的人工逻辑,是从 tensor 构造 runtime specialization。当前写法如下: + +```cpp +RopeSpecialization build_runtime_specialization(const torch::Tensor& x_rows) { + return make_rope_specialization( + RopeHeadDim{static_cast(x_rows.stride(0))}, + RopeRopeDim{static_cast(x_rows.size(1))}, + RopeDType{to_tilelang_dtype(x_rows.scalar_type())}); +} +``` + +对 `rope` 而言: + +- `head_dim` 对应 `x_rows.stride(0)`,也就是 kernel 使用的 `x_stride` +- `rope_dim` 对应 `x_rows.size(1)` +- `dtype` 对应 `x_rows.scalar_type()` + +运行时路径是: + +1. wrapper 把输入整理成 `x_rows / sin_rows / cos_rows` +2. `build_runtime_specialization(...)` 从 `x_rows` 构造 specialization +3. `find_rope_kernel_entry(...)` 在静态 registry 中精确匹配 +4. 命中后通过 `entry->fn(...)` 调用实际编译出来的符号 + +当前查找策略是线性扫描加精确匹配。只要 `head_dim`、`rope_dim`、`dtype` 有一个不一致,就不会命中。 + +所以新增 `specialization` 时,要重点核对的是: + +- Python 侧 `DISPATCH_SCHEMA` 的字段语义 +- Python 侧 `SPECIALIZATIONS` 的字段值 +- wrapper 里 `build_runtime_specialization(...)` 构造出的字段值 + +这三者必须完全对齐。 + +### 4.3 先生成并查看 `registry.inc` + +在写或修改 wrapper 之前,先生成一次 `registry.inc` 并查看内容。重点看: + +- 生成出来的 `RopeSpecialization` 字段顺序是否符合预期 +- 生成出来的字段包装类型名是否符合预期 +- `make_rope_specialization(...)` 的参数顺序是什么 +- 生成出来的 entry symbol 名称是什么 + +`registry.inc` 是 wrapper 的直接契约,先看它,再写 wrapper。 + +## 5. 修改 CMake + +新增 `kernel` 时,在 `xllm/xllm/core/kernels/npu/tilelang/CMakeLists.txt` 中接入这个 kernel。 + +CMake 接入统一通过高层 helper 完成: + +- `tilelang_register_runtime_kernel(NAME WRAPPER_SRCS )` + +以 `rope` 为例,最小模板如下: + +```cmake +tilelang_register_runtime_kernel( + NAME rope + WRAPPER_SRCS rope_wrapper.cpp +) +``` + +这条 helper 会完成: + +- 按 `TILELANG_GENERATED_ROOT/targets/ascend//manifest.json` 推导 manifest 路径 +- 导入 manifest +- 把该 kernel 的 wrapper source 和 compiled objects 加入 `tilelang_kernels` +- 自动追加 `XLLM_TL__REGISTRY_INC=...` compile definition + +因此,新增一个 runtime kernel 时,CMake 侧主要就是两件事: + +1. 保证 Python 侧已经能生成该 kernel 的 manifest +2. 在 `tilelang` 的 CMakeLists 里新增一条 `tilelang_register_runtime_kernel(...)` + +日常新增 kernel 时,直接在 CMake 中增加一条 `tilelang_register_runtime_kernel(...)`。`tilelang_import_kernel_manifest(...)` 作为这条高层 helper 的实现基础,保留在底层。 + +## 6. 验证 + +推荐按下面顺序验证: + +1. 先编译 TileLang kernel,并查看生成的 `registry.inc` +2. 再跑完整的 wrapper 测试 + +常用命令: + +```bash +python xllm/compiler/tilelang_launcher.py compile-kernels \ + --target ascend \ + --device a3 \ + --output-root build/cmake.linux-aarch64-cpython-311/xllm/compiler/tilelang \ + --kernels rope + +python setup.py test --test-name rope_wrapper_test --device a3 +``` + +第一条命令用于生成 `manifest.json`、`registry.inc` 和 object;第二条命令用于验证完整接入路径。 diff --git a/docs/zh/features/overview.md b/docs/zh/features/overview.md index 0ffc82fe5..197c68899 100644 --- a/docs/zh/features/overview.md +++ b/docs/zh/features/overview.md @@ -36,11 +36,11 @@ xLLM全面支持PD分离场景,实现了高效的PD实例的管理以及PD实 #### 全局调度 xLLM对请求和实例做全周期的资源调度智能管理。 ##### 实例调度 -我们实现了多种实例调度策略来选择如何将实例分配到更适合的实例。包括简单的Round Robin策略,基于请求在各实例上的prefix cache命中率来选择的prefix cahce-aware策略,基于实例上的显存空闲程度的kv cache-aware策略。另外,针对PD分离场景,由于静态的PD比例往往无法很好应对流量以及请求输入输出长度突变的场景,我们实现了一种自适应的PD动态调度器,负责在线请求的全局实例分配与运行时PD动态调整。 +我们实现了多种实例调度策略来选择如何将实例分配到更适合的实例。包括简单的Round Robin策略,基于请求在各实例上的 prefix cache 命中率来选择的 prefix cache-aware 策略,基于实例上的显存空闲程度的 KV Cache-aware 策略。另外,针对PD分离场景,由于静态的PD比例往往无法很好应对流量以及请求输入输出长度突变的场景,我们实现了一种自适应的PD动态调度器,负责在线请求的全局实例分配与运行时PD动态调整。 ##### 请求调度 我们实现了多种请求调度策略,支持continuous batching,包括chunked prefill,prefill优先和decode优先等batch策略,同时全面支持PD分离场景。 #### 全局kv cache管理 -在全局层面采用ETCD作为元数据服务中间件,实现集群服务注册、负载信息同步及全局缓存状态管理。每个计算实例维护本地多级缓存池。在调度策略方面,系统采用基于kv cache缓存的动态决策机制:首先进行前缀匹配检测,计算各候选节点的KV缓存复用率,最终选择综合性能最优的节点进行处理,实现kv cache的动态卸载与迁移。 +在全局层面采用ETCD作为元数据服务中间件,实现集群服务注册、负载信息同步及全局缓存状态管理。每个计算实例维护本地多级缓存池。在调度策略方面,系统采用基于 KV Cache 的动态决策机制:首先进行前缀匹配检测,计算各候选节点的 KV Cache 复用率,最终选择综合性能最优的节点进行处理,实现 KV Cache 的动态卸载与迁移。 #### 投机推理 xLLM内置优化后的投机推理算法,一次生成多个tokens提升吞吐。xLLM通过投机模块下沉减少通信成本,并使用调度和计算时序重叠优化、减少投机场景算子数据搬运等方式优化投机推理计算。 @@ -50,3 +50,8 @@ xLLM针对MoE模型实现了基于历史专家负载统计的专家权重更新 ### 多模态支持 xLLM对包括Qwen2-VL,MiniCPMV在内的多种多模态模型提供全面的支持。 + +### 相关设计文档 + +- [Graph Mode 设计文档](../design/graph_mode_design.md) +- [生成式推荐设计文档](../design/generative_recommendation_design.md) diff --git a/docs/zh/getting_started/quick_start.md b/docs/zh/getting_started/quick_start.md index bff46e75a..d08c91f5c 100644 --- a/docs/zh/getting_started/quick_start.md +++ b/docs/zh/getting_started/quick_start.md @@ -94,7 +94,7 @@ cd xllm pip install pre-commit pre-commit install -git submodule update --init +git submodule update --init --recursive ``` 编译生成的二进制文件位于`/path/to/xllm/build/xllm/core/server/xllm`,在新镜像中,第一次编译xllm耗时较长,因为需要编译vcpkg中的所有依赖,但是后续编译会很快。 diff --git a/docs/zh/getting_started/quick_start_GLM5.md b/docs/zh/getting_started/quick_start_GLM5.md index dd1f27524..0b8c2a15e 100644 --- a/docs/zh/getting_started/quick_start_GLM5.md +++ b/docs/zh/getting_started/quick_start_GLM5.md @@ -161,7 +161,7 @@ do --communication_backend="hccl" \ --enable_schedule_overlap=true \ --enable_graph=true \ - --enable_graph_mode_decode_no_padding=true \ + --enable_graph_mode_decode_no_padding=true \ --draft_model=$DRAFT_MODEL_PATH \ --draft_devices="npu:$DEVICE" \ --num_speculative_tokens=1 \ @@ -233,7 +233,7 @@ for (( i=0; i<$LOCAL_NODES; i++ ))do --communication_backend="hccl" \ --enable_schedule_overlap=true \ --enable_graph=true \ - --enable_graph_mode_decode_no_padding=true \ + --enable_graph_mode_decode_no_padding=true \ --ep_size=16 \ --dp_size=1 \ --rank_tablefile=/yourPath/ranktable.json \ @@ -273,7 +273,7 @@ for (( i=0; i<$LOCAL_NODES; i++ ))do --communication_backend="hccl" \ --enable_schedule_overlap=true \ --enable_graph=true \ - --enable_graph_mode_decode_no_padding=true \ + --enable_graph_mode_decode_no_padding=true \ --ep_size=16 \ --dp_size=1 \ --rank_tablefile=/yourPath/ranktable.json \ @@ -572,7 +572,506 @@ ENABLE_DECODE_RESPONSE_TO_SERVICE=true ../xllm-service/build/xllm_service/xllm_m #--etcd_addr=$LOCAL_HOST:3389 参考etcd中advertise-client-urls的配置 #--instance_role=DECODE PD配置,DECODE\PREFILL ``` - + +# GLM5/CP 特性压测数据(最优配置) +## 压测环境 +* 硬件: Ascend 910C(A3) / 4 Pods +* 主模型:GLM5-W8A8 +* 草稿模型:GLM5-W8A8-MTP +* PD分离配置: + * P实例:cp_size = 16,dp_size = 1, ep_size = 1 + * D实例:dp_size = 2, ep_size = 32 +* xllm 版本:release/v0.9.0(9be308aec60ea4a2dd799ee021ea42d608f4e67c - lastcommit) + +## PD分离服务启动脚本 +### PD分离4机配置 +#### prefill双节点配置 +##### prefill节点1 +``` +#!/bin/bash +set -e + +rm -rf core.* +rm -rf ~/ascend/log + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 +export HCCL_IF_BASE_PORT=43432 + +#export ASCEND_GLOBAL_LOG_LEVEL=1 +#export MINDIE_LOG_TO_STDOUT=1 + +#export LCCL_DETERMINISTIC=1 +#export HCCL_DETERMINISTIC=true +#export ATB_MATMUL_SHUFFLE_K_ENABLE=0 + +#export ASCEND_LAUNCH_BLOCKING=1 +#export ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=1 + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export NPU_MEMORY_FRACTION=0.96 +export ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=3 +export ATB_WORKSPACE_MEM_ALLOC_GLOBAL=1 +export ATB_LAYER_INTERNAL_TENSOR_REUSE=1 +export ATB_CONTEXT_WORKSPACE_SIZE=0 + + +MODEL_PATH="/export/home/models/GLM-5-final-w8a8/" +#MODEL_PATH="/export/home/models/DeepSeek-V3.2-w8a8/" +#DRAFT_MODEL_PATH="/export/home/models/DeepSeek-V3.2-w8a8-mtp" +DRAFT_MODEL_PATH="/export/home/models/GLM-5-final-w8a8-MTP/" +MASTER_NODE_ADDR="11.87.191.98:1895" +START_PORT=48000 +START_DEVICE=0 +LOG_DIR="log" +NNODES=32 +LOCAL_NODES=16 +LOCAL_HOST="11.87.191.98" + +mkdir -p $LOG_DIR + + + #--draft_model $DRAFT_MODEL_PATH \ + #--draft_devices="npu:$DEVICE" \ + #--num_speculative_tokens 3 \ + +for (( i=0; i<$LOCAL_NODES; i++ )) +do + PORT=$((START_PORT + i)) + DEVICE=$((START_DEVICE + i)) + LOG_FILE="$LOG_DIR/node_$i.log" + nohup numactl -C $((DEVICE*40))-$((DEVICE*40+39)) /export/home/shifengmin.3/workspace/lt_xllm/build/xllm/core/server/xllm \ + --model $MODEL_PATH \ + --devices="npu:$DEVICE" \ + --port $PORT \ + --host $LOCAL_HOST \ + --master_node_addr=$MASTER_NODE_ADDR \ + --draft_model $DRAFT_MODEL_PATH \ # 草稿模型 + --draft_devices="npu:$DEVICE" \ + --num_speculative_tokens 3 \ # 采样率 + --nnodes=$NNODES \ + --max_memory_utilization=0.7 \ # 现存使用率 + --block_size=128 \ + --max_seqs_per_batch=9000 \ + --max_tokens_per_batch=67000 \ + --communication_backend="hccl" \ + --enable_prefix_cache=false \ + --enable_chunked_prefill=false \ + --enable_schedule_overlap=false \ + --enable_disagg_pd=true \ # 开启PD分离 + --instance_role=PREFILL \ + --etcd_addr=11.87.191.83:3389 \ + --transfer_listen_port=$((26000+i)) \ + --disagg_pd_port=7777 \ + --cp_size 16 \ # 开启CP + --dp_size 1 \ + --ep_size 1 \ + --node_rank=$i \ + --rank_tablefile=/export/home/shifengmin.3/workspace/ranktable_9899_new.json \ # prefill双机卡间通信路由表 + > $LOG_FILE 2>&1 & + +done + +tail -f log/node_0.log +``` + +##### prefill节点2 +``` +#!/bin/bash +set -e + +rm -rf core.* +rm -rf ~/ascend/log + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 +export HCCL_IF_BASE_PORT=43432 + +#export ASCEND_GLOBAL_LOG_LEVEL=1 +#export MINDIE_LOG_TO_STDOUT=1 + +#export LCCL_DETERMINISTIC=1 +#export HCCL_DETERMINISTIC=true +#export ATB_MATMUL_SHUFFLE_K_ENABLE=0 + +#export ASCEND_LAUNCH_BLOCKING=1 +#export ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export NPU_MEMORY_FRACTION=0.96 +export ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=3 +export ATB_WORKSPACE_MEM_ALLOC_GLOBAL=1 +export ATB_LAYER_INTERNAL_TENSOR_REUSE=1 +export ATB_CONTEXT_WORKSPACE_SIZE=0 + +MODEL_PATH="/export/home/models/GLM-5-final-w8a8/" +DRAFT_MODEL_PATH="/export/home/models/GLM-5-final-w8a8-MTP/" +MASTER_NODE_ADDR="11.87.191.98:1895" +START_PORT=48000 +START_DEVICE=0 +LOG_DIR="log" +NNODES=32 +LOCAL_NODES=16 +LOCAL_HOST="11.87.191.99" + +mkdir -p $LOG_DIR + + + #--draft_model $DRAFT_MODEL_PATH \ + #--draft_devices="npu:$DEVICE" \ + #--num_speculative_tokens 3 \ + +for (( i=0; i<$LOCAL_NODES; i++ )) +do + PORT=$((START_PORT + i)) + DEVICE=$((START_DEVICE + i)) + LOG_FILE="$LOG_DIR/node_$i.log" + nohup numactl -C $((DEVICE*40))-$((DEVICE*40+39)) /export/home/shifengmin.3/workspace/lt_xllm/build/xllm/core/server/xllm \ + --model $MODEL_PATH \ + --devices="npu:$DEVICE" \ + --port $PORT \ + --host $LOCAL_HOST \ + --master_node_addr=$MASTER_NODE_ADDR \ + --draft_model $DRAFT_MODEL_PATH \ # 草稿模型地址 + --draft_devices="npu:$DEVICE" \ + --num_speculative_tokens 3 \ # 采样率 + --nnodes=$NNODES \ + --max_memory_utilization=0.7 \ # 显存使用率 + --block_size=128 \ + --max_seqs_per_batch=9000 \ + --max_tokens_per_batch=67000 \ + --communication_backend="hccl" \ + --enable_prefix_cache=false \ + --enable_chunked_prefill=false \ + --enable_schedule_overlap=false \ + --enable_disagg_pd=true \ # 开启PD分离 + --instance_role=PREFILL \ + --etcd_addr=11.87.191.83:3389 \ + --transfer_listen_port=$((26100+i)) \ + --disagg_pd_port=7777 \ + --cp_size 16 \ # 开启CP + --dp_size 1 \ + --ep_size 1 \ + --node_rank=$((i+LOCAL_NODES)) \ + --rank_tablefile=/export/home/shifengmin.3/workspace/ranktable_9899_new.json \ + > $LOG_FILE 2>&1 & + +done + +tail -f log/node_0.log +``` +#### decode 双机配置 +##### decode节点1 +``` +#!/bin/bash +set -e + +rm -rf core.* +rm -rf ~/ascend/log + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 +export HCCL_IF_BASE_PORT=43432 + +#export ASCEND_GLOBAL_LOG_LEVEL=1 +#export MINDIE_LOG_TO_STDOUT=1 + +#export LCCL_DETERMINISTIC=1 +#export HCCL_DETERMINISTIC=true +#export ATB_MATMUL_SHUFFLE_K_ENABLE=0 + +#export ASCEND_LAUNCH_BLOCKING=1 +#export ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=1 + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export NPU_MEMORY_FRACTION=0.96 +export ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=3 +export ATB_WORKSPACE_MEM_ALLOC_GLOBAL=1 +export ATB_LAYER_INTERNAL_TENSOR_REUSE=1 +export ATB_CONTEXT_WORKSPACE_SIZE=0 + + +MODEL_PATH="/export/home/models/GLM-5-final-w8a8/" +DRAFT_MODEL_PATH="/export/home/models/GLM-5-final-w8a8-MTP/" +MASTER_NODE_ADDR="11.87.191.83:1895" +START_PORT=48000 +START_DEVICE=0 +LOG_DIR="log" +NNODES=32 +LOCAL_NODES=16 +LOCAL_HOST="11.87.191.83" + +mkdir -p $LOG_DIR + + + #--draft_model $DRAFT_MODEL_PATH \ + #--draft_devices="npu:$DEVICE" \ + #--num_speculative_tokens 3 \ + +for (( i=0; i<$LOCAL_NODES; i++ )) +do + PORT=$((START_PORT + i)) + DEVICE=$((START_DEVICE + i)) + LOG_FILE="$LOG_DIR/node_$i.log" + nohup numactl -C $((DEVICE*40))-$((DEVICE*40+39)) /export/home/shifengmin.3/workspace/lt_xllm/build/xllm/core/server/xllm \ + --model $MODEL_PATH \ # GLM5.0权重 + --devices="npu:$DEVICE" \ + --port $PORT \ + --host $LOCAL_HOST \ + --master_node_addr=$MASTER_NODE_ADDR \ + --draft_model $DRAFT_MODEL_PATH \ # MTP权重 + --draft_devices="npu:$DEVICE" \ + --num_speculative_tokens 3 \ # 采样率 + --nnodes=$NNODES \ + --max_memory_utilization=0.80 \ + --block_size=128 \ + --max_seqs_per_batch=9000 \ + --communication_backend="hccl" \ + --enable_prefix_cache=false \ + --enable_chunked_prefill=false \ + --enable_schedule_overlap=true \ # 开启异步调度 + --enable_shm=true \ # 开启共享内存 + --enable_graph=false \ + --enable_graph_mode_decode_no_padding=false \ + --enable_disagg_pd=true \ # 开启PD分离 + --instance_role=DECODE \ + --etcd_addr=11.87.191.83:3389 \ + --transfer_listen_port=$((26000+i)) \ + --disagg_pd_port=7777 \ + --dp_size 2 \ # dp并行 + --ep_size 32 \ # EP并行 + --node_rank=$i \ + --rank_tablefile=/export/home/shifengmin.3/workspace/ranktable_8382_new.json \ # 设置卡间通信 + > $LOG_FILE 2>&1 & + +done + +tail -f log/node_0.log +``` +##### decode节点-2 +``` +#!/bin/bash +set -e + +rm -rf core.* +rm -rf ~/ascend/log + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 +export HCCL_IF_BASE_PORT=43432 + +#export ASCEND_GLOBAL_LOG_LEVEL=1 +#export MINDIE_LOG_TO_STDOUT=1 + +#export LCCL_DETERMINISTIC=1 +#export HCCL_DETERMINISTIC=true +#export ATB_MATMUL_SHUFFLE_K_ENABLE=0 + +#export ASCEND_LAUNCH_BLOCKING=1 +#export ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=1 + +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export NPU_MEMORY_FRACTION=0.96 +export ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=3 +export ATB_WORKSPACE_MEM_ALLOC_GLOBAL=1 +export ATB_LAYER_INTERNAL_TENSOR_REUSE=1 +export ATB_CONTEXT_WORKSPACE_SIZE=0 + + +MODEL_PATH="/export/home/models/GLM-5-final-w8a8/" +DRAFT_MODEL_PATH="/export/home/models/GLM-5-final-w8a8-MTP/" +MASTER_NODE_ADDR="11.87.191.83:1895" +START_PORT=48000 +START_DEVICE=0 +LOG_DIR="log" +NNODES=32 +LOCAL_NODES=16 +LOCAL_HOST="11.87.191.82" + +mkdir -p $LOG_DIR + + + #--draft_model $DRAFT_MODEL_PATH \ + #--draft_devices="npu:$DEVICE" \ + #--num_speculative_tokens 3 \ + +for (( i=0; i<$LOCAL_NODES; i++ )) +do + PORT=$((START_PORT + i)) + DEVICE=$((START_DEVICE + i)) + LOG_FILE="$LOG_DIR/node_$i.log" + nohup numactl -C $((DEVICE*40))-$((DEVICE*40+39)) /export/home/shifengmin.3/workspace/lt_xllm/build/xllm/core/server/xllm \ + --model $MODEL_PATH \ + --devices="npu:$DEVICE" \ + --port $PORT \ + --host $LOCAL_HOST \ + --master_node_addr=$MASTER_NODE_ADDR \ + --draft_model $DRAFT_MODEL_PATH \ # 草稿模型 + --draft_devices="npu:$DEVICE" \ + --num_speculative_tokens 3 \ # 采样率 + --nnodes=$NNODES \ + --max_memory_utilization=0.80 \ # 现存使用率 80% + --block_size=128 \ + --max_seqs_per_batch=9000 \ + --communication_backend="hccl" \ + --enable_prefix_cache=false \ + --enable_chunked_prefill=false \ + --enable_schedule_overlap=true \ # 开启异步调度 + --enable_shm=true \ # 开启共享内存 + --enable_graph=false \ + --enable_graph_mode_decode_no_padding=false \ + --enable_disagg_pd=true \ # PD分离 + --instance_role=DECODE \ # decode 节点 + --etcd_addr=11.87.191.83:3389 \ + --transfer_listen_port=$((26100+i)) \ + --disagg_pd_port=7777 \ + --dp_size 2 \ # 开启dp + --ep_size 32 \ # 开启ep + --node_rank=$((i+LOCAL_NODES)) \ + --rank_tablefile=/export/home/shifengmin.3/workspace/ranktable_8382_new.json \ # 双机间通信路由表 + > $LOG_FILE 2>&1 & + +done + +tail -f log/node_0.log +``` + +## 压测 +### 自定义数据集 - 输入输出配置 +Modified Location:/benchmark/ais_bench/datasets/synthetic/synthetic_config.py +``` +# +# [Uniform均匀分布] -- "Method" : "uniform" +# - MinValue: 最小值,范围为 [1, 2**20] +# - MaxValue: 最大值, 范围为 [1, 2**20], 可等于MinValue +# +# [Gaussian高斯分布] -- "Method" : "gaussian" +# - Mean : 平均值, 范围为 [-3.0e38, 3.0e38],分布中心位置 +# - Var : 方差, 范围为[0, 3.0e38],控制数据分散程度 +# - MinValue: 最小值, 范围为 [1, 2**20], 可低于Mean +# - MaxValue: 最大值, 范围为 [1, 2**20], 可高于Mean, 可等于MinValue +# +# [Zipf齐夫分布] -- "Method" : "zipf" +# - Alpha : 形状参数, 范围为(1.0,10.0], 值越大分布越均匀 +# - MinValue: 最小值, 范围为 [1, 2**20] +# - MaxValue: 最大值, 范围为 [1, 2**20], 需大于MinValue +""" +synthetic_config = { + "Type":"tokenid", # [tokenid/string],生成的随机数据集类型,支持固定长度的随机tokenid,和随机长度的string,两种类型的数据集 + "RequestCount": 10, # 生成的请求条数,应与模型侧配置文件中的 decode_batch_size 一致 + "TrustRemoteCode": False, #是否信任远端代码,tokenid模式下需要加载tokenizer生成tokenid,默认为Fasle + "StringConfig" : { # string类型的随机数据集的配置相关项,请参考以上注释处:"StringConfig中的随机生成方法参数说明" + "Input" : { # 每条请求的输入长度 + "Method": "uniform", + "Params": {"MinValue": 16384, "MaxValue": 16384} + }, + "Output" : { # 每条请求的输出长度 + "Method": "gaussian", + "Params": {"Mean": 1024, "Var": 0, "MinValue": 1024, "MaxValue": 1024} + } + }, + "TokenIdConfig" : { # tokenid类型的随机数据集的配置相关项 + "RequestSize": 16384 # 每条请求的长度,即每条请求中token id的个数,应与模型侧配置文件中的 input_seq_len 一致 + } +} +``` +### ais_bench客户端设置 +更新位置:/benchmark/ais_bench/benchmark/configs/models/vllm_api/vllm_api_stream_chat.py +``` +from ais_bench.benchmark.models import VLLMCustomAPIChatStream +from ais_bench.benchmark.utils.model_postprocessors import extract_non_reasoning_content + +models = [ + dict( + attr="service", + type=VLLMCustomAPIChatStream, + abbr='vllm-api-stream-chat', + path="[$GLM5_weight]", # GLM5 w8a8权重 + model="[$GLM5_mtp_weight]", # GLM5-MTP权重 + request_rate = 0, + retry = 1, + host_ip = "[$server_ip]", # 推理服务ip + host_port = [$server_port], # 推理服务port + max_out_len = 1024, # token输出数量 + batch_size=1, + trust_remote_code=False, + generation_kwargs = dict( + temperature = 0, + top_k = -1, + top_p = 1, + seed = None, + repetition_penalty = 1.03, + ignore_eos=True, + ), + pred_postprocessor=dict(type=extract_non_reasoning_content) + ) +] +``` + +### ais_bench客户端发起压测 +``` +ais_bench --models vllm_api_stream_chat --datasets synthetic_gen -m perf +``` + +## 压测数据 +### 32k/2k +* TTFT: P99 3.36/s +* TPOT:P99 42/ms + +``` +╒══════════════════════════╤═════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════╕ +│ Performance Parameters │ Stage │ Average │ Min │ Max │ Median │ P75 │ P90 │ P99 │ N │ +╞══════════════════════════╪═════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════╡ +│ E2EL │ total │ 77142.7 ms │ 63874.8 ms │ 89564.7 ms │ 78482.0 ms │ 82704.9 ms │ 84642.2 ms │ 89072.4 ms │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ TTFT │ total │ 3221.8 ms │ 3179.5 ms │ 3375.2 ms │ 3198.3 ms │ 3230.2 ms │ 3255.4 ms │ 3363.2 ms │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ TPOT │ total │ 36.1 ms │ 29.6 ms │ 42.2 ms │ 36.8 ms │ 38.8 ms │ 39.8 ms │ 42.0 ms │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ ITL │ total │ 88.0 ms │ 0.0 ms │ 1866.4 ms │ 89.9 ms │ 92.0 ms │ 93.9 ms │ 110.2 ms │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ InputTokens │ total │ 32744.3 │ 32629.0 │ 32923.0 │ 32733.5 │ 32801.25 │ 32867.2 │ 32917.42 │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ OutputTokens │ total │ 2048.0 │ 2048.0 │ 2048.0 │ 2048.0 │ 2048.0 │ 2048.0 │ 2048.0 │ 10 │ +├──────────────────────────┼─────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────┤ +│ OutputTokenThroughput │ total │ 26.8428 token/s │ 22.8662 token/s │ 32.0627 token/s │ 26.0953 token/s │ 27.7358 token/s │ 31.9176 token/s │ 32.0482 token/s │ 10 │ +╘══════════════════════════╧═════════╧═════════════════╧═════════════════╧═════════════════╧═════════════════╧═════════════════╧═════════════════╧═════════════════╧═════╛ +╒══════════════════════════╤═════════╤════════════════════╕ +│ Common Metric │ Stage │ Value │ +╞══════════════════════════╪═════════╪════════════════════╡ +│ Benchmark Duration │ total │ 771444.3698 ms │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Total Requests │ total │ 10 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Failed Requests │ total │ 0 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Success Requests │ total │ 10 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Concurrency │ total │ 1.0 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Max Concurrency │ total │ 1 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Request Throughput │ total │ 0.013 req/s │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Total Input Tokens │ total │ 327443 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Prefill Token Throughput │ total │ 10163.3049 token/s │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Total Generated Tokens │ total │ 20480 │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Input Token Throughput │ total │ 424.4545 token/s │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Output Token Throughput │ total │ 26.5476 token/s │ +├──────────────────────────┼─────────┼────────────────────┤ +│ Total Token Throughput │ total │ 451.0021 token/s │ +╘══════════════════════════╧═════════╧═════════════════ +``` + + + 需要注意: - PD分离需要读取`/etc/hccn.conf`文件,确保将物理机上的该文件映射到了容器中 diff --git a/docs/zh/index.md b/docs/zh/index.md index 6c84c06d2..27d80b32a 100644 --- a/docs/zh/index.md +++ b/docs/zh/index.md @@ -59,4 +59,8 @@ xLLM 提供了强大的智能计算能力,通过硬件系统的算力优化与 - 投机推理优化,多核并行提升效率; - MoE专家的动态负载均衡,实现专家分布的高效调整。 +## 设计文档 + +- [Graph Mode 设计文档](design/graph_mode_design.md) +- [生成式推荐设计文档](design/generative_recommendation_design.md) diff --git a/docs/zh/supported_models.md b/docs/zh/supported_models.md index 5b9112538..eb2c2c2ea 100644 --- a/docs/zh/supported_models.md +++ b/docs/zh/supported_models.md @@ -43,6 +43,7 @@ ## Rec | | NPU | MLU | ILU | | --- | :---: | :---: | :---: | -| | | | | -| | | | | -| | | | | \ No newline at end of file +| OneRec | ✅ | ❌ | ❌ | +| Qwen2 | ✅ | ❌ | ❌ | +| Qwen2.5 | ✅ | ❌ | ❌ | +| Qwen3 | ✅ | ❌ | ❌ | diff --git a/mkdocs_en.yml b/mkdocs_en.yml index eb5d54f65..0a5bdb054 100644 --- a/mkdocs_en.yml +++ b/mkdocs_en.yml @@ -221,5 +221,8 @@ nav: - Advanced Guides: - Speculative Inference: features/mtp.md - GraphMode: features/graph_mode.md + - Developer Guide: + - dev_guide/code_arch.md + - dev_guide/tilelang_ascend_kernel_dev.md - CLI Reference: - cli_reference.md diff --git a/mkdocs_zh.yml b/mkdocs_zh.yml index cb4ee1679..13479521f 100644 --- a/mkdocs_zh.yml +++ b/mkdocs_zh.yml @@ -224,6 +224,9 @@ nav: - 投机推理: features/mtp.md - GraphMode: features/graph_mode.md - xLLM Service概览: features/xllm_service_overview.md + - 开发者指南: + - dev_guide/code_arch.md + - dev_guide/tilelang_ascend_kernel_dev.md - CLI参考: - cli_reference.md @@ -275,4 +278,4 @@ nav: # - ACLGraph: zh/features/acl_graph.md # - xLLM Service概览: zh/features/xllm_service_overview.md # - CLI参考: -# - zh/cli_reference.md \ No newline at end of file +# - zh/cli_reference.md diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..047912da9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "xllm" +description = "A high-performance inference system for large language models." +requires-python = ">=3.10" +license = { text = "Apache 2.0" } +authors = [{ name = "xLLM Team", email = "infer@xllm.ai" }] +dynamic = ["readme", "version"] +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Programming Language :: C++", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Operating System :: POSIX", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +[project.urls] +Homepage = "https://xllm.readthedocs.io/zh-cn/latest/" +Documentation = "https://xllm.readthedocs.io/zh-cn/latest/" +Repository = "https://github.com/jd-opensource/xllm" + +[tool.setuptools.dynamic] +readme = { file = ["README.md"], content-type = "text/markdown" } +version = { file = ["version.txt"] } diff --git a/env.py b/scripts/build_support/env.py similarity index 100% rename from env.py rename to scripts/build_support/env.py diff --git a/utils.py b/scripts/build_support/utils.py similarity index 95% rename from utils.py rename to scripts/build_support/utils.py index 9f4788185..78a990db0 100644 --- a/utils.py +++ b/scripts/build_support/utils.py @@ -5,6 +5,7 @@ import sysconfig import io import shlex +from pathlib import Path from typing import Optional # get cpu architecture @@ -53,7 +54,13 @@ def get_device_type() -> str: exit(1) def get_base_dir() -> str: - return os.path.abspath(os.path.dirname(__file__)) + helper_path = Path(__file__).resolve() + for parent in helper_path.parents: + if all((parent / marker).exists() for marker in ("setup.py", "version.txt", "CMakeLists.txt")): + return str(parent) + + fallback_index = min(2, len(helper_path.parents) - 1) + return str(helper_path.parents[fallback_index]) def _join_path(*paths: str) -> str: return os.path.join(get_base_dir(), *paths) @@ -197,7 +204,7 @@ def _collect_submodule_init_issues(repo_root: str) -> dict[str, str]: _print_manual_check_commands([ f"cd {repo_root}", "git submodule status", - "git submodule update --init", + "git submodule update --init --recursive", ]) exit(1) @@ -306,7 +313,7 @@ def _validate_submodules_or_exit(repo_root: str) -> None: for path in sorted(issues): print(f" - {path}: {issues[path]}") print("\nPlease align submodules and try again:") - print(" git submodule update --init [-f|--force]") + print(" git submodule update --init --recursive [-f|--force]") exit(1) @@ -375,7 +382,7 @@ def _ensure_xllm_ops_rebuild_on_missing_marker() -> None: return def pre_build() -> None: - script_path = os.path.dirname(os.path.abspath(__file__)) + script_path = get_base_dir() _validate_submodules_or_exit(script_path) _ensure_prebuild_dependencies_installed(script_path) diff --git a/setup.py b/setup.py index 83ed3e3e4..2ad68c766 100644 --- a/setup.py +++ b/setup.py @@ -8,15 +8,60 @@ from typing import Any, Optional from distutils.core import Command -from setuptools import Extension, setup -from setuptools.command.bdist_wheel import bdist_wheel +from setuptools import Extension, find_namespace_packages, setup from setuptools.command.build_ext import build_ext -from env import get_cxx_abi, set_npu_envs, set_mlu_envs, set_cuda_envs, set_ilu_envs, set_musa_envs -from utils import get_cpu_arch, get_device_type, pre_build, get_version, check_and_install_pre_commit, read_readme, get_cmake_dir, get_base_dir, get_python_version, get_torch_version +try: + from setuptools.command.bdist_wheel import bdist_wheel +except ModuleNotFoundError: + from wheel.bdist_wheel import bdist_wheel + +from scripts.build_support.env import ( + get_cxx_abi, + set_cuda_envs, + set_ilu_envs, + set_mlu_envs, + set_musa_envs, + set_npu_envs, +) +from scripts.build_support.utils import ( + check_and_install_pre_commit, + get_base_dir, + get_cmake_dir, + get_cpu_arch, + get_device_type, + get_python_version, + get_torch_version, + get_version, + pre_build, + read_readme, +) BUILD_TEST_FILE: bool = True BUILD_EXPORT: bool = True + + +def _maybe_compile_tilelang_kernels(device: str) -> None: + if device != "npu": + return + + output_root = os.path.join(get_cmake_dir(), "xllm", "compiler", "tilelang") + os.makedirs(output_root, exist_ok=True) + + env = os.environ.copy() + base_dir = get_base_dir() + + cmd = [ + sys.executable, + os.path.join(base_dir, "xllm", "compiler", "tilelang_launcher.py"), + "compile-kernels", + "--target", + "ascend", + "--output-root", + output_root, + ] + print("[INFO] compiling TileLang kernels via source-tree launcher") + subprocess.check_call(cmd, cwd=base_dir, env=env) class CMakeExtension(Extension): def __init__(self, name: str, path: str, sourcedir: str = "") -> None: @@ -119,6 +164,7 @@ def build_extension(self, ext: CMakeExtension) -> None: if self.device == "npu": cmake_args += ["-DUSE_NPU=ON"] set_npu_envs() + _maybe_compile_tilelang_kernels(self.device) elif self.device == "mlu": cmake_args += ["-DUSE_MLU=ON"] set_mlu_envs() @@ -179,9 +225,7 @@ def build_cmake_targets( ) -> None: """Build CMake targets""" cmake_dir = get_cmake_dir() - subprocess.check_call( - ["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env - ) + subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) base_build_args = build_args # add build target to speed up the build process @@ -230,9 +274,7 @@ def build_cmake_targets( ) -> None: """Override method: only build the specified test target and run""" cmake_dir = get_cmake_dir() - subprocess.check_call( - ["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env - ) + subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) base_build_args = build_args # Only build the specified test target @@ -514,7 +556,6 @@ def parse_arguments() -> dict[str, Any]: default='auto', help='Device type: npu, mlu, ilu, cuda or musa (case-insensitive)' ) - parser.add_argument( '--generate-so', type=str.lower, @@ -591,7 +632,7 @@ def parse_arguments() -> dict[str, Any]: version=version, license="Apache 2.0", author="xLLM Team", - author_email="infer@jd.com", + author_email="infer@xllm.ai", description="A high-performance inference system for large language models.", long_description=read_readme(), long_description_content_type="text/markdown", @@ -618,6 +659,7 @@ def parse_arguments() -> dict[str, Any]: "test": test_cmd, 'bdist_wheel': BuildDistWheel}, options=options, + packages=find_namespace_packages(include=["scripts.build_support"]), zip_safe=False, py_modules=["xllm/launch_xllm", "xllm/__init__", "xllm/pybind/llm", "xllm/pybind/vlm", diff --git a/third_party/Mooncake b/third_party/Mooncake index 8c31e7484..1adcb2fb3 160000 --- a/third_party/Mooncake +++ b/third_party/Mooncake @@ -1 +1 @@ -Subproject commit 8c31e74842584c08ce1312b3a0b3347ba597263e +Subproject commit 1adcb2fb399661e69df0bc4c342c3f4a7028b8e8 diff --git a/third_party/tilelang-ascend b/third_party/tilelang-ascend new file mode 160000 index 000000000..289e1ae64 --- /dev/null +++ b/third_party/tilelang-ascend @@ -0,0 +1 @@ +Subproject commit 289e1ae6490084c87c74a1b635dd869ce75ed9ce diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index bf90ef22c..9dc44e054 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit bf90ef22cc789be1a89541da11d2813ef2c8dd4c +Subproject commit 9dc44e054e62a5afc778491674ec60d2298a7a1b diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index 918c03d2a..d6aa214ce 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit 918c03d2abc4c9996196a797aefe743863b7e0ae +Subproject commit d6aa214ce69acac8a3061ee8f0ef48b94dd3f5f6 diff --git a/third_party/xllm_ops b/third_party/xllm_ops index cc2845f24..d2236de5b 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit cc2845f2432731bc3af7d16edbc5be171ca30c8a +Subproject commit d2236de5b4820c3a58e65ab9fe1de63d37c3fedc diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index f882f0f30..722ee4ee6 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -8,7 +8,7 @@ cc_library( api_service.h api_service_impl.h call.h - chat_json_utils.h + chat_json_parser.h completion_service_impl.h rec_completion_service_impl.h chat_service_impl.h @@ -20,6 +20,7 @@ cc_library( qwen3_rerank_service_impl.h non_stream_call.h service_impl_factory.h + serving_mode.h stream_call.h models_service_impl.h stream_output_parser.h @@ -28,6 +29,8 @@ cc_library( utils.h SRCS api_service.cpp + chat_json_parser.cpp + service_impl_factory.cpp call.cpp completion_service_impl.cpp rec_completion_service_impl.cpp @@ -89,7 +92,7 @@ cc_test( NAME api_service_test SRCS - chat_json_utils_test.cpp + chat_json_parser_test.cpp DEPS api_service GTest::gtest_main diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 992c1878e..bff43b0c2 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -21,11 +21,12 @@ limitations under the License. #include #include -#include +#include "api_service/chat_json_parser.h" +#include "api_service/service_impl_factory.h" +#include "api_service/serving_mode.h" #include "call.h" #include "chat.pb.h" -#include "chat_json_utils.h" #include "common.pb.h" #include "completion.pb.h" #include "core/common/constants.h" @@ -54,66 +55,20 @@ google::protobuf::Arena* GetArenaWithCheck( } } -std::string build_sample_backend_error_message() { - return "Current backend '" + FLAGS_backend + - "' does not support /v1/sample; only llm is supported"; -} +const char* kSampleNotSupportedError = "/v1/sample is only supported for LLM"; + } // namespace APIService::APIService(Master* master, const std::vector& model_names, const std::vector& model_versions) : master_(master) { + set_model_master(model_names[0], master); if (FLAGS_node_rank != 0) { - set_model_master(model_names[0], master); return; } - if (FLAGS_backend == "llm") { - auto llm_master = dynamic_cast(master); - anthropic_service_impl_ = - std::make_unique(llm_master, model_names); - completion_service_impl_ = - ServiceImplFactory::create_service_impl( - llm_master, model_names); - sample_service_impl_ = - ServiceImplFactory::create_service_impl(llm_master, - model_names); - chat_service_impl_ = - ServiceImplFactory::create_service_impl(llm_master, - model_names); - embedding_service_impl_ = - ServiceImplFactory::create_service_impl( - llm_master, model_names); - if (FLAGS_enable_qwen3_reranker) { - rerank_service_impl_ = - ServiceImplFactory::create_service_impl( - llm_master, model_names); - } else { - rerank_service_impl_ = - ServiceImplFactory::create_service_impl( - llm_master, model_names); - } - } else if (FLAGS_backend == "vlm") { - auto vlm_master = dynamic_cast(master); - mm_chat_service_impl_ = - std::make_unique(vlm_master, model_names); - mm_embedding_service_impl_ = - std::make_unique(vlm_master, model_names); - } else if (FLAGS_backend == "dit") { - image_generation_service_impl_ = - std::make_unique( - dynamic_cast(master), model_names); - } else if (FLAGS_backend == "rec") { - auto rec_master = dynamic_cast(master); - rec_completion_service_impl_ = - std::make_unique(rec_master, model_names); - chat_service_impl_ = - std::make_unique(rec_master, model_names); - } - set_model_master(model_names[0], master); - models_service_impl_ = - ServiceImplFactory::create_service_impl( - model_names, model_versions); + ServiceImplFactory::create(this, master, model_names, model_versions); + register_chat_completions_handler(); } void APIService::set_model_master(const std::string& model_id, Master* master) { @@ -155,9 +110,9 @@ void APIService::Completions(::google::protobuf::RpcController* controller, } auto ctrl = reinterpret_cast(controller); - if (FLAGS_backend == "llm") { + if (completion_service_impl_) { completion_service_impl_->process_async_rpc_impl(request); - } else if (FLAGS_backend == "rec") { + } else if (rec_completion_service_impl_) { auto arena = GetArenaWithCheck(response); std::shared_ptr call = std::make_shared( ctrl, @@ -202,9 +157,9 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, std::shared_ptr call = std::make_shared( ctrl, done_guard.release(), req_pb, resp_pb, arena != nullptr); - if (FLAGS_backend == "llm") { + if (completion_service_impl_) { completion_service_impl_->process_async(call); - } else if (FLAGS_backend == "rec") { + } else if (rec_completion_service_impl_) { rec_completion_service_impl_->process_async(call); } } @@ -223,11 +178,10 @@ void APIService::Sample(::google::protobuf::RpcController* controller, } auto ctrl = reinterpret_cast(controller); - if (FLAGS_backend != "llm") { - ctrl->SetFailed(build_sample_backend_error_message()); + if (!sample_service_impl_) { + ctrl->SetFailed(kSampleNotSupportedError); return; } - CHECK(sample_service_impl_) << " sample service is invalid."; Status status; if (!sample_service_impl_->process_request(*request, response, &status)) { @@ -250,11 +204,10 @@ void APIService::SampleHttp(::google::protobuf::RpcController* controller, } auto ctrl = reinterpret_cast(controller); - if (FLAGS_backend != "llm") { - ctrl->SetFailed(build_sample_backend_error_message()); + if (!sample_service_impl_) { + ctrl->SetFailed(kSampleNotSupportedError); return; } - CHECK(sample_service_impl_) << " sample service is invalid."; auto arena = GetArenaWithCheck(response); auto req_pb = @@ -298,95 +251,6 @@ size_t get_json_content_length(const brpc::Controller* ctrl) { } // namespace -// Preprocess chat JSON to normalize array content to string. -// For text-only backends, combines text array items into a single string. -// For multimodal backends, passes through unchanged without parsing. -// Returns Status with processed JSON on success, or error status on failure. -std::pair preprocess_chat_json(std::string json_str, - bool is_multimodal) { - // Multimodal backends handle array content natively, skip parsing - if (is_multimodal) { - return {Status(), std::move(json_str)}; - } - - try { - auto json = nlohmann::json::parse(json_str); - if (!json.contains("messages") || !json["messages"].is_array()) { - return {Status(), std::move(json_str)}; - } - - bool modified = false; - for (auto& msg : json["messages"]) { - if (!msg.is_object()) { - return {Status(StatusCode::INVALID_ARGUMENT, - "Message in 'messages' array must be an object."), - ""}; - } - if (msg.contains("content") && msg["content"].is_array()) { - // Validate all items are text-only with proper text field - for (const auto& item : msg["content"]) { - if (!item.is_object()) { - return {Status(StatusCode::INVALID_ARGUMENT, - "Content array item must be an object."), - ""}; - } - if (!item.contains("type") || item["type"] != "text") { - // Non-text content on text-only backend is an error - return {Status(StatusCode::INVALID_ARGUMENT, - "Non-text content (e.g., image_url) requires " - "multimodal backend (-backend vlm)"), - ""}; - } - // Validate text items have proper text field - if (!item.contains("text") || !item["text"].is_string()) { - return {Status(StatusCode::INVALID_ARGUMENT, - "Missing or invalid 'text' field in content item."), - ""}; - } - } - - // All items are text-only; combine into single string. - // Pre-calculate total size to avoid reallocations. - size_t total_size = 0; - size_t num_items = msg["content"].size(); - for (const auto& item : msg["content"]) { - // Already validated above - total_size += item["text"].get_ref().size(); - } - // Add space for newline separators - if (num_items > 1) { - total_size += num_items - 1; - } - - // Reserve capacity once to avoid reallocations - std::string combined_text; - combined_text.reserve(total_size); - bool first = true; - for (const auto& item : msg["content"]) { - if (!first) { - combined_text += '\n'; - } - combined_text += item["text"].get_ref(); - first = false; - } - msg["content"] = combined_text; - modified = true; - } - } - return modified ? std::make_pair(Status(), json.dump()) - : std::make_pair(Status(), std::move(json_str)); - } catch (const nlohmann::json::exception& e) { - return {Status(StatusCode::INVALID_ARGUMENT, - "Invalid JSON format: " + std::string(e.what())), - ""}; - } catch (const std::exception& e) { - LOG(ERROR) << "Exception during JSON preprocessing: " << e.what(); - return {Status(StatusCode::UNKNOWN, - "Internal server error during JSON processing."), - ""}; - } -} - namespace { template @@ -395,7 +259,7 @@ void chat_completions_http_impl(std::unique_ptr& service, brpc::Controller* ctrl, const proto::HttpRequest* request, proto::HttpResponse* response, - bool is_multimodal) { + const ChatJsonParser& chat_json_parser) { auto arena = GetArenaWithCheck(response); auto req_pb = google::protobuf::Arena::CreateMessage(arena); @@ -412,7 +276,7 @@ void chat_completions_http_impl(std::unique_ptr& service, ctrl->request_attachment().copy_to(&attachment, content_len, 0); auto [preprocess_status, processed_json] = - preprocess_chat_json(std::move(attachment), is_multimodal); + chat_json_parser.preprocess(std::move(attachment)); if (!preprocess_status.ok()) { ctrl->SetFailed(preprocess_status.message()); LOG(ERROR) << "Complex message preprocessing failed: " @@ -437,6 +301,36 @@ void chat_completions_http_impl(std::unique_ptr& service, } // namespace +void APIService::register_chat_completions_handler() { + if (mm_chat_service_impl_) { + chat_completions_handler_ = [this](ClosureGuard& guard, + brpc::Controller* ctrl, + const proto::HttpRequest* request, + proto::HttpResponse* response) { + chat_completions_http_impl( + mm_chat_service_impl_, + guard, + ctrl, + request, + response, + ChatJsonParser::get(ServingMode::VLM)); + }; + } else if (chat_service_impl_) { + chat_completions_handler_ = [this](ClosureGuard& guard, + brpc::Controller* ctrl, + const proto::HttpRequest* request, + proto::HttpResponse* response) { + chat_completions_http_impl( + chat_service_impl_, + guard, + ctrl, + request, + response, + ChatJsonParser::get(ServingMode::LLM)); + }; + } +} + void APIService::ChatCompletions(::google::protobuf::RpcController* controller, const proto::ChatRequest* request, proto::ChatResponse* response, @@ -471,36 +365,13 @@ void APIService::ChatCompletionsHttp( return; } - auto ctrl = reinterpret_cast(controller); - - if (FLAGS_backend == "llm") { - CHECK(chat_service_impl_) << " chat service is invalid."; - chat_completions_http_impl( - chat_service_impl_, - done_guard, - ctrl, - request, - response, - /*is_multimodal=*/false); - } else if (FLAGS_backend == "vlm") { - CHECK(mm_chat_service_impl_) << " mm chat service is invalid."; - chat_completions_http_impl( - mm_chat_service_impl_, - done_guard, - ctrl, - request, - response, - /*is_multimodal=*/true); - } else if (FLAGS_backend == "rec") { - CHECK(chat_service_impl_) << " chat service is invalid."; - chat_completions_http_impl( - chat_service_impl_, - done_guard, - ctrl, - request, - response, - /*is_multimodal=*/false); + if (!chat_completions_handler_) { + LOG(ERROR) << "No chat completions handler registered"; + return; } + + auto ctrl = reinterpret_cast(controller); + chat_completions_handler_(done_guard, ctrl, request, response); } void APIService::Embeddings(::google::protobuf::RpcController* controller, @@ -560,12 +431,10 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller, const proto::HttpRequest* request, proto::HttpResponse* response, ::google::protobuf::Closure* done) { - if (FLAGS_backend == "llm") { - CHECK(embedding_service_impl_) << " embedding service is invalid."; + if (embedding_service_impl_) { handle_embedding_request( embedding_service_impl_, controller, request, response, done); - } else if (FLAGS_backend == "vlm") { - CHECK(mm_embedding_service_impl_) << " mm embedding service is invalid."; + } else if (mm_embedding_service_impl_) { handle_embedding_request( mm_embedding_service_impl_, controller, request, response, done); } @@ -733,53 +602,6 @@ void APIService::ModelVersionsHttp( namespace { -// Preprocess Anthropic API JSON to convert "content" field to -// protobuf-compatible format Anthropic API uses "content" field which can be -// string or array Our protobuf uses "content_string" for string and -// "content_blocks" for array -std::string preprocess_anthropic_json(const std::string& json_str) { - try { - nlohmann::json j = nlohmann::json::parse(json_str); - - if (j.contains("messages") && j["messages"].is_array()) { - for (auto& msg : j["messages"]) { - if (msg.contains("content")) { - auto& content = msg["content"]; - if (content.is_string()) { - // Convert "content": "string" to "content_string": "string" - msg["content_string"] = content.get(); - msg.erase("content"); - } else if (content.is_array()) { - // Convert "content": [...] to "content_blocks": {"blocks": [...]} - nlohmann::json content_blocks; - content_blocks["blocks"] = content; - msg["content_blocks"] = content_blocks; - msg.erase("content"); - } - } - } - } - - if (j.contains("system")) { - auto& system = j["system"]; - if (system.is_string()) { - j["system_string"] = system.get(); - j.erase("system"); - } else if (system.is_array()) { - nlohmann::json system_blocks; - system_blocks["blocks"] = system; - j["system_blocks"] = system_blocks; - j.erase("system"); - } - } - - return j.dump(); - } catch (const std::exception& e) { - LOG(ERROR) << "Failed to preprocess Anthropic JSON: " << e.what(); - return json_str; // Return original on error - } -} - void handle_anthropic_messages(std::unique_ptr& service, xllm::ClosureGuard& guard, brpc::Controller* ctrl, @@ -801,9 +623,14 @@ void handle_anthropic_messages(std::unique_ptr& service, std::string attachment; ctrl->request_attachment().copy_to(&attachment, content_len, 0); - // Preprocess JSON to convert Anthropic API format to protobuf-compatible - // format - std::string processed_json = preprocess_anthropic_json(attachment); + auto [preprocess_status, processed_json] = + ChatJsonParser::anthropic().preprocess(std::move(attachment)); + if (!preprocess_status.ok()) { + ctrl->SetFailed(preprocess_status.message()); + LOG(ERROR) << "Anthropic JSON preprocessing failed: " + << preprocess_status.message(); + return; + } google::protobuf::util::JsonParseOptions options; options.ignore_unknown_fields = true; @@ -840,13 +667,12 @@ void APIService::AnthropicMessagesHttp( auto ctrl = reinterpret_cast(controller); - if (FLAGS_backend == "llm") { - CHECK(anthropic_service_impl_) << " anthropic service is invalid."; + if (anthropic_service_impl_) { handle_anthropic_messages( anthropic_service_impl_, done_guard, ctrl, request, response); } else { - ctrl->SetFailed("Anthropic messages API is only supported for LLM backend"); - LOG(ERROR) << "Anthropic messages API is only supported for LLM backend"; + ctrl->SetFailed("Anthropic messages API is only supported for LLM engine"); + LOG(ERROR) << "Anthropic messages API is only supported for LLM engine"; } } @@ -918,8 +744,8 @@ void APIService::ForkMasterHttp(::google::protobuf::RpcController* controller, return; } - if (FLAGS_backend != "llm") { - LOG(ERROR) << "fork master only supports llm backend"; + if (to_serving_mode(master_->engine_type()) != ServingMode::LLM) { + LOG(ERROR) << "fork master only supports LLM engine"; return; } @@ -1101,7 +927,7 @@ void APIService::WakeupHttp(::google::protobuf::RpcController* controller, std::vector segments; segments.reserve(seg_list.segments_size()); for (const auto& proto_seg : seg_list.segments()) { - segments.push_back({proto_seg.offset(), proto_seg.size()}); + segments.emplace_back(proto_seg.offset(), proto_seg.size()); } wakeup_options.src_weight_segments.push_back(std::move(segments)); } diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index a1fe12a02..43bd054f0 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include #include @@ -33,7 +34,12 @@ limitations under the License. namespace xllm { +class ClosureGuard; +class ServiceImplFactory; + class APIService : public proto::XllmAPIService { + friend class ServiceImplFactory; + public: APIService(Master* master, const std::vector& model_names, @@ -171,6 +177,13 @@ class APIService : public proto::XllmAPIService { ::google::protobuf::Closure* done) override; private: + using ChatHttpHandler = std::function; + + void register_chat_completions_handler(); + bool ParseForkMasterRequest(const proto::MasterInfos* request, Options& options); void set_model_master(const std::string& model_id, Master* master); @@ -179,6 +192,7 @@ class APIService : public proto::XllmAPIService { Master* get_model_master(const std::string& model_id) const; Master* master_; + ChatHttpHandler chat_completions_handler_; mutable std::shared_mutex masters_mutex_; std::unordered_map masters_; std::unique_ptr anthropic_service_impl_; diff --git a/xllm/api_service/chat_json_parser.cpp b/xllm/api_service/chat_json_parser.cpp new file mode 100644 index 000000000..6c205b92a --- /dev/null +++ b/xllm/api_service/chat_json_parser.cpp @@ -0,0 +1,164 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "api_service/chat_json_parser.h" + +#include + +#include + +namespace xllm { + +const ChatJsonParser& ChatJsonParser::get(ServingMode mode) { + if (mode == ServingMode::VLM) { + static const VlmChatJsonParser k_vlm_parser; + return k_vlm_parser; + } + static const LlmChatJsonParser k_llm_parser; + return k_llm_parser; +} + +const ChatJsonParser& ChatJsonParser::anthropic() { + static const AnthropicChatJsonParser k_anthropic_parser; + return k_anthropic_parser; +} + +std::pair VlmChatJsonParser::preprocess( + std::string json_str) const { + return {Status(), std::move(json_str)}; +} + +std::pair LlmChatJsonParser::preprocess( + std::string json_str) const { + try { + auto json = nlohmann::json::parse(json_str); + if (!json.contains("messages") || !json["messages"].is_array()) { + return {Status(), std::move(json_str)}; + } + + bool modified = false; + for (auto& msg : json["messages"]) { + if (!msg.is_object()) { + return {Status(StatusCode::INVALID_ARGUMENT, + "Message in 'messages' array must be an object."), + ""}; + } + if (msg.contains("content") && msg["content"].is_array()) { + for (const auto& item : msg["content"]) { + if (!item.is_object()) { + return {Status(StatusCode::INVALID_ARGUMENT, + "Content array item must be an object."), + ""}; + } + if (!item.contains("type") || item["type"] != "text") { + return {Status(StatusCode::INVALID_ARGUMENT, + "Non-text content (e.g., image_url) requires " + "multimodal backend (-backend vlm)"), + ""}; + } + if (!item.contains("text") || !item["text"].is_string()) { + return {Status(StatusCode::INVALID_ARGUMENT, + "Missing or invalid 'text' field in content item."), + ""}; + } + } + + size_t total_size = 0; + size_t num_items = msg["content"].size(); + for (const auto& item : msg["content"]) { + total_size += item["text"].get_ref().size(); + } + if (num_items > 1) { + total_size += num_items - 1; + } + + std::string combined_text; + combined_text.reserve(total_size); + bool first = true; + for (const auto& item : msg["content"]) { + if (!first) { + combined_text += '\n'; + } + combined_text += item["text"].get_ref(); + first = false; + } + msg["content"] = combined_text; + modified = true; + } + } + return modified ? std::make_pair(Status(), json.dump()) + : std::make_pair(Status(), std::move(json_str)); + } catch (const nlohmann::json::exception& e) { + return {Status(StatusCode::INVALID_ARGUMENT, + "Invalid JSON format: " + std::string(e.what())), + ""}; + } catch (const std::exception& e) { + LOG(ERROR) << "Exception during JSON preprocessing: " << e.what(); + return {Status(StatusCode::UNKNOWN, + "Internal server error during JSON processing."), + ""}; + } +} + +std::pair AnthropicChatJsonParser::preprocess( + std::string json_str) const { + try { + auto j = nlohmann::json::parse(json_str); + + if (j.contains("messages") && j["messages"].is_array()) { + for (auto& msg : j["messages"]) { + if (!msg.contains("content")) { + continue; + } + auto& content = msg["content"]; + if (content.is_string()) { + msg["content_string"] = content.get(); + msg.erase("content"); + } else if (content.is_array()) { + nlohmann::json content_blocks; + content_blocks["blocks"] = content; + msg["content_blocks"] = content_blocks; + msg.erase("content"); + } + } + } + + if (j.contains("system")) { + auto& system = j["system"]; + if (system.is_string()) { + j["system_string"] = system.get(); + j.erase("system"); + } else if (system.is_array()) { + nlohmann::json system_blocks; + system_blocks["blocks"] = system; + j["system_blocks"] = system_blocks; + j.erase("system"); + } + } + + return {Status(), j.dump()}; + } catch (const nlohmann::json::exception& e) { + return {Status(StatusCode::INVALID_ARGUMENT, + "Invalid JSON format: " + std::string(e.what())), + ""}; + } catch (const std::exception& e) { + LOG(ERROR) << "Exception during Anthropic JSON preprocessing: " << e.what(); + return {Status(StatusCode::UNKNOWN, + "Internal server error during JSON processing."), + ""}; + } +} + +} // namespace xllm diff --git a/xllm/api_service/chat_json_parser.h b/xllm/api_service/chat_json_parser.h new file mode 100644 index 000000000..85c69db97 --- /dev/null +++ b/xllm/api_service/chat_json_parser.h @@ -0,0 +1,68 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "api_service/serving_mode.h" +#include "core/common/types.h" + +namespace xllm { + +// Normalizes OpenAI-style chat JSON before protobuf parsing. LLM backends +// collapse text-only content arrays into a single string; VLM backends pass +// JSON through for downstream multimodal handling. +class ChatJsonParser { + public: + virtual ~ChatJsonParser() = default; + + [[nodiscard]] virtual std::pair preprocess( + std::string json_str) const = 0; + + // Returns the singleton parser for the given serving mode. + // LLM/REC → LlmChatJsonParser, VLM → VlmChatJsonParser. + static const ChatJsonParser& get(ServingMode mode); + + // Returns the Anthropic protocol parser (separate from serving mode). + static const ChatJsonParser& anthropic(); +}; + +// Text-only backend: combines array content items of type "text" into one +// string; rejects non-text parts (e.g. image_url). +class LlmChatJsonParser final : public ChatJsonParser { + public: + std::pair preprocess( + std::string json_str) const override; +}; + +// Multimodal backend: no preprocessing; array content stays as-is. +class VlmChatJsonParser final : public ChatJsonParser { + public: + std::pair preprocess( + std::string json_str) const override; +}; + +// Anthropic Messages API: remaps "content" (string|array) to +// "content_string"/"content_blocks" and "system" (string|array) to +// "system_string"/"system_blocks" for protobuf compatibility. +class AnthropicChatJsonParser final : public ChatJsonParser { + public: + std::pair preprocess( + std::string json_str) const override; +}; + +} // namespace xllm diff --git a/xllm/api_service/chat_json_utils_test.cpp b/xllm/api_service/chat_json_parser_test.cpp similarity index 58% rename from xllm/api_service/chat_json_utils_test.cpp rename to xllm/api_service/chat_json_parser_test.cpp index 92cb316ba..303997ced 100644 --- a/xllm/api_service/chat_json_utils_test.cpp +++ b/xllm/api_service/chat_json_parser_test.cpp @@ -1,4 +1,4 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. +/* Copyright 2026 The xLLM Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "chat_json_utils.h" +#include "api_service/chat_json_parser.h" #include @@ -23,23 +23,20 @@ namespace xllm { class PreprocessChatJsonTest : public ::testing::Test { protected: - // Helper to check successful preprocessing - void ExpectSuccess(const std::string& input, - bool is_multimodal, - const std::string& expected_output) { - auto [status, result] = preprocess_chat_json(input, is_multimodal); + void expect_success(const std::string& input, + const ChatJsonParser& parser, + const std::string& expected_output) { + auto [status, result] = parser.preprocess(input); ASSERT_TRUE(status.ok()) << "Unexpected error: " << status.message(); - // Parse both to compare JSON structure, not string equality auto result_json = nlohmann::json::parse(result); auto expected_json = nlohmann::json::parse(expected_output); EXPECT_EQ(result_json, expected_json); } - // Helper to check expected error - void ExpectError(const std::string& input, - bool is_multimodal, - const std::string& expected_error_substring) { - auto [status, result] = preprocess_chat_json(input, is_multimodal); + void expect_error(const std::string& input, + const ChatJsonParser& parser, + const std::string& expected_error_substring) { + auto [status, result] = parser.preprocess(input); ASSERT_FALSE(status.ok()) << "Expected error but got success"; EXPECT_NE(status.message().find(expected_error_substring), std::string::npos) @@ -58,14 +55,17 @@ TEST_F(PreprocessChatJsonTest, PassThroughNonArrayContent) { std::string input = R"({ "messages": [{"role": "user", "content": "Hello"}] })"; - ExpectSuccess(input, /*is_multimodal=*/false, input); - ExpectSuccess(input, /*is_multimodal=*/true, input); + LlmChatJsonParser llm_parser; + VlmChatJsonParser vlm_parser; + expect_success(input, llm_parser, input); + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, PassThroughNoMessages) { // JSON without messages field should pass through std::string input = R"({"model": "test"})"; - ExpectSuccess(input, /*is_multimodal=*/false, input); + LlmChatJsonParser llm_parser; + expect_success(input, llm_parser, input); } TEST_F(PreprocessChatJsonTest, CombineTextArrayIntoString) { @@ -83,9 +83,11 @@ TEST_F(PreprocessChatJsonTest, CombineTextArrayIntoString) { std::string expected = R"({ "messages": [{"role": "user", "content": "Hello\nWorld"}] })"; - ExpectSuccess(input, /*is_multimodal=*/false, expected); + LlmChatJsonParser llm_parser; + VlmChatJsonParser vlm_parser; + expect_success(input, llm_parser, expected); // For multimodal, array is preserved (not combined) - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, SingleTextItemCombined) { @@ -99,9 +101,11 @@ TEST_F(PreprocessChatJsonTest, SingleTextItemCombined) { std::string expected = R"({ "messages": [{"role": "user", "content": "Hello"}] })"; - ExpectSuccess(input, /*is_multimodal=*/false, expected); + LlmChatJsonParser llm_parser; + VlmChatJsonParser vlm_parser; + expect_success(input, llm_parser, expected); // For multimodal, array is preserved - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } // ============================================================================= @@ -119,8 +123,9 @@ TEST_F(PreprocessChatJsonTest, ImageUrlPassesThroughOnMultimodal) { ] }] })"; + VlmChatJsonParser vlm_parser; // Should pass through unchanged for multimodal - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, ImageUrlErrorsOnTextOnly) { @@ -134,8 +139,9 @@ TEST_F(PreprocessChatJsonTest, ImageUrlErrorsOnTextOnly) { ] }] })"; - ExpectError(input, /*is_multimodal=*/false, "multimodal backend"); - ExpectError(input, /*is_multimodal=*/false, "-backend vlm"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "multimodal backend"); + expect_error(input, llm_parser, "-backend vlm"); } TEST_F(PreprocessChatJsonTest, MultipleMessagesWithMixedContent) { @@ -156,8 +162,9 @@ TEST_F(PreprocessChatJsonTest, MultipleMessagesWithMixedContent) { } ] })"; + VlmChatJsonParser vlm_parser; // On multimodal: all arrays preserved unchanged - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } // ============================================================================= @@ -166,35 +173,38 @@ TEST_F(PreprocessChatJsonTest, MultipleMessagesWithMixedContent) { TEST_F(PreprocessChatJsonTest, InvalidJsonReturnsError) { std::string input = "not valid json"; - ExpectError(input, /*is_multimodal=*/false, "Invalid JSON"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "Invalid JSON"); } TEST_F(PreprocessChatJsonTest, NonObjectMessageReturnsError) { std::string input = R"({"messages": ["not an object"]})"; - ExpectError(input, /*is_multimodal=*/false, "must be an object"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "must be an object"); } TEST_F(PreprocessChatJsonTest, NonObjectContentItemReturnsError) { std::string input = R"({ "messages": [{"role": "user", "content": ["not an object"]}] })"; - ExpectError(input, /*is_multimodal=*/false, "must be an object"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "must be an object"); } TEST_F(PreprocessChatJsonTest, MissingTextFieldReturnsError) { std::string input = R"({ "messages": [{"role": "user", "content": [{"type": "text"}]}] })"; - ExpectError( - input, /*is_multimodal=*/false, "Missing or invalid 'text' field"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "Missing or invalid 'text' field"); } TEST_F(PreprocessChatJsonTest, NonStringTextFieldReturnsError) { std::string input = R"({ "messages": [{"role": "user", "content": [{"type": "text", "text": 123}]}] })"; - ExpectError( - input, /*is_multimodal=*/false, "Missing or invalid 'text' field"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "Missing or invalid 'text' field"); } TEST_F(PreprocessChatJsonTest, MalformedTextInMultimodalContent) { @@ -208,8 +218,9 @@ TEST_F(PreprocessChatJsonTest, MalformedTextInMultimodalContent) { ] }] })"; + VlmChatJsonParser vlm_parser; // Should pass through unchanged without validation - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } // ============================================================================= @@ -224,9 +235,11 @@ TEST_F(PreprocessChatJsonTest, EmptyContentArray) { std::string expected = R"({ "messages": [{"role": "user", "content": ""}] })"; - ExpectSuccess(input, /*is_multimodal=*/false, expected); + LlmChatJsonParser llm_parser; + VlmChatJsonParser vlm_parser; + expect_success(input, llm_parser, expected); // For multimodal, empty array is preserved - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, PreservesOtherFields) { @@ -243,9 +256,11 @@ TEST_F(PreprocessChatJsonTest, PreservesOtherFields) { "temperature": 0.7, "max_tokens": 100 })"; - ExpectSuccess(input, /*is_multimodal=*/false, expected); + LlmChatJsonParser llm_parser; + VlmChatJsonParser vlm_parser; + expect_success(input, llm_parser, expected); // For multimodal, array is preserved - ExpectSuccess(input, /*is_multimodal=*/true, input); + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, UnknownContentTypeOnMultimodal) { @@ -256,7 +271,8 @@ TEST_F(PreprocessChatJsonTest, UnknownContentTypeOnMultimodal) { "content": [{"type": "video", "video": {"url": "..."}}] }] })"; - ExpectSuccess(input, /*is_multimodal=*/true, input); + VlmChatJsonParser vlm_parser; + expect_success(input, vlm_parser, input); } TEST_F(PreprocessChatJsonTest, UnknownContentTypeErrorsOnTextOnly) { @@ -267,7 +283,101 @@ TEST_F(PreprocessChatJsonTest, UnknownContentTypeErrorsOnTextOnly) { "content": [{"type": "video", "video": {"url": "..."}}] }] })"; - ExpectError(input, /*is_multimodal=*/false, "multimodal backend"); + LlmChatJsonParser llm_parser; + expect_error(input, llm_parser, "multimodal backend"); +} + +// ============================================================================= +// Anthropic parser tests +// ============================================================================= + +TEST_F(PreprocessChatJsonTest, AnthropicStringContentRemapped) { + std::string input = R"({ + "messages": [{"role": "user", "content": "Hello"}] + })"; + std::string expected = R"({ + "messages": [{"role": "user", "content_string": "Hello"}] + })"; + AnthropicChatJsonParser parser; + expect_success(input, parser, expected); +} + +TEST_F(PreprocessChatJsonTest, AnthropicArrayContentRemapped) { + std::string input = R"({ + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "source": {"data": "abc"}} + ] + }] + })"; + std::string expected = R"({ + "messages": [{ + "role": "user", + "content_blocks": { + "blocks": [ + {"type": "text", "text": "Hello"}, + {"type": "image", "source": {"data": "abc"}} + ] + } + }] + })"; + AnthropicChatJsonParser parser; + expect_success(input, parser, expected); +} + +TEST_F(PreprocessChatJsonTest, AnthropicSystemStringRemapped) { + std::string input = R"({ + "system": "You are helpful.", + "messages": [{"role": "user", "content": "Hi"}] + })"; + std::string expected = R"({ + "system_string": "You are helpful.", + "messages": [{"role": "user", "content_string": "Hi"}] + })"; + AnthropicChatJsonParser parser; + expect_success(input, parser, expected); +} + +TEST_F(PreprocessChatJsonTest, AnthropicSystemArrayRemapped) { + std::string input = R"({ + "system": [{"type": "text", "text": "You are helpful."}], + "messages": [{"role": "user", "content": "Hi"}] + })"; + std::string expected = R"({ + "system_blocks": {"blocks": [{"type": "text", "text": "You are helpful."}]}, + "messages": [{"role": "user", "content_string": "Hi"}] + })"; + AnthropicChatJsonParser parser; + expect_success(input, parser, expected); +} + +TEST_F(PreprocessChatJsonTest, AnthropicNoContentNoSystem) { + std::string input = R"({"model": "claude-3"})"; + AnthropicChatJsonParser parser; + expect_success(input, parser, input); +} + +TEST_F(PreprocessChatJsonTest, AnthropicInvalidJsonReturnsError) { + std::string input = "not valid json"; + AnthropicChatJsonParser parser; + expect_error(input, parser, "Invalid JSON"); +} + +TEST_F(PreprocessChatJsonTest, AnthropicPreservesOtherFields) { + std::string input = R"({ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello"}] + })"; + std::string expected = R"({ + "model": "claude-3", + "max_tokens": 1024, + "messages": [{"role": "user", "content_string": "Hello"}] + })"; + AnthropicChatJsonParser parser; + expect_success(input, parser, expected); } } // namespace xllm diff --git a/xllm/api_service/non_stream_call.h b/xllm/api_service/non_stream_call.h index 66e41c91b..641513fba 100644 --- a/xllm/api_service/non_stream_call.h +++ b/xllm/api_service/non_stream_call.h @@ -45,8 +45,7 @@ class NonStreamCall : public Call { request_(request), response_(response), use_arena_(use_arena) { - controller_->http_response().SetHeader("Content-Type", - "text/javascript; charset=utf-8"); + controller_->http_response().set_content_type("application/json"); json_options_.bytes_to_base64 = false; json_options_.jsonify_empty_array = true; diff --git a/xllm/api_service/openai_service_test.cpp b/xllm/api_service/openai_service_test.cpp index b9261d3a1..c4e4279bb 100644 --- a/xllm/api_service/openai_service_test.cpp +++ b/xllm/api_service/openai_service_test.cpp @@ -56,6 +56,7 @@ struct TestConfig { struct HttpResult { bool controller_failed = false; int status_code = 0; + std::string content_type; std::string error_text; std::string body; nlohmann::json json = nullptr; @@ -92,6 +93,7 @@ class HttpClient { HttpResult result; result.controller_failed = cntl.Failed(); result.status_code = cntl.http_response().status_code(); + result.content_type = cntl.http_response().content_type(); result.error_text = cntl.ErrorText(); result.body = cntl.response_attachment().to_string(); if (!result.body.empty()) { @@ -110,6 +112,9 @@ class HttpClient { std::string describe_result(const HttpResult& result) { std::string description = "status=" + std::to_string(result.status_code); + if (!result.content_type.empty()) { + description += ", content_type=" + result.content_type; + } if (!result.error_text.empty()) { description += ", error=" + result.error_text; } @@ -196,6 +201,7 @@ TEST_F(DISABLED_OpenAIServerFeaturesTest, SampleSingleMatch) { ASSERT_FALSE(result.controller_failed) << describe_result(result); ASSERT_EQ(result.status_code, 200) << describe_result(result); + EXPECT_EQ(result.content_type, "application/json") << describe_result(result); ASSERT_TRUE(result.json.is_object()) << describe_result(result); EXPECT_EQ(result.json["id"], "sample-it"); EXPECT_EQ(result.json["object"], "sample_completion"); @@ -261,6 +267,7 @@ TEST_F(DISABLED_OpenAIServerFeaturesTest, CompletionsRegressionSmoke) { ASSERT_FALSE(result.controller_failed) << describe_result(result); ASSERT_EQ(result.status_code, 200) << describe_result(result); + EXPECT_EQ(result.content_type, "application/json") << describe_result(result); ASSERT_TRUE(result.json.is_object()) << describe_result(result); ASSERT_TRUE(result.json.contains("choices")); ASSERT_EQ(result.json["choices"].size(), 1); @@ -279,6 +286,7 @@ TEST_F(DISABLED_OpenAIServerFeaturesTest, ChatCompletionsRegressionSmoke) { ASSERT_FALSE(result.controller_failed) << describe_result(result); ASSERT_EQ(result.status_code, 200) << describe_result(result); + EXPECT_EQ(result.content_type, "application/json") << describe_result(result); ASSERT_TRUE(result.json.is_object()) << describe_result(result); ASSERT_TRUE(result.json.contains("choices")); ASSERT_EQ(result.json["choices"].size(), 1); diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp index b71eccef9..55a8b71ec 100644 --- a/xllm/api_service/rec_completion_service_impl.cpp +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include #include #include +#include #include "common/global_flags.h" #include "common/instance_name.h" @@ -42,6 +44,21 @@ limitations under the License. namespace xllm { namespace { +void append_rec_logprobs(proto::InferTensorContents* logprobs_context, + const SequenceOutput& output, + int32_t expected_count) { + const auto& token_logprobs = output.token_ids_logprobs; + const int32_t actual_count = static_cast(token_logprobs.size()); + + for (int32_t i = 0; i < expected_count; ++i) { + if (i < actual_count && token_logprobs[i].has_value()) { + logprobs_context->mutable_fp32_contents()->Add(token_logprobs[i].value()); + } else { + logprobs_context->mutable_fp32_contents()->Add(0.0f); + } + } +} + void set_logprobs(proto::Choice* choice, const std::optional>& logprobs) { if (!logprobs.has_value() || logprobs.value().empty()) { @@ -91,32 +108,93 @@ bool send_result_to_client_brpc_rec(std::shared_ptr call, // Add rec specific output tensors auto output_tensor = response.mutable_output_tensors()->Add(); output_tensor->set_name("rec_result"); - if (FLAGS_enable_constrained_decoding) { + proto::InferOutputTensor* logprobs_tensor = nullptr; + int32_t logprob_width = 0; + if (FLAGS_enable_output_sku_logprobs && !req_output.outputs.empty()) { + logprobs_tensor = response.mutable_output_tensors()->Add(); + logprobs_tensor->set_name("sku_logprobs"); + logprobs_tensor->set_datatype(proto::DataType::FLOAT); + logprob_width = + static_cast(req_output.outputs[0].token_ids_logprobs.size()); + } + + if (FLAGS_enable_convert_tokens_to_item) { output_tensor->set_datatype(proto::DataType::INT64); - output_tensor->mutable_shape()->Add(req_output.outputs.size()); - output_tensor->mutable_shape()->Add(1); // Single item per output - // TODO: add following when next pr. - /* - auto context = output_tensor->mutable_contents(); - for (int i = 0; i < req_output.outputs.size(); ++i) { - if (req_output.outputs[i].item_ids.has_value()) { - context->mutable_int64_contents()->Add( - req_output.outputs[i].item_ids.value()); + const int32_t output_count = + static_cast(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(output_count); + if (logprobs_tensor != nullptr) { + logprobs_tensor->mutable_shape()->Add(output_count); + logprobs_tensor->mutable_shape()->Add(logprob_width); + } + + auto* output_context = output_tensor->mutable_contents(); + auto* logprobs_context = logprobs_tensor == nullptr + ? nullptr + : logprobs_tensor->mutable_contents(); + auto append_output_logprobs = [&](int32_t output_index) { + if (logprobs_context != nullptr) { + append_rec_logprobs( + logprobs_context, req_output.outputs[output_index], logprob_width); + } + }; + int32_t total_count = 0; + const int32_t total_threshold = FLAGS_total_conversion_threshold; + for (int32_t i = 0; i < output_count; ++i) { + const auto& output = req_output.outputs[i]; + if (!output.item_ids_list.empty()) { + for (const int64_t item_id : output.item_ids_list) { + if (total_count >= total_threshold) { + break; + } + output_context->mutable_int64_contents()->Add(item_id); + append_output_logprobs(i); + ++total_count; + } + } else if (output.item_ids.has_value() && total_count < total_threshold) { + output_context->mutable_int64_contents()->Add(output.item_ids.value()); + append_output_logprobs(i); + ++total_count; } } - */ } else { output_tensor->set_datatype(proto::DataType::INT32); - output_tensor->mutable_shape()->Add(req_output.outputs.size()); + if (req_output.outputs.empty()) { + output_tensor->mutable_shape()->Add(0); + output_tensor->mutable_shape()->Add(0); + if (logprobs_tensor != nullptr) { + logprobs_tensor->mutable_shape()->Add(0); + logprobs_tensor->mutable_shape()->Add(0); + } + return call->write_and_finish(response); + } + + const int32_t output_count = + static_cast(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(output_count); output_tensor->mutable_shape()->Add(req_output.outputs[0].token_ids.size()); + if (logprobs_tensor != nullptr) { + logprobs_tensor->mutable_shape()->Add(output_count); + logprobs_tensor->mutable_shape()->Add(logprob_width); + } - auto context = output_tensor->mutable_contents(); - for (int i = 0; i < req_output.outputs.size(); ++i) { + auto* context = output_tensor->mutable_contents(); + auto* logprobs_context = logprobs_tensor == nullptr + ? nullptr + : logprobs_tensor->mutable_contents(); + auto append_output_logprobs = [&](int32_t output_index) { + if (logprobs_context != nullptr) { + append_rec_logprobs( + logprobs_context, req_output.outputs[output_index], logprob_width); + } + }; + for (int32_t i = 0; i < output_count; ++i) { // LOG(INFO) << req_output.outputs[i].token_ids; context->mutable_int_contents()->Add( req_output.outputs[i].token_ids.begin(), req_output.outputs[i].token_ids.end()); + append_output_logprobs(i); } } @@ -153,6 +231,9 @@ void RecCompletionServiceImpl::process_async_impl( RequestParams request_params( rpc_request, call->get_x_request_id(), call->get_x_request_time()); + if (FLAGS_enable_output_sku_logprobs) { + request_params.logprobs = true; + } bool include_usage = false; if (rpc_request.has_stream_options()) { include_usage = rpc_request.stream_options().include_usage(); diff --git a/xllm/api_service/service_impl_factory.cpp b/xllm/api_service/service_impl_factory.cpp new file mode 100644 index 000000000..092a5aba4 --- /dev/null +++ b/xllm/api_service/service_impl_factory.cpp @@ -0,0 +1,124 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "api_service/service_impl_factory.h" + +#include + +#include +#include + +#include "api_service.h" +#include "api_service/serving_mode.h" +#include "core/common/global_flags.h" +#include "core/distributed_runtime/dit_master.h" +#include "core/distributed_runtime/llm_master.h" +#include "core/distributed_runtime/rec_master.h" +#include "core/distributed_runtime/vlm_master.h" + +namespace xllm { + +namespace { + +template +std::unique_ptr create_service_impl( + MasterT* master, + const std::vector& model_names) { + return std::make_unique(master, model_names); +} + +} // namespace + +void ServiceImplFactory::create( + APIService* service, + Master* master, + const std::vector& model_names, + const std::vector& model_versions) { + using InitFn = std::function&)>; + + static const std::unordered_map kRegistry = { + {static_cast(ServingMode::LLM), + [](APIService* self, + Master* master, + const std::vector& models) { + auto* llm_master = dynamic_cast(master); + self->anthropic_service_impl_ = + std::make_unique(llm_master, models); + self->completion_service_impl_ = + create_service_impl(llm_master, models); + self->sample_service_impl_ = + create_service_impl(llm_master, models); + self->chat_service_impl_ = + create_service_impl(llm_master, models); + self->embedding_service_impl_ = + create_service_impl(llm_master, models); + if (FLAGS_enable_qwen3_reranker) { + self->rerank_service_impl_ = + create_service_impl(llm_master, models); + } else { + self->rerank_service_impl_ = + create_service_impl(llm_master, models); + } + }}, + {static_cast(ServingMode::VLM), + [](APIService* self, + Master* master, + const std::vector& models) { + auto* vlm_master = dynamic_cast(master); + self->mm_chat_service_impl_ = + std::make_unique(vlm_master, models); + self->mm_embedding_service_impl_ = + std::make_unique(vlm_master, models); + }}, + {static_cast(ServingMode::DIT), + [](APIService* self, + Master* master, + const std::vector& models) { + self->image_generation_service_impl_ = + std::make_unique( + dynamic_cast(master), models); + }}, + {static_cast(ServingMode::REC), + [](APIService* self, + Master* master, + const std::vector& models) { + auto* rec_master = dynamic_cast(master); + self->rec_completion_service_impl_ = + std::make_unique(rec_master, models); + self->chat_service_impl_ = + std::make_unique(rec_master, models); + }}, + }; + + ServingMode mode = to_serving_mode(master->engine_type()); + auto it = kRegistry.find(static_cast(mode)); + if (it != kRegistry.end()) { + it->second(service, master, model_names); + } else { + LOG(FATAL) << "Unsupported serving mode for engine type: " + << master->engine_type().to_string(); + } + + CHECK_EQ(model_names.size(), model_versions.size()) + << "Models and model_versions size mismatch: model_names.size()=" + << model_names.size() + << ", model_versions.size()=" << model_versions.size(); + + service->models_service_impl_ = + std::make_unique(model_names, model_versions); +} + +} // namespace xllm diff --git a/xllm/api_service/service_impl_factory.h b/xllm/api_service/service_impl_factory.h index 270efc6e2..bd09e045e 100644 --- a/xllm/api_service/service_impl_factory.h +++ b/xllm/api_service/service_impl_factory.h @@ -1,4 +1,4 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. +/* Copyright 2026 The xLLM Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,38 +15,23 @@ limitations under the License. #pragma once -#include "api_service/api_service_impl.h" +#include +#include namespace xllm { -template +class APIService; +class Master; + +// Creates all service-impl instances that an APIService needs for the active +// engine type. Adding a new engine type only requires one new entry in the +// registry defined in service_impl_factory.cpp. class ServiceImplFactory { public: - static std::unique_ptr create_service_impl( - LLMMaster* master, - const std::vector& model_names) { - auto service_impl = std::make_unique(master, model_names); - if (!service_impl) { - LOG(ERROR) << "handler is nullptr"; - } - return service_impl; - } - - static std::unique_ptr create_service_impl( - const std::vector& model_names, - const std::vector& model_versions) { - if (model_names.size() != model_versions.size()) { - LOG(ERROR) - << "Models and model_versions size mismatch: model_names.size()=" - << model_names.size() - << ", model_versions.size()=" << model_versions.size(); - } - auto service_impl = std::make_unique(model_names, model_versions); - if (!service_impl) { - LOG(ERROR) << "handler is nullptr"; - } - return service_impl; - } + static void create(APIService* service, + Master* master, + const std::vector& model_names, + const std::vector& model_versions); }; } // namespace xllm diff --git a/xllm/api_service/serving_mode.h b/xllm/api_service/serving_mode.h new file mode 100644 index 000000000..dd5e5bf77 --- /dev/null +++ b/xllm/api_service/serving_mode.h @@ -0,0 +1,49 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/common/types.h" + +namespace xllm { + +// Service-layer abstraction of the active serving pipeline. +// Unlike EngineType (which includes engine-internal variants such as SSM), +// ServingMode only exposes the distinctions that matter to the API surface. +enum class ServingMode : int8_t { + LLM = 0, + VLM = 1, + DIT = 2, + REC = 3, +}; + +// Maps an engine-layer EngineType to its corresponding ServingMode. +// SSM (speculative decoding) serves the same API as LLM. +inline ServingMode to_serving_mode(EngineType engine_type) { + switch (static_cast(engine_type)) { + case EngineType::VLM: + return ServingMode::VLM; + case EngineType::DIT: + return ServingMode::DIT; + case EngineType::REC: + return ServingMode::REC; + default: + return ServingMode::LLM; + } +} + +} // namespace xllm diff --git a/xllm/api_service/stream_call.h b/xllm/api_service/stream_call.h index 97a72b5eb..6b1d07aed 100644 --- a/xllm/api_service/stream_call.h +++ b/xllm/api_service/stream_call.h @@ -51,7 +51,8 @@ class StreamCall : public Call { pa_ = controller_->CreateProgressiveAttachment(); // Send the first SSE response - controller_->http_response().set_content_type("text/event-stream"); + controller_->http_response().set_content_type( + "text/event-stream; charset=utf-8"); controller_->http_response().set_status_code(200); controller_->http_response().SetHeader("Connection", "keep-alive"); controller_->http_response().SetHeader("Cache-Control", "no-cache"); @@ -59,8 +60,7 @@ class StreamCall : public Call { done_->Run(); } else { - controller_->http_response().SetHeader("Content-Type", - "text/javascript; charset=utf-8"); + controller_->http_response().set_content_type("application/json"); } json_options_.bytes_to_base64 = false; diff --git a/xllm/c_api/internal/helper.cpp b/xllm/c_api/internal/helper.cpp index ef7cc9740..d7f16b35e 100644 --- a/xllm/c_api/internal/helper.cpp +++ b/xllm/c_api/internal/helper.cpp @@ -158,6 +158,7 @@ XLLM_Response* build_error_response(const std::string& request_id, XLLM_Response* build_success_response(const InferenceType& inference_type, const RequestOutput& output, + RecPipelineType rec_pipeline_type, const std::string& request_id, int64_t created_time, const std::string& model) { @@ -180,11 +181,30 @@ XLLM_Response* build_success_response(const InferenceType& inference_type, response->choices.entries_size = output.outputs.size(); response->choices.entries = new XLLM_Choice[response->choices.entries_size](); CHECK(nullptr != response->choices.entries); + const bool is_rec_inference = + inference_type == InferenceType::REC_COMPLETIONS || + inference_type == InferenceType::REC_CHAT_COMPLETIONS; + const bool is_onerec_pipeline = + is_rec_inference && rec_pipeline_type == RecPipelineType::kOneRecDefault; + if (is_onerec_pipeline) { + response->rec_outputs.entries_size = output.outputs.size(); + response->rec_outputs.entries = + new XLLM_RecOutput[response->rec_outputs.entries_size](); + CHECK(nullptr != response->rec_outputs.entries); + } + + int32_t total_item_count = 0; + const int32_t total_threshold = FLAGS_total_conversion_threshold; for (int i = 0; i < output.outputs.size(); i++) { const auto& seq_output = output.outputs[i]; XLLM_Choice& choice = response->choices.entries[i]; choice.index = seq_output.index; + XLLM_RecOutput* rec_output = nullptr; + if (response->rec_outputs.entries != nullptr) { + rec_output = &response->rec_outputs.entries[i]; + rec_output->index = seq_output.index; + } if (inference_type == InferenceType::LLM_COMPLETIONS || inference_type == InferenceType::REC_COMPLETIONS) { @@ -233,6 +253,50 @@ XLLM_Response* build_success_response(const InferenceType& inference_type, xllm_logprob.logprob = logprob.logprob; } } + + if (is_onerec_pipeline && FLAGS_enable_convert_tokens_to_item && + rec_output != nullptr) { + size_t copied_item_count = 0; + if (!seq_output.item_ids_list.empty()) { + copied_item_count = + std::min(seq_output.item_ids_list.size(), + static_cast( + std::max(total_threshold - total_item_count, 0))); + if (copied_item_count > 0) { + rec_output->item_ids_size = copied_item_count; + rec_output->item_ids = new int64_t[copied_item_count]; + CHECK(nullptr != rec_output->item_ids); + for (size_t j = 0; j < copied_item_count; ++j) { + rec_output->item_ids[j] = seq_output.item_ids_list[j]; + } + total_item_count += static_cast(copied_item_count); + } + } else if (seq_output.item_ids.has_value() && + total_item_count < total_threshold) { + rec_output->item_ids_size = 1; + rec_output->item_ids = new int64_t[1]; + CHECK(nullptr != rec_output->item_ids); + rec_output->item_ids[0] = seq_output.item_ids.value(); + ++total_item_count; + } + } + + if (is_onerec_pipeline && FLAGS_enable_output_sku_logprobs && + !seq_output.token_ids_logprobs.empty() && rec_output != nullptr) { + rec_output->rec_token_logprobs_size = + seq_output.token_ids_logprobs.size(); + rec_output->rec_token_logprobs = + new float[rec_output->rec_token_logprobs_size]; + CHECK(nullptr != rec_output->rec_token_logprobs); + for (size_t j = 0; j < rec_output->rec_token_logprobs_size; ++j) { + if (seq_output.token_ids_logprobs[j].has_value()) { + rec_output->rec_token_logprobs[j] = + seq_output.token_ids_logprobs[j].value(); + } else { + rec_output->rec_token_logprobs[j] = 0.0f; + } + } + } } if (output.usage.has_value()) { @@ -279,6 +343,14 @@ XLLM_Response* handle_inference_request( xllm::RequestParams xllm_request_params; transfer_request_params(inference_type, request_params, &xllm_request_params); xllm_request_params.request_id = request_id; + RecPipelineType rec_pipeline_type = RecPipelineType::kLlmRecDefault; + if constexpr (std::is_same_v) { + rec_pipeline_type = handler->pipeline_type; + if (FLAGS_enable_output_sku_logprobs && + rec_pipeline_type == RecPipelineType::kOneRecDefault) { + xllm_request_params.logprobs = true; + } + } const int64_t created_time = absl::ToUnixSeconds(absl::Now()); @@ -290,6 +362,7 @@ XLLM_Response* handle_inference_request( request_id, created_time, inference_type, + rec_pipeline_type, weak_promise = std::weak_ptr(promise_ptr)]( const RequestOutput& req_output) -> bool { if (auto locked_promise = weak_promise.lock()) { @@ -298,6 +371,7 @@ XLLM_Response* handle_inference_request( if (req_output.status.value().ok()) { locked_promise->setValue(build_success_response(inference_type, req_output, + rec_pipeline_type, request_id, created_time, model_id)); @@ -432,6 +506,24 @@ void xllm_free_response(XLLM_Response* resp) { } resp->choices.entries_size = 0; + if (nullptr != resp->rec_outputs.entries) { + for (size_t i = 0; i < resp->rec_outputs.entries_size; ++i) { + XLLM_RecOutput& rec_output = resp->rec_outputs.entries[i]; + if (nullptr != rec_output.item_ids) { + delete[] rec_output.item_ids; + rec_output.item_ids = nullptr; + rec_output.item_ids_size = 0; + } + if (nullptr != rec_output.rec_token_logprobs) { + delete[] rec_output.rec_token_logprobs; + rec_output.rec_token_logprobs = nullptr; + rec_output.rec_token_logprobs_size = 0; + } + } + delete[] resp->rec_outputs.entries; + resp->rec_outputs.entries = nullptr; + } + resp->rec_outputs.entries_size = 0; delete resp; return; diff --git a/xllm/c_api/internal/helper.h b/xllm/c_api/internal/helper.h index 3c541a637..74659a896 100644 --- a/xllm/c_api/internal/helper.h +++ b/xllm/c_api/internal/helper.h @@ -31,6 +31,7 @@ limitations under the License. #include "core/distributed_runtime/rec_master.h" #include "core/framework/request/request_output.h" #include "core/framework/request/request_params.h" +#include "core/util/rec_model_utils.h" /** * @brief Opaque handle for LLM inference instance @@ -56,6 +57,9 @@ struct XLLM_REC_Handler { /** Flag indicating if REC instance is initialized and ready for inference */ bool initialized{false}; + /** Selected REC pipeline type for the loaded model */ + xllm::RecPipelineType pipeline_type{xllm::RecPipelineType::kLlmRecDefault}; + /** List of loaded recommendation model IDs */ std::vector model_ids; @@ -127,6 +131,7 @@ XLLM_Response* build_error_response(const std::string& request_id, */ XLLM_Response* build_success_response(const InferenceType& inference_type, const xllm::RequestOutput& output, + xllm::RecPipelineType rec_pipeline_type, const std::string& request_id, int64_t created_time, const std::string& model); @@ -164,4 +169,4 @@ xllm::MMDataItem convert_xllm_mm_item_to_internal( bool convert_xllm_mm_data_to_internal(const XLLM_MM_Data* mm_data, xllm::MMData& internal_mm_data); } // namespace helper -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/c_api/internal/llm.cpp b/xllm/c_api/internal/llm.cpp index 14aa06aba..d409545ce 100644 --- a/xllm/c_api/internal/llm.cpp +++ b/xllm/c_api/internal/llm.cpp @@ -117,7 +117,7 @@ XLLM_CAPI_EXPORT bool xllm_llm_initialize( options.enable_graph(FLAGS_enable_graph); -#if !defined(USE_NPU) +#if !defined(USE_NPU) && !defined(USE_CUDA) FLAGS_enable_block_copy_kernel = false; #endif diff --git a/xllm/c_api/internal/rec.cpp b/xllm/c_api/internal/rec.cpp index 18bcbd151..eb77c4f99 100644 --- a/xllm/c_api/internal/rec.cpp +++ b/xllm/c_api/internal/rec.cpp @@ -26,13 +26,68 @@ limitations under the License. #include #include +#include "core/framework/model_loader.h" +#include "core/util/rec_model_utils.h" #include "helper.h" +namespace { + +const char* get_rec_pipeline_name(xllm::RecPipelineType pipeline_type) { + switch (pipeline_type) { + case xllm::RecPipelineType::kLlmRecDefault: + return "LlmRecEnginePipeline"; + case xllm::RecPipelineType::kLlmRecWithMmData: + return "LlmRecWithMmData"; + case xllm::RecPipelineType::kLlmRecMultiRoundPipeline: + return "RecMultiRoundEnginePipeline"; + case xllm::RecPipelineType::kOneRecDefault: + return "OneRecEnginePipeline"; + default: + return "UnknownRecPipeline"; + } +} + +void reset_pipeline_runtime_toggles() { + FLAGS_enable_rec_fast_sampler = false; + FLAGS_enable_prefill_piecewise_graph = false; + FLAGS_enable_xattention_one_stage = false; + FLAGS_enable_graph_mode_decode_no_padding = false; + FLAGS_enable_rec_prefill_only = false; + FLAGS_enable_constrained_decoding = false; + FLAGS_enable_topk_sorted = false; +} + +void apply_multi_round_pipeline_toggles() { + FLAGS_enable_rec_fast_sampler = true; + FLAGS_enable_prefill_piecewise_graph = true; + FLAGS_enable_xattention_one_stage = false; + FLAGS_enable_graph_mode_decode_no_padding = true; + FLAGS_enable_topk_sorted = false; +} + +void apply_onerec_pipeline_toggles(xllm::Options* options) { + FLAGS_enable_rec_prefill_only = true; + FLAGS_enable_constrained_decoding = true; + FLAGS_enable_prefix_cache = false; + FLAGS_enable_schedule_overlap = false; + FLAGS_enable_chunked_prefill = false; + + options->enable_prefix_cache(false) + .enable_schedule_overlap(false) + .enable_chunked_prefill(false); + + // OneRec does not use Rec multi-round decode rounds. + FLAGS_max_decode_rounds = 0; +} + +} // namespace + XLLM_CAPI_EXPORT XLLM_REC_Handler* xllm_rec_create(void) { XLLM_REC_Handler* handler = new XLLM_REC_Handler(); CHECK(nullptr != handler); handler->initialized = false; + handler->pipeline_type = xllm::RecPipelineType::kLlmRecDefault; return handler; } @@ -43,6 +98,7 @@ XLLM_CAPI_EXPORT void xllm_rec_destroy(XLLM_REC_Handler* handler) { handler->master.reset(); handler->executor.reset(); handler->model_ids.clear(); + handler->pipeline_type = xllm::RecPipelineType::kLlmRecDefault; handler->initialized = false; delete handler; @@ -116,22 +172,82 @@ XLLM_CAPI_EXPORT bool xllm_rec_initialize( // @TODO: Currently, gflags are configured through hard coding, which needs // to be improved in the future. For example, a separate gflags // configuration file can be provided to the so for setting gflags. + // + // REC so still has two configuration paths: + // - some request/runtime code reads FLAGS_* directly + // - master/worker construction reads xllm::Options + // + // The fields copied from init options below are read from FLAGS_* today. + // beam_width/block_size/max_tokens/max_seqs are also represented in + // Options, so duplicated values must stay aligned. FLAGS_beam_width = xllm_init_options.beam_width; FLAGS_max_decode_rounds = xllm_init_options.max_decode_rounds; FLAGS_max_seqs_per_batch = xllm_init_options.max_seqs_per_batch; FLAGS_max_tokens_per_batch = xllm_init_options.max_tokens_per_batch; FLAGS_block_size = xllm_init_options.block_size; + FLAGS_enable_prefix_cache = xllm_init_options.enable_prefix_cache; + FLAGS_enable_schedule_overlap = xllm_init_options.enable_schedule_overlap; + FLAGS_enable_chunked_prefill = xllm_init_options.enable_chunked_prefill; + auto model_loader = xllm::ModelLoader::create(model_path); + if (model_loader == nullptr) { + LOG(ERROR) << "Failed to create model loader for path: " << model_path; + return false; + } + const auto& model_args = model_loader->model_args(); + const xllm::RecModelKind rec_model_kind = + xllm::get_rec_model_kind(model_args.model_type()); + if (rec_model_kind == xllm::RecModelKind::kNone) { + LOG(ERROR) << "Unsupported rec model_type: " << model_args.model_type(); + return false; + } + const xllm::RecPipelineType pipeline_type = + xllm::get_rec_pipeline_type(rec_model_kind); + + // Hard-coded REC so settings. enable_graph and rec_worker_max_concurrency + // are dual-source: runtime may read FLAGS_* while setup also needs the same + // value in Options. FLAGS_enable_graph = true; FLAGS_rec_worker_max_concurrency = 2; - FLAGS_enable_rec_fast_sampler = true; - FLAGS_enable_prefill_piecewise_graph = true; - FLAGS_enable_xattention_one_stage = false; - FLAGS_enable_graph_mode_decode_no_padding = true; - // FLAGS_enable_rec_prefill_only = true; - FLAGS_enable_topk_sorted = false; - options.enable_graph(FLAGS_enable_graph); + // Pipeline-specific runtime toggles in the REC so path. + reset_pipeline_runtime_toggles(); + switch (pipeline_type) { + case xllm::RecPipelineType::kLlmRecMultiRoundPipeline: + apply_multi_round_pipeline_toggles(); + break; + case xllm::RecPipelineType::kOneRecDefault: + apply_onerec_pipeline_toggles(&options); + break; + case xllm::RecPipelineType::kLlmRecDefault: + case xllm::RecPipelineType::kLlmRecWithMmData: + break; + default: + LOG(ERROR) << "Unsupported rec pipeline type: " + << static_cast(pipeline_type); + return false; + } + + // Keep dual-source settings aligned with the FLAGS_* values above. + options.enable_graph(FLAGS_enable_graph) + .beam_width(FLAGS_beam_width) + .rec_worker_max_concurrency(FLAGS_rec_worker_max_concurrency); + + LOG(INFO) << "REC C API selected pipeline=" + << get_rec_pipeline_name(pipeline_type) + << ", model_type=" << model_args.model_type() + << ", enable_rec_prefill_only=" << FLAGS_enable_rec_prefill_only + << ", enable_constrained_decoding=" + << FLAGS_enable_constrained_decoding + << ", enable_prefix_cache=" << FLAGS_enable_prefix_cache + << ", enable_schedule_overlap=" << FLAGS_enable_schedule_overlap + << ", enable_chunked_prefill=" << FLAGS_enable_chunked_prefill + << ", enable_rec_fast_sampler=" << FLAGS_enable_rec_fast_sampler + << ", max_decode_rounds=" << FLAGS_max_decode_rounds; + +#if !defined(USE_NPU) && !defined(USE_CUDA) + FLAGS_enable_block_copy_kernel = false; +#endif handler->master = std::make_unique(options); handler->master->run(); @@ -155,6 +271,7 @@ XLLM_CAPI_EXPORT bool xllm_rec_initialize( } handler->model_ids.clear(); handler->model_ids.emplace_back(model_id); + handler->pipeline_type = pipeline_type; handler->initialized = true; @@ -166,6 +283,7 @@ XLLM_CAPI_EXPORT bool xllm_rec_initialize( handler->master.reset(); handler->executor.reset(); handler->model_ids.clear(); + handler->pipeline_type = xllm::RecPipelineType::kLlmRecDefault; handler->initialized = false; return false; diff --git a/xllm/c_api/types.h b/xllm/c_api/types.h index 25fc274b8..e95f1b125 100644 --- a/xllm/c_api/types.h +++ b/xllm/c_api/types.h @@ -323,6 +323,37 @@ typedef struct XLLM_CAPI_EXPORT XLLM_Choices { size_t entries_size; } XLLM_Choices; +/** + * @brief REC/OneRec specific output extension aligned by choice index + */ +typedef struct XLLM_CAPI_EXPORT XLLM_RecOutput { + /** Choice index this REC extension belongs to */ + uint32_t index; + + /** Selected REC item ids for this choice */ + int64_t* item_ids; + + /** Number of item ids in the item_ids array */ + size_t item_ids_size; + + /** Token-aligned REC/OneRec logprobs for this choice */ + float* rec_token_logprobs; + + /** Number of entries in rec_token_logprobs */ + size_t rec_token_logprobs_size; +} XLLM_RecOutput; + +/** + * @brief List of REC/OneRec specific output extensions + */ +typedef struct XLLM_CAPI_EXPORT XLLM_RecOutputs { + /** Pointer to array of REC output entries */ + XLLM_RecOutput* entries; + + /** Number of entries in the REC output array */ + size_t entries_size; +} XLLM_RecOutputs; + #define XLLM_ERROR_INFO_MAX_LEN 512 /** @@ -352,6 +383,9 @@ typedef struct XLLM_CAPI_EXPORT XLLM_Response { /** Token usage statistics for the request */ XLLM_Usage usage; + + /** REC/OneRec specific response extensions */ + XLLM_RecOutputs rec_outputs; } XLLM_Response; /** diff --git a/xllm/cc_api/llm.cpp b/xllm/cc_api/llm.cpp index dad5c5c76..ab3bd9e63 100644 --- a/xllm/cc_api/llm.cpp +++ b/xllm/cc_api/llm.cpp @@ -113,7 +113,7 @@ bool LLM::Initialize(const std::string& model_path, .is_local(init_options.is_local) .server_idx(init_options.server_idx); -#if !defined(USE_NPU) +#if !defined(USE_NPU) && !defined(USE_CUDA) FLAGS_enable_block_copy_kernel = false; #endif diff --git a/xllm/compiler/__init__.py b/xllm/compiler/__init__.py new file mode 100644 index 000000000..24134aa0b --- /dev/null +++ b/xllm/compiler/__init__.py @@ -0,0 +1 @@ +"""Compiler-side utilities for xLLM build and AOT flows.""" diff --git a/xllm/compiler/tilelang/__init__.py b/xllm/compiler/tilelang/__init__.py new file mode 100644 index 000000000..682dcd093 --- /dev/null +++ b/xllm/compiler/tilelang/__init__.py @@ -0,0 +1 @@ +"""TileLang AOT compiler support for xLLM.""" diff --git a/xllm/compiler/tilelang/bootstrap.py b/xllm/compiler/tilelang/bootstrap.py new file mode 100644 index 000000000..1ccbc005b --- /dev/null +++ b/xllm/compiler/tilelang/bootstrap.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from .tilelang_ascend_install import ( + PREPARE_ASCEND_COMMAND, + ensure_ascend_ready, + prepare_ascend, +) + +__all__ = [ + "PREPARE_ASCEND_COMMAND", + "ensure_ascend_ready", + "prepare_ascend", +] diff --git a/xllm/compiler/tilelang/cli/compile_kernels.py b/xllm/compiler/tilelang/cli/compile_kernels.py new file mode 100644 index 000000000..4a7e223e9 --- /dev/null +++ b/xllm/compiler/tilelang/cli/compile_kernels.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compile xLLM TileLang kernels and emit manifests." + ) + parser.add_argument( + "--target", + required=True, + choices=["ascend", "cuda"], + help="Compilation target backend.", + ) + parser.add_argument( + "--output-root", + required=True, + help="Output root for compiled TileLang artifacts.", + ) + parser.add_argument( + "--device", + choices=["a2", "a3"], + default=None, + help="Ascend device type used to resolve build-time toolchain settings.", + ) + parser.add_argument( + "--kernels", + nargs="*", + default=None, + help="Optional kernel names. Compile all registered kernels when omitted.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force recompilation even when cache is hit.", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + output_root = Path(args.output_root).resolve() + output_root.mkdir(parents=True, exist_ok=True) + + if args.target == "ascend": + from ..bootstrap import prepare_ascend + + prepare_ascend() + from ..targets.ascend.build import build_kernels + + manifests = build_kernels( + output_root=output_root, + kernel_names=args.kernels, + force=args.force, + device=args.device, + ) + elif args.target == "cuda": + from ..targets.cuda.build import build_kernels + + manifests = build_kernels( + output_root=output_root, + kernel_names=args.kernels, + force=args.force, + ) + else: + raise ValueError(f"Unsupported target: {args.target}") + for manifest in manifests: + print(f"[INFO] built {manifest.target}:{manifest.kernel_name}") + print(f"[INFO] manifest: {Path(manifest.output_dir) / 'manifest.json'}") + + +if __name__ == "__main__": + main() diff --git a/xllm/compiler/tilelang/cli/prepare_ascend.py b/xllm/compiler/tilelang/cli/prepare_ascend.py new file mode 100644 index 000000000..41f2ffb3a --- /dev/null +++ b/xllm/compiler/tilelang/cli/prepare_ascend.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import argparse + +from ..bootstrap import PREPARE_ASCEND_COMMAND, prepare_ascend + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Prepare third_party/tilelang-ascend for xLLM Ascend TileLang builds." + ) + parser.add_argument( + "--force", + action="store_true", + help="Force rerunning install_ascend.sh even when cached artifacts look ready.", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + tilelang_root = prepare_ascend(force=args.force) + print(f"[INFO] tilelang-ascend is ready under: {tilelang_root}") + print("[INFO] Next step: run your usual `python setup.py build ...` or `python setup.py test ...` command.") + print(f"[INFO] Re-run this step explicitly with: `{PREPARE_ASCEND_COMMAND}`") + + +if __name__ == "__main__": + main() diff --git a/xllm/compiler/tilelang/common/cache.py b/xllm/compiler/tilelang/common/cache.py new file mode 100644 index 000000000..53604dded --- /dev/null +++ b/xllm/compiler/tilelang/common/cache.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any + +from .manifest import KernelFamilyManifest +from .spec import KernelCompileSpec +from .toolchain import sha256_file + + +def compute_cache_key( + spec: KernelCompileSpec, + fingerprint: dict[str, Any], + dependency_files: list[str | Path], +) -> str: + payload = { + "spec": spec.cache_key_material(), + "fingerprint": fingerprint, + "dependencies": { + str(Path(path).resolve()): sha256_file(path) for path in dependency_files + }, + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + return hashlib.sha256(encoded).hexdigest() + + +def is_cache_hit( + manifest_path: str | Path, variant_key: str, expected_cache_key: str +) -> bool: + path = Path(manifest_path) + if not path.is_file(): + return False + + try: + manifest = KernelFamilyManifest.read(path) + except Exception: + return False + + variant = manifest.get_variant(variant_key) + if variant is None: + return False + + if variant.cache_key != expected_cache_key: + return False + + return Path(variant.generated_source).is_file() and Path(variant.compiled_binary).is_file() diff --git a/xllm/compiler/tilelang/common/manifest.py b/xllm/compiler/tilelang/common/manifest.py new file mode 100644 index 000000000..9fe32c1d6 --- /dev/null +++ b/xllm/compiler/tilelang/common/manifest.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +from .spec import DispatchField + + +@dataclass +class KernelAbiParameter: + cpp_type: str + name: str + + +@dataclass +class KernelAbi: + return_type: str + parameters: list[KernelAbiParameter] = field(default_factory=list) + + def to_json_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass +class KernelVariantManifest: + variant_key: str + specialization: dict[str, Any] + generated_source: str + compiled_binary: str + entry_symbol: str + cache_key: str + dispatch_values: dict[str, Any] = field(default_factory=dict) + toolchain_options: dict[str, Any] = field(default_factory=dict) + fingerprint: dict[str, Any] = field(default_factory=dict) + compile_definitions: list[str] = field(default_factory=list) + + def to_json_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass +class KernelFamilyManifest: + target: str + kernel_name: str + output_dir: str + variants_inc: str + registry_inc: str = "" + dispatch_schema: list[DispatchField] = field(default_factory=list) + kernel_abi: KernelAbi | None = None + variants: list[KernelVariantManifest] = field(default_factory=list) + schema_version: int = 2 + + def to_json_dict(self) -> dict[str, Any]: + data = asdict(self) + data["dispatch_schema"] = [asdict(field) for field in self.dispatch_schema] + data["kernel_abi"] = ( + None if self.kernel_abi is None else self.kernel_abi.to_json_dict() + ) + data["variants"] = [variant.to_json_dict() for variant in self.variants] + return data + + @property + def manifest_path(self) -> Path: + return Path(self.output_dir) / "manifest.json" + + def write(self, path: str | Path) -> None: + output = Path(path) + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text( + json.dumps(self.to_json_dict(), indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + @classmethod + def read(cls, path: str | Path) -> "KernelFamilyManifest": + data = json.loads(Path(path).read_text(encoding="utf-8")) + dispatch_schema = [ + DispatchField(**field) for field in data.pop("dispatch_schema", []) + ] + kernel_abi_data = data.pop("kernel_abi", None) + kernel_abi = None + if kernel_abi_data is not None: + kernel_abi = KernelAbi( + return_type=kernel_abi_data["return_type"], + parameters=[ + KernelAbiParameter(**param) + for param in kernel_abi_data.get("parameters", []) + ], + ) + variants = [ + KernelVariantManifest(**variant) for variant in data.pop("variants", []) + ] + return cls( + dispatch_schema=dispatch_schema, + kernel_abi=kernel_abi, + variants=variants, + **data, + ) + + def get_variant(self, variant_key: str) -> KernelVariantManifest | None: + for variant in self.variants: + if variant.variant_key == variant_key: + return variant + return None diff --git a/xllm/compiler/tilelang/common/spec.py b/xllm/compiler/tilelang/common/spec.py new file mode 100644 index 000000000..55e129cf9 --- /dev/null +++ b/xllm/compiler/tilelang/common/spec.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any + +_REGISTER_KERNEL_ATTR = "__xllm_tilelang_registered_kernel__" +_ENTRY_SYMBOL_CONTEXT_KEY = "entry_symbol" +_C_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +_SAFE_VARIANT_KEY_PATTERN = re.compile(r"^[A-Za-z0-9_]+$") +_SUPPORTED_DISPATCH_FIELD_KINDS = frozenset({"int32", "dtype"}) +_SPECIALIZATION_CONFIG_KEYS = frozenset( + { + "variant_key", + "specialization", + "compile_definitions", + "kernel_name", + "target", + "entry_name", + "source_entry_symbol", + } +) + + +@dataclass(frozen=True) +class DispatchField: + name: str + kind: str + + def validate(self) -> None: + if not _C_IDENTIFIER_PATTERN.match(self.name): + raise ValueError( + "DispatchField.name must be a valid C/C++ identifier" + ) + if self.kind not in _SUPPORTED_DISPATCH_FIELD_KINDS: + supported = ", ".join(sorted(_SUPPORTED_DISPATCH_FIELD_KINDS)) + raise ValueError( + f"DispatchField.kind must be one of: {supported}" + ) + + +@dataclass(frozen=True) +class KernelCompileSpec: + target: str + kernel_name: str + module_name: str + variant_key: str + specialization: dict[str, Any] = field(default_factory=dict) + dispatch_values: dict[str, Any] = field(default_factory=dict) + entry_name: str | None = None + source_entry_symbol: str = "call" + + def cache_key_material(self) -> dict[str, Any]: + return { + "target": self.target, + "kernel_name": self.kernel_name, + "module_name": self.module_name, + "variant_key": self.variant_key, + "specialization": self.specialization, + "dispatch_values": self.dispatch_values, + "entry_name": self.entry_name, + "source_entry_symbol": self.source_entry_symbol, + } + + +@dataclass(frozen=True) +class KernelSpec: + variant_key: str + specialization: dict[str, Any] = field(default_factory=dict) + compile_definitions: dict[str, str] = field(default_factory=dict) + kernel_name: str | None = None + target: str = "ascend" + entry_name: str | None = None + source_entry_symbol: str = "call" + + def validate(self) -> None: + if not self.variant_key: + raise ValueError("KernelSpec.variant_key must not be empty") + if not _SAFE_VARIANT_KEY_PATTERN.match(self.variant_key): + raise ValueError( + "KernelSpec.variant_key must contain only letters, digits, or " + "underscore" + ) + + if not self.specialization: + raise ValueError("KernelSpec.specialization must not be empty") + + if self.kernel_name is not None and not _C_IDENTIFIER_PATTERN.match( + self.kernel_name + ): + raise ValueError( + "KernelSpec.kernel_name must be a valid C/C++ identifier" + ) + + if self.entry_name is not None and not _C_IDENTIFIER_PATTERN.match( + self.entry_name + ): + raise ValueError("KernelSpec.entry_name must be a valid C/C++ identifier") + + context_keys = set(self.specialization) + context_keys.add(_ENTRY_SYMBOL_CONTEXT_KEY) + uses_entry_symbol = False + + for macro_name, context_key in self.compile_definitions.items(): + if context_key not in context_keys: + raise KeyError( + "KernelSpec.compile_definitions references unknown context " + f"key {context_key!r} for macro {macro_name!r}" + ) + uses_entry_symbol = ( + uses_entry_symbol or context_key == _ENTRY_SYMBOL_CONTEXT_KEY + ) + + if uses_entry_symbol and not self.entry_name: + raise ValueError( + "KernelSpec.entry_name is required when " + "compile_definitions references 'entry_symbol'" + ) + + def to_compile_spec( + self, *, module_name: str, dispatch_schema: list[DispatchField] + ) -> KernelCompileSpec: + dispatch_values = { + field.name: self.specialization[field.name] for field in dispatch_schema + } + return KernelCompileSpec( + target=self.target, + kernel_name=self.kernel_name or module_name, + module_name=module_name, + variant_key=self.variant_key, + specialization=dict(self.specialization), + dispatch_values=dispatch_values, + entry_name=self.entry_name, + source_entry_symbol=self.source_entry_symbol, + ) + + def render_compile_definitions(self, *, entry_symbol: str) -> list[str]: + self.validate() + context = dict(self.specialization) + context[_ENTRY_SYMBOL_CONTEXT_KEY] = entry_symbol + definitions: list[str] = [] + + for macro_name, context_key in self.compile_definitions.items(): + if context_key not in context: + raise KeyError( + "compile_definitions references unknown context key " + f"{context_key!r} for macro {macro_name!r}" + ) + definitions.append(f"{macro_name}={context[context_key]}") + + return definitions + + +class TilelangKernel: + """Marker base class for TileLang kernel generator classes.""" + + TARGET = "ascend" + KERNEL_NAME: str | None = None + ENTRY_NAME: str | None = None + COMPILE_DEFINITIONS: dict[str, str] = {} + SOURCE_ENTRY_SYMBOL = "call" + DISPATCH_SCHEMA: list[DispatchField] = [] + SPECIALIZATIONS: list[dict[str, Any] | KernelSpec] = [] + + @classmethod + def dispatch_schema(cls) -> list[DispatchField]: + if not cls.DISPATCH_SCHEMA: + raise NotImplementedError( + f"{cls.__name__} must define non-empty DISPATCH_SCHEMA" + ) + + normalized: list[DispatchField] = [] + seen_names: set[str] = set() + for index, field in enumerate(cls.DISPATCH_SCHEMA): + if not isinstance(field, DispatchField): + raise TypeError( + f"{cls.__name__}.DISPATCH_SCHEMA[{index}] must be DispatchField, " + f"got {type(field).__name__}" + ) + field.validate() + if field.name in seen_names: + raise ValueError( + f"{cls.__name__}.DISPATCH_SCHEMA contains duplicate field " + f"{field.name!r}" + ) + seen_names.add(field.name) + normalized.append(field) + return normalized + + @classmethod + def specs(cls) -> list[KernelSpec]: + if not cls.SPECIALIZATIONS: + raise NotImplementedError( + f"{cls.__name__} must define non-empty SPECIALIZATIONS or override " + "specs()" + ) + return [ + cls._specialization_to_spec(specialization) + for specialization in cls.SPECIALIZATIONS + ] + + @classmethod + def _specialization_to_spec( + cls, specialization: dict[str, Any] | KernelSpec + ) -> KernelSpec: + if isinstance(specialization, KernelSpec): + return specialization + + if not isinstance(specialization, dict): + raise TypeError( + f"{cls.__name__}.SPECIALIZATIONS entries must be dict or KernelSpec, " + f"got {type(specialization).__name__}" + ) + + specialization_data = dict(specialization) + specialization_fields = specialization_data.pop("specialization", None) + if specialization_fields is None: + specialization_fields = { + key: value + for key, value in specialization_data.items() + if key not in _SPECIALIZATION_CONFIG_KEYS + } + for key in specialization_fields: + specialization_data.pop(key) + elif not isinstance(specialization_fields, dict): + raise TypeError( + f"{cls.__name__}.SPECIALIZATIONS specialization must be dict, got " + f"{type(specialization_fields).__name__}" + ) + + compile_definitions = dict(cls.COMPILE_DEFINITIONS) + compile_definitions.update( + dict(specialization_data.pop("compile_definitions", {})) + ) + variant_key = specialization_data.pop("variant_key", None) + if not isinstance(variant_key, str) or not variant_key: + raise ValueError( + f"{cls.__name__}.SPECIALIZATIONS entries must define non-empty " + "'variant_key'" + ) + + spec = KernelSpec( + variant_key=variant_key, + specialization=dict(specialization_fields), + compile_definitions=compile_definitions, + kernel_name=specialization_data.pop("kernel_name", cls.KERNEL_NAME), + target=specialization_data.pop("target", cls.TARGET), + entry_name=specialization_data.pop("entry_name", cls.ENTRY_NAME), + source_entry_symbol=specialization_data.pop( + "source_entry_symbol", cls.SOURCE_ENTRY_SYMBOL + ), + ) + if specialization_data: + unknown_keys = ", ".join(sorted(specialization_data)) + raise KeyError( + f"{cls.__name__}.SPECIALIZATIONS contains unsupported config keys: " + f"{unknown_keys}" + ) + return spec + + +def register_kernel(cls: type[TilelangKernel]) -> type[TilelangKernel]: + setattr(cls, _REGISTER_KERNEL_ATTR, True) + return cls + + +def is_registered_kernel_class(obj: object) -> bool: + return isinstance(obj, type) and bool(obj.__dict__.get(_REGISTER_KERNEL_ATTR, False)) diff --git a/xllm/compiler/tilelang/common/toolchain.py b/xllm/compiler/tilelang/common/toolchain.py new file mode 100644 index 000000000..112f06b42 --- /dev/null +++ b/xllm/compiler/tilelang/common/toolchain.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import hashlib +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Sequence + + +def repo_root() -> Path: + return Path(__file__).resolve().parents[4] + + +def default_tilelang_root() -> Path: + return repo_root() / "third_party" / "tilelang-ascend" + + +def resolve_tilelang_root() -> Path: + value = os.environ.get("TL_ROOT", "").strip() + if value: + return Path(value).resolve() + return default_tilelang_root().resolve() + + +def require_env(name: str) -> str: + value = os.environ.get(name, "").strip() + if not value: + raise RuntimeError(f"Required environment variable is not set: {name}") + return value + + +def prepend_pythonpath(env: dict[str, str], path: str) -> None: + current = env.get("PYTHONPATH", "") + items = [item for item in current.split(os.pathsep) if item] + items = [item for item in items if item != path] + items.insert(0, path) + env["PYTHONPATH"] = os.pathsep.join(items) + + +def prepare_tilelang_import(tilelang_root: str | Path | None = None) -> Path: + tl_root = ( + Path(tilelang_root).resolve() if tilelang_root is not None else resolve_tilelang_root() + ) + os.environ["TL_ROOT"] = str(tl_root) + prepend_pythonpath(os.environ, str(tl_root)) + tl_root_str = str(tl_root) + # Keep TL_ROOT at sys.path front to avoid resolving the sibling + # package xllm/compiler/tilelang as top-level `tilelang`. + sys.path = [p for p in sys.path if p != tl_root_str] + sys.path.insert(0, tl_root_str) + os.environ.setdefault("ACL_OP_INIT_MODE", "1") + return tl_root + + +def run_checked( + cmd: Sequence[str], + *, + cwd: str | Path | None = None, + env: dict[str, str] | None = None, +) -> None: + subprocess.check_call(list(cmd), cwd=cwd, env=env) + + +def sha256_file(path: str | Path) -> str: + data = Path(path).read_bytes() + return hashlib.sha256(data).hexdigest() + + +def git_head(path: str | Path) -> str: + repo_path = str(Path(path).resolve()) + result = subprocess.run( + ["git", "-c", f"safe.directory={repo_path}", "-C", repo_path, "rev-parse", "HEAD"], + text=True, + capture_output=True, + check=False, + ) + if result.returncode != 0: + return "" + return result.stdout.strip() + + +def find_required_executable(name: str) -> str: + executable = shutil.which(name) + if not executable: + raise RuntimeError(f"Required executable was not found in PATH: {name}") + return executable diff --git a/xllm/compiler/tilelang/patches/tilelang_ascend/0001-install-ascend.patch b/xllm/compiler/tilelang/patches/tilelang_ascend/0001-install-ascend.patch new file mode 100644 index 000000000..e3114186d --- /dev/null +++ b/xllm/compiler/tilelang/patches/tilelang_ascend/0001-install-ascend.patch @@ -0,0 +1,23 @@ +diff --git a/requirements-build.txt b/requirements-build.txt +--- a/requirements-build.txt ++++ b/requirements-build.txt +@@ -1,6 +1,5 @@ + # Should be mirrored in pyproject.toml + build +-cmake>=3.26 + packaging + setuptools>=61 + torch +diff --git a/install_ascend.sh b/install_ascend.sh +--- a/install_ascend.sh ++++ b/install_ascend.sh +@@ -140,8 +140,8 @@ echo "Building TileLang with make..." + # Other wise, make will use all available cores + # and it may cause the system to be unresponsive + CORES=$(nproc) + MAKE_JOBS=$(( CORES * 50 / 100 )) +-make -j${MAKE_JOBS} ++make -j + + if [ $? -ne 0 ]; then + echo "Error: TileLang build failed." diff --git a/xllm/compiler/tilelang/targets/ascend/abi_entry.py b/xllm/compiler/tilelang/targets/ascend/abi_entry.py new file mode 100644 index 000000000..a686d9c75 --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/abi_entry.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import re + +from ...common.manifest import KernelAbi, KernelAbiParameter + + +def rename_entry_symbol(source: str, source_entry_symbol: str, entry_symbol: str) -> str: + pattern = rf"\b{re.escape(source_entry_symbol)}\b" + return re.sub(pattern, entry_symbol, source) + + +def rename_variant_internal_symbols(source: str, variant_key: str) -> str: + symbol_names: set[str] = set() + symbol_names.update( + re.findall( + r'extern\s+"C"\s+__global__\s+__aicore__\s+void\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(', + source, + ) + ) + symbol_names.update( + re.findall(r"\bvoid\s+([A-Za-z_][A-Za-z0-9_]*_tiling)\s*\(", source) + ) + + renamed_source = source + for symbol_name in sorted(symbol_names, key=len, reverse=True): + renamed_source = re.sub( + rf"\b{re.escape(symbol_name)}\b", + f"{symbol_name}__{variant_key}", + renamed_source, + ) + return renamed_source + + +def normalize_cpp_type(cpp_type: str) -> str: + normalized = re.sub(r"\s+", " ", cpp_type).strip() + normalized = re.sub(r"\s*([*&]+)\s*", r"\1", normalized) + return normalized + + +def parse_kernel_abi(source: str, entry_symbol: str) -> KernelAbi: + pattern = re.compile( + rf'extern\s+"C"\s+' + rf"(?P[^(){{}};]+?)\s+" + rf"{re.escape(entry_symbol)}\s*\(" + r"(?P[^)]*)\)\s*\{", + re.MULTILINE, + ) + match = pattern.search(source) + if match is None: + raise ValueError( + f"Failed to parse exported entry ABI for symbol {entry_symbol!r}" + ) + + return_type = normalize_cpp_type(match.group("return_type")) + params_text = match.group("params").strip() + parameters: list[KernelAbiParameter] = [] + if params_text and params_text != "void": + for param in (part.strip() for part in params_text.split(",")): + parsed = re.match( + r"(?P.+?[\*&]?)\s*(?P[A-Za-z_][A-Za-z0-9_]*)$", + param, + ) + if parsed is None: + raise ValueError( + "Failed to parse kernel ABI parameter " + f"{param!r} for symbol {entry_symbol!r}" + ) + parameters.append( + KernelAbiParameter( + cpp_type=normalize_cpp_type(parsed.group("type")), + name=parsed.group("name"), + ) + ) + + return KernelAbi(return_type=return_type, parameters=parameters) diff --git a/xllm/compiler/tilelang/targets/ascend/build.py b/xllm/compiler/tilelang/targets/ascend/build.py new file mode 100644 index 000000000..22fdddbd4 --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/build.py @@ -0,0 +1,53 @@ +import importlib +import os +import pkgutil +import re +from dataclasses import dataclass +from pathlib import Path + +from ...common.manifest import KernelFamilyManifest +from ...common.toolchain import find_required_executable +from .kernel_family_builder import build_kernel_family as _build_kernel_family +from .kernel_registry import RegisteredKernelFamily, get_default_families +from .toolchain import resolve_build_context + + +def build_kernel_family( + family: RegisteredKernelFamily, + output_root: str | Path, + force: bool = False, + device: str | None = None, +) -> KernelFamilyManifest: + context = resolve_build_context( + device=device, + bisheng_executable=find_required_executable("bisheng"), + ) + return _build_kernel_family( + family, + output_root=output_root, + context=context, + force=force, + ) + + +def build_kernels( + output_root: str | Path, + kernel_names: list[str] | None = None, + force: bool = False, + device: str | None = None, +) -> list[KernelFamilyManifest]: + context = resolve_build_context( + device=device, + bisheng_executable=find_required_executable("bisheng"), + ) + manifests = [] + for family in get_default_families(kernel_names): + manifests.append( + _build_kernel_family( + family, + output_root=output_root, + context=context, + force=force, + ) + ) + return manifests diff --git a/xllm/compiler/tilelang/targets/ascend/kernel_family_builder.py b/xllm/compiler/tilelang/targets/ascend/kernel_family_builder.py new file mode 100644 index 000000000..7f099bd1e --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/kernel_family_builder.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +import os +from pathlib import Path + +from ...common.cache import compute_cache_key, is_cache_hit +from ...common.manifest import KernelAbi, KernelFamilyManifest, KernelVariantManifest +from ...common.spec import DispatchField, KernelCompileSpec, KernelSpec, TilelangKernel +from ...common.toolchain import repo_root, run_checked +from . import abi_entry, kernel_registry, toolchain +from .kernel_registry import RegisteredKernelFamily +from .kernels import utils as kernel_utils +from .kernels.utils import render_family_registry_inc, render_family_variants_inc +from .toolchain import AscendBuildContext, TILELANG_BISHENG_COMMON_FLAGS + + +@dataclass(frozen=True) +class _VariantBuildPlan: + compile_spec: KernelCompileSpec + kernel_spec: KernelSpec + generated_source: Path + compiled_binary: Path + cache_key: str + + +@dataclass(frozen=True) +class _VariantBuildResult: + manifest: KernelVariantManifest + kernel_abi: KernelAbi + + +def _variant_entry_symbol(spec: KernelCompileSpec) -> str: + kernel_entry_name = spec.entry_name or spec.kernel_name + return f"{kernel_entry_name}__{spec.variant_key}_call" + + +def _read_family_manifest(path: Path) -> KernelFamilyManifest | None: + if not path.is_file(): + return None + try: + return KernelFamilyManifest.read(path) + except Exception: + return None + + +def _render_variants_inc( + kernel_name: str, + kernel_cls: type[TilelangKernel], + dispatch_schema: list[DispatchField], + variants: list[KernelVariantManifest], +) -> str: + renderer = getattr(kernel_cls, "render_variants_inc", None) + if renderer is None: + return render_family_variants_inc( + kernel_name=kernel_name, + dispatch_schema=dispatch_schema, + variants=variants, + ) + if not callable(renderer): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' defines " + "non-callable render_variants_inc" + ) + rendered = renderer(variants, dispatch_schema) + if not isinstance(rendered, str): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' " + "render_variants_inc(...) must return str" + ) + return rendered + + +def _render_registry_inc( + kernel_name: str, + kernel_cls: type[TilelangKernel], + dispatch_schema: list[DispatchField], + kernel_abi: KernelAbi, + variants: list[KernelVariantManifest], +) -> str: + renderer = getattr(kernel_cls, "render_registry_inc", None) + if renderer is None: + return render_family_registry_inc( + kernel_name=kernel_name, + dispatch_schema=dispatch_schema, + kernel_abi=kernel_abi, + variants=variants, + ) + if not callable(renderer): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' defines " + "non-callable render_registry_inc" + ) + rendered = renderer(variants, dispatch_schema, kernel_abi) + if not isinstance(rendered, str): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' " + "render_registry_inc(...) must return str" + ) + return rendered + + +def _build_dependency_files(family: RegisteredKernelFamily) -> list[Path]: + # Keep cache invalidation aligned with the split builder implementation. + files = [ + Path(family.module.__file__).resolve(), + Path(__file__).resolve(), + Path(toolchain.__file__).resolve(), + Path(kernel_registry.__file__).resolve(), + Path(abi_entry.__file__).resolve(), + Path(kernel_utils.__file__).resolve(), + Path(__file__).resolve().with_name("build.py"), + ] + deduped: list[Path] = [] + seen: set[Path] = set() + for path in files: + if path in seen: + continue + seen.add(path) + deduped.append(path) + return deduped + + +def build_kernel_family( + family: RegisteredKernelFamily, + output_root: str | Path, + context: AscendBuildContext, + force: bool = False, +) -> KernelFamilyManifest: + family_output_dir = Path(output_root) / "targets" / "ascend" / family.kernel_name + family_output_dir.mkdir(parents=True, exist_ok=True) + manifest_path = family_output_dir / "manifest.json" + existing_manifest = _read_family_manifest(manifest_path) + dependency_files = _build_dependency_files(family) + + variant_manifest_by_key: dict[str, KernelVariantManifest] = {} + uncached_plans: list[_VariantBuildPlan] = [] + family_kernel_abi: KernelAbi | None = None + + for compile_spec, kernel_spec in family.spec_pairs: + if compile_spec.target != "ascend": + raise ValueError( + f"Unsupported target for Ascend build.py: {compile_spec.target}" + ) + + variant_output_dir = family_output_dir / compile_spec.variant_key + variant_output_dir.mkdir(parents=True, exist_ok=True) + generated_source = ( + variant_output_dir + / f"{compile_spec.kernel_name}_{compile_spec.variant_key}_kernel.cpp" + ) + compiled_binary = ( + variant_output_dir + / f"{compile_spec.kernel_name}_{compile_spec.variant_key}_kernel.o" + ) + + cache_key = compute_cache_key( + compile_spec, + context.fingerprint, + dependency_files, + ) + + cached_variant = ( + existing_manifest.get_variant(compile_spec.variant_key) + if existing_manifest is not None + else None + ) + if ( + not force + and cached_variant is not None + and Path(cached_variant.generated_source).is_file() + and Path(cached_variant.compiled_binary).is_file() + and is_cache_hit(manifest_path, compile_spec.variant_key, cache_key) + ): + cached_source = Path(cached_variant.generated_source).read_text( + encoding="utf-8" + ) + kernel_abi = abi_entry.parse_kernel_abi( + cached_source, cached_variant.entry_symbol + ) + if family_kernel_abi is None: + family_kernel_abi = kernel_abi + elif kernel_abi != family_kernel_abi: + raise ValueError( + "All variants in a TileLang kernel must share the same exported " + f"C ABI. Mismatch found in variant {compile_spec.variant_key!r}." + ) + variant_manifest_by_key[compile_spec.variant_key] = KernelVariantManifest( + variant_key=compile_spec.variant_key, + specialization=dict(compile_spec.specialization), + dispatch_values=dict(compile_spec.dispatch_values), + generated_source=cached_variant.generated_source, + compiled_binary=cached_variant.compiled_binary, + entry_symbol=cached_variant.entry_symbol, + cache_key=cached_variant.cache_key, + toolchain_options=dict(context.toolchain_options), + fingerprint=dict(context.fingerprint), + compile_definitions=kernel_spec.render_compile_definitions( + entry_symbol=cached_variant.entry_symbol + ), + ) + continue + + uncached_plans.append( + _VariantBuildPlan( + compile_spec=compile_spec, + kernel_spec=kernel_spec, + generated_source=generated_source, + compiled_binary=compiled_binary, + cache_key=cache_key, + ) + ) + + compile_cwd = repo_root() + + def _run_variant_job(plan: _VariantBuildPlan) -> _VariantBuildResult: + compile_spec = plan.compile_spec + kernel_spec = plan.kernel_spec + source = family.kernel_cls.generate_source(**compile_spec.specialization) + entry_symbol = _variant_entry_symbol(compile_spec) + rendered_source = abi_entry.rename_variant_internal_symbols( + abi_entry.rename_entry_symbol( + source, compile_spec.source_entry_symbol, entry_symbol + ), + compile_spec.variant_key, + ) + kernel_abi = abi_entry.parse_kernel_abi(rendered_source, entry_symbol) + plan.generated_source.write_text(rendered_source, encoding="utf-8") + + compile_cmd = [ + context.bisheng_executable, + f"--cce-aicore-arch={context.bisheng_arch}", + *TILELANG_BISHENG_COMMON_FLAGS, + f"-Dg_tilingKey=g_tilingKey__{compile_spec.variant_key}", + *[f"-I{include_dir}" for include_dir in context.include_dirs], + str(plan.generated_source), + "-c", + "-o", + str(plan.compiled_binary), + ] + run_checked(compile_cmd, cwd=compile_cwd) + manifest = KernelVariantManifest( + variant_key=compile_spec.variant_key, + specialization=dict(compile_spec.specialization), + dispatch_values=dict(compile_spec.dispatch_values), + generated_source=str(plan.generated_source), + compiled_binary=str(plan.compiled_binary), + entry_symbol=entry_symbol, + cache_key=plan.cache_key, + toolchain_options=dict(context.toolchain_options), + fingerprint=dict(context.fingerprint), + compile_definitions=kernel_spec.render_compile_definitions( + entry_symbol=entry_symbol + ), + ) + return _VariantBuildResult(manifest=manifest, kernel_abi=kernel_abi) + + if uncached_plans: + max_workers = max(1, os.cpu_count() or 1) + if max_workers == 1: + for plan in uncached_plans: + result = _run_variant_job(plan) + if family_kernel_abi is None: + family_kernel_abi = result.kernel_abi + elif result.kernel_abi != family_kernel_abi: + raise ValueError( + "All variants in a TileLang kernel must share the same exported " + "C ABI. Mismatch found in variant " + f"{result.manifest.variant_key!r}." + ) + variant_manifest_by_key[result.manifest.variant_key] = result.manifest + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_plan = { + executor.submit(_run_variant_job, plan): plan + for plan in uncached_plans + } + for future in as_completed(future_to_plan): + plan = future_to_plan[future] + try: + result = future.result() + except Exception as exc: + raise RuntimeError( + "Ascend variant build failed for variant " + f"{plan.compile_spec.variant_key!r}" + ) from exc + if family_kernel_abi is None: + family_kernel_abi = result.kernel_abi + elif result.kernel_abi != family_kernel_abi: + raise ValueError( + "All variants in a TileLang kernel must share the same exported " + "C ABI. Mismatch found in variant " + f"{result.manifest.variant_key!r}." + ) + variant_manifest_by_key[result.manifest.variant_key] = result.manifest + + variant_manifests: list[KernelVariantManifest] = [ + variant_manifest_by_key[compile_spec.variant_key] + for compile_spec, _ in family.spec_pairs + ] + + if family_kernel_abi is None: + raise ValueError( + f"TileLang kernel {family.kernel_name!r} produced no exported kernel ABI" + ) + + variants_inc_path = family_output_dir / "variants.inc" + variants_inc_path.write_text( + _render_variants_inc( + family.kernel_name, + family.kernel_cls, + family.dispatch_schema, + variant_manifests, + ), + encoding="utf-8", + ) + registry_inc_path = family_output_dir / "registry.inc" + registry_inc_path.write_text( + _render_registry_inc( + family.kernel_name, + family.kernel_cls, + family.dispatch_schema, + family_kernel_abi, + variant_manifests, + ), + encoding="utf-8", + ) + + manifest = KernelFamilyManifest( + target="ascend", + kernel_name=family.kernel_name, + output_dir=str(family_output_dir), + variants_inc=str(variants_inc_path), + registry_inc=str(registry_inc_path), + dispatch_schema=list(family.dispatch_schema), + kernel_abi=family_kernel_abi, + variants=variant_manifests, + ) + manifest.write(manifest_path) + return manifest diff --git a/xllm/compiler/tilelang/targets/ascend/kernel_registry.py b/xllm/compiler/tilelang/targets/ascend/kernel_registry.py new file mode 100644 index 000000000..106b4a26b --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/kernel_registry.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import importlib +import pkgutil +from dataclasses import dataclass +from pathlib import Path +from types import ModuleType + +from ...common.spec import ( + DispatchField, + KernelCompileSpec, + KernelSpec, + TilelangKernel, + is_registered_kernel_class, +) +from ...common.toolchain import prepare_tilelang_import + + +@dataclass(frozen=True) +class RegisteredKernelFamily: + module: ModuleType + kernel_cls: type[TilelangKernel] + module_name: str + kernel_name: str + dispatch_schema: list[DispatchField] + spec_pairs: list[tuple[KernelCompileSpec, KernelSpec]] + + +def _load_kernel_module(module_name: str) -> ModuleType: + prepare_tilelang_import() + return importlib.import_module(f"{__package__}.kernels.{module_name}") + + +def _kernels_dir() -> Path: + return Path(__file__).resolve().parent / "kernels" + + +def _iter_kernel_module_names() -> list[str]: + return sorted( + module.name + for module in pkgutil.iter_modules([str(_kernels_dir())]) + if not module.name.startswith("_") + ) + + +def _resolve_registered_kernel_class( + module_name: str, +) -> tuple[ModuleType, type[TilelangKernel] | None]: + module = _load_kernel_module(module_name) + kernel_classes = [ + obj + for obj in vars(module).values() + if isinstance(obj, type) + and obj.__module__ == module.__name__ + and is_registered_kernel_class(obj) + ] + if not kernel_classes: + return module, None + if len(kernel_classes) > 1: + kernel_names = ", ".join(sorted(cls.__name__ for cls in kernel_classes)) + raise TypeError( + f"TileLang kernel module {module_name!r} must define at most one " + f"@register_kernel class, found: {kernel_names}" + ) + + kernel_cls = kernel_classes[0] + if not issubclass(kernel_cls, TilelangKernel): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must inherit " + "TilelangKernel" + ) + return module, kernel_cls + + +def _load_registered_kernel_family( + module_name: str, +) -> RegisteredKernelFamily | None: + module, kernel_cls = _resolve_registered_kernel_class(module_name) + if kernel_cls is None: + return None + + generate_source = kernel_cls.__dict__.get("generate_source") + if not isinstance(generate_source, (staticmethod, classmethod)): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must define " + "callable generate_source(...) as @staticmethod or @classmethod" + ) + + resolved_generate_source = getattr(kernel_cls, "generate_source", None) + if not callable(resolved_generate_source): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must define " + "callable generate_source(...)" + ) + + resolved_specs = getattr(kernel_cls, "specs", None) + if not callable(resolved_specs): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must define " + "callable specs() -> list[KernelSpec]" + ) + resolved_dispatch_schema = getattr(kernel_cls, "dispatch_schema", None) + if not callable(resolved_dispatch_schema): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must define " + "callable dispatch_schema() -> list[DispatchField]" + ) + + try: + kernel_specs = resolved_specs() + except NotImplementedError as exc: + raise TypeError(str(exc)) from exc + if not isinstance(kernel_specs, list) or not kernel_specs: + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must return a " + "non-empty list[KernelSpec] from specs()" + ) + try: + dispatch_schema = resolved_dispatch_schema() + except NotImplementedError as exc: + raise TypeError(str(exc)) from exc + if not isinstance(dispatch_schema, list) or not dispatch_schema: + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' must return a " + "non-empty list[DispatchField] from dispatch_schema()" + ) + for index, field in enumerate(dispatch_schema): + if not isinstance(field, DispatchField): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' " + f"dispatch_schema()[{index}] must be DispatchField" + ) + + family_kernel_name: str | None = None + seen_variant_keys: set[str] = set() + spec_pairs: list[tuple[KernelCompileSpec, KernelSpec]] = [] + + for index, kernel_spec in enumerate(kernel_specs): + if not isinstance(kernel_spec, KernelSpec): + raise TypeError( + f"registered kernel class '{kernel_cls.__name__}' specs()[{index}] " + "must be KernelSpec" + ) + + kernel_spec.validate() + missing_dispatch_fields = [ + field.name + for field in dispatch_schema + if field.name not in kernel_spec.specialization + ] + if missing_dispatch_fields: + raise ValueError( + f"registered kernel class '{kernel_cls.__name__}' specs()[{index}] " + "is missing DISPATCH_SCHEMA fields: " + f"{', '.join(missing_dispatch_fields)}" + ) + compile_spec = kernel_spec.to_compile_spec( + module_name=module_name, + dispatch_schema=dispatch_schema, + ) + if family_kernel_name is None: + family_kernel_name = compile_spec.kernel_name + elif compile_spec.kernel_name != family_kernel_name: + raise ValueError( + f"registered kernel class '{kernel_cls.__name__}' must return " + "KernelSpec entries with the same kernel_name" + ) + + if compile_spec.variant_key in seen_variant_keys: + raise ValueError( + f"registered kernel class '{kernel_cls.__name__}' has duplicate " + f"variant_key {compile_spec.variant_key!r}" + ) + seen_variant_keys.add(compile_spec.variant_key) + spec_pairs.append((compile_spec, kernel_spec)) + + assert family_kernel_name is not None + return RegisteredKernelFamily( + module=module, + kernel_cls=kernel_cls, + module_name=module_name, + kernel_name=family_kernel_name, + dispatch_schema=dispatch_schema, + spec_pairs=spec_pairs, + ) + + +def registered_families() -> dict[str, RegisteredKernelFamily]: + families: dict[str, RegisteredKernelFamily] = {} + for module_name in _iter_kernel_module_names(): + family = _load_registered_kernel_family(module_name) + if family is None: + continue + if family.kernel_name in families: + raise ValueError( + "Duplicate Ascend TileLang kernel_name registered: " + f"{family.kernel_name}" + ) + families[family.kernel_name] = family + return families + + +def get_default_families( + kernel_names: list[str] | None = None, +) -> list[RegisteredKernelFamily]: + families = registered_families() + if kernel_names is None: + return list(families.values()) + missing = [name for name in kernel_names if name not in families] + if missing: + raise ValueError(f"Unknown Ascend TileLang kernels: {', '.join(missing)}") + return [families[name] for name in kernel_names] diff --git a/xllm/compiler/tilelang/targets/ascend/kernels/__init__.py b/xllm/compiler/tilelang/targets/ascend/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py b/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py new file mode 100644 index 000000000..59e0d45c7 --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/kernels/fused_gdn_gating.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 + +import argparse +from pathlib import Path + +import tilelang +import tilelang.language as T + +from compiler.tilelang.targets.ascend.kernels.utils import ( + DEFAULT_ASCEND_PASS_CONFIGS, + detect_vec_core_num, +) +from compiler.tilelang.common.spec import ( + DispatchField, + TilelangKernel, + register_kernel, +) + +DEFAULT_NUM_HEADS = 32 +DEFAULT_DTYPE = "bf16" +DEFAULT_MAX_BATCH = 4096 +DEFAULT_MAX_HEADS = 128 +REF_CHECK_NUM_BATCHES = 16 +REF_CHECK_NUM_HEADS = (1, 16, 32, 48, 64, 128) +VEC_NUM = 2 +VECTOR_BYTES_PER_ITER = 256 +SUPPORTED_NUM_HEADS = (4, 6, 8, 12, 16, 24, 32, 48, 64, 128) +MAX_VEC_CORE_NUM = detect_vec_core_num() +BATCH_SIZE_SPECIALIZATIONS = tuple(range(2, 49, 2)) + + +def select_launch_block_num(*, num_batches: int, vec_core_num: int) -> int: + """Pick launch block_num by current batch size.""" + if num_batches <= 0: + raise ValueError(f"num_batches({num_batches}) must be > 0") + if vec_core_num <= 0: + raise ValueError(f"vec_core_num({vec_core_num}) must be > 0") + return min(num_batches, vec_core_num) + + +def _dtype_size_in_bytes(dtype: str) -> int: + sizes = { + "float16": 2, + "bfloat16": 2, + "float32": 4, + } + if dtype not in sizes: + raise ValueError(f"Unsupported dtype for vector alignment: {dtype}") + return sizes[dtype] + + +def _align_count_to_vector_bytes(count: int, dtype: str) -> int: + elem_bytes = _dtype_size_in_bytes(dtype) + elems_per_iter = VECTOR_BYTES_PER_ITER // elem_bytes + return ((count + elems_per_iter - 1) // elems_per_iter) * elems_per_iter + + +def build_fused_gdn_gating_kernel( + *, + batch_size: int, + compile_max_batch: int, + num_heads: int, +): + if num_heads not in SUPPORTED_NUM_HEADS: + raise ValueError( + "fused_gdn_gating only supports num_heads in " + f"{SUPPORTED_NUM_HEADS}, got {num_heads}" + ) + if batch_size <= 0: + raise ValueError(f"batch_size({batch_size}) must be > 0") + if compile_max_batch <= 0: + raise ValueError( + f"compile_max_batch({compile_max_batch}) must be > 0" + ) + if batch_size > compile_max_batch: + raise ValueError( + f"batch_size({batch_size}) must be <= compile_max_batch({compile_max_batch})" + ) + + # vec_core_num is hardware capability; block_num is launch-time choice. + # block_num = min(num_batches, full_vec_core_num). + vec_core_num = MAX_VEC_CORE_NUM + block_num = select_launch_block_num( + num_batches=batch_size, vec_core_num=vec_core_num + ) + cubecore_block_num = block_num + task_num = block_num * VEC_NUM + acc_dtype = "float32" + input_dtype = "bfloat16" + mask_dtype = "uint8" + ub_tensor_dim = _align_count_to_vector_bytes(num_heads, acc_dtype) + compare_select_mask_bytes = ub_tensor_dim // 8 + + @T.prim_func + def fused_gdn_gating_kernel( + A_log: T.Tensor((num_heads,), acc_dtype), + a: T.Tensor((compile_max_batch, num_heads), input_dtype), + b: T.Tensor((compile_max_batch, num_heads), input_dtype), + dt_bias: T.Tensor((num_heads,), acc_dtype), + g_out: T.Tensor((compile_max_batch, num_heads), acc_dtype), + beta_out: T.Tensor((compile_max_batch, num_heads), input_dtype), + num_batches: T.int32, + softplus_beta: T.float32, + softplus_threshold: T.float32, + ): + with T.Kernel(cubecore_block_num, is_npu=True) as (cid, vid): + task_id = cid * VEC_NUM + vid + block_m = (num_batches + task_num - 1) // task_num + row_start = task_id * block_m + rows_left = T.if_then_else( + num_batches > row_start, num_batches - row_start, 0 + ) + num_rows_per_vec = T.if_then_else( + rows_left < block_m, + rows_left, + block_m, + ) + + with T.Scope("V"): + A_log_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + neg_exp_A_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + dt_bias_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + a_half_ub = T.alloc_shared((1, ub_tensor_dim), input_dtype) + b_half_ub = T.alloc_shared((1, ub_tensor_dim), input_dtype) + x_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + beta_x_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + softplus_abs_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + softplus_tmp_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + beta_fp32_ub = T.alloc_shared((1, ub_tensor_dim), acc_dtype) + sigmoid_tmp_ub = T.alloc_ub((1, ub_tensor_dim), mask_dtype) + softplus_cmp_mask_ub = T.alloc_ub( + (1, compare_select_mask_bytes), mask_dtype + ) + + T.copy(A_log[0], A_log_ub[0, :num_heads]) + T.copy(dt_bias[0], dt_bias_ub[0, :num_heads]) + T.tile.exp(neg_exp_A_ub, A_log_ub) + T.tile.mul(neg_exp_A_ub, neg_exp_A_ub, -1.0) + + for row_local in T.serial(num_rows_per_vec): + row = row_start + row_local + + T.copy(a[row, 0], a_half_ub[0, :num_heads]) + T.copy(b[row, 0], b_half_ub[0, :num_heads]) + + # x = a + dt_bias + # beta_x = beta * x + # softplus_tmp = log(1 + exp(-abs(beta_x))) + T.tile.cast(x_ub, a_half_ub, "CAST_NONE", ub_tensor_dim) + T.tile.axpy(x_ub, dt_bias_ub, 1.0) + T.tile.mul(beta_x_ub, x_ub, softplus_beta) + T.tile.abs(softplus_abs_ub, beta_x_ub) + T.tile.mul(softplus_tmp_ub, softplus_abs_ub, -1.0) + T.tile.exp(beta_fp32_ub, softplus_tmp_ub) + T.tile.add(beta_fp32_ub, beta_fp32_ub, 1.0) + T.tile.ln(softplus_tmp_ub, beta_fp32_ub) + + # Ascend compare/select consumes one 256B vector chunk per + # iteration. For float32 this is 64 elements, so num_heads + # < 64 still uses UB tensors aligned to the full chunk. + T.tile.compare( + softplus_cmp_mask_ub, + beta_x_ub, + softplus_threshold, + "GT", + ) + # softplus(x) = log(1 + exp(-abs(beta_x))) / beta + # + 0.5 * (beta_x + abs(beta_x)) / beta + T.tile.add(beta_x_ub, beta_x_ub, softplus_abs_ub) + T.tile.mul(beta_x_ub, beta_x_ub, 0.5 / softplus_beta) + T.tile.axpy(beta_x_ub, softplus_tmp_ub, 1.0 / softplus_beta) + T.tile.select( + beta_x_ub, + softplus_cmp_mask_ub, + x_ub, + beta_x_ub, + "VSEL_TENSOR_TENSOR_MODE", + ) + + # Reuse x_ub as b_fp32 and g output buffer, and reuse + # b_half_ub as beta_half output buffer. + T.tile.cast(x_ub, b_half_ub, "CAST_NONE", ub_tensor_dim) + T.tile.sigmoid(beta_fp32_ub, x_ub, sigmoid_tmp_ub) + T.tile.mul(x_ub, neg_exp_A_ub, beta_x_ub) + T.tile.cast( + b_half_ub, beta_fp32_ub, "CAST_RINT", ub_tensor_dim + ) + + T.copy(x_ub[0, :num_heads], g_out[row, 0]) + T.copy(b_half_ub[0, :num_heads], beta_out[row, 0]) + + return fused_gdn_gating_kernel + + +@tilelang.jit(pass_configs=DEFAULT_ASCEND_PASS_CONFIGS) +def fused_gdn_gating_kernel_jit( + num_batches: int, + compile_max_batch: int, + num_heads: int, +): + return build_fused_gdn_gating_kernel( + batch_size=num_batches, + compile_max_batch=compile_max_batch, + num_heads=num_heads, + ) + + +@register_kernel +class FusedGdnGatingKernel(TilelangKernel): + DISPATCH_SCHEMA = [ + DispatchField("batch_size", "int32"), + DispatchField("num_heads", "int32"), + DispatchField("dtype", "dtype"), + ] + SPECIALIZATIONS = [ + { + "variant_key": f"bs{batch_size}_nh{num_heads}_bf16", + "batch_size": batch_size, + "num_heads": num_heads, + "dtype": DEFAULT_DTYPE, + } + for num_heads in SUPPORTED_NUM_HEADS + for batch_size in BATCH_SIZE_SPECIALIZATIONS + ] + + @staticmethod + def generate_source(batch_size: int, num_heads: int, dtype: str) -> str: + if dtype != DEFAULT_DTYPE: + raise ValueError( + f"fused_gdn_gating only supports dtype={DEFAULT_DTYPE}, got {dtype}" + ) + if num_heads not in SUPPORTED_NUM_HEADS: + raise ValueError( + "fused_gdn_gating only supports num_heads in " + f"{SUPPORTED_NUM_HEADS}, got {num_heads}" + ) + if batch_size not in BATCH_SIZE_SPECIALIZATIONS: + raise ValueError( + "fused_gdn_gating only supports batch_size in " + f"{BATCH_SIZE_SPECIALIZATIONS}, got {batch_size}" + ) + tilelang.disable_cache() + tilelang_kernel = build_fused_gdn_gating_kernel( + batch_size=batch_size, + compile_max_batch=DEFAULT_MAX_BATCH, + num_heads=num_heads, + ) + with tilelang.tvm.transform.PassContext( + opt_level=3, config=DEFAULT_ASCEND_PASS_CONFIGS + ): + kernel = tilelang.engine.lower(tilelang_kernel) + return kernel.kernel_source + + +def _torch_fused_gdn_gating( + A_log: "torch.Tensor", + a: "torch.Tensor", + b: "torch.Tensor", + dt_bias: "torch.Tensor", + softplus_beta: float, + softplus_threshold: float, +) -> tuple["torch.Tensor", "torch.Tensor"]: + import torch + import torch.nn.functional as F + + softplus_out = F.softplus( + a.to(torch.float32) + dt_bias, + beta=softplus_beta, + threshold=softplus_threshold, + ) + g_ref = -A_log.exp() * softplus_out + beta_ref = torch.sigmoid(b.to(torch.float32)).to(torch.bfloat16) + return g_ref, beta_ref + + +def _run_ref_check( + *, + num_batches: int, + num_heads: int, + compile_max_batch: int, + softplus_beta: float, + softplus_threshold: float, +) -> None: + import torch + + if not hasattr(torch, "npu") or not torch.npu.is_available(): + print("[WARN] Skip fused_gdn_gating reference check: NPU is not available") + return + + if num_batches <= 0: + raise ValueError(f"num_batches({num_batches}) must be > 0") + if num_batches > compile_max_batch: + raise ValueError( + f"num_batches({num_batches}) must be <= compile_max_batch({compile_max_batch})" + ) + + torch.manual_seed(42) + device = torch.device("npu") + + A_log = torch.randn((num_heads,), device=device, dtype=torch.float32) + a = torch.randn((num_batches, num_heads), device=device, dtype=torch.bfloat16) + b = torch.randn((num_batches, num_heads), device=device, dtype=torch.bfloat16) + dt_bias = torch.randn((num_heads,), device=device, dtype=torch.float32) + g_out = torch.empty((num_batches, num_heads), device=device, dtype=torch.float32) + beta_out = torch.empty( + (num_batches, num_heads), device=device, dtype=torch.bfloat16 + ) + + kernel = fused_gdn_gating_kernel_jit( + num_batches=num_batches, + compile_max_batch=num_batches, + num_heads=num_heads, + ) + kernel( + A_log, + a, + b, + dt_bias, + g_out, + beta_out, + num_batches, + softplus_beta, + softplus_threshold, + ) + torch.npu.synchronize() + + g_ref, beta_ref = _torch_fused_gdn_gating( + A_log=A_log, + a=a, + b=b, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + ) + torch.testing.assert_close(g_out, g_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + beta_out.to(torch.float32), + beta_ref.to(torch.float32), + rtol=1e-2, + atol=1e-2, + ) + print(f"[INFO] fused_gdn_gating output matches torch reference for num_heads={num_heads}") + + +def _run_ref_suite( + *, + num_batches: int, + compile_max_batch: int, + softplus_beta: float, + softplus_threshold: float, + ref_num_heads_list: list[int], +) -> None: + for num_heads in ref_num_heads_list: + _run_ref_check( + num_batches=num_batches, + num_heads=num_heads, + compile_max_batch=compile_max_batch, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate TileLang AscendC source for fused_gdn_gating AOT kernel." + ) + parser.add_argument("--output", type=Path, required=True) + parser.add_argument( + "--batch-size", + type=int, + default=max(BATCH_SIZE_SPECIALIZATIONS), + help=( + "Batch-size specialization used for source generation. " + f"Supported values: {BATCH_SIZE_SPECIALIZATIONS}" + ), + ) + parser.add_argument("--num-heads", type=int, default=DEFAULT_NUM_HEADS) + parser.add_argument("--dtype", type=str, default=DEFAULT_DTYPE) + parser.add_argument( + "--skip-ref-check", + action="store_true", + help="Skip runtime torch-reference check.", + ) + parser.add_argument( + "--ref-num-batches", + type=int, + default=REF_CHECK_NUM_BATCHES, + help="Batch size used by the optional torch-reference check.", + ) + parser.add_argument( + "--softplus-beta", + type=float, + default=1.0, + help="Softplus beta used by the optional torch-reference check.", + ) + parser.add_argument( + "--softplus-threshold", + type=float, + default=20.0, + help="Softplus threshold used by the optional torch-reference check.", + ) + parser.add_argument( + "--ref-num-heads-list", + type=int, + nargs="+", + default=list(REF_CHECK_NUM_HEADS), + help="Head counts covered by the optional bf16 torch-reference test suite.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + source = FusedGdnGatingKernel.generate_source( + batch_size=args.batch_size, + num_heads=args.num_heads, + dtype=args.dtype, + ) + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(source, encoding="utf-8") + + if not args.skip_ref_check: + _run_ref_suite( + num_batches=args.ref_num_batches, + compile_max_batch=DEFAULT_MAX_BATCH, + softplus_beta=args.softplus_beta, + softplus_threshold=args.softplus_threshold, + ref_num_heads_list=args.ref_num_heads_list, + ) + + +if __name__ == "__main__": + main() diff --git a/xllm/compiler/tilelang/targets/ascend/kernels/rope.py b/xllm/compiler/tilelang/targets/ascend/kernels/rope.py new file mode 100644 index 000000000..8b89e1cfc --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/kernels/rope.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 + +import argparse +from pathlib import Path + +import tilelang +import tilelang.language as T + +from .utils import ( + DEFAULT_ASCEND_PASS_CONFIGS, + detect_vec_core_num, +) +from ....common.spec import DispatchField, TilelangKernel, register_kernel + +DEFAULT_HEAD_DIM = 576 +DEFAULT_ROPE_DIM = 64 +DEFAULT_DTYPE = "bf16" +SECONDARY_HEAD_DIM = 128 +SECONDARY_ROPE_DIM = 128 +VEC_NUM = 2 +FIXED_UB_BUFFER_BYTES = 64 * 1024 +REF_CHECK_NUM_TOKENS = 16 +# AOT kernel tensor signatures still require static first-dim bounds. +# Keep a sufficiently large compile-time upper bound so runtime rows +# (`num_tokens * num_heads`) used by wrapper tests stay in-range. +MIN_COMPILE_NUM_TOKENS = 65536 + +# Per-row bytes in UB for this kernel: +# x_half(2) + x(4) + sin_half(2) + sin(4) + cos_half(2) + cos(4) +# + x_rotate(4) + out(4) + mask(4) = 30 bytes per rope element. +UB_BYTES_PER_ROW_PER_ROPE_ELEM = 30 + + +def _derive_max_rows_num_in_ub(rope_dim: int, ub_buffer_bytes: int) -> int: + if ub_buffer_bytes <= 0: + raise ValueError(f"ub_buffer_bytes({ub_buffer_bytes}) must be > 0") + if rope_dim <= 0: + raise ValueError(f"rope_dim({rope_dim}) must be > 0") + + bytes_per_row = UB_BYTES_PER_ROW_PER_ROPE_ELEM * rope_dim + max_rows = ub_buffer_bytes // bytes_per_row + if max_rows <= 0: + raise ValueError( + "UB budget is too small for current rope_dim: " + f"ub_buffer_bytes={ub_buffer_bytes}, rope_dim={rope_dim}" + ) + return max_rows + + +def build_rope_kernel( + head_dim: int, + rope_dim: int, + vec_core_num: int, + ub_buffer_bytes: int, +): + if rope_dim % 2 != 0: + raise ValueError(f"rope_dim({rope_dim}) must be even") + if rope_dim > head_dim: + raise ValueError(f"rope_dim({rope_dim}) must be <= head_dim({head_dim})") + if vec_core_num <= 0: + raise ValueError(f"vec_core_num({vec_core_num}) must be > 0") + if vec_core_num % VEC_NUM != 0: + raise ValueError( + f"vec_core_num({vec_core_num}) must be divisible by VEC_NUM({VEC_NUM})" + ) + + task_num = vec_core_num + m_num = vec_core_num // VEC_NUM + max_rows_num_in_ub = _derive_max_rows_num_in_ub( + rope_dim=rope_dim, + ub_buffer_bytes=ub_buffer_bytes, + ) + # Current AOT path fixes launch block_num at compile time, so runtime input + # shape only changes per-task workload splitting. The tensor signature still + # needs a static upper bound for the first dimension. + compile_num_tokens = max(task_num * max_rows_num_in_ub, MIN_COMPILE_NUM_TOKENS) + compile_flatten_width = compile_num_tokens * head_dim + acc_dtype = "float32" + mask_dtype = "uint32" + + @T.prim_func + def rope_in_place_kernel( + x_in: T.Tensor((1, compile_flatten_width), "bfloat16"), + sin: T.Tensor((compile_num_tokens, rope_dim), "bfloat16"), + cos: T.Tensor((compile_num_tokens, rope_dim), "bfloat16"), + x_out: T.Tensor((1, compile_flatten_width), "bfloat16"), + num_tokens: T.int32, + x_stride: T.int32, + ): + with T.Kernel(m_num, is_npu=True) as (cid, vid): + task_id = cid * VEC_NUM + vid + block_m = (num_tokens + task_num - 1) // task_num + row_start = task_id * block_m + rows_left = T.if_then_else( + num_tokens > row_start, num_tokens - row_start, 0 + ) + num_rows_per_vec = T.if_then_else( + rows_left < block_m, + rows_left, + block_m, + ) + + with T.Scope("V"): + mask_ub = T.alloc_ub([1, rope_dim], mask_dtype) + for j in T.serial(rope_dim // 2): + mask_ub[0, 2 * j] = 4 * (2 * j + 1) + mask_ub[0, 2 * j + 1] = 4 * (2 * j) + + sin_mask_ub = T.alloc_ub((rope_dim,), acc_dtype) + T.tile.fill(sin_mask_ub, 1.0) + for i in T.serial(rope_dim): + if i % 2 == 0: + sin_mask_ub[i] = -1.0 + x_half_ub = T.alloc_shared([1, rope_dim], "bfloat16") + x_ub = T.alloc_shared([1, rope_dim], acc_dtype) + sin_half_ub = T.alloc_shared([1, rope_dim], "bfloat16") + sin_ub = T.alloc_shared([1, rope_dim], acc_dtype) + cos_half_ub = T.alloc_shared([1, rope_dim], "bfloat16") + cos_ub = T.alloc_shared([1, rope_dim], acc_dtype) + x_rotate_ub = T.alloc_shared([1, rope_dim], acc_dtype) + out_ub = T.alloc_shared([1, rope_dim], acc_dtype) + + for row_local in T.serial(num_rows_per_vec): + row = row_start + row_local + row_offset = row * x_stride + T.copy(x_in[0, row_offset], x_half_ub[0, :]) + T.copy(sin[row, :], sin_half_ub[0, :]) + T.copy(cos[row, :], cos_half_ub[0, :]) + + T.tile.cast(x_ub, x_half_ub, "CAST_NONE", rope_dim) + T.tile.cast(sin_ub, sin_half_ub, "CAST_NONE", rope_dim) + T.tile.cast(cos_ub, cos_half_ub, "CAST_NONE", rope_dim) + T.tile.mul(sin_ub[0, :], sin_ub[0, :], sin_mask_ub) + + T.tile.gather(x_rotate_ub, x_ub, mask_ub, 0) + T.tile.mul(x_ub, x_ub, cos_ub) + T.tile.mul(x_rotate_ub, x_rotate_ub, sin_ub) + T.tile.add(out_ub, x_ub, x_rotate_ub) + T.tile.cast(x_half_ub, out_ub, "CAST_RINT", rope_dim) + T.copy(x_half_ub[0, :], x_out[0, row_offset]) + + return rope_in_place_kernel + + +@tilelang.jit(pass_configs=DEFAULT_ASCEND_PASS_CONFIGS) +def rope_in_place_kernel_jit( + head_dim: int, + rope_dim: int, + vec_core_num: int, + ub_buffer_bytes: int, +): + return build_rope_kernel( + head_dim=head_dim, + rope_dim=rope_dim, + vec_core_num=vec_core_num, + ub_buffer_bytes=ub_buffer_bytes, + ) + + +@register_kernel +class RopeKernel(TilelangKernel): + DISPATCH_SCHEMA = [ + DispatchField("head_dim", "int32"), + DispatchField("rope_dim", "int32"), + DispatchField("dtype", "dtype"), + ] + SPECIALIZATIONS = [ + { + "variant_key": "hd128_rd128_bf16", + "head_dim": SECONDARY_HEAD_DIM, + "rope_dim": SECONDARY_ROPE_DIM, + "dtype": DEFAULT_DTYPE, + }, + { + "variant_key": "hd576_rd64_bf16", + "head_dim": DEFAULT_HEAD_DIM, + "rope_dim": DEFAULT_ROPE_DIM, + "dtype": DEFAULT_DTYPE, + }, + ] + + @staticmethod + def generate_source(head_dim: int, rope_dim: int, dtype: str) -> str: + if dtype != DEFAULT_DTYPE: + raise ValueError( + f"RoPE TileLang kernel only supports dtype={DEFAULT_DTYPE}, got {dtype}" + ) + tilelang.disable_cache() + vec_core_num = detect_vec_core_num() + ub_buffer_bytes = FIXED_UB_BUFFER_BYTES + tilelang_kernel = build_rope_kernel( + head_dim=head_dim, + rope_dim=rope_dim, + vec_core_num=vec_core_num, + ub_buffer_bytes=ub_buffer_bytes, + ) + with tilelang.tvm.transform.PassContext( + opt_level=3, config=DEFAULT_ASCEND_PASS_CONFIGS + ): + kernel = tilelang.engine.lower(tilelang_kernel) + return kernel.kernel_source + + +def _torch_rope_ref_rows( + x: "torch.Tensor", + sin: "torch.Tensor", + cos: "torch.Tensor", + dim_start: int, +) -> "torch.Tensor": + import torch + + x_fp32 = x.to(torch.float32) + sin_fp32 = sin.to(torch.float32) + cos_fp32 = cos.to(torch.float32) + rope_dim = sin_fp32.shape[1] + x_part = x_fp32[:, dim_start : dim_start + rope_dim] + x_reshape = x_part.reshape(x_part.shape[0], -1, 2) + x0 = x_reshape[:, :, 0] + x1 = x_reshape[:, :, 1] + x_rot = torch.stack([-x1, x0], dim=-1).reshape_as(x_part) + + out = x.clone() + out[:, dim_start : dim_start + rope_dim] = ( + x_part * cos_fp32 + x_rot * sin_fp32 + ).to(torch.bfloat16) + return out + + +def _run_ref_check( + num_tokens: int, + head_dim: int, + rope_dim: int, + vec_core_num: int, + ub_buffer_bytes: int, +) -> None: + import torch + + if not hasattr(torch, "npu") or not torch.npu.is_available(): + print("[WARN] Skip RoPE reference check: NPU is not available") + return + + torch.manual_seed(42) + device = torch.device("npu") + x_in = torch.randn((num_tokens, head_dim), device=device, dtype=torch.bfloat16) + sin = torch.randn((num_tokens, rope_dim), device=device, dtype=torch.bfloat16) + cos = torch.randn((num_tokens, rope_dim), device=device, dtype=torch.bfloat16) + x_out = x_in.clone() + x_in_flat = x_in.view(1, -1) + x_out_flat = x_out.view(1, -1) + kernel = rope_in_place_kernel_jit( + head_dim=head_dim, + rope_dim=rope_dim, + vec_core_num=vec_core_num, + ub_buffer_bytes=ub_buffer_bytes, + ) + kernel(x_in_flat, sin, cos, x_out_flat, num_tokens, head_dim) + torch.npu.synchronize() + + x_ref = _torch_rope_ref_rows(x_in, sin, cos, 0) + torch.testing.assert_close(x_out, x_ref, rtol=1e-3, atol=1e-3) + print("[INFO] RoPE output matches torch reference") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate TileLang AscendC source for RoPE AOT kernel." + ) + parser.add_argument("--output", required=True, help="Output AscendC .cpp file") + parser.add_argument("--head-dim", type=int, default=DEFAULT_HEAD_DIM) + parser.add_argument("--rope-dim", type=int, default=DEFAULT_ROPE_DIM) + parser.add_argument("--dtype", default=DEFAULT_DTYPE) + parser.add_argument( + "--skip-ref-check", + action="store_true", + help="Skip runtime torch-reference check.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + output = Path(args.output).resolve() + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text( + RopeKernel.generate_source( + head_dim=args.head_dim, + rope_dim=args.rope_dim, + dtype=args.dtype, + ), + encoding="utf-8", + ) + + if not args.skip_ref_check: + _run_ref_check( + num_tokens=REF_CHECK_NUM_TOKENS, + head_dim=args.head_dim, + rope_dim=args.rope_dim, + vec_core_num=detect_vec_core_num(), + ub_buffer_bytes=FIXED_UB_BUFFER_BYTES, + ) + + +if __name__ == "__main__": + main() diff --git a/xllm/compiler/tilelang/targets/ascend/kernels/utils.py b/xllm/compiler/tilelang/targets/ascend/kernels/utils.py new file mode 100644 index 000000000..7a5744e34 --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/kernels/utils.py @@ -0,0 +1,249 @@ +import tilelang +from typing import Any + +from ....common.manifest import KernelAbi, KernelVariantManifest +from ....common.spec import DispatchField + +DEFAULT_ASCEND_PASS_CONFIGS = { + # Use raw pass-config strings to avoid hard dependency on + # tilelang.PassConfigKey export timing/version. + "tl.ascend_auto_sync": True, + "tl.ascend_memory_planning": True, + "tl.ascend_auto_cross_core_sync": True, + "tl.ascend_auto_cv_combine": True, +} + +DEFAULT_ASCEND_BISHENG_ARCH = "dav-c220" +ASCEND_VEC_CORE_NUM_PROPERTY_KEYS = ( + "vector_core_num", + "aiv_core_num", + "vec_core_num", +) + + +def detect_vec_core_num(default_vec_core_num: int = 48) -> int: + try: + import torch + + if hasattr(torch, "npu") and torch.npu.is_available(): + props = torch.npu.get_device_properties(torch.npu.current_device()) + for key in ASCEND_VEC_CORE_NUM_PROPERTY_KEYS: + value = getattr(props, key, None) + if isinstance(value, int) and value > 0: + return value + except Exception: + pass + + return default_vec_core_num + + +def _snake_to_pascal(name: str) -> str: + parts = [part for part in name.split("_") if part] + return "".join(part[:1].upper() + part[1:] for part in parts) + + +def _dispatch_field_suffix(name: str) -> str: + parts = [part for part in name.split("_") if part] + mapped = {"dtype": "DType"} + return "".join(mapped.get(part, part[:1].upper() + part[1:]) for part in parts) + + +def _dtype_enum_suffix(dtype_name: str) -> str: + common_suffixes = { + "bf16": "BF16", + "fp16": "Float16", + "fp32": "Float32", + "float16": "Float16", + "float32": "Float32", + "int8": "Int8", + "int32": "Int32", + "uint8": "UInt8", + } + if dtype_name in common_suffixes: + return common_suffixes[dtype_name] + return _snake_to_pascal(dtype_name) + + +def _dispatch_field_cpp_type(field: DispatchField) -> str: + field_types = { + "int32": "int32_t", + "dtype": "TilelangDType", + } + return field_types[field.kind] + + +def _render_dispatch_value_literal(*, field: DispatchField, value: Any) -> str: + if field.kind == "int32": + if not isinstance(value, int) or isinstance(value, bool): + raise TypeError( + f"Unsupported int32 dispatch value for {field.name!r}: {value!r}" + ) + return str(value) + if field.kind == "dtype": + if not isinstance(value, str): + raise TypeError( + f"Unsupported dtype dispatch value for {field.name!r}: {value!r}" + ) + dtype_suffix = _dtype_enum_suffix(value) + return f"TilelangDType::k{dtype_suffix}" + raise TypeError( + f"Unsupported dispatch field kind {field.kind!r} for {field.name!r}" + ) + + +def _validate_dispatch_values( + *, kernel_name: str, dispatch_schema: list[DispatchField], variant: KernelVariantManifest +) -> None: + schema_names = [field.name for field in dispatch_schema] + missing_keys = [name for name in schema_names if name not in variant.dispatch_values] + extra_keys = [name for name in variant.dispatch_values if name not in schema_names] + if missing_keys or extra_keys: + raise ValueError( + f"TileLang kernel family {kernel_name!r} variant " + f"{variant.variant_key!r} has inconsistent dispatch values: " + f"missing={missing_keys}, extra={extra_keys}" + ) + + +def render_family_variants_inc( + *, + kernel_name: str, + dispatch_schema: list[DispatchField], + variants: list[KernelVariantManifest], +) -> str: + if not variants: + return "" + + if not dispatch_schema: + raise ValueError( + f"TileLang kernel family {kernel_name!r} has empty dispatch schema" + ) + + macro_name = f"XLLM_TL_{kernel_name.upper()}_VARIANT" + lines: list[str] = [] + + for variant in variants: + _validate_dispatch_values( + kernel_name=kernel_name, + dispatch_schema=dispatch_schema, + variant=variant, + ) + variant_args = [ + _render_dispatch_value_literal( + field=field, + value=variant.dispatch_values[field.name], + ) + for field in dispatch_schema + ] + variant_args.append(f"\"{variant.variant_key}\"") + variant_args.append(variant.entry_symbol) + lines.append(f"{macro_name}({', '.join(variant_args)})") + + return "\n".join(lines) + "\n" + + +def render_family_registry_inc( + *, + kernel_name: str, + dispatch_schema: list[DispatchField], + kernel_abi: KernelAbi, + variants: list[KernelVariantManifest], +) -> str: + if not variants: + return "" + + if not dispatch_schema: + raise ValueError( + f"TileLang kernel family {kernel_name!r} has empty dispatch schema" + ) + + family_prefix = _snake_to_pascal(kernel_name) + specialization_type = f"{family_prefix}Specialization" + kernel_fn_type = f"{family_prefix}KernelFn" + registry_name = f"k{family_prefix}Registry" + entry_type = f"KernelEntry<{specialization_type}, {kernel_fn_type}>" + field_wrapper_types = { + field.name: f"{family_prefix}{_dispatch_field_suffix(field.name)}" + for field in dispatch_schema + } + + symbol_declarations: list[str] = [] + registry_entries: list[str] = [] + for variant in variants: + _validate_dispatch_values( + kernel_name=kernel_name, + dispatch_schema=dispatch_schema, + variant=variant, + ) + specialization_args = ", ".join( + f"{field_wrapper_types[field.name]}{{" + f"{_render_dispatch_value_literal(field=field, value=variant.dispatch_values[field.name])}" + f"}}" + for field in dispatch_schema + ) + symbol_declarations.append( + f'extern "C" function_type_t<{kernel_fn_type}> {variant.entry_symbol};' + ) + registry_entries.append( + f' {entry_type}{{make_{kernel_name}_specialization({specialization_args}), ' + f'"{variant.variant_key}", &{variant.entry_symbol}}},' + ) + + struct_fields = [ + f" {_dispatch_field_cpp_type(field)} {field.name};" for field in dispatch_schema + ] + equality_terms = [f"lhs.{field.name} == rhs.{field.name}" for field in dispatch_schema] + builder_params = [ + f"{field_wrapper_types[field.name]} {field.name}" for field in dispatch_schema + ] + builder_values = ", ".join(f"{field.name}.value" for field in dispatch_schema) + function_params = ", ".join( + f"{parameter.cpp_type} {parameter.name}" for parameter in kernel_abi.parameters + ) + + lines = [ + f"struct {specialization_type} {{", + *struct_fields, + "};", + "", + f"constexpr bool operator==(const {specialization_type}& lhs,", + f" const {specialization_type}& rhs) {{", + " return " + " && ".join(equality_terms) + ";", + "}", + "", + ] + for field in dispatch_schema: + lines.extend( + [ + f"struct {field_wrapper_types[field.name]} {{", + f" {_dispatch_field_cpp_type(field)} value;", + "};", + "", + ] + ) + lines.extend( + [ + f"constexpr {specialization_type} make_{kernel_name}_specialization(", + " " + ", ".join(builder_params) + ") {", + f" return {specialization_type}{{{builder_values}}};", + "}", + "", + f"using {kernel_fn_type} = {kernel_abi.return_type} (*)({function_params});", + "", + *symbol_declarations, + "", + f"constexpr std::array<{entry_type}, {len(variants)}> {registry_name}{{{{", + *registry_entries, + "}};", + "", + f"inline const {entry_type}* find_{kernel_name}_kernel_entry(", + f" const {specialization_type}& specialization) {{", + f" return find_kernel_entry({registry_name}, specialization);", + "}", + "", + f"inline std::string available_{kernel_name}_variant_keys() {{", + f" return available_variant_keys({registry_name});", + "}", + ] + ) + return "\n".join(lines) + "\n" diff --git a/xllm/compiler/tilelang/targets/ascend/toolchain.py b/xllm/compiler/tilelang/targets/ascend/toolchain.py new file mode 100644 index 000000000..06f9d50af --- /dev/null +++ b/xllm/compiler/tilelang/targets/ascend/toolchain.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +from ...common.toolchain import git_head, require_env +from .kernels.utils import DEFAULT_ASCEND_BISHENG_ARCH + +TILELANG_BISHENG_COMMON_FLAGS = [ + "-O2", + "-std=gnu++17", + "-xcce", + "-fPIC", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "-DL2_CACHE_HINT", + "-DBACKEND_HYBM", +] + +ASCEND_DEVICE_TO_BISHENG_ARCH = { + "a2": DEFAULT_ASCEND_BISHENG_ARCH, + "a3": DEFAULT_ASCEND_BISHENG_ARCH, +} + + +@dataclass(frozen=True) +class AscendBuildContext: + device: str | None + bisheng_arch: str + bisheng_executable: str + toolchain_options: dict[str, str] + fingerprint: dict[str, str] + include_dirs: list[str] + + +def normalize_ascend_device(device: str | None) -> str | None: + if device is None: + return None + normalized = device.strip().lower() + if not normalized: + return None + if normalized not in ASCEND_DEVICE_TO_BISHENG_ARCH: + supported = ", ".join(sorted(ASCEND_DEVICE_TO_BISHENG_ARCH)) + raise ValueError( + f"Unsupported Ascend TileLang device {device!r}. Expected one of: " + f"{supported}" + ) + return normalized + + +def resolve_bisheng_arch(device: str | None) -> tuple[str | None, str]: + normalized_device = normalize_ascend_device(device) + if normalized_device is None: + print( + "[WARN] TileLang Ascend build did not receive --device. Falling back " + f"to default bisheng_arch={DEFAULT_ASCEND_BISHENG_ARCH}. Prefer " + "running via xLLM main build path or pass --device a2|a3 explicitly." + ) + return None, DEFAULT_ASCEND_BISHENG_ARCH + return normalized_device, ASCEND_DEVICE_TO_BISHENG_ARCH[normalized_device] + + +def build_toolchain_options(device: str | None, bisheng_arch: str) -> dict[str, str]: + toolchain_options = {"bisheng_arch": bisheng_arch} + if device is not None: + toolchain_options["device"] = device + return toolchain_options + + +def resolve_npu_home_path() -> str: + for env_name in ("NPU_HOME_PATH", "NPU_TOOLKIT_HOME"): + value = os.environ.get(env_name, "").strip() + if value: + return value + + for candidate in ( + "/usr/local/Ascend/ascend-toolkit/latest", + "/usr/local/Ascend/ascend-toolkit", + ): + if Path(candidate).exists(): + return candidate + + raise RuntimeError( + "Required NPU toolkit root is not set. Expected NPU_HOME_PATH or " + "NPU_TOOLKIT_HOME, or a standard install path under " + "/usr/local/Ascend/ascend-toolkit." + ) + + +def bisheng_include_dirs() -> list[str]: + tl_root = require_env("TL_ROOT") + npu_home_path = resolve_npu_home_path() + return [ + f"{npu_home_path}/include", + f"{npu_home_path}/include/experiment/runtime", + f"{npu_home_path}/include/experiment/msprof", + f"{npu_home_path}/compiler/tikcpp", + f"{npu_home_path}/compiler/tikcpp/tikcfw", + f"{npu_home_path}/compiler/tikcpp/tikcfw/impl", + f"{npu_home_path}/compiler/tikcpp/tikcfw/interface", + f"{tl_root}/3rdparty/catlass/include", + f"{tl_root}/3rdparty/shmem/include", + f"{tl_root}/3rdparty/shmem/src/device", + f"{tl_root}/src", + ] + + +def build_fingerprint(bisheng_executable: str, bisheng_arch: str) -> dict[str, str]: + tl_root = require_env("TL_ROOT") + npu_home_path = resolve_npu_home_path() + return { + "target": "ascend", + "tl_root": tl_root, + "tilelang_git_head": git_head(tl_root), + "npu_home_path": npu_home_path, + "bisheng_executable": bisheng_executable, + "bisheng_arch": bisheng_arch, + } + + +def resolve_build_context(device: str | None, bisheng_executable: str) -> AscendBuildContext: + normalized_device, bisheng_arch = resolve_bisheng_arch(device) + fingerprint = build_fingerprint(bisheng_executable, bisheng_arch) + if normalized_device is not None: + fingerprint["device"] = normalized_device + return AscendBuildContext( + device=normalized_device, + bisheng_arch=bisheng_arch, + bisheng_executable=bisheng_executable, + toolchain_options=build_toolchain_options(normalized_device, bisheng_arch), + fingerprint=fingerprint, + include_dirs=bisheng_include_dirs(), + ) diff --git a/xllm/compiler/tilelang/targets/cuda/build.py b/xllm/compiler/tilelang/targets/cuda/build.py new file mode 100644 index 000000000..b8cff4571 --- /dev/null +++ b/xllm/compiler/tilelang/targets/cuda/build.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from pathlib import Path + +from ...common.manifest import KernelFamilyManifest + + +def build_kernels( + output_root: str | Path, + kernel_names: list[str] | None = None, + force: bool = False, +) -> list[KernelFamilyManifest]: + if kernel_names: + raise NotImplementedError( + "CUDA TileLang AOT build pipeline is scaffolded but no kernels are " + "registered yet." + ) + return [] diff --git a/xllm/compiler/tilelang/tilelang_ascend_install.py b/xllm/compiler/tilelang/tilelang_ascend_install.py new file mode 100644 index 000000000..f5222b73e --- /dev/null +++ b/xllm/compiler/tilelang/tilelang_ascend_install.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import os +import shlex +import subprocess +from dataclasses import dataclass +from pathlib import Path + +from scripts.build_support.env import set_npu_envs + +from .common.toolchain import ( + default_tilelang_root, + git_head, + prepare_tilelang_import, + repo_root, + resolve_tilelang_root, +) + +PREPARE_ASCEND_COMMAND = "python xllm/compiler/tilelang_launcher.py prepare-ascend" + + +@dataclass(frozen=True) +class TilelangPrepareState: + tilelang_root: Path + cann_set_env: Path + current_head: str + cached_head: str | None + artifacts_ready: bool + import_ok: bool + import_detail: str + + +def _ready_error(message: str) -> RuntimeError: + return RuntimeError(f"{message}\nRun `{PREPARE_ASCEND_COMMAND}` first.") + + +def _find_cann_set_env() -> Path | None: + candidates: list[Path] = [] + npu_home_path = os.environ.get("NPU_HOME_PATH", "").strip() + if npu_home_path: + toolkit_root = Path(npu_home_path).resolve() + candidates.append(toolkit_root / "set_env.sh") + candidates.append(toolkit_root.parent / "set_env.sh") + + candidates.extend( + [ + Path("/usr/local/Ascend/ascend-toolkit/set_env.sh"), + Path("/usr/local/Ascend/ascend-toolkit/latest/set_env.sh"), + ] + ) + + for script in candidates: + if script.is_file(): + return script.resolve() + return None + + +def resolve_cann_set_env() -> Path: + cann_set_env = _find_cann_set_env() + if cann_set_env is not None: + return cann_set_env + + set_npu_envs() + cann_set_env = _find_cann_set_env() + if cann_set_env is not None: + return cann_set_env + + raise RuntimeError( + "[ERROR] Cannot find CANN set_env.sh. Expected a path like " + "/usr/local/Ascend/ascend-toolkit/set_env.sh." + ) + + +def ensure_tilelang_submodules(tilelang_root: str | Path) -> Path: + tl_root = Path(tilelang_root).resolve() + required_markers = { + "3rdparty/catlass/CMakeLists.txt": tl_root / "3rdparty" / "catlass" / "CMakeLists.txt", + "3rdparty/composable_kernel/CMakeLists.txt": ( + tl_root / "3rdparty" / "composable_kernel" / "CMakeLists.txt" + ), + "3rdparty/cutlass/CMakeLists.txt": tl_root / "3rdparty" / "cutlass" / "CMakeLists.txt", + "3rdparty/pto-isa/CMakeLists.txt": tl_root / "3rdparty" / "pto-isa" / "CMakeLists.txt", + "3rdparty/shmem/CMakeLists.txt": tl_root / "3rdparty" / "shmem" / "CMakeLists.txt", + "3rdparty/tvm/CMakeLists.txt": tl_root / "3rdparty" / "tvm" / "CMakeLists.txt", + } + missing = [name for name, path in required_markers.items() if not path.is_file()] + if missing: + if (tl_root / ".git").exists(): + repair_hint = ( + "Run " + f"`git -C {shlex.quote(str(tl_root))} submodule update --init --recursive` " + "first." + ) + else: + bundled_root = default_tilelang_root().resolve() + bundled_repair_cmd = ( + f"git -C {shlex.quote(str(repo_root()))} " + "submodule update --init --recursive third_party/tilelang-ascend" + ) + if tl_root == bundled_root: + repair_hint = ( + "Sync the bundled TileLang checkout from the xLLM repo root: " + f"`{bundled_repair_cmd}`." + ) + else: + repair_hint = ( + f"`TL_ROOT={tl_root}` is not a git checkout. " + "Point TL_ROOT at a fully initialized tilelang-ascend clone, " + f"or sync the bundled checkout with `{bundled_repair_cmd}`." + ) + raise RuntimeError( + "[ERROR] tilelang-ascend nested dependencies are incomplete: " + f"missing {', '.join(missing)}. " + f"{repair_hint}" + ) + return tl_root + + +def tilelang_git_head_cache_path(tilelang_root: str | Path) -> Path: + return Path(tilelang_root).resolve() / "build" / ".xllm_tilelang_git_head_cached" + + +def read_tilelang_git_head_cached(tilelang_root: str | Path) -> str | None: + cache_path = tilelang_git_head_cache_path(tilelang_root) + if not cache_path.is_file(): + return None + value = cache_path.read_text(encoding="utf-8").strip() + return value or None + + +def write_tilelang_git_head_cached(tilelang_root: str | Path, head: str) -> None: + cache_path = tilelang_git_head_cache_path(tilelang_root) + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_text(head + "\n", encoding="utf-8") + + +def tilelang_artifacts_ready(tilelang_root: str | Path) -> bool: + tl_root = Path(tilelang_root).resolve() + required = [ + tl_root / "build" / "libtilelang_module.so", + tl_root / "build" / "libtilelang.so", + tl_root / "build" / "tvm" / "libtvm.so", + ] + return all(path.exists() for path in required) + + +def verify_tilelang_import(tilelang_root: str | Path) -> tuple[bool, str]: + tl_root = prepare_tilelang_import(tilelang_root) + env = os.environ.copy() + env["TL_ROOT"] = str(tl_root) + pythonpath = env.get("PYTHONPATH", "") + pythonpath_items = [item for item in pythonpath.split(os.pathsep) if item] + if str(tl_root) not in pythonpath_items: + pythonpath_items.insert(0, str(tl_root)) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_items) + cmd = [ + "bash", + "-lc", + "python - <<'PY'\n" + "import tilelang\n" + "print(getattr(tilelang, '__file__', ''))\n" + "PY", + ] + result = subprocess.run( + cmd, + text=True, + capture_output=True, + check=False, + env=env, + ) + if result.returncode != 0: + detail = (result.stderr or result.stdout).strip() + return False, detail + return True, result.stdout.strip() + + +def _tilelang_patch_dir() -> Path: + return Path(__file__).resolve().parent / "patches" / "tilelang_ascend" + + +def _tilelang_install_patch_path() -> Path: + patch_path = _tilelang_patch_dir() / "0001-install-ascend.patch" + if not patch_path.is_file(): + raise RuntimeError(f"[ERROR] Missing TileLang patch file: {patch_path}") + return patch_path + + +def _git_apply_base_cmd(repo_root: Path) -> list[str]: + return [ + "git", + "-c", + f"safe.directory={repo_root}", + "-C", + str(repo_root), + "apply", + "--whitespace=nowarn", + ] + + +def _check_git_patch_state(repo_root: Path, patch_path: Path) -> str: + apply_check = subprocess.run( + _git_apply_base_cmd(repo_root) + ["--check", str(patch_path)], + text=True, + capture_output=True, + check=False, + ) + if apply_check.returncode == 0: + return "unapplied" + + reverse_check = subprocess.run( + _git_apply_base_cmd(repo_root) + ["--reverse", "--check", str(patch_path)], + text=True, + capture_output=True, + check=False, + ) + if reverse_check.returncode == 0: + return "applied" + + apply_detail = (apply_check.stderr or apply_check.stdout).strip() + reverse_detail = (reverse_check.stderr or reverse_check.stdout).strip() + raise RuntimeError( + "[ERROR] Failed to match TileLang patch " + f"{patch_path.name} against {repo_root}.\n" + f"apply --check: {apply_detail or ''}\n" + f"reverse --check: {reverse_detail or ''}" + ) + + +def _apply_git_patch(repo_root: Path, patch_path: Path, message: str) -> None: + if _check_git_patch_state(repo_root, patch_path) == "applied": + return + subprocess.check_call(_git_apply_base_cmd(repo_root) + [str(patch_path)]) + print(message) + + +def _restore_git_patch(repo_root: Path, patch_path: Path) -> None: + if _check_git_patch_state(repo_root, patch_path) != "applied": + return + subprocess.check_call(_git_apply_base_cmd(repo_root) + ["--reverse", str(patch_path)]) + + +def _patch_tilelang_install_tree(tilelang_root: str | Path) -> None: + tl_root = Path(tilelang_root).resolve() + required_files = ( + tl_root / "install_ascend.sh", + tl_root / "requirements-build.txt", + ) + missing = [str(path.name) for path in required_files if not path.is_file()] + if missing: + raise RuntimeError( + "[ERROR] Missing tilelang install files: " + ", ".join(missing) + ) + _apply_git_patch( + tl_root, + _tilelang_install_patch_path(), + "[INFO] Applied tilelang install patch", + ) + + +def _restore_tilelang_install_tree(tilelang_root: str | Path) -> None: + tl_root = Path(tilelang_root).resolve() + _restore_git_patch(tl_root, _tilelang_install_patch_path()) + + +def _run_tilelang_install(tilelang_root: str | Path, cann_set_env: str | Path) -> None: + tl_root = ensure_tilelang_submodules(tilelang_root) + _patch_tilelang_install_tree(tl_root) + + cmd = ( + f"source {shlex.quote(str(cann_set_env))} && " + "bash install_ascend.sh && " + "source set_env.sh" + ) + env = os.environ.copy() + git_config_count = int(env.get("GIT_CONFIG_COUNT", "0") or "0") + env[f"GIT_CONFIG_KEY_{git_config_count}"] = "safe.directory" + env[f"GIT_CONFIG_VALUE_{git_config_count}"] = str(tl_root) + env["GIT_CONFIG_COUNT"] = str(git_config_count + 1) + try: + subprocess.check_call( + ["bash", "-lc", cmd], + cwd=str(tl_root), + env=env, + ) + finally: + _restore_tilelang_install_tree(tl_root) + + +def collect_prepare_state() -> TilelangPrepareState: + tilelang_root = ensure_tilelang_submodules(resolve_tilelang_root()) + prepare_tilelang_import(tilelang_root) + return TilelangPrepareState( + tilelang_root=tilelang_root, + cann_set_env=resolve_cann_set_env(), + current_head=git_head(tilelang_root), + cached_head=read_tilelang_git_head_cached(tilelang_root), + artifacts_ready=tilelang_artifacts_ready(tilelang_root), + import_ok=False, + import_detail="", + ) + + +def refresh_prepare_state_import(state: TilelangPrepareState) -> TilelangPrepareState: + import_ok, import_detail = verify_tilelang_import(state.tilelang_root) + return TilelangPrepareState( + tilelang_root=state.tilelang_root, + cann_set_env=state.cann_set_env, + current_head=state.current_head, + cached_head=state.cached_head, + artifacts_ready=tilelang_artifacts_ready(state.tilelang_root), + import_ok=import_ok, + import_detail=import_detail, + ) + + +def prepare_state() -> TilelangPrepareState: + return refresh_prepare_state_import(collect_prepare_state()) + + +def install_reasons(state: TilelangPrepareState, *, force: bool) -> list[str]: + reasons: list[str] = [] + if force: + reasons.append("forced") + if state.cached_head is None: + reasons.append("HEAD cache missing") + elif state.current_head != state.cached_head: + reasons.append("HEAD changed") + if not state.artifacts_ready: + reasons.append("artifacts missing") + if not state.import_ok: + reasons.append("tilelang import failed") + return list(dict.fromkeys(reasons)) + + +def ensure_ascend_ready() -> Path: + set_npu_envs() + state = prepare_state() + + if not state.artifacts_ready: + raise _ready_error( + "[ERROR] tilelang-ascend artifacts are missing under " + f"{state.tilelang_root / 'build'}." + ) + + if not state.import_ok: + raise _ready_error( + "[ERROR] Failed to import tilelang after configuring TL_ROOT=" + f"{state.tilelang_root}: {state.import_detail}" + ) + + return state.tilelang_root + + +def prepare_ascend(*, force: bool = False) -> Path: + set_npu_envs() + state = prepare_state() + reasons = install_reasons(state, force=force) + + if reasons: + print("[INFO] Preparing tilelang-ascend: " + "; ".join(reasons)) + _run_tilelang_install(state.tilelang_root, state.cann_set_env) + prepare_tilelang_import(state.tilelang_root) + write_tilelang_git_head_cached(state.tilelang_root, state.current_head) + state = prepare_state() + + if not state.artifacts_ready: + raise RuntimeError( + "[ERROR] tilelang-ascend artifacts are still missing after prepare." + ) + + if not state.import_ok: + raise RuntimeError( + "[ERROR] tilelang import still failed after prepare: " + f"{state.import_detail}" + ) + + print(f"[INFO] tilelang import success: {state.import_detail}") + return state.tilelang_root diff --git a/xllm/compiler/tilelang_launcher.py b/xllm/compiler/tilelang_launcher.py new file mode 100644 index 000000000..fc15bd0c0 --- /dev/null +++ b/xllm/compiler/tilelang_launcher.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + + +def _bootstrap_import_paths() -> None: + compiler_dir = Path(__file__).resolve().parent + package_root = compiler_dir.parent + repo_root = package_root.parent + for path in (repo_root, package_root): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Source-tree launcher for xLLM TileLang prepare/compile flows." + ) + subparsers = parser.add_subparsers(dest="command") + subparsers.required = True + subparsers.add_parser( + "prepare-ascend", + add_help=False, + help="Prepare third_party/tilelang-ascend for Ascend TileLang builds.", + ) + subparsers.add_parser( + "compile-kernels", + add_help=False, + help="Compile TileLang kernels and emit manifests.", + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + parser = _build_parser() + args, remainder = parser.parse_known_args(argv) + + _bootstrap_import_paths() + + if args.command == "prepare-ascend": + from compiler.tilelang.cli.prepare_ascend import main as entrypoint + elif args.command == "compile-kernels": + from compiler.tilelang.cli.compile_kernels import main as entrypoint + else: # pragma: no cover - argparse enforces choices + raise ValueError(f"Unsupported TileLang launcher command: {args.command}") + + entrypoint(remainder) + + +if __name__ == "__main__": + main() diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index 18c062542..acddc10ce 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -15,7 +15,6 @@ cc_library( $<$:mspti_helper.h> options.h rate_limiter.h - rec_model_utils.h types.h device_monitor.h version_singleton.h diff --git a/xllm/core/common/etcd_client.cpp b/xllm/core/common/etcd_client.cpp index 14e2028e3..73c1bf51d 100644 --- a/xllm/core/common/etcd_client.cpp +++ b/xllm/core/common/etcd_client.cpp @@ -21,17 +21,43 @@ limitations under the License. #include #include +#include "xllm/core/common/etcd_utils.h" + namespace xllm { -EtcdClient::EtcdClient(const std::string& etcd_addr) - : client_(etcd_addr), etcd_addr_(etcd_addr) { - auto response = client_.put("XLLM_PING", "PING"); +EtcdClient::EtcdClient(const std::string& etcd_addr, + const std::string& etcd_namespace) + : client_(etcd_addr), + etcd_addr_(etcd_addr), + etcd_namespace_prefix_(normalize_etcd_namespace(etcd_namespace)) { + auto response = client_.put(namespaced_key("XLLM_PING"), "PING"); + if (!response.is_ok()) { + LOG(FATAL) << "etcd connect to etcd server failed: " + << response.error_message(); + } +} + +EtcdClient::EtcdClient(const std::string& etcd_addr, + const std::string& username, + const std::string& password, + const std::string& etcd_namespace) + : client_(etcd_addr, username, password), + etcd_addr_(etcd_addr), + etcd_namespace_prefix_(normalize_etcd_namespace(etcd_namespace)) { + auto response = client_.put(namespaced_key("XLLM_PING"), "PING"); if (!response.is_ok()) { LOG(FATAL) << "etcd connect to etcd server failed: " << response.error_message(); } } +std::string EtcdClient::namespaced_key(const std::string& logical_key) const { + if (etcd_namespace_prefix_.empty()) { + return logical_key; + } + return etcd_namespace_prefix_ + logical_key; +} + EtcdClient::~EtcdClient() { stop_watch(); @@ -52,10 +78,14 @@ void EtcdClient::add_watch(const std::string& key_prefix, if (watchers_.find(key_prefix) != watchers_.end()) { watchers_[key_prefix].watcher->Cancel(); } + + uint64_t prefix_len = etcd_namespace_prefix_.size(); + auto bound_callback = std::bind(callback, std::placeholders::_1, prefix_len); + auto watcher = std::make_unique( client_, - key_prefix, - [callback](etcd::Response response) { callback(response); }, + namespaced_key(key_prefix), + [bound_callback](etcd::Response response) { bound_callback(response); }, recursive); watchers_[key_prefix] = {std::move(watcher), callback}; @@ -83,7 +113,7 @@ void EtcdClient::stop_watch() { bool EtcdClient::get_master_service(const std::string& key, std::string* values) { - auto response = client_.get(key); + auto response = client_.get(namespaced_key(key)); if (!response.is_ok()) { LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); return false; @@ -93,7 +123,7 @@ bool EtcdClient::get_master_service(const std::string& key, } bool EtcdClient::get(const std::string& key, std::string* value) { - auto response = client_.get(key); + auto response = client_.get(namespaced_key(key)); if (!response.is_ok()) { LOG(ERROR) << "etcd get " << key << " failed: " << response.error_message(); return false; @@ -107,7 +137,7 @@ bool EtcdClient::get(const std::string& key, std::string* value) { bool EtcdClient::get_all_xservices(const std::string& key_prefix, std::vector* values) { - auto response = client_.ls(key_prefix); + auto response = client_.ls(namespaced_key(key_prefix)); if (!response.is_ok()) { LOG(ERROR) << "etcd get " << key_prefix << " failed: " << response.error_message(); @@ -131,7 +161,7 @@ bool EtcdClient::register_instance(const std::string& key, const std::string& value, const int ttl) { auto keep_alive = std::make_shared(&client_, ttl); - auto response = client_.put(key, value, keep_alive->Lease()); + auto response = client_.put(namespaced_key(key), value, keep_alive->Lease()); if (!response.is_ok()) { LOG(ERROR) << "etcd set " << key << " failed: " << response.error_message(); keep_alive->Cancel(); diff --git a/xllm/core/common/etcd_client.h b/xllm/core/common/etcd_client.h index 145c72d9a..a709e4b2e 100644 --- a/xllm/core/common/etcd_client.h +++ b/xllm/core/common/etcd_client.h @@ -25,11 +25,16 @@ limitations under the License. namespace xllm { -using Callback = std::function; +using Callback = std::function; class EtcdClient { public: - EtcdClient(const std::string& etcd_addr); + explicit EtcdClient(const std::string& etcd_addr, + const std::string& etcd_namespace = ""); + EtcdClient(const std::string& etcd_addr, + const std::string& username, + const std::string& password, + const std::string& etcd_namespace = ""); ~EtcdClient(); void add_watch(const std::string& key_prefix, @@ -51,6 +56,8 @@ class EtcdClient { std::vector* values); private: + std::string namespaced_key(const std::string& logical_key) const; + struct WatcherInfo { std::unique_ptr watcher; Callback callback; @@ -58,6 +65,7 @@ class EtcdClient { etcd::SyncClient client_; std::string etcd_addr_; + std::string etcd_namespace_prefix_; std::mutex watchers_mutex_; std::unordered_map watchers_; std::mutex keep_alives_mutex_; diff --git a/xllm/api_service/chat_json_utils.h b/xllm/core/common/etcd_utils.h similarity index 51% rename from xllm/api_service/chat_json_utils.h rename to xllm/core/common/etcd_utils.h index 2757573da..042de3843 100644 --- a/xllm/api_service/chat_json_utils.h +++ b/xllm/core/common/etcd_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. +/* Copyright 2026 The xLLM Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,19 +16,27 @@ limitations under the License. #pragma once #include -#include - -#include "core/common/types.h" namespace xllm { -// Preprocess chat JSON to normalize array content to string. -// For text-only backends (is_multimodal=false), combines text array items into -// a single string. Returns an error if non-text content is encountered. -// For multimodal backends (is_multimodal=true), leaves non-text content -// unchanged for downstream processing. -// Returns Status with processed JSON on success, or error status on failure. -std::pair preprocess_chat_json(std::string json_str, - bool is_multimodal); +inline std::string normalize_etcd_namespace(const std::string& etcd_namespace) { + if (etcd_namespace.empty()) { + return ""; + } + + size_t start = 0; + size_t end = etcd_namespace.size(); + while (start < end && etcd_namespace[start] == '/') { + ++start; + } + while (end > start && etcd_namespace[end - 1] == '/') { + --end; + } + + if (start >= end) { + return ""; + } + return "/" + etcd_namespace.substr(start, end - start) + "/"; +} } // namespace xllm diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index da65bca52..fa24d4d9c 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -343,8 +343,6 @@ DEFINE_string(kv_cache_transfer_mode, "PUSH", "The mode of kv cache transfer(e.g. PUSH, PULL)."); -DEFINE_int32(npu_phy_id, -1, "npu phy id"); - DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port."); DEFINE_uint64(input_shm_size, @@ -408,14 +406,24 @@ DEFINE_bool(enable_atb_spec_kernel, // --- block copy config --- +#if defined(USE_NPU) || defined(USE_CUDA) DEFINE_bool(enable_block_copy_kernel, true, - "Whether to use ATB block copy kernel."); + "Whether to use block copy kernel on supported backends."); +#else +DEFINE_bool(enable_block_copy_kernel, + false, + "Whether to use block copy kernel on supported backends."); +#endif // --- service routing config --- DEFINE_string(etcd_addr, "", "Etcd adderss for save instance meta info."); +DEFINE_string(etcd_namespace, + "", + "Optional etcd namespace prefix for all xllm keys, e.g. prod-a."); + DEFINE_bool(enable_service_routing, false, "Whether to use xllm service routing."); @@ -597,7 +605,7 @@ DEFINE_int32(random_seed, -1, "Random seed for random number generator."); DEFINE_string(dit_cache_policy, "TaylorSeer", "The policy of dit cache(e.g. None, FBCache, TaylorSeer, " - "FBCacheTaylorSeer)."); + "FBCacheTaylorSeer, ResidualCache)."); DEFINE_int64(dit_cache_warmup_steps, 0, "The number of warmup steps."); @@ -619,6 +627,60 @@ DEFINE_bool(enable_constrained_decoding, "that the output meets specific format or structural requirements " "through pre-defined rules."); +DEFINE_bool(enable_convert_tokens_to_item, + false, + "Enable token ids conversion to item id in REC/OneRec response."); + +DEFINE_int64(dit_cache_start_steps, + 5, + "The number of steps to skip at the start"); + +DEFINE_int64(dit_cache_end_steps, 5, "The number of steps to skip at the end."); + +DEFINE_int64(dit_cache_start_blocks, + 5, + "The number of blocks to skip at the start."); + +DEFINE_int64(dit_cache_end_blocks, + 5, + "The number of blocks to skip at the end."); + +// --- dit parallel config --- + +DEFINE_int64(tp_size, 1, "Tensor parallelism size"); + +DEFINE_int64(sp_size, 1, "Sequence parallelism size"); + +DEFINE_int64(cfg_size, 1, "Classifier-free guidiance parallelism size"); + +DEFINE_int64(dit_sp_communication_overlap, + 1, + "Communication & Computation overlap for sequence parallel"); + +// --- dit debug --- + +DEFINE_bool(dit_debug_print, + false, + "whether print the debug info for dit models"); + +// --- embedding type --- + +DEFINE_bool(enable_return_mm_full_embeddings, + false, + "return vit and sequence embeddings for vlm models"); + +DEFINE_bool(enable_output_sku_logprobs, + false, + "Enable REC / OneRec token-aligned logprobs tensor output."); + +DEFINE_int32(each_conversion_threshold, + 50, + "Maximum number of items emitted for each REC token triplet."); + +DEFINE_int32(total_conversion_threshold, + 1000, + "Maximum total number of items emitted in one REC response."); + DEFINE_bool( use_audio_in_video, false, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 67a27afa9..59412570d 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -138,8 +138,6 @@ DECLARE_string(kv_cache_transfer_type); DECLARE_string(kv_cache_transfer_mode); -DECLARE_int32(npu_phy_id); - DECLARE_string(device_ip); DECLARE_int32(transfer_listen_port); @@ -164,6 +162,8 @@ DECLARE_bool(enable_block_copy_kernel); DECLARE_string(etcd_addr); +DECLARE_string(etcd_namespace); + DECLARE_bool(enable_service_routing); DECLARE_double(heart_beat_interval); @@ -254,6 +254,8 @@ DECLARE_bool(enable_beam_search_kernel); DECLARE_bool(enable_rec_fast_sampler); +DECLARE_bool(enable_rec_prefill_only); + DECLARE_bool(enable_topk_sorted); DECLARE_bool(output_rec_logprobs); @@ -292,6 +294,28 @@ DECLARE_int64(dit_cache_skip_interval_steps); DECLARE_double(dit_cache_residual_diff_threshold); DECLARE_bool(enable_constrained_decoding); +DECLARE_bool(enable_convert_tokens_to_item); +DECLARE_bool(enable_output_sku_logprobs); +DECLARE_int32(each_conversion_threshold); +DECLARE_int32(total_conversion_threshold); + +DECLARE_bool(enable_return_mm_full_embeddings); + +DECLARE_int64(dit_cache_start_steps); + +DECLARE_int64(dit_cache_end_steps); + +DECLARE_int64(dit_cache_start_blocks); + +DECLARE_int64(dit_cache_end_blocks); + +DECLARE_int64(tp_size); + +DECLARE_int64(sp_size); + +DECLARE_int64(cfg_size); + +DECLARE_bool(dit_debug_print); // --- multi-step decode config --- diff --git a/xllm/core/common/help_formatter.h b/xllm/core/common/help_formatter.h index dd5364a16..4719eee0a 100644 --- a/xllm/core/common/help_formatter.h +++ b/xllm/core/common/help_formatter.h @@ -60,6 +60,10 @@ const OptionCategory kMoeModelOptions = { "MOE MODEL OPTIONS", {"dp_size", "ep_size", "expert_parallel_degree"}}; +const OptionCategory kDiTModelOptions = { + "DiT MODEL OPTIONS", + {"dp_size", "tp_size", "sp_size", "cfg_size"}}; + const OptionCategory kDisaggregatedPrefillDecodeOptions = { "DISAGGREGATED PREFILL-DECODE OPTIONS", {"enable_disagg_pd", @@ -67,7 +71,6 @@ const OptionCategory kDisaggregatedPrefillDecodeOptions = { "instance_role", "kv_cache_transfer_mode", "device_ip", - "npu_phy_id", "transfer_listen_port"}}; const OptionCategory kMultiStepDecodeOptions = { @@ -86,14 +89,22 @@ const OptionCategory kMtpOptions = {"SPECULATIVE OPTIONS", "speculative_suffix_max_cached_requests", "speculative_suffix_use_tree_spec"}}; -const OptionCategory kXllmServiceOptions = {"XLLM-SERVICE OPTIONS", - {"etcd_addr", "rank_tablefile"}}; +const OptionCategory kXllmServiceOptions = { + "XLLM-SERVICE OPTIONS", + {"etcd_addr", "rank_tablefile", "etcd_namespace"}}; + +const OptionCategory kBeamSearchOptions = { + "BEAM SEARCH OPTIONS", + {"enable_beam_search_kernel", "enable_topk_sorted"}}; -const OptionCategory kBeamSearchOptions = {"BEAM SEARCH OPTIONS", - {"enable_beam_search_kernel", - "enable_rec_fast_sampler", - "enable_topk_sorted", - "output_rec_logprobs"}}; +const OptionCategory kRecOptions = {"REC OPTIONS", + {"enable_rec_fast_sampler", + "enable_convert_tokens_to_item", + "enable_output_sku_logprobs", + "each_conversion_threshold", + "total_conversion_threshold", + "enable_rec_prefill_only", + "output_rec_logprobs"}}; const OptionCategory kPrefixCacheOptions = { "PREFIX CACHE OPTIONS", @@ -113,11 +124,13 @@ const std::vector kOptionCategories = { kCommonOptions, kCacheOptions, kMoeModelOptions, + kDiTModelOptions, kDisaggregatedPrefillDecodeOptions, kMultiStepDecodeOptions, kMtpOptions, kXllmServiceOptions, kBeamSearchOptions, + kRecOptions, kPrefixCacheOptions, kOtherOptions}; diff --git a/xllm/core/common/macros.h b/xllm/core/common/macros.h index 8dddc257f..df6e1b4ac 100644 --- a/xllm/core/common/macros.h +++ b/xllm/core/common/macros.h @@ -62,6 +62,18 @@ namespace xllm { } \ } while (0) +#define TORCH_TENSOR_VEC_TO_PROTO_TENSOR_LIST(proto_field, torch_tensor_vec) \ + do { \ + proto_field->mutable_tensors()->Reserve(torch_tensor_vec.size()); \ + for (const auto& torch_tensor : torch_tensor_vec) { \ + proto::Tensor* pb_tensor = proto_field->add_tensors(); \ + if (!util::torch_to_proto(torch_tensor, pb_tensor)) { \ + LOG(ERROR) \ + << "Failed to convert torch Tensor to PB Tensor (list item)"; \ + } \ + } \ + } while (0) + #define CALLBACK_WITH_ERROR_ARGS2(CODE, MSG) callback(Status{CODE, MSG}) #define CALLBACK_WITH_ERROR_ARGS4(CODE, MSG, ID, TARGET_XSERVICE_ADDR) \ callback({Status{CODE, MSG}, ID, TARGET_XSERVICE_ADDR}) diff --git a/xllm/core/common/options.cpp b/xllm/core/common/options.cpp index 3b15cbd41..94469344c 100644 --- a/xllm/core/common/options.cpp +++ b/xllm/core/common/options.cpp @@ -65,6 +65,7 @@ std::string Options::to_string() const { << ", kv_cache_dtype: " << kv_cache_dtype() << ", kv_cache_transfer_mode: " << kv_cache_transfer_mode() << ", etcd_addr: " << etcd_addr().value_or("null") + << ", etcd_namespace: " << etcd_namespace().value_or("null") << ", enable_service_routing: " << enable_service_routing() << ", enable_cache_upload: " << enable_cache_upload() << ", enable_kvcache_store: " << enable_kvcache_store() diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 6179860c8..300364454 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -128,6 +128,12 @@ class Options { PROPERTY(int32_t, ep_size) = 1; + PROPERTY(int32_t, tp_size) = 1; + + PROPERTY(int32_t, sp_size) = 1; + + PROPERTY(int32_t, cfg_size) = 1; + PROPERTY(std::optional, instance_name); PROPERTY(bool, enable_disagg_pd) = false; @@ -146,6 +152,8 @@ class Options { PROPERTY(std::optional, etcd_addr); + PROPERTY(std::optional, etcd_namespace); + PROPERTY(bool, enable_service_routing) = false; PROPERTY(std::optional, tool_call_parser); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index cf5064ba3..dadc09f31 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -41,15 +41,15 @@ class EngineType { constexpr EngineType(Value v) : value_(v) {} EngineType(const std::string& str) { - if (str == "LLM") { + if (str == "LLM" || str == "llm") { value_ = LLM; - } else if (str == "SSM") { + } else if (str == "SSM" || str == "ssm") { value_ = SSM; - } else if (str == "VLM") { + } else if (str == "VLM" || str == "vlm") { value_ = VLM; - } else if (str == "DIT") { + } else if (str == "DIT" || str == "dit") { value_ = DIT; - } else if (str == "REC") { + } else if (str == "REC" || str == "rec") { value_ = REC; } else { value_ = INVALID; diff --git a/xllm/core/distributed_runtime/dist_manager.cpp b/xllm/core/distributed_runtime/dist_manager.cpp index 4be72dc5c..c2a7cdade 100644 --- a/xllm/core/distributed_runtime/dist_manager.cpp +++ b/xllm/core/distributed_runtime/dist_manager.cpp @@ -182,12 +182,25 @@ void DistManager::setup_multi_node_workers( /* TODO(CP): support smem + CP */ const int32_t dp_local_tp_size = world_size / dp_size; - LOG(INFO) << "Multi-node serving world_size = " << world_size - << ", each_node_ranks = " << each_node_ranks - << ", current node rank = " << options.node_rank() - << ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size - << ", cp_size = " << cp_size << ", ep_size = " << ep_size - << ", tp_size = " << dp_local_tp_size; + const auto& model_backend = options.backend(); + if (model_backend == "dit") { + const int32_t tp_size = options.tp_size(); + const int32_t sp_size = options.sp_size(); + const int32_t cfg_size = options.cfg_size(); + LOG(INFO) << "Multi-node serving world_size = " << world_size + << ", each_node_ranks = " << each_node_ranks + << ", current node rank = " << options.node_rank() + << ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size + << ", tp_size = " << tp_size << ", sp_size = " << sp_size + << ", cfg_size = " << cfg_size; + } else { + LOG(INFO) << "Multi-node serving world_size = " << world_size + << ", each_node_ranks = " << each_node_ranks + << ", current node rank = " << options.node_rank() + << ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size + << ", cp_size = " << cp_size << ", ep_size = " << ep_size + << ", tp_size = " << dp_local_tp_size; + } CHECK_EQ((world_size % dp_size), 0) << "Global world size must be divisible by dp size in multi-node " @@ -196,7 +209,6 @@ void DistManager::setup_multi_node_workers( runtime::Options worker_server_options = options; worker_server_options.world_size(world_size); WorkerType worker_type("LLM"); - const auto& model_backend = options.backend(); if (model_backend == "llm") { if (options.task_type() == "generate") { worker_type = WorkerType::LLM; @@ -219,6 +231,8 @@ void DistManager::setup_multi_node_workers( } } else if (model_backend == "rec") { worker_type = WorkerType::REC; + } else if (model_backend == "dit") { + worker_type = WorkerType::DIT; } else { LOG(FATAL) << "Unsupported " << model_backend << " in multi-node."; } diff --git a/xllm/core/distributed_runtime/dit_engine.cpp b/xllm/core/distributed_runtime/dit_engine.cpp index d7b50e834..96f234c5a 100644 --- a/xllm/core/distributed_runtime/dit_engine.cpp +++ b/xllm/core/distributed_runtime/dit_engine.cpp @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "common/device_monitor.h" #include "core/common/metrics.h" +#include "core/distributed_runtime/master.h" #include "core/platform/device.h" #include "framework/parallel_state/parallel_args.h" #include "framework/parallel_state/parallel_state.h" @@ -28,8 +30,16 @@ limitations under the License. #include "util/timer.h" namespace xllm { -DiTEngine::DiTEngine(const runtime::Options& options) : options_(options) { +DiTEngine::DiTEngine(const runtime::Options& options, + std::shared_ptr dist_manager) + : options_(options), dist_manager_(dist_manager) { + auto master_node_addr = options.master_node_addr().value_or(""); + CHECK(!master_node_addr.empty()) + << " DIT need to set master node addr, Please set --master_node_addr."; + const auto& devices = options_.devices(); + // initialize device monitor + DeviceMonitor::get_instance().initialize(devices); CHECK_GT(devices.size(), 0) << "At least one device is required"; CHECK(!devices[0].is_cpu()) << "CPU device is not supported"; @@ -37,43 +47,26 @@ DiTEngine::DiTEngine(const runtime::Options& options) : options_(options) { for (size_t i = 0; i < devices.size(); ++i) { CHECK(devices[i].type() == device_type) << "All devices should be the same type"; - Device device(devices[i]); - device.set_device(); - } - if (devices.size() > 1) { - // create a process group for each device if there are multiple gpus - process_groups_ = parallel_state::create_npu_process_groups(devices); +#if defined(USE_NPU) + FLAGS_enable_atb_comm_multiprocess = + options.enable_offline_inference() || (options.nnodes() > 1); +#endif } - const int32_t world_size = static_cast(devices.size()); - CHECK(!options_.enable_shm()) << "Dit can not support enable_shm currently."; + // setup all workers and create worker clients in nnode_rank=0 engine side. + setup_workers(options); + worker_clients_num_ = worker_clients_.size(); - // create workers - for (int32_t rank = 0; rank < world_size; ++rank) { - ProcessGroup* pg = world_size > 1 ? process_groups_[rank].get() : nullptr; - ParallelArgs parallel_args(rank, world_size, pg); - workers_.emplace_back( - std::make_unique(parallel_args, devices[rank], options_)); - } + // init thread pool + threadpool_ = std::make_unique(16); +} - if (workers_.size() > 1) { - // test process group - std::vector> futures; - futures.reserve(workers_.size()); - for (auto& worker : workers_) { - futures.emplace_back(worker->process_group_test_async()); - } - // Wait for all futures to complete with a configurable timeout. - // The timeout can be adjusted via the - // XLLM_PROCESS_GROUP_ASYNC_TIMEOUT_SECONDS environment variable (default: 4 - // seconds). This is particularly important in multi-node multi-device - // scenarios where network latency may require a longer timeout period. - const int timeout_seconds = util::get_process_group_test_timeout_seconds(); - folly::collectAll(futures) - .within(std::chrono::seconds(timeout_seconds)) - .get(); +void DiTEngine::setup_workers(const runtime::Options& options) { + if (!dist_manager_) { + dist_manager_ = std::make_shared(options); } + worker_clients_ = dist_manager_->get_worker_clients(); } bool DiTEngine::init() { @@ -86,13 +79,14 @@ bool DiTEngine::init() { bool DiTEngine::init_model() { const std::string& model_path = options_.model_path(); + // init model for each worker in parallel // multiple workers, call async init std::vector> futures; - LOG(INFO) << "Starting to init model on " << workers_.size() << " workers."; - futures.reserve(workers_.size()); - for (auto& worker : workers_) { - futures.push_back(worker->init_model(model_path)); + futures.reserve(worker_clients_num_); + for (auto& worker : worker_clients_) { + futures.push_back(worker->init_model_async( + model_path, FLAGS_random_seed, MasterStatus::WAKEUP)); } // wait for all futures to complete @@ -108,17 +102,25 @@ bool DiTEngine::init_model() { return true; } +// TODO : change to ForwardOutput? DiTForwardOutput DiTEngine::step(std::vector& batches) { - CHECK(!workers_.empty()); + if (worker_clients_.empty()) { + // empty worker, return + return {}; + } Timer timer; - auto forward_inputs = workers_[0]->prepare_inputs(batches[0]); + auto dit_forward_input = batches[0].prepare_forward_input(); + RawForwardInput raw_forward_input; + raw_forward_input.dit_forward_input = dit_forward_input; COUNTER_ADD(prepare_input_latency_seconds, timer.elapsed_seconds()); - std::vector>> futures; - futures.reserve(workers_.size()); - for (auto& worker : workers_) { - futures.emplace_back(worker->step(forward_inputs)); + std::vector>> futures; + futures.reserve(worker_clients_num_); + + for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) { + futures.emplace_back( + worker_clients_[worker_rank]->step_async(raw_forward_input)); } // wait for the all future to complete @@ -127,22 +129,22 @@ DiTForwardOutput DiTEngine::step(std::vector& batches) { // return the result from the driver auto forward_output = results.front().value(); DCHECK(forward_output.has_value()) << "Failed to execute model"; - batches[0].process_forward_output(forward_output.value()); - return forward_output.value(); + batches[0].process_forward_output(forward_output.value().dit_forward_output); + return forward_output.value().dit_forward_output; } std::vector DiTEngine::get_active_activation_memory() const { // call worker to get active activation memory std::vector> futures; - futures.reserve(workers_.size()); - for (auto& worker : workers_) { - futures.push_back(worker->get_active_activation_memory()); + futures.reserve(worker_clients_num_); + for (auto& worker : worker_clients_) { + futures.push_back(worker->get_active_activation_memory_async()); } // wait for all futures to complete auto results = folly::collectAll(futures).get(); std::vector active_activation_memories; - active_activation_memories.reserve(workers_.size()); + active_activation_memories.reserve(worker_clients_num_); for (auto& result : results) { active_activation_memories.push_back(result.value()); } diff --git a/xllm/core/distributed_runtime/dit_engine.h b/xllm/core/distributed_runtime/dit_engine.h index 519394b74..409c3f635 100644 --- a/xllm/core/distributed_runtime/dit_engine.h +++ b/xllm/core/distributed_runtime/dit_engine.h @@ -21,16 +21,19 @@ limitations under the License. #include #include "common/macros.h" +#include "dist_manager.h" +#include "engine.h" #include "framework/batch/dit_batch.h" #include "framework/parallel_state/process_group.h" #include "framework/quant_args.h" -#include "runtime/dit_worker.h" +#include "runtime/dit_worker_impl.h" namespace xllm { -class DiTEngine { +class DiTEngine : public Engine { public: - DiTEngine(const runtime::Options& options); + DiTEngine(const runtime::Options& options, + std::shared_ptr dist_manager = nullptr); ~DiTEngine() = default; @@ -43,16 +46,44 @@ class DiTEngine { // return the active activation memory std::vector get_active_activation_memory() const; + std::shared_ptr get_dist_manager() { return dist_manager_; } + + // These two functions wouldn't be used in dit inference progress + ForwardOutput step(std::vector& batch) override { + ForwardOutput output; + return output; + } + + void update_last_step_result(std::vector& batch) override { return; } + + protected: + // worker client which is used for call worker + // The reason for adding a worker client is to unify the + // access code for both local and remote workers, thereby + // introducing an additional worker_client abstraction. + std::vector> worker_clients_; + + // For multi-node serving + // engine brpc server, all workers connect to engine_server_, + // engine_server_ will send a UniqueId for workers to + // create process group. And workers send worker brpc server + // address to engine, engine will create WorkerClient for each worker. + // Engine call workers to step via these WorkerClients. + std::shared_ptr dist_manager_ = nullptr; + + std::unique_ptr threadpool_ = nullptr; + private: + // setup workers internal + void setup_workers(const runtime::Options& options); + // init models bool init_model(); // options runtime::Options options_; - + // num of worker_clients + int64_t worker_clients_num_; // a list of process groups, with each process group handling a single device std::vector> process_groups_; - - // a list of workers, with each worker handling a partial of model - std::vector> workers_; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/distributed_runtime/dit_master.cpp b/xllm/core/distributed_runtime/dit_master.cpp index 71d4437ec..1aac9d198 100644 --- a/xllm/core/distributed_runtime/dit_master.cpp +++ b/xllm/core/distributed_runtime/dit_master.cpp @@ -37,20 +37,10 @@ limitations under the License. #include "util/timer.h" namespace xllm { +volatile bool DiTAssistantMaster::running_ = false; + DiTMaster::DiTMaster(const Options& options) : Master(options, EngineType::DIT) { - // construct engine - const auto devices = - DeviceNameUtils::parse_devices(options_.devices().value_or("auto")); - LOG(INFO) << "Creating engine with devices: " - << DeviceNameUtils::to_string(devices); - - runtime::Options eng_options; - eng_options.model_path(options.model_path()) - .model_id(options.model_id()) - .devices(devices); - - engine_ = std::make_unique(eng_options); CHECK(engine_->init()); DiTScheduler::Options scheduler_options; @@ -157,4 +147,36 @@ void DiTMaster::generate() { running_.store(false, std::memory_order_relaxed); } +DiTAssistantMaster::DiTAssistantMaster(const Options& options) + : Master(options, EngineType::DIT) { + // setup process workers + auto master_node_addr = options_.master_node_addr().value_or(""); + // TODO: support local unix domain socket later. + if (master_node_addr.empty()) { + LOG(FATAL) + << "MultiNodeEngine required master_node_addr, current value is empty."; + return; + } + + running_ = true; +} + +DiTAssistantMaster::~DiTAssistantMaster() { + // wait for the loop thread to finish + if (loop_thread_.joinable()) { + loop_thread_.join(); + } +} + +void DiTAssistantMaster::run() { + signal(SIGINT, DiTAssistantMaster::handle_signal); + signal(SIGTERM, DiTAssistantMaster::handle_signal); + + loop_thread_ = std::thread([this]() { + while (running_) { + std::this_thread::sleep_for(std::chrono::seconds(5)); + } + }); +} + } // namespace xllm diff --git a/xllm/core/distributed_runtime/dit_master.h b/xllm/core/distributed_runtime/dit_master.h index 13385fb63..a67bad5a2 100644 --- a/xllm/core/distributed_runtime/dit_master.h +++ b/xllm/core/distributed_runtime/dit_master.h @@ -53,8 +53,6 @@ class DiTMaster : public Master { void generate(); private: - std::unique_ptr engine_; - std::unique_ptr scheduler_; // thread pool for handling requests @@ -70,4 +68,17 @@ class DiTMaster : public Master { std::atomic_bool running_{false}; }; +class DiTAssistantMaster : public Master { + public: + DiTAssistantMaster(const Options& options); + ~DiTAssistantMaster(); + void run() override; + + static void handle_signal(int signum) { running_ = false; } + + private: + std::thread loop_thread_; + static volatile bool running_; +}; + } // namespace xllm diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 1ec39f5d2..f337687c9 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -519,12 +519,26 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { } #endif + int64_t full_attention_interval = (args_.full_attention_interval() < 1) + ? 1 + : args_.full_attention_interval(); + int64_t num_full_attention_layers = + kv_cache_cap.n_layers / full_attention_interval; + int64_t num_linear_attention_layers = + kv_cache_cap.n_layers - num_full_attention_layers; // compute kv cache n_blocks const int32_t block_size = options_.block_size(); const int64_t block_size_in_bytes = block_size * (slot_size + index_slot_size + scale_slot_size); - kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / - (kv_cache_cap.n_layers * block_size_in_bytes); + const int64_t full_cache_block_size_in_bytes = + block_size * (slot_size + index_slot_size + scale_slot_size); + const int64_t total_cache_block_size_in_bytes = + num_full_attention_layers * full_cache_block_size_in_bytes + + num_linear_attention_layers * linear_slot_size; + CHECK_GT(total_cache_block_size_in_bytes, 0) + << "invalid cache block size estimate"; + kv_cache_cap.n_blocks = + kv_cache_cap.cache_size_in_bytes / total_cache_block_size_in_bytes; CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; return kv_cache_cap; } @@ -539,7 +553,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { CHECK_GT(kv_cache_cap.n_blocks, 0) << "no memory for kv cache"; const int32_t block_size = options_.block_size(); - bool enable_lighting_indexer = args_.index_n_heads() > 1; + bool enable_lighting_indexer = args_.index_n_heads() > 0; bool enable_gdn_attention = has_linear_attention_layers(args_); // init kv cache for each worker diff --git a/xllm/core/distributed_runtime/llm_master.cpp b/xllm/core/distributed_runtime/llm_master.cpp index 337b8b314..3effdf56d 100644 --- a/xllm/core/distributed_runtime/llm_master.cpp +++ b/xllm/core/distributed_runtime/llm_master.cpp @@ -67,7 +67,8 @@ LLMMaster::LLMMaster(const Options& options) xservice_client_ = XServiceClient::get_instance(); if (!xservice_client_->init(options_.etcd_addr().value_or(""), options_.instance_name().value_or(""), - engine_->block_manager_pool())) { + engine_->block_manager_pool(), + options_.etcd_namespace().value_or(""))) { LOG(FATAL) << "XServiceClient init fail!"; return; } diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 49174a6b8..cf5a50f10 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -128,7 +128,9 @@ void resolve_npu_kernel_backend_for_options(Options* options) { } // namespace Master::Master(const Options& options, EngineType type) - : options_(options), master_status_(options.master_status()) { + : options_(options), + engine_type_(type), + master_status_(options.master_status()) { const auto model_path = std::filesystem::path(options_.model_path()).lexically_normal(); options_.enable_mla(util::should_enable_mla(model_path, options_.backend())); @@ -253,8 +255,8 @@ Master::Master(const Options& options, EngineType type) options_.speculative_suffix_max_cached_requests()) .speculative_suffix_use_tree_spec( options_.speculative_suffix_use_tree_spec()) - .task_type(options.task_type()) - .enable_mla(options.enable_mla()) + .task_type(options_.task_type()) + .enable_mla(options_.enable_mla()) .npu_kernel_backend(options_.npu_kernel_backend()) .master_node_addr(options.master_node_addr()) .nnodes(options.nnodes()) @@ -382,6 +384,34 @@ Master::Master(const Options& options, EngineType type) .rec_worker_max_concurrency(options_.rec_worker_max_concurrency()); engine_ = std::make_unique(eng_options); + } else if (type == EngineType::DIT) { + // construct dit engine + runtime::Options eng_options; + eng_options.model_path(options.model_path()) + .model_id(options.model_id()) + .devices(devices) + .backend(options.backend()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .enable_offline_inference(options_.enable_offline_inference()) + .max_memory_utilization(options_.max_memory_utilization()) + .master_node_addr(options.master_node_addr()) + .nnodes(options.nnodes()) + .task_type(options_.task_type()) + .enable_shm(options_.enable_shm()) + .input_shm_size(options_.input_shm_size() * 1024 * 1024) + .output_shm_size(options_.output_shm_size() * 1024 * 1024) + .is_local(options_.is_local()) + .node_rank(options_.node_rank()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .dp_size(options_.dp_size()) + .ep_size(options_.ep_size()) + .tp_size(options_.tp_size()) + .sp_size(options_.sp_size()) + .cfg_size(options_.cfg_size()); + + auto dit_engine = std::make_unique(eng_options); + engine_ = std::move(dit_engine); } else { LOG(WARNING) << "Not supported llm engine type: " << static_cast(type); diff --git a/xllm/core/distributed_runtime/master.h b/xllm/core/distributed_runtime/master.h index 198d15536..fe05077cb 100644 --- a/xllm/core/distributed_runtime/master.h +++ b/xllm/core/distributed_runtime/master.h @@ -35,6 +35,7 @@ class Master { virtual ~Master() = default; virtual void run() = 0; virtual const Options& options() const { return options_; } + EngineType engine_type() const { return engine_type_; } virtual bool sleep() { return false; } @@ -62,6 +63,7 @@ class Master { protected: Options options_; + EngineType engine_type_ = EngineType::INVALID; std::unique_ptr engine_; RateLimiter rate_limiter_; MasterStatus master_status_{MasterStatus::WAKEUP}; diff --git a/xllm/core/distributed_runtime/rec_engine.cpp b/xllm/core/distributed_runtime/rec_engine.cpp index 8fda21b41..e9f02e989 100644 --- a/xllm/core/distributed_runtime/rec_engine.cpp +++ b/xllm/core/distributed_runtime/rec_engine.cpp @@ -25,7 +25,6 @@ limitations under the License. #include "common/global_flags.h" #include "common/metrics.h" -#include "common/rec_model_utils.h" #include "framework/model/model_args.h" #include "framework/model_loader.h" #include "framework/parallel_state/parallel_state.h" @@ -34,6 +33,7 @@ limitations under the License. #include "util/env_var.h" #include "util/net.h" #include "util/pretty_print.h" +#include "util/rec_model_utils.h" #include "util/timer.h" #include "util/utils.h" @@ -85,6 +85,8 @@ bool RecEngine::init_model() { tokenizer_args_ = model_loader->tokenizer_args(); // Determine rec model kind and create pipeline via factory rec_model_kind_ = get_rec_model_kind(args_.model_type()); + CHECK(rec_model_kind_ != RecModelKind::kNone) + << "Unsupported rec model_type: " << args_.model_type(); auto pipeline_type = get_rec_pipeline_type(rec_model_kind_); pipeline_ = create_pipeline(pipeline_type, *this); // LlmRec-specific initialization @@ -524,23 +526,36 @@ void RecEngine::OneRecEnginePipeline::process_group_test() { bool RecEngine::OneRecEnginePipeline::init_model_workers( const std::string& model_path) { const auto& devices = engine_.options_.devices(); - if (devices.size() > 1) { + const int32_t world_size = static_cast(devices.size()); + + // OneRec local workers still expect valid TP group metadata even on a + // single device. For world_size == 1, only rank/world_size metadata is + // needed, so avoid creating a real communication backend or extra streams. + // For multi-device NPU keep the HCCL-backed groups; other local backends use + // the generic local process group creation path. + if (world_size == 1) { + engine_.process_groups_.clear(); + engine_.process_groups_.emplace_back( + std::make_unique(/*rank=*/0, world_size, devices[0])); + } #if defined(USE_NPU) + else { engine_.process_groups_ = parallel_state::create_npu_process_groups(devices); + } #else + else { engine_.process_groups_ = parallel_state::create_local_process_groups(devices, engine_.options_); -#endif } +#endif engine_.workers_.clear(); WorkerType worker_type = WorkerType::REC; - const int32_t world_size = static_cast(devices.size()); for (int32_t rank = 0; rank < world_size; ++rank) { - ProcessGroup* pg = - world_size > 1 ? engine_.process_groups_[rank].get() : nullptr; + ProcessGroup* pg = engine_.process_groups_[rank].get(); ParallelArgs parallel_args(rank, world_size, pg); + parallel_args.tp_group_ = pg; engine_.workers_.emplace_back(std::make_unique( parallel_args, devices[rank], engine_.options_, worker_type)); } @@ -690,7 +705,36 @@ ForwardOutput RecEngine::OneRecEnginePipeline::get_model_output( auto forward_output = results.front().value(); CHECK(forward_output.has_value()) << "Failed to execute model"; - return forward_output.value(); + + auto& output = forward_output.value(); + auto& sample_output = output.sample_output; + + if (sample_output.embeddings.defined()) { + sample_output.embeddings = safe_to( + sample_output.embeddings, + torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32), + /*non_blocking=*/true); + } + + if (sample_output.next_tokens.defined()) { + sample_output.next_tokens = + safe_to(sample_output.next_tokens, torch::kCPU, /*non_blocking=*/true); + if (sample_output.logprobs.defined()) { + sample_output.logprobs = + safe_to(sample_output.logprobs, torch::kCPU, true); + } + if (sample_output.top_tokens.defined()) { + sample_output.top_tokens = + safe_to(sample_output.top_tokens, torch::kCPU, true); + } + if (sample_output.top_logprobs.defined()) { + sample_output.top_logprobs = + safe_to(sample_output.top_logprobs, torch::kCPU, true); + } + } + Device(engine_.workers_[0]->device()).synchronize_default_stream(); + + return output; } std::vector diff --git a/xllm/core/distributed_runtime/rec_engine.h b/xllm/core/distributed_runtime/rec_engine.h index b2de4ac08..100efb200 100644 --- a/xllm/core/distributed_runtime/rec_engine.h +++ b/xllm/core/distributed_runtime/rec_engine.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "common/macros.h" -#include "common/rec_model_utils.h" #include "distributed_runtime/dist_manager.h" #include "engine.h" #include "framework/batch/batch.h" @@ -29,6 +28,7 @@ limitations under the License. #include "framework/tokenizer/tokenizer.h" #include "framework/tokenizer/tokenizer_args.h" #include "runtime/worker.h" +#include "util/rec_model_utils.h" #include "util/threadpool.h" namespace xllm { diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp index fb34cd702..fafb3ff16 100644 --- a/xllm/core/distributed_runtime/rec_master.cpp +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -22,19 +22,20 @@ limitations under the License. #include #include +#include #include #include #include #include "common/macros.h" #include "common/metrics.h" -#include "common/rec_model_utils.h" #include "common/types.h" #include "framework/request/mm_data.h" #include "models/model_registry.h" #include "rec_engine.h" #include "runtime/xservice_client.h" #include "scheduler/scheduler_factory.h" +#include "util/rec_model_utils.h" #include "util/scope_guard.h" #include "util/threadpool.h" #include "util/utils.h" @@ -157,14 +158,21 @@ bool process_onerec_inputs( return false; } + if (len > std::numeric_limits::max() / hidden) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "OneRec input tensor '" + tensor_name + "' shape is too large"); + return false; + } + + const int64_t expected_numel = len * hidden; const int64_t actual_numel = static_cast(tensor.contents().fp32_contents_size()); - if (actual_numel % hidden != 0 || actual_numel / hidden != len) { + if (expected_numel != actual_numel) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "OneRec input tensor '" + tensor_name + "' fp32 contents size mismatch, expected " + - std::to_string(len) + " * " + - std::to_string(hidden) + ", got " + + std::to_string(expected_numel) + ", got " + std::to_string(actual_numel)); return false; } @@ -488,7 +496,8 @@ RecMaster::RecMaster(const Options& options) XServiceClient* xservice_client = XServiceClient::get_instance(); if (!xservice_client->init(options_.etcd_addr().value_or(""), options_.instance_name().value_or(""), - engine_->block_manager_pool())) { + engine_->block_manager_pool(), + options_.etcd_namespace().value_or(""))) { LOG(FATAL) << "XServiceClient init fail!"; return; } @@ -506,7 +515,8 @@ RecMaster::RecMaster(const Options& options) .enable_chunked_prefill(options_.enable_chunked_prefill()) .instance_role(options_.instance_role()) .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) - .enable_service_routing(options_.enable_service_routing()); + .enable_service_routing(options_.enable_service_routing()) + .rec_worker_max_concurrency(options_.rec_worker_max_concurrency()); scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options); chat_template_ = nullptr; @@ -523,6 +533,8 @@ RecMaster::RecMaster(const Options& options) // Create pipelines based on rec_type auto rec_model_kind = get_rec_model_kind(model_args_.model_type()); + CHECK(rec_model_kind != RecModelKind::kNone) + << "Unsupported rec model_type: " << model_args_.model_type(); auto pipeline_type = get_rec_pipeline_type(rec_model_kind); pipeline_ = create_pipeline(pipeline_type, *this); diff --git a/xllm/core/distributed_runtime/vlm_master.cpp b/xllm/core/distributed_runtime/vlm_master.cpp index eb54b7f93..ccc615d30 100755 --- a/xllm/core/distributed_runtime/vlm_master.cpp +++ b/xllm/core/distributed_runtime/vlm_master.cpp @@ -53,7 +53,8 @@ VLMMaster::VLMMaster(const Options& options) XServiceClient* xservice_client = XServiceClient::get_instance(); if (!xservice_client->init(options_.etcd_addr().value_or(""), options_.instance_name().value_or(""), - engine_->block_manager_pool())) { + engine_->block_manager_pool(), + options_.etcd_namespace().value_or(""))) { LOG(FATAL) << "XServiceClient init fail!"; return; } diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 943b25d32..9e781ec36 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -41,6 +41,7 @@ limitations under the License. #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/parallel_state/collective_communicator.h" +#include "framework/parallel_state/dit_collective_communicator.h" #include "framework/parallel_state/mapping_npu.h" #include "framework/state_dict/state_dict.h" #include "runtime/forward_params.h" @@ -137,15 +138,32 @@ void WorkerServer::create_server( return; } - CollectiveCommunicator comm( - worker_global_rank, world_size, dp_size, ep_size, cp_size); - const ParallelArgs* parallel_args = comm.parallel_args(); - comm.create_process_groups(master_node_addr, device); + const ParallelArgs* parallel_args = nullptr; + std::unique_ptr comm; + if (worker_type == WorkerType::DIT) { + auto dit_comm = + std::make_unique(worker_global_rank, + world_size, + options.dp_size(), + options.tp_size(), + options.sp_size(), + options.cfg_size()); + comm = std::move(dit_comm); + } else { + auto common_comm = std::make_unique( + worker_global_rank, world_size, dp_size, ep_size, cp_size); + comm = std::move(common_comm); + } + + comm->create_process_groups(master_node_addr, device); + parallel_args = comm->parallel_args(); std::unique_ptr worker = std::make_unique(*parallel_args, device, options, worker_type); worker_service->set_worker(std::move(worker)); - if (options.enable_shm() && input_shm_manager && output_shm_manager) { + bool create_shm = + options.enable_shm() && input_shm_manager && output_shm_manager; + if (create_shm) { worker_service->create_polling_shm_thread(std::move(input_shm_manager), std::move(output_shm_manager)); } @@ -280,7 +298,8 @@ WorkerServer::WorkerServer(int local_worker_idx, if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM || worker_type == WorkerType::VLM || worker_type == WorkerType::EVLM || - worker_type == WorkerType::REC || worker_type == WorkerType::MMEVLM) { + worker_type == WorkerType::REC || worker_type == WorkerType::MMEVLM || + worker_type == WorkerType::DIT) { if (use_spawn_worker) { // start worker in a spawn process(for offline inference worker.) create_spawn_server(local_worker_idx, diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index 9bd8ad84a..8e3f08e85 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -73,14 +73,16 @@ void WorkerService::step(ForwardInput& fwd_input, torch::Tensor& top_logprobs, torch::Tensor& embeddings, std::vector& mm_embeddings, + std::vector& dit_images, torch::Tensor& expert_load_data, int32_t& prepared_layer_id, torch::Tensor& src_seq_idxes, torch::Tensor& out_tokens, torch::Tensor& out_logprobs) { + const bool use_default_stream = + !options_.enable_schedule_overlap() && options_.backend() == "llm"; // execute model auto future = worker_->step_async(fwd_input); - if (!options_.enable_schedule_overlap()) { auto forward_outputs = std::move(future).get(); // convert ForwardOutput to proto::ForwardOutput which contain Tokens. @@ -89,57 +91,78 @@ void WorkerService::step(ForwardInput& fwd_input, const auto& sample_output = forward_outputs.value().sample_output; const auto& beam_search_output = forward_outputs.value().beam_search_output; + const auto& dit_forward_output = + forward_outputs.value().dit_forward_output; expert_load_data = safe_to(forward_outputs.value().expert_load_data, torch::kCPU, true); prepared_layer_id = forward_outputs.value().prepared_layer_id; { - c10::StreamGuard streamGuard = stream_->set_stream_guard(); - // only driver worker (rank=0) need to fill this - // [num_seq, ..., embed_dim] FloatTensor - embeddings = safe_to(sample_output.embeddings, - torch::dtype(torch::kFloat32).device(torch::kCPU), - true); - - mm_embeddings.clear(); - mm_embeddings.reserve(sample_output.mm_embeddings.size()); - for (auto mm_embedding : sample_output.mm_embeddings) { - mm_embeddings.emplace_back(safe_to(mm_embedding, torch::kCPU, true)); - } + auto copy_output_to_host = [&]() { + // only driver worker (rank=0) need to fill this + // [num_seq, ..., embed_dim] FloatTensor + embeddings = + safe_to(sample_output.embeddings, + torch::dtype(torch::kFloat32).device(torch::kCPU), + true); - // [num_seq] - next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); - if (next_tokens.defined()) { - // [num_seq] - logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); - - if (!beam_search_output.src_seq_idxes.defined()) { - // beam search kernel will provide final tokens/logprobs in beam - // search output, so keep top_tokens/top_logprobs undefined to - // avoid returning them. - // [num_seq, topk] - top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true); - // [num_seq, topk] - top_logprobs = - safe_to(sample_output.top_logprobs, torch::kCPU, true); + mm_embeddings.clear(); + mm_embeddings.reserve(sample_output.mm_embeddings.size()); + for (auto mm_embedding : sample_output.mm_embeddings) { + mm_embeddings.emplace_back( + safe_to(mm_embedding, torch::kCPU, true)); + } + + dit_images.clear(); + dit_images.reserve(dit_forward_output.tensors.size()); + for (auto dit_image : dit_forward_output.tensors) { + dit_images.emplace_back(safe_to(dit_image, torch::kCPU, true)); } - } - // beam search output - // [num_seq] - src_seq_idxes = - safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); - if (src_seq_idxes.defined()) { // [num_seq] - out_tokens = - safe_to(beam_search_output.out_tokens, torch::kCPU, true); + next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); + if (next_tokens.defined()) { + // [num_seq] + logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); + + if (!beam_search_output.src_seq_idxes.defined()) { + // beam search kernel will provide final tokens/logprobs in beam + // search output, so keep top_tokens/top_logprobs undefined to + // avoid returning them. + // [num_seq, topk] + top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true); + // [num_seq, topk] + top_logprobs = + safe_to(sample_output.top_logprobs, torch::kCPU, true); + } + } + + // beam search output // [num_seq] - out_logprobs = - safe_to(beam_search_output.out_logprobs, - torch::dtype(torch::kFloat32).device(torch::kCPU), - true); + src_seq_idxes = + safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); + if (src_seq_idxes.defined()) { + // [num_seq] + out_tokens = + safe_to(beam_search_output.out_tokens, torch::kCPU, true); + // [num_seq] + out_logprobs = + safe_to(beam_search_output.out_logprobs, + torch::dtype(torch::kFloat32).device(torch::kCPU), + true); + } + }; + if (use_default_stream) { + copy_output_to_host(); + } else { + c10::StreamGuard stream_guard = stream_->set_stream_guard(); + copy_output_to_host(); + } + if (use_default_stream) { + device_.synchronize_default_stream(); + } else { + stream_->synchronize(); } - auto ret = stream_->synchronize(); } } } else { @@ -175,6 +198,7 @@ void WorkerService::create_polling_shm_thread( torch::Tensor top_logprobs; torch::Tensor embeddings; std::vector mm_embeddings; + std::vector dit_images; torch::Tensor expert_load_data; int32_t prepared_layer_id = -1; @@ -190,6 +214,7 @@ void WorkerService::create_polling_shm_thread( top_logprobs, embeddings, mm_embeddings, + dit_images, expert_load_data, prepared_layer_id, src_seq_idxes, @@ -202,6 +227,7 @@ void WorkerService::create_polling_shm_thread( top_logprobs, embeddings, mm_embeddings, + dit_images, expert_load_data, prepared_layer_id, src_seq_idxes, @@ -603,7 +629,7 @@ void WorkerService::Wakeup(::google::protobuf::RpcController* controller, std::vector segments; segments.reserve(seg_list.segments_size()); for (const auto& proto_seg : seg_list.segments()) { - segments.push_back({proto_seg.offset(), proto_seg.size()}); + segments.emplace_back(proto_seg.offset(), proto_seg.size()); } options.src_weight_segments.push_back(std::move(segments)); } @@ -635,6 +661,7 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, torch::Tensor top_logprobs; torch::Tensor embeddings; std::vector mm_embeddings; + std::vector dit_images; torch::Tensor expert_load_data; int32_t prepared_layer_id = -1; // beam search kernel output @@ -649,6 +676,7 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, top_logprobs, embeddings, mm_embeddings, + dit_images, expert_load_data, prepared_layer_id, src_seq_idxes, @@ -665,6 +693,7 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller, src_seq_idxes, out_tokens, out_logprobs, + dit_images, pb_forward_output); COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds()); }); @@ -678,6 +707,8 @@ void WorkerService::GetLastStepResult( threadpool_->schedule( [this, controller, req, pb_forward_output, done]() mutable { brpc::ClosureGuard done_guard(done); + const bool use_default_stream = + !options_.enable_schedule_overlap() && options_.backend() == "llm"; auto future = worker_->get_last_step_result_async(); auto forward_outputs = std::move(future).get(); @@ -688,38 +719,64 @@ void WorkerService::GetLastStepResult( int32_t prepared_layer_id = forward_outputs.value().prepared_layer_id; const auto& beam_search_output = forward_outputs.value().beam_search_output; - c10::StreamGuard streamGuard = stream_->set_stream_guard(); - // [num_seq, ..., embed_dim] - auto embeddings = - safe_to(sample_output.embeddings, torch::kCPU, true); - embeddings = safe_to(embeddings, torch::kFloat32, true); + torch::Tensor embeddings; + torch::Tensor next_tokens; + torch::Tensor logprobs; + torch::Tensor top_tokens; + torch::Tensor top_logprobs; + torch::Tensor src_seq_idxes; + torch::Tensor out_tokens; + torch::Tensor out_logprobs; + std::vector dit_images; + auto copy_output_to_host = [&]() { + // [num_seq, ..., embed_dim] + embeddings = safe_to(sample_output.embeddings, torch::kCPU, true); + embeddings = safe_to(embeddings, torch::kFloat32, true); + + dit_images.reserve( + forward_outputs.value().dit_forward_output.tensors.size()); + for (auto image : + forward_outputs.value().dit_forward_output.tensors) { + dit_images.emplace_back(image); + } - // [num_seq] - const auto& next_tokens = - safe_to(sample_output.next_tokens, torch::kCPU, true); - if (next_tokens.defined() || FLAGS_enable_eplb) { - // [num_seq] FloatTensor - const auto& logprobs = - safe_to(sample_output.logprobs, torch::kCPU, true); - // [num_seq, topk] - const auto& top_tokens = - safe_to(sample_output.top_tokens, torch::kCPU, true); - // [num_seq, topk] - const auto& top_logprobs = - safe_to(sample_output.top_logprobs, torch::kCPU, true); - // [num_seq] - const auto& src_seq_idxes = - safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); // [num_seq] - const auto& out_tokens = - safe_to(beam_search_output.out_tokens, torch::kCPU, true); - // [num_seq] - const auto& out_logprobs = - safe_to(beam_search_output.out_logprobs, - torch::dtype(torch::kFloat32).device(torch::kCPU), - true); - auto ret = stream_->synchronize(); + next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); + if (next_tokens.defined() || FLAGS_enable_eplb) { + // [num_seq] FloatTensor + logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); + // [num_seq, topk] + top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true); + // [num_seq, topk] + top_logprobs = + safe_to(sample_output.top_logprobs, torch::kCPU, true); + // [num_seq] + src_seq_idxes = + safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); + // [num_seq] + out_tokens = + safe_to(beam_search_output.out_tokens, torch::kCPU, true); + // [num_seq] + out_logprobs = + safe_to(beam_search_output.out_logprobs, + torch::dtype(torch::kFloat32).device(torch::kCPU), + true); + } + }; + + if (use_default_stream) { + copy_output_to_host(); + } else { + c10::StreamGuard stream_guard = stream_->set_stream_guard(); + copy_output_to_host(); + } + if (use_default_stream) { + device_.synchronize_default_stream(); + } else { + stream_->synchronize(); + } + if (next_tokens.defined() || FLAGS_enable_eplb) { forward_output_to_proto(next_tokens, logprobs, top_tokens, @@ -730,6 +787,7 @@ void WorkerService::GetLastStepResult( src_seq_idxes, out_tokens, out_logprobs, + dit_images, pb_forward_output); } } diff --git a/xllm/core/distributed_runtime/worker_service.h b/xllm/core/distributed_runtime/worker_service.h index 2cc4ee2ea..be0b97756 100644 --- a/xllm/core/distributed_runtime/worker_service.h +++ b/xllm/core/distributed_runtime/worker_service.h @@ -148,6 +148,7 @@ class WorkerService : public proto::DistributeWorker { torch::Tensor& top_logprobs, torch::Tensor& embeddings, std::vector& mm_embeddings, + std::vector& dit_images, torch::Tensor& expert_load_data, int32_t& prepared_layer_id, torch::Tensor& src_seq_idxes, diff --git a/xllm/core/framework/CMakeLists.txt b/xllm/core/framework/CMakeLists.txt index c35af345e..8da405c6e 100644 --- a/xllm/core/framework/CMakeLists.txt +++ b/xllm/core/framework/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(xtensor) add_subdirectory(block) add_subdirectory(chat_template) add_subdirectory(kv_cache) +add_subdirectory(kv_cache_transfer) add_subdirectory(model) add_subdirectory(parallel_state) add_subdirectory(prefix_cache) diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index c7895b228..bda79bc0d 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -25,7 +25,7 @@ limitations under the License. #include "batch_input_builder.h" #include "common/global_flags.h" #include "common/metrics.h" -#include "core/common/rec_model_utils.h" +#include "core/util/rec_model_utils.h" #include "framework/batch/mposition.h" #include "framework/model/model_args.h" #include "framework/model/model_input_params.h" @@ -159,11 +159,25 @@ ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size, const ModelArgs& args, ThreadPool* thread_pool) { - output_targets_.clear(); RecType rec_type = RecType::kNone; if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { rec_type = sequence_groups_[0]->sequences()[0]->rec_type(); } + output_targets_.clear(); + if (rec_type == RecType::kOneRec) { + if (!sequence_groups_.empty()) { + // OneRec REC batches are tracked via sequence_groups_, while output + // target generation still walks sequences_. Refresh the flattened + // sequence view on every step so token writeback stays aligned after + // beam search expands or replaces the group-owned Sequence instances. + refresh_sequences_from_groups(); + } + if (FLAGS_enable_rec_prefill_only) { + refresh_onerec_prefill_output_targets(); + } else { + refresh_output_targets(); + } + } auto builder = RecBatchInputBuilder::create(rec_type, sequence_groups_, @@ -404,6 +418,59 @@ void Batch::refresh_output_targets() { } } +void Batch::refresh_onerec_prefill_output_targets() { + output_targets_.clear(); + if (sequences_.empty()) { + return; + } + + for (size_t seq_index = 0; seq_index < sequences_.size(); ++seq_index) { + auto* sequence = sequences_[seq_index]; + if (sequence == nullptr) { + continue; + } + + const auto token_ids = sequence->tokens(); + const uint32_t n_tokens = token_ids.size(); + const uint32_t n_kv_cache_tokens = + sequence->kv_state().kv_cache_tokens_num(); + const bool needs_context_target = sequence->is_onerec_model() && + n_tokens == 0 && n_kv_cache_tokens == 0 && + sequence->num_decoder_embeddings() > 0; + if (needs_context_target) { + output_targets_.push_back({sequence, /*sample_id=*/0, false}); + continue; + } + if (n_tokens <= n_kv_cache_tokens) { + continue; + } + + CHECK(allowed_max_tokens_[seq_index] > 0); + const uint32_t q_seq_len = + std::min(n_tokens - n_kv_cache_tokens, allowed_max_tokens_[seq_index]); + const uint32_t seq_len = q_seq_len + n_kv_cache_tokens; + const auto& sample_slots = sequence->sample_slots(); + + if (sample_slots.empty()) { + if (seq_len == n_tokens) { + output_targets_.push_back({sequence, /*sample_id=*/0, false}); + } + continue; + } + + for (const auto& sample_slot : sample_slots) { + const uint32_t sample_source_position = + get_sample_source_position(sample_slot); + if (sample_source_position < n_kv_cache_tokens || + sample_source_position >= seq_len) { + continue; + } + output_targets_.push_back( + {sequence, sample_slot.sample_id, /*from_sample_slot=*/true}); + } + } +} + void Batch::process_sample_output(const RawForwardOutput& raw_output, bool replace_fake_token) { if (raw_output.mm_embeddings.size() > 0) { @@ -416,15 +483,22 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output, continue; } std::vector seq_mm_embeddings; - seq_mm_embeddings.reserve(n_images); - for (int i = mm_embedding_idx; i < mm_embedding_idx + n_images; ++i) { + // if we want to return the full embeding of images and prompts, + // the output is a single embedding tensor, else it would be a vector of + // image embeddings + int64_t output_tensor_size = + FLAGS_enable_return_mm_full_embeddings ? 1 : n_images; + seq_mm_embeddings.reserve(output_tensor_size); + for (int64_t i = mm_embedding_idx; + i < mm_embedding_idx + output_tensor_size; + ++i) { CHECK_LT(i, raw_output.mm_embeddings.size()); seq_mm_embeddings.push_back(raw_output.mm_embeddings[i]); } seq->update_mm_embeddings(seq_mm_embeddings); // we only support complete mm embedding in one iteration now CHECK(seq->finished()); - mm_embedding_idx += n_images; + mm_embedding_idx += output_tensor_size; } } diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index b64b38d61..a14167b37 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -148,6 +148,7 @@ class Batch { }; void refresh_output_targets(); + void refresh_onerec_prefill_output_targets(); bool update_sequence_state(Sequence* seq, bool replace_fake_token); void append_token_for_sequence(Sequence* seq, diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 8f50424c7..18e3c3e57 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include +#include #include #include "common/global_flags.h" @@ -695,11 +697,65 @@ void BatchInputBuilder::process_swap_block_infos( [](const BlockTransferInfo& a, const BlockTransferInfo& b) { return a.src_block_id < b.src_block_id; }); +#if defined(USE_CUDA) + raw_forward_input.swap_blocks.insert(raw_forward_input.swap_blocks.end(), + swap_blocks.begin(), + swap_blocks.end()); + + if (swap_blocks.size() > 0) { + std::vector src_indices, dst_indices, cum_sum; + std::unordered_set src_set; + std::unordered_map dst_to_src; + bool has_overlap = false; + int32_t current_src = swap_blocks[0].src_block_id; + src_indices.reserve(swap_blocks.size()); + dst_indices.reserve(swap_blocks.size()); + cum_sum.reserve(swap_blocks.size()); + + for (const auto& block : swap_blocks) { + src_set.insert(block.src_block_id); + } + + src_indices.push_back(swap_blocks[0].src_block_id); + dst_indices.push_back(swap_blocks[0].dst_block_id); + dst_to_src.emplace(swap_blocks[0].dst_block_id, + swap_blocks[0].src_block_id); + if (src_set.count(swap_blocks[0].dst_block_id) > 0 && + swap_blocks[0].dst_block_id != swap_blocks[0].src_block_id) { + has_overlap = true; + } + for (size_t i = 1; i < swap_blocks.size(); i++) { + dst_indices.push_back(swap_blocks[i].dst_block_id); + auto [it, inserted] = dst_to_src.emplace(swap_blocks[i].dst_block_id, + swap_blocks[i].src_block_id); + if (!inserted && it->second != swap_blocks[i].src_block_id) { + has_overlap = true; + } + if (src_set.count(swap_blocks[i].dst_block_id) > 0 && + swap_blocks[i].dst_block_id != swap_blocks[i].src_block_id) { + has_overlap = true; + } + if (swap_blocks[i].src_block_id != current_src) { + src_indices.push_back(swap_blocks[i].src_block_id); + cum_sum.push_back(i); + current_src = swap_blocks[i].src_block_id; + } + } + cum_sum.emplace_back(swap_blocks.size()); + + if (!has_overlap) { + raw_forward_input.src_block_indices = std::move(src_indices); + raw_forward_input.dst_block_indices = std::move(dst_indices); + raw_forward_input.cum_sum = std::move(cum_sum); + } + } +#else if (swap_blocks.size() > 0) { std::vector src_indices, dst_indices, cum_sum; int32_t current_src = swap_blocks[0].src_block_id; src_indices.reserve(swap_blocks.size()); dst_indices.reserve(swap_blocks.size()); + cum_sum.reserve(swap_blocks.size()); src_indices.push_back(swap_blocks[0].src_block_id); dst_indices.push_back(swap_blocks[0].dst_block_id); @@ -718,6 +774,7 @@ void BatchInputBuilder::process_swap_block_infos( raw_forward_input.dst_block_indices = std::move(dst_indices); raw_forward_input.cum_sum = std::move(cum_sum); } +#endif } else { raw_forward_input.swap_blocks.insert(raw_forward_input.swap_blocks.end(), swap_block_transfer_infos_->begin(), diff --git a/xllm/core/framework/batch/beam_search.h b/xllm/core/framework/batch/beam_search.h index 1cd6c213f..603bfc6c4 100644 --- a/xllm/core/framework/batch/beam_search.h +++ b/xllm/core/framework/batch/beam_search.h @@ -15,26 +15,31 @@ limitations under the License. #pragma once +#include +#include +#include + +#include "framework/block/block.h" + namespace xllm { +struct BeamSourceInfo { + size_t suffix_start_idx = 0; + std::vector generated_token_ids; + std::vector> generated_logprobs; + std::vector src_blocks; +}; + // BeamCandidate structure for beam search sorting struct BeamCandidate { - size_t seq_index; - float logprob_sum; - std::vector token_ids; - std::vector> logprobs; + size_t source_index = 0; + float logprob_sum = 0.0f; + bool override_last_token = false; + int32_t last_token_id = 0; + std::optional last_token_logprob; BeamCandidate() = default; - BeamCandidate(size_t seq_idx, - float logprob, - std::vector& token_ids, - std::vector>& logprobs) - : seq_index(seq_idx), - logprob_sum(logprob), - token_ids(std::move(token_ids)), - logprobs(std::move(logprobs)) {} - bool operator<(const BeamCandidate& other) const { return logprob_sum > other.logprob_sum; } @@ -127,4 +132,4 @@ class SimpleTopKOptimizer { using SimpleTopKOptimizerBeamCandidate = SimpleTopKOptimizer; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index 6b192b183..f2ba313f9 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -61,10 +61,21 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector negative_pooled_prompt_embeds; std::vector images; + std::vector condition_images; std::vector mask_images; std::vector control_images; std::vector latents; std::vector masked_image_latents; + const auto batch_size = request_vec_.size(); + prompt_embeds.reserve(batch_size); + pooled_prompt_embeds.reserve(batch_size); + negative_prompt_embeds.reserve(batch_size); + negative_pooled_prompt_embeds.reserve(batch_size); + images.reserve(batch_size); + mask_images.reserve(batch_size); + control_images.reserve(batch_size); + latents.reserve(batch_size); + masked_image_latents.reserve(batch_size); for (const auto& request : request_vec_) { const auto& generation_params = request->state().generation_params(); if (input.generation_params != generation_params) { @@ -96,6 +107,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.emplace_back(input_params.image); mask_images.emplace_back(input_params.mask_image); + condition_images.emplace_back(input_params.condition_image); control_images.emplace_back(input_params.control_image); } @@ -119,6 +131,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.images = torch::stack(images); } + if (check_tensors_valid(condition_images)) { + input.condition_images = torch::stack(condition_images); + } + if (check_tensors_valid(mask_images)) { input.mask_images = torch::stack(mask_images); } diff --git a/xllm/core/framework/batch/mposition.cpp b/xllm/core/framework/batch/mposition.cpp old mode 100755 new mode 100644 index ef4f0ee64..369fd5468 --- a/xllm/core/framework/batch/mposition.cpp +++ b/xllm/core/framework/batch/mposition.cpp @@ -23,22 +23,23 @@ limitations under the License. namespace xllm { namespace { -std::vector> groupByTokenType( +std::vector> groupByTokenType( const std::vector& token_types) { - std::vector> groups; + std::vector> groups; if (token_types.empty()) return groups; std::string current_key = token_types[0]; - int start = 0; + int32_t start = 0; - for (int i = 1; i < token_types.size(); ++i) { + for (size_t i = 1; i < token_types.size(); ++i) { if (token_types[i] != current_key) { groups.emplace_back(current_key, start, i); current_key = token_types[i]; start = i; } } - groups.emplace_back(current_key, start, static_cast(token_types.size())); + groups.emplace_back( + current_key, start, static_cast(token_types.size())); return groups; } } // namespace @@ -59,12 +60,15 @@ torch::Tensor MPositionHelper::get_positions() { torch::Tensor second_per_grid_ts; if (auto res = mm_data.get("second_per_grid_ts")) second_per_grid_ts = res.value(); - std::tuple res; - if (!absl::StartsWith(args_.model_type(), "glm4v")) { - res = get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts); - } else { + std::tuple res; + if (absl::StartsWith(args_.model_type(), "glm4v")) { res = get_positions_glm(image_grid_thw, video_grid_thw); + } else if (absl::StartsWith(args_.model_type(), "qwen3_vl")) { + res = get_positions_qwen3(image_grid_thw, video_grid_thw); + } else { + res = get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts); } + seq_.set_mrope_position_delta(std::get<1>(res)); return std::get<0>(res); } else { @@ -72,7 +76,7 @@ torch::Tensor MPositionHelper::get_positions() { } } -std::tuple MPositionHelper::get_positions_glm( +std::tuple MPositionHelper::get_positions_glm( torch::Tensor image_grid_thw, torch::Tensor video_grid_thw) { auto input_tokens = seq_.tokens(); @@ -86,9 +90,9 @@ std::tuple MPositionHelper::get_positions_glm( std::vector input_token_type; bool in_video = false; - int num_tokens = input_tokens.size(); + int32_t num_tokens = input_tokens.size(); - for (int index = 0; index < num_tokens; ++index) { + for (int32_t index = 0; index < num_tokens; ++index) { auto token = input_tokens[index]; if (token == video_start_token_id) { in_video = true; @@ -105,26 +109,26 @@ std::tuple MPositionHelper::get_positions_glm( } } auto input_type_group = groupByTokenType(input_token_type); - int image_index = 0; - int video_index = 0; - int video_group_index = 0; + int32_t image_index = 0; + int32_t video_index = 0; + int32_t video_group_index = 0; std::vector llm_pos_ids_list; - int video_frame_num = 1; + int32_t video_frame_num = 1; for (const auto& group : input_type_group) { const auto& modality_type = std::get<0>(group); - int start_idx = std::get<1>(group); - int end_idx = std::get<2>(group); - int st_idx = 0; + int32_t start_idx = std::get<1>(group); + int32_t end_idx = std::get<2>(group); + int32_t st_idx = 0; if (!llm_pos_ids_list.empty()) { - st_idx = llm_pos_ids_list.back().max().item() + 1; + st_idx = llm_pos_ids_list.back().max().item() + 1; } if (modality_type == "image") { auto grid = image_grid_thw[image_index]; - int t = grid[0].item(); - int h = grid[1].item() / spatial_merge_size; - int w = grid[2].item() / spatial_merge_size; + int32_t t = grid[0].item(); + int32_t h = grid[1].item() / spatial_merge_size; + int32_t w = grid[2].item() / spatial_merge_size; auto t_arange = torch::arange(t, dtype).view({-1, 1}).expand({-1, h * w}).flatten(); @@ -138,11 +142,13 @@ std::tuple MPositionHelper::get_positions_glm( video_frame_num = 1; image_index++; } else if (modality_type == "video") { - int t = video_frame_num; - int h = video_grid_thw[video_index][1].item() / spatial_merge_size; - int w = video_grid_thw[video_index][2].item() / spatial_merge_size; + int32_t t = video_frame_num; + int32_t h = + video_grid_thw[video_index][1].item() / spatial_merge_size; + int32_t w = + video_grid_thw[video_index][2].item() / spatial_merge_size; - for (int t_idx = 0; t_idx < t; ++t_idx) { + for (int32_t t_idx = 0; t_idx < t; ++t_idx) { auto t_tensor = torch::full({1, h * w}, t_idx, dtype).flatten(); auto h_tensor = torch::arange(h, dtype) .view({1, -1, 1}) @@ -158,13 +164,13 @@ std::tuple MPositionHelper::get_positions_glm( } video_group_index++; - if (video_group_index >= video_grid_thw[video_index][0].item()) { + if (video_group_index >= video_grid_thw[video_index][0].item()) { video_index++; video_group_index = 0; } video_frame_num++; } else { // text - int text_len = end_idx - start_idx; + int32_t text_len = end_idx - start_idx; auto arange = torch::arange(text_len, dtype).view({1, -1}).expand({3, -1}) + st_idx; llm_pos_ids_list.push_back(arange); @@ -175,13 +181,13 @@ std::tuple MPositionHelper::get_positions_glm( torch::Tensor llm_positions = torch::cat(llm_pos_ids_list, /*dim=*/1).reshape({3, -1}); llm_positions = llm_positions; - int mrope_position_delta = - (llm_positions.max().item() + 1 - input_tokens.size()); + int32_t mrope_position_delta = + (llm_positions.max().item() + 1 - input_tokens.size()); return std::make_pair(llm_positions, mrope_position_delta); } -std::tuple MPositionHelper::get_positions_p( +std::tuple MPositionHelper::get_positions_p( torch::Tensor image_grid_thw, torch::Tensor video_grid_thw, torch::Tensor second_per_grid_ts) { @@ -192,23 +198,26 @@ std::tuple MPositionHelper::get_positions_p( auto tokens_per_second = args_.mm_tokens_per_second(); auto input_tokens = seq_.tokens(); - auto input_tokens_tensor = torch::tensor(std::vector(input_tokens)); + auto input_tokens_tensor = + torch::tensor(std::vector(input_tokens), torch::kInt32); auto vision_start_indices = torch::argwhere(input_tokens_tensor == vision_start_token_id).squeeze(1); auto vision_tokens = input_tokens_tensor.index({vision_start_indices + 1}); - int image_nums = torch::sum(vision_tokens == image_token_id).item(); - int video_nums = torch::sum(vision_tokens == video_token_id).item(); + int32_t image_nums = + torch::sum(vision_tokens == image_token_id).item(); + int32_t video_nums = + torch::sum(vision_tokens == video_token_id).item(); std::vector llm_pos_ids_list; - int st = 0; - int remain_images = image_nums, remain_videos = video_nums; - int image_index = 0, video_index = 0; + int32_t st = 0; + int32_t remain_images = image_nums, remain_videos = video_nums; + int32_t image_index = 0, video_index = 0; - for (int i = 0; i < image_nums + video_nums; ++i) { + for (int32_t i = 0; i < image_nums + video_nums; ++i) { float video_second_per_grid_t = 1.0f; - int ed_image = input_tokens.size() + 1; - int ed_video = input_tokens.size() + 1; + int32_t ed_image = input_tokens.size() + 1; + int32_t ed_video = input_tokens.size() + 1; if (remain_images > 0) { auto it = std::find( @@ -226,19 +235,19 @@ std::tuple MPositionHelper::get_positions_p( } } - int t = 0, h = 0, w = 0; - int ed = 0; + int32_t t = 0, h = 0, w = 0; + int32_t ed = 0; if (ed_image < ed_video) { - t = image_grid_thw[image_index][0].item(); - h = image_grid_thw[image_index][1].item(); - w = image_grid_thw[image_index][2].item(); + t = image_grid_thw[image_index][0].item(); + h = image_grid_thw[image_index][1].item(); + w = image_grid_thw[image_index][2].item(); image_index++; remain_images--; ed = ed_image; } else { - t = video_grid_thw[video_index][0].item(); - h = video_grid_thw[video_index][1].item(); - w = video_grid_thw[video_index][2].item(); + t = video_grid_thw[video_index][0].item(); + h = video_grid_thw[video_index][1].item(); + w = video_grid_thw[video_index][2].item(); video_second_per_grid_t = second_per_grid_ts[video_index].item(); @@ -247,14 +256,14 @@ std::tuple MPositionHelper::get_positions_p( ed = ed_video; } - int llm_grid_t = t; - int llm_grid_h = h / spatial_merge_size; - int llm_grid_w = w / spatial_merge_size; - int text_len = ed - st; + int32_t llm_grid_t = t; + int32_t llm_grid_h = h / spatial_merge_size; + int32_t llm_grid_w = w / spatial_merge_size; + int32_t text_len = ed - st; - int st_idx = 0; + int32_t st_idx = 0; if (!llm_pos_ids_list.empty()) { - st_idx = llm_pos_ids_list.back().max().item() + 1; + st_idx = llm_pos_ids_list.back().max().item() + 1; } if (text_len > 0) { @@ -288,13 +297,142 @@ std::tuple MPositionHelper::get_positions_p( st = ed + llm_grid_t * llm_grid_h * llm_grid_w; } - if (st < static_cast(input_tokens.size())) { - int st_idx = 0; + if (st < static_cast(input_tokens.size())) { + int32_t st_idx = 0; + if (!llm_pos_ids_list.empty()) { + st_idx = llm_pos_ids_list.back().max().item() + 1; + } + + int32_t text_len = input_tokens.size() - st; + auto text_pos = + torch::arange(text_len, torch::kInt32).view({1, -1}).expand({3, -1}) + + st_idx; + llm_pos_ids_list.push_back(text_pos); + } + + auto llm_positions = torch::cat(llm_pos_ids_list, 1).reshape({3, -1}); + int32_t mrope_position_delta = + (llm_positions.max().item() + 1 - input_tokens.size()); + return std::make_tuple(llm_positions, mrope_position_delta); +} + +std::tuple MPositionHelper::get_positions_qwen3( + torch::Tensor image_grid_thw, + torch::Tensor video_grid_thw) { + auto image_token_id = args_.image_token_id(); + auto video_token_id = args_.video_token_id(); + auto vision_start_token_id = args_.vision_start_token_id(); + auto spatial_merge_size = args_.mm_spatial_merge_size(); + + if (video_grid_thw.defined() && video_grid_thw.numel() > 0) { + auto t_counts = + video_grid_thw.index({torch::indexing::Slice(), 0}).to(torch::kLong); + video_grid_thw = + torch::repeat_interleave(video_grid_thw, t_counts, /*dim=*/0); + video_grid_thw.index_put_({torch::indexing::Slice(), 0}, 1); + } + + auto input_tokens = seq_.tokens(); + auto input_tokens_tensor = + torch::tensor(std::vector(input_tokens), torch::kInt32); + auto vision_start_indices = + torch::argwhere(input_tokens_tensor == vision_start_token_id).squeeze(1); + auto vision_tokens = input_tokens_tensor.index({vision_start_indices + 1}); + + int32_t image_nums = + torch::sum(vision_tokens == image_token_id).item(); + int32_t video_nums = + torch::sum(vision_tokens == video_token_id).item(); + + std::vector llm_pos_ids_list; + int32_t st = 0; + int32_t remain_images = image_nums, remain_videos = video_nums; + int32_t image_index = 0, video_index = 0; + + for (int32_t i = 0; i < image_nums + video_nums; ++i) { + int32_t ed_image = input_tokens.size() + 1; + int32_t ed_video = input_tokens.size() + 1; + + if (remain_images > 0) { + auto it = std::find( + input_tokens.begin() + st, input_tokens.end(), image_token_id); + if (it != input_tokens.end()) { + ed_image = std::distance(input_tokens.begin(), it); + } + } + + if (remain_videos > 0) { + auto it = std::find( + input_tokens.begin() + st, input_tokens.end(), video_token_id); + if (it != input_tokens.end()) { + ed_video = std::distance(input_tokens.begin(), it); + } + } + + int32_t t = 0, h = 0, w = 0; + int32_t ed = 0; + if (ed_image < ed_video) { + t = image_grid_thw[image_index][0].item(); + h = image_grid_thw[image_index][1].item(); + w = image_grid_thw[image_index][2].item(); + image_index++; + remain_images--; + ed = ed_image; + } else { + t = video_grid_thw[video_index][0].item(); + h = video_grid_thw[video_index][1].item(); + w = video_grid_thw[video_index][2].item(); + video_index++; + remain_videos--; + ed = ed_video; + } + + int32_t llm_grid_t = t; + int32_t llm_grid_h = h / spatial_merge_size; + int32_t llm_grid_w = w / spatial_merge_size; + int32_t text_len = ed - st; + + int32_t st_idx = 0; + if (!llm_pos_ids_list.empty()) { + st_idx = llm_pos_ids_list.back().max().item() + 1; + } + + if (text_len > 0) { + auto text_pos = + torch::arange(text_len, torch::kInt32).view({1, -1}).expand({3, -1}) + + st_idx; + llm_pos_ids_list.push_back(text_pos); + } + + auto t_index = torch::arange(llm_grid_t, torch::kInt32) + .view({-1, 1}) + .expand({-1, llm_grid_h * llm_grid_w}) + .flatten(); + + auto h_index = torch::arange(llm_grid_h, torch::kInt32) + .view({1, -1, 1}) + .expand({llm_grid_t, -1, llm_grid_w}) + .flatten(); + + auto w_index = torch::arange(llm_grid_w, torch::kInt32) + .view({1, 1, -1}) + .expand({llm_grid_t, llm_grid_h, -1}) + .flatten(); + + auto visual_pos = + torch::stack({t_index, h_index, w_index}) + text_len + st_idx; + llm_pos_ids_list.push_back(visual_pos); + + st = ed + llm_grid_t * llm_grid_h * llm_grid_w; + } + + if (st < static_cast(input_tokens.size())) { + int32_t st_idx = 0; if (!llm_pos_ids_list.empty()) { - st_idx = llm_pos_ids_list.back().max().item() + 1; + st_idx = llm_pos_ids_list.back().max().item() + 1; } - int text_len = input_tokens.size() - st; + int32_t text_len = input_tokens.size() - st; auto text_pos = torch::arange(text_len, torch::kInt32).view({1, -1}).expand({3, -1}) + st_idx; @@ -302,17 +440,18 @@ std::tuple MPositionHelper::get_positions_p( } auto llm_positions = torch::cat(llm_pos_ids_list, 1).reshape({3, -1}); - int mrope_position_delta = - (llm_positions.max().item() + 1 - input_tokens.size()); + int32_t mrope_position_delta = + (llm_positions.max().item() + 1 - input_tokens.size()); return std::make_tuple(llm_positions, mrope_position_delta); } torch::Tensor MPositionHelper::get_positions_d() { auto mrope_position_delta = seq_.get_mrope_position_delta(); auto num_tokens = seq_.num_tokens(); - return torch::arange(int(mrope_position_delta + num_tokens - 1), - int(mrope_position_delta + num_tokens), - torch::kInt32) + return torch::arange( + static_cast(mrope_position_delta + num_tokens - 1), + static_cast(mrope_position_delta + num_tokens), + torch::kInt32) .expand({3, -1}); } diff --git a/xllm/core/framework/batch/mposition.h b/xllm/core/framework/batch/mposition.h index 466660baa..a89e72ab0 100644 --- a/xllm/core/framework/batch/mposition.h +++ b/xllm/core/framework/batch/mposition.h @@ -33,11 +33,14 @@ class MPositionHelper { torch::Tensor get_positions(); private: - std::tuple get_positions_p( + std::tuple get_positions_p( torch::Tensor image_grid_thw, torch::Tensor video_grid_thw, torch::Tensor second_per_grid_ts); - std::tuple get_positions_glm( + std::tuple get_positions_qwen3( + torch::Tensor image_grid_thw, + torch::Tensor video_grid_thw); + std::tuple get_positions_glm( torch::Tensor image_grid_thw, torch::Tensor video_grid_thw); diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_batch_input_builder.cpp index 941585ef9..a96de9706 100644 --- a/xllm/core/framework/batch/onerec_batch_input_builder.cpp +++ b/xllm/core/framework/batch/onerec_batch_input_builder.cpp @@ -858,6 +858,14 @@ ForwardInput OneRecBatchInputBuilder::build_rec_forward_input( // ========== Common parameter settings ========== // Batch set other parameters input_params.embedding_ids.assign(num_sequences, 0); + input_params.batch_id = batch_id_; + input_params.request_ids.clear(); + input_params.request_ids.reserve(static_cast(num_sequences)); + for (auto* group : sequence_groups_) { + for (const auto& sequence : group->sequences()) { + input_params.request_ids.emplace_back(sequence->request_id()); + } + } // OneRec model parameters onerec_params.rec_stage = OneRecModelInputParams::RecStage::PREFILL; diff --git a/xllm/core/framework/batch/rec_batch_input_builder.cpp b/xllm/core/framework/batch/rec_batch_input_builder.cpp index 31b14b5d7..39cfb08ad 100644 --- a/xllm/core/framework/batch/rec_batch_input_builder.cpp +++ b/xllm/core/framework/batch/rec_batch_input_builder.cpp @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "core/common/rec_model_utils.h" +#include "core/util/rec_model_utils.h" #include "onerec_batch_input_builder.h" #include "rec_multi_round_batch_input_builder.h" diff --git a/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp b/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp index 6907f9db3..06ec0f0e9 100644 --- a/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp +++ b/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp @@ -27,7 +27,7 @@ limitations under the License. #include "common/global_flags.h" #include "common/metrics.h" -#include "core/common/rec_model_utils.h" +#include "core/util/rec_model_utils.h" #include "framework/batch/mposition.h" #include "framework/model/model_args.h" #include "framework/model/model_input_params.h" diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index 5ad918cbe..bd7a28f92 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -192,13 +192,14 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) { AUTO_COUNTER(allocate_blocks_latency_seconds); DCHECK(sequence != nullptr); int32_t dp_rank = get_dp_rank(sequence); + const bool started_empty = sequence->kv_state().num_kv_blocks() == 0; const bool needs_embedding_id = !sequence->has_embedding_id(); if (needs_embedding_id && !allocate_embedding_id(sequence, dp_rank)) { return false; } // first try to allocate shared blocks - if (sequence->kv_state().num_kv_blocks() == 0) { + if (started_empty) { BlockManagerPool::allocate_shared(sequence); } @@ -215,6 +216,13 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) { const auto blocks = block_managers_[dp_rank]->allocate(num_additional_blocks); if (blocks.size() != num_additional_blocks) { + if (started_empty) { + block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks()); + if (needs_embedding_id) { + deallocate_embedding_id(sequence, dp_rank); + } + sequence->reset(); + } // LOG(ERROR) << " Fail to allocate " << num_additional_blocks << " // blocks."; return false; diff --git a/xllm/core/framework/dit_cache/CMakeLists.txt b/xllm/core/framework/dit_cache/CMakeLists.txt index dd0649fd3..7509fd99b 100644 --- a/xllm/core/framework/dit_cache/CMakeLists.txt +++ b/xllm/core/framework/dit_cache/CMakeLists.txt @@ -13,6 +13,7 @@ cc_library( fbcache.h fbcache_taylorseer.h taylorseer.h + residual_cache.h SRCS dit_cache_impl.cpp dit_cache.cpp @@ -20,8 +21,9 @@ cc_library( fbcache.cpp fbcache_taylorseer.cpp taylorseer.cpp + residual_cache.cpp DEPS torch glog::glog Folly::folly -) \ No newline at end of file +) diff --git a/xllm/core/framework/dit_cache/dit_cache.cpp b/xllm/core/framework/dit_cache/dit_cache.cpp index 9a94823e7..156161915 100644 --- a/xllm/core/framework/dit_cache/dit_cache.cpp +++ b/xllm/core/framework/dit_cache/dit_cache.cpp @@ -19,26 +19,48 @@ namespace xllm { bool DiTCache::init(const DiTCacheConfig& cfg) { active_cache_ = create_dit_cache(cfg); - if (!active_cache_) { + active_cond_cache_ = create_dit_cache(cfg); + if (!active_cache_ || !active_cond_cache_) { return false; } active_cache_->init(cfg); + active_cond_cache_->init(cfg); return true; } -bool DiTCache::on_before_block(const CacheBlockIn& blockin) { +torch::Tensor DiTCache::get_tensor_or_empty(const TensorMap& m, + const std::string& k) { + auto it = m.find(k); + if (it != m.end()) return it->second; + return torch::Tensor(); +} + +bool DiTCache::on_before_block(const CacheBlockIn& blockin, bool use_cfg) { + if (use_cfg) { + return active_cond_cache_->on_before_block(blockin); + } return active_cache_->on_before_block(blockin); } -CacheBlockOut DiTCache::on_after_block(const CacheBlockIn& blockin) { +CacheBlockOut DiTCache::on_after_block(const CacheBlockIn& blockin, + bool use_cfg) { + if (use_cfg) { + return active_cond_cache_->on_after_block(blockin); + } return active_cache_->on_after_block(blockin); } -bool DiTCache::on_before_step(const CacheStepIn& stepin) { +bool DiTCache::on_before_step(const CacheStepIn& stepin, bool use_cfg) { + if (use_cfg) { + return active_cond_cache_->on_before_step(stepin); + } return active_cache_->on_before_step(stepin); } -CacheStepOut DiTCache::on_after_step(const CacheStepIn& stepin) { +CacheStepOut DiTCache::on_after_step(const CacheStepIn& stepin, bool use_cfg) { + if (use_cfg) { + return active_cond_cache_->on_after_step(stepin); + } return active_cache_->on_after_step(stepin); } diff --git a/xllm/core/framework/dit_cache/dit_cache.h b/xllm/core/framework/dit_cache/dit_cache.h index 868b22eb7..f536f5685 100644 --- a/xllm/core/framework/dit_cache/dit_cache.h +++ b/xllm/core/framework/dit_cache/dit_cache.h @@ -35,14 +35,41 @@ class DiTCache { bool init(const DiTCacheConfig& cfg); - bool on_before_block(const CacheBlockIn& blockin); - CacheBlockOut on_after_block(const CacheBlockIn& blockin); + DiTCache(const DiTCacheConfig& cfg) { + active_cache_ = create_dit_cache(cfg); + active_cond_cache_ = create_dit_cache(cfg); + if (!active_cache_ || !active_cond_cache_) { + LOG(ERROR) << "failed to initialized dit cache, " + "please check your config"; + } + active_cache_->init(cfg); + active_cond_cache_->init(cfg); + } + + bool on_before_block(const CacheBlockIn& blockin, bool use_cfg = false); + + CacheBlockOut on_after_block(const CacheBlockIn& blockin, + bool use_cfg = false); - bool on_before_step(const CacheStepIn& stepin); - CacheStepOut on_after_step(const CacheStepIn& stepin); + bool on_before_step(const CacheStepIn& stepin, bool use_cfg = false); + + CacheStepOut on_after_step(const CacheStepIn& stepin, bool use_cfg = false); + + virtual void set_infer_steps(const int64_t& infer_steps) { + active_cache_->set_infer_steps(infer_steps); + active_cond_cache_->set_infer_steps(infer_steps); + } + + virtual void set_num_blocks(const int64_t& num_blocks) { + active_cache_->set_num_blocks(num_blocks); + active_cond_cache_->set_num_blocks(num_blocks); + } private: + torch::Tensor get_tensor_or_empty(const TensorMap& m, const std::string& k); + std::unique_ptr active_cache_; + std::unique_ptr active_cond_cache_; }; } // namespace xllm diff --git a/xllm/core/framework/dit_cache/dit_cache_config.h b/xllm/core/framework/dit_cache/dit_cache_config.h index 84aff731c..289aab985 100644 --- a/xllm/core/framework/dit_cache/dit_cache_config.h +++ b/xllm/core/framework/dit_cache/dit_cache_config.h @@ -22,6 +22,7 @@ enum class PolicyType { FBCache, TaylorSeer, FBCacheTaylorSeer, + ResidualCache }; struct DiTBaseCacheOptions { @@ -50,6 +51,23 @@ struct FBCacheTaylorSeerOptions : public DiTBaseCacheOptions { int n_derivatives = 3; }; +struct ResidualCacheOptions { + // The number of steps to skip at the start. + int64_t dit_cache_start_steps = 5; + + // The number of steps to skip at the end. + int64_t dit_cache_end_steps = 5; + + // The number of blocks to skip at the start. + int64_t dit_cache_start_blocks = 5; + + // The number of blocks to skip at the end. + int64_t dit_cache_end_blocks = 5; + + // the interval steps to skip for derivative calculation. + int64_t skip_interval_steps = 3; +}; + struct DiTCacheConfig { DiTCacheConfig() = default; @@ -64,6 +82,9 @@ struct DiTCacheConfig { // the configuration for combined FBCache with TaylorSeer policy. FBCacheTaylorSeerOptions fbcachetaylorseer; + + // the configuration for ResidualCache policy. + ResidualCacheOptions residual_cache; }; } // namespace xllm diff --git a/xllm/core/framework/dit_cache/dit_cache_impl.cpp b/xllm/core/framework/dit_cache/dit_cache_impl.cpp index 6adac9d68..1c0fe7b7a 100644 --- a/xllm/core/framework/dit_cache/dit_cache_impl.cpp +++ b/xllm/core/framework/dit_cache/dit_cache_impl.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "dit_non_cache.h" #include "fbcache.h" #include "fbcache_taylorseer.h" +#include "residual_cache.h" #include "taylorseer.h" namespace xllm { @@ -55,6 +56,8 @@ std::unique_ptr create_dit_cache(const DiTCacheConfig& cfg) { return std::make_unique(); case PolicyType::FBCacheTaylorSeer: return std::make_unique(); + case PolicyType::ResidualCache: + return std::make_unique(); default: return std::make_unique(); } diff --git a/xllm/core/framework/dit_cache/dit_cache_impl.h b/xllm/core/framework/dit_cache/dit_cache_impl.h index 9df6e5c31..f6096d303 100644 --- a/xllm/core/framework/dit_cache/dit_cache_impl.h +++ b/xllm/core/framework/dit_cache/dit_cache_impl.h @@ -37,10 +37,20 @@ class DitCacheImpl { virtual bool on_before_step(const CacheStepIn& stepin) = 0; virtual CacheStepOut on_after_step(const CacheStepIn& stepin) = 0; + virtual void set_infer_steps(const int64_t& infer_steps) { + infer_steps_ = infer_steps; + } + + virtual void set_num_blocks(const int64_t& num_blocks) { + num_blocks_ = num_blocks; + } + protected: int64_t num_inference_steps_; int64_t warmup_steps_; int64_t current_step_; + int64_t infer_steps_; + int64_t num_blocks_; TensorMap buffers; static torch::Tensor get_tensor_or_empty(const TensorMap& m, diff --git a/xllm/core/framework/dit_cache/fbcache.cpp b/xllm/core/framework/dit_cache/fbcache.cpp index 5327e0f68..ca80592b1 100644 --- a/xllm/core/framework/dit_cache/fbcache.cpp +++ b/xllm/core/framework/dit_cache/fbcache.cpp @@ -58,8 +58,8 @@ CacheBlockOut FBCache::on_after_block(const CacheBlockIn& blockin) { if (can_use_cache(first_hidden_states_residual)) { use_cache_ = true; - auto [new_hidden, new_encoder] = - apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states); + auto [new_hidden, new_encoder] = apply_prev_hidden_states_residual( + original_hidden_states, encoder_hidden_states); output_hidden_states = std::move(new_hidden); output_encoder_hidden_states = std::move(new_encoder); } else { diff --git a/xllm/core/framework/dit_cache/residual_cache.cpp b/xllm/core/framework/dit_cache/residual_cache.cpp new file mode 100644 index 000000000..2aa5319e4 --- /dev/null +++ b/xllm/core/framework/dit_cache/residual_cache.cpp @@ -0,0 +1,136 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "residual_cache.h" + +namespace xllm { + +void ResidualCache::init(const DiTCacheConfig& cfg) { + CHECK_GT(cfg.residual_cache.skip_interval_steps, 0) + << "skip_interval_steps must be > 0"; + CHECK_GT(cfg.residual_cache.dit_cache_start_steps, 0) + << "dit_cache_start_steps must be > 0"; + CHECK_GT(cfg.residual_cache.dit_cache_end_steps, 0) + << "dit_cache_end_steps must be > 0"; + CHECK_GT(cfg.residual_cache.dit_cache_start_blocks, 0) + << "dit_cache_start_blocks must be > 0"; + CHECK_GT(cfg.residual_cache.dit_cache_end_blocks, 0) + << "dit_cache_end_blocks must be > 0"; + skip_interval_steps_ = cfg.residual_cache.skip_interval_steps; + dit_cache_start_steps_ = cfg.residual_cache.dit_cache_start_steps; + dit_cache_end_steps_ = cfg.residual_cache.dit_cache_end_steps; + dit_cache_start_blocks_ = cfg.residual_cache.dit_cache_start_blocks; + dit_cache_end_blocks_ = cfg.residual_cache.dit_cache_end_blocks; + reset_cache(); +} + +void ResidualCache::reset_cache() { + use_cache_ = false; + update_cache_ = false; +} + +void ResidualCache::mark_step_begin() { ++current_step_; } + +torch::Tensor ResidualCache::get_residual(const torch::Tensor& hidden_states, + const std::string& key) { + return hidden_states - buffers[key]; +} + +torch::Tensor ResidualCache::add_residual(const torch::Tensor& hidden_states, + const std::string& key) { + return hidden_states + buffers[key]; +} + +void ResidualCache::update(const torch::Tensor& residual, + const std::string& key) { + buffers[key] = residual; +} + +bool ResidualCache::cache_validation() { + bool step_valid = + infer_steps_ > dit_cache_start_steps_ + dit_cache_end_steps_ && + infer_steps_ > dit_cache_start_steps_ && + infer_steps_ > dit_cache_end_steps_; + bool block_valid = + num_blocks_ > dit_cache_start_blocks_ + dit_cache_end_blocks_ && + num_blocks_ > dit_cache_start_blocks_ && + num_blocks_ > dit_cache_end_blocks_; + return step_valid & block_valid; +} + +bool ResidualCache::on_before_block(const CacheBlockIn& blockin) { + // when infer_steps is less than skipped_skips, won't use cache + if (!cache_validation() || !use_cache_ || + blockin.block_id < dit_cache_start_blocks_ || + blockin.block_id >= num_blocks_ - dit_cache_end_blocks_ - 1) { + return false; + } + + return true; +} + +CacheBlockOut ResidualCache::on_after_block(const CacheBlockIn& blockin) { + TensorMap out_map; + auto hidden_states = get_tensor_or_empty(blockin.tensors, "hidden_states"); + auto encoder_hidden_states = + get_tensor_or_empty(blockin.tensors, "encoder_hidden_states"); + if (cache_validation()) { + if (use_cache_) { + if (blockin.block_id == num_blocks_ - dit_cache_end_blocks_ - 1) { + out_map["hidden_states"] = add_residual(hidden_states, "hidden_states"); + out_map["encoder_hidden_states"] = + add_residual(encoder_hidden_states, "encoder_hidden_states"); + return CacheBlockOut(out_map); + } + } else if (update_cache_) { + if (blockin.block_id == dit_cache_start_blocks_ - 1) { + // cache + update(hidden_states.clone(), "hidden_states"); + update(encoder_hidden_states.clone(), "encoder_hidden_states"); + } else if (blockin.block_id == num_blocks_ - dit_cache_end_blocks_ - 1) { + // calculate residual and update cache + update(get_residual(hidden_states, "hidden_states"), "hidden_states"); + update(get_residual(encoder_hidden_states, "encoder_hidden_states"), + "encoder_hidden_states"); + } + } + } + out_map["hidden_states"] = hidden_states; + out_map["encoder_hidden_states"] = encoder_hidden_states; + return CacheBlockOut(out_map); +} + +bool ResidualCache::on_before_step(const CacheStepIn& stepin) { + current_step_ = stepin.step_id; + // if outside the target steps, do nothing + if (!cache_validation() || current_step_ < dit_cache_start_steps_ - 1 || + current_step_ >= infer_steps_ - dit_cache_end_steps_) { + reset_cache(); + return false; + } + // if inside target steps, use_cache when inside the interval + // update cache when interval ends + use_cache_ = + ((current_step_ - (dit_cache_start_steps_ - 1)) % skip_interval_steps_ != + 0); + update_cache_ = !use_cache_; + return false; +} + +CacheStepOut ResidualCache::on_after_step(const CacheStepIn& stepin) { + return CacheStepOut(stepin.tensors); +} + +} // namespace xllm diff --git a/xllm/core/framework/dit_cache/residual_cache.h b/xllm/core/framework/dit_cache/residual_cache.h new file mode 100644 index 000000000..27aee8c0c --- /dev/null +++ b/xllm/core/framework/dit_cache/residual_cache.h @@ -0,0 +1,76 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "dit_cache_impl.h" + +namespace xllm { + +using TensorMap = std::unordered_map; + +class ResidualCache : public DitCacheImpl { + public: + ResidualCache() = default; + ~ResidualCache() = default; + + ResidualCache(const ResidualCache&) = delete; + ResidualCache& operator=(const ResidualCache&) = delete; + ResidualCache(ResidualCache&&) = default; + ResidualCache& operator=(ResidualCache&&) = default; + + void init(const DiTCacheConfig& cfg) override; + // check whether to use cache + bool cache_validation(); + + // Reset all cached derivatives and internal state + void reset_cache(); + + // Mark the beginning of a new inference step + void mark_step_begin(); + + // calculate residual + torch::Tensor get_residual(const torch::Tensor& hidden_states, + const std::string& key); + + // add residaul to hidden states + torch::Tensor add_residual(const torch::Tensor& hidden_states, + const std::string& key); + + // Update internal caches with the new residual + void update(const torch::Tensor& residual, const std::string& key); + + bool on_before_block(const CacheBlockIn& blockin) override; + CacheBlockOut on_after_block(const CacheBlockIn& blockin) override; + + bool on_before_step(const CacheStepIn& stepin) override; + CacheStepOut on_after_step(const CacheStepIn& stepin) override; + + private: + bool use_cache_ = false; + bool update_cache_ = false; + int64_t skip_interval_steps_; + int64_t dit_cache_start_steps_; + int64_t dit_cache_end_steps_; + int64_t dit_cache_start_blocks_; + int64_t dit_cache_end_blocks_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/dit_model_context.cpp b/xllm/core/framework/dit_model_context.cpp index 18d703aea..d4391a246 100644 --- a/xllm/core/framework/dit_model_context.cpp +++ b/xllm/core/framework/dit_model_context.cpp @@ -33,11 +33,13 @@ DiTModelContext::DiTModelContext( const std::unordered_map& model_args, const std::unordered_map& quant_args, const torch::TensorOptions& tensor_options, + const DiTCacheConfig& dit_config, const std::string& model_type) : parallel_args_(input_parallel_args), model_args_(std::move(model_args)), quant_args_(std::move(quant_args)), tensor_options_(tensor_options), + dit_config_(dit_config), model_type_(model_type) { #if defined(USE_NPU) int32_t device_id = tensor_options.device().index(); @@ -72,14 +74,21 @@ const QuantArgs& DiTModelContext::get_quant_args( } } -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_CUDA) || defined(USE_MLU) ModelContext DiTModelContext::get_model_context( const std::string& component) const { +#if defined(USE_NPU) return ModelContext(parallel_args_, get_model_args(component), get_quant_args(component), tensor_options_, context_); +#else + return ModelContext(parallel_args_, + get_model_args(component), + get_quant_args(component), + tensor_options_); +#endif } #endif diff --git a/xllm/core/framework/dit_model_context.h b/xllm/core/framework/dit_model_context.h index 6f705f51c..238cd0cad 100644 --- a/xllm/core/framework/dit_model_context.h +++ b/xllm/core/framework/dit_model_context.h @@ -21,6 +21,7 @@ limitations under the License. #include +#include "core/framework/dit_cache/dit_cache_config.h" #include "core/framework/model/model_args.h" #include "core/framework/model_context.h" #include "core/framework/quant_args.h" @@ -36,13 +37,14 @@ class DiTModelContext { const std::unordered_map& model_args, const std::unordered_map& quant_args, const torch::TensorOptions& tensor_options, + const DiTCacheConfig& dit_config, const std::string& model_type); const ModelArgs& get_model_args(const std::string& component) const; const QuantArgs& get_quant_args(const std::string& component) const; -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_CUDA) || defined(USE_MLU) ModelContext get_model_context(const std::string& component) const; #endif @@ -54,6 +56,8 @@ class DiTModelContext { const std::string& model_type() const { return model_type_; } + const DiTCacheConfig& get_dit_config() const { return dit_config_; } + #if defined(USE_NPU) const atb::Context* get_atb_context() const { return context_; } #endif @@ -63,6 +67,7 @@ class DiTModelContext { std::unordered_map quant_args_; ParallelArgs parallel_args_; torch::TensorOptions tensor_options_; + DiTCacheConfig dit_config_; std::string model_type_; #if defined(USE_NPU) diff --git a/xllm/core/framework/dit_model_loader.cpp b/xllm/core/framework/dit_model_loader.cpp index 08a9c865d..a4388297b 100644 --- a/xllm/core/framework/dit_model_loader.cpp +++ b/xllm/core/framework/dit_model_loader.cpp @@ -276,6 +276,13 @@ DiTModelLoader::DiTModelLoader(const std::string& model_root_path) LOG(FATAL) << "DiTModelLoader: model_index.json root is not an object!"; } + if (root_json.contains("_class_name")) { + set_model_type(root_json["_class_name"]); + } else { + LOG(WARNING) + << "model_index.json doesn't contains the _class_name key, xllm may " + << "not obtain model type for dit model"; + } // parse model_index.json & initialize model_loader for (const auto& [json_key, json_value] : root_json.items()) { if (!json_value.is_array() || json_value.size() != 2) { diff --git a/xllm/core/framework/dit_model_loader.h b/xllm/core/framework/dit_model_loader.h index 621ab7b28..9fd1e38f4 100644 --- a/xllm/core/framework/dit_model_loader.h +++ b/xllm/core/framework/dit_model_loader.h @@ -71,8 +71,14 @@ class DiTModelLoader { std::string get_torch_dtype() const; + void set_model_type(const std::string& model_type) { + model_type_ = model_type; + } + std::string get_model_type() { return model_type_; } + private: std::string model_root_path_; + std::string model_type_; std::unordered_map> name_to_loader_; diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 2e2432cea..a48df61c6 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -34,7 +34,7 @@ limitations under the License. #include #include -#include "core/common/rec_model_utils.h" +#include "core/common/global_flags.h" #include "core/common/version_singleton.h" #include "core/framework/state_dict/rec_vocab_dict.h" #include "core/framework/state_dict/safetensors/safetensors.h" @@ -46,6 +46,7 @@ limitations under the License. #include "core/platform/device.h" #include "core/util/blocking_counter.h" #include "core/util/json_reader.h" +#include "core/util/rec_model_utils.h" #include "core/util/scope_guard.h" #include "core/util/tensor_helper.h" #include "models/model_registry.h" @@ -113,6 +114,11 @@ bool try_load_compressed_tensors_quant_cfg(const JsonReader& reader, if (dynamic_it != input_activations_it->end() && !dynamic_it->is_null()) { quant_args.activation_dynamic() = dynamic_it->get(); } + if (const auto ignore = reader.value>( + "quantization_config.ignore"); + ignore.has_value()) { + quant_args.ignored_modules() = *ignore; + } return true; } @@ -657,7 +663,19 @@ bool HFModelLoader::load_rec_vocab(const std::string& model_weights_path) { ->initialize(vocab_full_path)) << "Failed to initialize vocab dict from " << vocab_full_path; } else { - LOG(ERROR) << "Vocab file is not set"; + if (FLAGS_enable_constrained_decoding) { + LOG(ERROR) << "Vocab file is not set for OneRec REC tokenizer under " + << model_weights_path + << ". Constrained decoding requires `vocab_file` in " + "tokenizer_config.json."; + return false; + } + + LOG(WARNING) << "Vocab file is not set for OneRec REC tokenizer under " + << model_weights_path + << ". Skip vocab dict initialization because constrained " + "decoding is disabled."; + return true; } return true; diff --git a/xllm/core/framework/hf_model_loader_test.cpp b/xllm/core/framework/hf_model_loader_test.cpp index 9777ff02c..908464eb8 100644 --- a/xllm/core/framework/hf_model_loader_test.cpp +++ b/xllm/core/framework/hf_model_loader_test.cpp @@ -42,6 +42,10 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) { } } }, + "ignore": [ + "lm_head", + "model.layers.1.mlp.down_proj" + ], "quant_method": "compressed-tensors" } } @@ -54,6 +58,9 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) { EXPECT_EQ(quant_args.bits(), 8); EXPECT_EQ(quant_args.moe_weight_bits(), 8); EXPECT_FALSE(quant_args.activation_dynamic()); + ASSERT_EQ(quant_args.ignored_modules().size(), 2); + EXPECT_EQ(quant_args.ignored_modules()[0], "lm_head"); + EXPECT_EQ(quant_args.ignored_modules()[1], "model.layers.1.mlp.down_proj"); } } diff --git a/xllm/core/framework/kv_cache/CMakeLists.txt b/xllm/core/framework/kv_cache/CMakeLists.txt index 4b969064e..8343c3e72 100644 --- a/xllm/core/framework/kv_cache/CMakeLists.txt +++ b/xllm/core/framework/kv_cache/CMakeLists.txt @@ -10,35 +10,14 @@ cc_library( embedding_cache.h kv_cache.h kv_cache_event.h - kv_cache_transfer.h - $<$:llm_data_dist_transfer.h> - $<$:spec_kv_cache_transfer.h> - kv_cache_store.h - hierarchy_kv_cache_transfer.h - $<$:mooncake_transfer_engine.h> - $<$:mooncake_kv_cache_transfer.h> - $<$:mooncake_weight_transfer.h> SRCS embedding_cache.cpp kv_cache.cpp - kv_cache_transfer.cpp - $<$:llm_data_dist_transfer.cpp> - $<$:spec_kv_cache_transfer.cpp> - kv_cache_store.cpp - hierarchy_kv_cache_transfer.cpp - $<$:mooncake_transfer_engine.cpp> - $<$:mooncake_kv_cache_transfer.cpp> - $<$:mooncake_weight_transfer.cpp> DEPS :common - $<$:graph> glog::glog - $<$:llm_datadist> - $<$:c_sec> torch $<$:torch_npu> - mooncake_store - $<$:platform_npu> ) cc_test( @@ -58,4 +37,4 @@ target_link_libraries(embedding_cache_test $<$:ascendcl> $<$:hccl> $<$:c_sec> - $<$:nnopbase>) \ No newline at end of file + $<$:nnopbase>) diff --git a/xllm/core/framework/kv_cache_transfer/CMakeLists.txt b/xllm/core/framework/kv_cache_transfer/CMakeLists.txt new file mode 100644 index 000000000..3075efaaf --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/CMakeLists.txt @@ -0,0 +1,51 @@ +include(cc_library) +include(cc_test) + + +cc_library( + NAME + kv_cache_transfer + HDRS + kv_cache_transfer.h + $<$:llm_data_dist_transfer.h> + $<$:spec_kv_cache_transfer.h> + kv_cache_store.h + hierarchy_kv_cache_transfer.h + $<$,$>:mooncake_transfer_engine.h> + $<$,$>:mooncake_kv_cache_transfer.h> + $<$:mooncake_weight_transfer.h> + SRCS + kv_cache_transfer.cpp + $<$:llm_data_dist_transfer.cpp> + $<$:spec_kv_cache_transfer.cpp> + kv_cache_store.cpp + hierarchy_kv_cache_transfer.cpp + $<$,$>:mooncake_transfer_engine.cpp> + $<$,$>:mooncake_kv_cache_transfer.cpp> + $<$:mooncake_weight_transfer.cpp> + DEPS + :common + :kv_cache + :xtensor + $<$:graph> + glog::glog + $<$:llm_datadist> + $<$:c_sec> + torch + $<$:torch_npu> + mooncake_store + proto::xllm_proto + $<$:platform_npu> +) + +if(USE_NPU OR USE_MLU) +cc_test( + NAME + mooncake_transfer_engine_test + SRCS + mooncake_transfer_engine_test.cpp + DEPS + :kv_cache_transfer + GTest::gtest_main +) +endif() diff --git a/xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp similarity index 99% rename from xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp index e8ea32271..538b2e6f5 100644 --- a/xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "hierarchy_kv_cache_transfer.h" +#include "framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h" #include #include @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "kv_cache_store.h" +#include "framework/kv_cache_transfer/kv_cache_store.h" namespace xllm { constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024; @@ -50,7 +50,7 @@ HierarchyKVCacheTransfer::HierarchyKVCacheTransfer( } if (options_.enable_kvcache_store()) { - StoreConfig config; + KVCacheStoreConfig config; config.localhost_name = options_.store_local_hostname(); config.protocol = options_.store_protocol(); config.metadata_server = options_.store_metadata_server(); diff --git a/xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h similarity index 98% rename from xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.h rename to xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h index d7741b847..32ffb7608 100644 --- a/xllm/core/framework/kv_cache/hierarchy_kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "common/types.h" +#include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" -#include "kv_cache.h" #include "platform/device.h" #include "util/blockingconcurrentqueue.h" #include "util/threadpool.h" diff --git a/xllm/core/framework/kv_cache/kv_cache_store.cpp b/xllm/core/framework/kv_cache_transfer/kv_cache_store.cpp similarity index 98% rename from xllm/core/framework/kv_cache/kv_cache_store.cpp rename to xllm/core/framework/kv_cache_transfer/kv_cache_store.cpp index 94b85af07..1e600b3d8 100644 --- a/xllm/core/framework/kv_cache/kv_cache_store.cpp +++ b/xllm/core/framework/kv_cache_transfer/kv_cache_store.cpp @@ -1,5 +1,5 @@ -#include "kv_cache_store.h" +#include "framework/kv_cache_transfer/kv_cache_store.h" #include #include @@ -11,7 +11,7 @@ namespace xllm { -bool KVCacheStore::init(const StoreConfig& config, +bool KVCacheStore::init(const KVCacheStoreConfig& config, std::vector* host_kv_caches) { CHECK(!is_initialized_) << "KVCacheStore is initialized."; config_ = config; diff --git a/xllm/core/framework/kv_cache/kv_cache_store.h b/xllm/core/framework/kv_cache_transfer/kv_cache_store.h similarity index 92% rename from xllm/core/framework/kv_cache/kv_cache_store.h rename to xllm/core/framework/kv_cache_transfer/kv_cache_store.h index cae74ae21..323c48bb9 100644 --- a/xllm/core/framework/kv_cache/kv_cache_store.h +++ b/xllm/core/framework/kv_cache_transfer/kv_cache_store.h @@ -6,13 +6,13 @@ #include #include "common/macros.h" +#include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" -#include "kv_cache.h" #include "util/slice.h" namespace xllm { -struct StoreConfig { +struct KVCacheStoreConfig { std::string localhost_name = "127.0.0.1"; std::string protocol = "tcp"; std::string metadata_server = ""; @@ -27,7 +27,7 @@ class KVCacheStore { public: ~KVCacheStore(); - bool init(const StoreConfig& config, + bool init(const KVCacheStoreConfig& config, std::vector* host_kv_caches); uint32_t batch_put( @@ -66,7 +66,7 @@ class KVCacheStore { private: bool is_initialized_ = false; - StoreConfig config_; + KVCacheStoreConfig config_; mooncake::ReplicateConfig rep_config_; std::vector* host_kv_caches_; diff --git a/xllm/core/framework/kv_cache/kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/kv_cache_transfer.cpp similarity index 94% rename from xllm/core/framework/kv_cache/kv_cache_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/kv_cache_transfer.cpp index 87d011954..3a6c42be0 100644 --- a/xllm/core/framework/kv_cache/kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/kv_cache_transfer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "kv_cache_transfer.h" +#include "framework/kv_cache_transfer/kv_cache_transfer.h" #include @@ -21,9 +21,14 @@ limitations under the License. #if defined(USE_NPU) #include +#endif -#include "llm_data_dist_transfer.h" -#include "mooncake_kv_cache_transfer.h" +#if defined(USE_NPU) +#include "framework/kv_cache_transfer/llm_data_dist_transfer.h" +#endif + +#if defined(USE_NPU) || defined(USE_MLU) +#include "framework/kv_cache_transfer/mooncake_kv_cache_transfer.h" #endif namespace xllm { @@ -56,11 +61,11 @@ folly::SemiFuture KVCacheTransfer::pull_kv_blocks_async( return future; } -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) folly::SemiFuture KVCacheTransfer::push_kv_blocks_async( const std::vector& transfer_kv_infos, const ParallelArgs& parallel_args, - std::shared_ptr layer_synchronizer, + std::shared_ptr layer_synchronizer, bool is_spec_draft) { folly::Promise promise; auto future = promise.getSemiFuture(); @@ -245,10 +250,11 @@ std::shared_ptr KVCacheTransferFactory::create( int32_t device_id = device.index(); -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) LOG(INFO) << "Create KVCacheTransfer for " << transfer_type << "flag" << FLAGS_kv_cache_transfer_type; if (transfer_type == "LlmDataDist") { +#if defined(USE_NPU) transfer = std::make_shared(device_ip, transfer_listen_port, instance_role, @@ -259,8 +265,12 @@ std::shared_ptr KVCacheTransferFactory::create( transfer->initialize(device_id); transfer->allocate_kv_cache(kv_caches, num_layers, kv_cache_shape, dtype); +#else + LOG(FATAL) << "LlmDataDist is not supported on MLU backend."; +#endif } else if (transfer_type == "Mooncake") { std::shared_ptr mooncake_transfer; +#if defined(USE_NPU) if (FLAGS_enable_xtensor) { auto xtensor_transfer = std::make_shared( device_id, transfer_listen_port, device); @@ -272,9 +282,13 @@ std::shared_ptr KVCacheTransferFactory::create( } mooncake_transfer = xtensor_transfer; } else { - mooncake_transfer = std::make_shared( + mooncake_transfer = std::make_shared( device_id, transfer_listen_port, device, model_type); } +#else + mooncake_transfer = std::make_shared( + device_id, transfer_listen_port, device, model_type); +#endif mooncake_transfer->initialize(device_id); mooncake_transfer->allocate_kv_cache( diff --git a/xllm/core/framework/kv_cache/kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/kv_cache_transfer.h similarity index 91% rename from xllm/core/framework/kv_cache/kv_cache_transfer.h rename to xllm/core/framework/kv_cache_transfer/kv_cache_transfer.h index ecb21b9e7..14e2c3e9c 100644 --- a/xllm/core/framework/kv_cache/kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/kv_cache_transfer.h @@ -18,15 +18,25 @@ limitations under the License. #include #include "common/types.h" -#include "kv_cache.h" +#include "framework/kv_cache/kv_cache.h" #if defined(USE_NPU) #include "platform/npu/npu_layer_synchronizer.h" #endif +#if defined(USE_MLU) +#include "platform/mlu/mlu_layer_synchronizer.h" +#endif #include "framework/parallel_state/parallel_args.h" #include "platform/device.h" #include "util/threadpool.h" namespace xllm { + +#if defined(USE_NPU) +using KVPushSynchronizerImpl = NPULayerSynchronizerImpl; +#elif defined(USE_MLU) +using KVPushSynchronizerImpl = MLULayerSynchronizerImpl; +#endif + class KVCacheTransfer { public: struct KVCacheInfo { @@ -101,11 +111,11 @@ class KVCacheTransfer { const std::vector& src_blocks, const std::vector& dst_blocks); -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) virtual folly::SemiFuture push_kv_blocks_async( const std::vector& transfer_kv_infos, const ParallelArgs& parallel_args, - std::shared_ptr layer_synchronizer, + std::shared_ptr layer_synchronizer, bool is_spec_draft); #endif @@ -114,10 +124,10 @@ class KVCacheTransfer { const std::vector& transfer_kv_infos, const ParallelArgs& parallel_args); -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) virtual bool push_kv_blocks( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer, + std::shared_ptr& layer_synchronizer, bool is_spec_draft) = 0; #endif diff --git a/xllm/core/framework/kv_cache/llm_data_dist_transfer.cpp b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp similarity index 99% rename from xllm/core/framework/kv_cache/llm_data_dist_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp index 608862a08..319fd1a3b 100644 --- a/xllm/core/framework/kv_cache/llm_data_dist_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llm_data_dist_transfer.h" +#include "framework/kv_cache_transfer/llm_data_dist_transfer.h" #include diff --git a/xllm/core/framework/kv_cache/llm_data_dist_transfer.h b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.h similarity index 98% rename from xllm/core/framework/kv_cache/llm_data_dist_transfer.h rename to xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.h index c2d013d12..43a15c6a3 100644 --- a/xllm/core/framework/kv_cache/llm_data_dist_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "kv_cache_transfer.h" +#include "framework/kv_cache_transfer/kv_cache_transfer.h" namespace xllm { diff --git a/xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp similarity index 74% rename from xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp index b3f67d0d9..9030092dd 100644 --- a/xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mooncake_kv_cache_transfer.h" +#include "framework/kv_cache_transfer/mooncake_kv_cache_transfer.h" #include +#include + #if defined(USE_NPU) #ifdef TORCH_HIGHER_THAN_PTA6 #include @@ -44,6 +46,7 @@ MooncakeKVCacheTransferBase::MooncakeKVCacheTransferBase( const torch::Device& device, std::unique_ptr engine) : device_id_(device_id), + device_(device), listen_port_(listen_port), mooncake_te_(std::move(engine)) { std::string instance_ip = net::get_local_ip_addr(); @@ -90,10 +93,10 @@ bool MooncakeKVCacheTransferBase::unlink_cluster(const uint64_t& cluster_id, } // ============================================================================ -// MooncakeKVCacheTransferNative +// MooncakeKVCacheTransferDefault // ============================================================================ -MooncakeKVCacheTransferNative::MooncakeKVCacheTransferNative( +MooncakeKVCacheTransferDefault::MooncakeKVCacheTransferDefault( const int32_t device_id, const int16_t listen_port, const torch::Device& device, @@ -105,36 +108,64 @@ MooncakeKVCacheTransferNative::MooncakeKVCacheTransferNative( std::make_unique(listen_port, device)), model_type_(model_type) {} -void MooncakeKVCacheTransferNative::allocate_kv_cache( +void MooncakeKVCacheTransferDefault::allocate_kv_cache( std::vector& kv_caches, const int64_t num_layers, const std::vector>& kv_cache_shape, torch::ScalarType dtype) { num_layers_ = num_layers; - allocate_kv_cache_native(kv_caches, num_layers, kv_cache_shape, dtype); + allocate_kv_cache_impl(kv_caches, num_layers, kv_cache_shape, dtype); } -void MooncakeKVCacheTransferNative::register_kv_cache( +void MooncakeKVCacheTransferDefault::register_kv_cache( std::vector& kv_caches, const std::vector>& kv_cache_shape, torch::ScalarType dtype) { num_layers_ = kv_caches.size(); + if (!kv_caches.empty()) { + torch::Tensor value_cache = kv_caches[0].get_v_cache(); + torch::Tensor index_cache = kv_caches[0].get_index_cache(); + has_v_cache_ = value_cache.defined() && value_cache.numel() > 0; + has_index_cache_ = index_cache.defined() && index_cache.numel() > 0; + } + buf_cnt_per_layer_ = 1 + static_cast(has_v_cache_) + + static_cast(has_index_cache_); int64_t data_size = torch::scalarTypeToTypeMeta(dtype).itemsize(); int64_t count_per_block = 1; - for (int32_t i = 1; i < kv_cache_shape[0].size(); ++i) { + for (size_t i = 1; i < kv_cache_shape[0].size(); ++i) { count_per_block *= kv_cache_shape[0][i]; } size_per_block_ = count_per_block * data_size; - register_per_layer_kv_cache(kv_caches, kv_cache_shape, dtype); + register_kv_cache_impl(kv_caches); } -void MooncakeKVCacheTransferNative::allocate_kv_cache_native( +void MooncakeKVCacheTransferDefault::allocate_kv_cache_impl( std::vector& kv_caches, int64_t num_layers, const std::vector>& kv_cache_shape, torch::ScalarType dtype) { +#if defined(USE_MLU) + torch::TensorOptions options = + torch::TensorOptions().dtype(dtype).device(device_); + for (int64_t i = 0; i < num_layers; ++i) { + torch::Tensor key_cache = torch::zeros(kv_cache_shape[0], options); + torch::Tensor value_cache; + torch::Tensor index_cache; + if (kv_cache_shape.size() > 1 && !kv_cache_shape[1].empty()) { + value_cache = torch::zeros(kv_cache_shape[1], options); + } + if (kv_cache_shape.size() > 2 && !kv_cache_shape[2].empty()) { + index_cache = torch::zeros(kv_cache_shape[2], options); + } + if (index_cache.defined()) { + kv_caches.emplace_back(key_cache, value_cache, index_cache); + } else { + kv_caches.emplace_back(key_cache, value_cache); + } + } +#else // Original mode: allocate device memory using aclrtMalloc // calculate the size of kv cache for each layer auto data_size = torch::elementSize(dtype); @@ -190,40 +221,81 @@ void MooncakeKVCacheTransferNative::allocate_kv_cache_native( value_cache = v_torch_tensors[i]; kv_caches.emplace_back(key_cache, value_cache); } +#endif } -void MooncakeKVCacheTransferNative::register_per_layer_kv_cache( - std::vector& kv_caches, - const std::vector>& kv_cache_shape, - torch::ScalarType dtype) { - int64_t num_cache = num_layers_ * 2; +void MooncakeKVCacheTransferDefault::add_buf( + const torch::Tensor& tensor, + std::vector& addrs, + std::vector& lens, + std::vector& buf_bytes) const { + if (!tensor.defined() || tensor.numel() == 0) { + return; + } + + CHECK_GT(tensor.dim(), 0) << "cache tensor dim must be positive"; + int64_t block_cnt = tensor.size(0); + CHECK_GT(block_cnt, 0) << "cache tensor block dim must be positive"; - std::vector cache_addrs; - std::vector cache_lens; - cache_addrs.reserve(num_cache); - cache_lens.reserve(num_cache); + addrs.emplace_back(tensor.data_ptr()); + lens.emplace_back(static_cast(tensor.nbytes())); + buf_bytes.emplace_back(static_cast(tensor.nbytes() / block_cnt)); +} + +std::vector MooncakeKVCacheTransferDefault::get_buf_ids( + const std::vector& layer_ids) const { + std::vector active_layer_ids; + if (layer_ids.empty()) { + active_layer_ids.resize(static_cast(num_layers_)); + std::iota(active_layer_ids.begin(), active_layer_ids.end(), 0); + } else { + active_layer_ids = layer_ids; + } - for (int32_t i = 0; i < num_layers_; ++i) { - cache_addrs.emplace_back(kv_caches[i].get_k_cache().data_ptr()); - cache_lens.emplace_back(kv_caches[i].get_k_cache().nbytes()); + std::vector buf_ids; + buf_ids.reserve(active_layer_ids.size() * + static_cast(buf_cnt_per_layer_)); + for (int64_t layer_id : active_layer_ids) { + CHECK_GE(layer_id, 0) << "layer_id must be non-negative"; + CHECK_LT(layer_id, num_layers_) << "layer_id out of range"; + + int64_t buf_id = layer_id * buf_cnt_per_layer_; + buf_ids.emplace_back(buf_id++); + if (has_v_cache_) { + buf_ids.emplace_back(buf_id++); + } + if (has_index_cache_) { + buf_ids.emplace_back(buf_id); + } } + return buf_ids; +} - for (int32_t i = 0; i < num_layers_; ++i) { - cache_addrs.emplace_back(kv_caches[i].get_v_cache().data_ptr()); - cache_lens.emplace_back(kv_caches[i].get_v_cache().nbytes()); +void MooncakeKVCacheTransferDefault::register_kv_cache_impl( + std::vector& kv_caches) { + std::vector addrs; + std::vector lens; + std::vector buf_bytes; + addrs.reserve(static_cast(num_layers_) * 3); + lens.reserve(static_cast(num_layers_) * 3); + buf_bytes.reserve(static_cast(num_layers_) * 3); + + for (int64_t i = 0; i < num_layers_; ++i) { + add_buf(kv_caches[i].get_k_cache(), addrs, lens, buf_bytes); + add_buf(kv_caches[i].get_v_cache(), addrs, lens, buf_bytes); + add_buf(kv_caches[i].get_index_cache(), addrs, lens, buf_bytes); } - if (!mooncake_te_->register_memory( - cache_addrs, cache_lens, size_per_block_)) { - LOG(ERROR) << "register_per_layer_kv_cache failed"; + if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) { + LOG(ERROR) << "register_kv_cache_impl failed"; return; } - LOG(INFO) << "register_per_layer_kv_cache success, num_layers=" << num_layers_ - << ", size_per_block=" << size_per_block_; + LOG(INFO) << "register_kv_cache_impl success, num_layers=" << num_layers_ + << ", buffers=" << buf_bytes.size(); } -bool MooncakeKVCacheTransferNative::pull_kv_blocks( +bool MooncakeKVCacheTransferDefault::pull_kv_blocks( const uint64_t src_cluster_id, const std::string& src_addr, const int64_t src_k_cache_id, @@ -234,8 +306,9 @@ bool MooncakeKVCacheTransferNative::pull_kv_blocks( (void)src_k_cache_id; (void)src_v_cache_id; std::vector layer_ids; + std::vector buf_ids = get_buf_ids(layer_ids); auto ret = mooncake_te_->pull_memory_blocks( - src_addr, src_blocks, dst_blocks, layer_ids); + src_addr, src_blocks, dst_blocks, buf_ids); if (!ret) { LOG(ERROR) << "Pull kv cache blocks failed, ret = " << ret; return false; @@ -243,18 +316,19 @@ bool MooncakeKVCacheTransferNative::pull_kv_blocks( return true; } -bool MooncakeKVCacheTransferNative::push_kv_blocks( +bool MooncakeKVCacheTransferDefault::push_kv_blocks( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer, + std::shared_ptr& layer_synchronizer, bool is_spec_draft) { (void)is_spec_draft; for (int64_t layer_index = 0; layer_index < num_layers_; ++layer_index) { layer_synchronizer->synchronize_layer(layer_index); for (const auto& pair : merged_kv_infos) { std::vector layer_ids = {layer_index}; + std::vector buf_ids = get_buf_ids(layer_ids); const KVCacheInfo& kv_info = pair.second; auto ret = mooncake_te_->push_memory_blocks( - kv_info.dst_addr, kv_info.src_blocks, kv_info.dst_blocks, layer_ids); + kv_info.dst_addr, kv_info.src_blocks, kv_info.dst_blocks, buf_ids); if (!ret) { LOG(ERROR) << "Push kv blocks failed, layer = " << layer_index << ", ret = " << ret; @@ -285,7 +359,7 @@ void MooncakeKVCacheTransferXTensor::allocate_kv_cache( const std::vector>& kv_cache_shape, torch::ScalarType dtype) { num_layers_ = num_layers; - allocate_kv_cache_xtensor(kv_caches, num_layers, kv_cache_shape, dtype); + allocate_kv_cache_impl(kv_caches, num_layers, kv_cache_shape, dtype); } void MooncakeKVCacheTransferXTensor::register_kv_cache( @@ -301,10 +375,10 @@ void MooncakeKVCacheTransferXTensor::register_kv_cache( } size_per_block_ = count_per_block * data_size; - register_global_xtensor(kv_cache_shape, dtype); + register_kv_cache_impl(); } -void MooncakeKVCacheTransferXTensor::allocate_kv_cache_xtensor( +void MooncakeKVCacheTransferXTensor::allocate_kv_cache_impl( std::vector& kv_caches, int64_t num_layers, const std::vector>& kv_cache_shape, @@ -333,9 +407,8 @@ void MooncakeKVCacheTransferXTensor::allocate_kv_cache_xtensor( << ", model_id=" << model_id_ << ", num_layers=" << num_layers; } -void MooncakeKVCacheTransferXTensor::register_global_xtensor( - const std::vector>& kv_cache_shape, - torch::ScalarType dtype) { +void MooncakeKVCacheTransferXTensor::register_kv_cache_impl() { + // XTensor mode registers one shared GlobalXTensor memory region. auto& global_xtensor = GlobalXTensor::get_instance(); if (!global_xtensor.is_initialized()) { LOG(ERROR) << "GlobalXTensor not initialized in xtensor mode"; @@ -349,14 +422,15 @@ void MooncakeKVCacheTransferXTensor::register_global_xtensor( std::vector addrs = {global_xtensor.base_vaddr()}; std::vector lens = {global_xtensor.total_size()}; + std::vector buf_bytes = {static_cast(size_per_block_)}; - if (!mooncake_te_->register_memory(addrs, lens, size_per_block_)) { + if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) { LOG(ERROR) << "register GlobalXTensor failed"; return; } global_xtensor.set_mooncake_registered(true); - LOG(INFO) << "register_global_xtensor success, total_size=" + LOG(INFO) << "register_kv_cache_impl success, total_size=" << global_xtensor.total_size() << ", num_pages=" << global_xtensor.num_total_pages() << ", size_per_block=" << size_per_block_; @@ -372,18 +446,18 @@ bool MooncakeKVCacheTransferXTensor::pull_kv_blocks( (void)src_cluster_id; (void)src_k_cache_id; (void)src_v_cache_id; - return pull_kv_blocks_xtensor_mode(src_addr, src_blocks, dst_blocks); + return pull_kv_blocks_impl(src_addr, src_blocks, dst_blocks); } bool MooncakeKVCacheTransferXTensor::push_kv_blocks( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer, + std::shared_ptr& layer_synchronizer, bool is_spec_draft) { (void)is_spec_draft; - return push_kv_blocks_xtensor_mode(merged_kv_infos, layer_synchronizer); + return push_kv_blocks_impl(merged_kv_infos, layer_synchronizer); } -bool MooncakeKVCacheTransferXTensor::pull_kv_blocks_xtensor_mode( +bool MooncakeKVCacheTransferXTensor::pull_kv_blocks_impl( const std::string& src_addr, const std::vector& src_blocks, const std::vector& dst_blocks) { @@ -436,19 +510,19 @@ bool MooncakeKVCacheTransferXTensor::pull_kv_blocks_xtensor_mode( size_per_block_, MooncakeTransferEngine::MoveOpcode::READ); if (!ret) { - LOG(ERROR) << "pull_kv_blocks_xtensor_mode failed at layer " << layer_id; + LOG(ERROR) << "pull_kv_blocks_impl failed at layer " << layer_id; return false; } } - VLOG(1) << "pull_kv_blocks_xtensor_mode success, num_blocks=" - << src_blocks.size() << ", num_layers=" << num_layers_; + VLOG(1) << "pull_kv_blocks_impl success, num_blocks=" << src_blocks.size() + << ", num_layers=" << num_layers_; return true; } -bool MooncakeKVCacheTransferXTensor::push_kv_blocks_xtensor_mode( +bool MooncakeKVCacheTransferXTensor::push_kv_blocks_impl( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer) { + std::shared_ptr& layer_synchronizer) { if (model_id_.empty()) { LOG(ERROR) << "model_id not set for XTensor mode push"; return false; @@ -520,14 +594,13 @@ bool MooncakeKVCacheTransferXTensor::push_kv_blocks_xtensor_mode( size_per_block_, MooncakeTransferEngine::MoveOpcode::WRITE); if (!ret) { - LOG(ERROR) << "push_kv_blocks_xtensor_mode failed at layer " - << layer_index; + LOG(ERROR) << "push_kv_blocks_impl failed at layer " << layer_index; return false; } } } - VLOG(1) << "push_kv_blocks_xtensor_mode success, num_layers=" << num_layers_; + VLOG(1) << "push_kv_blocks_impl success, num_layers=" << num_layers_; return true; } diff --git a/xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h similarity index 73% rename from xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.h rename to xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h index de95fbbd7..f9dfd361c 100644 --- a/xllm/core/framework/kv_cache/mooncake_kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h @@ -15,13 +15,13 @@ limitations under the License. #pragma once -#include "kv_cache_transfer.h" -#include "mooncake_transfer_engine.h" +#include "framework/kv_cache_transfer/kv_cache_transfer.h" +#include "framework/kv_cache_transfer/mooncake_transfer_engine.h" namespace xllm { -// Base class for Mooncake-based KV cache transfer -// Native and XTensor subclasses inherit this class (single inheritance). +// Base class for Mooncake-based KV cache transfer. +// Default and XTensor subclasses inherit this class (single inheritance). class MooncakeKVCacheTransferBase : public KVCacheTransfer { public: MooncakeKVCacheTransferBase(const int32_t device_id, @@ -53,18 +53,20 @@ class MooncakeKVCacheTransferBase : public KVCacheTransfer { uint64_t cluster_id_; int16_t listen_port_; int32_t device_id_; + torch::Device device_; int64_t num_layers_ = 0; int64_t size_per_block_ = 0; std::unique_ptr mooncake_te_; }; -class MooncakeKVCacheTransferNative final : public MooncakeKVCacheTransferBase { +class MooncakeKVCacheTransferDefault final + : public MooncakeKVCacheTransferBase { public: - MooncakeKVCacheTransferNative(const int32_t device_id, - const int16_t listen_port, - const torch::Device& device, - const std::string& model_type); + MooncakeKVCacheTransferDefault(const int32_t device_id, + const int16_t listen_port, + const torch::Device& device, + const std::string& model_type); void allocate_kv_cache( std::vector& kv_caches, @@ -86,21 +88,28 @@ class MooncakeKVCacheTransferNative final : public MooncakeKVCacheTransferBase { bool push_kv_blocks( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer, + std::shared_ptr& layer_synchronizer, bool is_spec_draft) override; private: - void allocate_kv_cache_native( + void allocate_kv_cache_impl( std::vector& kv_caches, int64_t num_layers, const std::vector>& kv_cache_shape, torch::ScalarType dtype); - void register_per_layer_kv_cache( - std::vector& kv_caches, - const std::vector>& kv_cache_shape, - torch::ScalarType dtype); + void add_buf(const torch::Tensor& tensor, + std::vector& addrs, + std::vector& lens, + std::vector& buf_bytes) const; + std::vector get_buf_ids(const std::vector& layer_ids) const; + + // Register per-layer K/V tensor memory. + void register_kv_cache_impl(std::vector& kv_caches); + bool has_v_cache_ = true; + bool has_index_cache_ = false; + int64_t buf_cnt_per_layer_ = 2; std::string model_type_; }; @@ -133,27 +142,26 @@ class MooncakeKVCacheTransferXTensor final bool push_kv_blocks( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer, + std::shared_ptr& layer_synchronizer, bool is_spec_draft) override; private: - void allocate_kv_cache_xtensor( + void allocate_kv_cache_impl( std::vector& kv_caches, int64_t num_layers, const std::vector>& kv_cache_shape, torch::ScalarType dtype); - void register_global_xtensor( - const std::vector>& kv_cache_shape, - torch::ScalarType dtype); + // Register GlobalXTensor memory region. + void register_kv_cache_impl(); - bool pull_kv_blocks_xtensor_mode(const std::string& src_addr, - const std::vector& src_blocks, - const std::vector& dst_blocks); + bool pull_kv_blocks_impl(const std::string& src_addr, + const std::vector& src_blocks, + const std::vector& dst_blocks); - bool push_kv_blocks_xtensor_mode( + bool push_kv_blocks_impl( std::unordered_map& merged_kv_infos, - std::shared_ptr& layer_synchronizer); + std::shared_ptr& layer_synchronizer); std::string model_id_; }; diff --git a/xllm/core/framework/kv_cache/mooncake_transfer_engine.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.cpp similarity index 66% rename from xllm/core/framework/kv_cache/mooncake_transfer_engine.cpp rename to xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.cpp index f8a2396e3..77b713b9f 100644 --- a/xllm/core/framework/kv_cache/mooncake_transfer_engine.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.cpp @@ -13,27 +13,74 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mooncake_transfer_engine.h" +#include "framework/kv_cache_transfer/mooncake_transfer_engine.h" -#include -#include #include +#include #include -#include "common/global_flags.h" #include "util/net.h" namespace xllm { +namespace { + +bool close_remote_session(MooncakeTransferEngineCore* core, + uint64_t cluster_id) { + proto::MooncakeTransferEngineService_Stub* stub = + core->get_or_create_stub(cluster_id); + if (stub == nullptr) { + LOG(ERROR) << "create_rpc_channel failed for cluster_id=" << cluster_id; + return false; + } + + proto::SessionInfo session_info; + session_info.set_addr(core->addr()); + proto::Status response; + brpc::Controller cntl; + stub->CloseSession(&cntl, &session_info, &response, nullptr); + if (cntl.Failed() || !response.ok()) { + LOG(ERROR) << "CloseSession failed, " << cntl.ErrorText(); + return false; + } + return true; +} + +bool check_buf_range(uint64_t buf_len, + uint64_t buf_bytes, + uint64_t block_id, + uint64_t block_len, + int64_t buf_id) { + if (buf_bytes == 0) { + LOG(ERROR) << "buf bytes is zero, buf_id=" << buf_id; + return false; + } + if (buf_len % buf_bytes != 0) { + LOG(ERROR) << "buf len is not aligned with block bytes, buf_id=" << buf_id + << ", buf_len=" << buf_len << ", buf_bytes=" << buf_bytes; + return false; + } + + uint64_t block_cnt = buf_len / buf_bytes; + if (block_id > block_cnt || block_len > block_cnt - block_id) { + LOG(ERROR) << "block range out of bounds, buf_id=" << buf_id + << ", block_cnt=" << block_cnt << ", block_id=" << block_id + << ", block_len=" << block_len; + return false; + } + return true; +} + +} // namespace + // ============================================================================ // MooncakeTransferEngineCore (Singleton) // ============================================================================ MooncakeTransferEngineCore::~MooncakeTransferEngineCore() { - // free stub for (auto& pair : stub_map_) { - if (pair.second) { + if (pair.second != nullptr) { delete pair.second->channel(); delete pair.second; } @@ -58,23 +105,13 @@ bool MooncakeTransferEngineCore::initialize(int16_t listen_port, listen_port_ = listen_port; host_ip_ = net::get_local_ip_addr(); - // Create TransferEngine engine_ = std::make_unique(true); Device dev(device); dev.set_device(); dev.init_device_context(); - int32_t device_id = dev.index(); - std::string hostname; - int32_t phy_id = FLAGS_npu_phy_id; - if (phy_id != -1) { - hostname = host_ip_ + ":" + std::to_string(listen_port_) + ":npu_" + - std::to_string(phy_id); - } else { - hostname = host_ip_ + ":" + std::to_string(listen_port_) + ":npu_" + - std::to_string(device_id); - } + std::string hostname = host_ip_ + ":" + std::to_string(listen_port_); if (engine_->init("P2PHANDSHAKE", hostname, "", 0)) { LOG(ERROR) << "engine init failed, hostname=" << hostname; @@ -83,7 +120,6 @@ bool MooncakeTransferEngineCore::initialize(int16_t listen_port, LOG(INFO) << "TransferEngine init success, hostname=" << hostname; - // Create brpc service and server service_ = std::make_shared(); if (server_.AddService(service_.get(), brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { @@ -115,7 +151,7 @@ bool MooncakeTransferEngineCore::open_session(const uint64_t cluster_id, auto it = handles_.find(remote_addr); if (it != handles_.end()) { - // Session exists, just increment ref count + // Reuse the existing session until the last caller releases it. it->second.ref_count++; LOG(INFO) << "Reusing existing session for " << remote_addr << ", ref_count=" << it->second.ref_count; @@ -124,18 +160,18 @@ bool MooncakeTransferEngineCore::open_session(const uint64_t cluster_id, if (cluster_id != 0) { proto::MooncakeTransferEngineService_Stub* stub = - get_or_create_stub(cluster_id); - if (!stub) { + get_or_create_stub_locked(cluster_id); + if (stub == nullptr) { LOG(ERROR) << "create_rpc_channel failed"; return false; } - proto::SessionInfo proto_session_info; - proto_session_info.set_addr(addr_); - proto::Status status; + proto::SessionInfo request; + request.set_addr(addr_); + proto::Status response; brpc::Controller cntl; - stub->OpenSession(&cntl, &proto_session_info, &status, nullptr); - if (cntl.Failed() || !status.ok()) { + stub->OpenSession(&cntl, &request, &response, nullptr); + if (cntl.Failed() || !response.ok()) { LOG(ERROR) << "OpenSession failed, " << cntl.ErrorText(); return false; } @@ -145,9 +181,8 @@ bool MooncakeTransferEngineCore::open_session(const uint64_t cluster_id, return true; } - Transport::SegmentHandle handle; - handle = engine_->openSegment(remote_addr); - if (handle == (Transport::SegmentHandle)-1) { + Transport::SegmentHandle handle = engine_->openSegment(remote_addr); + if (handle == static_cast(-1)) { LOG(ERROR) << "Fail to connect to " << remote_addr; return false; } @@ -170,43 +205,35 @@ bool MooncakeTransferEngineCore::close_session(const uint64_t cluster_id, << ", remote_addr=" << remote_addr; auto it = handles_.find(remote_addr); + if (cluster_id != 0) { + if (it != handles_.end()) { + it->second.ref_count--; + LOG(INFO) << "Decremented ref_count for " << remote_addr + << ", ref_count=" << it->second.ref_count; + if (it->second.ref_count > 0) { + return true; + } + } + return close_remote_session(this, cluster_id); + } + if (it == handles_.end()) { return true; } - // Decrement ref count it->second.ref_count--; LOG(INFO) << "Decremented ref_count for " << remote_addr << ", ref_count=" << it->second.ref_count; - // Only close when ref_count reaches 0 if (it->second.ref_count > 0) { return true; } - if (cluster_id != 0) { - proto::MooncakeTransferEngineService_Stub* stub = - get_or_create_stub(cluster_id); - if (!stub) { - LOG(ERROR) << "create_rpc_channel failed"; - return false; - } - - proto::SessionInfo proto_session_info; - proto_session_info.set_addr(addr_); - - proto::Status status; - brpc::Controller cntl; - stub->CloseSession(&cntl, &proto_session_info, &status, nullptr); - if (cntl.Failed() || !status.ok()) { - LOG(ERROR) << "CloseSession failed, " << cntl.ErrorText(); - return false; - } - return true; + SegmentHandle handle = it->second.handle; + if (handle != static_cast(-1)) { + engine_->closeSegment(handle); } - - engine_->closeSegment(it->second.handle); - handles_.erase(remote_addr); + handles_.erase(it); LOG(INFO) << "Closed session for " << remote_addr; @@ -218,14 +245,19 @@ SegmentHandle MooncakeTransferEngineCore::get_handle( std::lock_guard lock(mutex_); auto it = handles_.find(remote_addr); if (it == handles_.end()) { - return (SegmentHandle)-1; + return static_cast(-1); } return it->second.handle; } proto::MooncakeTransferEngineService_Stub* MooncakeTransferEngineCore::get_or_create_stub(uint64_t cluster_id) { - // Note: caller should hold mutex_ if needed + std::lock_guard lock(mutex_); + return get_or_create_stub_locked(cluster_id); +} + +proto::MooncakeTransferEngineService_Stub* +MooncakeTransferEngineCore::get_or_create_stub_locked(uint64_t cluster_id) { auto it = stub_map_.find(cluster_id); if (it == stub_map_.end()) { auto [remote_ip, remote_port] = net::convert_uint64_to_ip_port(cluster_id); @@ -271,26 +303,28 @@ std::string MooncakeTransferEngine::initialize() { bool MooncakeTransferEngine::register_memory(std::vector addrs, std::vector lens, - int64_t size_per_block) { - int64_t num = addrs.size(); - num_layers_ = num / 2; - - std::vector buffers; - buffers.reserve(num); - for (size_t i = 0; i < num; i++) { - buffers.push_back(BufferEntry{(void*)addrs[i], lens[i]}); - } - - int ret = - core_.engine()->registerLocalMemoryBatch(buffers, kWildcardLocation); - if (ret) { - LOG(ERROR) << "registerLocalMemoryBatch failed, ret=" << ret; + std::vector buf_bytes) { + if (addrs.size() != lens.size() || addrs.size() != buf_bytes.size()) { + LOG(ERROR) << "register_memory input size mismatch, addrs=" << addrs.size() + << ", lens=" << lens.size() + << ", buf_bytes=" << buf_bytes.size(); return false; } - size_per_block_ = size_per_block; + TransferEngine* engine = core_.engine(); + for (size_t i = 0; i < addrs.size(); ++i) { + int32_t ret = engine->registerLocalMemory( + addrs[i], lens[i], kWildcardLocation, true, true); + if (ret != 0) { + LOG(ERROR) << "registerLocalMemory failed, buf_id=" << i + << ", addr=" << addrs[i] << ", len=" << lens[i] + << ", ret=" << ret; + return false; + } + } - LOG(INFO) << "register_memory success, size_per_block_=" << size_per_block_; + buf_bytes_ = std::move(buf_bytes); + LOG(INFO) << "register_memory success, buf_num=" << buf_bytes_.size(); return true; } @@ -317,11 +351,11 @@ void merge_block_ids(const std::vector& src_blocks, std::vector& merged_src_blocks, std::vector& merged_dst_blocks, std::vector& block_lengths) { - // Create an index array and sort it based on the values of src blocks. size_t block_num = src_blocks.size(); if (block_num == 0) { return; } + std::vector indices(block_num); std::iota(indices.begin(), indices.end(), 0); std::sort( @@ -329,17 +363,15 @@ void merge_block_ids(const std::vector& src_blocks, return src_blocks[i] < src_blocks[j]; }); - // Generate sorted src blocks and dst blocks. std::vector sorted_src_blocks; std::vector sorted_dst_blocks; sorted_src_blocks.reserve(block_num); sorted_dst_blocks.reserve(block_num); - for (auto id : indices) { + for (uint64_t id : indices) { sorted_src_blocks.emplace_back(src_blocks[id]); sorted_dst_blocks.emplace_back(dst_blocks[id]); } - // Obtain continuous blocks. uint64_t current_src_id = sorted_src_blocks[0]; uint64_t current_dst_id = sorted_dst_blocks[0]; uint64_t current_length = 1; @@ -368,32 +400,48 @@ bool MooncakeTransferEngine::move_memory_blocks( const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids, + const std::vector& buf_ids, MoveOpcode move_opcode) { - auto remote_handle = core_.get_handle(remote_addr); - if (remote_handle == (SegmentHandle)-1) { + if (src_blocks.size() != dst_blocks.size()) { + LOG(ERROR) << "src_blocks size must equal dst_blocks size, src=" + << src_blocks.size() << ", dst=" << dst_blocks.size(); + return false; + } + + SegmentHandle remote_handle = core_.get_handle(remote_addr); + if (remote_handle == static_cast(-1)) { LOG(ERROR) << "remote addr does not exist: " << remote_addr; return false; } - auto* engine = core_.engine(); - std::shared_ptr remote_segment_desc; - remote_segment_desc = + TransferEngine* engine = core_.engine(); + std::shared_ptr remote_segment_desc = engine->getMetadata()->getSegmentDescByID(remote_handle); if (!remote_segment_desc) { LOG(ERROR) << "remote_segment_desc is null"; return false; } - std::shared_ptr local_segment_desc; - local_segment_desc = + std::shared_ptr local_segment_desc = engine->getMetadata()->getSegmentDescByID(LOCAL_SEGMENT_ID); if (!local_segment_desc) { LOG(ERROR) << "local_segment_desc is null"; return false; } - // Merge consecutive block ids to improve transmission efficiency. + size_t local_buf_cnt = local_segment_desc->buffers.size(); + size_t remote_buf_cnt = remote_segment_desc->buffers.size(); + if (local_buf_cnt != remote_buf_cnt) { + LOG(ERROR) << "buffer count mismatch, local=" << local_buf_cnt + << ", remote=" << remote_buf_cnt; + return false; + } + if (local_buf_cnt != buf_bytes_.size()) { + LOG(ERROR) << "registered buffer count mismatch, local=" << local_buf_cnt + << ", block_bytes=" << buf_bytes_.size(); + return false; + } + std::vector merged_src_blocks; std::vector merged_dst_blocks; std::vector block_lengths; @@ -403,59 +451,66 @@ bool MooncakeTransferEngine::move_memory_blocks( merged_dst_blocks, block_lengths); - std::vector addr_ids; - if (layer_ids.size() == 0) { - addr_ids.resize(num_layers_); - std::iota(addr_ids.begin(), addr_ids.end(), 0); + std::vector active_buf_ids; + if (buf_ids.empty()) { + active_buf_ids.resize(buf_bytes_.size()); + std::iota(active_buf_ids.begin(), active_buf_ids.end(), 0); } else { - addr_ids = layer_ids; + active_buf_ids = buf_ids; } - TransferRequest::OpCode opcode; + TransferRequest::OpCode opcode = TransferRequest::READ; if (move_opcode == MoveOpcode::WRITE) { opcode = TransferRequest::WRITE; - } else { - opcode = TransferRequest::READ; } std::vector entries; - for (auto addr_id : addr_ids) { - char* k_local_base = (char*)(local_segment_desc->buffers[addr_id].addr); - char* k_remote_base = (char*)(remote_segment_desc->buffers[addr_id].addr); + for (int64_t buf_id : active_buf_ids) { + if (buf_id < 0 || static_cast(buf_id) >= local_buf_cnt) { + LOG(ERROR) << "buf_id out of range, buf_id=" << buf_id + << ", buf_cnt=" << local_buf_cnt; + return false; + } - int64_t v_addr_id = addr_id + num_layers_; - char* v_local_base = (char*)(local_segment_desc->buffers[v_addr_id].addr); - char* v_remote_base = (char*)(remote_segment_desc->buffers[v_addr_id].addr); + size_t local_buf_id = static_cast(buf_id); + uint64_t buf_bytes = buf_bytes_[local_buf_id]; + uint64_t local_buf_len = local_segment_desc->buffers[local_buf_id].length; + uint64_t remote_buf_len = remote_segment_desc->buffers[local_buf_id].length; + char* local_base = + reinterpret_cast(local_segment_desc->buffers[local_buf_id].addr); + uint64_t remote_base = remote_segment_desc->buffers[local_buf_id].addr; for (size_t i = 0; i < merged_src_blocks.size(); ++i) { uint64_t src_block_id = merged_src_blocks[i]; uint64_t dst_block_id = merged_dst_blocks[i]; uint64_t block_length = block_lengths[i]; - uint64_t src_bias = src_block_id * size_per_block_; - uint64_t dst_bias = dst_block_id * size_per_block_; - uint64_t len = block_length * size_per_block_; - - TransferRequest k_entry; - k_entry.opcode = opcode; - k_entry.length = len; - k_entry.source = (void*)(k_local_base + src_bias); - k_entry.target_id = remote_handle; - k_entry.target_offset = (uint64_t)(k_remote_base + dst_bias); - k_entry.advise_retry_cnt = 0; - entries.push_back(k_entry); - - TransferRequest v_entry; - v_entry.opcode = opcode; - v_entry.length = len; - v_entry.source = (void*)(v_local_base + src_bias); - v_entry.target_id = remote_handle; - v_entry.target_offset = (uint64_t)(v_remote_base + dst_bias); - v_entry.advise_retry_cnt = 0; - entries.push_back(v_entry); + if (!check_buf_range( + local_buf_len, buf_bytes, src_block_id, block_length, buf_id) || + !check_buf_range( + remote_buf_len, buf_bytes, dst_block_id, block_length, buf_id)) { + return false; + } + + uint64_t src_bias = src_block_id * buf_bytes; + uint64_t dst_bias = dst_block_id * buf_bytes; + uint64_t len = block_length * buf_bytes; + + TransferRequest entry; + entry.opcode = opcode; + entry.length = len; + entry.source = reinterpret_cast(local_base + src_bias); + entry.target_id = remote_handle; + entry.target_offset = remote_base + dst_bias; + entry.advise_retry_cnt = 0; + entries.push_back(entry); } } - auto batch_size = entries.size(); + if (entries.empty()) { + return true; + } + + size_t batch_size = entries.size(); auto batch_id = engine->allocateBatchID(batch_size); mooncake::Status s = engine->submitTransfer(batch_id, entries); if (!s.ok()) { @@ -499,44 +554,41 @@ bool MooncakeTransferEngine::move_memory_by_global_offsets( const std::vector& dst_offsets, size_t transfer_size, MoveOpcode move_opcode) { - auto remote_handle = core_.get_handle(remote_addr); - if (remote_handle == (SegmentHandle)-1) { + SegmentHandle remote_handle = core_.get_handle(remote_addr); + if (remote_handle == static_cast(-1)) { LOG(ERROR) << "remote addr does not exist: " << remote_addr; return false; } - auto* engine = core_.engine(); - std::shared_ptr remote_segment_desc; - remote_segment_desc = + TransferEngine* engine = core_.engine(); + std::shared_ptr remote_segment_desc = engine->getMetadata()->getSegmentDescByID(remote_handle); if (!remote_segment_desc) { LOG(ERROR) << "remote_segment_desc is null"; return false; } - std::shared_ptr local_segment_desc; - local_segment_desc = + std::shared_ptr local_segment_desc = engine->getMetadata()->getSegmentDescByID(LOCAL_SEGMENT_ID); if (!local_segment_desc) { LOG(ERROR) << "local_segment_desc is null"; return false; } - // XTensor mode: use buffer[0] which is the GlobalXTensor if (local_segment_desc->buffers.empty() || remote_segment_desc->buffers.empty()) { LOG(ERROR) << "No buffers registered for XTensor mode"; return false; } - char* local_base = (char*)(local_segment_desc->buffers[0].addr); - char* remote_base = (char*)(remote_segment_desc->buffers[0].addr); + char* local_base = + reinterpret_cast(local_segment_desc->buffers[0].addr); + char* remote_base = + reinterpret_cast(remote_segment_desc->buffers[0].addr); - TransferRequest::OpCode opcode; + TransferRequest::OpCode opcode = TransferRequest::READ; if (move_opcode == MoveOpcode::WRITE) { opcode = TransferRequest::WRITE; - } else { - opcode = TransferRequest::READ; } std::vector entries; @@ -546,14 +598,15 @@ bool MooncakeTransferEngine::move_memory_by_global_offsets( TransferRequest entry; entry.opcode = opcode; entry.length = transfer_size; - entry.source = (void*)(local_base + src_offsets[i]); + entry.source = reinterpret_cast(local_base + src_offsets[i]); entry.target_id = remote_handle; - entry.target_offset = (uint64_t)(remote_base + dst_offsets[i]); + entry.target_offset = + reinterpret_cast(remote_base + dst_offsets[i]); entry.advise_retry_cnt = 0; entries.push_back(entry); } - auto batch_size = entries.size(); + size_t batch_size = entries.size(); auto batch_id = engine->allocateBatchID(batch_size); mooncake::Status s = engine->submitTransfer(batch_id, entries); if (!s.ok()) { @@ -595,9 +648,9 @@ bool MooncakeTransferEngine::pull_memory_blocks( const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids) { - auto ret = move_memory_blocks( - remote_addr, src_blocks, dst_blocks, layer_ids, MoveOpcode::READ); + const std::vector& buf_ids) { + bool ret = move_memory_blocks( + remote_addr, src_blocks, dst_blocks, buf_ids, MoveOpcode::READ); if (!ret) { LOG(ERROR) << "Pull memory blocks failed, ret = " << ret; return false; @@ -610,9 +663,9 @@ bool MooncakeTransferEngine::push_memory_blocks( const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids) { - auto ret = move_memory_blocks( - remote_addr, src_blocks, dst_blocks, layer_ids, MoveOpcode::WRITE); + const std::vector& buf_ids) { + bool ret = move_memory_blocks( + remote_addr, src_blocks, dst_blocks, buf_ids, MoveOpcode::WRITE); if (!ret) { LOG(ERROR) << "Push memory blocks failed, ret = " << ret; return false; @@ -631,11 +684,17 @@ void MooncakeTransferEngineService::OpenSession( proto::Status* response, ::google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - if (!request || !response || !controller) { + if (request == nullptr || response == nullptr || controller == nullptr) { LOG(ERROR) << "brpc request | response | controller is null"; return; } + if (request->addr().empty()) { + LOG(ERROR) << "OpenSession request missing addr"; + response->set_ok(false); + return; + } + std::string remote_addr(request->addr()); bool result = MooncakeTransferEngineCore::get_instance().open_session(0, remote_addr); @@ -649,11 +708,17 @@ void MooncakeTransferEngineService::CloseSession( proto::Status* response, ::google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - if (!request || !response || !controller) { + if (request == nullptr || response == nullptr || controller == nullptr) { LOG(ERROR) << "brpc request | response | controller is null"; return; } + if (request->addr().empty()) { + LOG(ERROR) << "CloseSession request missing addr"; + response->set_ok(false); + return; + } + std::string remote_addr(request->addr()); bool result = MooncakeTransferEngineCore::get_instance().close_session(0, remote_addr); diff --git a/xllm/core/framework/kv_cache/mooncake_transfer_engine.h b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.h similarity index 82% rename from xllm/core/framework/kv_cache/mooncake_transfer_engine.h rename to xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.h index ad0658877..d658a31ad 100644 --- a/xllm/core/framework/kv_cache/mooncake_transfer_engine.h +++ b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine.h @@ -19,8 +19,11 @@ limitations under the License. #include #include +#include #include #include +#include +#include #include "mooncake_transfer_engine.pb.h" #include "platform/device.h" @@ -31,38 +34,38 @@ using namespace mooncake; class MooncakeTransferEngineService; -// Singleton core that holds the actual TransferEngine and brpc Server -// Multiple MooncakeTransferEngine instances share this core +// Singleton core that holds the actual TransferEngine and brpc Server. +// Multiple MooncakeTransferEngine instances share this core. class MooncakeTransferEngineCore { public: - // Get the global singleton instance static MooncakeTransferEngineCore& get_instance() { static MooncakeTransferEngineCore instance; return instance; } - // Initialize the core (only first call takes effect) + // Initialize the shared core. Only the first call takes effect. bool initialize(int16_t listen_port, const torch::Device& device); - // Get the underlying TransferEngine TransferEngine* engine() { return engine_.get(); } - // Get the RPC address const std::string& addr() const { return addr_; } const std::string& host_ip() const { return host_ip_; } - // Session management (shared across all MooncakeTransferEngine instances) + // Session state is shared across all MooncakeTransferEngine instances. bool open_session(const uint64_t cluster_id, const std::string& remote_addr); bool close_session(const uint64_t cluster_id, const std::string& remote_addr); SegmentHandle get_handle(const std::string& remote_addr); - // RPC channel management + // Lazily create and cache the RPC stub for a remote cluster. proto::MooncakeTransferEngineService_Stub* get_or_create_stub( uint64_t cluster_id); bool is_initialized() const { return initialized_; } private: + proto::MooncakeTransferEngineService_Stub* get_or_create_stub_locked( + uint64_t cluster_id); + MooncakeTransferEngineCore() = default; ~MooncakeTransferEngineCore(); MooncakeTransferEngineCore(const MooncakeTransferEngineCore&) = delete; @@ -81,11 +84,10 @@ class MooncakeTransferEngineCore { brpc::Server server_; std::shared_ptr service_; - // Session handle with reference count for isolation between kv cache and - // weight transfer + // Keep a shared session handle so kv cache and weight transfer can reuse it. struct SessionInfo { - SegmentHandle handle; - int ref_count = 0; + SegmentHandle handle = static_cast(-1); + int32_t ref_count = 0; }; std::unordered_map handles_; std::unordered_map @@ -104,28 +106,25 @@ class MooncakeTransferEngine final { bool register_memory(std::vector addrs, std::vector lens, - int64_t size_per_block); + std::vector buf_bytes); bool move_memory_blocks(const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids, + const std::vector& buf_ids, MoveOpcode move_opcode); bool pull_memory_blocks(const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids); + const std::vector& buf_ids); bool push_memory_blocks(const std::string& remote_addr, const std::vector& src_blocks, const std::vector& dst_blocks, - const std::vector& layer_ids); + const std::vector& buf_ids); - // === XTensor mode: transfer by GlobalXTensor offsets === - // Instead of using block_id and per-layer buffers, this method uses - // raw offsets into the GlobalXTensor memory region (buffer[0]). - // src_offsets and dst_offsets are absolute offsets within GlobalXTensor. + // XTensor mode uses raw offsets in the GlobalXTensor region in buffer[0]. bool move_memory_by_global_offsets(const std::string& remote_addr, const std::vector& src_offsets, const std::vector& dst_offsets, @@ -141,8 +140,7 @@ class MooncakeTransferEngine final { private: int16_t listen_port_; - int64_t size_per_block_ = 0; - int64_t num_layers_ = 0; + std::vector buf_bytes_; Device device_; MooncakeTransferEngineCore& core_; }; diff --git a/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp new file mode 100644 index 000000000..5b55e4851 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp @@ -0,0 +1,57 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "framework/kv_cache_transfer/mooncake_transfer_engine.h" + +#include +#include + +namespace xllm { + +TEST(MooncakeTransferEngineServiceTest, OpenSessionRejectsMissingAddr) { + MooncakeTransferEngineService service; + proto::SessionInfo request; + proto::Status response; + brpc::Controller cntl; + + service.OpenSession(&cntl, &request, &response, nullptr); + + EXPECT_FALSE(response.ok()); +} + +TEST(MooncakeTransferEngineServiceTest, CloseSessionRejectsMissingAddr) { + MooncakeTransferEngineService service; + proto::SessionInfo request; + proto::Status response; + brpc::Controller cntl; + + service.CloseSession(&cntl, &request, &response, nullptr); + + EXPECT_FALSE(response.ok()); +} + +TEST(MooncakeTransferEngineServiceTest, CloseSessionWithoutHandleReturnsTrue) { + MooncakeTransferEngineService service; + proto::SessionInfo request; + request.set_addr("127.0.0.1:5001"); + proto::Status response; + brpc::Controller cntl; + + service.CloseSession(&cntl, &request, &response, nullptr); + + EXPECT_TRUE(response.ok()); +} + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache/mooncake_weight_transfer.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.cpp similarity index 95% rename from xllm/core/framework/kv_cache/mooncake_weight_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.cpp index ddb70f121..97a5f2904 100644 --- a/xllm/core/framework/kv_cache/mooncake_weight_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mooncake_weight_transfer.h" +#include "framework/kv_cache_transfer/mooncake_weight_transfer.h" #include @@ -53,8 +53,9 @@ bool MooncakeWeightTransfer::register_global_xtensor() { std::vector addrs = {global_xtensor.base_vaddr()}; std::vector lens = {global_xtensor.total_size()}; - if (!mooncake_te_->register_memory( - addrs, lens, static_cast(global_xtensor.page_size()))) { + std::vector buf_bytes = { + static_cast(global_xtensor.page_size())}; + if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) { LOG(ERROR) << "register GlobalXTensor failed"; return false; } diff --git a/xllm/core/framework/kv_cache/mooncake_weight_transfer.h b/xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.h similarity index 96% rename from xllm/core/framework/kv_cache/mooncake_weight_transfer.h rename to xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.h index b4e1b73e2..2cb6e5fba 100644 --- a/xllm/core/framework/kv_cache/mooncake_weight_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/mooncake_weight_transfer.h @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "mooncake_transfer_engine.h" +#include "framework/kv_cache_transfer/mooncake_transfer_engine.h" namespace xllm { diff --git a/xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp similarity index 99% rename from xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp rename to xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp index 1014216de..e42166dec 100644 --- a/xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "spec_kv_cache_transfer.h" +#include "framework/kv_cache_transfer/spec_kv_cache_transfer.h" #include #include diff --git a/xllm/core/framework/kv_cache/spec_kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.h similarity index 98% rename from xllm/core/framework/kv_cache/spec_kv_cache_transfer.h rename to xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.h index 191b4e0f5..5a63c0221 100644 --- a/xllm/core/framework/kv_cache/spec_kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.h @@ -15,8 +15,8 @@ limitations under the License. #pragma once +#include "framework/kv_cache_transfer/llm_data_dist_transfer.h" #include "framework/parallel_state/parallel_args.h" -#include "llm_data_dist_transfer.h" namespace xllm { diff --git a/xllm/core/framework/model/CMakeLists.txt b/xllm/core/framework/model/CMakeLists.txt index e5f88bff6..c2d1a0c48 100644 --- a/xllm/core/framework/model/CMakeLists.txt +++ b/xllm/core/framework/model/CMakeLists.txt @@ -8,7 +8,6 @@ cc_library( causal_lm.h causal_vlm.h dit_model.h - embedding_vlm.h mm_embedding_vlm.h model_args.h npu_cp_ep_padding.h diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index 10e794ae2..f418c7846 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -61,6 +61,15 @@ class CausalVLMImpl : public CausalVLM { return model_->forward(tokens, positions, kv_caches, parameters); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) override { + if constexpr (detail::has_pooler::value) { + return model_->pooler(hidden_states, seleted_idxes); + } else { + return CausalLM::pooler(hidden_states, seleted_idxes); + } + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) override { return model_->logits(hidden_states, seleted_idxes); diff --git a/xllm/core/framework/model/embedding_vlm.h b/xllm/core/framework/model/embedding_vlm.h deleted file mode 100644 index 4e7bae651..000000000 --- a/xllm/core/framework/model/embedding_vlm.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include - -#include - -#include "causal_vlm.h" -#include "core/framework/kv_cache/kv_cache.h" -#include "core/framework/quant_args.h" -#include "core/framework/state_dict/state_dict.h" -#include "model_args.h" -#include "model_input_params.h" - -namespace xllm { - -class EmbeddingVLM : public CausalVLM { - public: - ~EmbeddingVLM() override = default; - - // hidden_states: [num_tokens, hidden_size] - // seleted_idxes: [num_tokens] - // returns: [num_seqs, hidden_size] - virtual torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) = 0; -}; - -template -class EmbeddingVLMImpl : public EmbeddingVLM { - public: - EmbeddingVLMImpl(Model model, const torch::TensorOptions& options) - : model_(std::move(model)), options_(options) {} - - torch::Tensor logits(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->logits(hidden_states, seleted_idxes); - } - - torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->pooler(hidden_states, seleted_idxes); - } - - void load_model(std::unique_ptr loader) override { - model_->load_model(std::move(loader)); - } - - torch::Device device() const override { return options_.device(); } - - const torch::TensorOptions& options() const override { return options_; } - - private: - Model model_; - - torch::TensorOptions options_; -}; - -} // namespace xllm diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index 57ca97bee..080977002 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -239,6 +239,9 @@ struct ModelArgs { // Vision model's mm_projection_dim PROPERTY(int64_t, mm_projection_dim) = 0; + // Vision model's mm_projector_hidden_size + PROPERTY(int64_t, mm_projector_hidden_size) = 0; + PROPERTY(int64_t, mm_spatial_merge_size) = 0; PROPERTY(int64_t, mm_spatial_patch_size) = 0; @@ -422,6 +425,23 @@ struct ModelArgs { PROPERTY(float, max_shift) = 0; PROPERTY(int64_t, base_image_seq_len) = 0; PROPERTY(int64_t, max_image_seq_len) = 0; + PROPERTY(float, shift_terminal) = 0; + + // qwen_image_edit_2509 vae related args + PROPERTY(int64_t, base_dim) = 0; + PROPERTY(int64_t, z_dim) = 0; + PROPERTY(std::vector, dim_mult) = {}; + PROPERTY(std::vector, attn_scales) = {}; + PROPERTY(std::vector, temperal_downsample) = {}; + PROPERTY(int64_t, num_res_blocks) = 0; + PROPERTY(double, dropout) = 0; + PROPERTY(std::vector, latents_mean) = {}; + PROPERTY(std::vector, latents_std) = {}; + + // qwen_image_edit_2511 dit related args + PROPERTY(bool, zero_cond_t) = false; + PROPERTY(bool, use_additional_t_cond) = false; + PROPERTY(bool, use_layer3d_rope) = false; }; // Qwen hybrid models may describe full-attention layers explicitly via diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 85c945df6..4aeb8ea34 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -26,11 +26,15 @@ limitations under the License. #if defined(USE_NPU) #include "platform/npu/npu_layer_synchronizer.h" #endif +#if defined(USE_MLU) +#include "platform/mlu/mlu_layer_synchronizer.h" +#endif #include "framework/batch/batch_forward_type.h" #include "framework/request/mm_batch_data.h" #include "npu_cp_ep_padding.h" #include "npu_cp_prepare.h" #include "npu_dp_ep_padding.h" +#include "runtime/dit_forward_params.h" #include "util/hash_util.h" #include "util/tensor_helper.h" @@ -180,6 +184,8 @@ struct LlmRecMultiRoundParams { result.full_k_caches.clear(); result.full_v_caches.clear(); + result.full_k_caches.reserve(full_k_caches.size()); + result.full_v_caches.reserve(full_v_caches.size()); for (const auto& t : full_k_caches) { result.full_k_caches.push_back(safe_to(t, device)); } @@ -190,6 +196,10 @@ struct LlmRecMultiRoundParams { result.unshared_v_caches.clear(); result.shared_k_caches.clear(); result.shared_v_caches.clear(); + result.unshared_k_caches.reserve(unshared_k_caches.size()); + result.unshared_v_caches.reserve(unshared_v_caches.size()); + result.shared_k_caches.reserve(shared_k_caches.size()); + result.shared_v_caches.reserve(shared_v_caches.size()); for (const auto& t : unshared_k_caches) { result.unshared_k_caches.push_back(safe_to(t, device)); } @@ -240,6 +250,8 @@ struct LlmRecMultiRoundParams { } result.decode_positions_tensor_list.clear(); + result.decode_positions_tensor_list.reserve( + decode_positions_tensor_list.size()); for (const auto& t : decode_positions_tensor_list) { result.decode_positions_tensor_list.push_back(safe_to(t, device)); } @@ -334,7 +346,9 @@ struct ModelInputParams { params.kv_seq_lens = safe_to(kv_seq_lens, device, true); params.q_seq_lens = safe_to(q_seq_lens, device, true); +#if !defined(USE_CUDA) params.q_cu_seq_lens = safe_to(q_cu_seq_lens, device, true); +#endif params.new_cache_slots = safe_to(new_cache_slots, device, true); params.block_tables = safe_to(block_tables, device, true); @@ -376,7 +390,7 @@ struct ModelInputParams { params.ring_cur_seqlen_host = ring_cur_seqlen_host; params.ring_cache_seqlen = safe_to(ring_cache_seqlen, device); params.ring_cache_seqlen_host = ring_cache_seqlen_host; -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) params.layer_synchronizer = layer_synchronizer; #endif params.expert_load_data = expert_load_data; @@ -406,6 +420,8 @@ struct ModelInputParams { params.batch_id = batch_id; + params.dit_forward_input = dit_forward_input.to(device); + // rec_params device conversion for both OneRec and LLM-Rec variants if (const auto* onerec = onerec_params()) { params.rec_params = onerec->to(device); @@ -464,6 +480,20 @@ struct ModelInputParams { return false; } } +#else + (void)layer_idx; +#endif + return true; + } + + bool record_layer(uint32_t layer_idx, const torch::Device& device) const { +#if defined(USE_MLU) + if (layer_synchronizer != nullptr) { + return layer_synchronizer->record_current(layer_idx, device.index()); + } +#else + (void)layer_idx; + (void)device; #endif return true; } @@ -481,7 +511,7 @@ struct ModelInputParams { int32_t kv_max_seq_len = 0; int32_t q_max_seq_len = 0; - uint64_t batch_id; + uint64_t batch_id = 0; torch::Tensor q_seq_lens; torch::Tensor kv_seq_lens; @@ -552,7 +582,9 @@ struct ModelInputParams { // visual pos mask for Qwen3-VL mutable torch::Tensor visual_pos_masks; -#if defined(USE_NPU) +#if defined(USE_MLU) + std::shared_ptr layer_synchronizer = nullptr; +#elif defined(USE_NPU) std::shared_ptr layer_synchronizer = nullptr; uint32_t layers_per_bacth_copy = std::numeric_limits::max(); std::shared_ptr layer_wise_load_synchronizer = @@ -577,6 +609,9 @@ struct ModelInputParams { RecModelInputParams rec_params; + // dit input data + DiTForwardInput dit_forward_input; + const OneRecModelInputParams* onerec_params() const { return std::get_if(&rec_params); } diff --git a/xllm/core/framework/model/model_traits.h b/xllm/core/framework/model/model_traits.h index 9e351819f..a9ac78b28 100644 --- a/xllm/core/framework/model/model_traits.h +++ b/xllm/core/framework/model/model_traits.h @@ -104,6 +104,16 @@ struct has_reload_model_weights_from_device< decltype(std::declval()->reload_model_weights_from_device())>> : std::true_type {}; +template +struct has_pooler : std::false_type {}; + +template +struct has_pooler()->pooler( + std::declval(), + std::declval()))>> + : std::true_type {}; + #if defined(USE_NPU) template struct has_get_npu_lm_head : std::false_type {}; diff --git a/xllm/core/framework/model_context.cpp b/xllm/core/framework/model_context.cpp index 09a66473c..6e7c46779 100644 --- a/xllm/core/framework/model_context.cpp +++ b/xllm/core/framework/model_context.cpp @@ -17,7 +17,9 @@ limitations under the License. #include +#include "common/global_flags.h" #include "platform/device.h" +#include "util/env_var.h" #if defined(USE_NPU) #ifdef TORCH_HIGHER_THAN_PTA6 // #include @@ -30,6 +32,21 @@ limitations under the License. #endif namespace xllm { + +namespace { + +bool should_enable_async_tiling_copy_stream() { + // ATB copy-stream teardown is not reversible for the same context on the + // current CANN/PTA stack, so contexts that may enter graph capture must not + // pre-create the helper stream. + if (FLAGS_enable_graph) { + return false; + } + return util::get_bool_env("ATB_USE_TILING_COPY_STREAM", false); +} + +} // namespace + ModelContext::ModelContext(const ParallelArgs& input_parallel_args, const ModelArgs& model_args, const QuantArgs& quant_args, @@ -44,7 +61,9 @@ ModelContext::ModelContext(const ParallelArgs& input_parallel_args, atb::CreateContext(&context_); void* stream = c10_npu::getCurrentNPUStream(device_id).stream(); context_->SetExecuteStream(stream); - context_->SetAsyncTilingCopyStatus(true); + if (should_enable_async_tiling_copy_stream()) { + context_->SetAsyncTilingCopyStatus(true); + } atb_workspace_ = std::make_shared(tensor_options.device()); #endif derive_optimization_config(); diff --git a/xllm/core/framework/parallel_state/CMakeLists.txt b/xllm/core/framework/parallel_state/CMakeLists.txt index 252bb95e9..2defafbdf 100644 --- a/xllm/core/framework/parallel_state/CMakeLists.txt +++ b/xllm/core/framework/parallel_state/CMakeLists.txt @@ -6,6 +6,8 @@ cc_library( parallel_state HDRS mapping_npu.h + dit_mapping.h + rank_generator.h parallel_args.h parallel_state.h process_group.h @@ -14,14 +16,18 @@ cc_library( $<$:musa_process_group.h> $<$:cuda_process_group.h> $<$:ilu_process_group.h> + collective_communicator_base.h collective_communicator.h + dit_collective_communicator.h SRCS mapping_npu.cpp + dit_mapping.cpp parallel_state.cpp parallel_state_async.cpp process_group.cpp $<$:npu_process_group.cpp> collective_communicator.cpp + dit_collective_communicator.cpp DEPS :common torch diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index ac9faf6b1..06d9a922b 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -41,7 +41,8 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, int world_size, int dp_size, int ep_size, - int cp_size) { + int cp_size) + : CollectiveCommunicatorBase(global_rank, world_size) { #if defined(USE_NPU) // create hccl process group with hccl_root_info // std::vector unique_ids; diff --git a/xllm/core/framework/parallel_state/collective_communicator.h b/xllm/core/framework/parallel_state/collective_communicator.h index 72d2cde13..1f2f68d78 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.h +++ b/xllm/core/framework/parallel_state/collective_communicator.h @@ -15,15 +15,11 @@ limitations under the License. #pragma once -#include -#include - -#include "parallel_args.h" -#include "process_group.h" +#include "collective_communicator_base.h" namespace xllm { -class CollectiveCommunicator { +class CollectiveCommunicator : public CollectiveCommunicatorBase { public: CollectiveCommunicator(int global_rank, int world_size, @@ -33,10 +29,10 @@ class CollectiveCommunicator { ~CollectiveCommunicator() = default; void create_process_groups(const std::string& master_addr, - const torch::Device& device); + const torch::Device& device) override; // init communicator and return parallel args. - const ParallelArgs* parallel_args(); + const ParallelArgs* parallel_args() override; private: std::unique_ptr parallel_args_; diff --git a/xllm/core/framework/parallel_state/collective_communicator_base.h b/xllm/core/framework/parallel_state/collective_communicator_base.h new file mode 100644 index 000000000..31628289b --- /dev/null +++ b/xllm/core/framework/parallel_state/collective_communicator_base.h @@ -0,0 +1,48 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "parallel_args.h" +#include "process_group.h" + +namespace xllm { + +class CollectiveCommunicatorBase { + public: + CollectiveCommunicatorBase(int global_rank, int world_size) + : global_rank_(global_rank), world_size_(world_size) {} + + virtual ~CollectiveCommunicatorBase() = default; + + virtual void create_process_groups(const std::string& master_addr, + const torch::Device& device) = 0; + + virtual const ParallelArgs* parallel_args() = 0; + + int get_global_rank() const { return global_rank_; } + int get_world_size() const { return world_size_; } + + protected: + int global_rank_; + int world_size_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/dit_collective_communicator.cpp b/xllm/core/framework/parallel_state/dit_collective_communicator.cpp new file mode 100644 index 000000000..4f3304f57 --- /dev/null +++ b/xllm/core/framework/parallel_state/dit_collective_communicator.cpp @@ -0,0 +1,182 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "dit_collective_communicator.h" + +#include "mapping_npu.h" + +#if defined(USE_NPU) +#include "npu_process_group.h" +#elif defined(USE_MLU) +#include "mlu_process_group.h" +#elif defined(USE_CUDA) +#include "cuda_process_group.h" +#elif defined(USE_ILU) +#include "ilu_process_group.h" +#elif defined(USE_MUSA) +#include "musa_process_group.h" +#endif +#include "common/global_flags.h" +#include "parallel_args.h" +#include "platform/device.h" +#include "process_group.h" +#include "util/net.h" +namespace xllm { + +DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank, + int32_t world_size, + int32_t dit_dp_size, + int32_t dit_tp_size, + int32_t dit_sp_size, + int32_t dit_cfg_size) + : CollectiveCommunicatorBase(global_rank, world_size) { + parallel_args_ = std::make_unique(global_rank, + world_size, + dit_dp_size, + dit_tp_size, + dit_sp_size, + dit_cfg_size, + /*process_group=*/nullptr); + DiTMapping::Options dit_mapping_options; + dit_mapping_options.dit_tp_size(dit_tp_size) + .dit_sp_size(dit_sp_size) + .dit_cfg_size(dit_cfg_size) + .dit_dp_size(dit_dp_size); + dit_mapping_ = std::make_unique( + world_size, global_rank, dit_mapping_options); +} + +void DiTCollectiveCommunicator::create_process_groups( + const std::string& master_addr, + const torch::Device& device) { + Device device_(device); + device_.set_device(); + std::string host; + int32_t port; + net::parse_host_port_from_addr(master_addr, host, port); + + int32_t global_rank = parallel_args_->rank(); + int32_t world_size = parallel_args_->world_size(); + int32_t dp_size = parallel_args_->dp_size(); + int32_t tp_size = parallel_args_->tp_size(); + int32_t sp_size = parallel_args_->sp_size(); + int32_t cfg_size = parallel_args_->cfg_size(); + + process_group_ = create_process_group(global_rank, + world_size, + world_size, + ++port, + false, + host, + "world_group", + device); + + parallel_args_->process_group_ = process_group_.get(); + + if (tp_size > 1 && dit_mapping_) { + auto tp_parallel_info = dit_mapping_->get_parallel_info("tp"); + auto group_id = tp_parallel_info.current_group_id(); + auto num_group = tp_parallel_info.num_group(); + auto local_rank = tp_parallel_info.rank(); + auto& rank_per_group = tp_parallel_info.rank_per_group()[group_id]; + int port_offset = group_id + 1; +#if defined(USE_NPU) || defined(USE_MLU) + dit_tp_group_ = create_process_group(global_rank, + local_rank, + rank_per_group, + world_size, + tp_size, + port + port_offset, + host, + "tp_group", + device); +#endif + parallel_args_->dit_tp_group_ = dit_tp_group_.get(); + port += num_group; + } + + if (sp_size > 1 && dit_mapping_) { + auto sp_parallel_info = dit_mapping_->get_parallel_info("sp"); + auto group_id = sp_parallel_info.current_group_id(); + auto num_group = sp_parallel_info.num_group(); + auto local_rank = sp_parallel_info.rank(); + auto& rank_per_group = sp_parallel_info.rank_per_group()[group_id]; + int port_offset = group_id + 1; +#if defined(USE_NPU) || defined(USE_MLU) + dit_sp_group_ = create_process_group(global_rank, + local_rank, + rank_per_group, + world_size, + sp_size, + port + port_offset, + host, + "sp_group", + device); +#endif + parallel_args_->dit_sp_group_ = dit_sp_group_.get(); + port += num_group; + } + + if (cfg_size > 1 && dit_mapping_) { + auto cfg_parallel_info = dit_mapping_->get_parallel_info("cfg"); + auto group_id = cfg_parallel_info.current_group_id(); + auto num_group = cfg_parallel_info.num_group(); + auto local_rank = cfg_parallel_info.rank(); + auto& rank_per_group = cfg_parallel_info.rank_per_group()[group_id]; + int port_offset = group_id + 1; +#if defined(USE_NPU) || defined(USE_MLU) + dit_cfg_group_ = create_process_group(global_rank, + local_rank, + rank_per_group, + world_size, + cfg_size, + port + port_offset, + host, + "cfg_group", + device); +#endif + parallel_args_->dit_cfg_group_ = dit_cfg_group_.get(); + port += num_group; + } + + if (dp_size > 1 && dit_mapping_) { + auto dp_parallel_info = dit_mapping_->get_parallel_info("dp"); + auto group_id = dp_parallel_info.current_group_id(); + auto num_group = dp_parallel_info.num_group(); + auto local_rank = dp_parallel_info.rank(); + auto& rank_per_group = dp_parallel_info.rank_per_group()[group_id]; + int port_offset = group_id + 1; +#if defined(USE_NPU) || defined(USE_MLU) + dit_dp_group_ = create_process_group(global_rank, + local_rank, + rank_per_group, + world_size, + dp_size, + port + port_offset, + host, + "dp_group", + device); +#endif + parallel_args_->dit_dp_group_ = dit_dp_group_.get(); + port += num_group; + } +} + +const ParallelArgs* DiTCollectiveCommunicator::parallel_args() { + // TODO: init communicator + return parallel_args_.get(); +} + +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/dit_collective_communicator.h b/xllm/core/framework/parallel_state/dit_collective_communicator.h new file mode 100644 index 000000000..0acb68431 --- /dev/null +++ b/xllm/core/framework/parallel_state/dit_collective_communicator.h @@ -0,0 +1,50 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "collective_communicator_base.h" +#include "dit_mapping.h" + +namespace xllm { + +class DiTCollectiveCommunicator : public CollectiveCommunicatorBase { + public: + DiTCollectiveCommunicator(int32_t global_rank, + int32_t world_size, + int32_t dit_dp_size, + int32_t dit_tp_size, + int32_t dit_sp_size, + int32_t dit_cfg_size); + + ~DiTCollectiveCommunicator() = default; + + void create_process_groups(const std::string& master_addr, + const torch::Device& device) override; + + // init communicator and return parallel args. + const ParallelArgs* parallel_args() override; + + private: + std::unique_ptr dit_mapping_{nullptr}; + std::unique_ptr parallel_args_; + std::unique_ptr process_group_; + std::unique_ptr dit_tp_group_; + std::unique_ptr dit_sp_group_; + std::unique_ptr dit_dp_group_; + std::unique_ptr dit_cfg_group_; +}; + +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/dit_mapping.cpp b/xllm/core/framework/parallel_state/dit_mapping.cpp new file mode 100644 index 000000000..96608d2c1 --- /dev/null +++ b/xllm/core/framework/parallel_state/dit_mapping.cpp @@ -0,0 +1,145 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "dit_mapping.h" + +#include + +namespace xllm { + +DiTMapping::DiTMapping(const int32_t world_size, + const int32_t rank, + const Options& options) + : rank_(rank), options_(options), world_size_(world_size) { + tp_.backend("hccl"); + sp_.backend("hccl"); + cfg_.backend("hccl"); + dp_.backend("hccl"); + parse_parallel_info(); + validate(); + rank_generator_ = + std::make_unique(tp_.group_size(), + sp_.group_size(), + cfg_.group_size(), + dp_.group_size(), + /*group_order=*/"tp-sp-cfg-dp"); + set_group_by_type(tp_, "tp"); + set_group_by_type(sp_, "sp"); + set_group_by_type(cfg_, "cfg"); + set_group_by_type(dp_, "dp"); +} + +void DiTMapping::parse_parallel_info() { + if (options_.dit_tp_size() != -1) { + tp_.group_size(options_.dit_tp_size()); + } + if (options_.dit_sp_size() != -1) { + sp_.group_size(options_.dit_sp_size()); + } + if (options_.dit_cfg_size() != -1) { + cfg_.group_size(options_.dit_cfg_size()); + } + if (options_.dit_dp_size() != -1) { + dp_.group_size(options_.dit_dp_size()); + } +} + +void DiTMapping::validate() { + CHECK(cfg_.group_size() * tp_.group_size() * sp_.group_size() * + dp_.group_size() == + world_size_) + << "World size must equal to cfg_size * tp_size * sp_size. " + "cfg_size is " + + std::to_string(cfg_.group_size()) + + ". " + "tp_size is " + + std::to_string(tp_.group_size()) + + ". " + "sp_size is " + + std::to_string(sp_.group_size()) + + ". " + "dp_size is " + + std::to_string(dp_.group_size()) + + ". " + "world_size is " + + std::to_string(world_size_) + + ". " + "Please check `cfg`, `tp`, `sp`, `dp` and `world_size`."; + + CHECK(cfg_.group_size() <= 2) << "cfg_size must less than 2 " + "cfg_size is " + + std::to_string(cfg_.group_size()) + + ". Please check `cfg` ."; +} + +void DiTMapping::set_group_by_type(ParallelInfo& parallel_info, + const std::string& group_type) { + auto rank_per_group = rank_generator_->get_ranks(group_type); + parallel_info.rank_per_group(rank_per_group); + auto group_size = rank_per_group[0].size(); + parallel_info.num_group(world_size_ / group_size); + auto [current_group_id, local_rank] = + get_current_group_id(rank_per_group, rank_); + CHECK(current_group_id >= 0 && local_rank >= 0) + << "Failed to get current group id : " << current_group_id + << " local_rank " << local_rank; + parallel_info.current_group_id(current_group_id); + parallel_info.rank(local_rank); +} + +std::tuple DiTMapping::get_current_group_id( + const std::vector>& rank_per_group, + int32_t target_rank_id) { + for (int32_t idx = 0; idx < rank_per_group.size(); ++idx) { + const auto& group = rank_per_group[idx]; + auto it = std::find(group.begin(), group.end(), target_rank_id); + if (it != group.end()) { + return std::make_tuple(idx, std::distance(group.begin(), it)); + } + } + return std::make_tuple(-1, -1); +} + +const ParallelInfo& DiTMapping::get_parallel_info( + const std::string& group_type) const { + if (group_type == "tp") { + return tp_; + } else if (group_type == "sp") { + return sp_; + } else if (group_type == "cfg") { + return cfg_; + } else if (group_type == "dp") { + return dp_; + } else { + LOG(ERROR) << "get unexpected group_type: " << group_type; + } +} + +nlohmann::json DiTMapping::to_json() { + nlohmann::json data; + + data["SpSize"] = options_.dit_sp_size(); + data["TpSize"] = options_.dit_tp_size(); + data["CfgSize"] = options_.dit_cfg_size(); + data["worldSize"] = world_size_; + data["rank"] = rank_; + data["sp"] = sp_.to_json(); + data["tp"] = tp_.to_json(); + data["cfg"] = cfg_.to_json(); + data["dp"] = dp_.to_json(); + return data; +} + +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/dit_mapping.h b/xllm/core/framework/parallel_state/dit_mapping.h new file mode 100644 index 000000000..02b0e1af8 --- /dev/null +++ b/xllm/core/framework/parallel_state/dit_mapping.h @@ -0,0 +1,72 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/common/macros.h" +#include "core/util/json_reader.h" +#include "mapping_npu.h" +#include "rank_generator.h" +namespace xllm { + +class DiTMapping final { + public: + struct Options { + // cfg size + PROPERTY(int32_t, dit_cfg_size) = -1; + // tp size + PROPERTY(int32_t, dit_tp_size) = -1; + // sp size + PROPERTY(int32_t, dit_sp_size) = -1; + // dp size + PROPERTY(int32_t, dit_dp_size) = -1; + }; + + DiTMapping(const int32_t world_size, + const int32_t rank, + const Options& options); + + int32_t get_num_nodes(); + + void parse_parallel_info(); + + void validate(); + + void set_group_by_type(ParallelInfo& parallel_info, + const std::string& group_type); + + std::tuple get_current_group_id( + const std::vector>& rank_per_group, + int target_rank_id); + + const ParallelInfo& get_parallel_info(const std::string& group_type) const; + + nlohmann::json to_json(); + + private: + Options options_; + int32_t num_nodes_; + int32_t world_size_ = 0; + int32_t rank_ = 0; + int32_t local_world_size_ = 0; + ParallelInfo sp_ = ParallelInfo(); + ParallelInfo tp_ = ParallelInfo(); + ParallelInfo cfg_ = ParallelInfo(); + ParallelInfo dp_ = ParallelInfo(); + std::unique_ptr rank_generator_{nullptr}; +}; +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/mapping_npu.cpp b/xllm/core/framework/parallel_state/mapping_npu.cpp index b876a3e24..c7873f3cd 100644 --- a/xllm/core/framework/parallel_state/mapping_npu.cpp +++ b/xllm/core/framework/parallel_state/mapping_npu.cpp @@ -42,8 +42,8 @@ MappingNPU::MappingNPU(const std::string rank_table_file, num_nodes_ = get_num_nodes(); world_size_ = world_size; local_world_size_ = world_size / num_nodes_; - attn_o_proj_tp_.backend("lccl"); - attn_inner_sp_.backend("lccl"); + attn_o_proj_tp_.backend("hccl"); + attn_inner_sp_.backend("hccl"); parse_parallel_info(); validate(); get_tp_group(word_embed_tp_); diff --git a/xllm/core/framework/parallel_state/mlu_process_group.h b/xllm/core/framework/parallel_state/mlu_process_group.h index 6caf088fe..9afba2273 100644 --- a/xllm/core/framework/parallel_state/mlu_process_group.h +++ b/xllm/core/framework/parallel_state/mlu_process_group.h @@ -55,6 +55,30 @@ class ProcessGroupImpl : public ProcessGroup { pg_ = std::make_unique( store, rank, rank_size, pg_options); } + + ProcessGroupImpl(int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(global_rank, world_size, device) { + c10::intrusive_ptr pg_options = + torch_mlu::ProcessGroupCNCL::Options::create(); + pg_options->group_name = group_name; + std::vector ranks_unsigned; + ranks_unsigned.reserve(group_ranks.size()); + for (int32_t rank : group_ranks) { + ranks_unsigned.push_back(static_cast(rank)); + } + pg_options->global_ranks_in_group = ranks_unsigned; + auto store = create_tcp_store(host, port, local_rank); + pg_ = std::make_unique( + store, local_rank, rank_size, pg_options); + } }; } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/parallel_state/npu_process_group.cpp b/xllm/core/framework/parallel_state/npu_process_group.cpp index b5d6405fa..d9dd2ead3 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.cpp +++ b/xllm/core/framework/parallel_state/npu_process_group.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "npu_process_group.h" #include @@ -98,12 +97,55 @@ ProcessGroupImpl::ProcessGroupImpl(int32_t global_rank, hccl_pg_options->global_ranks_in_group = uint32_ranks; rank = local_rank; } - auto store = create_tcp_store(host, port, rank); pg_ = std::make_unique( store, rank, rank_size, hccl_pg_options); } +ProcessGroupImpl::ProcessGroupImpl(int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(global_rank, world_size, device), + comm_stream_(c10_npu::getNPUStreamFromPool(device.index())) { + c10::intrusive_ptr hccl_pg_options = + c10d_npu::ProcessGroupHCCL::Options::create(); + hccl_pg_options->group_id = group_name; + if (world_size != rank_size) { + std::vector uint32_ranks; + for (auto rank : group_ranks) { + uint32_ranks.push_back(static_cast(rank)); + } + hccl_pg_options->global_ranks_in_group = uint32_ranks; + } + + if (FLAGS_dit_debug_print) { + std::stringstream ranks_ss; + ranks_ss << "Group : [" << group_ranks[0]; + for (size_t i = 1; i < group_ranks.size(); i++) { + ranks_ss << ", " << group_ranks[i]; + } + ranks_ss << "]" << std::endl; + + LOG(INFO) << "Creating HccLProcessGroup for " << group_name + << " group, with global rank " << global_rank << ", local rank" + << local_rank << ", with port " << host << ":" << port + << ", rank_size is " << rank_size << ", world_size is " + << world_size + << ", the following ranks should share the same port, " + << ranks_ss.str(); + } + + auto store = create_tcp_store(host, port, local_rank); + pg_ = std::make_unique( + store, local_rank, rank_size, hccl_pg_options); +} + // Destructor. ProcessGroupImpl::~ProcessGroupImpl() { if (pg_) { @@ -122,93 +164,4 @@ ProcessGroupImpl::ProcessGroupImpl(int rank, comm_(comm), comm_stream_(c10_npu::getNPUStreamFromPool(device.index())) {} -void ProcessGroupImpl::allgather(const torch::Tensor& input, - std::vector& outputs) { - CHECK_EQ(input.device(), device()) - << "input should be on the same device as the process group"; - CHECK_EQ(outputs.size(), world_size()) - << "outputs should have the same size as world_size"; - check_input(input); - torch::DeviceGuard device_guard(device()); - - if (pg_) { - std::vector input_tensors = {input}; - std::vector> output_tensors = {outputs}; - pg_->allgather(output_tensors, input_tensors)->wait(); - return; - } - CHECK(comm_ != nullptr) << "HCCL comm is not initialized."; - - torch::Tensor flattened_output = flatten_for_scatter_gather(outputs); - - const auto count = input.numel(); - const auto data_type = to_hccl_data_type(input); - - auto compute_stream = c10_npu::getCurrentNPUStream(); - - auto ready = std::make_shared(); - ready->record(compute_stream); - ready->block(comm_stream_); - - c10_npu::NPUCachingAllocator::recordStream(input.storage().data_ptr(), - comm_stream_); - c10_npu::NPUCachingAllocator::recordStream( - flattened_output.storage().data_ptr(), comm_stream_); - - HCCLCHECK(HcclAllGather( - /*sendbuff=*/input.data_ptr(), - /*recvbuff=*/flattened_output.data_ptr(), - /*sendcount=*/count, - /*datatype=*/data_type, - /*comm=*/comm_, - /*stream=*/comm_stream_.stream())); - - auto done = std::make_shared(); - done->record(comm_stream_); - done->block(compute_stream); - - for (int i = 0; i < static_cast(outputs.size()); ++i) { - outputs[i].copy_(flattened_output[i], /*non_blocking=*/true); - } -} - -void ProcessGroupImpl::allreduce(torch::Tensor& input) { - CHECK_EQ(input.device(), device()) - << "input should be on the same device as the process group"; - check_input(input); - torch::DeviceGuard device_guard(device()); - - if (pg_) { - std::vector input_tensors = {input}; - pg_->allreduce(input_tensors)->wait(); - return; - } - CHECK(comm_ != nullptr) << "HCCL comm is not initialized."; - - const auto count = input.numel(); - const auto data_type = to_hccl_data_type(input); - - auto compute_stream = c10_npu::getCurrentNPUStream(); - - auto ready = std::make_shared(); - ready->record(compute_stream); - ready->block(comm_stream_); - - c10_npu::NPUCachingAllocator::recordStream(input.storage().data_ptr(), - comm_stream_); - - HCCLCHECK(HcclAllReduce( - /*sendbuff=*/input.data_ptr(), - /*recvbuff=*/input.data_ptr(), - /*count=*/count, - /*datatype=*/data_type, - /*op=*/HCCL_REDUCE_SUM, - /*comm=*/comm_, - /*stream=*/comm_stream_.stream())); - - auto done = std::make_shared(); - done->record(comm_stream_); - done->block(compute_stream); -} - } // namespace xllm diff --git a/xllm/core/framework/parallel_state/npu_process_group.h b/xllm/core/framework/parallel_state/npu_process_group.h index e806c29ff..3f9fb3bec 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.h +++ b/xllm/core/framework/parallel_state/npu_process_group.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "core/common/global_flags.h" #include "hccl/hccl.h" #include "process_group.h" @@ -41,14 +42,19 @@ class ProcessGroupImpl : public ProcessGroup { const std::string& group_name, const torch::Device& device); + ProcessGroupImpl(int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device); + // Destructor. ~ProcessGroupImpl() override; - void allreduce(torch::Tensor& input) override; - - void allgather(const torch::Tensor& input, - std::vector& outputs) override; - private: HcclComm comm_ = nullptr; c10_npu::NPUStream comm_stream_; @@ -64,4 +70,4 @@ class ProcessGroupImpl : public ProcessGroup { } \ } while (0) #endif -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/parallel_args.h b/xllm/core/framework/parallel_state/parallel_args.h index 1e3cefaa7..4505a7191 100644 --- a/xllm/core/framework/parallel_state/parallel_args.h +++ b/xllm/core/framework/parallel_state/parallel_args.h @@ -88,6 +88,21 @@ struct ParallelArgs { dp_local_process_group_(dp_local_process_group), dp_size_(dp_size) {} + ParallelArgs(int32_t rank, + int32_t world_size, + int32_t dp_size, + int32_t tp_size, + int32_t sp_size, + int32_t cfg_size, + ProcessGroup* process_group) + : rank_(rank), + world_size_(world_size), + dp_size_(dp_size), + tp_size_(tp_size), + sp_size_(sp_size), + cfg_size_(cfg_size), + process_group_(process_group) {} + // rank of current process PROPERTY(int32_t, rank) = 0; @@ -103,6 +118,15 @@ struct ParallelArgs { // cp size PROPERTY(int32_t, cp_size) = 1; + // tp size + PROPERTY(int32_t, tp_size) = 1; + + // sp size + PROPERTY(int32_t, sp_size) = 1; + + // cfg size + PROPERTY(int32_t, cfg_size) = 1; + // atb hccl mapping json data PROPERTY(nlohmann::json, mapping_data); @@ -130,6 +154,12 @@ struct ParallelArgs { ProcessGroup* sp_group_ = nullptr; ProcessGroup* moe_ep_group_ = nullptr; ProcessGroup* moe_tp_group_ = nullptr; + + // ProcessGroups for DiT models + ProcessGroup* dit_tp_group_ = nullptr; + ProcessGroup* dit_sp_group_ = nullptr; + ProcessGroup* dit_cfg_group_ = nullptr; + ProcessGroup* dit_dp_group_ = nullptr; }; } // namespace xllm diff --git a/xllm/core/framework/parallel_state/parallel_state.cpp b/xllm/core/framework/parallel_state/parallel_state.cpp index be71a670c..b1aa6fd71 100644 --- a/xllm/core/framework/parallel_state/parallel_state.cpp +++ b/xllm/core/framework/parallel_state/parallel_state.cpp @@ -268,6 +268,156 @@ torch::Tensor scatter(torch::Tensor input, return tensor_list[rank]; } +std::function all_to_all_4D(const torch::Tensor& input, + int32_t scatter_idx, + int32_t gather_idx, + bool async_ops, + ProcessGroup* process_group) { + if (!process_group) { + return [input]() { return input; }; + } + const int32_t group_size = process_group->world_size(); + + if (group_size == 1) { + return [input]() { return input; }; + } + + auto rank = process_group->rank(); + + TORCH_CHECK(input.dim() == 4, + "all_to_all_4D: input must be 4D, got dim=", + input.dim()); + auto send_input = input; + + if (scatter_idx == 2 && gather_idx == 1) { + // branch A : from "sequence shard" -> "head shard" + // input: (bs, seqlen / group_size (shard_seqlen), head_num, head_dim) + // output (bs, seqlen, head_num / group_size, head_dim) + auto sizes = send_input.sizes().vec(); + const int64_t bs = sizes[0]; + const int64_t shard_seqlen = sizes[1]; + const int64_t head_num = sizes[2]; + const int64_t head_size = sizes[3]; + const int64_t seqlen = shard_seqlen * group_size; + TORCH_CHECK(head_num % group_size == 0, + "all_to_all_4D(A): head_num must be divisible by group_size"); + const int64_t shard_head_num = head_num / group_size; + + // prepare expected shape for All2All (group_size, shard_seqlen, bs, + // shard_head_num, head_size) + auto input_t = + send_input + .reshape({bs, shard_seqlen, group_size, shard_head_num, head_size}) + .transpose( + 0, + 2) // (group_size, shard_seqlen, bs, shard_head_num, head_size) + .contiguous(); + torch::Tensor output = torch::empty_like(input_t); + std::vector input_split_size = {}; + std::vector output_split_size = {}; + + if (!async_ops) { + process_group->all_to_all_single( + output, input_t, output_split_size, input_split_size, async_ops); + output = output.reshape({seqlen, bs, shard_head_num, head_size}) + .transpose(0, 1) + .contiguous() + .reshape({bs, seqlen, shard_head_num, head_size}); + return [output]() { return output; }; + } else { + c10::intrusive_ptr all2all_work; + process_group->all_to_all_single(output, + input_t, + output_split_size, + input_split_size, + async_ops, + &all2all_work); + return [output, + all2all_work, + bs, + seqlen, + shard_head_num, + head_size]() mutable -> torch::Tensor { + all2all_work->wait(); + auto comm_output = + output.reshape({seqlen, bs, shard_head_num, head_size}) + .transpose(0, 1) + .contiguous() + .reshape({bs, seqlen, shard_head_num, head_size}); + return comm_output; + }; + } + } else if (scatter_idx == 1 && gather_idx == 2) { + // branch B : from "head shard" -> "sequence shard" + // input: (bs, seqlen, head_num / group_size, head_size) + // output (bs, seqlen / group_size, head_num, haed_size) + auto sizes = send_input.sizes().vec(); + const int64_t bs = sizes[0]; + const int64_t seqlen = sizes[1]; + const int64_t shard_head_num = sizes[2]; + const int64_t head_size = sizes[3]; + TORCH_CHECK(seqlen % group_size == 0, + "all_to_all_4D(B): seqlen must be divisible by group_size"); + const int64_t shard_seqlen = seqlen / group_size; + const int64_t head_num = shard_head_num * group_size; + + // prepare expected shape for All2All (group_size, shard_head_num, + // shard_seqlen, bs, head_size) + auto input_t = + send_input + .reshape({bs, group_size, shard_seqlen, shard_head_num, head_size}) + .transpose( + 0, + 3) // (shard_head_num, group_size, shard_seqlen, bs, head_size) + .transpose( + 0, + 1) // (group_size, shard_head_num, shard_seqlen, bs, head_size) + .contiguous(); + torch::Tensor output = torch::empty_like(input_t); + std::vector input_split_size = {}; + std::vector output_split_size = {}; + + if (!async_ops) { + process_group->all_to_all_single(output, + input_t, + output_split_size, + input_split_size, + /*async_op=*/false); + output = output.reshape({head_num, shard_seqlen, bs, head_size}) + .transpose(0, 2) + .contiguous() + .reshape({bs, shard_seqlen, head_num, head_size}); + return [output]() { return output; }; + } else { + c10::intrusive_ptr all2all_work; + process_group->all_to_all_single(output, + input_t, + output_split_size, + input_split_size, + /*async_op=*/true, + &all2all_work); + return [output, + all2all_work, + head_num, + shard_seqlen, + bs, + head_size]() mutable -> torch::Tensor { + all2all_work->wait(); + auto comm_output = + output.reshape({head_num, shard_seqlen, bs, head_size}) + .transpose(0, 2) + .contiguous() + .reshape({bs, shard_seqlen, head_num, head_size}); + return comm_output; + }; + } + } else { + TORCH_CHECK(false, + "all_to_all_4D: only (scatter_idx,gather_idx)=(2,1) or (1,2) " + "are supported"); + } +} + std::vector> create_npu_process_groups( const std::vector& devices) { #if defined(USE_NPU) diff --git a/xllm/core/framework/parallel_state/parallel_state.h b/xllm/core/framework/parallel_state/parallel_state.h index f356c9905..32caf7c8a 100644 --- a/xllm/core/framework/parallel_state/parallel_state.h +++ b/xllm/core/framework/parallel_state/parallel_state.h @@ -72,6 +72,12 @@ torch::Tensor scatter(torch::Tensor input, ProcessGroup* process_group, int dim = -1); +std::function all_to_all_4D(const torch::Tensor& input_, + int32_t scatter_idx, + int32_t gather_idx, + bool is_sync, + ProcessGroup* pg); + // Create a process group where each process has a single device // devices: list of devices to create process groups on. std::vector> create_npu_process_groups( diff --git a/xllm/core/framework/parallel_state/process_group.cpp b/xllm/core/framework/parallel_state/process_group.cpp index eb22f26a1..4b73661d3 100644 --- a/xllm/core/framework/parallel_state/process_group.cpp +++ b/xllm/core/framework/parallel_state/process_group.cpp @@ -156,6 +156,35 @@ void ProcessGroup::reduce_scatter(const torch::Tensor& input, ->wait(); } +void ProcessGroup::all_to_all_single( + torch::Tensor output, + torch::Tensor input, + std::vector output_split_sizes, + std::vector input_split_sizes, + bool async_op, + c10::intrusive_ptr* async_work) { + CHECK(pg_ != nullptr) << "Process group is not initialized."; + CHECK(output.defined()) + << "Output of all_to_all_single function is not defined"; + CHECK(input.defined()) + << "Input of all_to_all_single function is not defined"; + if (input.is_complex()) { + input = torch::view_as_real(input); + } + if (output.is_complex()) { + output = torch::view_as_real(output); + } + + auto opts = c10d::AllToAllOptions(); + auto work = pg_->alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts); + if (async_op) { + *async_work = work; + } else { + work->wait(); + } +} + std::unique_ptr create_process_group( int32_t rank, int32_t world_size, @@ -169,4 +198,32 @@ std::unique_ptr create_process_group( rank, world_size, rank_size, port, trans, host, group_name, device); } +#if defined(USE_NPU) || defined(USE_MLU) +// we only support DiT models onNPU and MLU for now. +// TODO: This function is used by DiT models, since the DiT communication group +// info have already been calculated by rank_generator, we only need to pass the +// info to create the process groups. For any device that want to reuse the +// function and dit process groups, please implement the corresponding +// ProcessGroupImpl construct function. +std::unique_ptr create_process_group( + int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device) { + return std::make_unique(global_rank, + local_rank, + group_ranks, + world_size, + rank_size, + port, + host, + group_name, + device); +} +#endif } // namespace xllm diff --git a/xllm/core/framework/parallel_state/process_group.h b/xllm/core/framework/parallel_state/process_group.h index 35855e8ad..cd7c47935 100644 --- a/xllm/core/framework/parallel_state/process_group.h +++ b/xllm/core/framework/parallel_state/process_group.h @@ -45,12 +45,16 @@ class ProcessGroup { virtual ~ProcessGroup() = default; int32_t rank() const { - CHECK(pg_ != nullptr) << "Process group is not initialized."; + if (pg_ == nullptr) { + return rank_; + } return pg_->getRank(); } int32_t world_size() const { - CHECK(pg_ != nullptr) << "Process group is not initialized."; + if (pg_ == nullptr) { + return world_size_; + } return pg_->getSize(); } @@ -88,6 +92,14 @@ class ProcessGroup { virtual void reduce_scatter(const torch::Tensor& input, torch::Tensor& output); + virtual void all_to_all_single( + torch::Tensor output, + torch::Tensor input, + std::vector output_split_sizes = {}, + std::vector input_split_sizes = {}, + bool async_op = false, + c10::intrusive_ptr* async_work = nullptr); + private: // rank of current process. int32_t rank_ = 0; @@ -121,4 +133,18 @@ std::unique_ptr create_process_group( const std::string& group_name, const torch::Device& device); +#if defined(USE_NPU) || defined(USE_MLU) +// for DiT models +std::unique_ptr create_process_group( + int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device); +#endif + } // namespace xllm diff --git a/xllm/core/framework/parallel_state/rank_generator.h b/xllm/core/framework/parallel_state/rank_generator.h new file mode 100644 index 000000000..e3e185cb0 --- /dev/null +++ b/xllm/core/framework/parallel_state/rank_generator.h @@ -0,0 +1,266 @@ +#include + +#include +#include +#include + +#include "core/common/global_flags.h" + +class RankGenerator { + public: + RankGenerator(int32_t tp, + int32_t sp, + int32_t cfg, + int32_t dp, + const std::string& group_order = "tp-sp-cfg-dp", + int32_t rank_offset = 0) + : tp_(tp), sp_(sp), cfg_(cfg), dp_(dp), rank_offset_(rank_offset) { + world_size_ = tp * sp * cfg * dp; + + group_size_map_["tp"] = tp; + group_size_map_["sp"] = sp; + group_size_map_["cfg"] = cfg; + group_size_map_["dp"] = dp; + + auto full_order = group_order; + for (const auto& group_size_pair : group_size_map_) { + const std::string& group_name = group_size_pair.first; + int32_t group_size = group_size_pair.second; + + if (full_order.find(group_name) == std::string::npos) { + if (group_size != 1) { + LOG(FATAL) << "The size of (" << group_name << ") is (" << group_size + << "), but you haven't specified it in order (" + << full_order << ")."; + } else { + full_order = full_order + "-" + group_name; + } + } + } + + group_order_ = full_order; + + auto split = [](const std::string& s, + char delimiter) -> std::vector { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; + }; + + ordered_group_name_ = split(group_order_, '-'); + for (const std::string& token : ordered_group_name_) { + auto it = group_size_map_.find(token); + if (it != group_size_map_.end()) { + ordered_group_size_.push_back(it->second); + } + } + + LOG(INFO) << "RankGenerator initialized with tp=" << tp << ", sp=" << sp + << ", cfg=" << cfg << ", dp=" << dp << ", order=" << group_order_ + << ", world_size=" << world_size_; + + if (FLAGS_dit_debug_print) { + debug_print(); + } + } + + std::vector> get_ranks(const std::string& group_query) { + std::vector mask = get_mask(group_query); + std::vector> ranks = + generate_masked_orthogonal_rank_groups( + world_size_, ordered_group_size_, mask); + if (rank_offset_ > 0) { + for (auto& rank_group : ranks) { + for (size_t i = 0; i < rank_group.size(); i++) { + rank_group[i] += rank_offset_; + } + } + } + + return ranks; + } + + int32_t get_world_size() const { return world_size_; } + const std::string& get_order() const { return group_order_; } + int32_t get_tp() const { return tp_; } + int32_t get_sp() const { return sp_; } + int32_t get_cfg() const { return cfg_; } + int32_t get_dp() const { return dp_; } + + void debug_print() { + print_ranks("cfg"); + print_ranks("tp"); + print_ranks("sp"); + print_ranks("dp"); + } + + void print_ranks(const std::string& group_query) { + auto ranks = get_ranks(group_query); + + std::stringstream ss; + ss << "Ranks for query '" << group_query << "':" << std::endl; + for (size_t i = 0; i < ranks.size(); i++) { + ss << " Group " << i << ": ["; + for (size_t j = 0; j < ranks[i].size(); j++) { + ss << ranks[i][j]; + if (j < ranks[i].size() - 1) ss << ", "; + } + ss << "]" << std::endl; + } + LOG(INFO) << ss.str(); + } + + private: + std::vector prefix_product(const std::vector& group_size, + int32_t init = 1) { + std::vector prefix_product_sizes; + prefix_product_sizes.push_back(init); + for (int32_t size : group_size) { + init = init * size; + prefix_product_sizes.push_back(init); + } + return prefix_product_sizes; + } + + int32_t inner_product(const std::vector& a, + const std::vector& b) { + int32_t result = 0; + for (size_t i = 0; i < a.size(); i++) { + result += a[i] * b[i]; + } + return result; + } + + std::vector decompose(int32_t index, + const std::vector& shape, + const std::vector& stride = {}) { + std::vector idx; + std::vector actual_stride; + + if (stride.empty()) { + actual_stride = prefix_product(shape); + } else { + actual_stride = stride; + } + + for (size_t i = 0; i < shape.size(); i++) { + int32_t d = actual_stride[i]; + int32_t s = shape[i]; + idx.push_back((index / d) % s); + } + + int32_t sum = 0; + for (size_t i = 0; i < idx.size(); i++) { + sum += idx[i] * actual_stride[i]; + } + + if (sum != index) { + std::stringstream ss; + ss << "idx " << index << " with shape ["; + for (size_t i = 0; i < shape.size(); i++) { + ss << shape[i]; + if (i < shape.size() - 1) ss << ", "; + } + ss << "] mismatch the return idx ["; + for (size_t i = 0; i < idx.size(); i++) { + ss << idx[i]; + if (i < idx.size() - 1) ss << ", "; + } + ss << "]"; + LOG(INFO) << ss.str(); + } + + return idx; + } + + std::vector> generate_masked_orthogonal_rank_groups( + int32_t world_size, + const std::vector& parallel_size, + const std::vector& mask) { + std::vector queried_group_size; + std::vector unqueried_group_size; + for (size_t i = 0; i < parallel_size.size(); i++) { + if (mask[i]) { + queried_group_size.push_back(parallel_size[i]); + } else { + unqueried_group_size.push_back(parallel_size[i]); + } + } + std::vector global_group_stride = prefix_product(parallel_size); + std::vector queried_group_stride; + std::vector unqueried_group_stride; + for (size_t i = 0; i < parallel_size.size(); i++) { + if (mask[i]) { + queried_group_stride.push_back(global_group_stride[i]); + } else { + unqueried_group_stride.push_back(global_group_stride[i]); + } + } + std::vector queried_group_prefix = + prefix_product(queried_group_size); + // group size equals to the product of queryed group type sizes; + int32_t group_size = queried_group_prefix.back(); + int32_t num_of_group = world_size / group_size; + + std::vector> ranks; + for (int32_t group_index = 0; group_index < num_of_group; group_index++) { + std::vector decomposed_group_idx = + decompose(group_index, unqueried_group_size); + std::vector rank; + for (int32_t rank_in_group = 0; rank_in_group < group_size; + rank_in_group++) { + std::vector decomposed_rank_idx = + decompose(rank_in_group, queried_group_size); + int32_t calculated_rank = + inner_product(decomposed_rank_idx, queried_group_stride) + + inner_product(decomposed_group_idx, unqueried_group_stride); + rank.push_back(calculated_rank); + } + ranks.push_back(rank); + } + + return ranks; + } + + std::vector get_mask(const std::string& group_query) { + std::vector query_group_name = split(group_query, '-'); + std::vector mask(ordered_group_name_.size(), false); + + for (const std::string& group_name : query_group_name) { + auto it = std::find( + ordered_group_name_.begin(), ordered_group_name_.end(), group_name); + if (it != ordered_group_name_.end()) { + size_t index = std::distance(ordered_group_name_.begin(), it); + mask[index] = true; + } + } + + return mask; + } + + std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; + } + + private: + int32_t tp_; + int32_t sp_; + int32_t cfg_; + int32_t dp_; + int32_t rank_offset_; + int32_t world_size_; + std::string group_order_; + std::vector ordered_group_size_; + std::vector ordered_group_name_; + std::unordered_map group_size_map_; +}; diff --git a/xllm/core/framework/quant_args.h b/xllm/core/framework/quant_args.h index aa56cb417..6f2b66156 100644 --- a/xllm/core/framework/quant_args.h +++ b/xllm/core/framework/quant_args.h @@ -17,7 +17,9 @@ limitations under the License. #pragma once #include +#include #include +#include #include "common/macros.h" @@ -55,6 +57,35 @@ struct QuantArgs { // weight block size PROPERTY(std::vector, weight_block_size) = {}; + // exact module names or regexes prefixed with "re:" that should bypass + // quantization for compressed-tensors models. + PROPERTY(std::vector, ignored_modules) = {}; + + bool should_ignore_module(const std::string& module_name) const { + for (const auto& pattern : ignored_modules()) { + if (pattern == module_name) { + return true; + } + if (pattern.size() > 3 && pattern.rfind("re:", 0) == 0) { + try { + if (std::regex_match(module_name, std::regex(pattern.substr(3)))) { + return true; + } + } catch (const std::regex_error&) { + } + } + } + return false; + } + + QuantArgs for_module(const std::string& module_name) const { + QuantArgs local_args = *this; + if (should_ignore_module(module_name)) { + local_args.quant_method().clear(); + } + return local_args; + } + // check if weights can be fused bool can_be_fused() const { // can't fuse quantized weights if desc_act is true @@ -72,6 +103,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) { os << ", is_sym: " << args.is_sym(); os << ", activation_dynamic: " << args.activation_dynamic(); os << ", fmt: " << args.fmt(); + os << ", ignored_modules: " << args.ignored_modules().size(); os << "]"; return os; } diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 0bffd017f..2525e22dc 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -86,6 +86,7 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, if (input.has_latent()) { input_params.latent = util::proto_to_torch(input.latent()); } + if (input.has_masked_image_latent()) { input_params.masked_image_latent = util::proto_to_torch(input.masked_image_latent()); @@ -102,6 +103,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, } } + if (input.has_condition_image()) { + std::string raw_bytes; + if (!butil::Base64Decode(input.condition_image(), &raw_bytes)) { + LOG(ERROR) << "Base64 image decode failed"; + } + if (!decoder.decode(raw_bytes, input_params.condition_image)) { + LOG(ERROR) << "Image decode failed."; + } + } + if (input.has_mask_image()) { std::string raw_bytes; if (!butil::Base64Decode(input.mask_image(), &raw_bytes)) { diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index 7201b9662..ba2636eec 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -98,6 +98,8 @@ struct DiTInputParams { torch::Tensor image; + torch::Tensor condition_image; + torch::Tensor control_image; torch::Tensor mask_image; diff --git a/xllm/core/framework/request/mm_codec.cpp b/xllm/core/framework/request/mm_codec.cpp index e0237117f..8c86f727c 100644 --- a/xllm/core/framework/request/mm_codec.cpp +++ b/xllm/core/framework/request/mm_codec.cpp @@ -263,7 +263,7 @@ class MemoryVideoReader : public MemoryMediaReader { } tensor = torch::stack(frames_); // [T,C,H,W] - metadata.total_num_frames = static_cast(frames_.size()); + metadata.total_num_frames = static_cast(frames_.size()); metadata.duration = (metadata.fps > 0.0) ? static_cast(metadata.total_num_frames) / metadata.fps @@ -376,12 +376,12 @@ class MemoryAudioReader : public MemoryMediaReader { } AVChannelLayout out_layout; - av_channel_layout_default(&out_layout, static_cast(target_ch_)); + av_channel_layout_default(&out_layout, target_ch_); if (swr_alloc_set_opts2(&swr_ctx_, &out_layout, AV_SAMPLE_FMT_FLT, - static_cast(target_sr_), + target_sr_, &in_layout, codec_ctx_->sample_fmt, codec_ctx_->sample_rate, @@ -437,13 +437,13 @@ class MemoryAudioReader : public MemoryMediaReader { // build output tensor and compute metadata if (target_ch_ == 1) { tensor = torch::from_blob(pcm_.data(), - {static_cast(pcm_.size())}, + {static_cast(pcm_.size())}, torch::TensorOptions().dtype(torch::kFloat32)) .clone(); metadata.duration = static_cast(pcm_.size()) / target_sr_; } else { - int64_t T = - static_cast(pcm_.size() / static_cast(target_ch_)); + int32_t T = + static_cast(pcm_.size() / static_cast(target_ch_)); tensor = torch::from_blob(pcm_.data(), {T, target_ch_}, torch::TensorOptions().dtype(torch::kFloat32)) @@ -485,7 +485,7 @@ class MemoryAudioReader : public MemoryMediaReader { } // append converted samples to pcm buffer - const int64_t n = static_cast(converted) * target_ch_; + const int64_t n = static_cast(converted * target_ch_); pcm_.reserve(pcm_.size() + static_cast(n)); pcm_.insert(pcm_.end(), out_buf.data(), out_buf.data() + n); return converted; @@ -493,8 +493,8 @@ class MemoryAudioReader : public MemoryMediaReader { private: SwrContext* swr_ctx_ = nullptr; - int64_t target_sr_ = 16000; - int64_t target_ch_ = 1; + int32_t target_sr_ = 16000; + int32_t target_ch_ = 1; std::vector pcm_; }; diff --git a/xllm/core/framework/request/mm_data_visitor.cpp b/xllm/core/framework/request/mm_data_visitor.cpp index 4755bf7e4..4ce3fe46a 100644 --- a/xllm/core/framework/request/mm_data_visitor.cpp +++ b/xllm/core/framework/request/mm_data_visitor.cpp @@ -70,8 +70,6 @@ bool CollectMMDataTensorVisitor::visit(MMData& data) { } bool EncoderInputGatherVisitor::visit(MMDataItem& item) { - if (item.state().prefix_complete_cached()) return true; - if (item.is_embedded()) return true; for (const auto& [key, value] : item.data()) { @@ -103,8 +101,6 @@ bool EncoderInputGatherVisitor::finish(MMBatchData& mm_data) { } bool EncoderOutputScatterVisitor::visit(MMDataItem& item) { - if (item.state().prefix_complete_cached()) return true; - if (item.is_embedded()) return true; std::string prefix; @@ -157,7 +153,6 @@ bool EncoderOutputScatterVisitor::finish() const { bool EncoderEmbeddingGatherVisitor::visit(MMDataItem& item) { const auto& state = item.state(); - if (state.prefix_complete_cached()) return true; int modality_tokens = state.token_pos().length; uint32_t cached_token_num = state.prefix_cache().cached_token_num; @@ -167,7 +162,7 @@ bool EncoderEmbeddingGatherVisitor::visit(MMDataItem& item) { auto& emb = std::get(value); emb = safe_to(emb, device_, true); if (absl::StartsWith(key, gather_prefix_)) { - datas_[key].push_back(emb.index({mask})); + datas_[key].push_back(emb); } } return true; diff --git a/xllm/core/framework/request/mm_type.h b/xllm/core/framework/request/mm_type.h index 713ea8e30..4fe9cd69d 100644 --- a/xllm/core/framework/request/mm_type.h +++ b/xllm/core/framework/request/mm_type.h @@ -51,13 +51,13 @@ class MMType { }; struct ImageMetadata { - int64_t height = 0; - int64_t width = 0; + int32_t height = 0; + int32_t width = 0; }; struct VideoMetadata { double fps = 0.0; // original fps - int64_t total_num_frames = 0; // original frames + int32_t total_num_frames = 0; // original frames double duration = 0.0; double sampled_fps = 0.0; torch::Tensor frame_indices; @@ -65,8 +65,8 @@ struct VideoMetadata { }; struct AudioMetadata { - int64_t sample_rate = 0; - int64_t num_channels = 0; + int32_t sample_rate = 0; + int32_t num_channels = 0; double duration = 0.0; }; diff --git a/xllm/core/framework/request/request_output.h b/xllm/core/framework/request/request_output.h index 3be00d64b..cfcc8fe7c 100644 --- a/xllm/core/framework/request/request_output.h +++ b/xllm/core/framework/request/request_output.h @@ -71,12 +71,21 @@ struct SequenceOutput { // the token ids of the generated text. std::vector token_ids; + // the decoded item id for constrained recommendation output. + std::optional item_ids; + + // decoded item ids for multi-item recommendation output. + std::vector item_ids_list; + // the reason the sequence finished. std::optional finish_reason; // log probabilities of the generated tokens. std::optional> logprobs; + // token-aligned logprobs for REC / OneRec outputs. + std::vector> token_ids_logprobs; + // the embeddings of the prompt token std::optional> embeddings; diff --git a/xllm/core/framework/request/sample_slot_test.cpp b/xllm/core/framework/request/sample_slot_test.cpp index a59282271..5df4e4479 100644 --- a/xllm/core/framework/request/sample_slot_test.cpp +++ b/xllm/core/framework/request/sample_slot_test.cpp @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "common/global_flags.h" #include "framework/block/block_manager_impl.h" #include "platform/device.h" #include "request.h" @@ -103,6 +104,19 @@ class CharTokenizer final : public Tokenizer { static constexpr size_t kEmbTokenLen = sizeof(kEmbToken) - 1; }; +class ScopedBoolFlag final { + public: + ScopedBoolFlag(bool* flag, bool value) : flag_(flag), old_value_(*flag) { + *flag_ = value; + } + + ~ScopedBoolFlag() { *flag_ = old_value_; } + + private: + bool* flag_; + bool old_value_; +}; + TEST(SampleSlotTest, BuildSampleSlotsKeepsMatchOrderAndSampleIds) { CharTokenizer tokenizer; std::vector sample_slots; @@ -345,5 +359,67 @@ TEST(SampleSlotTest, RequestOutputStableSortsOutOfOrderSampleIds) { EXPECT_EQ(output.outputs[2].text, "C"); } +TEST(SampleSlotTest, OneRecOutputCarriesTokenLogprobsWhenEnabled) { + ScopedBoolFlag enable_output_sku_logprobs(&FLAGS_enable_output_sku_logprobs, + true); + ScopedBoolFlag enable_convert_tokens_to_item( + &FLAGS_enable_convert_tokens_to_item, false); + + CharTokenizer tokenizer; + RequestSamplingParam sampling_param; + sampling_param.logprobs = true; + + StoppingChecker stopping_checker; + stopping_checker.set_max_generated_tokens(3); + + RequestState request_state( + /*prompt=*/"", + std::vector{11, 12}, + sampling_param, + SchedulerParam{}, + stopping_checker, + /*seq_capacity=*/8, + /*n=*/1, + /*best_of=*/1, + /*logprobs=*/true, + /*stream=*/false, + /*echo=*/false, + /*skip_special_tokens=*/true, + /*enable_schedule_overlap=*/false, + [](const RequestOutput&) { return true; }, + OutputsFunc{}); + request_state.rec_type = RecType::kOneRec; + + Request request("onerec-score", + /*x_request_id=*/"", + /*x_request_time=*/"", + request_state); + auto* seq = request.sequences()[0].get(); + + Token first_token(101); + first_token.logprob = -0.10f; + seq->append_token(first_token); + + Token second_token(102); + second_token.logprob = -0.20f; + seq->append_token(second_token); + + Token third_token(103); + third_token.logprob = -0.30f; + seq->append_token(third_token); + + RequestOutput output = request.generate_output(tokenizer); + + ASSERT_EQ(output.outputs.size(), 1); + ASSERT_EQ(output.outputs[0].token_ids.size(), 3U); + ASSERT_EQ(output.outputs[0].token_ids_logprobs.size(), 3U); + ASSERT_TRUE(output.outputs[0].token_ids_logprobs[0].has_value()); + ASSERT_TRUE(output.outputs[0].token_ids_logprobs[1].has_value()); + ASSERT_TRUE(output.outputs[0].token_ids_logprobs[2].has_value()); + EXPECT_FLOAT_EQ(output.outputs[0].token_ids_logprobs[0].value(), -0.10f); + EXPECT_FLOAT_EQ(output.outputs[0].token_ids_logprobs[1].value(), -0.20f); + EXPECT_FLOAT_EQ(output.outputs[0].token_ids_logprobs[2].value(), -0.30f); +} + } // namespace } // namespace xllm diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 13304cce2..06ec4985d 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -21,10 +21,13 @@ limitations under the License. #include #include +#include #include #include #include +#include #include +#include #include #include "core/common/global_flags.h" @@ -41,6 +44,32 @@ namespace { constexpr size_t kDecoderBosTokenCount = 1; constexpr size_t kDecoderMaxTokenCount = kRecTotalSteps + kDecoderBosTokenCount; constexpr char kEmptyLogprobsFinishReason[] = "empty_logprobs"; + +std::vector normalize_rec_item_ids(const std::vector& raw_ids, + size_t sequence_index) { + std::vector item_ids; + item_ids.reserve(raw_ids.size()); + std::unordered_set seen_item_ids; + for (const int64_t item_id : raw_ids) { + if (seen_item_ids.insert(item_id).second) { + item_ids.emplace_back(item_id); + } + } + + const int32_t each_threshold = FLAGS_each_conversion_threshold; + if (each_threshold > 0 && + static_cast(item_ids.size()) > each_threshold) { + uint32_t seed = FLAGS_random_seed >= 0 + ? static_cast(FLAGS_random_seed) + + static_cast(sequence_index) + : std::random_device{}(); + std::mt19937 generator(seed); + std::shuffle(item_ids.begin(), item_ids.end(), generator); + item_ids.resize(each_threshold); + } + + return item_ids; +} } // namespace const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; @@ -97,6 +126,7 @@ void Sequence::generate_onerec_streaming_output(const Slice& ids, void Sequence::generate_onerec_output(const Slice& ids, size_t size, + const Tokenizer& tokenizer, SequenceOutput& output) const { output.index = index_; if (output_embedding_.defined()) { @@ -106,6 +136,32 @@ void Sequence::generate_onerec_output(const Slice& ids, output.finish_reason = finish_reason_.to_string(); } output.token_ids = ids.slice(num_prompt_tokens_, size); + if (FLAGS_enable_output_sku_logprobs && logprob_state_ != nullptr) { + const auto& token_logprobs = logprob_state_->get_logprobs(); + output.token_ids_logprobs.reserve(output.token_ids.size()); + for (size_t i = num_prompt_tokens_; i < size; ++i) { + if (i < token_logprobs.size()) { + output.token_ids_logprobs.emplace_back(token_logprobs[i]); + } else { + output.token_ids_logprobs.emplace_back(); + } + } + } + const size_t rec_token_size = static_cast(REC_TOKEN_SIZE); + if (FLAGS_enable_convert_tokens_to_item && + output.token_ids.size() == rec_token_size) { + std::vector item_ids; + const bool ok = tokenizer.decode( + Slice{output.token_ids.data(), output.token_ids.size()}, + sequence_params_.skip_special_tokens, + &item_ids); + if (ok && !item_ids.empty()) { + output.item_ids_list = normalize_rec_item_ids(item_ids, index_); + if (!output.item_ids_list.empty()) { + output.item_ids = output.item_ids_list.front(); + } + } + } } Sequence::Sequence(size_t index, @@ -530,7 +586,7 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { // 3. generate onerec output if (is_onerec_model()) { - generate_onerec_output(ids, size, output); + generate_onerec_output(ids, size, tokenizer, output); return output; } diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index a184d8d92..3e2d00d6f 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -418,6 +418,7 @@ class Sequence final { void generate_onerec_output(const Slice& ids, size_t size, + const Tokenizer& tokenizer, SequenceOutput& output) const; struct OneRecState { diff --git a/xllm/core/framework/request/sequences_group.cpp b/xllm/core/framework/request/sequences_group.cpp index be89e0302..cbdea80e3 100644 --- a/xllm/core/framework/request/sequences_group.cpp +++ b/xllm/core/framework/request/sequences_group.cpp @@ -20,7 +20,7 @@ limitations under the License. #include #include "common/global_flags.h" -#include "core/common/rec_model_utils.h" +#include "core/util/rec_model_utils.h" #include "framework/batch/beam_search.h" #include "util/blocking_counter.h" #include "util/slice.h" @@ -232,17 +232,40 @@ void SequencesGroup::process_beam_search() { const size_t topk = std::max(1, sequence_params_.sampling_param->top_logprobs); - SimpleTopKOptimizerBeamCandidate topk_optimizer(beam_width); - auto add_self_candidate = [&](size_t seq_index, Sequence* seq) { + std::vector source_infos; + source_infos.reserve(sequences_.size()); + + auto build_source_info = [&](Sequence* seq) { + BeamSourceInfo source_info; + source_info.suffix_start_idx = seq->num_prompt_tokens(); + const auto token_ids = seq->tokens(); const auto& log_probs = seq->logprob_state()->get_logprobs(); - const float logprob_sum = seq->get_acc_logprob(); + const size_t generated_token_count = + token_ids.size() - source_info.suffix_start_idx; + source_info.generated_token_ids.reserve(generated_token_count); + source_info.generated_logprobs.reserve(generated_token_count); + for (size_t token_idx = source_info.suffix_start_idx; + token_idx < token_ids.size(); + ++token_idx) { + source_info.generated_token_ids.push_back(token_ids[token_idx]); + source_info.generated_logprobs.push_back(log_probs[token_idx]); + } + source_info.src_blocks.assign(seq->kv_state().kv_blocks().begin(), + seq->kv_state().kv_blocks().end()); + source_infos.emplace_back(std::move(source_info)); + }; + + for (size_t seq_index = 0; seq_index < sequences_.size(); ++seq_index) { + build_source_info(sequences_[seq_index].get()); + } + + SimpleTopKOptimizerBeamCandidate topk_optimizer(beam_width); + auto add_self_candidate = [&](size_t seq_index, Sequence* seq) { BeamCandidate candidate; - candidate.seq_index = seq_index; - candidate.logprob_sum = logprob_sum; - candidate.token_ids = std::vector(token_ids); - candidate.logprobs = log_probs; + candidate.source_index = seq_index; + candidate.logprob_sum = seq->get_acc_logprob(); topk_optimizer.insert(std::move(candidate)); }; @@ -264,8 +287,6 @@ void SequencesGroup::process_beam_search() { } const int32_t last_token_idx = seq->num_tokens() - 1; - Slice token_ids = seq->tokens(); - const auto& log_probs = seq->logprob_state()->get_logprobs(); const auto& top_logprobs = seq->logprob_state()->get_top_logprobs()[last_token_idx]; const auto& top_tokens = @@ -278,19 +299,19 @@ void SequencesGroup::process_beam_search() { } const float base_logprob = seq->get_base_logprob(); + const size_t source_index = i; for (size_t idx = 0; idx < candidate_topk; ++idx) { - float new_logprob = base_logprob + top_logprobs[idx]; + const float new_logprob = base_logprob + top_logprobs[idx]; if (!topk_optimizer.worthInserting(new_logprob)) { break; } BeamCandidate candidate; - candidate.seq_index = i; + candidate.source_index = source_index; candidate.logprob_sum = new_logprob; - candidate.token_ids = std::vector(token_ids); - candidate.logprobs = log_probs; - candidate.token_ids[last_token_idx] = top_tokens[idx]; - candidate.logprobs[last_token_idx] = top_logprobs[idx]; + candidate.override_last_token = true; + candidate.last_token_id = static_cast(top_tokens[idx]); + candidate.last_token_logprob = top_logprobs[idx]; topk_optimizer.insert(std::move(candidate)); } } @@ -303,36 +324,73 @@ void SequencesGroup::process_beam_search() { return; } - std::vector> result; - result.reserve(std::min(beam_width, candidates.size())); + const size_t result_size = std::min(beam_width, candidates.size()); + CHECK(!sequences_.empty()); + + std::vector> replacement_sequences(result_size); + for (size_t i = 0; i < result_size; ++i) { + const BeamCandidate& candidate = candidates[i]; + const BeamSourceInfo& source_info = source_infos[candidate.source_index]; + const bool need_replace = + i >= sequences_.size() || sequences_[i] == nullptr || + sequences_[i]->num_prompt_tokens() != source_info.suffix_start_idx || + sequences_[i]->num_tokens() - source_info.suffix_start_idx != + source_info.generated_token_ids.size(); + if (!need_replace) { + continue; + } + + CHECK_LT(candidate.source_index, sequences_.size()); + CHECK(sequences_[candidate.source_index] != nullptr); + replacement_sequences[i] = + std::make_unique(*sequences_[candidate.source_index]); + } + + if (sequences_.size() < result_size) { + sequences_.resize(result_size); + } + std::unordered_set reused_src; - for (size_t i = 0; i < beam_width && i < candidates.size(); ++i) { - const BeamCandidate& c = candidates[i]; - auto& src_seq = sequences_[c.seq_index]; - auto next_seq = std::make_unique(*src_seq); + for (size_t i = 0; i < result_size; ++i) { + const BeamCandidate& candidate = candidates[i]; + const BeamSourceInfo& source_info = source_infos[candidate.source_index]; + if (replacement_sequences[i] != nullptr) { + sequences_[i] = std::move(replacement_sequences[i]); + } + auto& next_seq = sequences_[i]; + CHECK(next_seq != nullptr); - CHECK_EQ(next_seq->num_tokens(), c.token_ids.size()); - for (size_t token_idx = next_seq->num_prompt_tokens(); + CHECK_EQ(next_seq->num_prompt_tokens(), source_info.suffix_start_idx); + CHECK_EQ(next_seq->num_tokens() - source_info.suffix_start_idx, + source_info.generated_token_ids.size()); + for (size_t token_idx = source_info.suffix_start_idx; token_idx < next_seq->num_tokens(); ++token_idx) { - Token new_token(c.token_ids[token_idx]); - new_token.logprob = c.logprobs[token_idx].has_value() - ? c.logprobs[token_idx].value() - : 0.0f; + const size_t suffix_idx = token_idx - source_info.suffix_start_idx; + Token new_token(source_info.generated_token_ids[suffix_idx]); + new_token.logprob = + source_info.generated_logprobs[suffix_idx].has_value() + ? source_info.generated_logprobs[suffix_idx].value() + : 0.0f; + if (candidate.override_last_token && + token_idx + 1 == next_seq->num_tokens()) { + new_token.id = candidate.last_token_id; + new_token.logprob = candidate.last_token_logprob.has_value() + ? candidate.last_token_logprob.value() + : 0.0f; + } next_seq->update_token(token_idx, new_token); } - next_seq->logprob_state()->set_acc_logprob(c.logprob_sum); + next_seq->logprob_state()->set_acc_logprob(candidate.logprob_sum); next_seq->logprob_state()->set_last_acc_token_idx(next_seq->num_tokens()); next_seq->reset_finish_state_for_beam_search(); - bool need_swap = !reused_src.insert(c.seq_index).second; - auto src_blocks = src_seq->kv_state().kv_blocks(); - next_seq->kv_state().set_src_blocks(src_blocks, need_swap); - result.emplace_back(std::move(next_seq)); + const bool need_swap = !reused_src.insert(candidate.source_index).second; + next_seq->kv_state().set_src_blocks(source_info.src_blocks, need_swap); } - if (!result.empty()) { - sequences_ = std::move(result); + if (sequences_.size() > result_size) { + sequences_.resize(result_size); } } diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp index 0b53331cb..7c13024ae 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.cpp +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -27,20 +27,19 @@ limitations under the License. #include #include "common/global_flags.h" -#include "common/version_singleton.h" #include "framework/state_dict/rec_vocab_dict.h" #include "util/slice.h" #include "util/tensor_helper.h" namespace xllm { -RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version, +RecConstrainedDecoding::RecConstrainedDecoding(RecVocabDict* vocab_dict, const int32_t vocab_size, torch::ScalarType dtype, torch::Device device, bool use_gen_threadpool) : use_gen_threadpool_(use_gen_threadpool), vocab_size_(vocab_size), - model_version_(model_version), + vocab_dict_(vocab_dict), device_(device), dtype_(dtype) { if (use_gen_threadpool_) { @@ -51,6 +50,8 @@ RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version, } bool RecConstrainedDecoding::build_mask_cache() { + CHECK(vocab_dict_ != nullptr) + << "RecVocabDict must be initialized before constrained decoding."; first_token_mask_ = torch::full({vocab_size_}, PRE_MASK_FACTOR, dtype_); std::vector empty_token_ids; @@ -58,9 +59,7 @@ bool RecConstrainedDecoding::build_mask_cache() { empty_token_ids.size()}; const std::unordered_set& first_token_ids = - VersionSingleton::GetInstance( - std::to_string(model_version_)) - ->get_next_tokens_by_prefix_tokens(prefix_token_ids); + vocab_dict_->get_next_tokens_by_prefix_tokens(prefix_token_ids); for (auto token_id : first_token_ids) { first_token_mask_[token_id] = 0; @@ -123,9 +122,7 @@ torch::Tensor RecConstrainedDecoding::generate_decode_mask( Slice tokens_slice(generated_token_list[token_idx]); const std::unordered_set& next_token_ids = - VersionSingleton::GetInstance( - std::to_string(model_version_)) - ->get_next_tokens_by_prefix_tokens(tokens_slice); + vocab_dict_->get_next_tokens_by_prefix_tokens(tokens_slice); if (next_token_ids.size() > 0) { for (int32_t vocab_idx : next_token_ids) { @@ -195,4 +192,4 @@ torch::Tensor RecConstrainedDecoding::generate_decode_mask( return mask; } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.h b/xllm/core/framework/sampling/rec_constrained_decoding.h index 7cf049286..783999cc4 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.h +++ b/xllm/core/framework/sampling/rec_constrained_decoding.h @@ -22,9 +22,11 @@ limitations under the License. namespace xllm { +class RecVocabDict; + class RecConstrainedDecoding : public ConstrainedDecoding { public: - RecConstrainedDecoding(uint64_t model_version, + RecConstrainedDecoding(RecVocabDict* vocab_dict, const int32_t vocab_size, torch::ScalarType dtype, torch::Device device, @@ -48,7 +50,7 @@ class RecConstrainedDecoding : public ConstrainedDecoding { bool build_mask_cache_; bool use_gen_threadpool_; int32_t vocab_size_; - uint64_t model_version_; + RecVocabDict* vocab_dict_ = nullptr; torch::Device device_; torch::ScalarType dtype_; torch::Tensor first_token_mask_; diff --git a/xllm/core/framework/sampling/rec_sampler.cpp b/xllm/core/framework/sampling/rec_sampler.cpp index bd8d94379..025dc51ec 100644 --- a/xllm/core/framework/sampling/rec_sampler.cpp +++ b/xllm/core/framework/sampling/rec_sampler.cpp @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include #include "common/global_flags.h" #include "logits_utils.h" @@ -71,6 +73,57 @@ static inline torch::Tensor log_softmax_last_dim( return torch::log_softmax(logits, /*dim=*/-1); } +static inline void sample_top_candidates(const torch::Tensor& probs, + const torch::Tensor& logprobs, + int64_t top_count, + torch::Tensor* top_tokens, + torch::Tensor* top_logprobs) { + CHECK(top_tokens != nullptr); + CHECK(top_logprobs != nullptr); + CHECK_EQ(probs.dim(), 2) << "probs must be 2D, got " << probs.sizes(); + CHECK_EQ(logprobs.dim(), 2) + << "logprobs must be 2D, got " << logprobs.sizes(); + CHECK_EQ(probs.sizes(), logprobs.sizes()) + << "probs/logprobs shape mismatch, probs=" << probs.sizes() + << ", logprobs=" << logprobs.sizes(); + CHECK_GT(top_count, 0) << "top_count must be positive"; + + const int64_t batch_size = probs.size(0); + auto device = probs.device(); + auto token_options = + torch::TensorOptions().dtype(torch::kLong).device(device); + auto logprob_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); + *top_tokens = torch::empty({batch_size, top_count}, token_options); + *top_logprobs = torch::empty({batch_size, top_count}, logprob_options); + + auto valid_counts = probs.gt(0).sum(/*dim=*/-1).to(torch::kCPU); + for (int64_t row = 0; row < batch_size; ++row) { + auto probs_row = probs[row]; + auto logprobs_row = logprobs[row]; + int64_t valid_count = valid_counts[row].item(); + if (valid_count >= top_count) { + auto sampled = probs_row.multinomial( + /*num_samples=*/top_count, /*replacement=*/false); + auto sampled_logprobs = logprobs_row.gather(/*dim=*/-1, sampled); + torch::Tensor sorted_values; + torch::Tensor sorted_order; + std::tie(sorted_values, sorted_order) = sampled_logprobs.sort( + /*dim=*/-1, /*descending=*/true); + auto sorted_tokens = sampled.gather(/*dim=*/-1, sorted_order); + (*top_tokens)[row].copy_(sorted_tokens); + (*top_logprobs)[row].copy_(sorted_values); + } else { + torch::Tensor topk_values; + torch::Tensor topk_indices; + std::tie(topk_values, topk_indices) = logprobs_row.topk( + top_count, /*dim=*/-1, /*largest=*/true, /*sorted=*/true); + (*top_tokens)[row].copy_(topk_indices); + (*top_logprobs)[row].copy_(topk_values); + } + } +} + } // namespace RecSampler::RecSampler(RecPipelineType pipeline_type) @@ -80,8 +133,9 @@ RecSampler::RecSampler(RecPipelineType pipeline_type) } SampleOutput RecSampler::forward(torch::Tensor& logits, - const SamplingParameters& params) const { - return strategy_->forward(logits, params); + const SamplingParameters& params, + const torch::Tensor& filter_mask) const { + return strategy_->forward(logits, params, filter_mask); } // --- SamplingStrategy factory --- @@ -94,10 +148,11 @@ RecSampler::create_sampling_strategy(RecPipelineType type, return std::make_unique(sampler); case RecPipelineType::kLlmRecDefault: case RecPipelineType::kLlmRecWithMmData: - case RecPipelineType::kOneRecDefault: return std::make_unique(sampler); + case RecPipelineType::kOneRecDefault: + return std::make_unique(sampler); default: - LOG(FATAL) << "Unknown RecPipelineType: " << static_cast(type); + LOG(FATAL) << "Unknown RecPipelineType: " << static_cast(type); __builtin_unreachable(); } } @@ -110,8 +165,102 @@ RecSampler::DefaultSamplingStrategy::DefaultSamplingStrategy( SampleOutput RecSampler::DefaultSamplingStrategy::forward( torch::Tensor& logits, - const SamplingParameters& params) const { - return sampler_.forward(logits, params); + const SamplingParameters& params, + const torch::Tensor& filter_mask) const { + return sampler_.forward(logits, params, filter_mask); +} + +// --- OneRecConstrainedSamplingStrategy --- + +RecSampler::OneRecConstrainedSamplingStrategy:: + OneRecConstrainedSamplingStrategy(const Sampler& sampler) + : sampler_(sampler) {} + +SampleOutput RecSampler::OneRecConstrainedSamplingStrategy::forward( + torch::Tensor& logits, + const SamplingParameters& params, + const torch::Tensor& filter_mask) const { + if (!(params.use_beam_search && params.all_random_sample && params.logprobs && + params.max_top_logprobs > 0)) { + return sampler_.forward(logits, params, filter_mask); + } + + if (params.frequency_penalties.defined()) { + apply_frequency_presence_penalties(logits, + params.unique_token_ids, + params.unique_token_counts, + params.frequency_penalties, + params.presence_penalties); + } + + if (params.repetition_penalties.defined()) { + apply_repetition_penalties( + logits, params.unique_token_ids, params.repetition_penalties); + } + + torch::Tensor sample_logits = logits; + torch::Tensor sample_temperatures = params.temperatures; + torch::Tensor sample_top_k = params.top_k; + torch::Tensor sample_top_p = params.top_p; + const bool use_sample_indices = + params.selected_token_idxes.numel() != params.sample_idxes.numel(); + if (use_sample_indices) { + sample_logits = logits.index_select(/*dim=*/0, params.sample_idxes); + if (params.temperatures.defined()) { + sample_temperatures = + params.temperatures.index_select(/*dim=*/0, params.sample_idxes); + } + if (params.top_k.defined()) { + sample_top_k = params.top_k.index_select(/*dim=*/0, params.sample_idxes); + } + if (params.top_p.defined()) { + sample_top_p = params.top_p.index_select(/*dim=*/0, params.sample_idxes); + } + } + + if (filter_mask.defined()) { + CHECK_EQ(filter_mask.dim(), 2) + << "filter_mask must be 2-D, dim=" << filter_mask.dim(); + CHECK_EQ(filter_mask.size(0), sample_logits.size(0)) + << "filter_mask batch mismatch, filter_mask.size(0)=" + << filter_mask.size(0) + << ", sample_logits.size(0)=" << sample_logits.size(0); + CHECK_EQ(filter_mask.size(1), sample_logits.size(1)) + << "filter_mask vocab mismatch, filter_mask.size(1)=" + << filter_mask.size(1) + << ", sample_logits.size(1)=" << sample_logits.size(1); + sample_logits = sample_logits + filter_mask; + } + + apply_top_k_top_p( + sample_logits, sample_temperatures, sample_top_k, sample_top_p); + if (use_sample_indices) { + logits.index_copy_(/*dim=*/0, params.sample_idxes, sample_logits); + } + + CHECK(params.do_sample.defined()) << "params.do_sample must be defined"; + CHECK_EQ(params.do_sample.dim(), 1) + << "params.do_sample must be 1D [num_seqs], got " + << params.do_sample.sizes(); + CHECK_EQ(sample_logits.size(0), params.do_sample.size(0)); + + SampleOutput output; + auto probs = + torch::softmax(sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); + output.probs = probs.to(logits.dtype()); + auto logprobs = + torch::log_softmax(sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); + + const int64_t vocab_size = probs.size(-1); + const int64_t top_count = std::min(params.max_top_logprobs, + static_cast(vocab_size)); + sample_top_candidates( + probs, logprobs, top_count, &output.top_tokens, &output.top_logprobs); + output.next_tokens = + output.top_tokens.select(/*dim=*/1, /*index=*/0).to(torch::kLong); + output.logprobs = + output.top_logprobs.select(/*dim=*/1, /*index=*/0).contiguous(); + return output; } // --- MultiRoundFastPathSamplingStrategy --- @@ -122,7 +271,9 @@ RecSampler::MultiRoundFastPathSamplingStrategy:: SampleOutput RecSampler::MultiRoundFastPathSamplingStrategy::forward( torch::Tensor& logits, - const SamplingParameters& params) const { + const SamplingParameters& params, + const torch::Tensor& filter_mask) const { + (void)filter_mask; const bool use_fast_path = can_use_fast_path(params); if (!use_fast_path) { diff --git a/xllm/core/framework/sampling/rec_sampler.h b/xllm/core/framework/sampling/rec_sampler.h index 2cbd9b8f3..088599815 100644 --- a/xllm/core/framework/sampling/rec_sampler.h +++ b/xllm/core/framework/sampling/rec_sampler.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "common/rec_model_utils.h" #include "sampling_params.h" +#include "util/rec_model_utils.h" namespace xllm { @@ -31,22 +31,37 @@ class RecSampler { ~RecSampler() = default; // logits: [batch_size, vocab_size] - SampleOutput forward(torch::Tensor& logits, - const SamplingParameters& params) const; + SampleOutput forward( + torch::Tensor& logits, + const SamplingParameters& params, + const torch::Tensor& filter_mask = torch::Tensor()) const; private: class SamplingStrategy { public: virtual ~SamplingStrategy() = default; virtual SampleOutput forward(torch::Tensor& logits, - const SamplingParameters& params) const = 0; + const SamplingParameters& params, + const torch::Tensor& filter_mask) const = 0; }; class DefaultSamplingStrategy final : public SamplingStrategy { public: explicit DefaultSamplingStrategy(const Sampler& sampler); SampleOutput forward(torch::Tensor& logits, - const SamplingParameters& params) const override; + const SamplingParameters& params, + const torch::Tensor& filter_mask) const override; + + private: + const Sampler& sampler_; + }; + + class OneRecConstrainedSamplingStrategy final : public SamplingStrategy { + public: + explicit OneRecConstrainedSamplingStrategy(const Sampler& sampler); + SampleOutput forward(torch::Tensor& logits, + const SamplingParameters& params, + const torch::Tensor& filter_mask) const override; private: const Sampler& sampler_; @@ -56,7 +71,8 @@ class RecSampler { public: explicit MultiRoundFastPathSamplingStrategy(const Sampler& sampler); SampleOutput forward(torch::Tensor& logits, - const SamplingParameters& params) const override; + const SamplingParameters& params, + const torch::Tensor& filter_mask) const override; private: const Sampler& sampler_; diff --git a/xllm/core/framework/sampling/rejection_sampler.cpp b/xllm/core/framework/sampling/rejection_sampler.cpp index c9c425c6e..424c9dec7 100644 --- a/xllm/core/framework/sampling/rejection_sampler.cpp +++ b/xllm/core/framework/sampling/rejection_sampler.cpp @@ -105,8 +105,11 @@ RejectionSampler::RejectionSampler( rate_controller_(rate_controller), enable_fused_kernel_(enable_fused_kernel) { CHECK(do_sample.defined()); - // [batch_size, 1] - do_sample_ = do_sample.unsqueeze_(/*dim=*/-1); + // Keep a private expanded view and do not mutate the caller-owned tensor. + // The same SamplingParameters object is reused later by MTP draft extend. + // An in-place unsqueeze here corrupts Sampler::forward() mixed-mode shape + // assumptions and can broadcast sampled token ids into 2D. + do_sample_ = do_sample.unsqueeze(/*dim=*/-1); } // draft_token_ids: [batch_size, n_speculative_tokens] diff --git a/xllm/core/framework/sampling/rejection_sampler_test.cpp b/xllm/core/framework/sampling/rejection_sampler_test.cpp index 4e2219a6c..e50e2a4f5 100644 --- a/xllm/core/framework/sampling/rejection_sampler_test.cpp +++ b/xllm/core/framework/sampling/rejection_sampler_test.cpp @@ -254,6 +254,55 @@ TEST(RejectionSamplerTest, LogProbs) { EXPECT_TRUE(torch::equal(output.top_tokens, top_k_indices)); } +TEST(RejectionSamplerTest, ConstructorDoesNotMutateDoSampleShape) { + const auto device = get_test_device(); + auto do_sample = torch::tensor({false, true}, torch::device(device)); + + ASSERT_EQ(do_sample.dim(), 1); + ASSERT_EQ(do_sample.sizes(), torch::IntArrayRef({2})); + + RejectionSampler sampler(do_sample, + do_sample.all().item(), + !do_sample.any().item(), + /*logprobs=*/false, + /*max_top_logprobs=*/0); + + EXPECT_EQ(do_sample.dim(), 1); + EXPECT_EQ(do_sample.sizes(), torch::IntArrayRef({2})); +} + +TEST(RejectionSamplerTest, + ReusingDoSampleAfterRejectionSamplerKeepsSamplerOutput1D) { + const auto options = get_test_options(torch::kFloat32); + const auto device = get_test_device(); + auto do_sample = torch::tensor({false, true}, torch::device(device)); + + RejectionSampler rejection_sampler(do_sample, + do_sample.all().item(), + !do_sample.any().item(), + /*logprobs=*/false, + /*max_top_logprobs=*/0); + (void)rejection_sampler; + + SamplingParameters params; + params.selected_token_idxes = + torch::tensor({0, 1}, torch::dtype(torch::kInt64).device(device)); + params.sample_idxes = + torch::tensor({0, 1}, torch::dtype(torch::kInt64).device(device)); + params.do_sample = do_sample; + params.all_random_sample = false; + params.all_greedy_sample = false; + + auto logits = + torch::tensor({{3.0f, 1.0f, 0.5f}, {0.1f, 0.2f, 4.0f}}, options); + auto output = Sampler().forward(logits, params); + + EXPECT_EQ(output.probs.dim(), 2); + EXPECT_EQ(output.probs.size(0), 2); + EXPECT_EQ(output.next_tokens.dim(), 1); + EXPECT_EQ(output.next_tokens.size(0), 2); +} + TEST(RejectionSamplerTest, Random) { const auto options = get_test_options(torch::kFloat32); diff --git a/xllm/core/framework/sampling/sampler.cpp b/xllm/core/framework/sampling/sampler.cpp index 318b2f6e6..f28803bad 100644 --- a/xllm/core/framework/sampling/sampler.cpp +++ b/xllm/core/framework/sampling/sampler.cpp @@ -26,7 +26,8 @@ limitations under the License. namespace xllm { SampleOutput Sampler::forward(torch::Tensor& logits, - const SamplingParameters& params) const { + const SamplingParameters& params, + const torch::Tensor& filter_mask) const { SampleOutput output; // apply frequency and presence penalties if (params.frequency_penalties.defined()) { @@ -43,37 +44,70 @@ SampleOutput Sampler::forward(torch::Tensor& logits, logits, params.unique_token_ids, params.repetition_penalties); } - // apply temperatures, top-k and top-p - apply_top_k_top_p(logits, params.temperatures, params.top_k, params.top_p); - torch::Tensor sample_logits = logits; - if (params.selected_token_idxes.numel() != params.sample_idxes.numel()) { + torch::Tensor sample_temperatures = params.temperatures; + torch::Tensor sample_top_k = params.top_k; + torch::Tensor sample_top_p = params.top_p; + const bool use_sample_indices = + params.selected_token_idxes.numel() != params.sample_idxes.numel(); + if (use_sample_indices) { sample_logits = logits.index_select(/*dim=*/0, params.sample_idxes); + if (params.temperatures.defined()) { + sample_temperatures = + params.temperatures.index_select(/*dim=*/0, params.sample_idxes); + } + if (params.top_k.defined()) { + sample_top_k = params.top_k.index_select(/*dim=*/0, params.sample_idxes); + } + if (params.top_p.defined()) { + sample_top_p = params.top_p.index_select(/*dim=*/0, params.sample_idxes); + } + } + + if (filter_mask.defined()) { + CHECK_EQ(filter_mask.dim(), 2) + << "filter_mask must be 2-D, dim=" << filter_mask.dim(); + CHECK_EQ(filter_mask.size(0), sample_logits.size(0)) + << "filter_mask batch mismatch, filter_mask.size(0)=" + << filter_mask.size(0) + << ", sample_logits.size(0)=" << sample_logits.size(0); + CHECK_EQ(filter_mask.size(1), sample_logits.size(1)) + << "filter_mask vocab mismatch, filter_mask.size(1)=" + << filter_mask.size(1) + << ", sample_logits.size(1)=" << sample_logits.size(1); + sample_logits = sample_logits + filter_mask; + } + + // apply temperatures, top-k and top-p + apply_top_k_top_p( + sample_logits, sample_temperatures, sample_top_k, sample_top_p); + if (use_sample_indices) { + logits.index_copy_(/*dim=*/0, params.sample_idxes, sample_logits); } + CHECK(params.do_sample.defined()) << "params.do_sample must be defined"; + CHECK_EQ(params.do_sample.dim(), 1) + << "params.do_sample must be 1D [num_seqs], got " + << params.do_sample.sizes(); // same batch size CHECK_EQ(sample_logits.size(0), params.do_sample.size(0)); - auto probs = sample_logits; + auto probs = + torch::softmax(sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); torch::Tensor samples; if (params.all_random_sample) { - // use float32 for probabilities and log probabilities - probs = - torch::softmax(sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); samples = random_sample(probs); } else if (params.all_greedy_sample) { samples = greedy_sample(probs); } else { - // use float32 for probabilities and log probabilities - probs = - torch::softmax(sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); // mixed sample, sample both then choose based on do_sample auto random = random_sample(probs); auto greedy = greedy_sample(probs); samples = torch::where(params.do_sample, random, greedy); } + auto sample_indices = samples.to(torch::kLong); output.probs = probs.to(logits.dtype()); - output.next_tokens = samples; + output.next_tokens = sample_indices; if (params.logprobs) { if (FLAGS_enable_qwen3_reranker) { @@ -92,7 +126,8 @@ SampleOutput Sampler::forward(torch::Tensor& logits, const auto logprobs = torch::log_softmax( sample_logits, /*dim=*/-1, /*dtype=*/torch::kFloat32); // select the logprobs for each sequence - auto selected_logprobs = logprobs.gather(/*dim=*/-1, samples.view({-1, 1})); + auto selected_logprobs = + logprobs.gather(/*dim=*/-1, sample_indices.view({-1, 1})); output.logprobs = selected_logprobs.view({-1}); if (params.max_top_logprobs > 0) { @@ -111,7 +146,7 @@ torch::Tensor Sampler::greedy_sample(const torch::Tensor& probs) { } torch::Tensor Sampler::random_sample(const torch::Tensor& probs) { -#if defined(USE_MLU) +#if defined(USE_MLU) || defined(USE_CUDA) xllm::kernel::RandomSampleParams params; params.logits = probs; return xllm::kernel::random_sample(params); diff --git a/xllm/core/framework/sampling/sampler.h b/xllm/core/framework/sampling/sampler.h index 0e5e5f023..6fbf03573 100644 --- a/xllm/core/framework/sampling/sampler.h +++ b/xllm/core/framework/sampling/sampler.h @@ -34,8 +34,10 @@ class Sampler final { } // logits: [batch_size, vocab_size] - SampleOutput forward(torch::Tensor& logits, - const SamplingParameters& params) const; + SampleOutput forward( + torch::Tensor& logits, + const SamplingParameters& params, + const torch::Tensor& filter_mask = torch::Tensor()) const; // helper functions // probs: [..., vocab_size] diff --git a/xllm/core/framework/sampling/sampling_params.cpp b/xllm/core/framework/sampling/sampling_params.cpp index a626ded14..253d0d073 100644 --- a/xllm/core/framework/sampling/sampling_params.cpp +++ b/xllm/core/framework/sampling/sampling_params.cpp @@ -42,6 +42,12 @@ void SamplingParameters::init( std::vector temperatures; std::vector top_p; std::vector top_k; + frequency_penalties.reserve(req_sampling_params.size()); + presence_penalties.reserve(req_sampling_params.size()); + repetition_penalties.reserve(req_sampling_params.size()); + temperatures.reserve(req_sampling_params.size()); + top_p.reserve(req_sampling_params.size()); + top_k.reserve(req_sampling_params.size()); bool logprobs = false; int64_t max_top_logprobs = 0; bool is_embeddings = false; @@ -128,6 +134,7 @@ void SamplingParameters::init( // construct do sample tensor std::vector do_sample; + do_sample.reserve(sample_idxes.size()); for (const auto idx : sample_idxes) { const auto* p = req_sampling_params[idx]; // need to do sample if any of following is true diff --git a/xllm/core/framework/state_dict/state_dict.cpp b/xllm/core/framework/state_dict/state_dict.cpp index 7c128a1f3..eab17d90e 100644 --- a/xllm/core/framework/state_dict/state_dict.cpp +++ b/xllm/core/framework/state_dict/state_dict.cpp @@ -193,6 +193,11 @@ StateDict StateDict::get_dict_with_prefix( return tensors; } +bool StateDict::has(const std::string& tensor_name) const { + const auto it = dict_.find(tensor_name); + return it != dict_.end(); +} + StateDictFromSafeTensor::StateDictFromSafeTensor( std::unique_ptr mem_map, std::unordered_map dict) diff --git a/xllm/core/framework/state_dict/state_dict.h b/xllm/core/framework/state_dict/state_dict.h index d6ec51859..37208ae1e 100644 --- a/xllm/core/framework/state_dict/state_dict.h +++ b/xllm/core/framework/state_dict/state_dict.h @@ -58,6 +58,8 @@ class StateDict { size_t size() const { return dict_.size(); } + bool has(const std::string& tensor_name) const; + std::string_view prefix() const { return prefix_; } auto begin() const { return dict_.begin(); } diff --git a/xllm/core/framework/tokenizer/tokenizer_proxy.cpp b/xllm/core/framework/tokenizer/tokenizer_proxy.cpp index a174d762c..7f3119be2 100644 --- a/xllm/core/framework/tokenizer/tokenizer_proxy.cpp +++ b/xllm/core/framework/tokenizer/tokenizer_proxy.cpp @@ -36,11 +36,22 @@ bool TokenizerProxy::encode(const std::string_view& text, return get_tls_tokenizer()->encode(text, ids, add_special_tokens); } +bool TokenizerProxy::encode(int64_t item_id, + std::vector* token_ids) const { + return get_tls_tokenizer()->encode(item_id, token_ids); +} + std::string TokenizerProxy::decode(const Slice& ids, bool skip_special_tokens) const { return get_tls_tokenizer()->decode(ids, skip_special_tokens); } +bool TokenizerProxy::decode(const Slice& token_ids, + bool skip_special_tokens, + std::vector* item_ids) const { + return get_tls_tokenizer()->decode(token_ids, skip_special_tokens, item_ids); +} + std::optional TokenizerProxy::token_to_id( const std::string_view& token) const { return get_tls_tokenizer()->token_to_id(token); @@ -58,4 +69,4 @@ Tokenizer* TokenizerProxy::get_tls_tokenizer() const { thread_local std::unique_ptr tls_tokenizer(tokenizer_->clone()); return tls_tokenizer.get(); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/tokenizer/tokenizer_proxy.h b/xllm/core/framework/tokenizer/tokenizer_proxy.h index b214c8136..e22b02f75 100644 --- a/xllm/core/framework/tokenizer/tokenizer_proxy.h +++ b/xllm/core/framework/tokenizer/tokenizer_proxy.h @@ -29,9 +29,15 @@ class TokenizerProxy : public Tokenizer { std::vector* ids, bool add_special_tokens = true) const override; + bool encode(int64_t item_id, std::vector* token_ids) const override; + std::string decode(const Slice& ids, bool skip_special_tokens) const override; + bool decode(const Slice& token_ids, + bool skip_special_tokens, + std::vector* item_ids) const override; + std::optional token_to_id( const std::string_view& token) const override; @@ -46,4 +52,4 @@ class TokenizerProxy : public Tokenizer { std::unique_ptr tokenizer_; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/xtensor/page_allocator.cpp b/xllm/core/framework/xtensor/page_allocator.cpp index f80b7210a..da00d4bcd 100644 --- a/xllm/core/framework/xtensor/page_allocator.cpp +++ b/xllm/core/framework/xtensor/page_allocator.cpp @@ -282,7 +282,7 @@ bool PageAllocator::wakeup_model(const std::string& model_id) { size_t pages_needed = virt_page_ids.size() * phy_pages_per_virt; auto [start_w, end_w] = get_dp_group_worker_range(model_id, dp_rank); total_phy_pages_needed += pages_needed; - groups_to_map.push_back({dp_rank, std::move(virt_page_ids)}); + groups_to_map.emplace_back(dp_rank, std::move(virt_page_ids)); for (int32_t w = start_w; w < end_w && w < max_world_size_; ++w) { pages_to_consume_per_worker[w] += pages_needed; } diff --git a/xllm/core/framework/xtensor/phy_page.h b/xllm/core/framework/xtensor/phy_page.h index d042cf033..89e628f30 100644 --- a/xllm/core/framework/xtensor/phy_page.h +++ b/xllm/core/framework/xtensor/phy_page.h @@ -26,7 +26,7 @@ using page_id_t = int64_t; class PhyPage { public: - // Constructor with page_id (-1 means unassigned, e.g., for zero page) + // Constructor with page_id (-1 means unassigned) PhyPage(torch::Device device, page_id_t page_id = -1); ~PhyPage(); @@ -43,4 +43,4 @@ class PhyPage { PhyMemHandle phy_handle_; page_id_t page_id_; // Unique identifier for this page in the pool }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/xtensor/phy_page_pool.cpp b/xllm/core/framework/xtensor/phy_page_pool.cpp index 7980c6c46..f35335927 100644 --- a/xllm/core/framework/xtensor/phy_page_pool.cpp +++ b/xllm/core/framework/xtensor/phy_page_pool.cpp @@ -36,10 +36,6 @@ void PhyPagePool::init(const torch::Device& device, size_t num_pages) { LOG(INFO) << "PhyPagePool: pre-allocating " << num_pages << " physical pages on device " << device; - // Pre-allocate zero page first (used by all XTensors for initialization) - // Zero page has page_id = -1 - zero_page_ = std::make_unique(device_, -1); - // Pre-allocate all physical pages for data with unique page_ids all_pages_.reserve(num_pages); page_allocated_.resize(num_pages, false); @@ -55,8 +51,7 @@ void PhyPagePool::init(const torch::Device& device, size_t num_pages) { initialized_ = true; LOG(INFO) << "PhyPagePool: successfully pre-allocated " << num_pages - << " physical pages (page_id 0-" << (num_pages - 1) - << ") + 1 zero page"; + << " physical pages (page_id 0-" << (num_pages - 1) << ")"; } std::unique_ptr PhyPagePool::get() { @@ -296,13 +291,6 @@ size_t PhyPagePool::num_available() const { return free_page_ids_.size(); } -PhyPage* PhyPagePool::get_zero_page() { - std::lock_guard lock(mtx_); - CHECK(initialized_) << "PhyPagePool not initialized"; - CHECK(zero_page_) << "Zero page not created"; - return zero_page_.get(); -} - // ============== Global XTensor Support ============== const std::vector& PhyPagePool::get_all_pages() const { diff --git a/xllm/core/framework/xtensor/phy_page_pool.h b/xllm/core/framework/xtensor/phy_page_pool.h index b8886be7d..f2147474f 100644 --- a/xllm/core/framework/xtensor/phy_page_pool.h +++ b/xllm/core/framework/xtensor/phy_page_pool.h @@ -90,10 +90,6 @@ class PhyPagePool { // Get the device const torch::Device& device() const { return device_; } - // Get the zero page (for initializing virtual memory) - // The returned pointer is owned by PhyPagePool, do not delete it - PhyPage* get_zero_page(); - // ============== Global XTensor Support ============== // Get all pages as raw pointers for GlobalXTensor mapping @@ -124,9 +120,6 @@ class PhyPagePool { // Track which pages are allocated (for segment management) std::vector page_allocated_; - - // Zero page for initializing virtual memory (owned by pool) - std::unique_ptr zero_page_; }; } // namespace xllm diff --git a/xllm/core/framework/xtensor/xtensor.cpp b/xllm/core/framework/xtensor/xtensor.cpp index 28f6d0951..b978928e4 100644 --- a/xllm/core/framework/xtensor/xtensor.cpp +++ b/xllm/core/framework/xtensor/xtensor.cpp @@ -53,7 +53,14 @@ static inline void unmap_and_release_virtual_mem(VirPtr vaddr, vmm::release_vir_ptr(vaddr, size); } -static inline void return_owned_pages_to_pool( +static inline void release_virtual_mem(VirPtr vaddr, size_t size) { + if (is_null_vir_ptr(vaddr)) { + return; + } + vmm::release_vir_ptr(vaddr, size); +} + +static inline void return_pages_to_pool( std::unordered_map>& mapping) { std::vector> pages_to_return; pages_to_return.reserve(mapping.size()); @@ -67,6 +74,31 @@ static inline void return_owned_pages_to_pool( } } +static inline void unmap_pages( + VirPtr vaddr, + size_t page_size, + const std::unordered_map>& mapping) { + if (is_null_vir_ptr(vaddr)) { + return; + } + + for (const auto& entry : mapping) { + VirPtr addr = + add_vir_ptr_offset(vaddr, static_cast(entry.first) * page_size); + vmm::unmap(addr, page_size); + } +} + +static inline void cleanup_pages_and_vmem( + VirPtr vaddr, + size_t size, + size_t page_size, + std::unordered_map>& mapping) { + unmap_pages(vaddr, page_size, mapping); + return_pages_to_pool(mapping); + release_virtual_mem(vaddr, size); +} + static inline void free_preallocated_weight_pages( const std::vector& page_ids) { if (page_ids.empty()) { @@ -78,20 +110,15 @@ static inline void free_preallocated_weight_pages( << " preallocated weight pages"; } -XTensor::XTensor(size_t size, - torch::Dtype dtype, - torch::Device dev, - PhyPage* zero_page) +XTensor::XTensor(size_t size, torch::Dtype dtype, torch::Device dev) : vaddr_(0), size_(0), page_size_(FLAGS_phy_page_granularity_size), dtype_(dtype), - dev_(dev), - zero_page_(zero_page) { + dev_(dev) { // Align size to page_size_ size_ = align_up(size, page_size_); vaddr_ = alloc_virtual_mem(size_); - init_with_zero_(); } XTensor::XTensor(const std::vector& page_ids, @@ -102,7 +129,6 @@ XTensor::XTensor(const std::vector& page_ids, page_size_(FLAGS_phy_page_granularity_size), dtype_(dtype), dev_(dev), - zero_page_(nullptr), use_preallocated_pages_(true), preallocated_page_ids_(page_ids) { if (page_ids.empty()) { @@ -128,10 +154,7 @@ XTensor::~XTensor() { return; } - return_owned_pages_to_pool(mapping_); - // zero_page_ is not owned, don't delete it - - unmap_and_release_virtual_mem(vaddr_, size_, page_size_); + cleanup_pages_and_vmem(vaddr_, size_, page_size_, mapping_); } bool XTensor::map(offset_t offset) { @@ -154,8 +177,6 @@ bool XTensor::map(offset_t offset) { // Map the physical page VirPtr vaddr = add_vir_ptr_offset(vaddr_, offset); - vmm::unmap(vaddr, page_size_); - PhyMemHandle phy_handle = phy_pages[0]->get_phy_handle(); vmm::map(vaddr, phy_handle); @@ -178,9 +199,6 @@ bool XTensor::unmap(offset_t offset) { VirPtr vaddr = add_vir_ptr_offset(vaddr_, offset); vmm::unmap(vaddr, page_size_); - // Map the zero page instead to ensure memory integrity - map_phy_page_(zero_page_, offset); - // Return the physical page to pool std::vector> pages_to_return; pages_to_return.push_back(std::move(it->second)); @@ -262,23 +280,6 @@ bool XTensor::map_phy_page_(PhyPage* page, offset_t offset) { return true; } -bool XTensor::init_with_zero_() { - CHECK(vir_ptr_to_uintptr(vaddr_) % page_size_ == 0) - << "vaddr not aligned to page size"; - CHECK(size_ % page_size_ == 0) << "size not aligned to page size"; - - bool succ = true; - - // Initialize all pages with zero page - for (size_t offset = 0; offset < size_; offset += page_size_) { - if (!map_phy_page_(zero_page_, offset)) { - succ = false; - break; - } - } - return succ; -} - bool XTensor::allocate(void*& ptr, size_t size) { // Check if there's enough space if (alloc_offset_ + size > size_) { diff --git a/xllm/core/framework/xtensor/xtensor.h b/xllm/core/framework/xtensor/xtensor.h index 3063a0467..51adb005b 100644 --- a/xllm/core/framework/xtensor/xtensor.h +++ b/xllm/core/framework/xtensor/xtensor.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "common/global_flags.h" #include "common/macros.h" @@ -33,10 +34,7 @@ using offset_t = page_id_t; /* NOTE: XTensorAllocator is thread-safe but XTensor is not. */ class XTensor { public: - XTensor(size_t size, - torch::Dtype dtype, - torch::Device dev, - PhyPage* zero_page); + XTensor(size_t size, torch::Dtype dtype, torch::Device dev); // Constructor for weight tensor using pre-allocated page_ids (non-contiguous) // page_ids: physical page IDs from PhyPagePool (allocated via @@ -105,14 +103,12 @@ class XTensor { private: // Map a single physical page at the given offset bool map_phy_page_(PhyPage* page, offset_t offset); - bool init_with_zero_(); VirPtr vaddr_; size_t size_; size_t page_size_; // Page size (FLAGS_phy_page_granularity_size) torch::Dtype dtype_; torch::Device dev_; - PhyPage* zero_page_; // Not owned, managed by PhyPagePool // Maps page id -> PhyPage (page id = offset / page_size_) std::unordered_map> mapping_; diff --git a/xllm/core/framework/xtensor/xtensor_allocator.cpp b/xllm/core/framework/xtensor/xtensor_allocator.cpp index 4ebdf8db1..46307ae87 100644 --- a/xllm/core/framework/xtensor/xtensor_allocator.cpp +++ b/xllm/core/framework/xtensor/xtensor_allocator.cpp @@ -59,7 +59,6 @@ XTensorAllocator::~XTensorAllocator() { void XTensorAllocator::destroy() { std::lock_guard lock(mtx_); model_tensors_.clear(); - zero_page_ = nullptr; // Not owned, just clear pointer xtensor_dist_clients_.clear(); xtensor_dist_servers_.clear(); initialized_ = false; @@ -553,10 +552,6 @@ std::vector XTensorAllocator::create_kv_tensors_impl_( model.num_layers = num_layers; model.kv_tensor_size_per_layer = size; - if (!zero_page_) { - zero_page_ = PhyPagePool::get_instance().get_zero_page(); - } - return create_tensors_internal_( size, dims, dtype, num_layers, *target_tensors); } @@ -687,12 +682,12 @@ void XTensorAllocator::record_weight_fallback_allocation( if (sorted_pages[i] == sorted_pages[i - 1] + 1) { seg_size += page_size; } else { - tensors.weight_segments.push_back({seg_offset, seg_size}); + tensors.weight_segments.emplace_back(seg_offset, seg_size); seg_offset = static_cast(sorted_pages[i]) * page_size; seg_size = page_size; } } - tensors.weight_segments.push_back({seg_offset, seg_size}); + tensors.weight_segments.emplace_back(seg_offset, seg_size); } LOG(INFO) << "XTensorAllocator: recorded XTensor allocation for model " @@ -763,7 +758,7 @@ std::vector XTensorAllocator::create_tensors_internal_( tensors_out.reserve(num_layers); for (int64_t i = 0; i < num_layers; i++) { - auto xtensor = std::make_unique(size, dtype, dev_, zero_page_); + auto xtensor = std::make_unique(size, dtype, dev_); tensors.push_back(xtensor->to_torch_tensor(0, dims)); tensors_out.push_back(std::move(xtensor)); } diff --git a/xllm/core/framework/xtensor/xtensor_allocator.h b/xllm/core/framework/xtensor/xtensor_allocator.h index 2e3ab7c9d..133e184ee 100644 --- a/xllm/core/framework/xtensor/xtensor_allocator.h +++ b/xllm/core/framework/xtensor/xtensor_allocator.h @@ -273,9 +273,6 @@ class XTensorAllocator { // Per-model tensors storage (key: model_id) std::unordered_map model_tensors_; - // Zero page pointer (owned by PhyPagePool, not this class) - PhyPage* zero_page_ = nullptr; - // Multi-node XTensor dist members int32_t world_size_ = 0; // total workers = dp_size * tp_size int32_t dp_size_ = 1; diff --git a/xllm/core/kernels/cuda/CMakeLists.txt b/xllm/core/kernels/cuda/CMakeLists.txt index 2567d3930..a19686ce7 100644 --- a/xllm/core/kernels/cuda/CMakeLists.txt +++ b/xllm/core/kernels/cuda/CMakeLists.txt @@ -44,7 +44,54 @@ set(CUDA_HEADER_FILES moe/moe_topk_softmax_kernels.cuh ) +# +# Per-SM-architecture CUTLASS libraries. +# +# CUTLASS template instantiations are extremely expensive to compile. Splitting each SM generation +# into its own static library improves incremental build times — changing one SM's CUTLASS code +# only recompiles that library, not all CUTLASS files. + +set(_CUTLASS_SM_LIBS "") + +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0) + # --- SM90 (Hopper) --- + cc_library( + NAME cutlass_sm90 + SRCS + cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu + cutlass_w8a8/scaled_mm_c3x_sm90.cu + DEPS + torch + tvm_ffi + ) + + # --- SM100 (Blackwell) --- + cc_library( + NAME cutlass_sm100 + SRCS + cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu + cutlass_w8a8/scaled_mm_c3x_sm100.cu + DEPS + torch + tvm_ffi + ) + + # --- SM120 --- + cc_library( + NAME cutlass_sm120 + SRCS + cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu + cutlass_w8a8/scaled_mm_c3x_sm120.cu + DEPS + torch + tvm_ffi + ) + + set(_CUTLASS_SM_LIBS :cutlass_sm90 :cutlass_sm100 :cutlass_sm120) +endif() + # Keep source list explicit to avoid accidentally compiling test files into cuda_kernels. +# NOTE: SM-specific CUTLASS sources are compiled via the per-arch libraries above. set(CUDA_SOURCE_FILES activation.cu air_log_softmax_last_dim.cu @@ -52,13 +99,8 @@ set(CUDA_SOURCE_FILES batch_decode.cpp batch_prefill.cpp batch_chunked_prefill.cpp + block_copy.cu cutlass_extensions/common.cpp - cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu - cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu - cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu - cutlass_w8a8/scaled_mm_c3x_sm100.cu - cutlass_w8a8/scaled_mm_c3x_sm120.cu - cutlass_w8a8/scaled_mm_c3x_sm90.cu cutlass_w8a8/scaled_mm_entry.cu fp8_quant.cu fp8_scaled_matmul.cpp @@ -71,6 +113,7 @@ set(CUDA_SOURCE_FILES rec_beam_search.cu reshape_paged_cache.cu rope.cu + random_sample.cpp utils.cpp xattention/beam_search.cpp xattention/cache_select.cu @@ -92,6 +135,7 @@ cc_library( torch :util :platform + ${_CUTLASS_SM_LIBS} ) cc_test( @@ -106,6 +150,24 @@ cc_test( glog::glog ) +option(XLLM_ENABLE_BLOCK_COPY_TEST + "Build and register the expensive CUDA block_copy_test" + OFF) + +if(XLLM_ENABLE_BLOCK_COPY_TEST) + cc_test( + NAME + block_copy_test + SRCS + block_copy_test.cpp + DEPS + :cuda_kernels + torch + GTest::gtest_main + glog::glog + ) +endif() + cc_test( NAME decoder_reshape_and_cache_test diff --git a/xllm/core/kernels/cuda/air_log_softmax_last_dim.cu b/xllm/core/kernels/cuda/air_log_softmax_last_dim.cu index 8b06a191e..fcc447e37 100644 --- a/xllm/core/kernels/cuda/air_log_softmax_last_dim.cu +++ b/xllm/core/kernels/cuda/air_log_softmax_last_dim.cu @@ -25,6 +25,7 @@ #include #include "cuda_ops_api.h" +#include "device_utils.cuh" #include "utils.h" namespace xllm::kernel::cuda { @@ -69,7 +70,7 @@ __global__ void log_softmax_last_dim_kernel(const scalar_t* __restrict__ input, thread_max = s_data[col] > thread_max ? s_data[col] : thread_max; } float row_max_local = - BlockReduce(reduce_storage).Reduce(thread_max, cub::Max()); + BlockReduce(reduce_storage).Reduce(thread_max, MaxReduceOp()); if (threadIdx.x == 0) { s_row_max = row_max_local; } diff --git a/xllm/core/kernels/cuda/block_copy.cu b/xllm/core/kernels/cuda/block_copy.cu new file mode 100644 index 000000000..5195b5c29 --- /dev/null +++ b/xllm/core/kernels/cuda/block_copy.cu @@ -0,0 +1,209 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cuda_ops_api.h" +#include "utils.h" + +namespace xllm::kernel::cuda { +namespace { + +template +struct VecType; + +template <> +struct VecType { + using type = uint4; + static constexpr int32_t vec_width = 8; +}; + +template <> +struct VecType { + using type = uint4; + static constexpr int32_t vec_width = 8; +}; + +template <> +struct VecType { + using type = float4; + static constexpr int32_t vec_width = 4; +}; + +DEVICE_INLINE int32_t find_group_idx(const int32_t* __restrict__ cum_sum, + const int32_t num_groups, + const int32_t dst_idx) { + int32_t left = 0; + int32_t right = num_groups - 1; + while (left < right) { + const int32_t mid = left + ((right - left) >> 1); + const bool move_left = dst_idx < cum_sum[mid]; + right = move_left ? mid : right; + left = move_left ? left : mid + 1; + } + return left; +} + +template +__global__ void block_copy_kernel(const int64_t* __restrict__ key_cache_ptrs, + const int64_t* __restrict__ value_cache_ptrs, + const int32_t* __restrict__ src_block_indices, + const int32_t* __restrict__ dst_block_indices, + const int32_t* __restrict__ cum_sum, + const int32_t num_groups, + const int64_t numel_per_block) { + const int64_t layer_idx = static_cast(blockIdx.x); + const int32_t dst_linear_idx = static_cast(blockIdx.y); + const int64_t tile_idx = static_cast(blockIdx.z); + + scalar_t* __restrict__ key_cache = reinterpret_cast( + static_cast(key_cache_ptrs[layer_idx])); + scalar_t* __restrict__ value_cache = reinterpret_cast( + static_cast(value_cache_ptrs[layer_idx])); + + const int32_t group_idx = find_group_idx(cum_sum, num_groups, dst_linear_idx); + const int32_t src_block = src_block_indices[group_idx]; + const int32_t dst_block = dst_block_indices[dst_linear_idx]; + const int64_t src_offset = static_cast(src_block) * numel_per_block; + const int64_t dst_offset = static_cast(dst_block) * numel_per_block; + + if constexpr (kVectorized) { + using VecTypeT = typename VecType::type; + constexpr int32_t kVecWidth = VecType::vec_width; + const int64_t num_vecs_per_block = numel_per_block / kVecWidth; + const int64_t vec_idx = tile_idx * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (vec_idx >= num_vecs_per_block) { + return; + } + + const int64_t elem_offset = vec_idx * kVecWidth; + const auto* key_src_vec = + reinterpret_cast(key_cache + src_offset + elem_offset); + const auto* value_src_vec = reinterpret_cast( + value_cache + src_offset + elem_offset); + auto* key_dst_vec = + reinterpret_cast(key_cache + dst_offset + elem_offset); + auto* value_dst_vec = + reinterpret_cast(value_cache + dst_offset + elem_offset); + *key_dst_vec = *key_src_vec; + *value_dst_vec = *value_src_vec; + } else { + const int64_t elem_idx = tile_idx * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (elem_idx >= numel_per_block) { + return; + } + + key_cache[dst_offset + elem_idx] = key_cache[src_offset + elem_idx]; + value_cache[dst_offset + elem_idx] = value_cache[src_offset + elem_idx]; + } +} + +} // namespace + +void block_copy(torch::Tensor key_cache_ptrs, + torch::Tensor value_cache_ptrs, + torch::Tensor src_block_indices, + torch::Tensor dst_block_indices, + torch::Tensor cum_sum, + int64_t numel_per_block, + torch::ScalarType cache_dtype) { + if (src_block_indices.numel() == 0) { + return; + } + + CHECK(key_cache_ptrs.is_cuda()); + CHECK(value_cache_ptrs.is_cuda()); + CHECK(src_block_indices.is_cuda()); + CHECK(dst_block_indices.is_cuda()); + CHECK(cum_sum.is_cuda()); + CHECK_EQ(key_cache_ptrs.scalar_type(), torch::kInt64); + CHECK_EQ(value_cache_ptrs.scalar_type(), torch::kInt64); + CHECK_EQ(src_block_indices.scalar_type(), torch::kInt32); + CHECK_EQ(dst_block_indices.scalar_type(), torch::kInt32); + CHECK_EQ(cum_sum.scalar_type(), torch::kInt32); + CHECK_EQ(key_cache_ptrs.dim(), 1); + CHECK_EQ(value_cache_ptrs.dim(), 1); + CHECK_EQ(src_block_indices.dim(), 1); + CHECK_EQ(dst_block_indices.dim(), 1); + CHECK_EQ(cum_sum.dim(), 1); + CHECK(key_cache_ptrs.is_contiguous()); + CHECK(value_cache_ptrs.is_contiguous()); + CHECK(src_block_indices.is_contiguous()); + CHECK(dst_block_indices.is_contiguous()); + CHECK(cum_sum.is_contiguous()); + CHECK_EQ(key_cache_ptrs.size(0), value_cache_ptrs.size(0)); + CHECK_EQ(src_block_indices.size(0), cum_sum.size(0)); + CHECK_GT(numel_per_block, 0); + + const at::cuda::OptionalCUDAGuard device_guard(key_cache_ptrs.device()); + constexpr int32_t kThreadsPerBlock = 256; + const int32_t num_layers = static_cast(key_cache_ptrs.size(0)); + const int32_t num_groups = static_cast(src_block_indices.size(0)); + const int32_t num_dst_blocks = + static_cast(dst_block_indices.size(0)); + const cudaStream_t stream = + c10::cuda::getCurrentCUDAStream(key_cache_ptrs.get_device()); + + DISPATCH_FLOATING_TYPES(cache_dtype, "block_copy_kernel", [&] { + constexpr bool kHasVecType = std::is_same_v || + std::is_same_v || + std::is_same_v; + + if constexpr (kHasVecType) { + constexpr int32_t kVecWidth = VecType::vec_width; + if (numel_per_block % kVecWidth == 0) { + const int64_t tiles_per_block = + ceil_div(numel_per_block / kVecWidth, kThreadsPerBlock); + const dim3 grid(num_layers, num_dst_blocks, tiles_per_block); + block_copy_kernel + <<>>( + key_cache_ptrs.data_ptr(), + value_cache_ptrs.data_ptr(), + src_block_indices.data_ptr(), + dst_block_indices.data_ptr(), + cum_sum.data_ptr(), + num_groups, + numel_per_block); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return; + } + } + + const int64_t tiles_per_block = + ceil_div(numel_per_block, kThreadsPerBlock); + const dim3 grid(num_layers, num_dst_blocks, tiles_per_block); + block_copy_kernel<<>>( + key_cache_ptrs.data_ptr(), + value_cache_ptrs.data_ptr(), + src_block_indices.data_ptr(), + dst_block_indices.data_ptr(), + cum_sum.data_ptr(), + num_groups, + numel_per_block); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +} // namespace xllm::kernel::cuda diff --git a/xllm/core/kernels/cuda/block_copy_test.cpp b/xllm/core/kernels/cuda/block_copy_test.cpp new file mode 100644 index 000000000..ed622df14 --- /dev/null +++ b/xllm/core/kernels/cuda/block_copy_test.cpp @@ -0,0 +1,688 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "core/kernels/cuda/cuda_ops_api.h" + +namespace xllm::kernel::cuda { +namespace test { +namespace { + +struct BlockCopyCaseConfig { + int64_t num_layers; + int64_t num_blocks; + int64_t block_size; + int64_t num_heads; + int64_t head_dim; + std::vector src_blocks; + std::vector dst_blocks; + std::vector cum_sum; +}; + +struct PerfBenchmarkCase { + std::string name; + BlockCopyCaseConfig config; + torch::ScalarType dtype; + int32_t warmup_iters; + int32_t measure_iters; +}; + +struct PerfCompareResult { + double kernel_ms; + double native_ms; + double speedup; + double logical_copy_gbps; + double traffic_gbps; +}; + +struct BlockCopyLaunchInputs { + torch::Tensor key_ptr_tensor; + torch::Tensor value_ptr_tensor; + torch::Tensor src_tensor; + torch::Tensor dst_tensor; + torch::Tensor cum_sum_tensor; + int64_t numel_per_block; +}; + +struct NativeBlockCopyLaunchInputs { + torch::Tensor src_tensor; + torch::Tensor dst_tensor; +}; + +void apply_reference_block_copy(const std::vector& key_caches, + const std::vector& value_caches, + const std::vector& src_blocks, + const std::vector& dst_blocks, + const std::vector& cum_sum, + std::vector& ref_k_caches, + std::vector& ref_v_caches) { + for (size_t layer_idx = 0; layer_idx < key_caches.size(); ++layer_idx) { + ref_k_caches[layer_idx] = key_caches[layer_idx].clone(); + ref_v_caches[layer_idx] = value_caches[layer_idx].clone(); + } + + for (size_t group_idx = 0; group_idx < src_blocks.size(); ++group_idx) { + const int32_t src_block = src_blocks[group_idx]; + const int32_t dst_begin = group_idx == 0 ? 0 : cum_sum[group_idx - 1]; + const int32_t dst_end = cum_sum[group_idx]; + for (int32_t dst_idx = dst_begin; dst_idx < dst_end; ++dst_idx) { + const int32_t dst_block = dst_blocks[dst_idx]; + for (size_t layer_idx = 0; layer_idx < ref_k_caches.size(); ++layer_idx) { + ref_k_caches[layer_idx][dst_block].copy_( + ref_k_caches[layer_idx][src_block]); + ref_v_caches[layer_idx][dst_block].copy_( + ref_v_caches[layer_idx][src_block]); + } + } + } +} + +std::vector flatten_src_blocks_for_native( + const std::vector& src_blocks, + const std::vector& cum_sum) { + std::vector flat_src_blocks; + flat_src_blocks.reserve(cum_sum.empty() ? 0 : cum_sum.back()); + for (size_t group_idx = 0; group_idx < src_blocks.size(); ++group_idx) { + const int32_t begin = group_idx == 0 ? 0 : cum_sum[group_idx - 1]; + const int32_t end = cum_sum[group_idx]; + for (int32_t dst_idx = begin; dst_idx < end; ++dst_idx) { + flat_src_blocks.push_back(src_blocks[group_idx]); + } + } + return flat_src_blocks; +} + +void native_block_copy(std::vector& key_caches, + std::vector& value_caches, + const NativeBlockCopyLaunchInputs& launch_inputs) { + for (size_t layer_idx = 0; layer_idx < key_caches.size(); ++layer_idx) { + auto selected_keys = + torch::index_select(key_caches[layer_idx], 0, launch_inputs.src_tensor); + auto selected_values = torch::index_select( + value_caches[layer_idx], 0, launch_inputs.src_tensor); + key_caches[layer_idx].index_copy_( + 0, launch_inputs.dst_tensor, selected_keys); + value_caches[layer_idx].index_copy_( + 0, launch_inputs.dst_tensor, selected_values); + } +} + +BlockCopyLaunchInputs prepare_block_copy_launch_inputs( + const std::vector& key_caches, + const std::vector& value_caches, + const std::vector& src_blocks, + const std::vector& dst_blocks, + const std::vector& cum_sum, + const torch::Device& device) { + std::vector key_ptrs; + std::vector value_ptrs; + key_ptrs.reserve(key_caches.size()); + value_ptrs.reserve(value_caches.size()); + for (size_t layer_idx = 0; layer_idx < key_caches.size(); ++layer_idx) { + key_ptrs.push_back( + reinterpret_cast(key_caches[layer_idx].data_ptr())); + value_ptrs.push_back( + reinterpret_cast(value_caches[layer_idx].data_ptr())); + } + + const auto ptr_opts = + torch::TensorOptions().device(device).dtype(torch::kInt64); + const auto idx_opts = + torch::TensorOptions().device(device).dtype(torch::kInt32); + return { + .key_ptr_tensor = torch::tensor(key_ptrs, ptr_opts), + .value_ptr_tensor = torch::tensor(value_ptrs, ptr_opts), + .src_tensor = torch::tensor(src_blocks, idx_opts), + .dst_tensor = torch::tensor(dst_blocks, idx_opts), + .cum_sum_tensor = torch::tensor(cum_sum, idx_opts), + .numel_per_block = key_caches[0][0].numel(), + }; +} + +NativeBlockCopyLaunchInputs prepare_native_block_copy_launch_inputs( + const std::vector& src_blocks, + const std::vector& dst_blocks, + const std::vector& cum_sum, + const torch::Device& device) { + auto flat_src_blocks = flatten_src_blocks_for_native(src_blocks, cum_sum); + std::vector flat_dst_blocks(dst_blocks.begin(), dst_blocks.end()); + return { + .src_tensor = torch::tensor( + flat_src_blocks, + torch::TensorOptions().device(device).dtype(torch::kLong)), + .dst_tensor = torch::tensor( + flat_dst_blocks, + torch::TensorOptions().device(device).dtype(torch::kLong)), + }; +} + +void kernel_block_copy(const BlockCopyLaunchInputs& launch_inputs, + torch::ScalarType dtype) { + block_copy(launch_inputs.key_ptr_tensor, + launch_inputs.value_ptr_tensor, + launch_inputs.src_tensor, + launch_inputs.dst_tensor, + launch_inputs.cum_sum_tensor, + launch_inputs.numel_per_block, + dtype); +} + +double measure_cuda_time_ms(const std::function& fn, + int32_t warmup_iters, + int32_t measure_iters) { + for (int32_t iter = 0; iter < warmup_iters; ++iter) { + fn(); + } + torch::cuda::synchronize(); + + const auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::CUDAEvent start_event(cudaEventDefault); + at::cuda::CUDAEvent stop_event(cudaEventDefault); + start_event.record(stream); + for (int32_t iter = 0; iter < measure_iters; ++iter) { + fn(); + } + stop_event.record(stream); + stop_event.synchronize(); + + const float elapsed_ms = start_event.elapsed_time(stop_event); + return static_cast(elapsed_ms) / measure_iters; +} + +std::vector make_random_caches(const BlockCopyCaseConfig& config, + const torch::Device& device, + torch::ScalarType dtype) { + std::vector caches; + caches.reserve(config.num_layers); + const auto opts = torch::TensorOptions().device(device).dtype(dtype); + for (int64_t layer_idx = 0; layer_idx < config.num_layers; ++layer_idx) { + caches.push_back(torch::randn({config.num_blocks, + config.block_size, + config.num_heads, + config.head_dim}, + opts)); + } + return caches; +} + +void expect_caches_allclose(const std::vector& lhs, + const std::vector& rhs, + double rtol, + double atol) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (size_t idx = 0; idx < lhs.size(); ++idx) { + EXPECT_TRUE(torch::allclose(lhs[idx], rhs[idx], rtol, atol)) + << "cache mismatch at layer=" << idx; + } +} + +void run_accuracy_compare_case(const BlockCopyCaseConfig& config, + torch::ScalarType dtype, + double rtol, + double atol) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA not available, skipping test."; + } + + torch::manual_seed(2026); + const auto device = torch::Device(torch::kCUDA, 0); + + auto base_key_caches = make_random_caches(config, device, dtype); + auto base_value_caches = make_random_caches(config, device, dtype); + auto kernel_key_caches = base_key_caches; + auto kernel_value_caches = base_value_caches; + auto native_key_caches = base_key_caches; + auto native_value_caches = base_value_caches; + + auto kernel_launch_inputs = + prepare_block_copy_launch_inputs(kernel_key_caches, + kernel_value_caches, + config.src_blocks, + config.dst_blocks, + config.cum_sum, + device); + auto native_launch_inputs = prepare_native_block_copy_launch_inputs( + config.src_blocks, config.dst_blocks, config.cum_sum, device); + + kernel_block_copy(kernel_launch_inputs, dtype); + native_block_copy( + native_key_caches, native_value_caches, native_launch_inputs); + torch::cuda::synchronize(); + + expect_caches_allclose(kernel_key_caches, native_key_caches, rtol, atol); + expect_caches_allclose(kernel_value_caches, native_value_caches, rtol, atol); +} + +std::string dtype_to_string(torch::ScalarType dtype) { + switch (dtype) { + case torch::kHalf: + return "fp16"; + case torch::kBFloat16: + return "bf16"; + case torch::kFloat: + return "fp32"; + default: + return c10::toString(dtype); + } +} + +int64_t get_total_dst_copies(const BlockCopyCaseConfig& config) { + return static_cast(config.dst_blocks.size()); +} + +double get_logical_copy_bytes_per_iter(const BlockCopyCaseConfig& config, + torch::ScalarType dtype) { + const int64_t numel_per_block = + config.block_size * config.num_heads * config.head_dim; + const int64_t bytes_per_elem = c10::elementSize(dtype); + const int64_t total_dst_copies = get_total_dst_copies(config); + const int64_t total_elements = + 2LL * config.num_layers * total_dst_copies * numel_per_block; + return static_cast(total_elements) * bytes_per_elem; +} + +double get_traffic_bytes_per_iter(const BlockCopyCaseConfig& config, + torch::ScalarType dtype) { + return get_logical_copy_bytes_per_iter(config, dtype) * 2.0; +} + +std::string format_perf_case_summary(const PerfBenchmarkCase& benchmark_case, + const PerfCompareResult& result) { + const auto& config = benchmark_case.config; + const int64_t total_dst_copies = get_total_dst_copies(config); + const double avg_fanout = + config.src_blocks.empty() + ? 0.0 + : static_cast(total_dst_copies) / + static_cast(config.src_blocks.size()); + std::ostringstream oss; + oss << "block_copy bench [" << benchmark_case.name + << "] dtype=" << dtype_to_string(benchmark_case.dtype) + << ", layers=" << config.num_layers << ", blocks=" << config.num_blocks + << ", block_size=" << config.block_size << ", heads=" << config.num_heads + << ", head_dim=" << config.head_dim + << ", groups=" << config.src_blocks.size() + << ", total_dst=" << total_dst_copies << ", avg_fanout=" << avg_fanout + << ", kernel=" << result.kernel_ms << " ms" + << ", native=" << result.native_ms << " ms" + << ", speedup=" << result.speedup << "x" + << ", logical_bw=" << result.logical_copy_gbps << " GB/s" + << ", traffic_bw=" << result.traffic_gbps << " GB/s"; + return oss.str(); +} + +PerfCompareResult run_perf_compare_case(const BlockCopyCaseConfig& config, + torch::ScalarType dtype, + int32_t warmup_iters, + int32_t measure_iters) { + torch::manual_seed(2026); + const auto device = torch::Device(torch::kCUDA, 0); + + auto kernel_key_caches = make_random_caches(config, device, dtype); + auto kernel_value_caches = make_random_caches(config, device, dtype); + auto native_key_caches = kernel_key_caches; + auto native_value_caches = kernel_value_caches; + + auto kernel_launch_inputs = + prepare_block_copy_launch_inputs(kernel_key_caches, + kernel_value_caches, + config.src_blocks, + config.dst_blocks, + config.cum_sum, + device); + auto native_launch_inputs = prepare_native_block_copy_launch_inputs( + config.src_blocks, config.dst_blocks, config.cum_sum, device); + + const double kernel_ms = measure_cuda_time_ms( + [&]() { kernel_block_copy(kernel_launch_inputs, dtype); }, + warmup_iters, + measure_iters); + + const double native_ms = measure_cuda_time_ms( + [&]() { + native_block_copy( + native_key_caches, native_value_caches, native_launch_inputs); + }, + warmup_iters, + measure_iters); + + expect_caches_allclose(kernel_key_caches, native_key_caches, 1e-5, 1e-5); + expect_caches_allclose(kernel_value_caches, native_value_caches, 1e-5, 1e-5); + + const double logical_copy_bytes = + get_logical_copy_bytes_per_iter(config, dtype); + const double traffic_bytes = get_traffic_bytes_per_iter(config, dtype); + const double speedup = native_ms / kernel_ms; + const double logical_copy_gbps = logical_copy_bytes / (kernel_ms * 1.0e6); + const double traffic_gbps = traffic_bytes / (kernel_ms * 1.0e6); + + EXPECT_GT(kernel_ms, 0.0); + EXPECT_GT(native_ms, 0.0); + return PerfCompareResult{ + .kernel_ms = kernel_ms, + .native_ms = native_ms, + .speedup = speedup, + .logical_copy_gbps = logical_copy_gbps, + .traffic_gbps = traffic_gbps, + }; +} + +void run_multi_shape_perf_benchmark( + const std::vector& benchmark_cases) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA not available, skipping test."; + } + for (const auto& benchmark_case : benchmark_cases) { + SCOPED_TRACE(benchmark_case.name); + const auto result = run_perf_compare_case(benchmark_case.config, + benchmark_case.dtype, + benchmark_case.warmup_iters, + benchmark_case.measure_iters); + LOG(INFO) << format_perf_case_summary(benchmark_case, result); + } +} + +} // namespace + +TEST(BlockCopyTest, KernelMatchesReferenceFp16) { + run_accuracy_compare_case( + BlockCopyCaseConfig{ + .num_layers = 3, + .num_blocks = 8, + .block_size = 4, + .num_heads = 2, + .head_dim = 8, + .src_blocks = {1, 4}, + .dst_blocks = {2, 3, 6}, + .cum_sum = {2, 3}, + }, + torch::kHalf, + 1e-5, + 1e-5); +} + +TEST(BlockCopyTest, KernelMatchesReferenceFp32) { + run_accuracy_compare_case( + BlockCopyCaseConfig{ + .num_layers = 4, + .num_blocks = 10, + .block_size = 8, + .num_heads = 3, + .head_dim = 16, + .src_blocks = {1, 4, 7}, + .dst_blocks = {2, 3, 5, 8, 9}, + .cum_sum = {2, 4, 5}, + }, + torch::kFloat, + 1e-6, + 1e-6); +} + +TEST(BlockCopyTest, KernelMatchesNativeFp16) { + run_accuracy_compare_case( + BlockCopyCaseConfig{ + .num_layers = 6, + .num_blocks = 32, + .block_size = 16, + .num_heads = 4, + .head_dim = 32, + .src_blocks = {1, 4, 9, 12}, + .dst_blocks = {2, 3, 5, 6, 10, 11, 20}, + .cum_sum = {2, 4, 6, 7}, + }, + torch::kHalf, + 1e-5, + 1e-5); +} + +TEST(BlockCopyTest, KernelMatchesNativeFp32) { + run_accuracy_compare_case( + BlockCopyCaseConfig{ + .num_layers = 5, + .num_blocks = 24, + .block_size = 12, + .num_heads = 3, + .head_dim = 24, + .src_blocks = {1, 4, 9}, + .dst_blocks = {2, 3, 5, 6, 10, 11}, + .cum_sum = {2, 4, 6}, + }, + torch::kFloat, + 1e-6, + 1e-6); +} + +TEST(BlockCopyTest, PerfCompareKernelVsNativeMultiShapeFp16) { + run_multi_shape_perf_benchmark({ + PerfBenchmarkCase{ + .name = "tiny_balanced", + .config = + BlockCopyCaseConfig{ + .num_layers = 4, + .num_blocks = 32, + .block_size = 16, + .num_heads = 4, + .head_dim = 32, + .src_blocks = {1, 4, 8, 12}, + .dst_blocks = {2, 3, 5, 6, 9, 10, 13, 14}, + .cum_sum = {2, 4, 6, 8}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 150, + }, + PerfBenchmarkCase{ + .name = "tiny_high_fanout", + .config = + BlockCopyCaseConfig{ + .num_layers = 4, + .num_blocks = 48, + .block_size = 16, + .num_heads = 4, + .head_dim = 32, + .src_blocks = {1, 8}, + .dst_blocks = {2, 3, 4, 5, 6, 9, 10, 11, 12, 13}, + .cum_sum = {5, 10}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 150, + }, + PerfBenchmarkCase{ + .name = "medium_balanced", + .config = + BlockCopyCaseConfig{ + .num_layers = 8, + .num_blocks = 64, + .block_size = 64, + .num_heads = 8, + .head_dim = 128, + .src_blocks = {1, 4, 8, 12, 16, 20, 24, 28}, + .dst_blocks = {2, + 3, + 5, + 6, + 9, + 10, + 13, + 14, + 17, + 18, + 21, + 22, + 25, + 26, + 29, + 30}, + .cum_sum = {2, 4, 6, 8, 10, 12, 14, 16}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 120, + }, + PerfBenchmarkCase{ + .name = "large_many_layers", + .config = + BlockCopyCaseConfig{ + .num_layers = 32, + .num_blocks = 256, + .block_size = 64, + .num_heads = 8, + .head_dim = 128, + .src_blocks = {1, 9, 17, 25, 33, 41, 49, 57}, + .dst_blocks = {2, + 3, + 10, + 11, + 18, + 19, + 26, + 27, + 34, + 35, + 42, + 43, + 50, + 51, + 58, + 59}, + .cum_sum = {2, 4, 6, 8, 10, 12, 14, 16}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 80, + }, + PerfBenchmarkCase{ + .name = "large_high_fanout", + .config = + BlockCopyCaseConfig{ + .num_layers = 16, + .num_blocks = 256, + .block_size = 64, + .num_heads = 8, + .head_dim = 128, + .src_blocks = {1, 33, 65, 97}, + .dst_blocks = {2, 3, 4, 5, 6, 34, 35, 36, 37, 38, + 66, 67, 68, 69, 70, 98, 99, 100, 101, 102}, + .cum_sum = {5, 10, 15, 20}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 80, + }, + PerfBenchmarkCase{ + .name = "many_groups_sparse", + .config = + BlockCopyCaseConfig{ + .num_layers = 16, + .num_blocks = 256, + .block_size = 32, + .num_heads = 8, + .head_dim = 128, + .src_blocks = {1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45}, + .dst_blocks = {2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + 34, + 38, + 42, + 46, + 3, + 7, + 11, + 15, + 19, + 23}, + .cum_sum = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 18}, + }, + .dtype = torch::kHalf, + .warmup_iters = 20, + .measure_iters = 100, + }, + }); +} + +TEST(BlockCopyTest, PerfCompareKernelVsNativeMultiShapeFp32) { + run_multi_shape_perf_benchmark({ + PerfBenchmarkCase{ + .name = "fp32_medium_balanced", + .config = + BlockCopyCaseConfig{ + .num_layers = 8, + .num_blocks = 64, + .block_size = 32, + .num_heads = 8, + .head_dim = 64, + .src_blocks = {1, 4, 8, 12, 16, 20}, + .dst_blocks = {2, 3, 5, 6, 9, 10, 13, 14, 17, 18, 21, 22}, + .cum_sum = {2, 4, 6, 8, 10, 12}, + }, + .dtype = torch::kFloat, + .warmup_iters = 20, + .measure_iters = 120, + }, + PerfBenchmarkCase{ + .name = "fp32_large_many_layers", + .config = + BlockCopyCaseConfig{ + .num_layers = 24, + .num_blocks = 192, + .block_size = 64, + .num_heads = 8, + .head_dim = 64, + .src_blocks = {1, 9, 17, 25, 33, 41, 49, 57}, + .dst_blocks = {2, + 3, + 10, + 11, + 18, + 19, + 26, + 27, + 34, + 35, + 42, + 43, + 50, + 51, + 58, + 59}, + .cum_sum = {2, 4, 6, 8, 10, 12, 14, 16}, + }, + .dtype = torch::kFloat, + .warmup_iters = 20, + .measure_iters = 80, + }, + }); +} + +} // namespace test +} // namespace xllm::kernel::cuda diff --git a/xllm/core/kernels/cuda/cuda_ops_api.h b/xllm/core/kernels/cuda/cuda_ops_api.h index b88519212..d9b503dc8 100644 --- a/xllm/core/kernels/cuda/cuda_ops_api.h +++ b/xllm/core/kernels/cuda/cuda_ops_api.h @@ -47,6 +47,14 @@ void reshape_paged_cache( torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim] torch::Tensor value_cache); +void block_copy(torch::Tensor key_cache_ptrs, + torch::Tensor value_cache_ptrs, + torch::Tensor src_block_indices, + torch::Tensor dst_block_indices, + torch::Tensor cum_sum, + int64_t numel_per_block, + torch::ScalarType cache_dtype); + void batch_prefill(const std::string& uri, ffi::Array plan_info, torch::Tensor float_workspace_buffer, @@ -246,4 +254,6 @@ std::tuple moe_fused_topk( bool renormalize, const std::optional& correction_bias, const std::string& scoring_func); + +torch::Tensor random_sample(const torch::Tensor& probs); } // namespace xllm::kernel::cuda diff --git a/xllm/core/kernels/cuda/device_utils.cuh b/xllm/core/kernels/cuda/device_utils.cuh index 2a027651b..e44db2945 100644 --- a/xllm/core/kernels/cuda/device_utils.cuh +++ b/xllm/core/kernels/cuda/device_utils.cuh @@ -42,8 +42,8 @@ class alignas(Alignment) AlignedArray { // Define reduction operators based on CUDA version // CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum #if CUDA_VERSION >= 12090 -using MaxReduceOp = cuda::maximum<>; -using MinReduceOp = cuda::minimum<>; +using MaxReduceOp = ::cuda::maximum<>; +using MinReduceOp = ::cuda::minimum<>; #else using MaxReduceOp = cub::Max; using MinReduceOp = cub::Min; diff --git a/xllm/core/kernels/cuda/norm.cu b/xllm/core/kernels/cuda/norm.cu index 761eb5553..90de7e533 100644 --- a/xllm/core/kernels/cuda/norm.cu +++ b/xllm/core/kernels/cuda/norm.cu @@ -27,8 +27,8 @@ limitations under the License. #if CUB_VERSION >= 200800 #include -using CubAddOp = cuda::std::plus<>; -using CubMaxOp = cuda::maximum<>; +using CubAddOp = ::cuda::std::plus<>; +using CubMaxOp = ::cuda::maximum<>; #else // if CUB_VERSION < 200800 using CubAddOp = cub::Sum; using CubMaxOp = cub::Max; diff --git a/xllm/core/kernels/cuda/random_sample.cpp b/xllm/core/kernels/cuda/random_sample.cpp new file mode 100644 index 000000000..cdfb4bf06 --- /dev/null +++ b/xllm/core/kernels/cuda/random_sample.cpp @@ -0,0 +1,82 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include + +#include "cuda_ops_api.h" + +namespace { + +at::Generator get_default_generator(c10::DeviceIndex device_index) { + static std::unordered_map cache; + static std::mutex mu; + std::lock_guard lock(mu); + auto it = cache.find(device_index); + if (it != cache.end()) { + return it->second; + } + at::globalContext().lazyInitCUDA(); + at::Generator gen = at::cuda::detail::getDefaultCUDAGenerator(device_index); + cache.emplace(device_index, gen); + return gen; +} + +std::tuple get_seed_and_offset( + int64_t increment, + const torch::Device& device, + c10::optional generator = c10::nullopt) { + at::Generator gen = generator.has_value() + ? generator.value() + : get_default_generator(device.index()); + std::lock_guard lock(gen.mutex()); + auto* cuda_gen = at::check_generator(gen); + + int64_t seed = static_cast(cuda_gen->current_seed()); + int64_t offset = static_cast(cuda_gen->get_offset()); + offset += (increment + 3) / 4 * 4; + cuda_gen->set_offset(static_cast(offset)); + + return std::make_tuple(seed, offset); +} +} // namespace + +namespace xllm::kernel::cuda { + +torch::Tensor random_sample(const torch::Tensor& probs) { + CHECK_EQ(probs.dim(), 2) << "probs must be a 2D tensor"; + const torch::Device device = probs.device(); + int64_t batch_size = probs.size(0); + torch::ScalarType out_dtype = torch::kInt32; + torch::Tensor samples = + torch::empty({batch_size}, torch::dtype(out_dtype).device(device)); + auto [seed, offset] = get_seed_and_offset(batch_size, device); + + get_function(/*uri=*/"sampling", + /*func_name=*/"sampling_from_probs")( + to_ffi_tensor(probs), + to_ffi_tensor(samples), + /*maybe_indices=*/ffi::Optional(), + /*deterministic=*/true, + /*philox_seed=*/seed, + /*philox_offset=*/offset); + return samples; +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/tilelang/README.md b/xllm/core/kernels/cuda/tilelang/README.md new file mode 100644 index 000000000..8074637ce --- /dev/null +++ b/xllm/core/kernels/cuda/tilelang/README.md @@ -0,0 +1,10 @@ +# TileLang CUDA Runtime Wrappers + +This directory is reserved for CUDA-side TileLang runtime wrappers. + +Compiler-side Python TileLang kernel definitions live under: + +- `xllm/xllm/compiler/tilelang/targets/cuda/kernels` + +Runtime wrapper code should stay in this directory so the compiler input and +device execution glue remain separated. diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index 269019061..c37e0ca1d 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -1,5 +1,7 @@ include(cc_library) +add_subdirectory(tilelang) + add_subdirectory(xllm_ops) cc_library( @@ -12,6 +14,7 @@ cc_library( attention.cpp fused_layernorm.cpp matmul.cpp + npu_gemma_rms_norm.cpp npu_grouped_matmul.cpp npu_moe_gating_topk_softmax.cpp npu_moe_init_routing_v2.cpp @@ -19,4 +22,6 @@ cc_library( rope.cpp DEPS :torch_npu_kernels + :tilelang_kernels + ascendc_ops ) diff --git a/xllm/core/kernels/npu/npu_gemma_rms_norm.cpp b/xllm/core/kernels/npu/npu_gemma_rms_norm.cpp new file mode 100644 index 000000000..3dc2fd733 --- /dev/null +++ b/xllm/core/kernels/npu/npu_gemma_rms_norm.cpp @@ -0,0 +1,31 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_ops_api.h" + +// Include ascendc_ops_api.h for npu_ops::npu_gemma_rms_norm +#include "ascendc_npu/ascendc_ops_api.h" + +namespace xllm::kernel::npu { + +void npu_gemma_rms_norm(const torch::Tensor& x, + const torch::Tensor& gamma, + double epsilon, + torch::Tensor& rstdOut, + torch::Tensor& yOut) { + npu_ops::npu_gemma_rms_norm(x, gamma, epsilon, rstdOut, yOut); +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index ef1a12783..6d85ca0df 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -68,6 +68,12 @@ torch::Tensor rms_norm(const torch::Tensor& input, double eps, const std::string& mode); +void npu_gemma_rms_norm(const torch::Tensor& x, + const torch::Tensor& gamma, + double epsilon, + torch::Tensor& rstd_out, + torch::Tensor& y_out); + std::tuple add_rms_norm( const torch::Tensor& x1, const torch::Tensor& x2, diff --git a/xllm/core/kernels/npu/tilelang/CMakeLists.txt b/xllm/core/kernels/npu/tilelang/CMakeLists.txt new file mode 100644 index 000000000..5f7b95a1c --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/CMakeLists.txt @@ -0,0 +1,183 @@ +include(cc_library) +include(cc_test) +include(CMakeParseArguments) + +if(NOT USE_NPU) + return() +endif() + +if(NOT DEFINED ENV{NPU_HOME_PATH}) + message(FATAL_ERROR "NPU_HOME_PATH is not set") +endif() + +function(tilelang_import_kernel_manifest) + set(options "") + set(oneValueArgs PREFIX MANIFEST_PATH) + cmake_parse_arguments(TL "${options}" "${oneValueArgs}" "" ${ARGN}) + + if(NOT TL_PREFIX OR NOT TL_MANIFEST_PATH) + message(FATAL_ERROR "tilelang_import_kernel_manifest requires PREFIX and MANIFEST_PATH") + endif() + + if(NOT EXISTS "${TL_MANIFEST_PATH}") + message(FATAL_ERROR + "Missing TileLang manifest: ${TL_MANIFEST_PATH}\n" + "Run the Python TileLang compile entry before invoking CMake build.") + endif() + + file(READ "${TL_MANIFEST_PATH}" TL_MANIFEST_JSON) + string(JSON _schema_version GET "${TL_MANIFEST_JSON}" schema_version) + if(NOT _schema_version EQUAL 2) + message(FATAL_ERROR + "Unsupported TileLang manifest schema_version=${_schema_version}: ${TL_MANIFEST_PATH}") + endif() + + string(JSON _variants_inc GET "${TL_MANIFEST_JSON}" variants_inc) + if(NOT EXISTS "${_variants_inc}") + message(FATAL_ERROR "TileLang variants.inc does not exist: ${_variants_inc}") + endif() + + string(JSON _registry_inc ERROR_VARIABLE _registry_inc_error GET "${TL_MANIFEST_JSON}" registry_inc) + if(_registry_inc_error) + set(_registry_inc "${_variants_inc}") + endif() + if(NOT EXISTS "${_registry_inc}") + message(FATAL_ERROR "TileLang registry.inc does not exist: ${_registry_inc}") + endif() + + set(_kernel_objects "") + string(JSON _variants_len LENGTH "${TL_MANIFEST_JSON}" variants) + if(_variants_len LESS 1) + message(FATAL_ERROR "TileLang manifest contains no variants: ${TL_MANIFEST_PATH}") + endif() + + math(EXPR _variants_last "${_variants_len} - 1") + foreach(_idx RANGE 0 ${_variants_last}) + string(JSON _generated_source GET "${TL_MANIFEST_JSON}" variants ${_idx} generated_source) + string(JSON _compiled_binary GET "${TL_MANIFEST_JSON}" variants ${_idx} compiled_binary) + + if(NOT EXISTS "${_generated_source}") + message(FATAL_ERROR "TileLang generated source does not exist: ${_generated_source}") + endif() + if(NOT EXISTS "${_compiled_binary}") + message(FATAL_ERROR "TileLang compiled binary does not exist: ${_compiled_binary}") + endif() + + set_source_files_properties("${_compiled_binary}" PROPERTIES + GENERATED TRUE + EXTERNAL_OBJECT TRUE + ) + list(APPEND _kernel_objects "${_compiled_binary}") + endforeach() + + set(${TL_PREFIX}_KERNEL_OBJECTS "${_kernel_objects}" PARENT_SCOPE) + set(${TL_PREFIX}_VARIANTS_INC "${_variants_inc}" PARENT_SCOPE) + set(${TL_PREFIX}_REGISTRY_INC "${_registry_inc}" PARENT_SCOPE) +endfunction() + +set(TILELANG_GENERATED_ROOT "${CMAKE_BINARY_DIR}/xllm/compiler/tilelang") + +function(tilelang_register_runtime_kernel) + set(options "") + set(oneValueArgs NAME) + set(multiValueArgs WRAPPER_SRCS) + cmake_parse_arguments(TL_RUNTIME "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if(NOT TL_RUNTIME_NAME OR NOT TL_RUNTIME_WRAPPER_SRCS) + message(FATAL_ERROR + "tilelang_register_runtime_kernel requires NAME and WRAPPER_SRCS") + endif() + if(NOT DEFINED TILELANG_GENERATED_ROOT) + message(FATAL_ERROR "TILELANG_GENERATED_ROOT must be set before registering TileLang kernels") + endif() + + set(_manifest_path + "${TILELANG_GENERATED_ROOT}/targets/ascend/${TL_RUNTIME_NAME}/manifest.json") + set(_kernel_prefix "tilelang_${TL_RUNTIME_NAME}") + tilelang_import_kernel_manifest( + PREFIX "${_kernel_prefix}" + MANIFEST_PATH "${_manifest_path}" + ) + + set(_kernel_srcs ${TILELANG_KERNEL_SRCS}) + list(APPEND _kernel_srcs + ${TL_RUNTIME_WRAPPER_SRCS} + ${${_kernel_prefix}_KERNEL_OBJECTS} + ) + set(TILELANG_KERNEL_SRCS "${_kernel_srcs}" PARENT_SCOPE) + + string(TOUPPER "${TL_RUNTIME_NAME}" _kernel_name_upper) + set(_kernel_private_definitions ${TILELANG_KERNEL_PRIVATE_DEFINITIONS}) + list(APPEND _kernel_private_definitions + XLLM_TL_${_kernel_name_upper}_REGISTRY_INC=\"${${_kernel_prefix}_REGISTRY_INC}\" + ) + set(TILELANG_KERNEL_PRIVATE_DEFINITIONS + "${_kernel_private_definitions}" PARENT_SCOPE) +endfunction() + +# Add more TileLang kernels here: +# tilelang_register_runtime_kernel( +# NAME foo +# WRAPPER_SRCS foo_wrapper.cpp +# ) +set(TILELANG_KERNEL_SRCS) +set(TILELANG_KERNEL_PRIVATE_DEFINITIONS) + +tilelang_register_runtime_kernel( + NAME rope + WRAPPER_SRCS rope_wrapper.cpp +) + +tilelang_register_runtime_kernel( + NAME fused_gdn_gating + WRAPPER_SRCS fused_gdn_gating_wrapper.cpp +) + +cc_library( + NAME + tilelang_kernels + HDRS + dispatch_registry.h + tilelang_ops_api.h + SRCS + ${TILELANG_KERNEL_SRCS} + DEPS + torch + torch_npu +) + +target_compile_definitions(tilelang_kernels PRIVATE + ${TILELANG_KERNEL_PRIVATE_DEFINITIONS} +) + +target_link_libraries(tilelang_kernels + PUBLIC + "$ENV{NPU_HOME_PATH}/lib64/libascendcl.so" + "$ENV{NPU_HOME_PATH}/lib64/libruntime.so" + "$ENV{NPU_HOME_PATH}/lib64/libplatform.so" + "$ENV{NPU_HOME_PATH}/lib64/libc_sec.so" +) + +cc_test( + NAME + rope_wrapper_test + SRCS + rope_wrapper_test.cpp + DEPS + :tilelang_kernels + torch + GTest::gtest_main + glog::glog +) + +cc_test( + NAME + fused_gdn_gating_wrapper_test + SRCS + fused_gdn_gating_wrapper_test.cpp + DEPS + :tilelang_kernels + torch + GTest::gtest_main + glog::glog +) diff --git a/xllm/core/kernels/npu/tilelang/dispatch_registry.h b/xllm/core/kernels/npu/tilelang/dispatch_registry.h new file mode 100644 index 000000000..d4582c80b --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/dispatch_registry.h @@ -0,0 +1,96 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace xllm::kernel::npu::tilelang { + +enum class TilelangDType { + kBF16, + kFP16, + kFP32, + kFloat16, + kFloat32, + kInt8, + kInt32, + kUInt8, +}; + +inline TilelangDType to_tilelang_dtype(c10::ScalarType dtype) { + switch (dtype) { + case c10::ScalarType::BFloat16: + return TilelangDType::kBF16; + case c10::ScalarType::Half: + return TilelangDType::kFloat16; + case c10::ScalarType::Float: + return TilelangDType::kFloat32; + case c10::ScalarType::Char: + return TilelangDType::kInt8; + case c10::ScalarType::Int: + return TilelangDType::kInt32; + case c10::ScalarType::Byte: + return TilelangDType::kUInt8; + default: + LOG(FATAL) << "TileLang: unsupported dtype " << dtype; + } + return TilelangDType::kBF16; +} + +template +using function_type_t = std::remove_pointer_t; + +template +struct KernelEntry { + Specialization spec; + const char* variant_key; + Fn fn; +}; + +template +inline const Entry* find_kernel_entry(const std::array& registry, + const Specialization& specialization) { + for (const auto& entry : registry) { + if (entry.spec == specialization) { + return &entry; + } + } + return nullptr; +} + +template +inline std::string available_variant_keys( + const std::array& registry) { + std::ostringstream oss; + bool first = true; + for (const auto& entry : registry) { + if (!first) { + oss << ", "; + } + first = false; + oss << entry.variant_key; + } + return oss.str(); +} + +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp new file mode 100644 index 000000000..07331ab97 --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper.cpp @@ -0,0 +1,269 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "core/kernels/npu/tilelang/dispatch_registry.h" +#include "core/kernels/npu/tilelang/tilelang_ops_api.h" + +#ifndef XLLM_TL_FUSED_GDN_GATING_REGISTRY_INC +#error "XLLM_TL_FUSED_GDN_GATING_REGISTRY_INC is not defined" +#endif + +namespace xllm::kernel::npu::tilelang { +namespace { + +constexpr int64_t kCompileMaxBatch = 4096; +constexpr int64_t kCompileMaxHeads = 128; +constexpr int32_t kBatchSpecializationMin = 2; +constexpr int32_t kBatchSpecializationStep = 2; + +#include XLLM_TL_FUSED_GDN_GATING_REGISTRY_INC + +int32_t max_compiled_batch_size(int32_t num_heads, TilelangDType dtype) { + int32_t max_batch_size = 0; + for (const auto& entry : kFusedGdnGatingRegistry) { + const auto& spec = entry.spec; + if (spec.num_heads == num_heads && spec.dtype == dtype) { + max_batch_size = std::max(max_batch_size, spec.batch_size); + } + } + return max_batch_size; +} + +int32_t select_launch_batch_size(int64_t num_batches, + int32_t num_heads, + TilelangDType dtype) { + CHECK_GT(num_batches, 0) + << "TileLang fused_gdn_gating: num_batches must be > 0"; + const int32_t max_batch_size = max_compiled_batch_size(num_heads, dtype); + CHECK_GT(max_batch_size, 0) + << "TileLang fused_gdn_gating: no compiled batch_size variant for " + << "num_heads=" << num_heads << ", dtype=" << static_cast(dtype); + CHECK_GE(max_batch_size, kBatchSpecializationMin) + << "TileLang fused_gdn_gating: compiled batch_size variants must be >= " + << kBatchSpecializationMin; + + const int64_t capped = std::min(num_batches, max_batch_size); + int64_t rounded_up_even = + ((capped + kBatchSpecializationStep - 1) / kBatchSpecializationStep) * + kBatchSpecializationStep; + rounded_up_even = std::max(rounded_up_even, kBatchSpecializationMin); + rounded_up_even = std::min(rounded_up_even, max_batch_size); + if ((rounded_up_even % kBatchSpecializationStep) != 0) { + rounded_up_even -= 1; + } + return static_cast(rounded_up_even); +} + +void check_supported(const torch::Tensor& A_log, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& dt_bias) { + CHECK(A_log.defined()) << "TileLang fused_gdn_gating: A_log must be defined"; + CHECK(a.defined()) << "TileLang fused_gdn_gating: a must be defined"; + CHECK(b.defined()) << "TileLang fused_gdn_gating: b must be defined"; + CHECK(dt_bias.defined()) + << "TileLang fused_gdn_gating: dt_bias must be defined"; + + CHECK(A_log.device().type() == c10::DeviceType::PrivateUse1 && + a.device().type() == c10::DeviceType::PrivateUse1 && + b.device().type() == c10::DeviceType::PrivateUse1 && + dt_bias.device().type() == c10::DeviceType::PrivateUse1) + << "TileLang fused_gdn_gating: all tensors must be on NPU"; + + CHECK_EQ(A_log.dim(), 1) + << "TileLang fused_gdn_gating: A_log must be 1D [num_heads]"; + CHECK_EQ(dt_bias.dim(), 1) + << "TileLang fused_gdn_gating: dt_bias must be 1D [num_heads]"; + CHECK_EQ(a.dim(), 2) << "TileLang fused_gdn_gating: a must be 2D [B, H]"; + CHECK_EQ(b.dim(), 2) << "TileLang fused_gdn_gating: b must be 2D [B, H]"; + CHECK_EQ(a.sizes(), b.sizes()) + << "TileLang fused_gdn_gating: a/b shape mismatch"; + CHECK_EQ(A_log.size(0), a.size(1)) + << "TileLang fused_gdn_gating: A_log head size mismatch"; + CHECK_EQ(dt_bias.size(0), a.size(1)) + << "TileLang fused_gdn_gating: dt_bias head size mismatch"; + CHECK_GT(a.size(1), 0) << "TileLang fused_gdn_gating: num_heads must be > 0"; + CHECK_LE(a.size(1), kCompileMaxHeads) + << "TileLang fused_gdn_gating: num_heads must be <= " << kCompileMaxHeads + << ", got " << a.size(1); + + CHECK_EQ(A_log.dtype(), torch::kFloat32) + << "TileLang fused_gdn_gating: A_log must be float32"; + CHECK_EQ(dt_bias.dtype(), torch::kFloat32) + << "TileLang fused_gdn_gating: dt_bias must be float32"; + CHECK_EQ(a.dtype(), b.dtype()) + << "TileLang fused_gdn_gating: a/b dtype mismatch"; + CHECK_EQ(a.dtype(), torch::kBFloat16) + << "TileLang fused_gdn_gating: only bf16 inputs are supported"; + + CHECK(A_log.is_contiguous()) + << "TileLang fused_gdn_gating: A_log must be contiguous"; + CHECK(dt_bias.is_contiguous()) + << "TileLang fused_gdn_gating: dt_bias must be contiguous"; + CHECK_EQ(a.stride(1), 1) + << "TileLang fused_gdn_gating: a last-dim stride must be 1"; + CHECK_EQ(b.stride(1), 1) + << "TileLang fused_gdn_gating: b last-dim stride must be 1"; + CHECK_GT(a.stride(0), 0) + << "TileLang fused_gdn_gating: a row stride must be > 0"; + CHECK_GT(b.stride(0), 0) + << "TileLang fused_gdn_gating: b row stride must be > 0"; +} + +FusedGdnGatingSpecialization build_runtime_specialization( + const torch::Tensor& a) { + CHECK_EQ(a.dim(), 2) << "TileLang fused_gdn_gating: a must be 2D"; + const TilelangDType dtype = to_tilelang_dtype(a.scalar_type()); + const int32_t num_heads = static_cast(a.size(1)); + const int32_t batch_size = + select_launch_batch_size(a.size(0), num_heads, dtype); + return make_fused_gdn_gating_specialization( + FusedGdnGatingBatchSize{batch_size}, + FusedGdnGatingNumHeads{num_heads}, + FusedGdnGatingDType{dtype}); +} + +void run_tilelang_fused_gdn_gating_chunk(const torch::Tensor& A_log, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& dt_bias, + torch::Tensor& g_out, + torch::Tensor& beta_out, + float softplus_beta, + float softplus_threshold) { + CHECK_EQ(a.dim(), 2) << "TileLang fused_gdn_gating: a must be 2D"; + CHECK_EQ(b.dim(), 2) << "TileLang fused_gdn_gating: b must be 2D"; + CHECK_EQ(g_out.dim(), 3) + << "TileLang fused_gdn_gating: g_out must be 3D [1, B, H]"; + CHECK_EQ(beta_out.dim(), 3) + << "TileLang fused_gdn_gating: beta_out must be 3D [1, B, H]"; + CHECK_EQ(g_out.size(0), 1) + << "TileLang fused_gdn_gating: g_out first dim must be 1"; + CHECK_EQ(beta_out.size(0), 1) + << "TileLang fused_gdn_gating: beta_out first dim must be 1"; + CHECK_EQ(g_out.size(1), a.size(0)) + << "TileLang fused_gdn_gating: g_out batch mismatch"; + CHECK_EQ(beta_out.size(1), a.size(0)) + << "TileLang fused_gdn_gating: beta_out batch mismatch"; + CHECK_EQ(g_out.size(2), a.size(1)) + << "TileLang fused_gdn_gating: g_out head mismatch"; + CHECK_EQ(beta_out.size(2), a.size(1)) + << "TileLang fused_gdn_gating: beta_out head mismatch"; + CHECK_EQ(g_out.dtype(), torch::kFloat32) + << "TileLang fused_gdn_gating: g_out must be float32"; + CHECK_EQ(beta_out.dtype(), a.dtype()) + << "TileLang fused_gdn_gating: beta_out dtype mismatch"; + CHECK_EQ(g_out.stride(2), 1) + << "TileLang fused_gdn_gating: g_out last-dim stride must be 1"; + CHECK_EQ(beta_out.stride(2), 1) + << "TileLang fused_gdn_gating: beta_out last-dim stride must be 1"; + CHECK_GT(g_out.stride(1), 0) + << "TileLang fused_gdn_gating: g_out row stride must be > 0"; + CHECK_GT(beta_out.stride(1), 0) + << "TileLang fused_gdn_gating: beta_out row stride must be > 0"; + CHECK_LE(a.size(0), kCompileMaxBatch) + << "TileLang fused_gdn_gating: chunk batch exceeds compile limit " + << kCompileMaxBatch; + CHECK_GT(softplus_beta, 0.0F) + << "TileLang fused_gdn_gating: softplus_beta must be > 0"; + + auto specialization = build_runtime_specialization(a); + const auto* entry = find_fused_gdn_gating_kernel_entry(specialization); + // Expected fast path: compiled batch_size variants are dense [2, 4, ..., 48]. + // If a value is missing, fall back to the nearest smaller batch_size. + if (entry == nullptr) { + int32_t fallback_batch_size = + specialization.batch_size - kBatchSpecializationStep; + while (fallback_batch_size >= kBatchSpecializationMin && entry == nullptr) { + specialization = make_fused_gdn_gating_specialization( + FusedGdnGatingBatchSize{fallback_batch_size}, + FusedGdnGatingNumHeads{specialization.num_heads}, + FusedGdnGatingDType{specialization.dtype}); + entry = find_fused_gdn_gating_kernel_entry(specialization); + fallback_batch_size -= kBatchSpecializationStep; + } + } + CHECK(entry != nullptr) + << "TileLang fused_gdn_gating: no compiled variant. Available variants: " + << available_fused_gdn_gating_variant_keys(); + + const int64_t num_batches = a.size(0); + + const int32_t device_id = a.device().index(); + aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + auto g_rows = g_out.squeeze(0); + auto beta_rows = beta_out.squeeze(0); + entry->fn(static_cast(A_log.data_ptr()), + static_cast(a.data_ptr()), + static_cast(b.data_ptr()), + static_cast(dt_bias.data_ptr()), + static_cast(g_rows.data_ptr()), + static_cast(beta_rows.data_ptr()), + static_cast(num_batches), + softplus_beta, + softplus_threshold, + stream); +} + +} // namespace + +std::pair fused_gdn_gating( + const torch::Tensor& A_log, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& dt_bias, + float softplus_beta, + float softplus_threshold) { + check_supported(A_log, a, b, dt_bias); + + const auto num_batches = a.size(0); + const auto num_heads = a.size(1); + auto g_out = torch::empty({1, num_batches, num_heads}, + a.options().dtype(torch::kFloat32)); + auto beta_out = torch::empty({1, num_batches, num_heads}, a.options()); + + for (int64_t start = 0; start < num_batches; start += kCompileMaxBatch) { + const int64_t chunk_batches = + std::min(kCompileMaxBatch, num_batches - start); + auto a_chunk = a.narrow(0, start, chunk_batches); + auto b_chunk = b.narrow(0, start, chunk_batches); + auto g_chunk = g_out.narrow(1, start, chunk_batches); + auto beta_chunk = beta_out.narrow(1, start, chunk_batches); + run_tilelang_fused_gdn_gating_chunk(A_log, + a_chunk, + b_chunk, + dt_bias, + g_chunk, + beta_chunk, + softplus_beta, + softplus_threshold); + } + + return {g_out, beta_out}; +} + +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp new file mode 100644 index 000000000..f91d06d7a --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/fused_gdn_gating_wrapper_test.cpp @@ -0,0 +1,164 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include + +#include "core/kernels/npu/tilelang/tilelang_ops_api.h" + +namespace xllm::kernel::npu::tilelang { +namespace { + +class TileLangFusedGdnGatingWrapperTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { torch_npu::init_npu("npu:0"); } + + static void TearDownTestSuite() { torch_npu::finalize_npu(); } +}; + +struct FusedGdnGatingTestCase { + std::string name; + int64_t num_batches; + int64_t num_heads; + int64_t seed; + float beta = 1.0F; + float threshold = 20.0F; +}; + +std::pair torch_fused_gdn_gating( + const torch::Tensor& A_log, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& dt_bias, + float beta = 1.0F, + float threshold = 20.0F) { + namespace F = torch::nn::functional; + + auto softplus_out = + F::softplus(a.to(torch::kFloat32) + dt_bias, + F::SoftplusFuncOptions().beta(beta).threshold(threshold)); + auto g = -A_log.exp() * softplus_out; + auto beta_output = torch::sigmoid(b.to(torch::kFloat32)).to(torch::kBFloat16); + return {g.unsqueeze(0), beta_output.unsqueeze(0)}; +} + +void run_fused_gdn_gating_case(const FusedGdnGatingTestCase& test_case) { + ASSERT_GT(test_case.num_batches, 0); + + const auto device = torch::Device("npu:0"); + torch::manual_seed(test_case.seed); + + auto fp32_opts = torch::TensorOptions().dtype(torch::kFloat32).device(device); + auto bf16_opts = + torch::TensorOptions().dtype(torch::kBFloat16).device(device); + + auto A_log = torch::randn({test_case.num_heads}, fp32_opts); + auto a = + torch::randn({test_case.num_batches, test_case.num_heads}, bf16_opts); + auto b = + torch::randn({test_case.num_batches, test_case.num_heads}, bf16_opts); + auto dt_bias = torch::randn({test_case.num_heads}, fp32_opts); + + auto [g_ref, beta_ref] = torch_fused_gdn_gating( + A_log, a, b, dt_bias, test_case.beta, test_case.threshold); + auto [g_out, beta_out] = fused_gdn_gating( + A_log, a, b, dt_bias, test_case.beta, test_case.threshold); + + auto g_max_diff = (g_out - g_ref).abs().max().item(); + auto beta_max_diff = + (beta_out.to(torch::kFloat32) - beta_ref.to(torch::kFloat32)) + .abs() + .max() + .item(); + + EXPECT_TRUE(torch::allclose(g_out, g_ref, 1e-3, 1e-3)) + << "g mismatch, max_diff=" << g_max_diff; + EXPECT_TRUE(torch::allclose(beta_out, beta_ref, 1e-2, 1e-2)) + << "beta mismatch, max_diff=" << beta_max_diff; +} + +TEST_F(TileLangFusedGdnGatingWrapperTest, MatchesTorchReference) { + const std::vector cases = { + { + .name = "tiny_b1_h8", + .num_batches = 1, + .num_heads = 8, + .seed = 101, + }, + { + .name = "tiny_b17_h8", + .num_batches = 17, + .num_heads = 8, + .seed = 101, + }, + { + .name = "tiny_b1_h16", + .num_batches = 1, + .num_heads = 16, + .seed = 101, + }, + { + .name = "small_b17_h32", + .num_batches = 17, + .num_heads = 32, + .seed = 102, + }, + { + .name = "medium_b29_h48", + .num_batches = 29, + .num_heads = 48, + .seed = 103, + }, + { + .name = "medium_b131_h64", + .num_batches = 131, + .num_heads = 64, + .seed = 104, + }, + { + .name = "medium_b257_h128", + .num_batches = 257, + .num_heads = 128, + .seed = 105, + }, + { + .name = "large_b4096_h32", + .num_batches = 4096, + .num_heads = 32, + .seed = 106, + }, + { + .name = "custom_beta2_threshold0p5_b33_h64", + .num_batches = 33, + .num_heads = 64, + .seed = 107, + .beta = 2.0F, + .threshold = 0.5F, + }, + }; + + for (const auto& test_case : cases) { + SCOPED_TRACE(test_case.name); + run_fused_gdn_gating_case(test_case); + } +} + +} // namespace +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/npu/tilelang/rope_wrapper.cpp b/xllm/core/kernels/npu/tilelang/rope_wrapper.cpp new file mode 100644 index 000000000..962fe1b5f --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/rope_wrapper.cpp @@ -0,0 +1,168 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#include + +#include "acl/acl.h" +#include "dispatch_registry.h" +#include "tilelang_ops_api.h" + +#ifndef XLLM_TL_ROPE_REGISTRY_INC +#error "XLLM_TL_ROPE_REGISTRY_INC is not defined" +#endif + +namespace xllm::kernel::npu::tilelang { +namespace { + +#include XLLM_TL_ROPE_REGISTRY_INC + +void check_supported(const torch::Tensor& input, + const torch::Tensor& sin_cache, + const torch::Tensor& cos_cache) { + CHECK(input.defined()) << "TileLang RoPE: input must be defined"; + CHECK(sin_cache.defined()) << "TileLang RoPE: sin_cache must be defined"; + CHECK(cos_cache.defined()) << "TileLang RoPE: cos_cache must be defined"; + + CHECK(input.device().type() == c10::DeviceType::PrivateUse1 && + sin_cache.device().type() == c10::DeviceType::PrivateUse1 && + cos_cache.device().type() == c10::DeviceType::PrivateUse1) + << "TileLang RoPE: all tensors must be on NPU"; + + CHECK_EQ(input.dtype(), sin_cache.dtype()) + << "TileLang RoPE: input/sin_cache dtype mismatch"; + CHECK_EQ(input.dtype(), cos_cache.dtype()) + << "TileLang RoPE: input/cos_cache dtype mismatch"; + [[maybe_unused]] auto dtype = to_tilelang_dtype(input.scalar_type()); + + CHECK_EQ(input.dim(), 3) << "TileLang RoPE: input must be 3D [T, H, D]"; + CHECK_EQ(input.stride(2), 1) + << "TileLang RoPE: input last dim stride must be 1"; + CHECK_EQ(input.stride(0), input.size(1) * input.stride(1)) + << "TileLang RoPE: unsupported input layout"; + + CHECK_EQ(sin_cache.dim(), 2) + << "TileLang RoPE: sin_cache must be 2D [T, rope_dim]"; + CHECK_EQ(cos_cache.dim(), 2) + << "TileLang RoPE: cos_cache must be 2D [T, rope_dim]"; + CHECK_EQ(sin_cache.sizes(), cos_cache.sizes()) + << "TileLang RoPE: sin_cache/cos_cache shape mismatch"; + CHECK_EQ(sin_cache.size(1), input.size(2)) + << "TileLang RoPE: rope_dim mismatch between input and sin_cache"; + CHECK_EQ(sin_cache.size(0), input.size(0)) + << "TileLang RoPE: sin_cache token size must match input.size(0)"; + + const int64_t row_count = input.size(0) * input.size(1); + CHECK_GT(row_count, 0) << "TileLang RoPE: row_count must be > 0"; +} + +RopeSpecialization build_runtime_specialization(const torch::Tensor& x_rows) { + CHECK_EQ(x_rows.dim(), 2) << "TileLang RoPE: x_rows must be 2D"; + CHECK_GT(x_rows.stride(0), 0) << "TileLang RoPE: x_rows stride must be > 0"; + CHECK_LE(x_rows.stride(0), + static_cast(std::numeric_limits::max())) + << "TileLang RoPE: x_rows stride exceeds int range"; + CHECK_LE(x_rows.size(1), + static_cast(std::numeric_limits::max())) + << "TileLang RoPE: rope_dim exceeds int range"; + + return make_rope_specialization( + RopeHeadDim{static_cast(x_rows.stride(0))}, + RopeRopeDim{static_cast(x_rows.size(1))}, + RopeDType{to_tilelang_dtype(x_rows.scalar_type())}); +} + +void run_tilelang_rope_once(torch::Tensor& x_rows, + const torch::Tensor& sin_rows, + const torch::Tensor& cos_rows) { + CHECK_EQ(x_rows.dim(), 2) << "TileLang RoPE: x_rows must be 2D"; + CHECK_EQ(sin_rows.dim(), 2) << "TileLang RoPE: sin_rows must be 2D"; + CHECK_EQ(cos_rows.dim(), 2) << "TileLang RoPE: cos_rows must be 2D"; + CHECK_EQ(x_rows.size(0), sin_rows.size(0)) + << "TileLang RoPE: x_rows/sin_rows row mismatch"; + CHECK_EQ(x_rows.size(0), cos_rows.size(0)) + << "TileLang RoPE: x_rows/cos_rows row mismatch"; + CHECK_EQ(x_rows.size(1), sin_rows.size(1)) + << "TileLang RoPE: x_rows/sin_rows rope_dim mismatch"; + CHECK_EQ(x_rows.size(1), cos_rows.size(1)) + << "TileLang RoPE: x_rows/cos_rows rope_dim mismatch"; + + CHECK_EQ(x_rows.dtype(), sin_rows.dtype()) + << "TileLang RoPE: x_rows/sin_rows dtype mismatch"; + CHECK_EQ(x_rows.dtype(), cos_rows.dtype()) + << "TileLang RoPE: x_rows/cos_rows dtype mismatch"; + + CHECK_EQ(x_rows.stride(1), 1) + << "TileLang RoPE: x_rows last dim stride must be 1"; + CHECK(sin_rows.is_contiguous()) + << "TileLang RoPE: sin_rows must be contiguous"; + CHECK(cos_rows.is_contiguous()) + << "TileLang RoPE: cos_rows must be contiguous"; + + const int64_t row_count = x_rows.size(0); + CHECK_LE(row_count, static_cast(std::numeric_limits::max())) + << "TileLang RoPE: row_count exceeds int range"; + + const RopeSpecialization specialization = + build_runtime_specialization(x_rows); + const auto* entry = find_rope_kernel_entry(specialization); + CHECK(entry != nullptr) + << "TileLang RoPE: no compiled variant. Available variants: " + << available_rope_variant_keys(); + CHECK_GE(specialization.head_dim, specialization.rope_dim) + << "TileLang RoPE: compiled head_dim must be >= rope_dim"; + + const int32_t device_id = x_rows.device().index(); + aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + const int32_t num_tokens = static_cast(row_count); + const int32_t x_stride = specialization.head_dim; + + entry->fn(reinterpret_cast(x_rows.data_ptr()), + reinterpret_cast(const_cast(sin_rows.data_ptr())), + reinterpret_cast(const_cast(cos_rows.data_ptr())), + reinterpret_cast(x_rows.data_ptr()), + num_tokens, + x_stride, + stream); +} + +} // namespace + +void rope_in_place(torch::Tensor& input, + const torch::Tensor& sin_cache, + const torch::Tensor& cos_cache) { + check_supported(input, sin_cache, cos_cache); + + auto input_rows = + input.as_strided({input.size(0) * input.size(1), input.size(2)}, + {input.stride(1), input.stride(2)}); + auto sin_rows = sin_cache.unsqueeze(1) + .expand({input.size(0), input.size(1), sin_cache.size(1)}) + .contiguous() + .view({input_rows.size(0), sin_cache.size(1)}); + auto cos_rows = cos_cache.unsqueeze(1) + .expand({input.size(0), input.size(1), cos_cache.size(1)}) + .contiguous() + .view({input_rows.size(0), cos_cache.size(1)}); + run_tilelang_rope_once(input_rows, sin_rows, cos_rows); +} + +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/npu/tilelang/rope_wrapper_test.cpp b/xllm/core/kernels/npu/tilelang/rope_wrapper_test.cpp new file mode 100644 index 000000000..24a1c77ca --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/rope_wrapper_test.cpp @@ -0,0 +1,408 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "tilelang_ops_api.h" + +namespace xllm::kernel::npu::tilelang { +namespace { + +class TileLangRopeWrapperTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { torch_npu::init_npu("npu:0"); } + + static void TearDownTestSuite() { torch_npu::finalize_npu(); } +}; + +struct RopeTestCase { + std::string name; + int64_t num_tokens; + int64_t num_heads; + int64_t full_head_dim; + int64_t rope_dim; + int64_t start_dim; + int64_t seed; +}; + +torch::Tensor torch_rope_ref(const torch::Tensor& x, + const torch::Tensor& sin, + const torch::Tensor& cos) { + auto cos_ref = cos; + auto sin_ref = sin; + if (cos_ref.dim() == 2) { + cos_ref = cos_ref.unsqueeze(1); + sin_ref = sin_ref.unsqueeze(1); + } + + auto x_fp32 = x.to(torch::kFloat32); + auto cos_fp32 = cos_ref.to(torch::kFloat32); + auto sin_fp32 = sin_ref.to(torch::kFloat32); + + auto x_reshaped = + x_fp32.view({x_fp32.size(0), x_fp32.size(1), x_fp32.size(2) / 2, 2}); + auto x0 = x_reshaped.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + 0}); + auto x1 = x_reshaped.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + 1}); + auto x_rotated = torch::stack({-x1, x0}, /*dim=*/-1).flatten(-2); + + auto out = x_fp32 * cos_fp32 + x_rotated * sin_fp32; + return out.to(torch::kBFloat16); +} + +double measure_npu_event_ms(const std::function& fn, + int32_t device_id, + int warmup_iters = 5, + int measure_iters = 100) { + CHECK_GT(measure_iters, 0) << "measure_iters must be > 0"; + CHECK_GE(warmup_iters, 0) << "warmup_iters must be >= 0"; + + const aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + for (int i = 0; i < warmup_iters; ++i) { + fn(); + } + CHECK_EQ(aclrtSynchronizeStream(stream), ACL_SUCCESS) + << "warmup stream synchronize failed"; + + aclrtEvent start_event = nullptr; + aclrtEvent end_event = nullptr; + CHECK_EQ(aclrtCreateEvent(&start_event), ACL_SUCCESS) + << "aclrtCreateEvent(start) failed"; + CHECK_EQ(aclrtCreateEvent(&end_event), ACL_SUCCESS) + << "aclrtCreateEvent(end) failed"; + + CHECK_EQ(aclrtRecordEvent(start_event, stream), ACL_SUCCESS) + << "aclrtRecordEvent(start) failed"; + for (int i = 0; i < measure_iters; ++i) { + fn(); + } + CHECK_EQ(aclrtRecordEvent(end_event, stream), ACL_SUCCESS) + << "aclrtRecordEvent(end) failed"; + CHECK_EQ(aclrtSynchronizeEvent(end_event), ACL_SUCCESS) + << "aclrtSynchronizeEvent(end) failed"; + + float elapsed_ms = 0.0F; + CHECK_EQ(aclrtEventElapsedTime(&elapsed_ms, start_event, end_event), + ACL_SUCCESS) + << "aclrtEventElapsedTime failed"; + CHECK_EQ(aclrtDestroyEvent(start_event), ACL_SUCCESS) + << "aclrtDestroyEvent(start) failed"; + CHECK_EQ(aclrtDestroyEvent(end_event), ACL_SUCCESS) + << "aclrtDestroyEvent(end) failed"; + + return static_cast(elapsed_ms) / static_cast(measure_iters); +} + +torch::Tensor maybe_narrow(const torch::Tensor& tensor, + int64_t start_dim, + int64_t rope_dim) { + if (start_dim == 0 && rope_dim == tensor.size(2)) { + return tensor; + } + return tensor.narrow(/*dim=*/2, /*start=*/start_dim, /*length=*/rope_dim); +} + +void run_apply_rotary_case(const RopeTestCase& test_case) { + ASSERT_GT(test_case.num_tokens, 0); + ASSERT_GT(test_case.num_heads, 0); + ASSERT_GT(test_case.full_head_dim, 0); + ASSERT_GT(test_case.rope_dim, 0); + ASSERT_GE(test_case.start_dim, 0); + ASSERT_LE(test_case.start_dim + test_case.rope_dim, test_case.full_head_dim); + + const auto npu_device = torch::Device("npu:0"); + const int32_t device_id = npu_device.index(); + const auto bf16_opts = + torch::TensorOptions().dtype(torch::kBFloat16).device(npu_device); + + torch::manual_seed(test_case.seed); + auto q_full = torch::randn( + {test_case.num_tokens, test_case.num_heads, test_case.full_head_dim}, + bf16_opts); + auto k_full = torch::randn( + {test_case.num_tokens, test_case.num_heads, test_case.full_head_dim}, + bf16_opts); + auto sin_cache = + torch::randn({test_case.num_tokens, test_case.rope_dim}, bf16_opts); + auto cos_cache = + torch::randn({test_case.num_tokens, test_case.rope_dim}, bf16_opts); + + auto q_input = maybe_narrow(q_full, test_case.start_dim, test_case.rope_dim); + auto k_input = maybe_narrow(k_full, test_case.start_dim, test_case.rope_dim); + + if (test_case.start_dim > 0) { + EXPECT_EQ(q_input.storage_offset(), test_case.start_dim); + EXPECT_EQ(k_input.storage_offset(), test_case.start_dim); + if (test_case.num_tokens * test_case.num_heads > 1) { + EXPECT_FALSE(q_input.is_contiguous()); + EXPECT_FALSE(k_input.is_contiguous()); + } + } + + auto q_ref = torch_rope_ref(q_input, sin_cache, cos_cache); + auto k_ref = torch_rope_ref(k_input, sin_cache, cos_cache); + auto q_runtime_full = q_full.clone(); + auto k_runtime_full = k_full.clone(); + auto q = + maybe_narrow(q_runtime_full, test_case.start_dim, test_case.rope_dim); + auto k = + maybe_narrow(k_runtime_full, test_case.start_dim, test_case.rope_dim); + rope_in_place(q, sin_cache, cos_cache); + rope_in_place(k, sin_cache, cos_cache); + + auto q_bench_full = q_full.clone(); + auto k_bench_full = k_full.clone(); + auto q_bench = + maybe_narrow(q_bench_full, test_case.start_dim, test_case.rope_dim); + auto k_bench = + maybe_narrow(k_bench_full, test_case.start_dim, test_case.rope_dim); + const double ref_elapsed_ms = measure_npu_event_ms( + [&]() { + [[maybe_unused]] auto q_ref_bench = + torch_rope_ref(q_input, sin_cache, cos_cache); + [[maybe_unused]] auto k_ref_bench = + torch_rope_ref(k_input, sin_cache, cos_cache); + }, + device_id); + const double tl_elapsed_ms = measure_npu_event_ms( + [&]() { + rope_in_place(q_bench, sin_cache, cos_cache); + rope_in_place(k_bench, sin_cache, cos_cache); + }, + device_id); + + const double speedup = + tl_elapsed_ms > 0.0 ? ref_elapsed_ms / tl_elapsed_ms : 0.0; + std::cout << "[rope_wrapper_test] case=" << test_case.name + << ", ref_ms=" << ref_elapsed_ms + << ", tilelang_ms=" << tl_elapsed_ms << ", speedup=" << speedup + << "x" << std::endl; + + auto q_max_diff = (q.to(torch::kFloat32) - q_ref.to(torch::kFloat32)) + .abs() + .max() + .item(); + auto k_max_diff = (k.to(torch::kFloat32) - k_ref.to(torch::kFloat32)) + .abs() + .max() + .item(); + + EXPECT_TRUE(torch::allclose(q, q_ref, /*rtol=*/1e-2, /*atol=*/1e-2)) + << "q mismatch: tilelang output differs from interleaved rope reference" + << ", max_diff=" << q_max_diff; + EXPECT_TRUE(torch::allclose(k, k_ref, /*rtol=*/1e-2, /*atol=*/1e-2)) + << "k mismatch: tilelang output differs from interleaved rope reference" + << ", max_diff=" << k_max_diff; +} + +TEST_F(TileLangRopeWrapperTest, ApplyRotaryMatchesNpuReferenceVariant128x128) { + const std::vector cases = { + {.name = "baseline_16x4_hd128_rd128", + .num_tokens = 16, + .num_heads = 4, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 20260213}, + {.name = "large_tokens_2051x2_hd128_rd128", + .num_tokens = 2051, + .num_heads = 2, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 20260214}, + {.name = "tiny_1x1_hd128_rd128", + .num_tokens = 1, + .num_heads = 1, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 101}, + {.name = "odd_tokens_7x3_hd128_rd128", + .num_tokens = 7, + .num_heads = 3, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 102}, + {.name = "token_dim_64x4_hd128_rd128", + .num_tokens = 64, + .num_heads = 4, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 107}, + {.name = "chunk_boundary_8x5_hd128_rd128", + .num_tokens = 8, + .num_heads = 5, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 103}, + {.name = "cross_chunk_9x5_hd128_rd128", + .num_tokens = 9, + .num_heads = 5, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 104}, + {.name = "head_dim_4x64_hd128_rd128", + .num_tokens = 4, + .num_heads = 64, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 108}, + {.name = "medium_127x8_hd128_rd128", + .num_tokens = 127, + .num_heads = 8, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 105}, + {.name = "large_heads_33x16_hd128_rd128", + .num_tokens = 33, + .num_heads = 16, + .full_head_dim = 128, + .rope_dim = 128, + .start_dim = 0, + .seed = 106}, + }; + + for (const auto& test_case : cases) { + SCOPED_TRACE(::testing::Message() << "case=" << test_case.name + << ", num_tokens=" << test_case.num_tokens + << ", num_heads=" << test_case.num_heads); + run_apply_rotary_case(test_case); + } +} + +TEST_F(TileLangRopeWrapperTest, ApplyRotaryMatchesNpuReferenceVariant576x64) { + constexpr int64_t kNumHeads = 1; + constexpr int64_t kFullHeadDim = 576; + constexpr int64_t kStartDim = 512; + constexpr int64_t kRopeDim = 64; + + const std::vector cases = { + {.name = "1x576_start512_rope64", + .num_tokens = 1, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260226}, + {.name = "8x576_start512_rope64", + .num_tokens = 8, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260227}, + {.name = "47x576_start512_rope64", + .num_tokens = 47, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260301}, + {.name = "48x576_start512_rope64", + .num_tokens = 48, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260302}, + {.name = "49x576_start512_rope64", + .num_tokens = 49, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260303}, + {.name = "95x576_start512_rope64", + .num_tokens = 95, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260304}, + {.name = "96x576_start512_rope64", + .num_tokens = 96, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260305}, + {.name = "97x576_start512_rope64", + .num_tokens = 97, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260306}, + {.name = "128x576_start512_rope64", + .num_tokens = 128, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260228}, + {.name = "512x576_start512_rope64", + .num_tokens = 512, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260307}, + {.name = "1024x576_start512_rope64", + .num_tokens = 1024, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260308}, + {.name = "2048x576_start512_rope64", + .num_tokens = 2048, + .num_heads = kNumHeads, + .full_head_dim = kFullHeadDim, + .rope_dim = kRopeDim, + .start_dim = kStartDim, + .seed = 20260225}, + }; + + for (const auto& test_case : cases) { + SCOPED_TRACE(::testing::Message() << "case=" << test_case.name + << ", num_tokens=" << test_case.num_tokens + << ", num_heads=" << test_case.num_heads); + run_apply_rotary_case(test_case); + } +} + +} // namespace +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/npu/tilelang/tilelang_ops_api.h b/xllm/core/kernels/npu/tilelang/tilelang_ops_api.h new file mode 100644 index 000000000..6351413d7 --- /dev/null +++ b/xllm/core/kernels/npu/tilelang/tilelang_ops_api.h @@ -0,0 +1,43 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +namespace xllm::kernel::npu::tilelang { + +// Public TileLang kernel APIs exported to the xLLM NPU runtime. +// +// Apply TileLang RoPE kernel in-place on a single input tensor. +// Invalid inputs trigger CHECK failures. +// Supports input not contiguous, with stride. +void rope_in_place(torch::Tensor& input, + const torch::Tensor& sin_cache, + const torch::Tensor& cos_cache); + +// Compute fused GDN gating outputs on NPU. +// Invalid inputs trigger CHECK failures. +std::pair fused_gdn_gating( + const torch::Tensor& A_log, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& dt_bias, + float softplus_beta, + float softplus_threshold); + +} // namespace xllm::kernel::npu::tilelang diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 552b8a1cf..d4ea329ff 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -18,6 +18,7 @@ limitations under the License. #if defined(USE_MLU) #include "mlu/mlu_ops_api.h" #elif defined(USE_NPU) +#include "core/kernels/npu/tilelang/tilelang_ops_api.h" #include "npu/npu_ops_api.h" #include "triton_npu/torch_api/triton_ops_api.h" #elif defined(USE_CUDA) @@ -613,6 +614,8 @@ torch::Tensor apply_top_k_top_p(TopKPParams& params) { torch::Tensor random_sample(RandomSampleParams& params) { #if defined(USE_MLU) return mlu::random_sample(params.logits); +#elif defined(USE_CUDA) + return cuda::random_sample(params.logits); #else NOT_IMPLEMENTED(); #endif @@ -786,12 +789,18 @@ std::tuple fp8_scaled_quantize( std::pair fused_gdn_gating( FusedGdnGatingParams& params) { #if defined(USE_NPU) - return npu::npu_fused_gdn_gating(params.A_log, - params.a, - params.b, - params.dt_bias, - params.beta, - params.threshold); + return npu::tilelang::fused_gdn_gating(params.A_log, + params.a, + params.b, + params.dt_bias, + params.beta, + params.threshold); + // return npu::npu_fused_gdn_gating(params.A_log, + // params.a, + // params.b, + // params.dt_bias, + // params.beta, + // params.threshold); #else NOT_IMPLEMENTED(); #endif @@ -908,6 +917,10 @@ std::tuple fused_add_rms_norm_static_fp8_quant( torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) { #if defined(USE_NPU) + if (params.conv_state_indices.has_value()) { + CHECK(params.conv_state_indices.value().is_contiguous()) + << "causal_conv1d_update: conv_state_indices must be contiguous."; + } return npu::npu_causal_conv1d_update(params.x, params.conv_state, params.weight, @@ -955,4 +968,27 @@ std::pair partial_rotary_embedding( NOT_IMPLEMENTED(); #endif } + +std::tuple +fused_qkvzba_split_reshape_cat(FusedQkvzbaSplitReshapeParams& params) { +#if defined(USE_NPU) + return npu::npu_fused_qkvzba_split_reshape_cat(params.mixed_qkvz, + params.mixed_ba, + params.num_heads_qk, + params.num_heads_v, + params.head_qk, + params.head_v); +#else + NOT_IMPLEMENTED(); +#endif +} + +void gemma_rms_norm(GemmaRMSNormParams& params) { +#if defined(USE_NPU) + npu::npu_gemma_rms_norm( + params.x, params.gamma, params.epsilon, params.rstd_out, params.norm_out); +#else + NOT_IMPLEMENTED(); +#endif +} } // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index a098d5e81..7e1e4a944 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -138,4 +138,9 @@ torch::Tensor gated_layer_norm(GatedLayerNormParams& params); std::pair partial_rotary_embedding( PartialRotaryEmbeddingParams& params); +std::tuple +fused_qkvzba_split_reshape_cat(FusedQkvzbaSplitReshapeParams& params); + +void gemma_rms_norm(GemmaRMSNormParams& params); + } // namespace xllm::kernel diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 222aa9209..6b76b20e6 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -1382,4 +1382,21 @@ struct PartialRotaryEmbeddingParams { torch::Tensor cos_sin_cache; bool is_neox_style; }; + +struct FusedQkvzbaSplitReshapeParams { + torch::Tensor mixed_qkvz; + torch::Tensor mixed_ba; + int32_t num_heads_qk; + int32_t num_heads_v; + int32_t head_qk; + int32_t head_v; +}; + +struct GemmaRMSNormParams { + torch::Tensor x; + torch::Tensor gamma; + double epsilon; + torch::Tensor rstd_out; + torch::Tensor norm_out; +}; } // namespace xllm::kernel diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 8eeee33dc..e90facd38 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -37,6 +37,7 @@ cc_library( layers HDRS onerec_block_layer.h + oxygen_vision_layer.h qwen2_decoder_layer.h qwen2_vision_layer.h qwen2_5_vision_layer.h @@ -44,6 +45,7 @@ cc_library( qwen3_decoder_layer.h qwen3_moe_decoder_layer.h SRCS + oxygen_vision_layer.cpp qwen2_vision_layer.cpp qwen2_decoder_layer.cpp qwen2_5_vision_layer.cpp diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index e2c952863..d290c4e85 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library( NAME common_layers HDRS + oxygen_vision_attention.h qwen2_attention.h qwen2_vision_attention.h qwen3_next_rms_norm.h @@ -24,6 +25,7 @@ cc_library( add_matmul.h moe_fused_topk.h SRCS + oxygen_vision_attention.cpp qwen2_attention.cpp qwen2_vision_attention.cpp qwen3_next_rms_norm.cpp diff --git a/xllm/core/layers/common/add_matmul.cpp b/xllm/core/layers/common/add_matmul.cpp index eb87b7ec4..4771b6338 100644 --- a/xllm/core/layers/common/add_matmul.cpp +++ b/xllm/core/layers/common/add_matmul.cpp @@ -98,5 +98,46 @@ void FusedAddMatmulImpl::load_state_dict( } } +AddMatmulWeightTransposedImpl::AddMatmulWeightTransposedImpl( + int64_t in, + int64_t out, + bool with_bias, + const torch::TensorOptions& options) + : AddMatmulImpl(in, out, with_bias, options) {} + +torch::Tensor AddMatmulWeightTransposedImpl::forward(const torch::Tensor& x) { + // use addmm when bias is provided + if (with_bias_) { + auto sizes = x.sizes(); + if (sizes.size() == 3) { + torch::Tensor t = x.reshape({sizes[0] * sizes[1], sizes[2]}); + return torch::addmm(bias_, t, weight_) + .reshape({sizes[0], sizes[1], weight_.size(1)}); + } else { + return torch::addmm(bias_, x, weight_); + } + } else { + return torch::matmul(x, weight_); + } +} + +void AddMatmulWeightTransposedImpl::load_state_dict( + const StateDict& state_dict) { + // only transpoes weights when state_dict has the key + // or it would be transposed multiple times when having + // multiple state dicts + if (state_dict.has("weight")) { + xllm::weight::load_weight(state_dict, "weight", weight_, weight_is_loaded_); + // weight need to be transposed when using addmm + if (with_bias_) { + torch::Tensor transposed = weight_.data().transpose(0, 1).contiguous(); + weight_.set_data(transposed); + } + } + if (with_bias_) { + weight::load_weight(state_dict, "bias", bias_, bias_is_loaded_); + } +} + } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/add_matmul.h b/xllm/core/layers/common/add_matmul.h index 75530a9e4..2542d3e82 100644 --- a/xllm/core/layers/common/add_matmul.h +++ b/xllm/core/layers/common/add_matmul.h @@ -29,9 +29,9 @@ class AddMatmulImpl : public torch::nn::Module { bool with_bias, const torch::TensorOptions& options); - torch::Tensor forward(const torch::Tensor& x); + virtual torch::Tensor forward(const torch::Tensor& x); - void load_state_dict(const xllm::StateDict& state_dict); + virtual void load_state_dict(const xllm::StateDict& state_dict); void verify_loaded_weights(const std::string& prefix) const; @@ -59,5 +59,18 @@ class FusedAddMatmulImpl : public AddMatmulImpl { }; TORCH_MODULE(FusedAddMatmul); +class AddMatmulWeightTransposedImpl : public AddMatmulImpl { + public: + AddMatmulWeightTransposedImpl(int64_t in, + int64_t out, + bool with_bias, + const torch::TensorOptions& options); + + torch::Tensor forward(const torch::Tensor& x) override; + + void load_state_dict(const xllm::StateDict& state_dict) override; +}; + +TORCH_MODULE(AddMatmulWeightTransposed); } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 8ff1955c7..399d3eccf 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -101,7 +101,7 @@ AttentionMetadata build_attention_metadata( attn_metadata.is_prefill = params.batch_forward_type.is_prefill(); if (!attn_metadata.is_prefill || enable_mla) { attn_metadata.block_table = params.block_tables; -#if !defined(USE_NPU) +#if !defined(USE_NPU) && !defined(USE_CUDA) attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens attn_metadata.q_seq_lens = torch::diff(params.q_seq_lens); // q seqlens #endif diff --git a/xllm/core/layers/common/dense_mlp.cpp b/xllm/core/layers/common/dense_mlp.cpp index 79ce81346..bb95dd0ff 100644 --- a/xllm/core/layers/common/dense_mlp.cpp +++ b/xllm/core/layers/common/dense_mlp.cpp @@ -31,7 +31,8 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size, bool enable_result_reduction, const QuantArgs& quant_args, ProcessGroup* process_group, - const torch::TensorOptions& options) + const torch::TensorOptions& options, + const std::string& module_prefix) : is_gated_(is_gated), intermediate_size_(intermediate_size), process_group_(process_group), @@ -73,13 +74,17 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size, act_ = register_module("act", Activation(hidden_act_, is_gated_)); // 2. down + const auto down_proj_quant_args = + module_prefix.empty() + ? quant_args + : quant_args.for_module(module_prefix + ".down_proj"); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size_, hidden_size, /*bias=*/has_bias, /*input_is_parallelized=*/true, enable_result_reduction, - quant_args, + down_proj_quant_args, process_group_, options, down_proj_extra_args)); diff --git a/xllm/core/layers/common/dense_mlp.h b/xllm/core/layers/common/dense_mlp.h index 545799558..8b4b2248d 100644 --- a/xllm/core/layers/common/dense_mlp.h +++ b/xllm/core/layers/common/dense_mlp.h @@ -38,7 +38,8 @@ class DenseMLPImpl : public torch::nn::Module { bool enable_result_reduction, const QuantArgs& quant_args, ProcessGroup* process_group, - const torch::TensorOptions& options); + const torch::TensorOptions& options, + const std::string& module_prefix = ""); torch::Tensor forward(const torch::Tensor& hidden_states); diff --git a/xllm/core/layers/common/linear.cpp b/xllm/core/layers/common/linear.cpp index 746209beb..e322fa5f8 100644 --- a/xllm/core/layers/common/linear.cpp +++ b/xllm/core/layers/common/linear.cpp @@ -181,13 +181,14 @@ torch::Tensor fp8_linear_forward( } // namespace ColumnParallelLinearImpl::ColumnParallelLinearImpl(const ModelContext& context) - : ColumnParallelLinearImpl(context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*gather_output=*/true, - context.get_quant_args(), - context.get_parallel_args().tp_group_, - context.get_tensor_options()) {} + : ColumnParallelLinearImpl( + context.get_model_args().hidden_size(), + context.get_model_args().vocab_size(), + /*bias=*/false, + /*gather_output=*/true, + QuantArgs{}, // do not use quantization for lm_head + context.get_parallel_args().tp_group_, + context.get_tensor_options()) {} // Linear layer with column parallelism. ColumnParallelLinearImpl::ColumnParallelLinearImpl( @@ -667,17 +668,6 @@ std::optional QKVParallelLinearImpl::get_input_scale() const { return std::nullopt; } -// Linear layer with row parallelism. -RowParallelLinearImpl::RowParallelLinearImpl(const ModelContext& context) - : RowParallelLinearImpl(context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*input_is_parallelized=*/false, - /*enable_result_reduction=*/true, - context.get_quant_args(), - context.get_parallel_args().tp_group_, - context.get_tensor_options()) {} - // Linear layer with row parallelism. RowParallelLinearImpl::RowParallelLinearImpl( int64_t in_features, diff --git a/xllm/core/layers/common/linear.h b/xllm/core/layers/common/linear.h index 588943f21..6e7c7aee2 100644 --- a/xllm/core/layers/common/linear.h +++ b/xllm/core/layers/common/linear.h @@ -198,8 +198,6 @@ TORCH_MODULE(QKVParallelLinear); // - - class RowParallelLinearImpl : public torch::nn::Module { public: - RowParallelLinearImpl(const ModelContext& context); - RowParallelLinearImpl( int64_t in_features, int64_t out_features, diff --git a/xllm/core/layers/common/lm_head.h b/xllm/core/layers/common/lm_head.h index 5df501353..03d722df7 100644 --- a/xllm/core/layers/common/lm_head.h +++ b/xllm/core/layers/common/lm_head.h @@ -20,26 +20,13 @@ limitations under the License. namespace xllm { namespace layer { -class LmHead : public torch::nn::ModuleHolder { +class LmHead : public torch::nn::ModuleHolder { public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = RowParallelLinearImpl; + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl; LmHead(const ModelContext& context) - : ModuleHolder(std::make_shared( - // NOTE: Quantization should NOT be used for the final language - // modeling head (lm_head). The output logits must remain in high - // precision (typically bfloat16/float16) for numerical stability - // and correct evaluation of loss and predictions. Always use - // unquantized weights here. - context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*input_is_parallelized=*/false, - /*enable_result_reduction=*/true, - QuantArgs{}, // do not use quantization for lm_head! - context.get_parallel_args().tp_group_, - context.get_tensor_options())) {} + : ModuleHolder(std::make_shared(context)) {} }; } // namespace layer diff --git a/xllm/core/layers/common/oxygen_vision_attention.cpp b/xllm/core/layers/common/oxygen_vision_attention.cpp new file mode 100644 index 000000000..f49f3f2e7 --- /dev/null +++ b/xllm/core/layers/common/oxygen_vision_attention.cpp @@ -0,0 +1,117 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "oxygen_vision_attention.h" + +#if defined(USE_MLU) +#include "kernels/mlu/mlu_ops_api.h" +#endif +#include "kernels/ops_api.h" +namespace xllm { +namespace layer { + +OxygenVisionAttentionImpl::OxygenVisionAttentionImpl( + const ModelContext& context) + : Qwen2VisionAttentionImpl(context, false) {} + +torch::Tensor OxygenVisionAttentionImpl::forward( + torch::Tensor& hidden_states, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& params) { + // 1. qkv projection + auto qkv = qkv_proj_->forward(hidden_states); + // 2. split qkv + auto qkv_split = split_qkv(qkv); + // 3. transpose [s, b, h, d] -> [b, s, h, d] + for (auto& tensor : qkv_split) { + tensor = tensor.transpose(0, 1).contiguous(); + } + auto q = qkv_split[0]; + auto k = qkv_split[1]; + auto v = qkv_split[2]; + int64_t B = q.size(0); + int64_t S = q.size(1); + int64_t head_dim = q.size(3); + CHECK_EQ(head_dim, hidden_size_per_attention_head_) << "head_dim mismatch"; + int32_t max_seqlen = + *std::max_element(cu_seq_len_vec.begin(), cu_seq_len_vec.end()); + + // 4. rope + // Reshape q, k from [B, S, H, D] to [B*S, H, D] before applying RoPE so + // that the RoPE kernel sees the correct total token count (B*S = seq_len), + // not just the batch dimension (B=1). + q = q.reshape({B * S, num_attention_heads_per_partition_, head_dim}); + k = k.reshape({B * S, num_attention_heads_per_partition_, head_dim}); + + // Apply rotary position embedding to both q and k seperately. + xllm::kernel::RotaryParams rotary_params; + rotary_params.q = q; + rotary_params.sin = m_sin_pos; + rotary_params.cos = m_cos_pos; + rotary_params.interleaved = false; + rotary_params.discrete = false; + rotary_params.cu_query_lens = cu_seq_len; + rotary_params.max_query_len = max_seqlen; + xllm::kernel::apply_rotary(rotary_params); + rotary_params.q = k; + xllm::kernel::apply_rotary(rotary_params); + + // q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + // q and k are already [B*S, H, D] after the reshape above; just + // flatten v to the same shape. + v = v.view({B * S, v.size(2), v.size(3)}); + torch::Tensor output = torch::zeros_like(q); + + // 5. store k/v cache and do attention +#if defined(USE_MLU) + std::optional output_lse = std::nullopt; + + xllm::kernel::mlu::batch_prefill(q, + k, + v, + output, + output_lse, + cu_seq_len, + cu_seq_len, + /*alibi_slope=*/std::nullopt, + /*alibi_bias=*/std::nullopt, + /*q_quant_scale=*/std::nullopt, + /*k_quant_scale=*/std::nullopt, + /*v_quant_scale=*/std::nullopt, + /*out_quant_scale=*/std::nullopt, + /*block_table=*/std::nullopt, + max_seqlen, + max_seqlen, + scale_, + /*is_causal=*/false, + /*window_size_left=*/-1, + /*window_size_right=*/-1, + /*compute_dtype=*/"half", + /*return_lse=*/false); +#endif + + // context_layer = rearrange(output, "(b s) h d -> s b (h d)", b=batch_size) + output = output.view({B, S, -1}); + // [B, S, ...] -> [S, B, ...] + output = output.transpose(0, 1).reshape({-1, output.size(-1)}); + // 6. output projection + return proj_->forward(output); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/oxygen_vision_attention.h b/xllm/core/layers/common/oxygen_vision_attention.h new file mode 100644 index 000000000..cde16aaa0 --- /dev/null +++ b/xllm/core/layers/common/oxygen_vision_attention.h @@ -0,0 +1,42 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "qwen2_vision_attention.h" + +namespace xllm { +namespace layer { + +class OxygenVisionAttentionImpl : public Qwen2VisionAttentionImpl { + public: + OxygenVisionAttentionImpl() = default; + OxygenVisionAttentionImpl(const ModelContext& context); + + torch::Tensor forward(torch::Tensor& hidden_states, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params) override; +}; +TORCH_MODULE(OxygenVisionAttention); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen2_attention.cpp b/xllm/core/layers/common/qwen2_attention.cpp index ed5fbc852..431e57ead 100644 --- a/xllm/core/layers/common/qwen2_attention.cpp +++ b/xllm/core/layers/common/qwen2_attention.cpp @@ -25,7 +25,7 @@ limitations under the License. namespace { inline bool is_qwen3_model(const std::string& model_type) { static const std::unordered_set qwen3_type_set = { - "qwen3", "qwen3_vl", "qwen3_moe", "qwen3_vl_moe"}; + "qwen3", "qwen3_vl", "qwen3_moe", "qwen3_vl_moe", "oxygenvlm"}; return qwen3_type_set.contains(model_type); } diff --git a/xllm/core/layers/common/qwen2_vision_attention.cpp b/xllm/core/layers/common/qwen2_vision_attention.cpp index 56aa06f6e..599209924 100644 --- a/xllm/core/layers/common/qwen2_vision_attention.cpp +++ b/xllm/core/layers/common/qwen2_vision_attention.cpp @@ -24,8 +24,8 @@ limitations under the License. namespace xllm { namespace layer { -Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl( - const ModelContext& context) { +Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl(const ModelContext& context, + bool has_bias) { const auto& args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); const auto& parallel_args = context.get_parallel_args(); @@ -47,7 +47,7 @@ Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl( num_attention_heads_per_partition_, hidden_size_per_attention_head_, /*num_kv_head_replicas=*/1, - /*bias=*/true, + /*bias=*/has_bias, /*gather_output=*/false, parallel_args, options)); @@ -55,7 +55,7 @@ Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl( proj_ = register_module("proj", RowParallelLinear(hidden_size, hidden_size, - /*bias=*/true, + /*bias=*/has_bias, /*input_is_parallelized=*/true, /*if_reduce_results=*/true, quant_args, diff --git a/xllm/core/layers/common/qwen2_vision_attention.h b/xllm/core/layers/common/qwen2_vision_attention.h index 463ce8620..b59e75b62 100644 --- a/xllm/core/layers/common/qwen2_vision_attention.h +++ b/xllm/core/layers/common/qwen2_vision_attention.h @@ -31,18 +31,18 @@ namespace layer { class Qwen2VisionAttentionImpl : public torch::nn::Module { public: Qwen2VisionAttentionImpl() = default; - Qwen2VisionAttentionImpl(const ModelContext& context); + Qwen2VisionAttentionImpl(const ModelContext& context, bool has_bias = true); - torch::Tensor forward(torch::Tensor& hidden_states, - torch::Tensor& m_cos_pos, - torch::Tensor& m_sin_pos, - torch::Tensor& cu_seq_len, - std::vector& cu_seq_len_vec, - ModelInputParams& input_params); + virtual torch::Tensor forward(torch::Tensor& hidden_states, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params); void load_state_dict(const StateDict& state_dict); - private: + protected: std::vector split_qkv(const torch::Tensor& qkv); int64_t hidden_size_per_attention_head_; diff --git a/xllm/core/layers/common/qwen3_next_rms_norm.cpp b/xllm/core/layers/common/qwen3_next_rms_norm.cpp index fa498a847..8735d8d3f 100644 --- a/xllm/core/layers/common/qwen3_next_rms_norm.cpp +++ b/xllm/core/layers/common/qwen3_next_rms_norm.cpp @@ -17,6 +17,8 @@ limitations under the License. #include +#include "xllm/core/kernels/ops_api.h" + namespace xllm { namespace layer { @@ -28,15 +30,12 @@ Qwen3NextRMSNormImpl::Qwen3NextRMSNormImpl(int64_t dim, } torch::Tensor Qwen3NextRMSNormImpl::forward(torch::Tensor& input) { - auto input_dtype = input.dtype(); - input = input.to(torch::kFloat32); - - // Calculate RMS - auto variance = torch::mean(torch::pow(input, 2), -1, true); - auto normalized = input * torch::rsqrt(variance + eps_); - - // Apply weight and convert back to original dtype - return (normalized * (1.0f + weight_.to(torch::kFloat32))).to(input_dtype); + xllm::kernel::GemmaRMSNormParams norm_params; + norm_params.x = input; + norm_params.gamma = weight_; + norm_params.epsilon = eps_; + xllm::kernel::gemma_rms_norm(norm_params); + return norm_params.norm_out; } void Qwen3NextRMSNormImpl::load_state_dict(const StateDict& state_dict) { diff --git a/xllm/core/layers/common/tests/deepseek_v2_attention_multi_device_tests.cpp b/xllm/core/layers/common/tests/deepseek_v2_attention_multi_device_tests.cpp index da049073b..90ebac045 100644 --- a/xllm/core/layers/common/tests/deepseek_v2_attention_multi_device_tests.cpp +++ b/xllm/core/layers/common/tests/deepseek_v2_attention_multi_device_tests.cpp @@ -901,7 +901,6 @@ int32_t run_attention_prefill_sp_baseline_test_child(int32_t rank, return EXIT_CODE_SKIP; } - FLAGS_enable_mla = true; FLAGS_block_size = 16; const int32_t device_index = rank % dev_count; xllm::Device xllm_device(device_index); diff --git a/xllm/core/layers/common/tests/dense_mlp_tests.cpp b/xllm/core/layers/common/tests/dense_mlp_tests.cpp index 36cf5ae60..dd0b5c046 100644 --- a/xllm/core/layers/common/tests/dense_mlp_tests.cpp +++ b/xllm/core/layers/common/tests/dense_mlp_tests.cpp @@ -318,6 +318,54 @@ TEST_F(DenseMLPTest, SmoothquantLoadStateDictTest) { LOG(INFO) << "State dict loading test passed - output sum: " << output_sum; } +TEST_F(DenseMLPTest, Fp8IgnoredDownProjLoadsAsUnquantized) { + QuantArgs fp8_quant_args; + fp8_quant_args.quant_method() = kQuantMethodFp8; + fp8_quant_args.bits() = 8; + fp8_quant_args.activation_dynamic() = false; + fp8_quant_args.ignored_modules() = {"model.layers.1.mlp.down_proj"}; + + const int64_t hidden_size = 16; + const int64_t intermediate_size = 32; + auto mlp = DenseMLP(DenseMLPImpl(hidden_size, + intermediate_size, + /*is_gated=*/true, + /*has_bias=*/false, + /*hidden_act=*/"silu", + /*enable_result_reduction=*/true, + fp8_quant_args, + parallel_args_.tp_group_, + options_, + "model.layers.1.mlp")); + + std::unordered_map weight_dict; + auto fp8_weight_options = options_.dtype(torch::kFloat8_e4m3fn); + auto scale_options = options_.dtype(torch::kFloat32); + + weight_dict["gate_proj.weight"] = + torch::zeros({intermediate_size, hidden_size}, fp8_weight_options); + weight_dict["gate_proj.weight_scale"] = torch::ones({1}, scale_options); + weight_dict["gate_proj.input_scale"] = torch::ones({1}, scale_options); + + weight_dict["up_proj.weight"] = + torch::zeros({intermediate_size, hidden_size}, fp8_weight_options); + weight_dict["up_proj.weight_scale"] = torch::ones({1}, scale_options); + weight_dict["up_proj.input_scale"] = torch::ones({1}, scale_options); + + weight_dict["down_proj.weight"] = + torch::zeros({hidden_size, intermediate_size}, options_); + + StateDict state_dict(weight_dict); + mlp->load_state_dict(state_dict); + + const auto params = mlp->named_parameters(/*recurse=*/true); + EXPECT_TRUE(params.contains("gate_up_proj.weight_scale")); + EXPECT_TRUE(params.contains("gate_up_proj.input_scale")); + EXPECT_TRUE(params.contains("down_proj.weight")); + EXPECT_FALSE(params.contains("down_proj.weight_scale")); + EXPECT_FALSE(params.contains("down_proj.input_scale")); +} + TEST_F(DenseMLPTest, SmoothquantPrecisionVerificationTest) { // Test precision verification with custom input and expected output const int64_t batch_size = 16; diff --git a/xllm/core/layers/cuda/attention.cpp b/xllm/core/layers/cuda/attention.cpp index 800cbffb0..55703874a 100644 --- a/xllm/core/layers/cuda/attention.cpp +++ b/xllm/core/layers/cuda/attention.cpp @@ -16,7 +16,7 @@ limitations under the License. #include "attention.h" #include "base_attention_impl.h" -#include "core/common/rec_model_utils.h" +#include "core/util/rec_model_utils.h" #include "flashinfer_attention.h" #include "xattention.h" diff --git a/xllm/core/layers/mlu/attention.cpp b/xllm/core/layers/mlu/attention.cpp index e4049777c..d1ffe09e1 100644 --- a/xllm/core/layers/mlu/attention.cpp +++ b/xllm/core/layers/mlu/attention.cpp @@ -191,6 +191,17 @@ void AttentionImpl::prefill_forward( {total_seqlens, num_kv_heads_, head_size_}, query.options()); torch::Tensor value_dequant; + std::optional value_cache_for_dequant = v_cache; + std::optional value_cache_scale_for_dequant = + v_cache_scale; + if (enable_mla_) { + // MLA stores latent cache only in k_cache, but the MLU dequant API + // requires both key/value carriers. Reuse the latent cache as a dummy + // value carrier and slice the real V view from key_dequant below. + value_dequant = torch::zeros_like(key_dequant); + value_cache_for_dequant = k_cache; + value_cache_scale_for_dequant = k_cache_scale; + } if (v_cache_scale.has_value() && v_cache_scale->defined() && v_cache_scale->numel() > 0) { value_dequant = torch::zeros( @@ -202,9 +213,9 @@ void AttentionImpl::prefill_forward( dequant_params.key = key_dequant; dequant_params.value = value_dequant; dequant_params.key_cache = k_cache; - dequant_params.value_cache = v_cache; + dequant_params.value_cache = value_cache_for_dequant; dequant_params.key_cache_quant_scale = k_cache_scale; - dequant_params.value_cache_quant_scale = v_cache_scale; + dequant_params.value_cache_quant_scale = value_cache_scale_for_dequant; dequant_params.context_lengths = attn_metadata.kv_seq_lens; dequant_params.max_context_len = attn_metadata.max_seq_len; dequant_params.context_seq_offset = std::nullopt; diff --git a/xllm/core/layers/npu/loader/base_loader.cpp b/xllm/core/layers/npu/loader/base_loader.cpp index 8c5b2c337..6e5d8ca47 100644 --- a/xllm/core/layers/npu/loader/base_loader.cpp +++ b/xllm/core/layers/npu/loader/base_loader.cpp @@ -176,5 +176,120 @@ torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) { LOG(FATAL) << "Unsupported dtype string: " << dtype_str; } +at::Tensor BaseLoader::pad_vocab_tensor(const at::Tensor& tensor, + int64_t padded_vocab_size) const { + if (tensor.size(0) >= padded_vocab_size) { + return tensor; + } + at::Tensor padded_tensor = + torch::zeros({padded_vocab_size, tensor.size(1)}, tensor.options()); + padded_tensor.slice(0, 0, tensor.size(0)) = tensor; + return padded_tensor; +} + +at::Tensor BaseLoader::shard_padded_tensor(const at::Tensor& padded_tensor, + int dim, + int rank, + int world_size) const { + if (world_size <= 1) { + return padded_tensor; + } + auto chunks = padded_tensor.chunk(world_size, dim); + return chunks[rank]; +} + +void BaseLoader::set_weight_with_padding(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int64_t padded_vocab_size, + bool to_host) { + auto device = to_host ? at::kCPU : device_; + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + at::Tensor mutable_tensor = tensor; + if (padded_vocab_size > tensor.size(0)) { + mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size); + } + correct_tensor_dtype(mutable_tensor, tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } + } + } +} + +void BaseLoader::set_weight_with_padding(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int rank, + int world_size, + int64_t padded_vocab_size, + bool to_host) { + auto device = to_host ? at::kCPU : device_; + if (world_size <= 1) { + set_weight_with_padding(state_dict, + tensor_name, + weight_position, + dim, + padded_vocab_size, + to_host); + return; + } + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + at::Tensor mutable_tensor = tensor; + if (padded_vocab_size > tensor.size(0)) { + // Memory-optimized path for vocabulary dimension sharding + if (dim == 0) { + int64_t shard_size = padded_vocab_size / world_size; + int64_t start_idx = rank * shard_size; + int64_t end_idx = (rank + 1) * shard_size; + if (start_idx >= tensor.size(0)) { + mutable_tensor = + torch::zeros({shard_size, tensor.size(1)}, tensor.options()); + } else { + auto valid_part = + tensor.slice(0, start_idx, std::min(end_idx, tensor.size(0))); + if (valid_part.size(0) < shard_size) { + mutable_tensor = + torch::zeros({shard_size, tensor.size(1)}, tensor.options()); + mutable_tensor.slice(0, 0, valid_part.size(0)).copy_(valid_part); + } else { + mutable_tensor = valid_part.clone(); + } + } + } else { + // Non-vocabulary dimension: use original approach + mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size); + mutable_tensor = + shard_padded_tensor(mutable_tensor, dim, rank, world_size); + } + } else { + mutable_tensor = + state_dict.get_sharded_tensor(tensor_name, dim, rank, world_size); + } + correct_tensor_dtype(mutable_tensor, tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } + } + } +} + +int64_t BaseLoader::get_padded_vocab_size(const ModelContext& context) const { + int64_t vocab_size = context.get_model_args().vocab_size(); + int32_t local_tp_size = dp_local_tp_size_; + if (vocab_size > 0 && local_tp_size > 1 && vocab_size % local_tp_size != 0) { + return ((vocab_size + local_tp_size - 1) / local_tp_size) * local_tp_size; + } + return vocab_size; +} + } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/loader/base_loader.h b/xllm/core/layers/npu/loader/base_loader.h index 87f1b6689..8b456914d 100644 --- a/xllm/core/layers/npu/loader/base_loader.h +++ b/xllm/core/layers/npu/loader/base_loader.h @@ -111,7 +111,33 @@ class BaseLoader { int rank, int world_size, bool to_host = false); + + void set_weight_with_padding(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int64_t padded_vocab_size, + bool to_host = false); + + void set_weight_with_padding(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int rank, + int world_size, + int64_t padded_vocab_size, + bool to_host = false); + + at::Tensor pad_vocab_tensor(const at::Tensor& tensor, + int64_t padded_vocab_size) const; + + at::Tensor shard_padded_tensor(const at::Tensor& padded_tensor, + int dim, + int rank, + int world_size) const; + + int64_t get_padded_vocab_size(const ModelContext& context) const; }; } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp index fa24e5458..13324fd14 100644 --- a/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp +++ b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp @@ -604,14 +604,14 @@ void DeekseekV2DecoderLoader::merge_experts_weights() { device_); } - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - device_, - /*transpose=*/false); - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - // at_npu::native::npu_format_cast(mlp_down_weight, 29); + // Optimization in coordination with MoeGroupedMatmulWeightNZOperation: + // ** Non-quantized weights use the ACL_FORMAT_FRACTAL_NZ layout, + // ** while the quantized version continues to use the ACL_FORMAT_ND layout. + int data_type = quantize_type_ == "" ? ACL_FORMAT_FRACTAL_NZ : ACL_FORMAT_ND; + torch::Tensor mlp_down_weight = merge_experts_weights( + experts_weights_["down_proj.weight"], device_, /*transpose=*/false); at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + at_npu::native::npu_format_cast(mlp_down_weight, data_type).contiguous(); if (quantize_type_ == "w8a8_dynamic") { at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( diff --git a/xllm/core/layers/npu/loader/lm_head_loader.cpp b/xllm/core/layers/npu/loader/lm_head_loader.cpp index 78e7e7961..73e679d5e 100644 --- a/xllm/core/layers/npu/loader/lm_head_loader.cpp +++ b/xllm/core/layers/npu/loader/lm_head_loader.cpp @@ -22,17 +22,32 @@ LmHeadLoader::LmHeadLoader(uint64_t weight_count, const ModelContext& context) : BaseLoader(weight_count, context) { auto options = context.get_tensor_options(); at_weight_tensors_[0] = torch::zeros({1}).to(options); + vocab_size_ = context.get_model_args().vocab_size(); + padded_vocab_size_ = get_padded_vocab_size(context); } void LmHeadLoader::load_state_dict(const StateDict& state_dict) { - if (cp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); - } else if (dp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_); + if (cp_size_ > 1 || dp_size_ > 1) { + set_weight_with_padding(state_dict, + "weight", + 0, + 0, + dp_local_tp_rank_, + dp_local_tp_size_, + padded_vocab_size_, + false); + } else if (parallel_args_.world_size() > 1) { + set_weight_with_padding(state_dict, + "weight", + 0, + 0, + parallel_args_.rank(), + parallel_args_.world_size(), + padded_vocab_size_, + false); } else { - set_weight(state_dict, "weight", 0, 1); + set_weight_with_padding( + state_dict, "weight", 0, 0, padded_vocab_size_, false); } } diff --git a/xllm/core/layers/npu/loader/lm_head_loader.h b/xllm/core/layers/npu/loader/lm_head_loader.h index df21b1ae9..7770390f8 100644 --- a/xllm/core/layers/npu/loader/lm_head_loader.h +++ b/xllm/core/layers/npu/loader/lm_head_loader.h @@ -25,6 +25,10 @@ class LmHeadLoader : public BaseLoader { void load_state_dict(const StateDict& state_dict) override; void verify_loaded_weights(const std::string& weight_str) const override; + + private: + int64_t vocab_size_ = -1; + int64_t padded_vocab_size_ = -1; }; } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/npu/loader/lm_head_manual_loader.cpp b/xllm/core/layers/npu/loader/lm_head_manual_loader.cpp index 7b4fea6a1..6e975835f 100644 --- a/xllm/core/layers/npu/loader/lm_head_manual_loader.cpp +++ b/xllm/core/layers/npu/loader/lm_head_manual_loader.cpp @@ -23,17 +23,32 @@ LmHeadManualLoader::LmHeadManualLoader(uint64_t weight_count, : BaseManualLoader(weight_count, context) { auto options = context.get_tensor_options(); at_weight_tensors_[0] = torch::zeros({1}).to(options); + vocab_size_ = context.get_model_args().vocab_size(); + padded_vocab_size_ = get_padded_vocab_size(context); } void LmHeadManualLoader::load_state_dict(const StateDict& state_dict) { - if (cp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_, true); - } else if (dp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_, true); + if (cp_size_ > 1 || dp_size_ > 1) { + set_weight_with_padding(state_dict, + "weight", + 0, + 0, + dp_local_tp_rank_, + dp_local_tp_size_, + padded_vocab_size_, + true); + } else if (parallel_args_.world_size() > 1) { + set_weight_with_padding(state_dict, + "weight", + 0, + 0, + parallel_args_.rank(), + parallel_args_.world_size(), + padded_vocab_size_, + true); } else { - set_weight(state_dict, "weight", 0, 1, true); + set_weight_with_padding( + state_dict, "weight", 0, 0, padded_vocab_size_, true); } } diff --git a/xllm/core/layers/npu/loader/lm_head_manual_loader.h b/xllm/core/layers/npu/loader/lm_head_manual_loader.h index d93098327..e91b3b58a 100644 --- a/xllm/core/layers/npu/loader/lm_head_manual_loader.h +++ b/xllm/core/layers/npu/loader/lm_head_manual_loader.h @@ -28,6 +28,10 @@ class LmHeadManualLoader : public BaseManualLoader { protected: void merge_host_at_weights() override; + + private: + int64_t vocab_size_ = -1; + int64_t padded_vocab_size_ = -1; }; } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_base_layer.cpp b/xllm/core/layers/npu/npu_base_layer.cpp index c963d2987..aacece869 100644 --- a/xllm/core/layers/npu/npu_base_layer.cpp +++ b/xllm/core/layers/npu/npu_base_layer.cpp @@ -61,7 +61,7 @@ atb::Status BaseLayer::execute_node(atb_speed::Model::Node& node, int node_id, aclrtEvent* event, std::atomic* event_flag) { - // TODO(by zhangminchao1@jd.com): Stream management needs to be refactored + // TODO: Stream management needs to be refactored // for better separation of concerns Current issues: // 1. ACLGraph capture requires execution on a non-default stream, so we // temporarily set the current stream diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index 25e331d39..1f5bf80d0 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -120,7 +120,7 @@ void NpuGlm4MoeDecoderImpl::initialize_basic_parameters( param.enableSplitFuse = (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill; - // TODO(zhangminchao1@jd.com): not support MTP model yet + // TODO: not support MTP model yet param.enableAclGraphPagedAttention = FLAGS_enable_graph && !is_prefill && args.n_layers() > 1; diff --git a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp index 4356b9bd4..a13303a5d 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp @@ -156,8 +156,6 @@ void NpuGlm4MoeDecoderLiteImpl::initialize_basic_parameters( (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill; param.enableAclGraphPagedAttention = false; - // TODO(zhangminchao1@jd.com): not support MTP model yet - // FLAGS_enable_graph && !is_prefill && args.n_layers() > 1; param.moeLinearTransposeType = (layer_id_ < args.first_k_dense_replace()) ? std::vector{-1, -1, -1, -1} diff --git a/xllm/core/layers/npu/npu_lm_head_impl.cpp b/xllm/core/layers/npu/npu_lm_head_impl.cpp index 36964e258..61c540e78 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.cpp +++ b/xllm/core/layers/npu/npu_lm_head_impl.cpp @@ -27,7 +27,6 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, bool isPrefill) { - const bool use_column_parallel = cp_size_ > 1; param.unpadInputs = true; param.gatherAhead = isPrefill; param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); @@ -35,8 +34,8 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, args.dtype() == "bfloat16"; param.linearParallelParam.unpadInputs = true; param.linearParallelParam.fusionLinearParam.transposeType = 1; + if (parallel_args.world_size() > 1) { - int32_t lm_head_tp_world_size = 1; if (parallel_args.mapping_data().empty()) { const bool use_local_tp = (dp_size_ > 1) || (cp_size_ > 1); if (use_local_tp) { @@ -53,20 +52,16 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, parallel_args.world_size(); } param.linearParallelParam.parallelType = - use_column_parallel ? atb_speed::common::COLUMN_PARALLEL - : atb_speed::common::ROW_PARALLEL; + atb_speed::common::COLUMN_PARALLEL; const int32_t tp_group_id = use_local_tp ? (parallel_args.rank() / dp_local_tp_size_) : 0; param.linearParallelParam.tensorParallelInfo.commDomain = std::to_string(tp_group_id); param.linearParallelParam.tensorParallelInfo.backend = FLAGS_communication_backend; - lm_head_tp_world_size = - param.linearParallelParam.tensorParallelInfo.worldSize; } else { param.linearParallelParam.parallelType = - use_column_parallel ? atb_speed::common::COLUMN_PARALLEL - : atb_speed::common::ROW_PARALLEL; + atb_speed::common::COLUMN_PARALLEL; atb_speed::common::ParallelInfo parallelInfo = parallel_args.mapping().Get(atb_speed::base::LM_HEAD_TP); param.linearParallelParam.tensorParallelInfo.rank = parallelInfo.rank; @@ -77,19 +72,23 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, parallelInfo.InitCommDomain( param.linearParallelParam.tensorParallelInfo.hcommInfo, param.linearParallelParam.tensorParallelInfo.commDomain); - lm_head_tp_world_size = - param.linearParallelParam.tensorParallelInfo.worldSize; param.contextParallelInfo = parallel_args.mapping().Get(atb_speed::base::ATTN_CP); } - if (!use_column_parallel) { - param.hiddenSizePerAttentionHead = - args.hidden_size() / lm_head_tp_world_size; - } } } NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) { + vocab_size_ = context.get_model_args().vocab_size(); + if (vocab_size_ > 0 && dp_local_tp_size_ > 1 && + vocab_size_ % dp_local_tp_size_ != 0) { + padded_vocab_size_ = + ((vocab_size_ + dp_local_tp_size_ - 1) / dp_local_tp_size_) * + dp_local_tp_size_; + } else { + padded_vocab_size_ = vocab_size_; + } + param_from_args(lm_head_param_prefill_, context.get_model_args(), context.get_parallel_args(), @@ -166,7 +165,11 @@ torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states, st = execute_node(lm_head_node_prefill_, nodeId); LOG_IF(FATAL, st != 0) << model_name_ << "execute lmhead node fail, error code: " << st; - return atOutTensors_[0]; + torch::Tensor output = atOutTensors_[0]; + if (padded_vocab_size_ > vocab_size_ && vocab_size_ > 0) { + output = output.slice(/*dim=*/-1, /*start=*/0, /*end=*/vocab_size_); + } + return output; } void NpuLmHeadImpl::build_node_variant_pack( diff --git a/xllm/core/layers/npu/npu_lm_head_impl.h b/xllm/core/layers/npu/npu_lm_head_impl.h index e96d487dd..bd3246f67 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.h +++ b/xllm/core/layers/npu/npu_lm_head_impl.h @@ -86,6 +86,9 @@ class NpuLmHeadImpl : public BaseLayer { std::vector> decode_tensor_storage_; atb::Tensor hidden_states_atb_; atb::Tensor seleted_idxes_atb_; + + int64_t vocab_size_ = -1; + int64_t padded_vocab_size_ = -1; }; TORCH_MODULE(NpuLmHead); diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp index 86ecd7a23..6ea2148f7 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp @@ -16,52 +16,1624 @@ limitations under the License. #include "npu_onerec_block_layer_impl.h" #include +#include +#include +#include +#include + +#include "common/global_flags.h" namespace xllm { namespace layer { +namespace { + +// Decoder normal mode: self-attn(29) + cross-attn(28) + layer-norm(4) + mlp(18) +// = 79 +static constexpr uint64_t kOneRecWeightCountPerLayer = 79; + +// Decoder MoE mode weights count (exclude runtime tensors like expert_array). +static constexpr uint64_t kOneRecMoeWeightCountPerLayer = 97; + +enum class OneRecBlockLayerTensorId : int32_t { + // Self-attention layer norm + IN_LAYER_NORM_WEIGHT = 0, + IN_LAYER_NORM_BIAS, + IN_INPUT_NORM_NEW_WEIGHT, + IN_INPUT_NORM_NEW_BIAS, + // Self-attention Q, K, V projections + IN_Q_WEIGHT, + IN_Q_BIAS, + IN_Q_DEQSCALE, + IN_Q_OFFSET, + IN_Q_SCALE, + IN_Q_COMPRESS_IDX, + + IN_K_WEIGHT, + IN_K_BIAS, + IN_K_DEQSCALE, + IN_K_OFFSET, + IN_K_SCALE, + IN_K_COMPRESS_IDX, + + IN_V_WEIGHT, + IN_V_BIAS, + IN_V_DEQSCALE, + IN_V_OFFSET, + IN_V_SCALE, + IN_V_COMPRESS_IDX, + + // Self-attention output projection + IN_SELF_ATTN_OUT_WEIGHT, + IN_SELF_ATTN_OUT_BIAS, + IN_SELF_ATTN_OUT_DEQSCALE, + IN_SELF_ATTN_OUT_OFFSET, + IN_SELF_ATTN_OUT_SCALE, + IN_SELF_ATTN_OUT_COMPRESS_IDX, + + // ONEREC relative attention bias (encoder only) + IN_RELATIVE_ATTENTION_BIAS_WEIGHT, + + // Cross-attention layer norm (decoder only) + IN_CROSS_LAYER_NORM_WEIGHT, + IN_CROSS_LAYER_NORM_BIAS, + IN_CROSS_LAYER_NORM_NEW_WEIGHT, + IN_CROSS_LAYER_NORM_NEW_BIAS, + + // Cross-attention Q, K, V projections (decoder only) + IN_CROSS_Q_WEIGHT, + IN_CROSS_Q_BIAS, + IN_CROSS_Q_DEQSCALE, + IN_CROSS_Q_OFFSET, + IN_CROSS_Q_SCALE, + IN_CROSS_Q_COMPRESS_IDX, + + IN_CROSS_K_WEIGHT, + IN_CROSS_K_BIAS, + IN_CROSS_K_DEQSCALE, + IN_CROSS_K_OFFSET, + IN_CROSS_K_SCALE, + IN_CROSS_K_COMPRESS_IDX, + + IN_CROSS_V_WEIGHT, + IN_CROSS_V_BIAS, + IN_CROSS_V_DEQSCALE, + IN_CROSS_V_OFFSET, + IN_CROSS_V_SCALE, + IN_CROSS_V_COMPRESS_IDX, + + // Cross-attention output projection (decoder only) + IN_CROSS_ATTN_OUT_WEIGHT, + IN_CROSS_ATTN_OUT_BIAS, + IN_CROSS_ATTN_OUT_DEQSCALE, + IN_CROSS_ATTN_OUT_OFFSET, + IN_CROSS_ATTN_OUT_SCALE, + IN_CROSS_ATTN_OUT_COMPRESS_IDX, + + // Final layer norm + IN_FINAL_LAYER_NORM_WEIGHT, + IN_FINAL_LAYER_NORM_BIAS, + IN_FINAL_LAYER_NORM_NEW_WEIGHT, + IN_FINAL_LAYER_NORM_NEW_BIAS, + + // Feed-forward network (gated activation) + IN_FFN_WI_0_WEIGHT = 61, // wi_0 (gate projection) + IN_FFN_WI_0_BIAS, + IN_FFN_WI_0_DEQSCALE, + IN_FFN_WI_0_OFFSET, + IN_FFN_WI_0_SCALE, + IN_FFN_WI_0_COMPRESS_IDX, + + IN_FFN_WI_1_WEIGHT, // wi_1 (up projection) + IN_FFN_WI_1_BIAS, + IN_FFN_WI_1_DEQSCALE, + IN_FFN_WI_1_OFFSET, + IN_FFN_WI_1_SCALE, + IN_FFN_WI_1_COMPRESS_IDX, + + IN_FFN_WO_WEIGHT, // wo (down projection) + IN_FFN_WO_BIAS, + IN_FFN_WO_DEQSCALE, + IN_FFN_WO_OFFSET, + IN_FFN_WO_SCALE, + IN_FFN_WO_COMPRESS_IDX, +}; + +constexpr int32_t kInLayerNormWeight = + static_cast(OneRecBlockLayerTensorId::IN_LAYER_NORM_WEIGHT); +constexpr int32_t kInLayerNormBias = + static_cast(OneRecBlockLayerTensorId::IN_LAYER_NORM_BIAS); +constexpr int32_t kInInputNormNewWeight = + static_cast(OneRecBlockLayerTensorId::IN_INPUT_NORM_NEW_WEIGHT); +constexpr int32_t kInInputNormNewBias = + static_cast(OneRecBlockLayerTensorId::IN_INPUT_NORM_NEW_BIAS); +constexpr int32_t kInQWeight = + static_cast(OneRecBlockLayerTensorId::IN_Q_WEIGHT); +constexpr int32_t kInQBias = + static_cast(OneRecBlockLayerTensorId::IN_Q_BIAS); +constexpr int32_t kInQDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_Q_DEQSCALE); +constexpr int32_t kInQOffset = + static_cast(OneRecBlockLayerTensorId::IN_Q_OFFSET); +constexpr int32_t kInQScale = + static_cast(OneRecBlockLayerTensorId::IN_Q_SCALE); +constexpr int32_t kInQCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_Q_COMPRESS_IDX); +constexpr int32_t kInKWeight = + static_cast(OneRecBlockLayerTensorId::IN_K_WEIGHT); +constexpr int32_t kInKBias = + static_cast(OneRecBlockLayerTensorId::IN_K_BIAS); +constexpr int32_t kInKDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_K_DEQSCALE); +constexpr int32_t kInKOffset = + static_cast(OneRecBlockLayerTensorId::IN_K_OFFSET); +constexpr int32_t kInKScale = + static_cast(OneRecBlockLayerTensorId::IN_K_SCALE); +constexpr int32_t kInKCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_K_COMPRESS_IDX); +constexpr int32_t kInVWeight = + static_cast(OneRecBlockLayerTensorId::IN_V_WEIGHT); +constexpr int32_t kInVBias = + static_cast(OneRecBlockLayerTensorId::IN_V_BIAS); +constexpr int32_t kInVDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_V_DEQSCALE); +constexpr int32_t kInVOffset = + static_cast(OneRecBlockLayerTensorId::IN_V_OFFSET); +constexpr int32_t kInVScale = + static_cast(OneRecBlockLayerTensorId::IN_V_SCALE); +constexpr int32_t kInVCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_V_COMPRESS_IDX); +constexpr int32_t kInSelfAttnOutWeight = + static_cast(OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_WEIGHT); +constexpr int32_t kInSelfAttnOutBias = + static_cast(OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_BIAS); +constexpr int32_t kInSelfAttnOutDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_DEQSCALE); +constexpr int32_t kInSelfAttnOutOffset = + static_cast(OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_OFFSET); +constexpr int32_t kInSelfAttnOutScale = + static_cast(OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_SCALE); +constexpr int32_t kInSelfAttnOutCompressIdx = static_cast( + OneRecBlockLayerTensorId::IN_SELF_ATTN_OUT_COMPRESS_IDX); +constexpr int32_t kInRelativeAttentionBiasWeight = static_cast( + OneRecBlockLayerTensorId::IN_RELATIVE_ATTENTION_BIAS_WEIGHT); +constexpr int32_t kInCrossLayerNormWeight = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_LAYER_NORM_WEIGHT); +constexpr int32_t kInCrossLayerNormBias = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_LAYER_NORM_BIAS); +constexpr int32_t kInCrossLayerNormNewWeight = static_cast( + OneRecBlockLayerTensorId::IN_CROSS_LAYER_NORM_NEW_WEIGHT); +constexpr int32_t kInCrossLayerNormNewBias = static_cast( + OneRecBlockLayerTensorId::IN_CROSS_LAYER_NORM_NEW_BIAS); +constexpr int32_t kInCrossQWeight = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_WEIGHT); +constexpr int32_t kInCrossQBias = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_BIAS); +constexpr int32_t kInCrossQDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_DEQSCALE); +constexpr int32_t kInCrossQOffset = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_OFFSET); +constexpr int32_t kInCrossQScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_SCALE); +constexpr int32_t kInCrossQCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_Q_COMPRESS_IDX); +constexpr int32_t kInCrossKWeight = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_WEIGHT); +constexpr int32_t kInCrossKBias = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_BIAS); +constexpr int32_t kInCrossKDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_DEQSCALE); +constexpr int32_t kInCrossKOffset = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_OFFSET); +constexpr int32_t kInCrossKScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_SCALE); +constexpr int32_t kInCrossKCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_K_COMPRESS_IDX); +constexpr int32_t kInCrossVWeight = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_WEIGHT); +constexpr int32_t kInCrossVBias = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_BIAS); +constexpr int32_t kInCrossVDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_DEQSCALE); +constexpr int32_t kInCrossVOffset = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_OFFSET); +constexpr int32_t kInCrossVScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_SCALE); +constexpr int32_t kInCrossVCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_V_COMPRESS_IDX); +constexpr int32_t kInCrossAttnOutWeight = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_WEIGHT); +constexpr int32_t kInCrossAttnOutBias = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_BIAS); +constexpr int32_t kInCrossAttnOutDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_DEQSCALE); +constexpr int32_t kInCrossAttnOutOffset = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_OFFSET); +constexpr int32_t kInCrossAttnOutScale = + static_cast(OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_SCALE); +constexpr int32_t kInCrossAttnOutCompressIdx = static_cast( + OneRecBlockLayerTensorId::IN_CROSS_ATTN_OUT_COMPRESS_IDX); +constexpr int32_t kInFinalLayerNormWeight = + static_cast(OneRecBlockLayerTensorId::IN_FINAL_LAYER_NORM_WEIGHT); +constexpr int32_t kInFinalLayerNormBias = + static_cast(OneRecBlockLayerTensorId::IN_FINAL_LAYER_NORM_BIAS); +constexpr int32_t kInFinalLayerNormNewWeight = static_cast( + OneRecBlockLayerTensorId::IN_FINAL_LAYER_NORM_NEW_WEIGHT); +constexpr int32_t kInFinalLayerNormNewBias = static_cast( + OneRecBlockLayerTensorId::IN_FINAL_LAYER_NORM_NEW_BIAS); +constexpr int32_t kInFfnWi0Weight = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_WEIGHT); +constexpr int32_t kInFfnWi0Bias = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_BIAS); +constexpr int32_t kInFfnWi0DeqScale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_DEQSCALE); +constexpr int32_t kInFfnWi0Offset = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_OFFSET); +constexpr int32_t kInFfnWi0Scale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_SCALE); +constexpr int32_t kInFfnWi0CompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_0_COMPRESS_IDX); +constexpr int32_t kInFfnWi1Weight = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_WEIGHT); +constexpr int32_t kInFfnWi1Bias = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_BIAS); +constexpr int32_t kInFfnWi1DeqScale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_DEQSCALE); +constexpr int32_t kInFfnWi1Offset = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_OFFSET); +constexpr int32_t kInFfnWi1Scale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_SCALE); +constexpr int32_t kInFfnWi1CompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WI_1_COMPRESS_IDX); +constexpr int32_t kInFfnWoWeight = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_WEIGHT); +constexpr int32_t kInFfnWoBias = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_BIAS); +constexpr int32_t kInFfnWoDeqScale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_DEQSCALE); +constexpr int32_t kInFfnWoOffset = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_OFFSET); +constexpr int32_t kInFfnWoScale = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_SCALE); +constexpr int32_t kInFfnWoCompressIdx = + static_cast(OneRecBlockLayerTensorId::IN_FFN_WO_COMPRESS_IDX); + +enum class OneRecMoeBlockLayerTensorId : int32_t { + // MoE weights (only used when use_moe=true) + IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 61, // routing weights + IN_BLOCK_SPARSE_MOE_GATE_BIAS = 62, // routing bias + IN_BLOCK_SPARSE_MOE_GATE_DESCALE, // gate descale + IN_BLOCK_SPARSE_MOE_GATE_OFFSET, // gate offset + IN_BLOCK_SPARSE_MOE_GATE_SCALE, // gate scale + IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX, // gate compress index + + // Shared expert weights + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + IN_MLP_GATEUP_BIAS_SHARED_EXPERT, + IN_MLP_GATEUP_DESCALE_SHARED_EXPERT, + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT, + + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, + IN_MLP_DOWN_BIAS_SHARED_EXPERT, + IN_MLP_DOWN_DESCALE_SHARED_EXPERT, + IN_MLP_DOWN_OFFSET_SHARED_EXPERT, + IN_MLP_DOWN_SCALE_SHARED_EXPERT, + IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT, + + // Shared expert gate weights + IN_SHARED_EXPERT_GATE_WEIGHT, + IN_SHARED_EXPERT_GATE_BIAS, + IN_SHARED_EXPERT_GATE_DESCALE, + IN_SHARED_EXPERT_GATE_OFFSET, + IN_SHARED_EXPERT_GATE_SCALE, + IN_SHARED_EXPERT_GATE_COMPRESS_IDX, + + // Expert weights + IN_MLP_GATEUP_WEIGHT_EXPERT, + IN_MLP_GATEUP_BIAS_EXPERT, + IN_MLP_GATEUP_DESCALE_EXPERT, + IN_MLP_GATEUP_OFFSET_EXPERT, + IN_MLP_GATEUP_SCALE_EXPERT, + IN_MLP_GATEUP_COMPRESS_IDX_EXPERT, + + IN_MLP_DOWN_WEIGHT_EXPERT, + IN_MLP_DOWN_BIAS_EXPERT, + IN_MLP_DOWN_DESCALE_EXPERT, + IN_MLP_DOWN_OFFSET_EXPERT, + IN_MLP_DOWN_SCALE_EXPERT, + IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 96, + + // Runtime tensors (not part of weight tensor array) + IN_EXPERT_ARRAY = 97, + IN_EXPERT_GROUP = 98, + IN_ONE_HOT = 99, + IN_ZERO_HOT = 100, + + // Legacy aliases for backward compatibility + IN_MOE_EXPERT_W1_WEIGHT = IN_MLP_GATEUP_WEIGHT_EXPERT, + IN_MOE_EXPERT_W2_WEIGHT = IN_MLP_DOWN_WEIGHT_EXPERT, + IN_MOE_EXPERT_W3_WEIGHT = IN_MLP_GATEUP_WEIGHT_EXPERT, + IN_MOE_SHARED_W1_WEIGHT = IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + IN_MOE_SHARED_W2_WEIGHT = IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, +}; + +constexpr int32_t kInBlockSparseMoeGateWeight = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_WEIGHT); +constexpr int32_t kInBlockSparseMoeGateBias = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_BIAS); +constexpr int32_t kInBlockSparseMoeGateDescale = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_DESCALE); +constexpr int32_t kInBlockSparseMoeGateOffset = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_OFFSET); +constexpr int32_t kInBlockSparseMoeGateScale = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_SCALE); +constexpr int32_t kInBlockSparseMoeGateCompressIdx = static_cast( + OneRecMoeBlockLayerTensorId::IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX); +constexpr int32_t kInMlpGateUpWeightSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT); +constexpr int32_t kInMlpGateUpBiasSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_BIAS_SHARED_EXPERT); +constexpr int32_t kInMlpGateUpDescaleSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_DESCALE_SHARED_EXPERT); +constexpr int32_t kInMlpGateUpOffsetSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_OFFSET_SHARED_EXPERT); +constexpr int32_t kInMlpGateUpScaleSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_SCALE_SHARED_EXPERT); +constexpr int32_t kInMlpGateUpCompressIdxSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT); +constexpr int32_t kInMlpDownWeightSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_WEIGHT_SHARED_EXPERT); +constexpr int32_t kInMlpDownBiasSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_BIAS_SHARED_EXPERT); +constexpr int32_t kInMlpDownDescaleSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_DESCALE_SHARED_EXPERT); +constexpr int32_t kInMlpDownOffsetSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_OFFSET_SHARED_EXPERT); +constexpr int32_t kInMlpDownScaleSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_SCALE_SHARED_EXPERT); +constexpr int32_t kInMlpDownCompressIdxSharedExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT); +constexpr int32_t kInSharedExpertGateWeight = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_WEIGHT); +constexpr int32_t kInSharedExpertGateBias = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_BIAS); +constexpr int32_t kInSharedExpertGateDescale = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_DESCALE); +constexpr int32_t kInSharedExpertGateOffset = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_OFFSET); +constexpr int32_t kInSharedExpertGateScale = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_SCALE); +constexpr int32_t kInSharedExpertGateCompressIdx = static_cast( + OneRecMoeBlockLayerTensorId::IN_SHARED_EXPERT_GATE_COMPRESS_IDX); +constexpr int32_t kInMlpGateUpWeightExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_WEIGHT_EXPERT); +constexpr int32_t kInMlpGateUpBiasExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_BIAS_EXPERT); +constexpr int32_t kInMlpGateUpDescaleExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_DESCALE_EXPERT); +constexpr int32_t kInMlpGateUpOffsetExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_OFFSET_EXPERT); +constexpr int32_t kInMlpGateUpScaleExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_SCALE_EXPERT); +constexpr int32_t kInMlpGateUpCompressIdxExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_GATEUP_COMPRESS_IDX_EXPERT); +constexpr int32_t kInMlpDownWeightExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_WEIGHT_EXPERT); +constexpr int32_t kInMlpDownBiasExpert = + static_cast(OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_BIAS_EXPERT); +constexpr int32_t kInMlpDownDescaleExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_DESCALE_EXPERT); +constexpr int32_t kInMlpDownOffsetExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_OFFSET_EXPERT); +constexpr int32_t kInMlpDownScaleExpert = + static_cast(OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_SCALE_EXPERT); +constexpr int32_t kInMlpDownCompressIdxExpert = static_cast( + OneRecMoeBlockLayerTensorId::IN_MLP_DOWN_COMPRESS_IDX_EXPERT); +constexpr int32_t kInExpertArray = + static_cast(OneRecMoeBlockLayerTensorId::IN_EXPERT_ARRAY); +constexpr int32_t kInExpertGroup = + static_cast(OneRecMoeBlockLayerTensorId::IN_EXPERT_GROUP); +constexpr int32_t kInOneHot = + static_cast(OneRecMoeBlockLayerTensorId::IN_ONE_HOT); +constexpr int32_t kInZeroHot = + static_cast(OneRecMoeBlockLayerTensorId::IN_ZERO_HOT); +constexpr int32_t kInMoeExpertW1Weight = + static_cast(OneRecMoeBlockLayerTensorId::IN_MOE_EXPERT_W1_WEIGHT); +constexpr int32_t kInMoeExpertW2Weight = + static_cast(OneRecMoeBlockLayerTensorId::IN_MOE_EXPERT_W2_WEIGHT); +constexpr int32_t kInMoeExpertW3Weight = + static_cast(OneRecMoeBlockLayerTensorId::IN_MOE_EXPERT_W3_WEIGHT); +constexpr int32_t kInMoeSharedW1Weight = + static_cast(OneRecMoeBlockLayerTensorId::IN_MOE_SHARED_W1_WEIGHT); +constexpr int32_t kInMoeSharedW2Weight = + static_cast(OneRecMoeBlockLayerTensorId::IN_MOE_SHARED_W2_WEIGHT); + +static const std::unordered_map + kOneRecEncoderWeightMapping = { + {"layer.0.layer_norm.weight", kInLayerNormWeight}, + {"layer.0.SelfAttention.q.weight", kInQWeight}, + {"layer.0.SelfAttention.k.weight", kInKWeight}, + {"layer.0.SelfAttention.v.weight", kInVWeight}, + {"layer.0.SelfAttention.o.weight", kInSelfAttnOutWeight}, + {"layer.0.SelfAttention.relative_attention_bias.weight", + kInRelativeAttentionBiasWeight}, + {"layer.1.layer_norm.weight", kInFinalLayerNormWeight}, + {"layer.1.DenseReluDense.wi.weight", kInFfnWi1Weight}, + {"layer.1.DenseReluDense.wo.weight", kInFfnWoWeight}, + {"layer.1.DenseReluDense.gate_proj.weight", kInFfnWi0Weight}, + {"layer.1.ffn.wi.weight", kInFfnWi1Weight}, + {"layer.1.ffn.wo.weight", kInFfnWoWeight}, + {"layer.1.ffn.gate_proj.weight", kInFfnWi0Weight}, + // Alternative format + {"0.layer_norm.weight", kInLayerNormWeight}, + {"0.SelfAttention.q.weight", kInQWeight}, + {"0.SelfAttention.k.weight", kInKWeight}, + {"0.SelfAttention.v.weight", kInVWeight}, + {"0.SelfAttention.o.weight", kInSelfAttnOutWeight}, + {"0.SelfAttention.relative_attention_bias.weight", + kInRelativeAttentionBiasWeight}, + {"1.layer_norm.weight", kInFinalLayerNormWeight}, + {"1.DenseReluDense.wi.weight", kInFfnWi1Weight}, + {"1.DenseReluDense.wo.weight", kInFfnWoWeight}, + {"1.DenseReluDense.gate_proj.weight", kInFfnWi0Weight}, + {"1.ffn.wi.weight", kInFfnWi1Weight}, + {"1.ffn.wo.weight", kInFfnWoWeight}, + {"1.ffn.gate_proj.weight", kInFfnWi0Weight}, +}; + +static const std::unordered_map + kOneRecDecoderWeightMapping = { + {"layer.0.layer_norm.weight", kInLayerNormWeight}, + {"layer.0.SelfAttention.q.weight", kInQWeight}, + {"layer.0.SelfAttention.k.weight", kInKWeight}, + {"layer.0.SelfAttention.v.weight", kInVWeight}, + {"layer.0.SelfAttention.o.weight", kInSelfAttnOutWeight}, + {"layer.0.SelfAttention.relative_attention_bias.weight", + kInRelativeAttentionBiasWeight}, + {"layer.1.layer_norm.weight", kInCrossLayerNormWeight}, + {"layer.1.EncDecAttention.q.weight", kInCrossQWeight}, + {"layer.1.EncDecAttention.k.weight", kInCrossKWeight}, + {"layer.1.EncDecAttention.v.weight", kInCrossVWeight}, + {"layer.1.EncDecAttention.o.weight", kInCrossAttnOutWeight}, + {"layer.2.layer_norm.weight", kInFinalLayerNormWeight}, + {"layer.2.DenseReluDense.wi.weight", kInFfnWi1Weight}, + {"layer.2.DenseReluDense.wo.weight", kInFfnWoWeight}, + {"layer.2.DenseReluDense.gate_proj.weight", kInFfnWi0Weight}, + // Alternative format + {"0.layer_norm.weight", kInLayerNormWeight}, + {"0.SelfAttention.q.weight", kInQWeight}, + {"0.SelfAttention.k.weight", kInKWeight}, + {"0.SelfAttention.v.weight", kInVWeight}, + {"0.SelfAttention.o.weight", kInSelfAttnOutWeight}, + {"0.SelfAttention.relative_attention_bias.weight", + kInRelativeAttentionBiasWeight}, + {"1.layer_norm.weight", kInCrossLayerNormWeight}, + {"1.EncDecAttention.q.weight", kInCrossQWeight}, + {"1.EncDecAttention.k.weight", kInCrossKWeight}, + {"1.EncDecAttention.v.weight", kInCrossVWeight}, + {"1.EncDecAttention.o.weight", kInCrossAttnOutWeight}, + {"2.layer_norm.weight", kInFinalLayerNormWeight}, + {"2.DenseReluDense.wi.weight", kInFfnWi1Weight}, + {"2.DenseReluDense.wo.weight", kInFfnWoWeight}, + {"2.DenseReluDense.gate_proj.weight", kInFfnWi0Weight}, + {"2.ffn.wi.weight", kInFfnWi1Weight}, + {"2.ffn.wo.weight", kInFfnWoWeight}, + {"2.ffn.gate_proj.weight", kInFfnWi0Weight}, +}; + +static std::unordered_map +get_onerec_decoder_moe_weight_mapping() { + std::unordered_map mapping = + kOneRecDecoderWeightMapping; + + mapping.emplace("layer.2.ffn.gate.weight", kInBlockSparseMoeGateWeight); + mapping.emplace("2.ffn.gate.weight", kInBlockSparseMoeGateWeight); + + mapping.emplace("layer.2.ffn.shared_experts.w1.weight", + kInMlpGateUpWeightSharedExpert); + mapping.emplace("layer.2.ffn.shared_experts.w3.weight", + kInMlpGateUpWeightSharedExpert); + mapping.emplace("layer.2.ffn.shared_experts.w2.weight", + kInMlpDownWeightSharedExpert); + + mapping.emplace("layer.2.ffn.shared_expert.gate.weight", + kInSharedExpertGateWeight); + mapping.emplace("layer.2.ffn.shared_expert.gate.bias", + kInSharedExpertGateBias); + mapping.emplace("layer.2.ffn.shared_expert.gate.weight_scale", + kInSharedExpertGateScale); + mapping.emplace("layer.2.ffn.shared_expert.gate.weight_offset", + kInSharedExpertGateOffset); + + // Expert weights are handled by + // process_expert_weights()/merge_experts_weights to avoid ambiguous suffix + // matching and keep deterministic loading. + + return mapping; +} + +static const std::unordered_map + kOneRecDecoderMoeWeightMapping = get_onerec_decoder_moe_weight_mapping(); + +static const std::unordered_map kOneRecWeightShard = { + {kInQWeight, 0}, + {kInKWeight, 0}, + {kInVWeight, 0}, + {kInSelfAttnOutWeight, 1}, + {kInCrossQWeight, 0}, + {kInCrossKWeight, 0}, + {kInCrossVWeight, 0}, + {kInCrossAttnOutWeight, 1}, + {kInFfnWi0Weight, 0}, + {kInFfnWi1Weight, 0}, + {kInFfnWoWeight, 1}, + // MoE + {kInBlockSparseMoeGateWeight, 0}, + {kInMlpGateUpWeightExpert, 0}, + {kInMlpDownWeightExpert, 1}, + // Shared experts + {kInMlpGateUpWeightSharedExpert, 0}, + {kInMlpGateUpOffsetSharedExpert, 0}, + {kInMlpGateUpScaleSharedExpert, 0}, + {kInMlpDownWeightSharedExpert, 1}, + {kInMlpDownOffsetSharedExpert, 1}, + {kInMlpDownScaleSharedExpert, 1}, + {kInSharedExpertGateWeight, 0}, + {kInSharedExpertGateBias, 0}, + {kInSharedExpertGateScale, 0}, + {kInSharedExpertGateOffset, 0}, +}; + +} // namespace NpuOneRecBlockLayerImpl::NpuOneRecBlockLayerImpl(const ModelContext& context, bool is_decoder, int32_t layer_id) - : device_(context.get_tensor_options().device()), - is_decoder_(is_decoder), - layer_id_(layer_id) {} - -torch::Tensor NpuOneRecBlockLayerImpl::forward(torch::Tensor& hidden_states, - torch::Tensor& attn_mask, - KVCache& kv_cache, - ModelInputParams& input_params, - torch::Tensor* encoder_output, - int32_t node_id, - aclrtEvent* event, - std::atomic* event_flag) { - (void)attn_mask; - (void)kv_cache; - (void)node_id; - (void)event; - (void)event_flag; + : BaseLayer(context), is_decoder_(is_decoder), layer_id_(layer_id) { + const auto& args = context.get_model_args(); + const auto& parallel_args = context.get_parallel_args(); + param_from_args(prefill_param_, args, parallel_args, /*is_prefill=*/true); + param_from_args(decode_param_, args, parallel_args, /*is_prefill=*/false); + + const int32_t weight_count = prefill_param_.use_moe + ? kOneRecMoeWeightCountPerLayer + : kOneRecWeightCountPerLayer; + at_weight_tensors_.resize(weight_count); + atb_weight_tensors_.resize(weight_count); + + placeholder_vec_ = {1, 1}; + dtype_ = c10::typeMetaToScalarType(context.get_tensor_options().dtype()); + device_id_ = context.get_tensor_options().device().index(); + + auto placeholder_tensor = torch::empty({1, 1}, torch::kInt32).to(device_); + placeholder_ = atb_speed::Utils::AtTensor2Tensor(placeholder_tensor); + at_placeholder_ = torch::empty({1, args.hidden_size()}, dtype_).to(device_); + + for (int32_t i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = + torch::zeros({1, args.hidden_size()}).to(context.get_tensor_options()); + } + + if (prefill_param_.use_moe) { + auto device = context.get_tensor_options().device(); + one_hot_ = torch::tensor({1}, torch::kInt32).to(device); + zero_hot_ = torch::tensor({0}, torch::kInt32).to(device); + expert_group_ = torch::tensor({1}, torch::dtype(torch::kInt32)).to(device); + } +} + +void NpuOneRecBlockLayerImpl::param_from_args( + atb_speed::onerec::BlockLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool is_prefill, + const ModelInputParams* input_params) { + (void)input_params; + + param.isFA = false; + param.isPrefill = is_prefill; + param.isBF16 = args.dtype() == "bfloat16"; + param.isPack = true; + param.supportSwiGLU = true; + param.supportLcoc = is_prefill; + param.supportSpeculate = false; + param.enableSplitFuse = FLAGS_enable_chunked_prefill && is_prefill; + param.supportLora = false; + param.loraEnableGMM = false; + param.enableLogN = false; + param.kvQuant = false; + param.enableIntraLayerAddNorm = false; + param.enableInterLayerAddNorm = false; + param.isDecoder = is_decoder_; + param.isOneRecEncoder = !is_decoder_; + param.enableOneRecPrefillOnly = FLAGS_enable_rec_prefill_only; + param.backend = FLAGS_communication_backend; + param.rank = parallel_args.rank(); + param.worldSize = parallel_args.world_size(); + param.quantType = 0; + param.quantGroupSize = 64; + + const int64_t args_n_heads = + is_decoder_ ? args.decoder_n_heads() : args.n_heads(); + const int64_t args_head_dim = + is_decoder_ ? args.decoder_head_dim() : args.head_dim(); + param.numAttentionHeadsPerRank = args_n_heads / param.worldSize; + param.hiddenSizePerAttentionHead = args_head_dim; + + std::optional optional_value = + is_decoder_ ? args.decoder_n_kv_heads().value_or(args.decoder_n_heads()) + : args.n_kv_heads().value_or(args.n_heads()); + param.numKeyValueHeadsPerRank = + static_cast(optional_value.value()) / param.worldSize; + param.rmsNormEps = args.rms_norm_eps(); + + param.seqLen = {}; + param.tokenOffset = {}; + param.packQuantType = {1, 1}; + param.linearQuantType = {0, -1, -1, 0, 0, -1, 0}; + param.layerId = layer_id_; + param.linearTransposeType = {1, 1, 1, 1, 1, 1, 1}; + + if (param.isBF16) { + param.linearDescs = { + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC)}; + } else { + param.linearDescs = { + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC)}; + } + + param.use_moe = args.use_moe() && is_decoder_; + if (param.use_moe) { + ep_size_ = 1; + const int32_t ep_rank = 0; + ep_local_tp_size_ = parallel_args.world_size() / ep_size_; + CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); + ep_local_tp_rank_ = parallel_args.rank() % ep_local_tp_size_; + + num_experts_per_partition_ = args.n_routed_experts() / ep_size_; + start_expert_id_ = ep_rank * num_experts_per_partition_; + end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; + + resize_experts_weights(num_experts_per_partition_); + + param.moe_config = std::make_unique(); + param.moe_config->moe_topk = args.num_experts_per_tok(); + param.moe_config->moe_num_experts = args.n_routed_experts(); + param.moe_config->moe_score_func = "softmax"; + param.moe_config->moe_route_scale = args.moe_route_scale(); + param.moe_config->moe_inter_dim = args.moe_intermediate_size(); + param.moe_config->use_bf16 = param.isBF16; + param.moe_config->hasSharedExpertGate = false; + param.moe_config->moe_use_shared_experts = args.moe_use_shared_experts(); + param.moe_config->moe_num_shared_experts = args.n_shared_experts(); + + param.moeLinearQuantType = {atb_speed::common::LinearType::FP, + atb_speed::common::LinearType::FP, + atb_speed::common::LinearType::INVALID, + atb_speed::common::LinearType::FP}; + } +} + +void NpuOneRecBlockLayerImpl::verify_loaded_weights( + const std::string& prefix) const { + const auto& weight_mapping = + [this]() -> const std::unordered_map& { + if (prefill_param_.use_moe) { + return kOneRecDecoderMoeWeightMapping; + } + return is_decoder_ ? kOneRecDecoderWeightMapping + : kOneRecEncoderWeightMapping; + }(); + + // verify_loaded_weights() runs before merge_loaded_weights(). + // Only allow placeholders for tensors that are intentionally absent before + // merge in the current mode. + std::set allowed_placeholders; + if (prefill_param_.use_moe) { + // MoE decoder path does not consume dense FFN gate/up/down tensors. + allowed_placeholders.insert(kInFfnWi0Weight); + allowed_placeholders.insert(kInFfnWi1Weight); + allowed_placeholders.insert(kInFfnWoWeight); + } + const bool has_shared_experts = + prefill_param_.moe_config != nullptr && + prefill_param_.moe_config->moe_use_shared_experts; + + for (const auto& [name, index] : weight_mapping) { + const auto sizes = at_weight_tensors_[index].sizes(); + const bool is_placeholder = (sizes.size() == 2 && sizes[0] == 1); + const bool expected_placeholder = allowed_placeholders.count(index) > 0; + const bool is_relative_bias = (index == kInRelativeAttentionBiasWeight); + const bool is_shared_optional = prefill_param_.use_moe && + !has_shared_experts && + (index == kInMlpGateUpWeightSharedExpert || + index == kInMlpDownWeightSharedExpert || + index == kInSharedExpertGateWeight || + index == kInSharedExpertGateBias || + index == kInSharedExpertGateOffset || + index == kInSharedExpertGateScale); + if (is_placeholder && !expected_placeholder && !is_relative_bias && + !is_shared_optional) { + CHECK(false) << "weight is not loaded for " << prefix << name; + } + } + + if (prefill_param_.use_moe) { + CHECK(validate_decoder_moe_weights(prefix)) + << "OneRec MoE expert weights are incomplete for " << prefix; + } +} + +bool NpuOneRecBlockLayerImpl::validate_decoder_moe_weights( + const std::string& prefix) const { + const auto gate_it = experts_weights_.find("gate_proj.weight"); + const auto up_it = experts_weights_.find("up_proj.weight"); + const auto down_it = experts_weights_.find("down_proj.weight"); + if (gate_it == experts_weights_.end() || up_it == experts_weights_.end() || + down_it == experts_weights_.end()) { + LOG(ERROR) << "Missing OneRec MoE expert tensors in " << prefix + << " (layer " << layer_id_ + << ", gate/up/down map entry not found)."; + return false; + } + + const auto& gate_weights = gate_it->second; + const auto& up_weights = up_it->second; + const auto& down_weights = down_it->second; + + if (gate_weights.size() != up_weights.size() || + gate_weights.size() != down_weights.size()) { + LOG(ERROR) << "OneRec MoE expert vector size mismatch in " << prefix + << ": gate=" << gate_weights.size() + << ", up=" << up_weights.size() + << ", down=" << down_weights.size() << ", layer " << layer_id_; + return false; + } + + for (size_t i = 0; i < gate_weights.size(); ++i) { + const bool gate_defined = gate_weights[i].defined(); + const bool up_defined = up_weights[i].defined(); + const bool down_defined = down_weights[i].defined(); + if (gate_defined != up_defined || gate_defined != down_defined) { + LOG(ERROR) << "OneRec MoE expert tensor mismatch in " << prefix + << " at local expert " << i << ": gate=" << gate_defined + << ", up=" << up_defined << ", down=" << down_defined + << ", layer " << layer_id_; + return false; + } + if (!gate_defined) { + LOG(ERROR) << "Missing OneRec MoE tensor for local expert " << i << " in " + << prefix << " (layer " << layer_id_ << ")."; + return false; + } + } + return true; +} + +void NpuOneRecBlockLayerImpl::merge_loaded_weights() { + const bool q_loaded = !(at_weight_tensors_[kInQWeight].sizes().size() == 2 && + at_weight_tensors_[kInQWeight].sizes()[0] == 1); + const bool k_loaded = !(at_weight_tensors_[kInKWeight].sizes().size() == 2 && + at_weight_tensors_[kInKWeight].sizes()[0] == 1); + const bool v_loaded = !(at_weight_tensors_[kInVWeight].sizes().size() == 2 && + at_weight_tensors_[kInVWeight].sizes()[0] == 1); + CHECK(q_loaded && k_loaded && v_loaded) + << "OneRec QKV weights are not properly loaded."; + + auto new_q_weight = torch::cat({at_weight_tensors_[kInQWeight], + at_weight_tensors_[kInKWeight], + at_weight_tensors_[kInVWeight]}, + 0); + at_weight_tensors_[kInQWeight] = new_q_weight; + at_weight_tensors_[kInKWeight] = + torch::zeros({1, at_weight_tensors_[kInQWeight].size(1)}) + .to(device_) + .to(dtype_); + at_weight_tensors_[kInVWeight] = + torch::zeros({1, at_weight_tensors_[kInQWeight].size(1)}) + .to(device_) + .to(dtype_); + + // Keep decoder cross-attention Q/K/V unpacked for current OneRec ATB + // contract. Do not merge IN_CROSS_{Q,K,V}_WEIGHT here. + + if (!prefill_param_.use_moe) { + const bool wi0_loaded = + !(at_weight_tensors_[kInFfnWi0Weight].sizes().size() == 2 && + at_weight_tensors_[kInFfnWi0Weight].sizes()[0] == 1); + const bool wi1_loaded = + !(at_weight_tensors_[kInFfnWi1Weight].sizes().size() == 2 && + at_weight_tensors_[kInFfnWi1Weight].sizes()[0] == 1); + CHECK(wi0_loaded && wi1_loaded) + << "OneRec FFN gate/up weights are not properly loaded."; + + auto new_gate_up_weight = torch::cat({at_weight_tensors_[kInFfnWi0Weight], + at_weight_tensors_[kInFfnWi1Weight]}, + 0); + at_weight_tensors_[kInFfnWi0Weight] = new_gate_up_weight; + at_weight_tensors_[kInFfnWi1Weight] = + torch::zeros({1, at_weight_tensors_[kInFfnWi0Weight].size(1)}) + .to(device_) + .to(dtype_); + } else { + merge_experts_weights(); + merge_shared_experts_weights(); + } + + const uint64_t weight_count = prefill_param_.use_moe + ? kOneRecMoeWeightCountPerLayer + : kOneRecWeightCountPerLayer; + for (int32_t i = 0; i < static_cast(weight_count); ++i) { + if (!at_weight_tensors_[i].defined()) { + at_weight_tensors_[i] = torch::zeros( + {1, 1}, torch::TensorOptions().device(device_).dtype(dtype_)); + } + if (!at_weight_tensors_[i].is_contiguous()) { + at_weight_tensors_[i] = at_weight_tensors_[i].contiguous(); + } + } + + for (int32_t i = 0; i < static_cast(weight_count); ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + + LOG(INFO) << "OneRec BlockLayer merge_loaded_weights calling init_layer" + << ", layer_role=" << (is_decoder_ ? "decoder" : "encoder") + << ", layer_id=" << layer_id_ << ", weight_count=" << weight_count; + const int64_t init_status = init_layer(); + LOG(INFO) << "OneRec BlockLayer merge_loaded_weights init_layer returned" + << ", layer_role=" << (is_decoder_ ? "decoder" : "encoder") + << ", layer_id=" << layer_id_ << ", status=" << init_status; + CHECK_EQ(init_status, atb::NO_ERROR) + << "OneRec BlockLayer init_layer failed, layer_role=" + << (is_decoder_ ? "decoder" : "encoder") << ", layer_id=" << layer_id_; +} + +void NpuOneRecBlockLayerImpl::load_state_dict(const StateDict& state_dict) { + const auto target_weight_dtype = [this]() -> torch::ScalarType { + if (torch_dtype_.empty()) { + return dtype_; + } + if (torch_dtype_ == "float16") { + return torch::kFloat16; + } + if (torch_dtype_ == "bfloat16") { + return torch::kBFloat16; + } + if (torch_dtype_ == "float32") { + return torch::kFloat32; + } + if (torch_dtype_ == "float64") { + return torch::kFloat64; + } + if (torch_dtype_ == "int8") { + return torch::kInt8; + } + if (torch_dtype_ == "int16") { + return torch::kInt16; + } + if (torch_dtype_ == "int32") { + return torch::kInt32; + } + if (torch_dtype_ == "int64") { + return torch::kInt64; + } + if (torch_dtype_ == "uint8") { + return torch::kUInt8; + } + if (torch_dtype_ == "bool") { + return torch::kBool; + } + LOG(FATAL) << "Unsupported OneRec weight dtype " << torch_dtype_ + << ", layer_id=" << layer_id_; + return dtype_; + }; + const auto correct_tensor_dtype = [this, &target_weight_dtype]( + torch::Tensor& tensor, + const std::string& tensor_name) { + if (absl::EndsWith(tensor_name, "deq_scale") && + torch_dtype_ == "bfloat16") { + return; + } + if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 && + tensor.dtype() != torch::kInt64) { + tensor = tensor.to(target_weight_dtype()); + } + }; + const auto load_weight = [this, &state_dict, &correct_tensor_dtype]( + const std::string& tensor_name, + int32_t weight_position, + int32_t shard_dim = -1) { + for (const auto& [name, tensor] : state_dict) { + if (!absl::EndsWith(name, tensor_name)) { + continue; + } + torch::Tensor mutable_tensor = + (shard_dim >= 0 && parallel_args_.world_size() > 1) + ? state_dict.get_sharded_tensor(tensor_name, + shard_dim, + parallel_args_.rank(), + parallel_args_.world_size()) + : tensor; + correct_tensor_dtype(mutable_tensor, tensor_name); + at_weight_tensors_[weight_position] = mutable_tensor.to(device_); + return; + } + }; + + const auto& weight_mapping = + [this]() -> const std::unordered_map& { + if (prefill_param_.use_moe) { + return kOneRecDecoderMoeWeightMapping; + } + return is_decoder_ ? kOneRecDecoderWeightMapping + : kOneRecEncoderWeightMapping; + }(); + + if (prefill_param_.use_moe) { + for (auto& [key, tensors] : experts_weights_) { + (void)key; + for (auto& t : tensors) { + t = torch::Tensor(); + } + } + shared_expert_weights_map_.clear(); + shared_expert_gate_weights_.clear(); + shared_expert_up_weights_.clear(); + shared_expert_down_weights_.clear(); + + for (const auto& [state_key, tensor] : state_dict) { + if (state_key.find(".ffn.experts.") != std::string::npos) { + process_expert_weights(state_dict, state_key, tensor); + } + } + + for (const auto& [state_key, tensor] : state_dict) { + const bool is_shared_expert = + (state_key.find(".ffn.shared_experts.") != std::string::npos || + state_key.find(".ffn.shared_expert.") != std::string::npos); + if (is_shared_expert) { + process_shared_expert_weights(state_dict, state_key, tensor); + } + } + } + + std::vector> ordered_mapping( + weight_mapping.begin(), weight_mapping.end()); + std::sort( + ordered_mapping.begin(), + ordered_mapping.end(), + [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + for (const auto& [name, index] : ordered_mapping) { + const bool is_relative_bias = (index == kInRelativeAttentionBiasWeight); + bool weight_exists = false; + for (const auto& [state_key, tensor] : state_dict) { + (void)tensor; + if (absl::EndsWith(state_key, name)) { + weight_exists = true; + break; + } + } + if (is_relative_bias && !weight_exists) { + continue; + } + + const auto it = kOneRecWeightShard.find(index); + if (it != kOneRecWeightShard.end()) { + load_weight(name, index, it->second); + } else { + load_weight(name, index); + } + } +} + +int64_t NpuOneRecBlockLayerImpl::init_layer() { + name_ = + is_decoder_ ? "onerec_decoder_block_layer" : "onerec_encoder_block_layer"; + model_name_ = "onerec"; + CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + if (is_decoder_) { + if (FLAGS_enable_rec_prefill_only) { + LOG(INFO) << "OneRec BlockLayer init_layer skip decode node because " + "enable_rec_prefill_only is enabled" + << ", layer_id=" << layer_id_; + LOG(INFO) << "OneRec BlockLayer init_layer success" + << ", layer_role=" << (is_decoder_ ? "decoder" : "encoder") + << ", layer_id=" << layer_id_ << ", status=" << atb::NO_ERROR; + return atb::NO_ERROR; + } + const int64_t decode_status = init_node(decode_node_, decode_param_); + LOG(INFO) << "OneRec BlockLayer init_layer node returned" + << ", node=decoder-decode" + << ", layer_id=" << layer_id_ << ", status=" << decode_status; + CHECK_OPERATION_STATUS_RETURN(decode_status); + } else { + LOG(INFO) << "OneRec BlockLayer init_layer skip decode node" + << ", layer_role=" << (is_decoder_ ? "decoder" : "encoder") + << ", layer_id=" << layer_id_; + } + return atb::NO_ERROR; +} + +int64_t NpuOneRecBlockLayerImpl::init_attn_mask() { return atb::NO_ERROR; } + +int64_t NpuOneRecBlockLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::onerec::BlockLayerParam& param) { + atb::Operation* operation = nullptr; + atb::Status status = atb_speed::onerec::BlockLayer(param, &operation); + if (status != atb::NO_ERROR) { + LOG(ERROR) << "Failed to create ONEREC BlockLayer operation, status: " + << status; + return status; + } + + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null after creation"; + return -1; + } + + uint32_t input_num = node.operation->GetInputNum(); + uint32_t output_num = node.operation->GetOutputNum(); + node.inTensors.resize(input_num); + node.outTensors.resize(output_num); + + const uint64_t weight_count = param.use_moe ? kOneRecMoeWeightCountPerLayer + : kOneRecWeightCountPerLayer; + for (size_t weight_tensor_id = 0; weight_tensor_id < weight_count; + ++weight_tensor_id) { + if (weight_tensor_id < input_num) { + node.inTensors.at(weight_tensor_id) = + &atb_weight_tensors_[weight_tensor_id]; + } + } + + node.variantPack.inTensors.resize(input_num); + node.variantPack.outTensors.resize(output_num); + + return atb::NO_ERROR; +} + +torch::Tensor NpuOneRecBlockLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + torch::Tensor* encoder_output, + int32_t node_id, + aclrtEvent* event, + std::atomic* event_flag, + const torch::Tensor& expert_array) { const auto* onerec_params = input_params.onerec_params(); - CHECK(onerec_params != nullptr) << "OneRec block requires onerec_params()."; + CHECK(onerec_params != nullptr) << "OneRec requires rec_params."; const bool is_prefill = onerec_params->rec_stage == OneRecModelInputParams::RecStage::PREFILL; + + atb::Status st; + if (is_prefill) { + if (is_decoder_) { + if (prefill_param_.use_moe) { + build_decoder_moe_node_variant_pack(prefill_node_, + x, + attn_mask, + kv_cache, + input_params, + true, + encoder_output, + node_id, + expert_array); + } else { + build_decoder_node_variant_pack(prefill_node_, + x, + attn_mask, + kv_cache, + input_params, + true, + encoder_output, + node_id); + } + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ << " execute prefill layer fail, error code: " << st; + } else { + build_encoder_node_variant_pack( + prefill_node_, x, attn_mask, input_params, true, node_id); + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ + << " execute encoder prefill layer fail, error code: " << st; + } + } else { + if (!is_decoder_) { + LOG(FATAL) << model_name_ << " encoder decode stage is not supported."; + } + + if (decode_param_.use_moe) { + build_decoder_moe_node_variant_pack(decode_node_, + x, + attn_mask, + kv_cache, + input_params, + false, + encoder_output, + node_id, + expert_array); + } else { + build_decoder_node_variant_pack(decode_node_, + x, + attn_mask, + kv_cache, + input_params, + false, + encoder_output, + node_id); + } + st = execute_node(decode_node_, node_id + 1000, event, event_flag); + LOG_IF(FATAL, st != 0) << model_name_ + << " execute decode layer fail, error code: " << st; + } + + return at_placeholder_; +} + +void NpuOneRecBlockLayerImpl::build_encoder_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + bool is_prefill, + int32_t layer_id) { + (void)is_prefill; + (void)layer_id; + + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + + for (size_t i = 0; i < kOneRecWeightCountPerLayer; ++i) { + CHECK(node.inTensors.at(i) != nullptr) + << model_name_ << " inTensor " << i << " is NULL"; + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + const int32_t input_tensor_idx = + static_cast(kOneRecWeightCountPerLayer); + const int32_t attention_mask_idx = input_tensor_idx + 1; + const int32_t token_offset_idx = attention_mask_idx + 1; + const int32_t layer_id_idx = token_offset_idx + 1; + const int32_t seq_len_idx = layer_id_idx + 1; + + node.variantPack.inTensors.at(input_tensor_idx) = internal_tensors_; + node.variantPack.inTensors.at(attention_mask_idx) = + atb_speed::Utils::AtTensor2Tensor(attn_mask); + + node.variantPack.inTensors.at(token_offset_idx) = placeholder_; + node.variantPack.inTensors.at(token_offset_idx).hostData = + placeholder_vec_.data(); + node.variantPack.inTensors.at(layer_id_idx) = placeholder_; + node.variantPack.inTensors.at(layer_id_idx).hostData = + placeholder_vec_.data(); + + const auto* onerec_params = input_params.onerec_params(); + if (onerec_params != nullptr && + onerec_params->encoder_seq_lens_tensor.defined()) { + node.variantPack.inTensors.at(seq_len_idx) = + atb_speed::Utils::AtTensor2Tensor( + onerec_params->encoder_seq_lens_tensor); + node.variantPack.inTensors.at(seq_len_idx).hostData = + const_cast(onerec_params->encoder_seq_lens.data()); + } else { + node.variantPack.inTensors.at(seq_len_idx) = placeholder_; + node.variantPack.inTensors.at(seq_len_idx).hostData = + placeholder_vec_.data(); + } + + node.variantPack.outTensors.at(0) = internal_tensors_; +} + +void NpuOneRecBlockLayerImpl::build_decoder_moe_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output, + int32_t layer_id, + const torch::Tensor& expert_array) { + (void)kv_cache; + (void)is_prefill; + (void)layer_id; + + for (size_t i = 0; i < kOneRecMoeWeightCountPerLayer; ++i) { + CHECK(node.inTensors.at(i) != nullptr) + << model_name_ << " inTensor " << i << " is NULL"; + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + const int32_t moe_tensor_start = + static_cast(kOneRecMoeWeightCountPerLayer); + if (expert_array.defined()) { + node.variantPack.inTensors.at(moe_tensor_start) = + atb_speed::Utils::AtTensor2Tensor(expert_array); + } else { + node.variantPack.inTensors.at(moe_tensor_start) = placeholder_; + } + + node.variantPack.inTensors.at(moe_tensor_start + 1) = + expert_group_.defined() ? atb_speed::Utils::AtTensor2Tensor(expert_group_) + : placeholder_; + node.variantPack.inTensors.at(moe_tensor_start + 2) = + one_hot_.defined() ? atb_speed::Utils::AtTensor2Tensor(one_hot_) + : placeholder_; + node.variantPack.inTensors.at(moe_tensor_start + 3) = + zero_hot_.defined() ? atb_speed::Utils::AtTensor2Tensor(zero_hot_) + : placeholder_; + + int32_t tensor_idx = setup_common_decoder_tensors( + node, x, attn_mask, input_params, encoder_output, moe_tensor_start + 4); + + while (tensor_idx < static_cast(node.variantPack.inTensors.size())) { + node.variantPack.inTensors.at(tensor_idx) = placeholder_; + node.variantPack.inTensors.at(tensor_idx).hostData = + placeholder_vec_.data(); + ++tensor_idx; + } +} + +int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + torch::Tensor* encoder_output, + int32_t start_tensor_idx) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + + int32_t idx = start_tensor_idx; + node.variantPack.inTensors.at(idx++) = internal_tensors_; + node.variantPack.inTensors.at(idx++) = + atb_speed::Utils::AtTensor2Tensor(attn_mask); + + // Token offset and layer id placeholders. + // ATB expects hostData to be valid for these scalar inputs. Keep them as + // placeholders but always provide hostData to avoid undefined reads. + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + + CHECK(input_params.kv_seq_lens.defined()) << "kv_seq_lens is required."; + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); + node.variantPack.inTensors.at(idx).hostData = + input_params.kv_seq_lens_vec.data(); + idx++; + + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + + if (!FLAGS_enable_rec_prefill_only && input_params.block_tables.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + if (!FLAGS_enable_rec_prefill_only && + input_params.new_cache_slots.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + if (encoder_output != nullptr) { + encoder_output_contiguous_ = encoder_output->is_contiguous() + ? *encoder_output + : encoder_output->contiguous(); + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(encoder_output_contiguous_); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + } + idx++; + + for (int32_t i = 0; i < 3; i++) { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + } + + const auto* onerec_params = input_params.onerec_params(); + if (onerec_params != nullptr && + onerec_params->encoder_seq_lens_tensor.defined()) { + node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( + onerec_params->encoder_seq_lens_tensor); + node.variantPack.inTensors.at(idx++).hostData = + const_cast(onerec_params->encoder_seq_lens.data()); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + } + + node.variantPack.outTensors.at(0) = internal_tensors_; + return idx; +} + +void NpuOneRecBlockLayerImpl::build_decoder_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output, + int32_t layer_id) { + (void)kv_cache; (void)is_prefill; + (void)layer_id; + + for (size_t i = 0; i < kOneRecWeightCountPerLayer; ++i) { + CHECK(node.inTensors.at(i) != nullptr) + << model_name_ << " inTensor " << i << " is NULL"; + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + int32_t tensor_idx = setup_common_decoder_tensors( + node, + x, + attn_mask, + input_params, + encoder_output, + static_cast(kOneRecWeightCountPerLayer)); + while (tensor_idx < static_cast(node.variantPack.inTensors.size())) { + node.variantPack.inTensors.at(tensor_idx) = placeholder_; + node.variantPack.inTensors.at(tensor_idx).hostData = + placeholder_vec_.data(); + ++tensor_idx; + } +} + +void NpuOneRecBlockLayerImpl::resize_experts_weights( + int32_t num_of_device_experts) { + experts_weights_["gate_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight"] = + std::vector(num_of_device_experts); +} + +void NpuOneRecBlockLayerImpl::process_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + (void)state_dict; + std::lock_guard lock(experts_mutex_); + + int32_t expert_id = extract_expert_index(name); + if (expert_id < 0) { + return; + } + + const int32_t local_index = expert_id % num_experts_per_partition_; + std::string weight_suffix = extract_endswith(name); + + std::string suffix; + if (weight_suffix == "gate_proj.weight" || weight_suffix == "w1.weight") { + suffix = "gate_proj.weight"; + } else if (weight_suffix == "up_proj.weight" || + weight_suffix == "w3.weight") { + suffix = "up_proj.weight"; + } else if (weight_suffix == "down_proj.weight" || + weight_suffix == "w2.weight") { + suffix = "down_proj.weight"; + } else { + return; + } + + auto it = experts_weights_.find(suffix); + if (it == experts_weights_.end() || local_index < 0 || + local_index >= static_cast(it->second.size())) { + LOG(ERROR) << "Invalid OneRec MoE local expert index " << local_index + << " for " << suffix << " at layer " << layer_id_ << "."; + return; + } + it->second[local_index] = tensor; +} + +void NpuOneRecBlockLayerImpl::process_shared_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + (void)state_dict; + torch::Tensor tmp_tensor = tensor.to(device_); + + std::string canonical_name; + if (absl::StrContains(name, "gate_proj") || absl::StrContains(name, "w1")) { + canonical_name = "gate_proj.weight"; + } else if (absl::StrContains(name, "up_proj") || + absl::StrContains(name, "w3")) { + canonical_name = "up_proj.weight"; + } else if (absl::StrContains(name, "down_proj") || + absl::StrContains(name, "w2")) { + canonical_name = "down_proj.weight"; + } else { + return; + } + + if (shared_expert_weights_map_.count(canonical_name) > 0) { + LOG(WARNING) << "Duplicate OneRec shared expert tensor for " + << canonical_name << " at layer " << layer_id_ + << ", overriding previous value."; + } + shared_expert_weights_map_[canonical_name] = tmp_tensor; +} + +int32_t NpuOneRecBlockLayerImpl::extract_expert_index(const std::string& name) { + size_t experts_pos = name.find(".experts."); + if (experts_pos == std::string::npos) { + return -1; + } + + size_t start_pos = experts_pos + 9; + size_t end_pos = name.find(".", start_pos); + if (end_pos == std::string::npos) { + return -1; + } + + try { + return std::stoi(name.substr(start_pos, end_pos - start_pos)); + } catch (const std::exception&) { + return -1; + } +} + +std::string NpuOneRecBlockLayerImpl::extract_endswith( + const std::string& input) { + size_t experts_pos = input.find(".experts."); + if (experts_pos == std::string::npos) { + return ""; + } + size_t start_pos = experts_pos + 9; + size_t next_dot = input.find(".", start_pos); + if (next_dot == std::string::npos) { + return ""; + } + return input.substr(next_dot + 1); +} + +torch::Tensor NpuOneRecBlockLayerImpl::merge_experts_weights( + std::vector& experts, + bool transpose) { + std::vector valid; + valid.reserve(experts.size()); + for (auto& t : experts) { + if (t.defined()) { + valid.push_back(t.to(device_)); + } + } + if (valid.empty()) { + LOG(ERROR) << "No expert weights to merge at layer " << layer_id_ << "."; + return torch::Tensor(); + } + torch::Tensor merged_tensor = torch::stack(valid, 0); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + return merged_tensor.contiguous(); +} + +torch::Tensor NpuOneRecBlockLayerImpl::merge_experts_weights( + std::vector& experts_gate, + std::vector& experts_up, + bool transpose) { + if (experts_gate.size() != experts_up.size()) { + LOG(ERROR) << "OneRec MoE gate/up expert size mismatch: gate=" + << experts_gate.size() << ", up=" << experts_up.size() + << ", layer " << layer_id_; + return torch::Tensor(); + } + for (size_t i = 0; i < experts_gate.size(); ++i) { + const bool gate_defined = experts_gate[i].defined(); + const bool up_defined = experts_up[i].defined(); + if (gate_defined != up_defined) { + LOG(ERROR) << "OneRec MoE gate/up tensor mismatch at local expert " << i + << ": gate=" << gate_defined << ", up=" << up_defined + << ", layer " << layer_id_; + return torch::Tensor(); + } + if (gate_defined) { + experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + } + } + return merge_experts_weights(experts_gate, transpose); +} + +void NpuOneRecBlockLayerImpl::merge_experts_weights() { + if (experts_weights_.count("gate_proj.weight") == 0 || + experts_weights_.count("up_proj.weight") == 0 || + experts_weights_.count("down_proj.weight") == 0) { + return; + } + + auto merged_gate_up = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*transpose=*/false); + CHECK(merged_gate_up.defined()) << "OneRec MoE gate/up experts merge failed."; + at_weight_tensors_[kInMoeExpertW1Weight] = + at_npu::native::npu_format_cast(merged_gate_up, /*format=*/2) + .contiguous(); + + auto merged_down = merge_experts_weights(experts_weights_["down_proj.weight"], + /*transpose=*/false); + CHECK(merged_down.defined()) << "OneRec MoE down experts merge failed."; + at_weight_tensors_[kInMoeExpertW2Weight] = + at_npu::native::npu_format_cast(merged_down, /*format=*/2).contiguous(); +} + +void NpuOneRecBlockLayerImpl::merge_shared_experts_weights() { + shared_expert_gate_weights_.clear(); + shared_expert_up_weights_.clear(); + shared_expert_down_weights_.clear(); + + if (const auto it = shared_expert_weights_map_.find("gate_proj.weight"); + it != shared_expert_weights_map_.end()) { + shared_expert_gate_weights_.push_back(it->second); + } + if (const auto it = shared_expert_weights_map_.find("up_proj.weight"); + it != shared_expert_weights_map_.end()) { + shared_expert_up_weights_.push_back(it->second); + } + if (const auto it = shared_expert_weights_map_.find("down_proj.weight"); + it != shared_expert_weights_map_.end()) { + shared_expert_down_weights_.push_back(it->second); + } - if (encoder_output != nullptr && encoder_output->defined() && - encoder_output->device() != device_) { - *encoder_output = encoder_output->to(device_); + if (shared_expert_gate_weights_.empty() && + shared_expert_up_weights_.empty() && + shared_expert_down_weights_.empty()) { + return; } - if (hidden_states.device() != device_) { - hidden_states = hidden_states.to(device_); + if (!shared_expert_gate_weights_.empty() && + !shared_expert_up_weights_.empty()) { + auto merged_gate_up = merge_experts_weights(shared_expert_gate_weights_, + shared_expert_up_weights_, + /*transpose=*/false); + CHECK(merged_gate_up.defined()) + << "OneRec shared gate/up experts merge failed at layer " << layer_id_; + at_weight_tensors_[kInMlpGateUpWeightSharedExpert] = merged_gate_up; + } else if (!shared_expert_gate_weights_.empty()) { + at_weight_tensors_[kInMlpGateUpWeightSharedExpert] = + merge_experts_weights(shared_expert_gate_weights_, false); } - if (!is_decoder_ && hidden_states.dim() > 1) { - return hidden_states.contiguous(); + if (!shared_expert_down_weights_.empty()) { + at_weight_tensors_[kInMlpDownWeightSharedExpert] = + merge_experts_weights(shared_expert_down_weights_, false); } - return hidden_states; + shared_expert_gate_weights_.clear(); + shared_expert_up_weights_.clear(); + shared_expert_down_weights_.clear(); + shared_expert_weights_map_.clear(); } } // namespace layer diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h index 3c4addacd..ab21afe62 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h @@ -17,34 +17,153 @@ limitations under the License. #include +#include #include +#include +#include +#include +#include +#include "framework/model/model_input_params.h" #include "framework/model_context.h" -#include "layers/onerec_block_layer.h" +#include "framework/state_dict/state_dict.h" +#include "npu_base_layer.h" +#include "xllm_atb_layers/core/include/atb_speed/base/hosttensor_binder.h" +#include "xllm_atb_layers/core/include/atb_speed/base/model.h" +#include "xllm_atb_layers/core/include/atb_speed/log.h" +#include "xllm_atb_layers/core/include/atb_speed/utils/model_factory.h" +#include "xllm_atb_layers/models/onerec/layer/block_layer.h" +#include "xllm_atb_layers/operations/fusion/utils.h" namespace xllm { namespace layer { -class NpuOneRecBlockLayerImpl final : public OneRecBlockLayer { +class NpuOneRecBlockLayerImpl final : public BaseLayer { public: explicit NpuOneRecBlockLayerImpl(const ModelContext& context, bool is_decoder = false, int32_t layer_id = 0); - torch::Tensor forward(torch::Tensor& hidden_states, + ~NpuOneRecBlockLayerImpl() override = default; + + void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights(const std::string& prefix) const; + + void merge_loaded_weights() override; + + int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, torch::Tensor& attn_mask, KVCache& kv_cache, ModelInputParams& input_params, torch::Tensor* encoder_output = nullptr, int32_t node_id = 0, aclrtEvent* event = nullptr, - std::atomic* event_flag = nullptr) override; + std::atomic* event_flag = nullptr, + const torch::Tensor& expert_array = torch::Tensor()); private: - const torch::Device device_; + void param_from_args(atb_speed::onerec::BlockLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool is_prefill, + const ModelInputParams* input_params = nullptr); + + void build_encoder_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + bool is_prefill, + int32_t layer_id = 0); + + void build_decoder_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output = nullptr, + int32_t layer_id = 0); + + void build_decoder_moe_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output = nullptr, + int32_t layer_id = 0, + const torch::Tensor& expert_array = torch::Tensor()); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::onerec::BlockLayerParam& param); + + int64_t init_attn_mask(); + + int32_t setup_common_decoder_tensors(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + torch::Tensor* encoder_output = nullptr, + int32_t start_tensor_idx = 0); + + void resize_experts_weights(int32_t num_of_device_experts); + void process_expert_weights(const StateDict& state_dict, + const std::string& state_key, + const torch::Tensor& tensor); + void process_shared_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + void merge_experts_weights(); + void merge_shared_experts_weights(); + bool validate_decoder_moe_weights(const std::string& prefix) const; + torch::Tensor merge_experts_weights(std::vector& experts, + bool transpose = false); + torch::Tensor merge_experts_weights(std::vector& experts_gate, + std::vector& experts_up, + bool transpose = false); + int32_t extract_expert_index(const std::string& name); + std::string extract_endswith(const std::string& input); + + atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node decode_node_; + std::string model_name_; + atb_speed::onerec::BlockLayerParam prefill_param_; + atb_speed::onerec::BlockLayerParam decode_param_; + + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + + at::Tensor encoder_output_contiguous_; + at::Tensor at_placeholder_; + std::vector placeholder_vec_; + + int32_t device_id_ = 0; bool is_decoder_ = false; int32_t layer_id_ = 0; + + std::unordered_map> experts_weights_; + std::mutex experts_mutex_; + int32_t start_expert_id_ = 0; + int32_t end_expert_id_ = 0; + int32_t num_experts_per_partition_ = 0; + int32_t ep_size_ = 1; + int32_t ep_local_tp_rank_ = 0; + int32_t ep_local_tp_size_ = 1; + + std::vector shared_expert_gate_weights_; + std::vector shared_expert_up_weights_; + std::vector shared_expert_down_weights_; + std::unordered_map shared_expert_weights_map_; + + torch::Tensor expert_group_; + torch::Tensor one_hot_; + torch::Tensor zero_hot_; }; +TORCH_MODULE(NpuOneRecBlockLayer); } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index adc146b08..d0c6e84e8 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -21,7 +21,7 @@ limitations under the License. #include #include "common/global_flags.h" -#include "common/rec_model_utils.h" +#include "util/rec_model_utils.h" // #include "attn_mask.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index 6936f84ae..fb443c4c8 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -329,37 +329,37 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( const ModelInputParams& input_params) { auto [qkvz_padded, ba_padded] = project_padded_inputs(hidden_states, attn_metadata); + int64_t batch_size = qkvz_padded.size(0); + int64_t seq_len = qkvz_padded.size(1); + + torch::Tensor qkvz_flat = + qkvz_padded.view({batch_size * seq_len, qkvz_padded.size(-1)}); + torch::Tensor ba_flat = + ba_padded.view({batch_size * seq_len, ba_padded.size(-1)}); + xllm::kernel::FusedQkvzbaSplitReshapeParams fused_params; + fused_params.mixed_qkvz = qkvz_flat; + fused_params.mixed_ba = ba_flat; + fused_params.num_heads_qk = static_cast(num_k_heads_ / tp_size_); + fused_params.num_heads_v = static_cast(num_v_heads_ / tp_size_); + fused_params.head_qk = static_cast(head_k_dim_); + fused_params.head_v = static_cast(head_v_dim_); + + torch::Tensor mixed_qkv, z, b, a; + std::tie(mixed_qkv, z, b, a) = + xllm::kernel::fused_qkvzba_split_reshape_cat(fused_params); + + mixed_qkv = mixed_qkv.view({batch_size, seq_len, mixed_qkv.size(-1)}); + z = z.view({batch_size, seq_len, num_v_heads_ / tp_size_, head_v_dim_}); + b = b.view({batch_size, seq_len, num_v_heads_ / tp_size_}); + a = a.view({batch_size, seq_len, num_v_heads_ / tp_size_}); - torch::Tensor q, k, v, z, b, a; - std::tie(q, k, v, z) = process_qkvz_tensor(qkvz_padded); - std::tie(b, a) = process_ba_tensor(ba_padded); - - auto rearrange_merge = [](const torch::Tensor& t) { - TORCH_CHECK( - t.dim() > 2, "Tensor must have at least 2 dims! but got ", t.dim()); - std::vector new_shape; - int64_t slice_end = t.dim() - 2; - auto valid_slice = t.sizes().slice(0, slice_end); - new_shape = std::vector(valid_slice.begin(), valid_slice.end()); - int64_t last_two_dim = t.size(slice_end) * t.size(slice_end + 1); - new_shape.push_back(last_two_dim); - return t.reshape(new_shape); - }; - - q = rearrange_merge(q); - k = rearrange_merge(k); - v = rearrange_merge(v); - - // Run the causal conv update on the mixed QKV states. - torch::Tensor mixed_qkv = torch::cat({q, k, v}, q.dim() - 1); - mixed_qkv = mixed_qkv.transpose(1, 2); - int64_t seq_len = mixed_qkv.size(2); torch::Tensor conv_cache = kv_cache.get_conv_cache(); torch::Tensor ssm_cache = kv_cache.get_ssm_cache(); torch::Tensor g, beta, core_attn_out, last_recurrent_state; auto device = mixed_qkv.device(); auto conv_weight = conv1d_->weight(); + mixed_qkv = mixed_qkv.transpose(1, 2); if (attn_metadata.is_prefill) { torch::Tensor conv_state = (seq_len < conv_kernel_size_ - 1) @@ -382,12 +382,12 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len)); } else { - const auto state_indices = attn_metadata.block_table.select(1, 0); xllm::kernel::CausalConv1dUpdateParams params; params.x = mixed_qkv; params.conv_state = conv_cache; params.weight = conv_weight; - params.conv_state_indices = state_indices; + params.conv_state_indices = + attn_metadata.block_table.select(1, 0).contiguous(); mixed_qkv = xllm::kernel::causal_conv1d_update(params); } @@ -444,7 +444,8 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( norm_out = norm_out.view({-1, norm_out.size(2), norm_out.size(3)}); // Project the normalized attention output back to hidden size. - auto rearranged_norm = rearrange_merge(norm_out); + auto rearranged_norm = + norm_out.reshape({norm_out.size(0), norm_out.size(1) * norm_out.size(2)}); rearranged_norm = reshape_qkvz_unpad(attn_metadata, rearranged_norm); auto attn_output = o_proj_->forward(rearranged_norm); return attn_output; @@ -519,67 +520,5 @@ Qwen3GatedDeltaNetBaseImpl::process_mixed_qkv(torch::Tensor& mixed_qkv) const { return std::make_tuple(processed_q, processed_k, processed_v); } -std::tuple -Qwen3GatedDeltaNetBaseImpl::process_qkvz_tensor( - const torch::Tensor& qkvz) const { - std::vector new_tensor_shape_qkvz = [&]() { - std::vector dims; - dims.push_back(qkvz.size(0)); - dims.push_back(qkvz.size(1)); - int64_t dim1 = num_k_heads_ / tp_size_; - int64_t dim2 = head_k_dim_ + head_k_dim_ + - (head_v_dim_ + head_v_dim_) * num_v_heads_ / num_k_heads_; - dims.push_back(dim1); - dims.push_back(dim2); - return dims; - }(); - - auto reshaped_qkvz = qkvz.view(new_tensor_shape_qkvz); - auto qkvz_split = torch::split(reshaped_qkvz, - {head_k_dim_, - head_k_dim_, - num_v_heads_ / num_k_heads_ * head_v_dim_, - num_v_heads_ / num_k_heads_ * head_v_dim_}, - reshaped_qkvz.dim() - 1); - - auto q = qkvz_split[0].contiguous(); - auto k = qkvz_split[1].contiguous(); - auto v = qkvz_split[2].contiguous(); - auto z = qkvz_split[3].contiguous(); - - v = v.reshape({v.size(0), v.size(1), num_v_heads_ / tp_size_, head_v_dim_}); - z = z.reshape({z.size(0), z.size(1), num_v_heads_ / tp_size_, head_v_dim_}); - - return std::make_tuple(q, k, v, z); -} - -std::tuple -Qwen3GatedDeltaNetBaseImpl::process_ba_tensor(const torch::Tensor& ba) const { - std::vector new_tensor_shape_ba = [&]() { - std::vector dims; - dims.push_back(ba.size(0)); - dims.push_back(ba.size(1)); - int64_t dim1 = num_k_heads_ / tp_size_; - int64_t dim2 = 2 * num_v_heads_ / num_k_heads_; - dims.push_back(dim1); - dims.push_back(dim2); - return dims; - }(); - - auto reshaped_ba = ba.view(new_tensor_shape_ba); - auto ba_split = - torch::split(reshaped_ba, - {num_v_heads_ / num_k_heads_, num_v_heads_ / num_k_heads_}, - reshaped_ba.dim() - 1); - - auto b = ba_split[0].contiguous(); - auto a = ba_split[1].contiguous(); - - b = b.reshape({b.size(0), b.size(1), num_v_heads_ / tp_size_}); - a = a.reshape({a.size(0), a.size(1), num_v_heads_ / tp_size_}); - - return std::make_tuple(b, a); -} - } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h index 2fa489fc7..dc011e684 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h @@ -63,10 +63,6 @@ class Qwen3GatedDeltaNetBaseImpl : public torch::nn::Module { torch::Tensor reshape_qkvz_unpad(const AttentionMetadata& attn_metadata, const torch::Tensor& padded_qkvz) const; - std::tuple - process_qkvz_tensor(const torch::Tensor& qkvz) const; - std::tuple process_ba_tensor( - const torch::Tensor& ba) const; std::tuple process_mixed_qkv( torch::Tensor& mixed_qkv) const; diff --git a/xllm/core/layers/oxygen_vision_layer.cpp b/xllm/core/layers/oxygen_vision_layer.cpp new file mode 100644 index 000000000..09d62ba4a --- /dev/null +++ b/xllm/core/layers/oxygen_vision_layer.cpp @@ -0,0 +1,72 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "oxygen_vision_layer.h" + +namespace xllm { +namespace layer { + +OxygenVisionLayerImpl::OxygenVisionLayerImpl(const ModelContext& context) { + const auto& args = context.get_model_args(); + const auto& quant_config = context.get_quant_args(); + const auto& parallel_args = context.get_parallel_args(); + const auto& options = context.get_tensor_options(); + int64_t dim = args.mm_hidden_size(); + int64_t mlp_intermediate_size = args.mm_intermediate_size(); + attention_ = register_module("self_attn", OxygenVisionAttention(context)); + norm1_ = register_module("norm1", RMSNorm(dim, args.rms_norm_eps(), options)); + norm2_ = register_module("norm2", RMSNorm(dim, args.rms_norm_eps(), options)); + + mlp_ = register_module("mlp", + DenseMLP(dim, + args.mm_intermediate_size(), + /*is_gated=*/true, + /*has_bias=*/false, + args.mm_hidden_act(), + /*enable_result_reduction=*/true, + quant_config, + parallel_args.tp_group_, + options)); +} + +void OxygenVisionLayerImpl::load_state_dict(const StateDict& state_dict) { + attention_->load_state_dict(state_dict.get_dict_with_prefix("attn.")); + mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp.")); + norm1_->load_state_dict(state_dict.get_dict_with_prefix("norm1.")); + norm2_->load_state_dict(state_dict.get_dict_with_prefix("norm2.")); +} + +torch::Tensor OxygenVisionLayerImpl::forward( + torch::Tensor& hidden_states, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + auto norm_output1 = std::get<0>(norm1_(hidden_states)); + auto output = hidden_states + attention_(norm_output1, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params); + auto norm_output2 = std::get<0>(norm2_(output)); + output = output + mlp_(norm_output2); + return output; +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/oxygen_vision_layer.h b/xllm/core/layers/oxygen_vision_layer.h new file mode 100644 index 000000000..4b81a2718 --- /dev/null +++ b/xllm/core/layers/oxygen_vision_layer.h @@ -0,0 +1,56 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "common/dense_mlp.h" +#include "common/oxygen_vision_attention.h" +#include "common/rms_norm.h" +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" + +namespace xllm { +namespace layer { + +class OxygenVisionLayerImpl : public torch::nn::Module { + public: + OxygenVisionLayerImpl(const ModelContext& context); + + void load_state_dict(const StateDict& state_dict); + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id); + + private: + OxygenVisionAttention attention_{nullptr}; + DenseMLP mlp_{nullptr}; + RMSNorm norm1_{nullptr}; + RMSNorm norm2_{nullptr}; +}; +TORCH_MODULE(OxygenVisionLayer); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/qwen2_decoder_layer.cpp b/xllm/core/layers/qwen2_decoder_layer.cpp index 5260a9bf2..db5f8395a 100644 --- a/xllm/core/layers/qwen2_decoder_layer.cpp +++ b/xllm/core/layers/qwen2_decoder_layer.cpp @@ -18,11 +18,14 @@ limitations under the License. namespace xllm { namespace layer { -Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) +Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context, + int32_t layer_id) : parallel_args_(context.get_parallel_args()) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); const auto& options = context.get_tensor_options(); + const std::string mlp_module_prefix = + layer_id >= 0 ? "model.layers." + std::to_string(layer_id) + ".mlp" : ""; // Initialize attention layers attention_ = register_module("self_attn", Qwen2Attention(context)); @@ -46,7 +49,8 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) /*enable_result_reduction=*/true, quant_args, parallel_args_.tp_group_, - options)); + options, + mlp_module_prefix)); } void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { diff --git a/xllm/core/layers/qwen2_decoder_layer.h b/xllm/core/layers/qwen2_decoder_layer.h index 86892e945..19ed4b601 100644 --- a/xllm/core/layers/qwen2_decoder_layer.h +++ b/xllm/core/layers/qwen2_decoder_layer.h @@ -35,7 +35,8 @@ namespace layer { class Qwen2DecoderLayerImpl : public torch::nn::Module { public: - explicit Qwen2DecoderLayerImpl(const ModelContext& context); + explicit Qwen2DecoderLayerImpl(const ModelContext& context, + int32_t layer_id = -1); void load_state_dict(const StateDict& state_dict); diff --git a/xllm/core/layers/qwen3_moe_decoder_layer.cpp b/xllm/core/layers/qwen3_moe_decoder_layer.cpp index 3c370828d..6fa3f1aef 100644 --- a/xllm/core/layers/qwen3_moe_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_moe_decoder_layer.cpp @@ -80,6 +80,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, parallel_args_, options)); } else { + const std::string mlp_module_prefix = + "model.layers." + std::to_string(layer_id) + ".mlp"; mlp_ = register_module("mlp", DenseMLP(model_args.hidden_size(), model_args.intermediate_size(), @@ -89,7 +91,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, /*enable_result_reduction=*/true, quant_args, parallel_args_.tp_group_, - options)); + options, + mlp_module_prefix)); } } diff --git a/xllm/core/platform/CMakeLists.txt b/xllm/core/platform/CMakeLists.txt index b67240496..7b7465516 100644 --- a/xllm/core/platform/CMakeLists.txt +++ b/xllm/core/platform/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library( vmm_api.h shared_vmm_allocator.h vmm_torch_allocator.h + $<$:mlu/mlu_layer_synchronizer.h> $<$:cuda/cuda_utils.h> $<$:numa_utils.h> SRCS @@ -21,6 +22,7 @@ cc_library( device.cpp vmm_api.cpp shared_vmm_allocator.cpp + $<$:mlu/mlu_layer_synchronizer.cpp> $<$:numa_utils.cpp> DEPS torch @@ -51,4 +53,3 @@ cc_test( glog::glog ) endif() - diff --git a/xllm/core/platform/mlu/mlu_layer_synchronizer.cpp b/xllm/core/platform/mlu/mlu_layer_synchronizer.cpp new file mode 100644 index 000000000..b23ec6a08 --- /dev/null +++ b/xllm/core/platform/mlu/mlu_layer_synchronizer.cpp @@ -0,0 +1,49 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "platform/mlu/mlu_layer_synchronizer.h" + +#include + +#include + +namespace xllm { + +MLULayerSynchronizerImpl::MLULayerSynchronizerImpl(int64_t num_layers) + : events_(), event_record_flags_(static_cast(num_layers)) { + events_.reserve(static_cast(num_layers)); + for (int64_t i = 0; i < num_layers; ++i) { + events_.emplace_back(c10::DeviceType::PrivateUse1); + } +} + +bool MLULayerSynchronizerImpl::synchronize_layer(int64_t layer_index) { + while (!event_record_flags_[layer_index].load(std::memory_order_acquire)) { + std::this_thread::yield(); + } + events_[layer_index].synchronize(); + return true; +} + +bool MLULayerSynchronizerImpl::record_current(int64_t layer_index, + int32_t device_index) { + c10::Stream current_stream = + torch_mlu::getCurrentMLUStream(device_index).unwrap(); + events_[layer_index].record(current_stream); + event_record_flags_[layer_index].store(true, std::memory_order_release); + return true; +} + +} // namespace xllm diff --git a/xllm/core/platform/mlu/mlu_layer_synchronizer.h b/xllm/core/platform/mlu/mlu_layer_synchronizer.h new file mode 100644 index 000000000..912c91ea8 --- /dev/null +++ b/xllm/core/platform/mlu/mlu_layer_synchronizer.h @@ -0,0 +1,40 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +namespace xllm { + +class MLULayerSynchronizerImpl final { + public: + explicit MLULayerSynchronizerImpl(int64_t num_layers); + ~MLULayerSynchronizerImpl() = default; + + bool synchronize_layer(int64_t layer_index); + bool record_current(int64_t layer_index, int32_t device_index); + uint32_t get_event_size() const { return events_.size(); } + + private: + std::vector events_; + std::vector> event_record_flags_; +}; + +} // namespace xllm diff --git a/xllm/core/platform/stream.h b/xllm/core/platform/stream.h index c4dd59f25..78b0c4efd 100644 --- a/xllm/core/platform/stream.h +++ b/xllm/core/platform/stream.h @@ -66,6 +66,8 @@ class Stream { c10_npu::NPUStream* get_stream() { return &stream_; } #elif defined(USE_MLU) torch_mlu::MLUStream* get_stream() { return &stream_; } +#elif defined(USE_CUDA) + c10::cuda::CUDAStream* get_stream() { return &stream_; } #endif void wait_stream(const Stream& other_stream); diff --git a/xllm/core/platform/vmm_torch_allocator.h b/xllm/core/platform/vmm_torch_allocator.h index bf7e0a657..c7f61d185 100644 --- a/xllm/core/platform/vmm_torch_allocator.h +++ b/xllm/core/platform/vmm_torch_allocator.h @@ -18,7 +18,9 @@ limitations under the License. // VMMTorchAllocator is only available for platforms using PyTorch's // CUDACachingAllocator interface (CUDA, ILU, ROCm). #if defined(USE_CUDA) || defined(USE_ILU) - +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 10 +#include +#endif #include #include @@ -108,10 +110,6 @@ class VMMTorchAllocator std::string name() override { return "VMMTorchAllocator"; } - void emptyCache() override { - LOG(FATAL) << "VMMTorchAllocator::emptyCache() called unexpectedly!"; - } - c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex /*device*/) override { LOG(FATAL) << "VMMTorchAllocator::getDeviceStats() called unexpectedly!"; @@ -167,29 +165,6 @@ class VMMTorchAllocator LOG(FATAL) << "VMMTorchAllocator::recordStream() called unexpectedly!"; } - c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override { - LOG(FATAL) << "VMMTorchAllocator::snapshot() called unexpectedly!"; - return {}; - } - - void beginAllocateToPool( - c10::DeviceIndex /*device*/, - c10::cuda::MempoolId_t /*mempool_id*/, - std::function /*filter*/) override { - LOG(FATAL) - << "VMMTorchAllocator::beginAllocateToPool() called unexpectedly!"; - } - - void endAllocateToPool(c10::DeviceIndex /*device*/, - c10::cuda::MempoolId_t /*mempool_id*/) override { - LOG(FATAL) << "VMMTorchAllocator::endAllocateToPool() called unexpectedly!"; - } - - void releasePool(c10::DeviceIndex /*device*/, - c10::cuda::MempoolId_t /*mempool_id*/) override { - LOG(FATAL) << "VMMTorchAllocator::releasePool() called unexpectedly!"; - } - c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle( void* /*ptr*/) override { LOG(ERROR) << "VMMTorchAllocator::shareIpcHandle() called - not supported!"; @@ -203,14 +178,6 @@ class VMMTorchAllocator return nullptr; } - void recordHistory( - bool /*enabled*/, - c10::cuda::CUDACachingAllocator::CreateContextFn /*context_recorder*/, - size_t /*alloc_trace_max_entries*/, - c10::cuda::CUDACachingAllocator::RecordContext /*when*/) override { - LOG(FATAL) << "VMMTorchAllocator::recordHistory() called unexpectedly!"; - } - void attachOutOfMemoryObserver( c10::cuda::CUDACachingAllocator::OutOfMemoryObserver /*observer*/) override { @@ -241,15 +208,6 @@ class VMMTorchAllocator return cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream); } - std::shared_ptr - getCheckpointState(c10::DeviceIndex /*device*/, - c10::cuda::MempoolId_t /*id*/) override { - LOG(ERROR) - << "VMMTorchAllocator::getCheckpointState() called - not supported!"; - TORCH_CHECK(false, name(), " does not support checkpointing"); - return nullptr; - } - c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState( c10::DeviceIndex /*device*/, std::shared_ptr /*pps*/) @@ -260,6 +218,77 @@ class VMMTorchAllocator return {}; } + void beginAllocateToPool( + c10::DeviceIndex /*device*/, + at::cuda::MempoolId_t /*mempool_id*/, + std::function /*filter*/) override { + LOG(FATAL) + << "VMMTorchAllocator::beginAllocateToPool() called unexpectedly!"; + } + + void endAllocateToPool(c10::DeviceIndex /*device*/, + at::cuda::MempoolId_t /*mempool_id*/) override { + LOG(FATAL) << "VMMTorchAllocator::endAllocateToPool() called unexpectedly!"; + } + + void releasePool(c10::DeviceIndex /*device*/, + at::cuda::MempoolId_t /*mempool_id*/) override { + LOG(FATAL) << "VMMTorchAllocator::releasePool() called unexpectedly!"; + } + + std::shared_ptr + getCheckpointState(c10::DeviceIndex /*device*/, + at::cuda::MempoolId_t /*id*/) override { + LOG(FATAL) + << "VMMTorchAllocator::getCheckpointState() called unexpectedly!"; + return nullptr; + } + +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 10 + void emptyCache(at::cuda::MempoolId_t /*mempool_id*/ = {0, 0}) override { + LOG(FATAL) << "VMMTorchAllocator::emptyCache() called unexpectedly!"; + } + + std::vector + getExpandableSegmentSizes(c10::DeviceIndex /*device*/) override { + LOG(FATAL) << "VMMTorchAllocator::getExpandableSegmentSizes() called " + "unexpectedly!"; + return {}; + } + + c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot( + at::cuda::MempoolId_t /*mempool_id*/ = {0, 0}) override { + LOG(FATAL) << "VMMTorchAllocator::snapshot() called unexpectedly!"; + return {}; + } + + void recordHistory( + bool /*enabled*/, + c10::cuda::CUDACachingAllocator::CreateContextFn /*context_recorder*/, + size_t /*alloc_trace_max_entries*/, + c10::cuda::CUDACachingAllocator::RecordContext /*when*/, + bool /*clearHistory*/) override { + LOG(FATAL) << "VMMTorchAllocator::recordHistory() called unexpectedly!"; + } +#else + void emptyCache() override { + LOG(FATAL) << "VMMTorchAllocator::emptyCache() called unexpectedly!"; + } + + c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override { + LOG(FATAL) << "VMMTorchAllocator::snapshot() called unexpectedly!"; + return {}; + } + + void recordHistory( + bool /*enabled*/, + c10::cuda::CUDACachingAllocator::CreateContextFn /*context_recorder*/, + size_t /*alloc_trace_max_entries*/, + c10::cuda::CUDACachingAllocator::RecordContext /*when*/) override { + LOG(FATAL) << "VMMTorchAllocator::recordHistory() called unexpectedly!"; + } +#endif + private: static void raw_deleter(void* ptr) { // No-op: VMM memory is not freed individually diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index 2055dfab4..d2574da19 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -24,7 +24,7 @@ cc_library( worker_impl.h llm_worker_impl.h vlm_worker_impl.h - dit_worker.h + dit_worker_impl.h embed_worker_impl.h embed_vlm_worker_impl.h rec_worker_impl.h @@ -50,7 +50,7 @@ cc_library( worker_impl.cpp llm_worker_impl.cpp vlm_worker_impl.cpp - dit_worker.cpp + dit_worker_impl.cpp embed_worker_impl.cpp embed_vlm_worker_impl.cpp rec_worker_impl.cpp @@ -74,6 +74,8 @@ cc_library( :state_dict :dit_cache :model + :kv_cache + :kv_cache_transfer :models :sampler :tokenizer diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index e34188e57..728498fc6 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -48,6 +48,19 @@ limitations under the License. namespace xllm::npu { namespace { +std::pair find_attention_plan_kv_cache( + const std::vector& kv_caches) { + for (const auto& cache : kv_caches) { + auto k_cache = cache.get_k_cache(); + auto v_cache = cache.get_v_cache(); + if (k_cache.defined() && v_cache.defined() && k_cache.numel() > 0 && + v_cache.numel() > 0) { + return {std::move(k_cache), std::move(v_cache)}; + } + } + return {torch::Tensor(), torch::Tensor()}; +} + int64_t get_decode_graph_capacity(const runtime::Options& options) { CHECK_GT(options.num_decoding_tokens(), 0) << "num_decoding_tokens must be > 0 for graph capacity"; @@ -265,14 +278,8 @@ std::optional GraphPersistentParam::update( } if (tiling_data_.numel() > 0) { - // Get current stream for tiling tensor update aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - // Update tiling tensor based on model type - // For models with mixed attention types (e.g., qwen3_next), only update if - // k/v cache is defined - // NOTE: linear attention may pass "defined but empty" k/v cache tensors. - // Only treat k/v cache as valid when they are defined and non-empty. if (need_update_attention_plan_ && k_cache.defined() && v_cache.defined() && k_cache.numel() > 0 && v_cache.numel() > 0) { plan_paged_attention_tiling( @@ -798,9 +805,12 @@ bool AclGraph::capture(CausalLM* model, aclrtStream stream = c10_npu::getCurrentNPUStream(tensor_options.device().index()).stream(); - // Update persistent parameters with input data before capture - const torch::Tensor& k_cache = kv_cache[0].get_k_cache(); - const torch::Tensor& v_cache = kv_cache[0].get_v_cache(); + // For hybrid models (e.g., qwen3_next with mixed GDN/full_attention layers), + // we need to find the first Full Attention layer to get the correct kv_cache. + // GDN layers have empty key_cache_/value_cache_ while Full Attention layers + // have valid kv caches. Using layer 0's cache directly would be incorrect + // if layer 0 is a GDN layer. + auto [k_cache, v_cache] = find_attention_plan_kv_cache(kv_cache); const uint32_t actual_num_tokens = tokens.size(0); CHECK_GE(num_tokens_, actual_num_tokens) << "num_tokens_ >= actual_num_tokens"; @@ -899,8 +909,11 @@ ModelOutput AclGraph::replay(const torch::Tensor& tokens, << actual_num_tokens; // Update persistent parameters with new input data - const torch::Tensor& k_cache = kv_cache[0].get_k_cache(); - const torch::Tensor& v_cache = kv_cache[0].get_v_cache(); + // Note: tiling_data is updated in update() if needed - for hybrid models + // (e.g., qwen3_next with mixed GDN/attention layers), tiling should only + // be updated when Full Attention layers are involved, which is determined + // by k_cache being valid and non-empty + auto [k_cache, v_cache] = find_attention_plan_kv_cache(kv_cache); persistent_param_.update(tokens, k_cache, v_cache, diff --git a/xllm/core/runtime/acl_graph_executor_test.cpp b/xllm/core/runtime/acl_graph_executor_test.cpp index 35408dbe3..9dcc438c6 100644 --- a/xllm/core/runtime/acl_graph_executor_test.cpp +++ b/xllm/core/runtime/acl_graph_executor_test.cpp @@ -337,7 +337,7 @@ class AclGraphExecutorTest : public ::testing::Test { auto kv_cache = torch::randn({n_blocks, block_size * hidden_size}, torch::dtype(torch::kFloat32).device(*device_)); - kv_caches_.push_back({kv_cache, kv_cache}); + kv_caches_.emplace_back(kv_cache, kv_cache); } void TearDown() override { return; } diff --git a/xllm/core/runtime/cuda_graph_executor_impl.cpp b/xllm/core/runtime/cuda_graph_executor_impl.cpp index 998110f2d..9c0e9ebe7 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.cpp +++ b/xllm/core/runtime/cuda_graph_executor_impl.cpp @@ -29,7 +29,6 @@ limitations under the License. #include "core/common/global_flags.h" #include "core/common/metrics.h" -#include "core/common/rec_model_utils.h" #include "core/layers/common/attention_metadata.h" #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/cuda/flashinfer_planinfo.h" @@ -39,6 +38,7 @@ limitations under the License. #include "core/platform/shared_vmm_allocator.h" #include "core/platform/stream.h" #include "core/platform/vmm_torch_allocator.h" +#include "core/util/rec_model_utils.h" #include "core/util/utils.h" #include "kernels/cuda/global_capture_instance.h" #include "kernels/cuda/utils.h" @@ -391,8 +391,14 @@ std::optional CudaGraphPersistentParam::update( const bool use_two_stage_decode = !FLAGS_enable_xattention_one_stage && is_decode_with_llmrec; const int32_t head_dim = args_.head_dim(); - const int64_t n_heads = args_.n_heads(); - const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads); + const int64_t tp_size = + options_.world_size() / std::max(options_.dp_size(), 1); + const int64_t n_heads = args_.n_heads() / std::max(tp_size, int64_t{1}); + const int64_t total_kv_heads = args_.n_kv_heads().value_or(args_.n_heads()); + const int64_t n_kv_heads = + (total_kv_heads >= tp_size) + ? (total_kv_heads / std::max(tp_size, int64_t{1})) + : 1; const int64_t block_size = options_.block_size(); // Get sliding_window from ModelArgs (default to -1 if not available) @@ -688,7 +694,7 @@ bool CudaGraph::capture(CausalLM* model, std::vector& kv_cache, uint32_t bucket_num_tokens, const at::cuda::MempoolId_t& pool, - c10::cuda::MemPool* pool_ptr) { + TorchMemPool* pool_ptr) { padded_num_tokens_ = bucket_num_tokens; const uint32_t actual_num_tokens = tokens.size(0); CHECK_GE(padded_num_tokens_, actual_num_tokens) @@ -751,14 +757,17 @@ bool CudaGraph::capture(CausalLM* model, kv_cache, graph_params_opt.value()); + // MemPoolContext has been deprecated in torch >= 2.8 +#if TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <= 7 // Activate VMM mempool only for the actual capture to keep plan_info // allocations out of the shared physical memory pool. std::optional mempool_ctx; if (pool_ptr != nullptr) { mempool_ctx.emplace(pool_ptr); } +#endif - // Begin piecewise capture via GlobalCaptureInstance + // Begin piecewise capture via GlobalCaptureInstance. GlobalCaptureInstance::get_instance().begin_capture(pool); // Execute forward pass - attention operations will be captured separately @@ -797,12 +806,16 @@ bool CudaGraph::capture(CausalLM* model, << ", num_runners: " << piecewise_graph_.num_runners(); } else { // Normal capture mode (for decode) + + // MemPoolContext has been deprecated in torch >= 2.8 +#if TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <= 7 // Activate VMM mempool only for the actual capture to keep plan_info // allocations out of the shared physical memory pool. std::optional mempool_ctx; if (pool_ptr != nullptr) { mempool_ctx.emplace(pool_ptr); } +#endif // Begin graph capture (capture_mode defaults to // cudaStreamCaptureModeGlobal) @@ -970,8 +983,7 @@ constexpr uint32_t kPhysicalPoolIdDecode = 1; struct CudaGraphExecutorImpl::VmmPoolState { std::unique_ptr allocator; std::unique_ptr torch_allocator; - std::unordered_map> - mempools_by_shape; + std::unordered_map> mempools_by_shape; }; CudaGraphExecutorImpl::~CudaGraphExecutorImpl() { @@ -999,7 +1011,7 @@ CudaGraphExecutorImpl::get_or_create_vmm_pool_state(uint32_t physical_pool_id) { return *slot; } -c10::cuda::MemPool* CudaGraphExecutorImpl::get_or_create_vmm_mempool( +TorchMemPool* CudaGraphExecutorImpl::get_or_create_vmm_mempool( uint32_t physical_pool_id, uint32_t shape_id) { VmmPoolState& state = get_or_create_vmm_pool_state(physical_pool_id); @@ -1009,9 +1021,9 @@ c10::cuda::MemPool* CudaGraphExecutorImpl::get_or_create_vmm_mempool( if (it != mempools.end()) { return it->second.get(); } - auto pool = std::make_unique(state.torch_allocator.get(), - /*is_user_created=*/true); - c10::cuda::MemPool* ptr = pool.get(); + auto pool = std::make_unique(state.torch_allocator.get(), + /*is_user_created=*/true); + TorchMemPool* ptr = pool.get(); mempools[shape_id] = std::move(pool); VLOG(kGraphExecutorLogVerboseLevel) << "Created per-shape VMM MemPool for executor " << this << ", device " @@ -1021,9 +1033,8 @@ c10::cuda::MemPool* CudaGraphExecutorImpl::get_or_create_vmm_mempool( return ptr; } -c10::cuda::MemPool* CudaGraphExecutorImpl::get_vmm_mempool( - uint32_t physical_pool_id, - uint32_t shape_id) { +TorchMemPool* CudaGraphExecutorImpl::get_vmm_mempool(uint32_t physical_pool_id, + uint32_t shape_id) { std::lock_guard lock(vmm_mutex_); auto it = vmm_pools_.find(physical_pool_id); if (it == vmm_pools_.end() || !it->second) { @@ -1151,7 +1162,7 @@ at::cuda::MempoolId_t CudaGraphExecutorImpl::get_mem_pool( return graph_pool_; } // Per-shape VMM MemPool: look up pool for (physical_pool_id, shape_id). - c10::cuda::MemPool* pool = get_vmm_mempool(physical_pool_id, shape_id); + TorchMemPool* pool = get_vmm_mempool(physical_pool_id, shape_id); CHECK(pool != nullptr) << "VMM MemPool for shape_id=" << shape_id << ", physical_pool_id=" << physical_pool_id @@ -1245,7 +1256,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, VLOG(kGraphExecutorLogVerboseLevel) << "CudaGraphExecutorImpl::run() in prefill piecewise capture mode"; - c10::cuda::MemPool* pool_ptr = nullptr; + TorchMemPool* pool_ptr = nullptr; if (FLAGS_enable_graph_vmm_pool) { reset_vmm_allocator_offset(kPhysicalPoolIdPrefill); const uint32_t shape_id = bucket_num_tokens; @@ -1330,7 +1341,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, VLOG(kGraphExecutorLogVerboseLevel) << "CudaGraphExecutorImpl::run() in decode capture mode"; - c10::cuda::MemPool* pool_ptr = nullptr; + TorchMemPool* pool_ptr = nullptr; if (FLAGS_enable_graph_vmm_pool) { reset_vmm_allocator_offset(kPhysicalPoolIdDecode); const uint32_t shape_id = bucket_num_tokens; diff --git a/xllm/core/runtime/cuda_graph_executor_impl.h b/xllm/core/runtime/cuda_graph_executor_impl.h index 160fd531e..d47d7659b 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.h +++ b/xllm/core/runtime/cuda_graph_executor_impl.h @@ -21,6 +21,9 @@ limitations under the License. #include #include #include +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 10 +#include +#endif #include #include @@ -40,6 +43,12 @@ limitations under the License. namespace xllm::runtime::cuda { +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 10 +using TorchMemPool = at::cuda::MemPool; +#else +using TorchMemPool = c10::cuda::MemPool; +#endif + // Helper class to hold persistent parameters for CUDA graph execution // Multiple CudaGraph instances can share the same CudaGraphPersistentParam // object @@ -222,7 +231,7 @@ class CudaGraph { std::vector& kv_cache, uint32_t bucket_num_tokens, const at::cuda::MempoolId_t& pool, - c10::cuda::MemPool* pool_ptr = nullptr); + TorchMemPool* pool_ptr = nullptr); // Replay captured graph with new input data ModelOutput replay(const torch::Tensor& tokens, @@ -341,10 +350,9 @@ class CudaGraphExecutorImpl : public ExecutorImpl { }; VmmPoolState& get_or_create_vmm_pool_state(uint32_t physical_pool_id); - c10::cuda::MemPool* get_or_create_vmm_mempool(uint32_t physical_pool_id, - uint32_t shape_id); - c10::cuda::MemPool* get_vmm_mempool(uint32_t physical_pool_id, - uint32_t shape_id); + TorchMemPool* get_or_create_vmm_mempool(uint32_t physical_pool_id, + uint32_t shape_id); + TorchMemPool* get_vmm_mempool(uint32_t physical_pool_id, uint32_t shape_id); GraphMemoryUsageStats get_graph_memory_usage_stats(); void log_graph_memory_after_capture(); diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index a96a5118a..861d8adc5 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -27,8 +27,144 @@ namespace xllm { // dit related forward input params struct DiTForwardInput { + bool valid() const { + return prompts.size() > 0 || prompt_embeds.defined() || + pooled_prompt_embeds.defined() || images.defined(); + } + + void save_with_prefix(std::string prefix) const { + torch::save(images, prefix + "images_cpp.pt"); + torch::save(prompt_embeds, prefix + "prompt_embeds_cpp.pt"); + torch::save(negative_prompt_embeds, prefix + "neg_prompt_embeds_cpp.pt"); + } + void debug_print(std::ostream& os = std::cout) const { + os << "=== DiTForwardInput Debug Info ===" << std::endl; + + // Print basic data types + os << "batch_size: " << batch_size << std::endl; + + // Print prompts vectors + os << "prompts: ["; + for (size_t i = 0; i < prompts.size(); ++i) { + os << "\"" << prompts[i] << "\""; + if (i < prompts.size() - 1) os << ", "; + } + os << "]" << std::endl; + + os << "prompts_2: ["; + for (size_t i = 0; i < prompts_2.size(); ++i) { + os << "\"" << prompts_2[i] << "\""; + if (i < prompts_2.size() - 1) os << ", "; + } + os << "]" << std::endl; + + os << "negative_prompts: ["; + for (size_t i = 0; i < negative_prompts.size(); ++i) { + os << "\"" << negative_prompts[i] << "\""; + if (i < negative_prompts.size() - 1) os << ", "; + } + os << "]" << std::endl; + + os << "negative_prompts_2: ["; + for (size_t i = 0; i < negative_prompts_2.size(); ++i) { + os << "\"" << negative_prompts_2[i] << "\""; + if (i < negative_prompts_2.size() - 1) os << ", "; + } + os << "]" << std::endl; + + // Print tensor shapes + os << "\n--- Tensor Shapes ---" << std::endl; + + os << "images: "; + if (images.defined()) { + os << images.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "condition_images: "; + if (condition_images.defined()) { + os << condition_images.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "mask_images: "; + if (mask_images.defined()) { + os << mask_images.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "control_image: "; + if (control_image.defined()) { + os << control_image.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "masked_image_latents: "; + if (masked_image_latents.defined()) { + os << masked_image_latents.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "prompt_embeds: "; + if (prompt_embeds.defined()) { + os << prompt_embeds.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "pooled_prompt_embeds: "; + if (pooled_prompt_embeds.defined()) { + os << pooled_prompt_embeds.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "negative_prompt_embeds: "; + if (negative_prompt_embeds.defined()) { + os << negative_prompt_embeds.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "negative_pooled_prompt_embeds: "; + if (negative_pooled_prompt_embeds.defined()) { + os << negative_pooled_prompt_embeds.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "latents: "; + if (latents.defined()) { + os << latents.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + // Print generation_params + os << "\n--- Generation Parameters ---" << std::endl; + os << "width: " << generation_params.width << std::endl; + os << "height: " << generation_params.height << std::endl; + os << "num_inference_steps: " << generation_params.num_inference_steps + << std::endl; + os << "true_cfg_scale: " << generation_params.true_cfg_scale << std::endl; + os << "guidance_scale: " << generation_params.guidance_scale << std::endl; + os << "num_images_per_prompt: " << generation_params.num_images_per_prompt + << std::endl; + os << "seed: " << generation_params.seed << std::endl; + os << "max_sequence_length: " << generation_params.max_sequence_length + << std::endl; + os << "strength: " << generation_params.strength << std::endl; + + os << "===============================" << std::endl; + } + DiTForwardInput to(const torch::Device& device, - torch::ScalarType dtype) const { + torch::ScalarType dtype = torch::kBFloat16) const { DiTForwardInput input = *this; if (prompt_embeds.defined()) { @@ -63,6 +199,14 @@ struct DiTForwardInput { if (mask_images.defined()) { input.mask_images = mask_images.to(device, dtype); } + + if (condition_images.defined()) { + input.condition_images = condition_images.to(device, dtype); + } + + if (control_image.defined()) { + input.control_image = control_image.to(device, dtype); + } return input; } @@ -82,6 +226,8 @@ struct DiTForwardInput { torch::Tensor images; + torch::Tensor condition_images; + torch::Tensor mask_images; torch::Tensor control_image; @@ -104,6 +250,9 @@ struct DiTForwardInput { // dit related forward output params struct DiTForwardOutput { + void save_with_prefix(std::string prefix) const { + torch::save(tensors[0], prefix + "dit_images_cpp.pt"); + } // generated tensor std::vector tensors; }; diff --git a/xllm/core/runtime/dit_worker.cpp b/xllm/core/runtime/dit_worker_impl.cpp similarity index 57% rename from xllm/core/runtime/dit_worker.cpp rename to xllm/core/runtime/dit_worker_impl.cpp index e7f34e9fd..f38bde768 100644 --- a/xllm/core/runtime/dit_worker.cpp +++ b/xllm/core/runtime/dit_worker_impl.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "dit_worker.h" +#include "dit_worker_impl.h" #include #include @@ -61,81 +61,123 @@ DiTCacheConfig parse_dit_cache_from_flags() { cache_config.fbcachetaylorseer.warmup_steps = FLAGS_dit_cache_warmup_steps; cache_config.fbcachetaylorseer.residual_diff_threshold = FLAGS_dit_cache_residual_diff_threshold; + } else if (FLAGS_dit_cache_policy == "ResidualCache") { + cache_config.selected_policy = PolicyType::ResidualCache; + cache_config.residual_cache.dit_cache_start_steps = + FLAGS_dit_cache_start_steps; + cache_config.residual_cache.dit_cache_end_steps = FLAGS_dit_cache_end_steps; + cache_config.residual_cache.dit_cache_start_blocks = + FLAGS_dit_cache_start_blocks; + cache_config.residual_cache.dit_cache_end_blocks = + FLAGS_dit_cache_end_blocks; + cache_config.residual_cache.skip_interval_steps = + FLAGS_dit_cache_skip_interval_steps; } else if (FLAGS_dit_cache_policy == "None") { - cache_config.selected_policy = PolicyType::TaylorSeer; + cache_config.selected_policy = PolicyType::None; } return cache_config; } } // namespace -DiTWorker::DiTWorker(const ParallelArgs& parallel_args, - const torch::Device& device, - const runtime::Options& options) - : device_(device), options_(options), parallel_args_(parallel_args) { +DiTWorkerImpl::DiTWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : WorkerImpl(parallel_args, device, options) { device_.set_device(); } -bool DiTWorker::init_model(const std::string& model_weights_path) { +bool DiTWorkerImpl::init_model(ModelContext& context) { + LOG(ERROR) + << "init model with model_context was not implemented for dit models"; + return false; +} + +bool DiTWorkerImpl::init_model(const std::string& model_weights_path, + int32_t random_seed, + MasterStatus master_status) { CHECK(dit_model_ == nullptr) << "Model is already initialized."; + // set same random seed for all worker + device_.set_seed(random_seed); + auto loader = std::make_unique(model_weights_path); dtype_ = util::parse_dtype(loader->get_torch_dtype(), device_); auto tensor_options = torch::dtype(dtype_).device(device_); - context_ = DiTModelContext(parallel_args_, - std::move(loader->get_model_args()), - std::move(loader->get_quant_args()), - tensor_options, - options_.model_id()); + DiTCacheConfig cache_config = parse_dit_cache_from_flags(); - dit_model_ = create_dit_model(context_); + auto model_type = loader->get_model_type(); + + if (!ModelRegistry::has_dit_model_factory(model_type)) { + LOG(WARNING) << "could not find model_type: " << model_type + << ", using model_id: " << options_.model_id() << " instead."; + model_type = options_.model_id(); + } + + dit_context_ = DiTModelContext(parallel_args_, + std::move(loader->get_model_args()), + std::move(loader->get_quant_args()), + tensor_options, + cache_config, + model_type); + + dit_model_ = create_dit_model(dit_context_); CHECK(dit_model_ != nullptr) << "Failed to create model."; dit_model_->load_model(std::move(loader)); dit_model_executor_ = std::make_unique(dit_model_.get(), options_); - DiTCacheConfig cache_config = parse_dit_cache_from_flags(); DiTCache::get_instance().init(cache_config); return true; } -folly::SemiFuture DiTWorker::init_model_async( - const std::string& model_weights_path) { +folly::SemiFuture DiTWorkerImpl::init_model_async( + const std::string& model_weights_path, + int32_t random_seed, + MasterStatus master_status) { auto promise = std::make_shared>(); auto future = promise->getSemiFuture(); - threadpool_.schedule([this, model_weights_path, promise]() mutable { - bool status = this->init_model(model_weights_path); + threadpool_.schedule([this, + model_weights_path, + random_seed, + master_status, + promise]() mutable { + bool status = + this->init_model(model_weights_path, random_seed, master_status); promise->setValue(status); }); return future; } -std::optional DiTWorker::step(const DiTForwardInput& inputs) { +std::optional DiTWorkerImpl::step(const ForwardInput& inputs) { + torch::DeviceGuard device_guard(device_); Timer timer; - - auto output = dit_model_executor_->forward(inputs.to(device_, dtype_)); + auto output = dit_model_executor_->forward( + inputs.input_params.dit_forward_input.to(device_, dtype_)); auto ret = device_.synchronize_default_stream(); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); - - return output; + ForwardOutput forward_output; + forward_output.dit_forward_output = output; + return forward_output; } -folly::SemiFuture> DiTWorker::step_async( - const DiTForwardInput& inputs) { - auto promise = - std::make_shared>>(); - auto future = promise->getSemiFuture(); - threadpool_.schedule([this, inputs, promise]() mutable { - auto output = this->step(inputs); - promise->setValue(output); +folly::SemiFuture> DiTWorkerImpl::step_async( + const ForwardInput& inputs) { + folly::Promise> promise; + auto future = promise.getSemiFuture(); + threadpool_.schedule([this, + input = std::move(inputs), + promise = std::move(promise)]() mutable { + auto output = this->step(input); + promise.setValue(output); }); return future; } -void DiTWorker::process_group_test() { +void DiTWorkerImpl::process_group_test() { // create random tensors const auto options = torch::dtype(torch::kHalf).device(device_); torch::Tensor tensor = torch::randn({10, 10}, options); @@ -153,7 +195,7 @@ void DiTWorker::process_group_test() { // context_.get_parallel_args().process_group_); } -folly::SemiFuture DiTWorker::process_group_test_async() { +folly::SemiFuture DiTWorkerImpl::process_group_test_async() { folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, promise = std::move(promise)]() mutable { @@ -164,11 +206,11 @@ folly::SemiFuture DiTWorker::process_group_test_async() { } // prepare input for execution -DiTForwardInput DiTWorker::prepare_inputs(DiTBatch& batch) { +DiTForwardInput DiTWorkerImpl::prepare_inputs(DiTBatch& batch) { return dit_model_executor_->prepare_inputs(batch); } -int64_t DiTWorker::get_active_activation_memory() { +int64_t DiTWorkerImpl::get_active_activation_memory() { return DeviceMonitor::get_instance() .get_device_stats(device_.index()) .active_activation_memory; diff --git a/xllm/core/runtime/dit_worker.h b/xllm/core/runtime/dit_worker_impl.h similarity index 68% rename from xllm/core/runtime/dit_worker.h rename to xllm/core/runtime/dit_worker_impl.h index c9f137f9e..c3a84e4a1 100644 --- a/xllm/core/runtime/dit_worker.h +++ b/xllm/core/runtime/dit_worker_impl.h @@ -27,24 +27,34 @@ limitations under the License. #include "options.h" #include "platform/device.h" #include "util/threadpool.h" +#include "worker_impl.h" namespace xllm { -class DiTWorker { +class DiTWorkerImpl : public WorkerImpl { public: - DiTWorker(const ParallelArgs& parallel_args, - const torch::Device& device, - const runtime::Options& options); + DiTWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); - ~DiTWorker() = default; + ~DiTWorkerImpl() = default; // initialize model, cache manager. blocking call - bool init_model(const std::string& model_weights_path); + bool init_model(const std::string& model_weights_path, + int32_t random_seed, + MasterStatus master_status) override; folly::SemiFuture init_model_async( - const std::string& model_weights_path); + const std::string& model_weights_path, + int32_t random_seed, + MasterStatus master_status) override; - std::optional step(const DiTForwardInput& inputs); + bool init_model(ModelContext& context) override; + + std::optional step(const ForwardInput& inputs) override; + + folly::SemiFuture> step_async( + const ForwardInput& inputs); folly::SemiFuture> step_async( const DiTForwardInput& inputs); @@ -59,20 +69,12 @@ class DiTWorker { int64_t get_active_activation_memory(); private: - runtime::Options options_; - std::unique_ptr dit_model_; std::unique_ptr dit_model_executor_; - Device device_; - - torch::ScalarType dtype_; - // model context, includes model args, parallel args and date type etc. - mutable DiTModelContext context_; - - ParallelArgs parallel_args_; + mutable DiTModelContext dit_context_; ThreadPool threadpool_; }; diff --git a/xllm/core/runtime/embed_vlm_worker_impl.cpp b/xllm/core/runtime/embed_vlm_worker_impl.cpp index 7bb823ccf..e188df2e0 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/embed_vlm_worker_impl.cpp @@ -43,8 +43,8 @@ EmbedVLMWorkerImpl::EmbedVLMWorkerImpl(const ParallelArgs& parallel_args, bool EmbedVLMWorkerImpl::init_model(ModelContext& context) { CHECK(model_ == nullptr) << "Model is already initialized."; - context.set_image_embedding_mode(true); - model_ = create_vlm_embedding_model(context); + context.set_image_embedding_mode(false); + model_ = create_vlm_model(context); CHECK(model_ != nullptr) << "Failed to create model."; model_executor_ = std::make_unique( model_.get(), context.get_model_args(), device_, options_); @@ -70,7 +70,6 @@ std::optional EmbedVLMWorkerImpl::step( auto model_output = model_executor_->forward( flatten_tokens, flatten_positions, kv_caches_, params); auto hidden_states = model_output.hidden_states; - ret = device_.synchronize_default_stream(); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); @@ -81,13 +80,25 @@ std::optional EmbedVLMWorkerImpl::step( // driver prepare model output ForwardOutput output; SampleOutput sample_output; - if (sampling_params.selected_token_idxes.defined() && input.sampling_params.is_embeddings) { - EmbeddingVLM* em_model = dynamic_cast(model_.get()); auto embeddings = - em_model->pooler(hidden_states, sampling_params.selected_token_idxes); + model_->pooler(hidden_states, sampling_params.selected_token_idxes); sample_output.embeddings = embeddings; + // split full embeddings and add them to mm_embeddings + // so that the user could receive embeddings of images and texts + if (FLAGS_enable_return_mm_full_embeddings) { + auto q_seq_len_vec = input.input_params.q_seq_lens_vec; + sample_output.mm_embeddings.reserve(q_seq_len_vec.size()); + int32_t token_start_idx = 0; + for (auto seq_len : q_seq_len_vec) { + auto image_embed = + embeddings.slice(0, token_start_idx, token_start_idx + seq_len); + sample_output.mm_embeddings.emplace_back(image_embed); + token_start_idx += seq_len; + } + } + output.sample_output = sample_output; output.embedding = embeddings; } diff --git a/xllm/core/runtime/embed_vlm_worker_impl.h b/xllm/core/runtime/embed_vlm_worker_impl.h index 79d100bc6..7d2788237 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.h +++ b/xllm/core/runtime/embed_vlm_worker_impl.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "core/common/global_flags.h" #include "executor.h" #include "forward_params.h" #include "framework/model/causal_vlm.h" diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 07515e14d..99eaadfe0 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -33,6 +33,7 @@ limitations under the License. #include "framework/sampling/beam_searcher.h" #include "framework/sampling/sampling_params.h" #include "platform/device.h" +#include "runtime/dit_forward_params.h" namespace xllm { @@ -203,6 +204,9 @@ struct ForwardOutput { BeamSearchOutput beam_search_output; torch::Tensor beam_sequence_group; + + // dit output data + DiTForwardOutput dit_forward_output; }; // Model input with raw data, which will be @@ -260,6 +264,9 @@ struct RawForwardInput { // multimodal data MMBatchData mm_data; + // dit input data + DiTForwardInput dit_forward_input; + RawForwardInput cp_partition(int32_t cp_rank, int32_t cp_size) const { RawForwardInput outputs = *this; if (cp_size <= 1 || flatten_tokens_vec.empty() || @@ -505,6 +512,8 @@ struct RawForwardOutput { std::vector beam_sequence_group; // flattened 2D // multimodal embedding output std::vector mm_embeddings; + // dit output data + DiTForwardOutput dit_forward_output; }; struct BatchedForwardInputs { diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 88b33fcdb..9764a7c70 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -15,15 +15,24 @@ limitations under the License. #include +#include +#include #include #include +#include #include +#include #include #include "core/common/global_flags.h" #include "platform/stream.h" +#if defined(USE_CUDA) +#include +#endif #if defined(USE_NPU) #include "platform/npu/device_capture_lock.h" +#elif defined(USE_CUDA) +#include "platform/cuda/device_capture_lock.h" #endif #include "core/util/net.h" #include "core/util/tensor_helper.h" @@ -55,6 +64,40 @@ constexpr size_t swap_block_info_fixed_size() { return type_size * 2; // src_block_id + dst_block_id } +constexpr std::uintptr_t kCudaZeroCopyAlignment = 16; +constexpr uint64_t kRawInputTensorArenaAlignment = + static_cast(kCudaZeroCopyAlignment); + +inline uint64_t align_up(uint64_t value, uint64_t alignment) { + if (alignment == 0) { + return value; + } + return ((value + alignment - 1) / alignment) * alignment; +} + +inline uint64_t get_aligned_tensor_arena_bytes(uint64_t data_bytes) { + return align_up(data_bytes, kRawInputTensorArenaAlignment); +} + +inline bool is_aligned_for_cuda_zero_copy(const void* ptr) { + return reinterpret_cast(ptr) % kCudaZeroCopyAlignment == 0; +} + +struct RawInputLayoutHeader final { + uint64_t descriptor_bytes = 0; + uint64_t tensor_arena_bytes = 0; +}; + +struct RawInputSectionCursor final { + char* ptr = nullptr; + uint64_t size = 0; +}; + +struct RawInputSerializeContext final { + RawInputSectionCursor descriptor; + RawInputSectionCursor tensor_arena; +}; + inline size_t get_string_size(const std::string& str) { return type_size + str.size(); } @@ -247,71 +290,51 @@ inline size_t get_mm_batch_data_size(const MMBatchData& mm_data) { return total; } -size_t calculate_raw_forward_input_size(const RawForwardInput& input) { - size_t total = 0; +// calculate dit input size +inline size_t get_dit_generation_params_size( + const DiTGenerationParams& params) { + return type_size * + 4 // width, height, num_inference_steps, max_sequence_length + + type_size * + 4 // true_cfg_scale, guidance_scale, strength, cfg_renorm_min + + type_size // num_images_per_prompt + + type_size // seed + + type_size; // enable_cfg_renorm +} + +inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { + size_t size = type_size; // batch_size + + // Vector of strings + size += get_string_vector_size(input.prompts); + size += get_string_vector_size(input.prompts_2); + size += get_string_vector_size(input.negative_prompts); + size += get_string_vector_size(input.negative_prompts_2); + + // Tensors + size += get_tensor_size(input.images); + size += get_tensor_size(input.condition_images); + size += get_tensor_size(input.mask_images); + size += get_tensor_size(input.control_image); + size += get_tensor_size(input.masked_image_latents); + size += get_tensor_size(input.prompt_embeds); + size += get_tensor_size(input.pooled_prompt_embeds); + size += get_tensor_size(input.negative_prompt_embeds); + size += get_tensor_size(input.negative_pooled_prompt_embeds); + size += get_tensor_size(input.latents); + + // Generation params + size += get_dit_generation_params_size(input.generation_params); - // flatten_tokens_vec - total += get_vector_to_tensor_size(input.flatten_tokens_vec); - if (input.flatten_positions_vec.size() > 0) { - // flatten_positions_vec - total += get_vector_to_tensor_size(input.flatten_positions_vec); - } else { - // m_positions_vec - total += get_2d_vector_to_tensor_size(input.m_positions_vec); - } - - // ModelInputParams - total += type_size // batch_forward_type - + type_size // num_sequences - + type_size * 2 // kv_max_seq_len + q_max_seq_len - + type_size; // batch_id - total += get_vector_to_tensor_size(input.q_seq_lens); - total += get_vector_to_tensor_size(input.seq_lens); - total += get_vector_to_tensor_size(input.new_token_slot_ids); - total += get_2d_vector_to_tensor_size(input.block_tables_vec); - total += get_vector_to_tensor_size(input.paged_kv_indptr); - total += get_vector_to_tensor_size(input.paged_kv_indices); - total += get_vector_to_tensor_size(input.paged_kv_last_page_len); - total += get_vector_to_tensor_size(input.new_cache_slot_offsets); - total += get_vector_to_tensor_size(input.kv_cache_start_offsets); - total += get_2d_vector_to_tensor_size(input.embeddings); - total += get_vector_size(input.dp_global_token_nums); - total += get_vector_size(input.dp_is_decode); - total += get_vector_size(input.embedding_ids); - total += get_string_vector_size(input.request_ids); - total += get_vector_size(input.extra_token_ids); - total += type_size + - input.swap_blocks.size() * swap_block_info_fixed_size(); - total += get_vector_to_tensor_size(input.src_block_indices); - total += get_vector_to_tensor_size(input.dst_block_indices); - total += get_vector_to_tensor_size(input.cum_sum); - total += get_mm_batch_data_size(input.mm_data); - total += get_vector_to_tensor_size(input.kv_cache_tokens_nums); - - // SamplingParameters - total += type_size; // selected_token_idxes.size() - if (input.selected_token_idxes.size() > 0) { - SamplingParameters sampling_params; - sampling_params.init(input.sampling_params, - input.selected_token_idxes, - input.sample_idxes, - input.unique_token_ids_vec, - input.unique_token_counts_vec, - input.unique_token_lens_vec); - total += get_sampling_params_size(sampling_params); - } - // acc_logprob - total += get_vector_to_tensor_size(input.acc_logprob_vec); + return size; +} - // transfer_kv_infos - total += type_size; - for (const auto& t : input.transfer_kv_infos) { - total += get_transfer_kv_info_size(t); +inline size_t get_dit_forward_output_size(const DiTForwardOutput& output) { + size_t size = type_size; // vector size + for (const auto& tensor : output.tensors) { + size += get_tensor_size(tensor); } - // eplb_info - total += get_eplb_info_size(input.eplb_info); - - return total; + return size; } template @@ -320,6 +343,39 @@ inline void write_data(char*& buffer, const T& data) { buffer += type_size; } +template +inline void write_data(RawInputSectionCursor& cursor, const T& data) { + if (cursor.ptr != nullptr) { + *reinterpret_cast(cursor.ptr) = data; + cursor.ptr += type_size; + } + cursor.size += type_size; +} + +inline void write_bytes(RawInputSectionCursor& cursor, + const void* data, + uint64_t bytes) { + if (bytes == 0) { + return; + } + if (cursor.ptr != nullptr) { + std::memcpy(cursor.ptr, data, bytes); + cursor.ptr += bytes; + } + cursor.size += bytes; +} + +inline void write_padding(RawInputSectionCursor& cursor, uint64_t bytes) { + if (bytes == 0) { + return; + } + if (cursor.ptr != nullptr) { + std::memset(cursor.ptr, 0, bytes); + cursor.ptr += bytes; + } + cursor.size += bytes; +} + inline void write_string(char*& buffer, const std::string& str) { const uint64_t len = str.size(); write_data(buffer, len); @@ -329,6 +385,15 @@ inline void write_string(char*& buffer, const std::string& str) { } } +inline void write_string(RawInputSectionCursor& cursor, + const std::string& str) { + const uint64_t len = str.size(); + write_data(cursor, len); + if (len > 0) { + write_bytes(cursor, str.data(), len); + } +} + inline void write_string_vector(char*& buffer, const std::vector& vec) { const uint64_t size = vec.size(); @@ -338,6 +403,15 @@ inline void write_string_vector(char*& buffer, } } +inline void write_string_vector(RawInputSectionCursor& cursor, + const std::vector& vec) { + const uint64_t size = vec.size(); + write_data(cursor, size); + for (const auto& str : vec) { + write_string(cursor, str); + } +} + inline void write_tensor(char*& buffer, const torch::Tensor& tensor) { if (!tensor.defined()) { uint64_t ndim = 0; @@ -366,6 +440,43 @@ inline void write_tensor(char*& buffer, const torch::Tensor& tensor) { } } +inline void write_tensor(RawInputSerializeContext& context, + const torch::Tensor& tensor) { + if (!tensor.defined()) { + uint64_t ndim = 0; + write_data(context.descriptor, ndim); + return; + } + + const uint64_t tensor_ndim = tensor.dim(); + write_data(context.descriptor, tensor_ndim); + for (int64_t i = 0; i < tensor.dim(); ++i) { + write_data(context.descriptor, static_cast(tensor.size(i))); + } + + const int8_t tensor_dtype = static_cast(tensor.scalar_type()); + write_data(context.descriptor, tensor_dtype); + + const uint64_t tensor_data_bytes = tensor.numel() * tensor.element_size(); + write_data(context.descriptor, tensor_data_bytes); + + if (tensor_data_bytes == 0) { + return; + } + + if (context.tensor_arena.ptr != nullptr) { + torch::Tensor contiguous_tensor = tensor.cpu().contiguous(); + write_bytes( + context.tensor_arena, contiguous_tensor.data_ptr(), tensor_data_bytes); + } else { + context.tensor_arena.size += tensor_data_bytes; + } + + const uint64_t aligned_bytes = + get_aligned_tensor_arena_bytes(tensor_data_bytes); + write_padding(context.tensor_arena, aligned_bytes - tensor_data_bytes); +} + inline void write_sampling_param(char*& buffer, const RequestSamplingParam& param) { char* ptr = buffer; @@ -405,6 +516,17 @@ inline void write_vector(char*& buffer, const std::vector& vec) { } } +template +inline void write_vector(RawInputSectionCursor& cursor, + const std::vector& vec) { + const uint64_t size = vec.size(); + write_data(cursor, size); + if (size > 0) { + const uint64_t bytes = size * type_size; + write_bytes(cursor, vec.data(), bytes); + } +} + template void write_vector_to_tensor(char*& buffer, const std::vector& vec) { // write ndim @@ -429,6 +551,30 @@ void write_vector_to_tensor(char*& buffer, const std::vector& vec) { buffer += data_bytes; } +template +void write_vector_to_tensor(RawInputSerializeContext& context, + const std::vector& vec) { + uint64_t ndim; + if (vec.empty()) { + ndim = 0; + write_data(context.descriptor, ndim); + return; + } + + ndim = 1; + write_data(context.descriptor, ndim); + write_data(context.descriptor, vec.size()); + + const int8_t tensor_dtype = static_cast(get_scalar_type()); + write_data(context.descriptor, tensor_dtype); + + const uint64_t data_bytes = vec.size() * type_size; + write_data(context.descriptor, data_bytes); + write_bytes(context.tensor_arena, vec.data(), data_bytes); + const uint64_t aligned_bytes = get_aligned_tensor_arena_bytes(data_bytes); + write_padding(context.tensor_arena, aligned_bytes - data_bytes); +} + template inline void write_2d_vector(char*& buffer, const std::vector>& vec2d) { @@ -438,6 +584,15 @@ inline void write_2d_vector(char*& buffer, } } +template +inline void write_2d_vector(RawInputSectionCursor& cursor, + const std::vector>& vec2d) { + write_data(cursor, static_cast(vec2d.size())); + for (const auto& vec : vec2d) { + write_vector(cursor, vec); + } +} + template void write_2d_vector_to_tensor(char*& buffer, const std::vector>& vec2d) { @@ -467,6 +622,36 @@ void write_2d_vector_to_tensor(char*& buffer, } } +template +void write_2d_vector_to_tensor(RawInputSerializeContext& context, + const std::vector>& vec2d) { + uint64_t ndim; + if (vec2d.empty() || vec2d[0].empty()) { + ndim = 0; + write_data(context.descriptor, ndim); + return; + } + + ndim = 2; + write_data(context.descriptor, ndim); + write_data(context.descriptor, vec2d.size()); + write_data(context.descriptor, vec2d[0].size()); + + const int8_t tensor_dtype = static_cast(get_scalar_type()); + write_data(context.descriptor, tensor_dtype); + + const uint64_t per_data_bytes = vec2d[0].size() * type_size; + const uint64_t data_bytes = vec2d.size() * per_data_bytes; + write_data(context.descriptor, data_bytes); + + for (const auto& vec : vec2d) { + write_bytes(context.tensor_arena, vec.data(), per_data_bytes); + } + + const uint64_t aligned_bytes = get_aligned_tensor_arena_bytes(data_bytes); + write_padding(context.tensor_arena, aligned_bytes - data_bytes); +} + inline void write_instance_info(char*& buffer, const InstanceInfo& info) { write_string(buffer, info.name); write_string(buffer, info.rpc_address); @@ -493,6 +678,32 @@ inline void write_instance_info(char*& buffer, const InstanceInfo& info) { } } +inline void write_instance_info(RawInputSerializeContext& context, + const InstanceInfo& info) { + write_string(context.descriptor, info.name); + write_string(context.descriptor, info.rpc_address); + write_string(context.descriptor, info.type); + + write_vector(context.descriptor, info.cluster_ids); + + write_data(context.descriptor, static_cast(info.addrs.size())); + for (const auto& addr : info.addrs) { + write_string(context.descriptor, addr); + } + + write_vector(context.descriptor, info.k_cache_ids); + write_vector(context.descriptor, info.v_cache_ids); + write_data(context.descriptor, info.dp_size); + + const uint64_t prof_size = info.ttft_profiling_data.size(); + write_data(context.descriptor, prof_size); + if (prof_size > 0) { + write_bytes(context.descriptor, + info.ttft_profiling_data.data(), + prof_size * sizeof(std::pair)); + } +} + inline void write_xtensor_layer_offsets( char*& buffer, const std::vector& offsets) { @@ -503,6 +714,16 @@ inline void write_xtensor_layer_offsets( } } +inline void write_xtensor_layer_offsets( + RawInputSerializeContext& context, + const std::vector& offsets) { + write_data(context.descriptor, static_cast(offsets.size())); + for (const auto& layer : offsets) { + write_vector(context.descriptor, layer.k_offsets); + write_vector(context.descriptor, layer.v_offsets); + } +} + inline void write_transfer_kv_info(char*& buffer, const TransferKVInfo& info) { write_string(buffer, info.request_id); write_vector(buffer, info.local_blocks_ids); @@ -512,12 +733,29 @@ inline void write_transfer_kv_info(char*& buffer, const TransferKVInfo& info) { write_xtensor_layer_offsets(buffer, info.dst_xtensor_layer_offsets); } +inline void write_transfer_kv_info(RawInputSerializeContext& context, + const TransferKVInfo& info) { + write_string(context.descriptor, info.request_id); + write_vector(context.descriptor, info.local_blocks_ids); + write_vector(context.descriptor, info.remote_blocks_ids); + write_data(context.descriptor, info.dp_rank); + write_instance_info(context, info.remote_instance_info); + write_xtensor_layer_offsets(context, info.dst_xtensor_layer_offsets); +} + inline void write_eplb_info(char*& buffer, const EplbInfo& info) { write_data(buffer, info.prepare_layer_id); write_vector(buffer, info.expert_ids); write_data(buffer, info.update_layer_id); } +inline void write_eplb_info(RawInputSerializeContext& context, + const EplbInfo& info) { + write_data(context.descriptor, info.prepare_layer_id); + write_vector(context.descriptor, info.expert_ids); + write_data(context.descriptor, info.update_layer_id); +} + inline void write_swap_blocks(char*& buffer, const std::vector& blocks) { write_data(buffer, (uint64_t)blocks.size()); @@ -528,6 +766,15 @@ inline void write_swap_blocks(char*& buffer, } } +inline void write_swap_blocks(RawInputSerializeContext& context, + const std::vector& blocks) { + write_data(context.descriptor, static_cast(blocks.size())); + for (const auto& b : blocks) { + write_data(context.descriptor, b.src_block_id); + write_data(context.descriptor, b.dst_block_id); + } +} + inline void write_vector_tensor(char*& buffer, const std::vector& tensor_vec) { int32_t tensor_num = tensor_vec.size(); @@ -561,6 +808,28 @@ inline void write_mm_dict(char*& buffer, const MMDict& mm_dict) { } } +inline void write_mm_dict(RawInputSerializeContext& context, + const MMDict& mm_dict) { + size_t size = mm_dict.size(); + write_data(context.descriptor, size); + int32_t tensor_num = 1; + for (auto& [mm_key, mm_value] : mm_dict) { + write_string(context.descriptor, mm_key); + if (std::holds_alternative(mm_value)) { + tensor_num = 1; + write_data(context.descriptor, tensor_num); + write_tensor(context, std::get(mm_value)); + } else if (std::holds_alternative>(mm_value)) { + auto& tensor_vec = std::get>(mm_value); + tensor_num = tensor_vec.size(); + write_data(context.descriptor, tensor_num); + for (const auto& tensor : tensor_vec) { + write_tensor(context, tensor); + } + } + } +} + inline void write_mm_item(char*& buffer, const MMDataItem& item) { write_data(buffer, item.type()); write_mm_dict(buffer, item.data()); @@ -576,6 +845,20 @@ inline void write_mm_item(char*& buffer, const MMDataItem& item) { write_data(buffer, state.prefix_cache().cached_token_num); } +inline void write_mm_item(RawInputSerializeContext& context, + const MMDataItem& item) { + write_data(context.descriptor, item.type()); + write_mm_dict(context, item.data()); + + const auto& state = item.state(); + write_data(context.descriptor, state.token_pos().offset); + write_data(context.descriptor, state.token_pos().length); + write_bytes(context.descriptor, + state.prefix_cache().key.data, + XXH3_128BITS_HASH_VALUE_LEN); + write_data(context.descriptor, state.prefix_cache().cached_token_num); +} + inline void write_mm_data_items(char*& buffer, const MMData& mm_data) { const auto& mm_items = mm_data.items(); write_data(buffer, mm_data.type()); @@ -585,12 +868,28 @@ inline void write_mm_data_items(char*& buffer, const MMData& mm_data) { } } +inline void write_mm_data_items(RawInputSerializeContext& context, + const MMData& mm_data) { + const auto& mm_items = mm_data.items(); + write_data(context.descriptor, mm_data.type()); + write_data(context.descriptor, mm_items.size()); + for (const auto& mm_item : mm_items) { + write_mm_item(context, mm_item); + } +} + inline void write_mm_data_dict(char*& buffer, const MMData& mm_data) { const auto& mm_dict = mm_data.items(); write_data(buffer, mm_data.type()); write_mm_dict(buffer, mm_dict); } +inline void write_mm_data_dict(RawInputSerializeContext& context, + const MMData& mm_data) { + write_data(context.descriptor, mm_data.type()); + write_mm_dict(context, mm_data.items()); +} + inline void write_mm_batch_data(char*& buffer, const MMBatchData& mm_data) { const auto& vec = mm_data.mm_data_vec(); write_data(buffer, vec.size()); @@ -599,24 +898,182 @@ inline void write_mm_batch_data(char*& buffer, const MMBatchData& mm_data) { vec.size() ? static_cast(vec[0].hold()) : 1; write_data(buffer, is_mm_item); std::function write_mm_data = - is_mm_item ? write_mm_data_items : write_mm_data_dict; + is_mm_item + ? static_cast(&write_mm_data_items) + : static_cast(&write_mm_data_dict); for (const auto& mm_data : vec) { write_mm_data(buffer, mm_data); } } +inline void write_mm_batch_data(RawInputSerializeContext& context, + const MMBatchData& mm_data) { + const auto& vec = mm_data.mm_data_vec(); + write_data(context.descriptor, vec.size()); + + uint8_t is_mm_item = + vec.empty() ? 1 : static_cast(vec[0].hold()); + write_data(context.descriptor, is_mm_item); + std::function write_mm_data = + is_mm_item + ? static_cast( + &write_mm_data_items) + : static_cast( + &write_mm_data_dict); + for (const auto& current_mm_data : vec) { + write_mm_data(context, current_mm_data); + } +} + +// write dit data +inline void write_dit_generation_params(char*& buffer, + const DiTGenerationParams& params) { + write_data(buffer, params.width); + write_data(buffer, params.height); + write_data(buffer, params.num_inference_steps); + write_data(buffer, params.true_cfg_scale); + write_data(buffer, params.guidance_scale); + write_data(buffer, params.num_images_per_prompt); + write_data(buffer, params.seed); + write_data(buffer, params.max_sequence_length); + write_data(buffer, params.strength); + write_data(buffer, params.enable_cfg_renorm); + write_data(buffer, params.cfg_renorm_min); +} + +inline void write_dit_generation_params(RawInputSerializeContext& context, + const DiTGenerationParams& params) { + write_data(context.descriptor, params.width); + write_data(context.descriptor, params.height); + write_data(context.descriptor, params.num_inference_steps); + write_data(context.descriptor, params.true_cfg_scale); + write_data(context.descriptor, params.guidance_scale); + write_data(context.descriptor, params.num_images_per_prompt); + write_data(context.descriptor, params.seed); + write_data(context.descriptor, params.max_sequence_length); + write_data(context.descriptor, params.strength); + write_data(context.descriptor, params.enable_cfg_renorm); + write_data(context.descriptor, params.cfg_renorm_min); +} + +inline void write_dit_forward_input(char*& buffer, + const DiTForwardInput& input) { + write_data(buffer, input.batch_size); + + write_string_vector(buffer, input.prompts); + write_string_vector(buffer, input.prompts_2); + write_string_vector(buffer, input.negative_prompts); + write_string_vector(buffer, input.negative_prompts_2); + + write_tensor(buffer, input.images); + write_tensor(buffer, input.condition_images); + write_tensor(buffer, input.mask_images); + write_tensor(buffer, input.control_image); + write_tensor(buffer, input.masked_image_latents); + write_tensor(buffer, input.prompt_embeds); + write_tensor(buffer, input.pooled_prompt_embeds); + write_tensor(buffer, input.negative_prompt_embeds); + write_tensor(buffer, input.negative_pooled_prompt_embeds); + write_tensor(buffer, input.latents); + + write_dit_generation_params(buffer, input.generation_params); +} + +inline void write_dit_forward_input(RawInputSerializeContext& context, + const DiTForwardInput& input) { + write_data(context.descriptor, input.batch_size); + + write_string_vector(context.descriptor, input.prompts); + write_string_vector(context.descriptor, input.prompts_2); + write_string_vector(context.descriptor, input.negative_prompts); + write_string_vector(context.descriptor, input.negative_prompts_2); + + write_tensor(context, input.images); + write_tensor(context, input.condition_images); + write_tensor(context, input.mask_images); + write_tensor(context, input.control_image); + write_tensor(context, input.masked_image_latents); + write_tensor(context, input.prompt_embeds); + write_tensor(context, input.pooled_prompt_embeds); + write_tensor(context, input.negative_prompt_embeds); + write_tensor(context, input.negative_pooled_prompt_embeds); + write_tensor(context, input.latents); + + write_dit_generation_params(context, input.generation_params); +} + +inline void write_dit_forward_output(char*& buffer, + const DiTForwardOutput& output) { + write_data(buffer, static_cast(output.tensors.size())); + for (const auto& tensor : output.tensors) { + write_tensor(buffer, tensor); + } +} + inline void safe_advance_buffer(const char*& buffer, size_t offset) { if (buffer != nullptr) { buffer += offset; } } +struct TensorMeta final { + std::vector shape; + torch::ScalarType dtype = torch::kFloat; + uint64_t data_bytes = 0; +}; + +struct DeviceBufferSession final { + torch::Tensor owner_buffer; + const char* device_cursor = nullptr; + bool active = false; + bool need_finalize_sync = false; +#if defined(USE_NPU) + std::optional> capture_lock_guard; +#elif defined(USE_CUDA) + std::optional> capture_lock_guard; +#endif +}; + +struct ReadContext final { + const char* descriptor_cursor; + const char* tensor_cursor; + DeviceBufferSession* device_session = nullptr; +}; + +inline uint64_t get_tensor_payload_bytes(const torch::Tensor& tensor) { + if (!tensor.defined()) { + return 0; + } + return tensor.numel() * tensor.element_size(); +} + +inline bool has_device_buffer(const ReadContext& context) { + return context.device_session != nullptr && context.device_session->active; +} + +inline void advance_descriptor_cursor(ReadContext& context, size_t offset) { + context.descriptor_cursor += offset; +} + +inline void advance_tensor_cursors(ReadContext& context, size_t offset) { + context.tensor_cursor += offset; + if (has_device_buffer(context)) { + safe_advance_buffer(context.device_session->device_cursor, offset); + } +} + template inline void read_data(const char*& buffer, T& data) { data = *reinterpret_cast(buffer); buffer += type_size; } +template +inline void read_data(ReadContext& context, T& data) { + data = *reinterpret_cast(context.descriptor_cursor); + advance_descriptor_cursor(context, type_size); +} + template inline void read_data(const char*& buffer, T& data, @@ -651,6 +1108,17 @@ inline void read_string(const char*& buffer, } } +inline void read_string(ReadContext& context, std::string& str) { + uint64_t len; + read_data(context, len); + if (len > 0) { + str.assign(context.descriptor_cursor, len); + advance_descriptor_cursor(context, len); + } else { + str.clear(); + } +} + inline void read_string_vector(const char*& buffer, std::vector& vec) { uint64_t size; @@ -672,6 +1140,16 @@ inline void read_string_vector(const char*& buffer, } } +inline void read_string_vector(ReadContext& context, + std::vector& vec) { + uint64_t size; + read_data(context, size); + vec.resize(size); + for (uint64_t i = 0; i < size; ++i) { + read_string(context, vec[i]); + } +} + inline void read_tensor(const char*& buffer, torch::Tensor& tensor) { // read ndim uint64_t ndim; @@ -703,43 +1181,130 @@ inline void read_tensor(const char*& buffer, torch::Tensor& tensor) { buffer += data_bytes; } -void read_tensor(const char*& buffer, - torch::Tensor& tensor, - const char*& device_buffer) { - // read ndim +inline TensorMeta read_tensor_meta(ReadContext& context) { + TensorMeta meta; + uint64_t ndim; + read_data(context, ndim); + if (ndim == 0) { + return meta; + } + + meta.shape.resize(ndim); + for (size_t i = 0; i < ndim; ++i) { + int64_t dim_size; + read_data(context, dim_size); + meta.shape[i] = static_cast(dim_size); + } + + int8_t tensor_dtype; + read_data(context, tensor_dtype); + meta.dtype = static_cast(tensor_dtype); + read_data(context, meta.data_bytes); + return meta; +} + +inline torch::Tensor materialize_tensor_from_current_cursor( + const TensorMeta& meta, + DeviceBufferSession& session, + Stream* stream) { + const char* device_buffer = session.device_cursor; +#if defined(USE_NPU) + return get_tensor_from_blob(meta.shape, meta.dtype, device_buffer); +#elif defined(USE_CUDA) + if (session.owner_buffer.defined() && + is_aligned_for_cuda_zero_copy(device_buffer)) { + return get_tensor_from_blob( + meta.shape, meta.dtype, device_buffer, session.owner_buffer); + } + + auto options = torch::TensorOptions().dtype(meta.dtype).device(torch::kCUDA); + auto tensor = torch::empty(meta.shape, options); + cudaError_t err; + if (stream != nullptr) { + err = cudaMemcpyAsync(tensor.data_ptr(), + device_buffer, + meta.data_bytes, + cudaMemcpyDeviceToDevice, + stream->get_stream()->stream()); + } else { + err = cudaMemcpy(tensor.data_ptr(), + device_buffer, + meta.data_bytes, + cudaMemcpyDeviceToDevice); + } + CHECK_EQ(err, cudaSuccess) + << "CUDA device buffer copy failed: " << cudaGetErrorString(err); + return tensor; +#else + LOG(FATAL) << "Unsupported device buffer backend"; +#endif +} + +inline void read_tensor(ReadContext& context, + torch::Tensor& tensor, + Stream* stream = nullptr, + bool force_host_materialize = false) { + const TensorMeta meta = read_tensor_meta(context); + if (meta.shape.empty()) { + return; + } + + if (!force_host_materialize && has_device_buffer(context)) { + tensor = materialize_tensor_from_current_cursor( + meta, *context.device_session, stream); + } else { + tensor = torch::from_blob( + const_cast(static_cast(context.tensor_cursor)), + meta.shape, + torch::TensorOptions() + .dtype(meta.dtype) + .device(torch::kCPU) + .pinned_memory(true)); + } + advance_tensor_cursors(context, + get_aligned_tensor_arena_bytes(meta.data_bytes)); +} + +inline void read_tensor(const char*& buffer, + torch::Tensor& tensor, + const char*& device_buffer, + Stream* stream = nullptr) { + TensorMeta meta; uint64_t ndim; read_data(buffer, ndim, device_buffer); if (ndim == 0) { return; } - // read shape - std::vector shape(ndim); + + meta.shape.resize(ndim); for (size_t i = 0; i < ndim; ++i) { int64_t dim_size; read_data(buffer, dim_size, device_buffer); - shape[i] = static_cast(dim_size); + meta.shape[i] = static_cast(dim_size); } - // read dtype + int8_t tensor_dtype; read_data(buffer, tensor_dtype, device_buffer); - auto dtype = static_cast(tensor_dtype); - // read data_bytes - uint64_t data_bytes; - read_data(buffer, data_bytes, device_buffer); + meta.dtype = static_cast(tensor_dtype); + read_data(buffer, meta.data_bytes, device_buffer); if (device_buffer != nullptr) { - tensor = get_tensor_from_blob(shape, dtype, device_buffer); + DeviceBufferSession device_session; + device_session.device_cursor = device_buffer; + device_session.active = true; + tensor = + materialize_tensor_from_current_cursor(meta, device_session, stream); + safe_advance_buffer(device_buffer, meta.data_bytes); } else { tensor = torch::from_blob(const_cast(static_cast(buffer)), - shape, + meta.shape, torch::TensorOptions() - .dtype(dtype) + .dtype(meta.dtype) .device(torch::kCPU) .pinned_memory(true)); } - buffer += data_bytes; - safe_advance_buffer(device_buffer, data_bytes); + buffer += meta.data_bytes; } template @@ -770,46 +1335,43 @@ inline void read_vector(const char*& buffer, } template -inline void read_tensor_and_vector(const char*& buffer, +inline void read_vector(ReadContext& context, std::vector& vec) { + uint64_t size; + read_data(context, size); + vec.resize(size); + if (size > 0) { + const size_t bytes = size * type_size; + std::memcpy(vec.data(), context.descriptor_cursor, bytes); + advance_descriptor_cursor(context, bytes); + } +} + +template +inline void read_tensor_and_vector(ReadContext& context, torch::Tensor& tensor, std::vector& vec, - const char*& device_buffer) { - // read ndim - uint64_t ndim; - read_data(buffer, ndim, device_buffer); - if (ndim == 0) { + Stream* stream = nullptr) { + const TensorMeta meta = read_tensor_meta(context); + if (meta.shape.empty()) { return; } - // read shape - std::vector shape(ndim); - for (size_t i = 0; i < ndim; ++i) { - int64_t dim_size; - read_data(buffer, dim_size, device_buffer); - shape[i] = static_cast(dim_size); - } - vec.resize(shape[0]); - // read dtype - int8_t tensor_dtype; - read_data(buffer, tensor_dtype, device_buffer); - auto dtype = static_cast(tensor_dtype); - // read data_bytes - uint64_t data_bytes; - read_data(buffer, data_bytes, device_buffer); - if (device_buffer != nullptr) { - tensor = get_tensor_from_blob(shape, dtype, device_buffer); + vec.resize(meta.shape[0]); + if (has_device_buffer(context)) { + tensor = materialize_tensor_from_current_cursor( + meta, *context.device_session, stream); } else { - tensor = - torch::from_blob(const_cast(static_cast(buffer)), - shape, - torch::TensorOptions() - .dtype(dtype) - .device(torch::kCPU) - .pinned_memory(true)); + tensor = torch::from_blob( + const_cast(static_cast(context.tensor_cursor)), + meta.shape, + torch::TensorOptions() + .dtype(meta.dtype) + .device(torch::kCPU) + .pinned_memory(true)); } - std::memcpy(vec.data(), buffer, data_bytes); - buffer += data_bytes; - safe_advance_buffer(device_buffer, data_bytes); + std::memcpy(vec.data(), context.tensor_cursor, meta.data_bytes); + advance_tensor_cursors(context, + get_aligned_tensor_arena_bytes(meta.data_bytes)); } template @@ -880,6 +1442,36 @@ inline void read_instance_info(const char*& buffer, InstanceInfo& info) { } } +inline void read_instance_info(ReadContext& context, InstanceInfo& info) { + read_string(context, info.name); + read_string(context, info.rpc_address); + read_string(context, info.type); + + read_vector(context, info.cluster_ids); + + uint64_t addr_count; + read_data(context, addr_count); + info.addrs.resize(addr_count); + for (auto& addr : info.addrs) { + read_string(context, addr); + } + + read_vector(context, info.k_cache_ids); + read_vector(context, info.v_cache_ids); + read_data(context, info.dp_size); + + uint64_t prof_size; + read_data(context, prof_size); + info.ttft_profiling_data.resize(prof_size); + if (prof_size > 0) { + std::memcpy(info.ttft_profiling_data.data(), + context.descriptor_cursor, + prof_size * sizeof(std::pair)); + advance_descriptor_cursor(context, + prof_size * sizeof(std::pair)); + } +} + inline void read_xtensor_layer_offsets( const char*& buffer, std::vector& offsets) { @@ -892,6 +1484,18 @@ inline void read_xtensor_layer_offsets( } } +inline void read_xtensor_layer_offsets( + ReadContext& context, + std::vector& offsets) { + uint64_t num_layers; + read_data(context, num_layers); + offsets.resize(num_layers); + for (auto& layer : offsets) { + read_vector(context, layer.k_offsets); + read_vector(context, layer.v_offsets); + } +} + inline void read_transfer_kv_info(const char*& buffer, TransferKVInfo& info) { read_string(buffer, info.request_id); read_vector(buffer, info.local_blocks_ids); @@ -901,12 +1505,27 @@ inline void read_transfer_kv_info(const char*& buffer, TransferKVInfo& info) { read_xtensor_layer_offsets(buffer, info.dst_xtensor_layer_offsets); } +inline void read_transfer_kv_info(ReadContext& context, TransferKVInfo& info) { + read_string(context, info.request_id); + read_vector(context, info.local_blocks_ids); + read_vector(context, info.remote_blocks_ids); + read_data(context, info.dp_rank); + read_instance_info(context, info.remote_instance_info); + read_xtensor_layer_offsets(context, info.dst_xtensor_layer_offsets); +} + inline void read_eplb_info(const char*& buffer, EplbInfo& info) { read_data(buffer, info.prepare_layer_id); read_vector(buffer, info.expert_ids); read_data(buffer, info.update_layer_id); } +inline void read_eplb_info(ReadContext& context, EplbInfo& info) { + read_data(context, info.prepare_layer_id); + read_vector(context, info.expert_ids); + read_data(context, info.update_layer_id); +} + inline void read_swap_blocks(const char*& buffer, std::vector& blocks, const char*& device_buffer) { @@ -923,6 +1542,21 @@ inline void read_swap_blocks(const char*& buffer, } } +inline void read_swap_blocks(ReadContext& context, + std::vector& blocks) { + uint64_t size; + read_data(context, size); + blocks.reserve(size); + + int32_t src_block_id; + int32_t dst_block_id; + for (uint64_t i = 0; i < size; ++i) { + read_data(context, src_block_id); + read_data(context, dst_block_id); + blocks.emplace_back(src_block_id, dst_block_id); + } +} + inline void read_vector_tensor(const char*& buffer, std::vector& tensor_vec) { int32_t tensor_num; @@ -957,6 +1591,28 @@ inline void read_mm_dict(const char*& buffer, } } +inline void read_mm_dict(ReadContext& context, MMDict& mm_dict) { + size_t size; + read_data(context, size); + int32_t tensor_num; + while (size--) { + std::string mm_key; + read_string(context, mm_key); + read_data(context, tensor_num); + if (tensor_num == 1) { + torch::Tensor tensor; + read_tensor(context, tensor); + mm_dict[mm_key] = tensor; + } else { + std::vector tensor_vec(tensor_num); + for (int32_t i = 0; i < tensor_num; ++i) { + read_tensor(context, tensor_vec[i]); + } + mm_dict[mm_key] = tensor_vec; + } + } +} + inline void read_mm_item(const char*& buffer, MMDataItem& item, const char*& device_buffer) { @@ -982,6 +1638,25 @@ inline void read_mm_item(const char*& buffer, buffer, state.mutable_prefix_cache().cached_token_num, device_buffer); } +inline void read_mm_item(ReadContext& context, MMDataItem& item) { + uint32_t type; + read_data(context, type); + MMDict dict; + read_mm_dict(context, dict); + auto mm_type_value = static_cast(type); + item = std::move(MMDataItem(mm_type_value, dict)); + auto& state = item.mutable_state(); + + read_data(context, state.mutable_token_pos().offset); + read_data(context, state.mutable_token_pos().length); + + std::memcpy(state.mutable_prefix_cache().key.data, + context.descriptor_cursor, + XXH3_128BITS_HASH_VALUE_LEN); + advance_descriptor_cursor(context, XXH3_128BITS_HASH_VALUE_LEN); + read_data(context, state.mutable_prefix_cache().cached_token_num); +} + inline void read_mm_data_dict(const char*& buffer, MMData& mm_data, const char*& device_buffer) { @@ -993,6 +1668,15 @@ inline void read_mm_data_dict(const char*& buffer, mm_data = MMData(ty, mm_dict); } +inline void read_mm_data_dict(ReadContext& context, MMData& mm_data) { + uint32_t mm_type; + read_data(context, mm_type); + MMDict mm_dict; + read_mm_dict(context, mm_dict); + MMType ty{static_cast(mm_type)}; + mm_data = MMData(ty, mm_dict); +} + inline void read_mm_data_items(const char*& buffer, MMData& mm_data, const char*& device_buffer) { @@ -1002,10 +1686,24 @@ inline void read_mm_data_items(const char*& buffer, read_data(buffer, mm_items_num, device_buffer); MMItemVec mm_items; mm_items.reserve(mm_items_num); - MMDataItem mm_item(MMType::NONE); for (size_t idx = 0; idx < mm_items_num; ++idx) { - read_mm_item(buffer, mm_item, device_buffer); - mm_items.push_back(std::move(mm_item)); + mm_items.emplace_back(MMType::NONE); + read_mm_item(buffer, mm_items.back(), device_buffer); + } + MMType ty{static_cast(mm_type)}; + mm_data = MMData(ty, std::move(mm_items)); +} + +inline void read_mm_data_items(ReadContext& context, MMData& mm_data) { + uint32_t mm_type; + read_data(context, mm_type); + size_t mm_items_num; + read_data(context, mm_items_num); + MMItemVec mm_items; + mm_items.reserve(mm_items_num); + for (size_t idx = 0; idx < mm_items_num; ++idx) { + mm_items.emplace_back(MMType::NONE); + read_mm_item(context, mm_items.back()); } MMType ty{static_cast(mm_type)}; mm_data = MMData(ty, std::move(mm_items)); @@ -1014,6 +1712,8 @@ inline void read_mm_data_items(const char*& buffer, inline void read_mm_batch_data(const char*& buffer, MMBatchData& batch_mm_data, const char*& device_buffer) { + using ReadMmDataFn = void (*)(const char*&, MMData&, const char*&); + std::vector vec; size_t mm_data_num; @@ -1021,184 +1721,404 @@ inline void read_mm_batch_data(const char*& buffer, uint8_t is_mm_item; read_data(buffer, is_mm_item, device_buffer); vec.reserve(mm_data_num); - MMData mm_data; - std::function read_mm_data = - is_mm_item ? read_mm_data_items : read_mm_data_dict; + ReadMmDataFn read_mm_data = + is_mm_item ? static_cast(&read_mm_data_items) + : static_cast(&read_mm_data_dict); for (size_t i = 0; i < mm_data_num; ++i) { - read_mm_data(buffer, mm_data, device_buffer); - vec.push_back(std::move(mm_data)); + vec.emplace_back(); + read_mm_data(buffer, vec.back(), device_buffer); } batch_mm_data.batch(std::move(vec)); } -inline void deserialize_raw_forward_input(const char*& buffer, - const uint64_t buffer_size, - ForwardInput& forward_input, - const torch::Device& device, - Stream* stream) { - const char* device_buffer = nullptr; +inline void read_mm_batch_data(ReadContext& context, + MMBatchData& batch_mm_data) { + using ReadMmDataFn = void (*)(ReadContext&, MMData&); + + std::vector vec; + + size_t mm_data_num; + read_data(context, mm_data_num); + uint8_t is_mm_item; + read_data(context, is_mm_item); + vec.reserve(mm_data_num); + ReadMmDataFn read_mm_data = + is_mm_item ? static_cast(&read_mm_data_items) + : static_cast(&read_mm_data_dict); + for (size_t i = 0; i < mm_data_num; ++i) { + vec.emplace_back(); + read_mm_data(context, vec.back()); + } + + batch_mm_data.batch(std::move(vec)); +} + +// read dit data +inline void read_dit_generation_params(const char*& buffer, + DiTGenerationParams& params) { + read_data(buffer, params.width); + read_data(buffer, params.height); + read_data(buffer, params.num_inference_steps); + read_data(buffer, params.true_cfg_scale); + read_data(buffer, params.guidance_scale); + read_data(buffer, params.num_images_per_prompt); + read_data(buffer, params.seed); + read_data(buffer, params.max_sequence_length); + read_data(buffer, params.strength); + read_data(buffer, params.enable_cfg_renorm); + read_data(buffer, params.cfg_renorm_min); +} + +inline void read_dit_generation_params(ReadContext& context, + DiTGenerationParams& params) { + read_data(context, params.width); + read_data(context, params.height); + read_data(context, params.num_inference_steps); + read_data(context, params.true_cfg_scale); + read_data(context, params.guidance_scale); + read_data(context, params.num_images_per_prompt); + read_data(context, params.seed); + read_data(context, params.max_sequence_length); + read_data(context, params.strength); + read_data(context, params.enable_cfg_renorm); + read_data(context, params.cfg_renorm_min); +} + +inline void read_dit_forward_input(const char*& buffer, + DiTForwardInput& input) { + read_data(buffer, input.batch_size); + + read_string_vector(buffer, input.prompts); + read_string_vector(buffer, input.prompts_2); + read_string_vector(buffer, input.negative_prompts); + read_string_vector(buffer, input.negative_prompts_2); + + read_tensor(buffer, input.images); + read_tensor(buffer, input.condition_images); + read_tensor(buffer, input.mask_images); + read_tensor(buffer, input.control_image); + read_tensor(buffer, input.masked_image_latents); + read_tensor(buffer, input.prompt_embeds); + read_tensor(buffer, input.pooled_prompt_embeds); + read_tensor(buffer, input.negative_prompt_embeds); + read_tensor(buffer, input.negative_pooled_prompt_embeds); + read_tensor(buffer, input.latents); + + read_dit_generation_params(buffer, input.generation_params); +} + +inline void read_dit_forward_input(ReadContext& context, + DiTForwardInput& input) { + read_data(context, input.batch_size); + + read_string_vector(context, input.prompts); + read_string_vector(context, input.prompts_2); + read_string_vector(context, input.negative_prompts); + read_string_vector(context, input.negative_prompts_2); + + read_tensor(context, + input.images, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.condition_images, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.mask_images, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.control_image, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.masked_image_latents, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.prompt_embeds, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.pooled_prompt_embeds, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.negative_prompt_embeds, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.negative_pooled_prompt_embeds, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input.latents, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + + read_dit_generation_params(context, input.generation_params); +} + +inline void read_dit_forward_output(const char*& buffer, + DiTForwardOutput& output) { + uint64_t size; + read_data(buffer, size); + output.tensors.resize(size); + for (auto& tensor : output.tensors) { + read_tensor(buffer, tensor); + } +} + +inline void initialize_device_buffer_session(ReadContext& context, + ForwardInput& forward_input, + const torch::Device& device, + const uint64_t buffer_size, + Stream* stream) { + if (context.device_session == nullptr) { + return; + } + +#if defined(USE_NPU) || defined(USE_CUDA) + if (!FLAGS_use_contiguous_input_buffer) { + return; + } + + auto& session = *context.device_session; + torch::Tensor host_input_buffer = + torch::from_blob(const_cast(context.tensor_cursor), + {static_cast(buffer_size)}, + torch::TensorOptions() + .dtype(torch::kUInt8) + .device(torch::kCPU) + .pinned_memory(/*pinned_memory=*/true)); + + auto device_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + + if (stream != nullptr) { #if defined(USE_NPU) - std::optional> capture_lock_guard; - torch::Tensor host_input_buffer; - if (FLAGS_use_contiguous_input_buffer) { - host_input_buffer = torch::from_blob(const_cast(buffer), - {static_cast(buffer_size)}, - torch::TensorOptions() - .dtype(torch::kUInt8) - .device(torch::kCPU) - .pinned_memory(true)); - - auto device_options = - torch::TensorOptions().dtype(torch::kUInt8).device(device); - - if (stream != nullptr) { - auto& capture_lock = - ::xllm::npu::DeviceCaptureLock::get_instance().get_lock( + auto& capture_lock = + ::xllm::npu::DeviceCaptureLock::get_instance().get_lock(device.index()); + session.capture_lock_guard.emplace(capture_lock); +#elif defined(USE_CUDA) + if (FLAGS_enable_graph) { + auto& replay_lock = + ::xllm::cuda::DeviceCaptureLock::get_instance().get_read_lock( device.index()); - capture_lock_guard.emplace(capture_lock); - c10::StreamGuard stream_guard = stream->set_stream_guard(); - forward_input.device_input_buffer = - safe_to(host_input_buffer, device_options, true); - } else { - forward_input.device_input_buffer = - safe_to(host_input_buffer, device_options); + session.capture_lock_guard.emplace(replay_lock); } +#endif + c10::StreamGuard stream_guard = stream->set_stream_guard(); + forward_input.device_input_buffer = + safe_to(host_input_buffer, device_options, /*non_blocking=*/true); + } else { + forward_input.device_input_buffer = + safe_to(host_input_buffer, device_options); + } - device_buffer = (char*)forward_input.device_input_buffer.data_ptr(); + session.owner_buffer = forward_input.device_input_buffer; + session.device_cursor = + static_cast(forward_input.device_input_buffer.data_ptr()); + session.active = session.device_cursor != nullptr; + session.need_finalize_sync = session.active && stream != nullptr; +#else + (void)context; + (void)forward_input; + (void)device; + (void)buffer_size; + (void)stream; +#endif +} + +inline void finalize_device_buffer_session(DeviceBufferSession& session, + Stream* stream) { +#if defined(USE_NPU) || defined(USE_CUDA) + if (session.need_finalize_sync && stream != nullptr) { + stream->synchronize(); } +#else + (void)session; + (void)stream; #endif +} - read_tensor(buffer, forward_input.token_ids, device_buffer); - read_tensor(buffer, forward_input.positions, device_buffer); +inline void deserialize_raw_forward_input(const char*& buffer, + const uint64_t buffer_size, + ForwardInput& forward_input, + const torch::Device& device, + Stream* stream) { + RawInputLayoutHeader layout; + read_data(buffer, layout.descriptor_bytes); + read_data(buffer, layout.tensor_arena_bytes); + CHECK_GE(buffer_size, sizeof(RawInputLayoutHeader)) + << "raw input layout header overflow"; + CHECK_LE(layout.descriptor_bytes + layout.tensor_arena_bytes, + buffer_size - sizeof(RawInputLayoutHeader)) + << "raw input layout overflow"; + + DeviceBufferSession device_session; + ReadContext context{ + buffer, buffer + layout.descriptor_bytes, &device_session}; + initialize_device_buffer_session( + context, forward_input, device, layout.tensor_arena_bytes, stream); + + read_tensor(context, forward_input.token_ids, stream); + read_tensor(context, forward_input.positions, stream); // input_params auto& input_params = forward_input.input_params; int32_t batch_forward_type; - read_data(buffer, batch_forward_type, device_buffer); + read_data(context, batch_forward_type); input_params.batch_forward_type = BatchForwardType(batch_forward_type); - read_data(buffer, input_params.num_sequences, device_buffer); - read_data(buffer, input_params.kv_max_seq_len, device_buffer); - read_data(buffer, input_params.q_max_seq_len, device_buffer); - read_data(buffer, input_params.batch_id, device_buffer); - read_tensor_and_vector(buffer, - input_params.q_seq_lens, - input_params.q_seq_lens_vec, - device_buffer); - read_tensor_and_vector(buffer, - input_params.kv_seq_lens, - input_params.kv_seq_lens_vec, - device_buffer); - read_tensor(buffer, input_params.paged_kv_indptr, device_buffer); - read_tensor(buffer, input_params.paged_kv_indices, device_buffer); - read_tensor(buffer, input_params.paged_kv_last_page_len, device_buffer); - read_tensor(buffer, input_params.new_cache_slot_offsets, device_buffer); - read_tensor(buffer, input_params.kv_cache_start_offsets, device_buffer); - read_tensor(buffer, input_params.input_embedding, device_buffer); - read_vector(buffer, input_params.dp_global_token_nums, device_buffer); - read_vector(buffer, input_params.dp_is_decode, device_buffer); - read_vector(buffer, input_params.embedding_ids, device_buffer); - read_string_vector(buffer, input_params.request_ids, device_buffer); - read_vector(buffer, input_params.extra_token_ids, device_buffer); - read_swap_blocks(buffer, input_params.swap_blocks, device_buffer); - read_tensor(buffer, input_params.src_block_indices, device_buffer); - read_tensor(buffer, input_params.dst_block_indices, device_buffer); - read_tensor(buffer, input_params.cum_sum, device_buffer); - read_mm_batch_data(buffer, input_params.mm_data, device_buffer); - read_tensor_and_vector(buffer, + read_data(context, input_params.num_sequences); + read_data(context, input_params.kv_max_seq_len); + read_data(context, input_params.q_max_seq_len); + read_data(context, input_params.batch_id); + read_tensor_and_vector( + context, input_params.q_seq_lens, input_params.q_seq_lens_vec, stream); +#if !defined(USE_CUDA) + if (!input_params.q_seq_lens_vec.empty()) { + std::vector cu_lens(input_params.q_seq_lens_vec.size()); + std::partial_sum(input_params.q_seq_lens_vec.begin(), + input_params.q_seq_lens_vec.end(), + cu_lens.begin()); + input_params.q_cu_seq_lens = torch::tensor(cu_lens, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + } +#endif + read_tensor_and_vector( + context, input_params.kv_seq_lens, input_params.kv_seq_lens_vec, stream); + read_tensor(context, input_params.paged_kv_indptr, stream); + read_tensor(context, input_params.paged_kv_indices, stream); + read_tensor(context, input_params.paged_kv_last_page_len, stream); + read_tensor(context, input_params.new_cache_slot_offsets, stream); + read_tensor(context, input_params.kv_cache_start_offsets, stream); + read_tensor(context, input_params.input_embedding, stream); + read_vector(context, input_params.dp_global_token_nums); + read_vector(context, input_params.dp_is_decode); + read_vector(context, input_params.embedding_ids); + read_string_vector(context, input_params.request_ids); + read_vector(context, input_params.extra_token_ids); + read_swap_blocks(context, input_params.swap_blocks); + read_tensor(context, input_params.src_block_indices, stream); + read_tensor(context, input_params.dst_block_indices, stream); + read_tensor(context, input_params.cum_sum, stream); + read_mm_batch_data(context, input_params.mm_data); + read_tensor_and_vector(context, input_params.kv_cache_tokens_nums, input_params.kv_cache_tokens_nums_host, - device_buffer); + stream); // sampling_params uint64_t selected_token_idxes_size; - read_data(buffer, selected_token_idxes_size, device_buffer); + read_data(context, selected_token_idxes_size); if (selected_token_idxes_size > 0) { auto& sampling_params = forward_input.sampling_params; - read_tensor(buffer, sampling_params.selected_token_idxes, device_buffer); - read_tensor(buffer, sampling_params.frequency_penalties, device_buffer); - read_tensor(buffer, sampling_params.presence_penalties, device_buffer); - read_tensor(buffer, sampling_params.repetition_penalties, device_buffer); - read_tensor(buffer, sampling_params.temperatures, device_buffer); - read_tensor(buffer, sampling_params.top_p, device_buffer); - read_tensor(buffer, sampling_params.top_k, device_buffer); - read_tensor(buffer, sampling_params.unique_token_ids, device_buffer); - read_tensor(buffer, sampling_params.unique_token_counts, device_buffer); - read_tensor(buffer, sampling_params.unique_token_ids_lens, device_buffer); - read_tensor(buffer, sampling_params.sample_idxes, device_buffer); - read_tensor(buffer, sampling_params.do_sample, device_buffer); - read_data(buffer, sampling_params.all_random_sample, device_buffer); - read_data(buffer, sampling_params.all_greedy_sample, device_buffer); - read_data(buffer, sampling_params.logprobs, device_buffer); - read_data(buffer, sampling_params.is_embeddings, device_buffer); - read_data(buffer, sampling_params.max_top_logprobs, device_buffer); - read_data(buffer, sampling_params.use_beam_search, device_buffer); + read_tensor(context, sampling_params.selected_token_idxes, stream); + read_tensor(context, sampling_params.frequency_penalties, stream); + read_tensor(context, sampling_params.presence_penalties, stream); + read_tensor(context, sampling_params.repetition_penalties, stream); + read_tensor(context, sampling_params.temperatures, stream); + read_tensor(context, sampling_params.top_p, stream); + read_tensor(context, sampling_params.top_k, stream); + read_tensor(context, sampling_params.unique_token_ids, stream); + read_tensor(context, sampling_params.unique_token_counts, stream); + read_tensor(context, sampling_params.unique_token_ids_lens, stream); + read_tensor(context, sampling_params.sample_idxes, stream); + read_tensor(context, sampling_params.do_sample, stream); + read_data(context, sampling_params.all_random_sample); + read_data(context, sampling_params.all_greedy_sample); + read_data(context, sampling_params.logprobs); + read_data(context, sampling_params.is_embeddings); + read_data(context, sampling_params.max_top_logprobs); + read_data(context, sampling_params.use_beam_search); } // acc_logprob - read_tensor(buffer, forward_input.acc_logprob, device_buffer); + read_tensor(context, forward_input.acc_logprob, stream); - // All inputs below are host data, no need to handle device-side pointers - // transfer_kv_infos + // Keep transfer/eplb host-materialized, but continue advancing the + // device cursor when a contiguous device buffer is active. uint64_t transfer_count; - read_data(buffer, transfer_count); + read_data(context, transfer_count); forward_input.transfer_kv_infos.resize(transfer_count); for (auto& transfer : forward_input.transfer_kv_infos) { - read_transfer_kv_info(buffer, transfer); + read_transfer_kv_info(context, transfer); } - // eplb_info - read_eplb_info(buffer, forward_input.eplb_info); + read_eplb_info(context, forward_input.eplb_info); - // TODO: Optimize this logic. Placing this tensor directly on contiguous - // device memory causes unknown errors. This needs to be optimized after the - // root cause is identified and the error is resolved. - read_tensor(buffer, input_params.new_cache_slots); - read_tensor(buffer, input_params.block_tables); - -#if defined(USE_NPU) - if (device_buffer != nullptr && stream != nullptr) { - stream->synchronize(); +#if defined(USE_CUDA) + if (has_device_buffer(context)) { + read_tensor(context, input_params.new_cache_slots, stream); + read_tensor(context, input_params.block_tables, stream); + } else { + read_tensor(context, + input_params.new_cache_slots, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input_params.block_tables, + /*stream=*/nullptr, + /*force_host_materialize=*/true); } +#else + read_tensor(context, + input_params.new_cache_slots, + /*stream=*/nullptr, + /*force_host_materialize=*/true); + read_tensor(context, + input_params.block_tables, + /*stream=*/nullptr, + /*force_host_materialize=*/true); #endif + + read_dit_forward_input(context, input_params.dit_forward_input); + + finalize_device_buffer_session(device_session, stream); + buffer = context.tensor_cursor; } -inline void serialize_raw_forward_input(const RawForwardInput& input, - char*& buffer) { - write_vector_to_tensor(buffer, input.flatten_tokens_vec); - if (input.flatten_positions_vec.size() > 0) { - write_vector_to_tensor(buffer, input.flatten_positions_vec); +inline void serialize_raw_forward_input_sections( + const RawForwardInput& input, + RawInputSerializeContext& context) { + write_vector_to_tensor(context, input.flatten_tokens_vec); + if (!input.flatten_positions_vec.empty()) { + write_vector_to_tensor(context, input.flatten_positions_vec); } else { - write_2d_vector_to_tensor(buffer, input.m_positions_vec); - } - - // ModelInputParams - write_data(buffer, input.batch_forward_type.value()); - write_data(buffer, input.num_sequences); - write_data(buffer, input.max_seq_len); - write_data(buffer, input.q_max_seq_len); - write_data(buffer, input.batch_id); - write_vector_to_tensor(buffer, input.q_seq_lens); - write_vector_to_tensor(buffer, input.seq_lens); - write_vector_to_tensor(buffer, input.paged_kv_indptr); - write_vector_to_tensor(buffer, input.paged_kv_indices); - write_vector_to_tensor(buffer, input.paged_kv_last_page_len); - write_vector_to_tensor(buffer, input.new_cache_slot_offsets); - write_vector_to_tensor(buffer, input.kv_cache_start_offsets); - write_2d_vector_to_tensor(buffer, input.embeddings); - write_vector(buffer, input.dp_global_token_nums); - write_vector(buffer, input.dp_is_decode); - write_vector(buffer, input.embedding_ids); - write_string_vector(buffer, input.request_ids); - write_vector(buffer, input.extra_token_ids); - write_swap_blocks(buffer, input.swap_blocks); - write_vector_to_tensor(buffer, input.src_block_indices); - write_vector_to_tensor(buffer, input.dst_block_indices); - write_vector_to_tensor(buffer, input.cum_sum); - write_mm_batch_data(buffer, input.mm_data); - write_vector_to_tensor(buffer, input.kv_cache_tokens_nums); - - // SamplingParameters - write_data(buffer, input.selected_token_idxes.size()); - if (input.selected_token_idxes.size() > 0) { + write_2d_vector_to_tensor(context, input.m_positions_vec); + } + + write_data(context.descriptor, input.batch_forward_type.value()); + write_data(context.descriptor, input.num_sequences); + write_data(context.descriptor, input.max_seq_len); + write_data(context.descriptor, input.q_max_seq_len); + write_data(context.descriptor, input.batch_id); + write_vector_to_tensor(context, input.q_seq_lens); + write_vector_to_tensor(context, input.seq_lens); + write_vector_to_tensor(context, input.paged_kv_indptr); + write_vector_to_tensor(context, input.paged_kv_indices); + write_vector_to_tensor(context, input.paged_kv_last_page_len); + write_vector_to_tensor(context, input.new_cache_slot_offsets); + write_vector_to_tensor(context, input.kv_cache_start_offsets); + write_2d_vector_to_tensor(context, input.embeddings); + write_vector(context.descriptor, input.dp_global_token_nums); + write_vector(context.descriptor, input.dp_is_decode); + write_vector(context.descriptor, input.embedding_ids); + write_string_vector(context.descriptor, input.request_ids); + write_vector(context.descriptor, input.extra_token_ids); + write_swap_blocks(context, input.swap_blocks); + write_vector_to_tensor(context, input.src_block_indices); + write_vector_to_tensor(context, input.dst_block_indices); + write_vector_to_tensor(context, input.cum_sum); + write_mm_batch_data(context, input.mm_data); + write_vector_to_tensor(context, input.kv_cache_tokens_nums); + + write_data(context.descriptor, input.selected_token_idxes.size()); + if (!input.selected_token_idxes.empty()) { SamplingParameters sampling_params; sampling_params.init(input.sampling_params, input.selected_token_idxes, @@ -1207,41 +2127,82 @@ inline void serialize_raw_forward_input(const RawForwardInput& input, input.unique_token_counts_vec, input.unique_token_lens_vec); - write_tensor(buffer, sampling_params.selected_token_idxes); - write_tensor(buffer, sampling_params.frequency_penalties); - write_tensor(buffer, sampling_params.presence_penalties); - write_tensor(buffer, sampling_params.repetition_penalties); - write_tensor(buffer, sampling_params.temperatures); - write_tensor(buffer, sampling_params.top_p); - write_tensor(buffer, sampling_params.top_k); - write_tensor(buffer, sampling_params.unique_token_ids); - write_tensor(buffer, sampling_params.unique_token_counts); - write_tensor(buffer, sampling_params.unique_token_ids_lens); - write_tensor(buffer, sampling_params.sample_idxes); - write_tensor(buffer, sampling_params.do_sample); - write_data(buffer, sampling_params.all_random_sample); - write_data(buffer, sampling_params.all_greedy_sample); - write_data(buffer, sampling_params.logprobs); - write_data(buffer, sampling_params.is_embeddings); - write_data(buffer, sampling_params.max_top_logprobs); - write_data(buffer, sampling_params.use_beam_search); + write_tensor(context, sampling_params.selected_token_idxes); + write_tensor(context, sampling_params.frequency_penalties); + write_tensor(context, sampling_params.presence_penalties); + write_tensor(context, sampling_params.repetition_penalties); + write_tensor(context, sampling_params.temperatures); + write_tensor(context, sampling_params.top_p); + write_tensor(context, sampling_params.top_k); + write_tensor(context, sampling_params.unique_token_ids); + write_tensor(context, sampling_params.unique_token_counts); + write_tensor(context, sampling_params.unique_token_ids_lens); + write_tensor(context, sampling_params.sample_idxes); + write_tensor(context, sampling_params.do_sample); + write_data(context.descriptor, sampling_params.all_random_sample); + write_data(context.descriptor, sampling_params.all_greedy_sample); + write_data(context.descriptor, sampling_params.logprobs); + write_data(context.descriptor, sampling_params.is_embeddings); + write_data(context.descriptor, sampling_params.max_top_logprobs); + write_data(context.descriptor, sampling_params.use_beam_search); } - // acc_logprob - write_vector_to_tensor(buffer, input.acc_logprob_vec); - // transfer_kv_infos - write_data(buffer, (uint64_t)input.transfer_kv_infos.size()); - for (const auto& t : input.transfer_kv_infos) { - write_transfer_kv_info(buffer, t); + write_vector_to_tensor(context, input.acc_logprob_vec); + + write_data(context.descriptor, + static_cast(input.transfer_kv_infos.size())); + for (const auto& transfer : input.transfer_kv_infos) { + write_transfer_kv_info(context, transfer); } - // eplb_info - write_eplb_info(buffer, input.eplb_info); + write_eplb_info(context, input.eplb_info); + + write_vector_to_tensor(context, input.new_token_slot_ids); + write_2d_vector_to_tensor(context, input.block_tables_vec); + + write_dit_forward_input(context, input.dit_forward_input); +} - // TODO: Optimize this logic. Placing this tensor directly on contiguous - // device memory causes unknown errors. This needs to be optimized after the - // root cause is identified and the error is resolved. - write_vector_to_tensor(buffer, input.new_token_slot_ids); - write_2d_vector_to_tensor(buffer, input.block_tables_vec); +inline RawInputLayoutHeader calculate_raw_forward_input_layout( + const RawForwardInput& input) { + RawInputSerializeContext context; + serialize_raw_forward_input_sections(input, context); + return RawInputLayoutHeader{context.descriptor.size, + context.tensor_arena.size}; +} + +inline uint64_t get_raw_forward_input_layout_size( + const RawInputLayoutHeader& layout) { + return sizeof(RawInputLayoutHeader) + layout.descriptor_bytes + + layout.tensor_arena_bytes; +} + +size_t calculate_raw_forward_input_size(const RawForwardInput& input) { + const RawInputLayoutHeader layout = calculate_raw_forward_input_layout(input); + return get_raw_forward_input_layout_size(layout); +} + +inline void serialize_raw_forward_input(const RawForwardInput& input, + const RawInputLayoutHeader& layout, + char*& buffer) { + write_data(buffer, layout.descriptor_bytes); + write_data(buffer, layout.tensor_arena_bytes); + + RawInputSerializeContext context{{buffer, 0}, + {buffer + layout.descriptor_bytes, 0}}; + serialize_raw_forward_input_sections(input, context); + + CHECK_EQ(context.descriptor.size, layout.descriptor_bytes) + << "raw input descriptor size mismatch"; + CHECK_EQ(context.tensor_arena.size, layout.tensor_arena_bytes) + << "raw input tensor arena size mismatch"; + + buffer += layout.descriptor_bytes + layout.tensor_arena_bytes; +} + +inline void serialize_raw_forward_input(const RawForwardInput& input, + char*& buffer) { + const RawInputLayoutHeader layout = calculate_raw_forward_input_layout(input); + serialize_raw_forward_input(input, layout, buffer); } size_t calculate_raw_token_size(const RawToken& token) { @@ -1280,6 +2241,8 @@ size_t calculate_raw_forward_output_size(const RawForwardOutput& output) { size += get_vector_size(output.out_tokens); size += get_vector_size(output.out_logprobs); size += type_size; // prepared_layer_id + // dit output data + size += get_dit_forward_output_size(output.dit_forward_output); return size; } @@ -1345,6 +2308,8 @@ void deserialize_raw_forward_output(const char* buffer, read_data(buffer, output.prepared_layer_id); read_vector_tensor(buffer, output.mm_embeddings); + // read dit output + read_dit_forward_output(buffer, output.dit_forward_output); } void serialize_raw_forward_output(const RawForwardOutput& output, @@ -1359,6 +2324,8 @@ void serialize_raw_forward_output(const RawForwardOutput& output, write_data(buffer, output.prepared_layer_id); write_vector_tensor(buffer, output.mm_embeddings); + // write dit output + write_dit_forward_output(buffer, output.dit_forward_output); } void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, @@ -1421,6 +2388,10 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, std::move(raw_input.kv_cache_start_offsets), tensor_options); input_params.mm_data = std::move(raw_input.mm_data); + + // dit input data + input_params.dit_forward_input = std::move(raw_input.dit_forward_input); + if (!raw_input.selected_token_idxes.empty()) { util::pad_2d_vector(raw_input.unique_token_ids_vec, 0); util::pad_2d_vector(raw_input.unique_token_counts_vec, 0); @@ -1447,6 +2418,7 @@ void convert_tensor_to_raw_output( const torch::Tensor& top_logprobs, const torch::Tensor& embeddings, const std::vector& mm_embeddings, + const std::vector& dit_images, const torch::Tensor& expert_load_data, int32_t prepared_layer_id, const torch::Tensor& src_seq_idxes, @@ -1491,6 +2463,7 @@ void convert_tensor_to_raw_output( raw_output.outputs.reserve(num_seqs); raw_output.mm_embeddings = mm_embeddings; + raw_output.dit_forward_output.tensors = dit_images; for (int32_t output_idx = 0; output_idx < num_seqs; ++output_idx) { RawSampleOutput raw_sample_output; @@ -1604,8 +2577,10 @@ std::string ForwardSharedMemoryManager::create_unique_name( } bool ForwardSharedMemoryManager::raw_input_write(const RawForwardInput& input) { + const RawInputLayoutHeader layout = calculate_raw_forward_input_layout(input); + const uint64_t payload_size = get_raw_forward_input_layout_size(layout); uint64_t total_size = sizeof(ControlMetadata); - total_size += type_size + calculate_raw_forward_input_size(input); + total_size += type_size + payload_size; if (unlikely(total_size > size())) { LOG(ERROR) << "raw input size overflow, total_size: " << total_size << ", shm size: " << size(); @@ -1613,16 +2588,14 @@ bool ForwardSharedMemoryManager::raw_input_write(const RawForwardInput& input) { } char* data_ptr = static_cast(base_address()) + sizeof(ControlMetadata); - write_data(data_ptr, - total_size - sizeof(ControlMetadata) - type_size); - serialize_raw_forward_input(input, data_ptr); + write_data(data_ptr, payload_size); + serialize_raw_forward_input(input, layout, data_ptr); uint64_t real_size = (uint64_t)(data_ptr - static_cast(base_address())); CHECK(total_size == real_size) << "total_size != real_size."; std::atomic_thread_fence(std::memory_order_release); control_ptr_->version = ++last_version_; - return true; } @@ -1654,6 +2627,7 @@ bool ForwardSharedMemoryManager::raw_output_write( const torch::Tensor& top_logprobs, const torch::Tensor& embeddings, const std::vector& mm_embeddings, + const std::vector& dit_images, const torch::Tensor& expert_load_data, int32_t prepared_layer_id, const torch::Tensor& src_seq_idxes, @@ -1666,6 +2640,7 @@ bool ForwardSharedMemoryManager::raw_output_write( top_logprobs, embeddings, mm_embeddings, + dit_images, expert_load_data, prepared_layer_id, src_seq_idxes, @@ -1685,7 +2660,6 @@ bool ForwardSharedMemoryManager::raw_output_write( char* test = static_cast(base_address()) + sizeof(ControlMetadata); std::atomic_thread_fence(std::memory_order_release); control_ptr_->version = ++last_version_; - return true; } @@ -1710,4 +2684,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) { void ForwardSharedMemoryManager::clear() { std::memset(base_address(), 0, size()); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/forward_shared_memory_manager.h b/xllm/core/runtime/forward_shared_memory_manager.h index 487535e9b..583a1929b 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.h +++ b/xllm/core/runtime/forward_shared_memory_manager.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "dit_forward_params.h" #include "forward_params.h" #include "params_utils.h" #include "util/shared_memory_manager.h" @@ -110,6 +111,7 @@ class ForwardSharedMemoryManager : public SharedMemoryManager { const torch::Tensor& top_logprobs, const torch::Tensor& embeddings, const std::vector& mm_embeddings, + const std::vector& dit_images, const torch::Tensor& expert_load_data, int32_t prepared_layer_id, const torch::Tensor& src_seq_idxes, diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 8033067f4..eff440ccb 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -79,7 +79,15 @@ bool LLMWorkerImpl::init_model(ModelContext& context) { std::optional LLMWorkerImpl::step(const ForwardInput& input) { if (FLAGS_enable_manual_loader) { #if defined(USE_NPU) - SET_ATB_EXECUTE_STREAM(compute_stream_, device_, context_); + if (!enable_schedule_overlap() && options_.backend() == "llm") { + aclrtStream current_stream = + c10_npu::getCurrentNPUStream(device_.index()).stream(); + atb::Context* atb_context = + const_cast(context_.get_atb_context()); + atb_context->SetExecuteStream(current_stream); + } else { + SET_ATB_EXECUTE_STREAM(compute_stream_, device_, context_); + } #endif return step_internal(input); } @@ -101,6 +109,12 @@ std::optional LLMWorkerImpl::step_internal( std::shared_ptr layer_synchronizer = std::make_shared( context_.get_model_args().n_layers()); +#elif defined(USE_MLU) + std::shared_ptr layer_synchronizer = + std::make_shared( + context_.get_model_args().n_layers()); +#endif +#if defined(USE_NPU) || defined(USE_MLU) const_cast(&(input.input_params))->layer_synchronizer = layer_synchronizer; diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 75f2e3af9..bd365015b 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -507,6 +507,16 @@ std::optional MTPWorkerImpl::run_validate( void MTPWorkerImpl::process_draft_sample_output(SampleOutput& sample_output) { if (sample_output.probs.defined()) { + CHECK(sample_output.next_tokens.defined()) + << "draft sample_output.next_tokens must be defined when probs exist"; + CHECK_EQ(sample_output.next_tokens.dim(), 1) + << "MTP draft cache expects next_tokens [batch], got " + << sample_output.next_tokens.sizes(); + CHECK(sample_output.probs.dim() == 1 || sample_output.probs.dim() == 2) + << "MTP draft cache expects probs [batch] or [batch,vocab], got " + << sample_output.probs.sizes(); + CHECK_EQ(sample_output.probs.size(0), sample_output.next_tokens.size(0)) + << "MTP draft cache probs/token batch mismatch"; // Cache always stores selected-only draft probs [batch_size] to reduce HBM. sample_output.probs = specBuilder::draftProbs::compress_for_cache( sample_output.probs, sample_output.next_tokens); diff --git a/xllm/core/runtime/mtp_worker_impl.h b/xllm/core/runtime/mtp_worker_impl.h index 647e5dfe2..3909ca63c 100644 --- a/xllm/core/runtime/mtp_worker_impl.h +++ b/xllm/core/runtime/mtp_worker_impl.h @@ -17,7 +17,7 @@ limitations under the License. #include "framework/kv_cache/embedding_cache.h" #if defined(USE_NPU) -#include "framework/kv_cache/spec_kv_cache_transfer.h" +#include "framework/kv_cache_transfer/spec_kv_cache_transfer.h" #endif #include "runtime/speculative_worker_impl.h" diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index 44d2f33ef..fb302e32f 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -111,6 +111,18 @@ struct Options { // Context parallelism size PROPERTY(int32_t, cp_size) = 1; + // tensor parallelism size + // Default set as 1 + PROPERTY(int32_t, tp_size) = 1; + + // sequence parallelism size + // Default set as 1 + PROPERTY(int32_t, sp_size) = 1; + + // classifier-free guidance parallelism size + // Default set as 1 + PROPERTY(int32_t, cfg_size) = 1; + // enable enable_schedule_overlap to improve runtime execution efficiency. PROPERTY(bool, enable_schedule_overlap) = true; diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 6b67acb8e..6031be87f 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -81,6 +81,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->paged_kv_last_page_len().begin(), pb_forward_input->paged_kv_last_page_len().end()); std::vector> block_tables_vec; + block_tables_vec.reserve(pb_forward_input->block_tables_vec().size()); for (size_t i = 0; i < pb_forward_input->block_tables_vec().size(); ++i) { block_tables_vec.emplace_back(std::vector( pb_forward_input->block_tables_vec()[i].block_tables().begin(), @@ -99,6 +100,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, // aprint(sample_idxes, "sample_idxes", global_rank_); std::vector> unique_token_ids_vec; + unique_token_ids_vec.reserve(pb_forward_input->unique_token_ids_vec().size()); for (size_t i = 0; i < pb_forward_input->unique_token_ids_vec().size(); ++i) { unique_token_ids_vec.emplace_back(std::vector( pb_forward_input->unique_token_ids_vec()[i].unique_token_ids().begin(), @@ -107,6 +109,8 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, // global_rank_); } std::vector> unique_token_counts_vec; + unique_token_counts_vec.reserve( + pb_forward_input->unique_token_counts_vec().size()); for (size_t i = 0; i < pb_forward_input->unique_token_counts_vec().size(); ++i) { unique_token_counts_vec.emplace_back( @@ -136,6 +140,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, pb_forward_input->request_ids().end()); std::vector swap_blocks; + swap_blocks.reserve(pb_forward_input->swap_blocks().size()); for (size_t i = 0; i < pb_forward_input->swap_blocks().size(); ++i) { swap_blocks.emplace_back(pb_forward_input->swap_blocks()[i].src_block_id(), pb_forward_input->swap_blocks()[i].dst_block_id()); @@ -152,6 +157,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector sampling_params; std::vector tmp_sampling_params; + tmp_sampling_params.reserve(pb_forward_input->sampling_params().size()); for (auto sp : pb_forward_input->sampling_params()) { RequestSamplingParam tmp; tmp.frequency_penalty = sp.frequency_penalty(); @@ -167,6 +173,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, tmp.beam_width = sp.beam_width(); tmp_sampling_params.emplace_back(tmp); } + sampling_params.reserve(tmp_sampling_params.size()); for (size_t i = 0; i < tmp_sampling_params.size(); ++i) { sampling_params.emplace_back(&tmp_sampling_params[i]); } @@ -242,6 +249,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, const int32_t rows = pb_forward_input->embeds().size(); const int32_t cols = pb_forward_input->embeds()[0].vals().size(); std::vector> embeddings_vec; + embeddings_vec.reserve(rows); for (size_t i = 0; i < rows; ++i) { embeddings_vec.emplace_back( std::vector(pb_forward_input->embeds()[i].vals().begin(), @@ -356,6 +364,11 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, proto_to_mmdata(pb_forward_input->mm_data(), &input_params.mm_data); } + if (pb_forward_input->has_dit_forward_input()) { + proto_to_dit_forward_input(pb_forward_input->dit_forward_input(), + input_params.dit_forward_input); + } + COUNTER_ADD(proto_latency_seconds_proto2i, timer.elapsed_seconds()); } @@ -369,6 +382,7 @@ void forward_input_to_proto(const RawForwardInput& inputs, ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_acc_logprob_vec(), inputs.acc_logprob_vec); std::vector pb_sampling_params; + pb_sampling_params.reserve(inputs.sampling_params.size()); for (auto sp : inputs.sampling_params) { proto::RequestSamplingParam pb_sp; pb_sp.set_frequency_penalty(sp->frequency_penalty); @@ -525,6 +539,11 @@ void forward_input_to_proto(const RawForwardInput& inputs, mmdata_to_proto(inputs.mm_data, pb_forward_input->mutable_mm_data()); } + if (inputs.dit_forward_input.valid()) { + dit_forward_input_to_proto(inputs.dit_forward_input, + pb_forward_input->mutable_dit_forward_input()); + } + COUNTER_ADD(proto_latency_seconds_i2proto, timer.elapsed_seconds()); } @@ -574,7 +593,8 @@ void proto_to_forward_output(const proto::ForwardOutput& pb_output, } raw_forward_output.outputs.emplace_back(s); } - + proto_to_dit_forward_output(pb_output.dit_forward_output(), + raw_forward_output.dit_forward_output); COUNTER_ADD(proto_latency_seconds_proto2o, timer.elapsed_seconds()); } @@ -588,6 +608,7 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, const torch::Tensor& src_seq_idxes, const torch::Tensor& out_tokens, const torch::Tensor& out_logprobs, + const std::vector& dit_images, proto::ForwardOutput* pb_forward_output) { Timer timer; int32_t num_seqs = next_tokens.size(0); @@ -738,7 +759,11 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, ADD_VECTOR_TO_PROTO(pb_forward_output->mutable_out_logprobs(), out_logprobs_slice); } - + if (!dit_images.empty()) { + TORCH_TENSOR_VEC_TO_PROTO_TENSOR_LIST( + pb_forward_output->mutable_dit_forward_output()->mutable_tensors(), + dit_images); + } COUNTER_ADD(proto_latency_seconds_o2proto, timer.elapsed_seconds()); return; } @@ -819,4 +844,213 @@ bool block_transfer_info_to_proto( return true; } +bool dit_forward_input_to_proto(const DiTForwardInput& dit_inputs, + proto::DiTForwardInput* pb_dit_inputs) { + pb_dit_inputs->set_batch_size(dit_inputs.batch_size); + + ADD_VECTOR_TO_PROTO(pb_dit_inputs->mutable_prompts(), dit_inputs.prompts); + + ADD_VECTOR_TO_PROTO(pb_dit_inputs->mutable_prompts_2(), dit_inputs.prompts_2); + + ADD_VECTOR_TO_PROTO(pb_dit_inputs->mutable_negative_prompts(), + dit_inputs.negative_prompts); + + ADD_VECTOR_TO_PROTO(pb_dit_inputs->mutable_negative_prompts_2(), + dit_inputs.negative_prompts_2); + + torch_tensor_to_proto_tensor(dit_inputs.images, + pb_dit_inputs->mutable_images()); + + torch_tensor_to_proto_tensor(dit_inputs.condition_images, + pb_dit_inputs->mutable_condition_images()); + + torch_tensor_to_proto_tensor(dit_inputs.mask_images, + pb_dit_inputs->mutable_mask_images()); + + torch_tensor_to_proto_tensor(dit_inputs.control_image, + pb_dit_inputs->mutable_control_image()); + + torch_tensor_to_proto_tensor(dit_inputs.masked_image_latents, + pb_dit_inputs->mutable_masked_image_latents()); + + torch_tensor_to_proto_tensor(dit_inputs.prompt_embeds, + pb_dit_inputs->mutable_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.pooled_prompt_embeds, + pb_dit_inputs->mutable_pooled_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.negative_prompt_embeds, + pb_dit_inputs->mutable_negative_prompt_embeds()); + + torch_tensor_to_proto_tensor( + dit_inputs.negative_pooled_prompt_embeds, + pb_dit_inputs->mutable_negative_pooled_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.latents, + pb_dit_inputs->mutable_latents()); + + if (!generation_params_to_proto(dit_inputs.generation_params, + pb_dit_inputs->mutable_generation_params())) { + LOG(ERROR) << "Failed to convert generation_params"; + return false; + } + + return true; +} + +bool generation_params_to_proto( + const DiTGenerationParams& dit_generation_params, + proto::DiTGenerationParams* pb_dit_generation_params) { + pb_dit_generation_params->set_width(dit_generation_params.width); + pb_dit_generation_params->set_height(dit_generation_params.height); + pb_dit_generation_params->set_num_inference_steps( + dit_generation_params.num_inference_steps); + pb_dit_generation_params->set_true_cfg_scale( + dit_generation_params.true_cfg_scale); + pb_dit_generation_params->set_guidance_scale( + dit_generation_params.guidance_scale); + pb_dit_generation_params->set_num_images_per_prompt( + dit_generation_params.num_images_per_prompt); + pb_dit_generation_params->set_seed(dit_generation_params.seed); + pb_dit_generation_params->set_max_sequence_length( + dit_generation_params.max_sequence_length); + pb_dit_generation_params->set_strength(dit_generation_params.strength); + pb_dit_generation_params->set_enable_cfg_renorm( + dit_generation_params.enable_cfg_renorm); + pb_dit_generation_params->set_cfg_renorm_min( + dit_generation_params.cfg_renorm_min); + return true; +} + +bool proto_to_dit_forward_input(const proto::DiTForwardInput& pb_dit_inputs, + DiTForwardInput& dit_inputs) { + dit_inputs.batch_size = pb_dit_inputs.batch_size(); + + std::vector prompts = std::vector( + pb_dit_inputs.prompts().begin(), pb_dit_inputs.prompts().end()); + std::vector prompts_2 = std::vector( + pb_dit_inputs.prompts_2().begin(), pb_dit_inputs.prompts_2().end()); + std::vector negative_prompts = + std::vector(pb_dit_inputs.negative_prompts().begin(), + pb_dit_inputs.negative_prompts().end()); + std::vector negative_prompts_2 = + std::vector(pb_dit_inputs.negative_prompts_2().begin(), + pb_dit_inputs.negative_prompts_2().end()); + dit_inputs.prompts = std::move(prompts); + + dit_inputs.prompts_2 = std::move(prompts_2); + + dit_inputs.negative_prompts = std::move(negative_prompts); + + dit_inputs.negative_prompts_2 = std::move(negative_prompts_2); + + if (pb_dit_inputs.has_images()) { + dit_inputs.images = util::proto_to_torch(pb_dit_inputs.images()); + } + + if (pb_dit_inputs.has_condition_images()) { + dit_inputs.condition_images = + util::proto_to_torch(pb_dit_inputs.condition_images()); + } + + if (pb_dit_inputs.has_mask_images()) { + dit_inputs.mask_images = util::proto_to_torch(pb_dit_inputs.mask_images()); + } + + if (pb_dit_inputs.has_control_image()) { + dit_inputs.control_image = + util::proto_to_torch(pb_dit_inputs.control_image()); + } + + if (pb_dit_inputs.has_masked_image_latents()) { + dit_inputs.masked_image_latents = + util::proto_to_torch(pb_dit_inputs.masked_image_latents()); + } + + if (pb_dit_inputs.has_prompt_embeds()) { + dit_inputs.prompt_embeds = + util::proto_to_torch(pb_dit_inputs.prompt_embeds()); + } + + if (pb_dit_inputs.has_pooled_prompt_embeds()) { + dit_inputs.pooled_prompt_embeds = + util::proto_to_torch(pb_dit_inputs.pooled_prompt_embeds()); + } + + if (pb_dit_inputs.has_negative_prompt_embeds()) { + dit_inputs.negative_prompt_embeds = + util::proto_to_torch(pb_dit_inputs.negative_prompt_embeds()); + } + + if (pb_dit_inputs.has_negative_pooled_prompt_embeds()) { + dit_inputs.negative_pooled_prompt_embeds = + util::proto_to_torch(pb_dit_inputs.negative_pooled_prompt_embeds()); + } + + if (pb_dit_inputs.has_latents()) { + dit_inputs.latents = util::proto_to_torch(pb_dit_inputs.latents()); + } + + if (!proto_to_generation_params(pb_dit_inputs.generation_params(), + dit_inputs.generation_params)) { + LOG(ERROR) << "Failed to convert generation_params"; + return false; + } + + return true; +} + +bool proto_to_generation_params( + const proto::DiTGenerationParams& pb_dit_generation_params, + DiTGenerationParams& dit_generation_params) { + LOG(INFO) << "start brpc transfer"; + dit_generation_params.width = pb_dit_generation_params.width(); + dit_generation_params.height = pb_dit_generation_params.height(); + dit_generation_params.num_inference_steps = + pb_dit_generation_params.num_inference_steps(); + dit_generation_params.true_cfg_scale = + pb_dit_generation_params.true_cfg_scale(); + dit_generation_params.guidance_scale = + pb_dit_generation_params.guidance_scale(); + dit_generation_params.num_images_per_prompt = + pb_dit_generation_params.num_images_per_prompt(); + dit_generation_params.seed = pb_dit_generation_params.seed(); + dit_generation_params.max_sequence_length = + pb_dit_generation_params.max_sequence_length(); + dit_generation_params.strength = pb_dit_generation_params.strength(); + dit_generation_params.enable_cfg_renorm = + pb_dit_generation_params.enable_cfg_renorm(); + dit_generation_params.cfg_renorm_min = + pb_dit_generation_params.cfg_renorm_min(); + return true; +} + +bool proto_to_dit_forward_output(const proto::DiTForwardOutput& pb_dit_outputs, + DiTForwardOutput& dit_outputs) { + const auto& pb_tensor_list = pb_dit_outputs.tensors(); + std::vector torch_tensor_vec; + torch_tensor_vec.reserve(pb_tensor_list.tensors_size()); + for (const auto& pb_tensor : pb_tensor_list.tensors()) { + torch::Tensor torch_tensor = util::proto_to_torch(pb_tensor); + if (!torch_tensor.defined()) { + LOG(ERROR) << "Failed to convert PB Tensor to torch Tensor (list item)"; + return false; + } + torch_tensor_vec.emplace_back(std::move(torch_tensor)); + } + dit_outputs.tensors = std::move(torch_tensor_vec); + return true; +} + +bool torch_tensor_to_proto_tensor(const torch::Tensor& torch_tensor, + proto::Tensor* proto_tensor) { + if (torch_tensor.defined()) { + if (!util::torch_to_proto(torch_tensor, proto_tensor)) { + LOG(ERROR) << "Failed to convert torch Tensor to Pb Tensor "; + return false; + } + } + return true; +} + } // namespace xllm diff --git a/xllm/core/runtime/params_utils.h b/xllm/core/runtime/params_utils.h index 8fca46a80..8788bd404 100644 --- a/xllm/core/runtime/params_utils.h +++ b/xllm/core/runtime/params_utils.h @@ -44,6 +44,7 @@ void forward_output_to_proto(const torch::Tensor& next_tokens, const torch::Tensor& src_seq_idxes, const torch::Tensor& out_tokens, const torch::Tensor& out_logprobs, + const std::vector& dit_images, proto::ForwardOutput* pb_forward_output); Token build_token(int64_t index, @@ -65,4 +66,23 @@ bool block_transfer_info_to_proto( const std::vector& block_transfer_info, proto::BlockTransferInfos* pb_block_transfer_info); +bool dit_forward_input_to_proto(const DiTForwardInput& dit_inputs, + proto::DiTForwardInput* pb_dit_inputs); + +bool generation_params_to_proto( + const DiTGenerationParams& dit_generation_params, + proto::DiTGenerationParams* pb_dit_generation_params); + +bool proto_to_dit_forward_input(const proto::DiTForwardInput& pb_dit_inputs, + DiTForwardInput& dit_inputs); + +bool proto_to_generation_params( + const proto::DiTGenerationParams& pb_dit_generation_params, + DiTGenerationParams& dit_generation_params); + +bool proto_to_dit_forward_output(const proto::DiTForwardOutput& pb_dit_outputs, + DiTForwardOutput& dit_outputs); + +bool torch_tensor_to_proto_tensor(const torch::Tensor& torch_tensor, + proto::Tensor* proto_tensor); } // namespace xllm diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp index 74c88b9cb..dfe450493 100644 --- a/xllm/core/runtime/rec_worker_impl.cpp +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,10 +28,10 @@ limitations under the License. #include "common/device_monitor.h" #include "common/global_flags.h" #include "common/metrics.h" -#include "common/rec_model_utils.h" #include "common/types.h" #include "core/common/global_flags.h" #include "framework/model/model_input_params.h" +#include "util/rec_model_utils.h" #if defined(USE_CUDA) #include "kernels/cuda/cuda_ops_api.h" #include "kernels/cuda/xattention/xattention_ops_api.h" @@ -42,14 +43,30 @@ limitations under the License. #include "kernels/npu/xllm_ops/xllm_ops_api.h" #include "platform/npu/device_capture_lock.h" #endif +#include "common/version_singleton.h" #include "framework/model_loader.h" +#include "framework/sampling/rec_constrained_decoding.h" #include "framework/sampling/rec_sampler.h" +#include "framework/state_dict/rec_vocab_dict.h" #include "models/model_registry.h" #include "util/env_var.h" #include "util/timer.h" namespace xllm { +namespace { + +RecVocabDict* get_onerec_vocab_dict(const std::string& model_weights_path) { + if (model_weights_path.empty()) { + return nullptr; + } + const std::string model_version = + std::filesystem::path(model_weights_path).filename().string(); + return VersionSingleton::GetInstance(model_version); +} + +} // namespace + // ============================================================ // RecWorkerImpl Implementation (base) // ============================================================ @@ -78,32 +95,9 @@ void RecWorkerImpl::RecWorkPipeline::prepare_work_before_execute( processed_inputs = inputs.to(runtime_.worker.device(), runtime_.worker.dtype()); auto& input_params = processed_inputs.input_params; -#if defined(USE_NPU) - if (input_params.swap_blocks.size() > 0 && !FLAGS_enable_block_copy_kernel) { - auto& swap_blocks = input_params.swap_blocks; - - // collect src and dst indices - std::vector src_indices, dst_indices; - src_indices.reserve(swap_blocks.size()); - dst_indices.reserve(swap_blocks.size()); + runtime_.worker.apply_kv_block_swaps(input_params); - for (const auto& block : swap_blocks) { - src_indices.push_back(block.src_block_id); - dst_indices.push_back(block.dst_block_id); - } - - // batch select keys and values - auto src_tensor = torch::tensor( - src_indices, - torch::dtype(torch::kLong).device(runtime_.worker.device_)); - auto dst_tensor = torch::tensor( - dst_indices, - torch::dtype(torch::kLong).device(runtime_.worker.device_)); - const int64_t num_layers = runtime_.context->get_model_args().n_layers(); - for (int layer_id = 0; layer_id < num_layers; layer_id++) { - runtime_.worker.kv_caches_[layer_id].swap_blocks(src_tensor, dst_tensor); - } - } +#if defined(USE_NPU) if (runtime_.context->get_model_args().enable_mla() && input_params.batch_forward_type.is_chunked_prefill()) { runtime_.worker.prepare_mla_prefixcache_inputs(input_params); @@ -149,6 +143,12 @@ std::optional RecWorkerImpl::RecWorkPipeline::step( std::shared_ptr layer_synchronizer = std::make_shared( runtime_.context->get_model_args().n_layers()); +#elif defined(USE_MLU) + std::shared_ptr layer_synchronizer = + std::make_shared( + runtime_.context->get_model_args().n_layers()); +#endif +#if defined(USE_NPU) || defined(USE_MLU) const_cast(&(input.input_params))->layer_synchronizer = layer_synchronizer; @@ -280,6 +280,34 @@ void RecWorkerImpl::LlmRecWorkPipeline::prepare_work_before_execute( runtime_.worker.prepare_multi_modal_data(processed_inputs); } +RecWorkerImpl::OneRecWorkPipeline::OneRecWorkPipeline( + RecPipelineRuntime& runtime) + : RecWorkPipeline(runtime), + rec_sampler_( + std::make_unique(RecPipelineType::kOneRecDefault)), + filter_mask_threadpool_(std::make_unique(1)) { + if (!FLAGS_enable_constrained_decoding) { + return; + } + + auto* vocab_dict = get_onerec_vocab_dict(runtime_.worker.model_weights_path_); + CHECK(vocab_dict != nullptr) + << "Failed to get RecVocabDict for OneRec constrained decoding, " + << "model_path=" << runtime_.worker.model_weights_path_; + + const int32_t vocab_size = + static_cast(runtime_.context->get_model_args().vocab_size()); + constrained_decoding_ = + std::make_unique(vocab_dict, + vocab_size, + runtime_.worker.dtype(), + runtime_.worker.device(), + /*use_gen_threadpool=*/false); + CHECK(constrained_decoding_->build_mask_cache()) + << "Failed to build OneRec constrained decoding cache, vocab_size=" + << vocab_size; +} + ForwardInput RecWorkerImpl::OneRecWorkPipeline::prepare_inputs(Batch& batch) { ThreadPool* thread_pool = runtime_.worker.input_builder_thread_pool_ @@ -293,6 +321,63 @@ ForwardInput RecWorkerImpl::OneRecWorkPipeline::prepare_inputs(Batch& batch) { thread_pool); } +void RecWorkerImpl::OneRecWorkPipeline::prepare_work_before_execute( + const ForwardInput& inputs, + ForwardInput& processed_inputs) { + RecWorkPipeline::prepare_work_before_execute(inputs, processed_inputs); + + auto& onerec_params = processed_inputs.input_params.mutable_onerec_params(); + if (!onerec_params.decoder_context_embedding.defined()) { + return; + } + + if (onerec_params.decoder_context_embedding.scalar_type() == + runtime_.worker.dtype()) { + return; + } + + onerec_params.decoder_context_embedding = + onerec_params.decoder_context_embedding.to(runtime_.worker.dtype()); +} + +folly::SemiFuture +RecWorkerImpl::OneRecWorkPipeline::prepare_filter_mask_async( + const std::vector>& generated_tokens) { + folly::Promise promise; + auto future = promise.getSemiFuture(); + + if (!constrained_decoding_ || !filter_mask_threadpool_ || + generated_tokens.empty()) { + promise.setValue(torch::Tensor()); + return future; + } + + filter_mask_threadpool_->schedule( + [this, generated_tokens, promise = std::move(promise)]() mutable { + try { + auto filter_mask = + constrained_decoding_->generate_mask(generated_tokens); + promise.setValue(filter_mask); + } catch (const std::exception& e) { + const int32_t batch = static_cast(generated_tokens.size()); + const int32_t seq = + batch > 0 ? static_cast(generated_tokens[0].size()) : 0; + LOG(ERROR) << "Failed to generate OneRec filter mask, batch=" << batch + << ", seq=" << seq << ", error=" << e.what(); + promise.setValue(torch::Tensor()); + } catch (...) { + const int32_t batch = static_cast(generated_tokens.size()); + const int32_t seq = + batch > 0 ? static_cast(generated_tokens[0].size()) : 0; + LOG(ERROR) << "Failed to generate OneRec filter mask, batch=" << batch + << ", seq=" << seq << ", error=unknown"; + promise.setValue(torch::Tensor()); + } + }); + + return future; +} + std::optional RecWorkerImpl::OneRecWorkPipeline::step( const ForwardInput& input) { Timer timer; @@ -305,10 +390,24 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( CHECK(onerec_params != nullptr) << "OneRec requires rec_params."; const OneRecModelInputParams& rec_params = *onerec_params; + const bool has_decoder_context = + rec_params.decoder_context_embedding.defined(); + const bool has_encoder_context = + rec_params.has_encoder_output || has_decoder_context; + std::optional> filter_mask_future; + if ((runtime_.worker.driver_ || runtime_.worker.dp_driver_) && + FLAGS_enable_constrained_decoding && constrained_decoding_ != nullptr && + sampling_params.selected_token_idxes.defined()) { + filter_mask_future = prepare_filter_mask_async(rec_params.generated_tokens); + } torch::Tensor hidden_states; if (rec_params.rec_stage == OneRecModelInputParams::RecStage::PREFILL) { if (!rec_params.is_first_prefill) { + if (!has_encoder_context) { + LOG(ERROR) << "OneRec prefill requires encoder context."; + return std::nullopt; + } ModelInputParams decoder_params = input_params; decoder_params.mutable_onerec_params().is_encoder_forward = false; decoder_params.mutable_onerec_params().has_encoder_output = @@ -353,11 +452,6 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = encoder_output.hidden_states.defined(); - if (encoder_output.hidden_states.defined() && - !decoder_onerec_params.decoder_context_embedding.defined()) { - decoder_onerec_params.decoder_context_embedding = - encoder_output.hidden_states; - } auto model_output = runtime_.executor->forward(input.token_ids, input.positions, runtime_.worker.kv_caches_, @@ -365,6 +459,10 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( hidden_states = model_output.hidden_states; } } else { + if (!has_encoder_context) { + LOG(ERROR) << "OneRec decode requires encoder context."; + return std::nullopt; + } ModelInputParams decoder_params = input_params; decoder_params.mutable_onerec_params().is_encoder_forward = false; decoder_params.mutable_onerec_params().has_encoder_output = @@ -398,8 +496,12 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( ForwardOutput output; if (sampling_params.selected_token_idxes.defined()) { + torch::Tensor filter_mask; + if (filter_mask_future.has_value()) { + filter_mask = std::move(filter_mask_future.value()).get(); + } auto sample_output = - runtime_.worker.sampler_->forward(logits, sampling_params); + rec_sampler_->forward(logits, sampling_params, filter_mask); output.logits = logits; output.sample_output = sample_output; output.do_sample = sampling_params.do_sample; diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h index 646dce449..40a7bf44e 100644 --- a/xllm/core/runtime/rec_worker_impl.h +++ b/xllm/core/runtime/rec_worker_impl.h @@ -23,13 +23,14 @@ limitations under the License. #include #include -#include "common/rec_model_utils.h" #include "runtime/llm_worker_impl.h" +#include "util/rec_model_utils.h" #include "util/threadpool.h" namespace xllm { class RecSampler; +class RecConstrainedDecoding; class RecWorkerImpl : public LLMWorkerImpl { friend class RecWorkPipeline; @@ -115,12 +116,22 @@ class RecWorkerImpl : public LLMWorkerImpl { class OneRecWorkPipeline final : public RecWorkPipeline { public: - explicit OneRecWorkPipeline(RecPipelineRuntime& runtime) - : RecWorkPipeline(runtime) {} + explicit OneRecWorkPipeline(RecPipelineRuntime& runtime); ForwardInput prepare_inputs(Batch& batch) override; + void prepare_work_before_execute(const ForwardInput& inputs, + ForwardInput& processed_inputs) override; + std::optional step(const ForwardInput& input) override; + + private: + folly::SemiFuture prepare_filter_mask_async( + const std::vector>& generated_tokens); + + std::unique_ptr rec_sampler_; + std::unique_ptr constrained_decoding_; + std::unique_ptr filter_mask_threadpool_; }; class LlmRecWithMmDataWorkPipeline final : public RecWorkPipeline { diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 74169ffb4..4cd922ff4 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/state_dict/state_dict.h" +#include "runtime/dit_worker_impl.h" #include "runtime/eagle3_worker_impl.h" #include "runtime/embed_vlm_worker_impl.h" #include "runtime/embed_worker_impl.h" @@ -69,6 +70,8 @@ Worker::Worker(const ParallelArgs& parallel_args, impl_ = new RecWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::MMEVLM) { impl_ = new MMEmbedVLMWorkerImpl(parallel_args, device, options); + } else if (worker_type == WorkerType::DIT) { + impl_ = new DiTWorkerImpl(parallel_args, device, options); } else { LOG(ERROR) << "Unknown worker type, please check logic"; } diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 491714e66..02c59c503 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -43,6 +43,8 @@ limitations under the License. #include "common/metrics.h" #if defined(USE_NPU) #include "platform/npu/device_capture_lock.h" +#elif defined(USE_CUDA) +#include "kernels/cuda/cuda_ops_api.h" #endif #include "core/distributed_runtime/master.h" #include "framework/kv_cache/kv_cache.h" @@ -54,7 +56,6 @@ limitations under the License. #include "framework/xtensor/global_xtensor.h" #include "framework/xtensor/xtensor_allocator.h" #if defined(USE_NPU) -#include "framework/kv_cache/mooncake_weight_transfer.h" #include "layers/npu/loader/rolling_weight_buffer.h" #endif #include "util/net.h" @@ -140,11 +141,12 @@ WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args, compute_stream_ = device_.get_stream_from_pool(); sampler_ = std::make_unique(); -#if !defined(USE_NPU) - // Startup validation: ATB block-copy kernel is NPU-only. We should fail fast - // if CUDA deployment accidentally enables it. - CHECK(!FLAGS_enable_block_copy_kernel) - << "enable_block_copy_kernel must be false on CUDA builds."; +#if !defined(USE_NPU) && !defined(USE_CUDA) + if (FLAGS_enable_block_copy_kernel) { + LOG(WARNING) << "enable_block_copy_kernel is only supported on NPU/CUDA; " + "forcing enable_block_copy_kernel=false."; + FLAGS_enable_block_copy_kernel = false; + } #endif #if defined(USE_NPU) @@ -172,14 +174,15 @@ bool WorkerImpl::allocate_kv_cache( const std::vector>& kv_cache_shape) { CHECK(model_ != nullptr) << "Model is not initialized."; CHECK(kv_caches_.empty()) << "KV caches are already initialized."; - const bool enable_linear_attention = - has_linear_attention_layers(context_.get_model_args()); - const bool enable_lighting_indexer = - context_.get_model_args().index_n_heads() > 0; + const auto& args = context_.get_model_args(); + const bool enable_linear_attention = has_linear_attention_layers(args); + const bool enable_lighting_indexer = args.index_n_heads() > 0; CHECK(!(enable_linear_attention && enable_lighting_indexer)) << "KVCache does not support linear attention and lighting indexer " << "simultaneously."; + const int64_t num_layers = get_num_layers(); + // Check if KV cache quantization is enabled // "auto" (default): cache dtype aligns with model dtype (no quantization) // "int8": enables INT8 quantization @@ -202,11 +205,12 @@ bool WorkerImpl::allocate_kv_cache( } // create a KVCache for each layer - const int64_t num_layers = get_num_layers(); kv_caches_.reserve(num_layers); if (FLAGS_enable_xtensor) { // XTensor mode: create xtensor-backed KV cache tensors. + // For hybrid models, we still create full KV cache for all layers + // since xtensor has its own memory management auto& allocator = XTensorAllocator::get_instance(); const std::string& model_id = options_.model_id(); // Create K tensors for all layers @@ -223,96 +227,140 @@ bool WorkerImpl::allocate_kv_cache( k_tensor = at_npu::native::npu_format_cast(k_tensor, ACL_FORMAT_ND); v_tensor = at_npu::native::npu_format_cast(v_tensor, ACL_FORMAT_ND); #endif + + // For xtensor mode, we still use the full KV cache approach kv_caches_.emplace_back(k_tensor, v_tensor); } } else { // Original mode: create torch tensors with optional int8 kv quantization. torch::ScalarType cache_dtype = enable_kv_cache_quant ? torch::kInt8 : dtype_; + + // Helper function to check if a layer is linear attention + auto is_linear_attention_layer = [&](int64_t layer_idx) { + if (args.full_attention_interval() > 1) { + return (layer_idx + 1) % args.full_attention_interval() != 0; + } + return false; + }; + for (int64_t i = 0; i < num_layers; ++i) { + bool is_linear_layer = is_linear_attention_layer(i); torch::Tensor key_cache, value_cache, index_cache, conv_cache, ssm_cache; torch::Tensor key_cache_scale, value_cache_scale; + + if (is_linear_layer) { + // Linear attention layer: only allocate conv_cache and ssm_cache #if defined(USE_NPU) - aclFormat npu_format_type = - context_.get_model_args().model_type() == "deepseek_v3" && - FLAGS_enable_prefix_cache - ? ACL_FORMAT_FRACTAL_NZ - : ACL_FORMAT_ND; - key_cache = at_npu::native::npu_format_cast( - torch::empty(kv_cache_shape[0], - torch::dtype(cache_dtype).device(device_)), - npu_format_type); - value_cache = at_npu::native::npu_format_cast( - torch::empty(kv_cache_shape[1], - torch::dtype(cache_dtype).device(device_)), - npu_format_type); - if (enable_lighting_indexer) { - index_cache = at_npu::native::npu_format_cast( - torch::empty(kv_cache_shape[2], - torch::dtype(dtype_).device(device_)), - npu_format_type); - } - if (enable_linear_attention) { - conv_cache = at_npu::native::npu_format_cast( - torch::zeros(kv_cache_shape[2], - torch::dtype(dtype_).device(device_)), - 2); - ssm_cache = at_npu::native::npu_format_cast( - torch::zeros(kv_cache_shape[3], - torch::dtype(dtype_).device(device_)), - 2); - } + aclFormat npu_format_type = ACL_FORMAT_ND; + if (enable_linear_attention) { + conv_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)), + 2); + ssm_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[3], + torch::dtype(dtype_).device(device_)), + 2); + } #elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) - key_cache = torch::zeros(kv_cache_shape[0], - torch::dtype(cache_dtype).device(device_)); - if (!kv_cache_shape[1].empty()) { - value_cache = torch::zeros(kv_cache_shape[1], - torch::dtype(cache_dtype).device(device_)); - } - if (enable_lighting_indexer) { - index_cache = torch::zeros(kv_cache_shape[2], + if (enable_linear_attention) { + conv_cache = torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)); + ssm_cache = torch::zeros(kv_cache_shape[3], torch::dtype(dtype_).device(device_)); - } - if (enable_kv_cache_quant) { - std::vector key_scale_shape(kv_cache_shape[0].begin(), - kv_cache_shape[0].end() - 1); - key_cache_scale = torch::zeros( - key_scale_shape, torch::dtype(torch::kFloat32).device(device_)); - if (!kv_cache_shape[1].empty()) { - std::vector value_scale_shape(kv_cache_shape[1].begin(), - kv_cache_shape[1].end() - 1); - value_cache_scale = torch::zeros( - value_scale_shape, torch::dtype(torch::kFloat32).device(device_)); } - } #else - key_cache = torch::empty(kv_cache_shape[0], - torch::dtype(cache_dtype).device(device_)); - if (!kv_cache_shape[1].empty()) { - value_cache = torch::empty(kv_cache_shape[1], - torch::dtype(cache_dtype).device(device_)); - } - if (enable_lighting_indexer) { - index_cache = torch::empty(kv_cache_shape[2], + if (enable_linear_attention) { + conv_cache = torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)); + ssm_cache = torch::zeros(kv_cache_shape[3], torch::dtype(dtype_).device(device_)); - } + } #endif - if (enable_kv_cache_quant) { - kv_caches_.emplace_back(key_cache, - value_cache, - index_cache, - key_cache_scale, - value_cache_scale); - } else if (enable_linear_attention) { - kv_caches_.emplace_back(key_cache, value_cache, conv_cache, ssm_cache); - } else if (enable_lighting_indexer) { - kv_caches_.emplace_back(key_cache, value_cache, index_cache); + // Create empty KVCache with only conv and ssm + kv_caches_.emplace_back( + torch::zeros({0}, torch::dtype(dtype_).device(device_)), + torch::zeros({0}, torch::dtype(dtype_).device(device_)), + conv_cache, + ssm_cache); } else { - kv_caches_.emplace_back(key_cache, value_cache); + // Full attention layer: allocate key_cache and value_cache only +#if defined(USE_NPU) + aclFormat npu_format_type = + context_.get_model_args().model_type() == "deepseek_v3" && + FLAGS_enable_prefix_cache + ? ACL_FORMAT_FRACTAL_NZ + : ACL_FORMAT_ND; + key_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[0], + torch::dtype(cache_dtype).device(device_)), + npu_format_type); + value_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[1], + torch::dtype(cache_dtype).device(device_)), + npu_format_type); + if (enable_lighting_indexer) { + index_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)), + npu_format_type); + } +#elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) + key_cache = torch::zeros(kv_cache_shape[0], + torch::dtype(cache_dtype).device(device_)); + if (!kv_cache_shape[1].empty()) { + value_cache = torch::zeros(kv_cache_shape[1], + torch::dtype(cache_dtype).device(device_)); + } + if (enable_lighting_indexer) { + index_cache = torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)); + } + if (enable_kv_cache_quant) { + std::vector key_scale_shape(kv_cache_shape[0].begin(), + kv_cache_shape[0].end() - 1); + key_cache_scale = torch::zeros( + key_scale_shape, torch::dtype(torch::kFloat32).device(device_)); + if (!kv_cache_shape[1].empty()) { + std::vector value_scale_shape(kv_cache_shape[1].begin(), + kv_cache_shape[1].end() - 1); + value_cache_scale = + torch::zeros(value_scale_shape, + torch::dtype(torch::kFloat32).device(device_)); + } + } +#else + key_cache = torch::zeros(kv_cache_shape[0], + torch::dtype(cache_dtype).device(device_)); + if (!kv_cache_shape[1].empty()) { + value_cache = torch::zeros(kv_cache_shape[1], + torch::dtype(cache_dtype).device(device_)); + } + if (enable_lighting_indexer) { + index_cache = torch::zeros(kv_cache_shape[2], + torch::dtype(dtype_).device(device_)); + } +#endif + if (enable_kv_cache_quant) { + kv_caches_.emplace_back(key_cache, + value_cache, + index_cache, + key_cache_scale, + value_cache_scale); + } else if (enable_lighting_indexer) { + kv_caches_.emplace_back(key_cache, value_cache, index_cache); + } else { + kv_caches_.emplace_back(key_cache, value_cache); + } } } } +#if defined(USE_CUDA) + refresh_cuda_block_copy_runtime_state(); +#endif + init_hierarchy_kv_cache_transfer(); status_ = Status::READY; return true; @@ -387,7 +435,7 @@ void WorkerImpl::get_cache_info(uint64_t& cluster_id, std::string& addr, int64_t& k_cache_id, int64_t& v_cache_id) { -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) kv_cache_transfer_->get_cache_info(cluster_id, addr, k_cache_id, v_cache_id); #endif } @@ -396,7 +444,7 @@ bool WorkerImpl::link_cluster(const std::vector& cluster_ids, const std::vector& addrs, const std::vector& device_ips, const std::vector& ports) { -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) for (int32_t i = 0; i < cluster_ids.size(); ++i) { if (!kv_cache_transfer_->link_cluster( cluster_ids[i], addrs[i], device_ips[i], ports[i])) { @@ -411,7 +459,7 @@ bool WorkerImpl::unlink_cluster(const std::vector& cluster_ids, const std::vector& addrs, const std::vector& device_ips, const std::vector& ports) { -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_MLU) for (int32_t i = 0; i < cluster_ids.size(); ++i) { if (!kv_cache_transfer_->unlink_cluster( cluster_ids[i], addrs[i], device_ips[i], ports[i])) { @@ -549,98 +597,199 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, lock_guard.emplace(capture_lock); } #endif - c10::StreamGuard streamGuard = prepare_stream_->set_stream_guard(); - processed_input = input.to(device_, dtype_); + const bool use_default_stream = + !enable_schedule_overlap() && options_.backend() == "llm"; + auto prepare_input_on_current_stream = [&]() { + processed_input = input.to(device_, dtype_); + auto& input_params = processed_input.input_params; #if defined(USE_NPU) - CpPrefillInputs tmp_cp_inputs; - if (parallel_args_.cp_size() > 1 && - input.input_params.batch_forward_type.is_prefill()) { - tmp_cp_inputs = prepare_cp_prefill_inputs(parallel_args_.cp_size(), - input.token_ids, - input.positions, - input.input_params.q_seq_lens); - processed_input.input_params.cp_prefill_inputs = tmp_cp_inputs.to(device_); - CpEpPadding cp_ep_padding( - input.token_ids, - context_.get_model_args().num_experts_per_tok(), - context_.get_parallel_args().mapping_data(), - /*device=*/device_, - dtype_, - /*is_prefill=*/input.input_params.batch_forward_type.is_prefill()); - processed_input.input_params.cp_ep_padding_data = cp_ep_padding.build(); - } + CpPrefillInputs tmp_cp_inputs; + if (parallel_args_.cp_size() > 1 && + input.input_params.batch_forward_type.is_prefill()) { + tmp_cp_inputs = prepare_cp_prefill_inputs(parallel_args_.cp_size(), + input.token_ids, + input.positions, + input.input_params.q_seq_lens); + processed_input.input_params.cp_prefill_inputs = + tmp_cp_inputs.to(device_); + CpEpPadding cp_ep_padding( + input.token_ids, + context_.get_model_args().num_experts_per_tok(), + context_.get_parallel_args().mapping_data(), + /*device=*/device_, + dtype_, + /*is_prefill=*/input.input_params.batch_forward_type.is_prefill()); + processed_input.input_params.cp_ep_padding_data = cp_ep_padding.build(); + } #endif - auto& input_params = processed_input.input_params; + apply_kv_block_swaps(input_params); #if defined(USE_NPU) - const bool use_block_copy_kernel = FLAGS_enable_block_copy_kernel; -#else - const bool use_block_copy_kernel = false; -#endif + if (context_.get_model_args().enable_mla() && + input_params.batch_forward_type.is_chunked_prefill()) { + prepare_mla_prefixcache_inputs(input_params); + } -#if defined(USE_NPU) || defined(USE_CUDA) - if (input_params.swap_blocks.size() > 0 && !use_block_copy_kernel) { - auto& swap_blocks = input_params.swap_blocks; + if (!context_.get_parallel_args().mapping_data().empty() && + !(context_.get_parallel_args().cp_size() > 1) && + (context_.get_parallel_args().dp_size() > 1 || + context_.get_parallel_args().ep_size() > 1)) { + torch::Tensor token_size_per_dp_group = + torch::tensor(processed_input.input_params.dp_global_token_nums, + torch::TensorOptions() + .device(torch::kCPU) + .dtype(torch::kInt32) + .pinned_memory(true)); + bool is_prefill = + processed_input.input_params.batch_forward_type.is_prefill(); + DpEpPadding dp_ep_padding(token_size_per_dp_group, + context_.get_model_args().num_experts_per_tok(), + context_.get_parallel_args().mapping_data(), + device_, + dtype_, + is_prefill); + processed_input.input_params.dp_ep_padding_data = dp_ep_padding.build(); + if (FLAGS_enable_eplb) { + // expert_load_data_.fill_(0); + processed_input.input_params.expert_load_data = expert_load_data_; + } + } +#endif + }; - // collect src and dst indices - std::vector src_indices, dst_indices; - src_indices.reserve(swap_blocks.size()); - dst_indices.reserve(swap_blocks.size()); + if (use_default_stream) { + prepare_input_on_current_stream(); + } else { + c10::StreamGuard stream_guard = prepare_stream_->set_stream_guard(); + prepare_input_on_current_stream(); + } - for (const auto& block : swap_blocks) { - src_indices.push_back(block.src_block_id); - dst_indices.push_back(block.dst_block_id); - } + if (!use_default_stream) { + prepare_stream_->synchronize(); + } +} - // batch select keys and values - auto src_tensor = - torch::tensor(src_indices, torch::dtype(torch::kLong).device(device_)); - auto dst_tensor = - torch::tensor(dst_indices, torch::dtype(torch::kLong).device(device_)); - const int64_t num_layers = context_.get_model_args().n_layers(); - for (int32_t layer_id = 0; layer_id < num_layers; layer_id++) { - kv_caches_[layer_id].swap_blocks(src_tensor, dst_tensor); - } +void WorkerImpl::apply_kv_block_swaps(const ModelInputParams& input_params) { +#if defined(USE_CUDA) + if (FLAGS_enable_block_copy_kernel && + can_use_cuda_block_copy_kernel(input_params)) { + execute_cuda_block_copy_kernel(input_params); + return; } #endif #if defined(USE_NPU) - if (context_.get_model_args().enable_mla() && - input_params.batch_forward_type.is_chunked_prefill()) { - prepare_mla_prefixcache_inputs(input_params); + if (input_params.swap_blocks.size() == 0 || FLAGS_enable_block_copy_kernel) { + return; + } +#elif defined(USE_CUDA) + if (input_params.swap_blocks.size() == 0) { + return; } +#else + return; +#endif - if (!context_.get_parallel_args().mapping_data().empty() && - !(context_.get_parallel_args().cp_size() > 1) && - (context_.get_parallel_args().dp_size() > 1 || - context_.get_parallel_args().ep_size() > 1)) { - torch::Tensor token_size_per_dp_group = - torch::tensor(processed_input.input_params.dp_global_token_nums, - torch::TensorOptions() - .device(torch::kCPU) - .dtype(torch::kInt32) - .pinned_memory(true)); - bool is_prefill = - processed_input.input_params.batch_forward_type.is_prefill(); - DpEpPadding dp_ep_padding(token_size_per_dp_group, - context_.get_model_args().num_experts_per_tok(), - context_.get_parallel_args().mapping_data(), - device_, - dtype_, - is_prefill); - processed_input.input_params.dp_ep_padding_data = dp_ep_padding.build(); - if (FLAGS_enable_eplb) { - // expert_load_data_.fill_(0); - processed_input.input_params.expert_load_data = expert_load_data_; - } +#if defined(USE_NPU) || defined(USE_CUDA) + std::vector src_indices, dst_indices; + src_indices.reserve(input_params.swap_blocks.size()); + dst_indices.reserve(input_params.swap_blocks.size()); + + for (const auto& block : input_params.swap_blocks) { + src_indices.push_back(block.src_block_id); + dst_indices.push_back(block.dst_block_id); + } + + auto src_tensor = + torch::tensor(src_indices, torch::dtype(torch::kLong).device(device_)); + auto dst_tensor = + torch::tensor(dst_indices, torch::dtype(torch::kLong).device(device_)); + for (size_t layer_id = 0; layer_id < kv_caches_.size(); ++layer_id) { + kv_caches_[layer_id].swap_blocks(src_tensor, dst_tensor); } #endif +} - auto ret = prepare_stream_->synchronize(); +#if defined(USE_CUDA) +void WorkerImpl::refresh_cuda_block_copy_runtime_state() { + cuda_block_copy_runtime_state_ = {}; + if (!FLAGS_enable_block_copy_kernel || kv_caches_.empty()) { + return; + } + + const auto& first_kv_cache = kv_caches_.front(); + auto key_cache = first_kv_cache.get_k_cache(); + auto value_cache = first_kv_cache.get_v_cache(); + if (!key_cache.defined() || !value_cache.defined() || !key_cache.is_cuda() || + !value_cache.is_cuda()) { + return; + } + + CHECK(key_cache.is_contiguous()) + << "CUDA block copy kernel expects contiguous key cache"; + CHECK(value_cache.is_contiguous()) + << "CUDA block copy kernel expects contiguous value cache"; + CHECK_GT(key_cache.size(0), 0); + + const auto cache_dtype = key_cache.scalar_type(); + std::vector key_cache_ptrs; + std::vector value_cache_ptrs; + key_cache_ptrs.reserve(kv_caches_.size()); + value_cache_ptrs.reserve(kv_caches_.size()); + for (const auto& kv_cache : kv_caches_) { + auto layer_k_cache = kv_cache.get_k_cache(); + auto layer_v_cache = kv_cache.get_v_cache(); + CHECK(layer_k_cache.defined() && layer_v_cache.defined()); + CHECK(layer_k_cache.is_cuda() && layer_v_cache.is_cuda()); + CHECK(layer_k_cache.is_contiguous()); + CHECK(layer_v_cache.is_contiguous()); + CHECK(layer_k_cache.scalar_type() == cache_dtype); + CHECK(layer_v_cache.scalar_type() == cache_dtype); + CHECK(layer_k_cache.sizes() == key_cache.sizes()); + CHECK(layer_v_cache.sizes() == value_cache.sizes()); + key_cache_ptrs.push_back( + reinterpret_cast(layer_k_cache.data_ptr())); + value_cache_ptrs.push_back( + reinterpret_cast(layer_v_cache.data_ptr())); + } + + auto ptr_options = + torch::TensorOptions().device(device_).dtype(torch::kInt64); + cuda_block_copy_runtime_state_.k_cache_ptrs_device = + torch::tensor(key_cache_ptrs, ptr_options); + cuda_block_copy_runtime_state_.v_cache_ptrs_device = + torch::tensor(value_cache_ptrs, ptr_options); + cuda_block_copy_runtime_state_.num_layers = kv_caches_.size(); + cuda_block_copy_runtime_state_.numel_per_block = key_cache[0].numel(); +} + +bool WorkerImpl::can_use_cuda_block_copy_kernel( + const ModelInputParams& input_params) const { + return cuda_block_copy_runtime_state_.valid() && + input_params.src_block_indices.defined() && + input_params.dst_block_indices.defined() && + input_params.cum_sum.defined() && + input_params.src_block_indices.numel() > 0 && + input_params.dst_block_indices.numel() > 0 && + input_params.cum_sum.numel() > 0; } +void WorkerImpl::execute_cuda_block_copy_kernel( + const ModelInputParams& input_params) { + CHECK(!kv_caches_.empty()); + xllm::kernel::cuda::block_copy( + cuda_block_copy_runtime_state_.k_cache_ptrs_device, + cuda_block_copy_runtime_state_.v_cache_ptrs_device, + input_params.src_block_indices, + input_params.dst_block_indices, + input_params.cum_sum, + cuda_block_copy_runtime_state_.numel_per_block, + kv_caches_.front().get_k_cache().scalar_type()); +} +#endif + folly::SemiFuture> WorkerImpl::step_async( const ForwardInput& input) { ForwardInput input_on_device; @@ -921,6 +1070,7 @@ bool WorkerImpl::init_model(const std::string& model_weights_path, {"deepseek_v3", "deepseek_v3_mtp"}, {"deepseek_v32", "deepseek_v3_mtp"}, {"glm_moe_dsa", "glm_moe_dsa_mtp"}, + {"joyai_llm_flash", "joyai_llm_flash_mtp"}, }; const std::string& current_type = args.model_type(); auto it = kModelTypeToMtpType.find(current_type); @@ -1072,6 +1222,14 @@ folly::SemiFuture WorkerImpl::pull_kv_blocks_async( src_v_cache_id, src_blocks, dst_blocks); +#elif defined(USE_MLU) + (void)src_cluster_id; + (void)src_addr; + (void)src_k_cache_id; + (void)src_v_cache_id; + (void)src_blocks; + (void)dst_blocks; + LOG(FATAL) << "MLU backend does not support PULL kv cache transfer."; #endif return false; } diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 662752ff0..43426d174 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -25,9 +25,9 @@ limitations under the License. #include "executor.h" #include "forward_params.h" #include "framework/eplb/eplb_executor.h" -#include "framework/kv_cache/hierarchy_kv_cache_transfer.h" -#include "framework/kv_cache/kv_cache_store.h" -#include "framework/kv_cache/kv_cache_transfer.h" +#include "framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h" +#include "framework/kv_cache_transfer/kv_cache_store.h" +#include "framework/kv_cache_transfer/kv_cache_transfer.h" #include "framework/model/causal_lm.h" #include "framework/model/model_input_params.h" #include "framework/model_context.h" @@ -41,7 +41,7 @@ limitations under the License. #include "platform/device.h" #include "util/threadpool.h" #if defined(USE_NPU) -#include "framework/kv_cache/mooncake_weight_transfer.h" +#include "framework/kv_cache_transfer/mooncake_weight_transfer.h" #include "layers/npu/loader/rolling_load_manager.h" #endif @@ -115,6 +115,9 @@ class WorkerImpl { virtual void prepare_work_before_execute(const ForwardInput& inputs, ForwardInput& processed_inputs); + // Internal helper shared by worker pipelines before model execution. + virtual void apply_kv_block_swaps(const ModelInputParams& input_params); + virtual std::optional step(const ForwardInput& inputs) = 0; virtual void process_group_test(); @@ -204,6 +207,25 @@ class WorkerImpl { bool wakeup_local(const WakeupOptions& options); +#if defined(USE_CUDA) + void refresh_cuda_block_copy_runtime_state(); + bool can_use_cuda_block_copy_kernel( + const ModelInputParams& input_params) const; + void execute_cuda_block_copy_kernel(const ModelInputParams& input_params); + + struct CudaBlockCopyRuntimeState { + torch::Tensor k_cache_ptrs_device; + torch::Tensor v_cache_ptrs_device; + int64_t num_layers = 0; + int64_t numel_per_block = 0; + + bool valid() const { + return k_cache_ptrs_device.defined() && v_cache_ptrs_device.defined() && + num_layers > 0 && numel_per_block > 0; + } + }; +#endif + #if defined(USE_NPU) bool wakeup_from_remote_weights(const WakeupOptions& options); // Complete rolling initialization by delegating to model-owned rolling @@ -263,6 +285,10 @@ class WorkerImpl { std::shared_ptr kv_cache_transfer_; std::unique_ptr hierarchy_kv_cache_transfer_; +#if defined(USE_CUDA) + CudaBlockCopyRuntimeState cuda_block_copy_runtime_state_; +#endif + #if defined(USE_NPU) std::unique_ptr weight_transfer_; std::unique_ptr load_stream_; diff --git a/xllm/core/runtime/xservice_client.cpp b/xllm/core/runtime/xservice_client.cpp index 55292e749..1007ef9a5 100644 --- a/xllm/core/runtime/xservice_client.cpp +++ b/xllm/core/runtime/xservice_client.cpp @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "util/env_var.h" #include "util/hash_util.h" #include "util/net.h" #include "util/uuid.h" @@ -35,6 +36,8 @@ namespace { static std::string ETCD_MASTER_SERVICE_KEY = "XLLM:SERVICE:MASTER"; static std::string ETCD_XSERVICES_KEY_PREFIX = "XLLM:SERVICE:"; // all xllm_service registeration prefix +constexpr const char* kEtcdUsernameEnvVar = "ETCD_USERNAME"; +constexpr const char* kEtcdPasswordEnvVar = "ETCD_PASSWORD"; static std::unordered_map ETCD_KEYS_PREFIX_MAP = { {xllm_service::proto::InstanceType::DEFAULT, "XLLM:DEFAULT:"}, @@ -71,7 +74,8 @@ bool check_instance_name(const std::string& name) { bool XServiceClient::init(const std::string& etcd_addr, const std::string& instance_name, - const BlockManagerPool* block_manager_pool) { + const BlockManagerPool* block_manager_pool, + const std::string& etcd_namespace) { if (initialize_done_) { LOG(INFO) << "XServiceClient is already initialized, skipping."; return true; @@ -90,7 +94,23 @@ bool XServiceClient::init(const std::string& etcd_addr, chan_options_.max_retry = 3; chan_options_.timeout_ms = FLAGS_rpc_channel_timeout_ms; - etcd_client_ = std::make_unique(etcd_addr); + const std::string etcd_username = + util::get_optional_string_env(kEtcdUsernameEnvVar).value_or(""); + const std::string etcd_password = + util::get_optional_string_env(kEtcdPasswordEnvVar).value_or(""); + const bool has_etcd_auth_user = !etcd_username.empty(); + const bool has_etcd_auth_password = !etcd_password.empty(); + if (has_etcd_auth_user != has_etcd_auth_password) { + LOG(ERROR) << "Both " << kEtcdUsernameEnvVar << " and " + << kEtcdPasswordEnvVar << " must be set together."; + return false; + } + if (has_etcd_auth_user) { + etcd_client_ = std::make_unique( + etcd_addr, etcd_username, etcd_password, etcd_namespace); + } else { + etcd_client_ = std::make_unique(etcd_addr, etcd_namespace); + } // connect master xllm_service while (!etcd_client_->get_master_service(ETCD_MASTER_SERVICE_KEY, @@ -132,12 +152,15 @@ bool XServiceClient::init(const std::string& etcd_addr, // watch master xllm_service change auto master_func = std::bind(&XServiceClient::handle_master_service_watch, this, - std::placeholders::_1); + std::placeholders::_1, + std::placeholders::_2); etcd_client_->add_watch(ETCD_MASTER_SERVICE_KEY, master_func); // watch all xllm_service changes - auto xservices_func = std::bind( - &XServiceClient::handle_xservices_watch, this, std::placeholders::_1); + auto xservices_func = std::bind(&XServiceClient::handle_xservices_watch, + this, + std::placeholders::_1, + std::placeholders::_2); etcd_client_->add_watch(ETCD_XSERVICES_KEY_PREFIX, xservices_func); block_manager_pool_ = block_manager_pool; @@ -771,8 +794,8 @@ void XServiceClient::disconnect_xservice(const std::string& xservice_addr) { } } -void XServiceClient::handle_master_service_watch( - const etcd::Response& response) { +void XServiceClient::handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len) { if (response.events().empty() || exited_.load()) { return; } @@ -807,7 +830,8 @@ void XServiceClient::handle_master_service_watch( } } -void XServiceClient::handle_xservices_watch(const etcd::Response& response) { +void XServiceClient::handle_xservices_watch(const etcd::Response& response, + const uint64_t& prefix_len) { if (response.events().empty() || exited_.load()) { return; } @@ -817,17 +841,17 @@ void XServiceClient::handle_xservices_watch(const etcd::Response& response) { std::string service_addr; if (event.event_type() == etcd::Event::EventType::PUT) { if (event.has_kv()) { - event_key = event.kv().key(); + event_key = event.kv().key().substr(prefix_len); service_addr = event.kv().as_string(); } } else if (event.event_type() == etcd::Event::EventType::DELETE_) { if (event.has_prev_kv()) { - event_key = event.prev_kv().key(); + event_key = event.prev_kv().key().substr(prefix_len); service_addr = event.prev_kv().as_string(); } if (service_addr.empty() && event.has_kv()) { if (event_key.empty()) { - event_key = event.kv().key(); + event_key = event.kv().key().substr(prefix_len); } service_addr = event.kv().as_string(); } diff --git a/xllm/core/runtime/xservice_client.h b/xllm/core/runtime/xservice_client.h index 972bda19d..8d8dda0a5 100644 --- a/xllm/core/runtime/xservice_client.h +++ b/xllm/core/runtime/xservice_client.h @@ -45,7 +45,8 @@ class XServiceClient { ~XServiceClient(); bool init(const std::string& etcd_addr, const std::string& instance_name = "", - const BlockManagerPool* block_manager_pool = nullptr); + const BlockManagerPool* block_manager_pool = nullptr, + const std::string& etcd_namespace = ""); void set_scheduler(Scheduler* scheduler); void set_engine(Engine* engine); bool initialize_done() { return initialize_done_; } @@ -69,8 +70,10 @@ class XServiceClient { bool reconcile_registration(); void reconcile_registration_loop(); - void handle_master_service_watch(const etcd::Response& response); - void handle_xservices_watch(const etcd::Response& response); + void handle_master_service_watch(const etcd::Response& response, + const uint64_t& prefix_len); + void handle_xservices_watch(const etcd::Response& response, + const uint64_t& prefix_len); // connect to specific xllm_service bool connect_to_xservice(const std::string& xservice_addr); diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index a0b2fb93a..7942185d6 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -1044,7 +1044,7 @@ std::vector ContinuousScheduler::schedule_request( break; } // wait for new requests to arrive - constexpr uint64_t kStepSleepTimeMs = 10; + constexpr uint64_t kStepSleepTimeMs = 1; const auto time_to_sleep = std::min(absl::Milliseconds(kStepSleepTimeMs), deadline - now); absl::SleepFor(time_to_sleep); @@ -1119,7 +1119,7 @@ void ContinuousScheduler::generate() { while (num_pending_requests() > 0 || !batch_empty || request_queue_.size() > 0) { // build a batch of requests/sequences - const auto timeout = absl::Milliseconds(500); + const auto timeout = absl::Milliseconds(50); std::vector batch = schedule_request(timeout); batch_empty = true; for (auto& b : batch) { diff --git a/xllm/core/scheduler/continuous_scheduler_test.cpp b/xllm/core/scheduler/continuous_scheduler_test.cpp index 5fd8071bc..0df840c34 100644 --- a/xllm/core/scheduler/continuous_scheduler_test.cpp +++ b/xllm/core/scheduler/continuous_scheduler_test.cpp @@ -40,11 +40,13 @@ class FakeTokenizer : public Tokenizer { class FakeEngine : public Engine { public: - FakeEngine(int32_t num_blocks, int32_t block_size) { + FakeEngine(int32_t num_blocks, + int32_t block_size, + bool enable_prefix_cache = false) { BlockManagerPool::Options opt; opt.num_blocks_ = num_blocks; opt.block_size_ = block_size; - opt.enable_prefix_cache_ = false; // we dont consider prefix cache here + opt.enable_prefix_cache_ = enable_prefix_cache; fake_tokenizer_ = std::make_unique(); fake_block_manager_ = std::make_unique(opt, 1); } @@ -182,6 +184,37 @@ std::vector> generate_request( return requests; } +std::shared_ptr generate_request_with_prompt_tokens( + const std::vector& prompt_token_ids, + int32_t max_tokens, + int32_t max_context_len) { + RequestSamplingParam sampling_param; + SchedulerParam scheduler_param; + + StoppingChecker stopping_checker; + stopping_checker.set_max_generated_tokens(max_tokens); + stopping_checker.set_max_context_len(max_context_len); + stopping_checker.set_ignore_eos(true); + + RequestState req_state("x", + prompt_token_ids, + sampling_param, + scheduler_param, + stopping_checker, + prompt_token_ids.size() + 30000, + 1, + 1, + false, + false, + false, + false, + false, + nullptr, + nullptr); + + return std::make_shared("1", "1", "1", std::move(req_state), "1"); +} + // dont not consider speculative decoding. void update_requests(std::vector> requests) { for (auto req : requests) { @@ -651,4 +684,46 @@ TEST(ContinuousSchedulerTest, LatencySchedule) { // EXPECT_TRUE(scheduler->get_running_requests().size() == 2); } +TEST(BlockManagerPoolTest, AllocateFailureRollsBackSharedPrefixBlocks) { + auto engine = std::make_unique(3, 4, true); + BlockManagerPool* block_manager_pool = engine->block_manager_pool(); + + auto cached_request = + generate_request_with_prompt_tokens({1, 2, 3, 4, 5, 6, 7, 8}, 1, 30000); + auto failed_request = generate_request_with_prompt_tokens( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 1, 30000); + auto later_request = + generate_request_with_prompt_tokens({20, 21, 22, 23}, 1, 30000); + + auto* cached_sequence = cached_request->sequences()[0].get(); + ASSERT_TRUE(block_manager_pool->allocate(cached_sequence, + cached_sequence->num_tokens())); + cached_sequence->kv_state().set_kv_cache_tokens_num( + cached_sequence->num_tokens()); + block_manager_pool->deallocate(cached_sequence); + + const size_t free_blocks_before_failure = + util::max(block_manager_pool->num_free_blocks()); + const size_t used_blocks_before_failure = + util::min(block_manager_pool->num_used_blocks()); + EXPECT_EQ(free_blocks_before_failure, 0); + + auto* failed_sequence = failed_request->sequences()[0].get(); + EXPECT_FALSE(block_manager_pool->allocate(failed_sequence, + failed_sequence->num_tokens())); + EXPECT_EQ(failed_sequence->kv_state().num_kv_blocks(), 0); + EXPECT_EQ(failed_sequence->kv_state().shared_kv_blocks_num(), 0); + EXPECT_EQ(util::max(block_manager_pool->num_free_blocks()), + free_blocks_before_failure); + EXPECT_EQ(util::min(block_manager_pool->num_used_blocks()), + used_blocks_before_failure); + + auto* later_sequence = later_request->sequences()[0].get(); + EXPECT_TRUE(block_manager_pool->allocate(later_sequence, + later_sequence->num_tokens())); + EXPECT_EQ(later_sequence->kv_state().num_kv_blocks(), 1); + + (void)engine.release(); +} + } // namespace xllm diff --git a/xllm/core/scheduler/dit_scheduler.cpp b/xllm/core/scheduler/dit_scheduler.cpp index 8b6d3e37e..7a0553e60 100644 --- a/xllm/core/scheduler/dit_scheduler.cpp +++ b/xllm/core/scheduler/dit_scheduler.cpp @@ -48,9 +48,11 @@ void DiTAsyncResponseProcessor::process_failed_request( std::shared_ptr request, Status status) {} -DiTDynamicBatchScheduler::DiTDynamicBatchScheduler(DiTEngine* engine, +DiTDynamicBatchScheduler::DiTDynamicBatchScheduler(Engine* engine, const Options& options) - : options_(options), engine_(engine), request_queue_(kRequestQueueSize) { + : options_(options), + engine_(dynamic_cast(engine)), + request_queue_(kRequestQueueSize) { CHECK(engine_ != nullptr); response_handler_ = std::make_unique(); diff --git a/xllm/core/scheduler/dit_scheduler.h b/xllm/core/scheduler/dit_scheduler.h index ecade6f3f..46be5bfd8 100644 --- a/xllm/core/scheduler/dit_scheduler.h +++ b/xllm/core/scheduler/dit_scheduler.h @@ -26,13 +26,14 @@ limitations under the License. #include "common/macros.h" #include "common/types.h" +#include "distributed_runtime/dit_engine.h" +#include "distributed_runtime/engine.h" #include "framework/batch/dit_batch.h" #include "framework/request/dit_request.h" #include "scheduler.h" #include "util/threadpool.h" namespace xllm { -class DiTEngine; class DiTAsyncResponseProcessor final { public: @@ -65,7 +66,7 @@ class DiTScheduler : public SchedulerBase { class DiTDynamicBatchScheduler : public DiTScheduler { public: - DiTDynamicBatchScheduler(DiTEngine* engine, const Options& options); + DiTDynamicBatchScheduler(Engine* engine, const Options& options); virtual ~DiTDynamicBatchScheduler(); bool add_request(std::shared_ptr& request) override; diff --git a/xllm/core/scheduler/fixed_steps_scheduler.cpp b/xllm/core/scheduler/fixed_steps_scheduler.cpp index fb1a6cd91..e5c1a207d 100644 --- a/xllm/core/scheduler/fixed_steps_scheduler.cpp +++ b/xllm/core/scheduler/fixed_steps_scheduler.cpp @@ -28,7 +28,6 @@ limitations under the License. #include #include "common/metrics.h" -#include "common/rec_model_utils.h" #include "common/types.h" #include "core/common/global_flags.h" #include "distributed_runtime/engine.h" @@ -36,6 +35,7 @@ limitations under the License. #include "framework/batch/batch_factory.h" #include "framework/request/request.h" #include "framework/request/sequence.h" +#include "util/rec_model_utils.h" namespace xllm { diff --git a/xllm/core/scheduler/profile/profile_manager.cpp b/xllm/core/scheduler/profile/profile_manager.cpp index a2f082f7d..df4a4ddc5 100644 --- a/xllm/core/scheduler/profile/profile_manager.cpp +++ b/xllm/core/scheduler/profile/profile_manager.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,9 +28,9 @@ limitations under the License. #include #include "common/global_flags.h" -#include "common/rec_model_utils.h" #include "framework/batch/batch_factory.h" #include "framework/request/request_state.h" +#include "util/rec_model_utils.h" namespace xllm { @@ -618,6 +619,62 @@ std::shared_ptr ProfileManager::generate_single_request( return request; } +std::shared_ptr ProfileManager::generate_single_decode_request( + int32_t total_length) { + CHECK_GT(total_length, 1) << "Decode profiling requires total_length > 1."; + + auto& model_args = engine_->model_args(); + int32_t vocab_size = model_args.vocab_size(); + int32_t eos_token_id = model_args.eos_token_id(); + + std::random_device rd; + std::mt19937_64 gen(rd()); + + // If req_state does not initialize the stopchecker, default eos_token_id = 0, + // need to skip it + std::uniform_int_distribution dis(1, vocab_size - 2); + + const int32_t prompt_length = total_length - 1; + std::vector prompt_token_ids(prompt_length); + std::generate(prompt_token_ids.begin(), prompt_token_ids.end(), [&]() { + int32_t token = dis(gen); + return token == eos_token_id ? token + 1 : token; // skip eos + }); + + RequestState req_state(prompt_token_ids); + req_state.enable_schedule_overlap = options_.enable_schedule_overlap(); + req_state.seq_capacity = total_length + 1; + auto request = std::make_shared( + /*request_id=*/"", + /*x_request_id=*/"", + /*x_request_time=*/"", + req_state); + + auto* sequence = request->sequences()[0].get(); + if (!block_manager_pool_->BlockManagerPool::allocate(sequence, + total_length + 1)) { + LOG(FATAL) << "Profiling decode step time failed! Not enough blocks, total " + "length: " + << total_length; + } + sequence->kv_state().incr_kv_cache_tokens_num(prompt_length); + + int32_t generated_token = dis(gen); + generated_token = + generated_token == eos_token_id ? generated_token + 1 : generated_token; + sequence->append_token(generated_token); + + CHECK(sequence->stage() == SequenceStage::DECODE) + << "Decode profiling request is not in DECODE stage. total_length: " + << total_length << ", prompt_length: " << prompt_length + << ", kv_cache_tokens_num: " << sequence->kv_state().kv_cache_tokens_num() + << ", num_tokens: " << sequence->num_tokens(); + CHECK_EQ(sequence->num_generated_tokens(), 1) + << "Decode profiling request should start with one generated token."; + + return request; +} + // collect the latency of each step double ProfileManager::run_request(int32_t token_length, int32_t prefix_length, @@ -627,6 +684,9 @@ double ProfileManager::run_request(int32_t token_length, std::vector sequences; std::vector sequences_budget; std::vector> requests; + sequences.reserve(batch_size); + sequences_budget.reserve(batch_size); + requests.reserve(batch_size); // batch sequences with the same kv cahce and token length for (int32_t i = 0; i < batch_size; i++) { @@ -672,6 +732,9 @@ double ProfileManager::run_request( std::vector sequences; std::vector sequences_budget; std::vector> requests; + sequences.reserve(token_length_vec.size()); + sequences_budget.reserve(token_length_vec.size()); + requests.reserve(token_length_vec.size()); // batch sequences with the same kv cahce and token length for (int32_t i = 0; i < token_length_vec.size(); i++) { @@ -704,42 +767,69 @@ double ProfileManager::run_request( return latency; } -// Generate a batch of decode requests and execute it, then return the step -// latency. +double ProfileManager::run_decode_request( + const std::vector& total_length_vec) { + std::vector sequences; + std::vector sequences_budget; + std::vector> requests; + + for (int32_t total_length : total_length_vec) { + std::shared_ptr request = + generate_single_decode_request(total_length); + requests.emplace_back(request); + sequences.emplace_back(request->sequences()[0].get()); + sequences_budget.emplace_back(1); + } + + auto batches = + BatchFactory::get_instance(options_.dp_size()) + ->create_batches(requests, sequences, sequences_budget, nullptr); + + absl::Time start_time = absl::Now(); + engine_->step(batches); + if (options_.enable_schedule_overlap()) { + engine_->update_last_step_result(batches); + } + double latency = absl::ToDoubleMilliseconds(absl::Now() - start_time); + for (auto& request : requests) { + block_manager_pool_->deallocate_without_cache( + request->sequences()[0].get()); + } + + return latency; +} + +// Generate a batch of decode requests in DECODE stage and execute one decode +// step, then return the step latency. double ProfileManager::profile_decode_step_time(int32_t token_length, int32_t batch_size, int32_t min_context_len, int32_t max_context_len) { - double total_latency = 0; + double total_latency = 0.0; for (int32_t i = 0; i < profile_count_per_step_; ++i) { std::vector token_length_vec; - std::vector prefix_length_vec; generate_random_decode_batch(batch_size * token_length, batch_size, min_context_len, max_context_len, - token_length_vec, - prefix_length_vec); - double latency = run_request(token_length_vec, prefix_length_vec); - total_latency += latency; + token_length_vec); + total_latency += run_decode_request(token_length_vec); } return total_latency / profile_count_per_step_; } -// Generate a batch of random decode requests with an average length of -// token_length. +// Generate a batch of random decode requests with an average total sequence +// length of token_length. void ProfileManager::generate_random_decode_batch( int32_t total_length, int32_t batch_size, int32_t min_context_len, int32_t max_context_len, - std::vector& token_length_vec, - std::vector& prefix_length_vec) { + std::vector& token_length_vec) { CHECK(total_length >= batch_size * min_context_len); CHECK(total_length <= batch_size * max_context_len); token_length_vec.resize(batch_size, min_context_len); - prefix_length_vec.resize(batch_size, min_context_len - 1); int remain = total_length - batch_size * min_context_len; std::random_device rd; @@ -755,7 +845,6 @@ void ProfileManager::generate_random_decode_batch( std::uniform_int_distribution dis(0, max); int add = dis(gen); token_length_vec[i] += add; - prefix_length_vec[i] += add; remain -= add; } @@ -763,7 +852,6 @@ void ProfileManager::generate_random_decode_batch( while (remain > 0) { if (token_length_vec[idx % batch_size] < max_context_len) { token_length_vec[idx % batch_size] += 1; - prefix_length_vec[idx % batch_size] += 1; --remain; } ++idx; diff --git a/xllm/core/scheduler/profile/profile_manager.h b/xllm/core/scheduler/profile/profile_manager.h index 2b18c5c79..68fb26d48 100644 --- a/xllm/core/scheduler/profile/profile_manager.h +++ b/xllm/core/scheduler/profile/profile_manager.h @@ -93,8 +93,8 @@ class ProfileManager { double run_request(const std::vector& token_length_vec, const std::vector& prefix_length_vec); - // Generate a batch of decode requests and execute it, then return the step - // latency. + // Generate a batch of decode requests in DECODE stage and execute one decode + // step, then return the step latency. double profile_decode_step_time(int32_t token_length, int32_t batch_size, int32_t min_context_len, @@ -129,6 +129,7 @@ class ProfileManager { std::shared_ptr generate_single_request(int32_t token_length, int32_t prefix_length); + std::shared_ptr generate_single_decode_request(int32_t total_length); std::string generate_filename(const std::string& file_suffix); @@ -147,14 +148,15 @@ class ProfileManager { int32_t lower_bound, int32_t upper_bound); - // Generate a batch of random decode requests with an average length of - // token_length. + // Generate a batch of random decode requests with an average total sequence + // length of token_length. void generate_random_decode_batch(int32_t total_length, int32_t batch_size, int32_t min_context_len, int32_t max_context_len, - std::vector& token_length_vec, - std::vector& prefix_length_vec); + std::vector& token_length_vec); + + double run_decode_request(const std::vector& total_length_vec); static const std::vector& get_copy_block_profile(); @@ -179,4 +181,4 @@ class ProfileManager { int32_t profile_token_budget_ = std::numeric_limits::max(); }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/scheduler/scheduler_factory.cpp b/xllm/core/scheduler/scheduler_factory.cpp index 9e5120e71..463905016 100644 --- a/xllm/core/scheduler/scheduler_factory.cpp +++ b/xllm/core/scheduler/scheduler_factory.cpp @@ -61,7 +61,7 @@ std::unique_ptr create_continuous_scheduler( } std::unique_ptr create_dit_scheduler( - DiTEngine* engine, + Engine* engine, DiTScheduler::Options options) { return std::make_unique(engine, options); } diff --git a/xllm/core/scheduler/scheduler_factory.h b/xllm/core/scheduler/scheduler_factory.h index e28e74db3..461a6585f 100644 --- a/xllm/core/scheduler/scheduler_factory.h +++ b/xllm/core/scheduler/scheduler_factory.h @@ -27,7 +27,7 @@ std::unique_ptr create_continuous_scheduler( ContinuousScheduler::Options options); std::unique_ptr create_dit_scheduler( - DiTEngine* engine, + Engine* engine, DiTScheduler::Options options); std::unique_ptr create_fixed_steps_scheduler( diff --git a/xllm/core/util/CMakeLists.txt b/xllm/core/util/CMakeLists.txt index 753ee04e6..a05eea53b 100644 --- a/xllm/core/util/CMakeLists.txt +++ b/xllm/core/util/CMakeLists.txt @@ -20,6 +20,7 @@ cc_library( lightweightsemaphore.h net.h pretty_print.h + rec_model_utils.h scope_guard.h slice.h spin_lock.h diff --git a/xllm/core/util/env_var.cpp b/xllm/core/util/env_var.cpp index 4778be943..7be665209 100644 --- a/xllm/core/util/env_var.cpp +++ b/xllm/core/util/env_var.cpp @@ -62,6 +62,14 @@ std::string get_string_env(const std::string& name) { return std::string(val); } +std::optional get_optional_string_env(const std::string& name) { + const char* val = std::getenv(name.c_str()); + if (val == nullptr) { + return std::nullopt; + } + return std::string(val); +} + double get_double_env(const std::string& key, double defaultValue = -1) { const char* val = std::getenv(key.c_str()); if (val == nullptr) { diff --git a/xllm/core/util/env_var.h b/xllm/core/util/env_var.h index 9d314ae24..d81d09cf6 100644 --- a/xllm/core/util/env_var.h +++ b/xllm/core/util/env_var.h @@ -29,6 +29,7 @@ bool get_bool_env(const std::string& key, bool defaultValue); int64_t get_int_env(const std::string& key, int64_t defaultValue); std::string get_string_env(const std::string& name); +std::optional get_optional_string_env(const std::string& name); // Get the timeout in seconds for process group test operations. // This timeout is used when waiting for process group initialization tests diff --git a/xllm/core/common/rec_model_utils.h b/xllm/core/util/rec_model_utils.h similarity index 100% rename from xllm/core/common/rec_model_utils.h rename to xllm/core/util/rec_model_utils.h diff --git a/xllm/core/util/shared_memory_manager.cpp b/xllm/core/util/shared_memory_manager.cpp index 93cd4710e..e3391fdd1 100644 --- a/xllm/core/util/shared_memory_manager.cpp +++ b/xllm/core/util/shared_memory_manager.cpp @@ -44,7 +44,6 @@ SharedMemoryManager::SharedMemoryManager(const std::string& name, // First try to create exclusively (O_CREAT | O_EXCL) fd_ = shm_open(name.c_str(), O_CREAT | O_RDWR | O_EXCL, 0666); is_creator = (fd_ != -1); - // If creation failed, try opening existing if (!is_creator) { fd_ = shm_open(name.c_str(), O_RDWR, 0666); diff --git a/xllm/core/util/tensor_helper.h b/xllm/core/util/tensor_helper.h index 69d358559..b7c468f69 100644 --- a/xllm/core/util/tensor_helper.h +++ b/xllm/core/util/tensor_helper.h @@ -328,32 +328,61 @@ inline std::optional try_get_scalar_type_from_string( inline torch::Tensor get_tensor_from_blob(const std::vector& dims, const torch::ScalarType dtype, const void* dev_addr) { +#if defined(USE_NPU) c10::DeviceType device_type = c10::DeviceType::PrivateUse1; torch::TensorOptions option = torch::TensorOptions().dtype(dtype).device(device_type); auto tensor = torch::empty({0}, option); -#if defined(USE_NPU) auto address = const_cast(dev_addr); torch::DataPtr c10_data_ptr(address, address, [](void*) {}, tensor.device()); size_t tensor_nbytes = at::detail::computeStorageNbytesContiguous( dims, tensor.dtype().itemsize()); torch::Storage storage; - // get npu storage constructor from register and construct storage auto fptr = c10::GetStorageImplCreate(device_type); auto allocator = c10::GetAllocator(device_type); - // PyTorch 2.7+: StorageImpl now takes DataPtr instead of raw allocator storage = fptr(c10::StorageImpl::use_byte_size_t(), c10::SymInt(tensor_nbytes), std::move(c10_data_ptr), allocator, - true); + /*resizable=*/true); tensor.set_(storage, 0, dims); -#endif return tensor; +#elif defined(USE_CUDA) + auto options = torch::TensorOptions() + .dtype(dtype) + .device(torch::kCUDA) + .requires_grad( + /*requires_grad=*/false); + return torch::from_blob(const_cast(dev_addr), dims, options); +#else + LOG(FATAL) << "get_tensor_from_blob only supports NPU and CUDA devices"; +#endif +} + +inline torch::Tensor get_tensor_from_blob(const std::vector& dims, + const torch::ScalarType dtype, + const void* dev_addr, + const torch::Tensor& owner) { +#if defined(USE_CUDA) + CHECK(owner.defined()) + << "get_tensor_from_blob requires a valid owner tensor on CUDA"; + + auto options = torch::TensorOptions() + .dtype(dtype) + .device(torch::kCUDA) + .requires_grad( + /*requires_grad=*/false); + auto owner_ref = owner; + auto deleter = [owner_ref](void*) {}; + return torch::from_blob(const_cast(dev_addr), dims, deleter, options); +#else + (void)owner; + return get_tensor_from_blob(dims, dtype, dev_addr); +#endif } inline int32_t get_dtype_size(torch::ScalarType dtype) { diff --git a/xllm/core/util/threadpool.cpp b/xllm/core/util/threadpool.cpp index 1a6c2f2b7..a275d5ab9 100644 --- a/xllm/core/util/threadpool.cpp +++ b/xllm/core/util/threadpool.cpp @@ -15,20 +15,81 @@ limitations under the License. #include "threadpool.h" +#include +#include +#include + +#include +#include #include namespace xllm { + +namespace { + +int32_t bind_thread_to_cpu_core(int32_t cpu_core) { + if (cpu_core < 0 || cpu_core >= CPU_SETSIZE) { + LOG(ERROR) << "Invalid CPU core " << cpu_core << ", valid range is [0, " + << CPU_SETSIZE - 1 << "]"; + return -1; + } + + cpu_set_t current_affinity; + CPU_ZERO(¤t_affinity); + if (sched_getaffinity(0, sizeof(cpu_set_t), ¤t_affinity) != 0) { + LOG(ERROR) << "Failed to get current process affinity: " << strerror(errno); + return -1; + } + if (!CPU_ISSET(cpu_core, ¤t_affinity)) { + LOG(ERROR) << "CPU core " << cpu_core + << " is not in the current process affinity set"; + return -1; + } + + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); + CPU_SET(cpu_core, &cpu_set); + + if (pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set) != + 0) { + LOG(ERROR) << "Failed to bind thread to CPU core " << cpu_core << ": " + << strerror(errno); + return -1; + } + + LOG(INFO) << "Successfully bound thread to CPU core " << cpu_core; + return 0; +} + +} // namespace + ThreadPool::ThreadPool(size_t num_threads) : ThreadPool(num_threads, nullptr) {} ThreadPool::ThreadPool(size_t num_threads, Runnable init_func) + : ThreadPool(num_threads, std::move(init_func), {}) {} + +ThreadPool::ThreadPool(size_t num_threads, std::vector cpu_cores) + : ThreadPool(num_threads, nullptr, std::move(cpu_cores)) {} + +ThreadPool::ThreadPool(size_t num_threads, + Runnable init_func, + std::vector cpu_cores) : queues_(num_threads) { + if (!cpu_cores.empty() && cpu_cores.size() != num_threads) { + LOG(WARNING) << "ThreadPool: cpu_cores.size() (" << cpu_cores.size() + << ") != num_threads (" << num_threads + << "), CPU core binding will be skipped"; + cpu_cores.clear(); + } BlockingCounter counter(num_threads); for (size_t i = 0; i < num_threads; ++i) { + int32_t cpu_core = cpu_cores.empty() ? -1 : cpu_cores[i]; threads_.emplace_back([this, i, + cpu_core, init_func_ptr = &init_func, counter_ptr = &counter]() mutable { - internal_loop(i, init_func_ptr, counter_ptr); + internal_loop(i, init_func_ptr, counter_ptr, cpu_core); }); } counter.wait(); @@ -72,7 +133,12 @@ void ThreadPool::schedule_with_tid(Runnable runnable, size_t tid) { void ThreadPool::internal_loop(size_t index, Runnable* init_func, - BlockingCounter* block_counter) { + BlockingCounter* block_counter, + int32_t cpu_core) { + if (cpu_core >= 0 && bind_thread_to_cpu_core(cpu_core) != 0) { + LOG(WARNING) << "Thread " << index << " CPU binding to core " << cpu_core + << " failed, running unbound"; + } if (init_func != nullptr && *init_func != nullptr) { (*init_func)(); } diff --git a/xllm/core/util/threadpool.h b/xllm/core/util/threadpool.h index d639d342d..a0869e851 100644 --- a/xllm/core/util/threadpool.h +++ b/xllm/core/util/threadpool.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "concurrent_queue.h" #include "util/blocking_counter.h" @@ -41,6 +42,14 @@ class ThreadPool final { explicit ThreadPool(size_t num_threads); explicit ThreadPool(size_t num_threads, Runnable init_func); + // Bind each worker thread to the corresponding CPU core in cpu_cores. + // cpu_cores[i] is the CPU core ID for thread i. If cpu_cores is empty, + // no binding is performed. If cpu_cores.size() does not equal num_threads, + // a warning will be logged and CPU core binding will be skipped. + explicit ThreadPool(size_t num_threads, std::vector cpu_cores); + explicit ThreadPool(size_t num_threads, + Runnable init_func, + std::vector cpu_cores); // schedule a runnable to be executed int32_t schedule(Runnable runnable); @@ -58,7 +67,8 @@ class ThreadPool final { private: void internal_loop(size_t tid, Runnable* init_func, - BlockingCounter* block_counter); + BlockingCounter* block_counter, + int32_t cpu_core); std::vector threads_; std::vector> queues_; diff --git a/xllm/core/util/threadpool_test.cpp b/xllm/core/util/threadpool_test.cpp index 08cdf3c01..fdbb9f061 100644 --- a/xllm/core/util/threadpool_test.cpp +++ b/xllm/core/util/threadpool_test.cpp @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include namespace xllm { @@ -87,4 +89,84 @@ TEST(ThreadPoolTest, MultipleThreads) { EXPECT_EQ(counter, 10); } +TEST(ThreadPoolTest, CpuCoreBindingConstructor) { + // Construct with cpu_cores binding — should not crash even if binding fails + // (e.g., in containers with restricted affinity). + std::vector cpu_cores = {0, 0}; // bind both threads to core 0 + ThreadPool threadpool(2, cpu_cores); + EXPECT_EQ(threadpool.size(), 2); + + std::atomic counter{0}; + absl::Notification notification; + for (int i = 0; i < 2; ++i) { + threadpool.schedule([&counter, ¬ification]() { + if (++counter == 2) { + notification.Notify(); + } + }); + } + EXPECT_TRUE( + notification.WaitForNotificationWithTimeout(absl::Milliseconds(500))); + EXPECT_EQ(counter, 2); +} + +TEST(ThreadPoolTest, CpuCoreBindingWithInitFunc) { + std::vector cpu_cores = {0}; + std::atomic init_called{false}; + absl::Notification init_done; + ThreadPool threadpool( + 1, + [&init_called, &init_done]() { + init_called = true; + init_done.Notify(); + }, + cpu_cores); + EXPECT_TRUE( + init_done.WaitForNotificationWithTimeout(absl::Milliseconds(500))); + EXPECT_TRUE(init_called); +} + +TEST(ThreadPoolTest, CpuCoreBindingMismatchFallback) { + // Mismatched cpu_cores size — should fall back to no binding gracefully. + std::vector cpu_cores = {0, 1}; // 2 cores but 4 threads + ThreadPool threadpool(4, cpu_cores); + EXPECT_EQ(threadpool.size(), 4); + + std::atomic counter{0}; + absl::Notification notification; + for (int i = 0; i < 4; ++i) { + threadpool.schedule([&counter, ¬ification]() { + if (++counter == 4) { + notification.Notify(); + } + }); + } + EXPECT_TRUE( + notification.WaitForNotificationWithTimeout(absl::Milliseconds(500))); + EXPECT_EQ(counter, 4); +} + +TEST(ThreadPoolTest, CpuCoreBindingVerifyAffinity) { + // Verify that after construction the thread is actually bound to the + // requested core (if the system allows it). + const int32_t target_core = 0; + std::vector cpu_cores = {target_core}; + + absl::Notification done; + std::atomic affinity_ok{false}; + + ThreadPool threadpool(1, cpu_cores); + threadpool.schedule([&done, &affinity_ok, target_core]() { + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); + if (pthread_getaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set) == + 0) { + affinity_ok = CPU_ISSET(target_core, &cpu_set); + } + done.Notify(); + }); + EXPECT_TRUE(done.WaitForNotificationWithTimeout(absl::Milliseconds(500))); + EXPECT_TRUE(affinity_ok); +} + } // namespace xllm diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 2532d615d..1e4757cd4 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -150,7 +150,7 @@ std::vector cal_vec_split_index(uint32_t vec_size, return split_index; } -torch::Dtype convert_rec_type_to_torch(proto::DataType data_type) { +torch::ScalarType convert_rec_type_to_torch(proto::DataType data_type) { // Future extensions go here. switch (data_type) { case proto::DataType::FLOAT: @@ -172,8 +172,7 @@ torch::Dtype convert_rec_type_to_torch(proto::DataType data_type) { return torch::kInt16; default: - throw std::runtime_error("Unsupported data type: " + - std::to_string(static_cast(data_type))); + LOG(FATAL) << "Unsupported data type: " << static_cast(data_type); } } @@ -186,12 +185,12 @@ torch::Tensor convert_rec_tensor_to_torch( } if (!input_tensor.has_contents()) { - throw std::runtime_error("Input tensor '" + input_tensor.name() + - "' has no contents"); + LOG(FATAL) << "Input tensor '" << input_tensor.name() + << "' has no contents"; } const auto& contents = input_tensor.contents(); - torch::Dtype dtype = convert_rec_type_to_torch(input_tensor.data_type()); + torch::ScalarType dtype = convert_rec_type_to_torch(input_tensor.data_type()); switch (dtype) { case torch::kFloat32: { @@ -240,8 +239,8 @@ torch::Tensor convert_rec_tensor_to_torch( } default: - throw std::runtime_error("Unhandled data type conversion for: " + - std::to_string(static_cast(dtype))); + LOG(FATAL) << "Unhandled data type conversion for: " + << static_cast(dtype); } } @@ -483,6 +482,9 @@ torch::Tensor proto_to_torch(const proto::Tensor& proto_tensor) { data_ptr = get_data_from_contents(proto_contents, proto_datatype); data_count = proto_contents.bytes_contents().size() / static_cast(sizeof(torch::Half)); + } else if (proto_datatype == "BYTES") { + data_ptr = get_data_from_contents(proto_contents, proto_datatype); + data_count = proto_contents.bytes_contents().size(); } if (data_ptr == nullptr) { @@ -652,4 +654,4 @@ int32_t ceil_pow2(int32_t n) { } } // namespace util -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 7eeeabc75..9f57375e7 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -124,6 +124,26 @@ inline bool is_mla_model_type(std::string_view model_type) { return mla_model_type_set().contains(std::string(model_type)); } +inline std::string get_model_name( + const std::filesystem::path& normalized_model_path) { + std::string model_name; + + if (normalized_model_path.has_filename()) { + model_name = normalized_model_path.filename().string(); + } else { + model_name = normalized_model_path.parent_path().filename().string(); + } + + if (model_name.empty()) { + LOG(FATAL) << "Cannot extract model name from path, as it appears to be a " + "root directory: " + << normalized_model_path.string(); + return ""; + } + + return model_name; +} + inline std::string get_model_type(const std::filesystem::path& model_path) { JsonReader reader; std::filesystem::path config_json_path = model_path / "config.json"; diff --git a/xllm/models/dit/clip_text_model.h b/xllm/models/dit/clip_text_model.h index 30942e16a..d2e5a5943 100644 --- a/xllm/models/dit/clip_text_model.h +++ b/xllm/models/dit/clip_text_model.h @@ -16,7 +16,9 @@ limitations under the License. #pragma once +#if defined(USE_NPU) #include +#endif #include #include @@ -27,13 +29,17 @@ limitations under the License. #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/model/model_input_params.h" #include "core/framework/model_context.h" +#if defined(USE_NPU) #include "core/layers/npu/npu_siglip_encoder_layer_impl.h" +#endif #include "models/model_registry.h" #include "processors/clip_image_processor.h" -#include "processors/input_processor.h" +#include "processors/clip_input_processor.h" #include "processors/pywarpper_image_processor.h" #include "xllm/core/layers/common/add_matmul.h" +#if defined(USE_NPU) #include "xllm_atb_layers/core/include/atb_speed/log.h" +#endif namespace xllm { // clip_text_model compatible with huggingface weights @@ -59,96 +65,6 @@ torch::Tensor _create_4d_causal_attention_mask(torch::IntArrayRef input_shape, return causal_mask; } -class CLIPVLInputProcessor : public InputProcessor { - enum class TokenType { - INVALID, - IMAGE, - VIDEO, - }; - - public: - explicit CLIPVLInputProcessor(const ModelArgs& args) { - merge_size_ = args.mm_image_merge_size(); - } - void process(std::string& prompt, const MMData& mm_data) override { - torch::Tensor image_grid_thw; - if (auto res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - torch::Tensor video_grid_thw; - if (auto res = mm_data.get("video_grid_thw")) - video_grid_thw = res.value(); - if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; - auto merge_length = merge_size_ * merge_size_; - int total_image_token = 0; - if (image_grid_thw.defined()) { - auto count = image_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_image_token += - image_grid_thw[idx].prod().item() / merge_length; - } - int total_video_token = 0; - if (video_grid_thw.defined()) { - auto count = video_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_video_token += - video_grid_thw[idx].prod().item() / merge_length; - } - size_t total_token_len = total_image_token * image_token_.size() + - total_video_token * video_token_.size(); - std::string data; - data.reserve(prompt.size() + total_token_len); - int image_index = 0; - int video_index = 0; - const torch::Tensor* grid_thw = nullptr; - const std::string* token = nullptr; - int* index = 0; - size_t begin = 0; - auto pair = _find_vision_token(prompt, begin); - while (pair.second != std::string::npos) { - data.append(prompt, begin, pair.second - begin); - if (pair.first == TokenType::IMAGE) { - grid_thw = &image_grid_thw; - token = &image_token_; - index = &image_index; - } else if (pair.first == TokenType::VIDEO) { - grid_thw = &video_grid_thw; - token = &video_token_; - index = &video_index; - } else { - assert(false); - } - auto token_num = (*grid_thw)[(*index)].prod().item() / merge_length; - while (token_num--) data.append(*token); - ++(*index); - begin = pair.second + token->size(); - pair = _find_vision_token(prompt, begin); - } - if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); - prompt = std::move(data); - } - - private: - std::pair _find_vision_token(const std::string& prompt, - size_t begin) { - auto img_pos = prompt.find(image_token_, begin); - auto vid_pos = prompt.find(video_token_, begin); - if (img_pos == std::string::npos && vid_pos == std::string::npos) - return {TokenType::INVALID, std::string::npos}; - else if (vid_pos == std::string::npos) - return {TokenType::IMAGE, img_pos}; - else if (img_pos == std::string::npos) - return {TokenType::VIDEO, vid_pos}; - else - return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) - : std::make_pair(TokenType::VIDEO, vid_pos); - } - - private: - const std::string image_token_ = "<|image_pad|>"; - const std::string video_token_ = "<|video_pad|>"; - int merge_size_ = 0; -}; - class CLIPTextEmbeddingImpl : public torch::nn::Module { public: explicit CLIPTextEmbeddingImpl(const ModelContext& context) { diff --git a/xllm/models/dit/flowmatch_euler_discrete_scheduler.h b/xllm/models/dit/flowmatch_euler_discrete_scheduler.h index 8986f0656..44d7f13e1 100644 --- a/xllm/models/dit/flowmatch_euler_discrete_scheduler.h +++ b/xllm/models/dit/flowmatch_euler_discrete_scheduler.h @@ -42,7 +42,13 @@ class FlowMatchEulerDiscreteSchedulerImpl : public torch::nn::Module { max_shift_ = args_.max_shift(), base_image_seq_len_ = args_.base_image_seq_len(); max_image_seq_len_ = args_.max_image_seq_len(); - shift_terminal_ = std::nullopt; + // shift_terminal_ = static_cast(args_.shift_terminal()) == -1 ? + // std::nullopt : args_.shift_terminal(); + if (static_cast(args_.shift_terminal()) == -1) { + shift_terminal_ = std::nullopt; + } else { + shift_terminal_ = args_.shift_terminal(); + } time_shift_type_ = "exponential"; std::vector timesteps_vec(num_train_timesteps_); for (int i = 0; i < num_train_timesteps_; ++i) { @@ -385,6 +391,7 @@ TORCH_MODULE(FlowMatchEulerDiscreteScheduler); REGISTER_MODEL_ARGS(FlowMatchEulerDiscreteScheduler, [&] { LOAD_ARG_OR(num_train_timesteps, "num_train_timesteps", 1000); LOAD_ARG_OR(shift, "shift", 1); + LOAD_ARG_OR(shift_terminal, "shift_terminal", -1); LOAD_ARG_OR(use_dynamic_shifting, "use_dynamic_shifting", true); LOAD_ARG_OR(base_shift, "base_shift", 0.5f); LOAD_ARG_OR(max_shift, "max_shift", 1.15f); diff --git a/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h b/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h new file mode 100644 index 000000000..ab98b5ea6 --- /dev/null +++ b/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h @@ -0,0 +1,2082 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "core/framework/dit_model_loader.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/state_dict/state_dict.h" +#include "core/layers/common/add_matmul.h" +#include "framework/model_context.h" +#include "models/dit/utils/common_util.h" +#include "models/model_registry.h" + +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#else +#include +#include +#endif + +// VAE model compatible with huggingface weights +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py + +namespace xllm::dit::npu { +namespace qwenimage { + +class QwenImageBaseModule : public torch::nn::Module { + public: + virtual torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) = 0; + virtual ~QwenImageBaseModule() = default; +}; + +const int64_t CACHE_T = 2; + +class QwenImageCausalConv3dImpl : public torch::nn::Module { + public: + QwenImageCausalConv3dImpl(const ModelContext& context, + int64_t in_channels, + int64_t out_channels, + torch::IntArrayRef kernel_size, + torch::IntArrayRef stride = 1, + torch::IntArrayRef padding = 0) { + conv_ = register_module( + "conv", + torch::nn::Conv3d( + torch::nn::Conv3dOptions(in_channels, out_channels, kernel_size) + .stride(stride) + .padding(0) + .bias(true))); + + auto p = padding.size() == 1 + ? std::vector{padding[0], padding[0], padding[0]} + : std::vector(padding.begin(), padding.end()); + + padding_ = {p[2], p[2], p[1], p[1], 2 * p[0], 0}; + } + + torch::Tensor forward(const torch::Tensor& x, + const torch::Tensor& cache_x = torch::Tensor()) { + auto padding_vec = padding_; + auto result_x = x; + + if (cache_x.defined() && padding_[4] > 0) { + auto device_x = result_x.device(); + auto cache_device = cache_x.to(device_x); + result_x = torch::cat({cache_device, result_x}, 2); + padding_vec[4] -= cache_x.size(2); + } + + result_x = torch::nn::functional::pad( + result_x, torch::nn::functional::PadFuncOptions(padding_vec)); + return conv_(result_x); + } + + void load_state_dict(const StateDict& state_dict) { + weight::load_weight(state_dict, "weight", conv_->weight, is_weight_loaded_); + weight::load_weight(state_dict, "bias", conv_->bias, is_bias_loaded_); + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_weight_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(is_bias_loaded_) << "weight is not loaded for " << prefix + "bias"; + } + + private: + bool is_weight_loaded_{false}; + bool is_bias_loaded_{false}; + torch::nn::Conv3d conv_ = nullptr; + std::vector padding_; +}; + +TORCH_MODULE(QwenImageCausalConv3d); + +class QwenImageRMS_normImpl : public torch::nn::Module { + public: + QwenImageRMS_normImpl(const ModelContext& context, + int64_t dim, + bool channel_first = true, + bool images = true, + bool is_bias = false, + bool fused = false) + : channel_first_(channel_first), fused_(fused), is_bias_(is_bias) { + auto broadcastable_dims = + images ? std::vector{1, 1} : std::vector{1, 1, 1}; + auto shape = std::vector{dim}; + if (channel_first) { + shape.insert( + shape.end(), broadcastable_dims.begin(), broadcastable_dims.end()); + } + + scale_ = std::sqrt(dim); + weight_ = register_parameter("gamma", torch::ones(shape)); + + if (is_bias_) { + bias_ = register_parameter("bias", torch::zeros(shape)); + } + } + + torch::Tensor forward(const torch::Tensor& x) { + if (fused_) { + auto [output, rstd] = + at_npu::native::custom_ops::npu_rms_norm(x, weight_, 0); + + if (is_bias_ && bias_.defined()) { + output = output + bias_.to(output.device()); + } + return output; + } else { + auto output = torch::nn::functional::normalize( + x, + torch::nn::functional::NormalizeFuncOptions().dim( + channel_first_ ? 1 : -1)) * + scale_ * weight_; + if (is_bias_) { + output = output + bias_; + } + return output; + } + } + + void load_state_dict(const StateDict& state_dict) { + weight::load_weight(state_dict, "gamma", weight_, is_weight_loaded_); + if (is_bias_) { + weight::load_weight(state_dict, "bias", bias_, is_bias_loaded_); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_weight_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(!is_bias_ || is_bias_loaded_) + << "bias is not loaded for " << prefix + "bias"; + } + + private: + bool channel_first_; + double scale_; + bool is_bias_; + bool fused_; + bool is_weight_loaded_{false}; + bool is_bias_loaded_{false}; + torch::Tensor weight_; + torch::Tensor bias_; + torch::TensorOptions options_; +}; + +TORCH_MODULE(QwenImageRMS_norm); + +class QwenImageUpsampleImpl : public torch::nn::Module { + public: + QwenImageUpsampleImpl( + const ModelContext& context, + const torch::nn::functional::InterpolateFuncOptions options) + : options_(options) {} + + torch::Tensor forward(const torch::Tensor& x) { + // auto result = upsample_(x.to(torch::kFloat)); + auto result = + torch::nn::functional::interpolate(x.to(torch::kFloat), options_); + return result.to(x.dtype()); + } + + private: + torch::nn::functional::InterpolateFuncOptions options_; + torch::nn::Upsample upsample_ = nullptr; +}; + +TORCH_MODULE(QwenImageUpsample); + +class QwenImageResampleImpl : public QwenImageBaseModule { + public: + QwenImageResampleImpl(const ModelContext& context, + int64_t dim, + const std::string& mode) + : dim_(dim), mode_(mode) { + if (mode_ == "upsample2d") { + resample_ = register_module( + "resample", + torch::nn::Sequential( + QwenImageUpsample(context, + torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector{2.0, 2.0}) + .mode(torch::kNearestExact)), + torch::nn::Conv2d( + torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim / 2, + /*kernel_size=*/3) + .padding(1)))); + } else if (mode_ == "upsample3d") { + resample_ = register_module( + "resample", + torch::nn::Sequential( + QwenImageUpsample(context, + torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector{2.0, 2.0}) + .mode(torch::kNearestExact)), + torch::nn::Conv2d( + torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim / 2, + /*kernel_size=*/3) + .padding(1)))); + + time_conv_ = register_module( + "time_conv", + QwenImageCausalConv3d(context, + /*in_channels=*/dim, + /*out_channels=*/dim * 2, + /*kernel_size=*/torch::IntArrayRef{3, 1, 1}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 0, 0})); + + } else if (mode_ == "downsample2d") { + resample_ = register_module( + "resample", + torch::nn::Sequential( + torch::nn::ZeroPad2d(torch::nn::ZeroPad2dOptions({/*left=*/0, + /*right=*/1, + /*top=*/0, + /*bottom=*/1})), + torch::nn::Conv2d(torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim, + /*kernel_size=*/3) + .stride(2)))); + } else if (mode_ == "downsample3d") { + resample_ = register_module( + "resample", + torch::nn::Sequential( + torch::nn::ZeroPad2d(torch::nn::ZeroPad2dOptions({/*left=*/0, + /*right=*/1, + /*top=*/0, + /*bottom=*/1})), + torch::nn::Conv2d(torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim, + /*kernel_size=*/3) + .stride(2)))); + time_conv_ = register_module( + "time_conv", + QwenImageCausalConv3d(context, + /*in_channels=*/dim, + /*out_channels=*/dim, + /*kernel_size=*/torch::IntArrayRef{3, 1, 1}, + /*stride=*/torch::IntArrayRef{2, 1, 1}, + /*padding=*/torch::IntArrayRef{0, 0, 0})); + } else { + resample_ = register_module("resample", + torch::nn::Sequential(torch::nn::Identity())); + } + + rep_tensor_ = register_parameter("rep_tensor", torch::tensor({-999.0})); + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) override { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + auto sizes = x.sizes(); + auto b = sizes[0], c = sizes[1], t = sizes[2], h = sizes[3], w = sizes[4]; + auto result_x = x; + + if (mode_ == "upsample3d" && feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + + if (idx < feat_cache->size() && feat_cache->at(idx).defined()) { + auto cache_x = result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice( + -CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined() && + !torch::equal(rep_tensor_, feat_cache->at(idx))) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined() && + torch::equal(rep_tensor_, feat_cache->at(idx))) { + cache_x = torch::cat( + {torch::zeros_like(cache_x).to(cache_x.device()), cache_x}, 2); + } + if (torch::equal(rep_tensor_, feat_cache->at(idx))) { + result_x = time_conv_->forward(result_x); + } else { + result_x = time_conv_->forward(result_x, feat_cache->at(idx)); + } + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + + result_x = result_x.reshape({b, 2, c, t, h, w}); + result_x = torch::stack({result_x.index({torch::indexing::Slice(), 0}), + result_x.index({torch::indexing::Slice(), 1})}, + 3); + result_x = result_x.reshape({b, c, t * 2, h, w}); + } else { + feat_cache->at(idx) = rep_tensor_; + (*feat_idx)[0]++; + } + } + + t = result_x.size(2); + result_x = result_x.permute({0, 2, 1, 3, 4}).reshape({b * t, c, h, w}); + result_x = resample_->forward(result_x); + result_x = + result_x + .view({b, t, result_x.size(1), result_x.size(2), result_x.size(3)}) + .permute({0, 2, 1, 3, 4}); + + if (mode_ == "downsample3d" && feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + + if (idx < feat_cache->size() && feat_cache->at(idx).defined()) { + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .clone(); + + auto concat_x = torch::cat( + {feat_cache->at(idx).index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}), + result_x}, + 2); + + result_x = time_conv_->forward(concat_x); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + feat_cache->at(idx) = result_x.clone(); + (*feat_idx)[0]++; + } + } + + return result_x; + } + + void load_state_dict(const StateDict& state_dict) { + auto params = resample_->named_parameters(); + for (auto& param : params) { + std::string name = param.key(); + if (name == "1.weight") { + weight::load_weight( + state_dict, "resample.1.weight", param.value(), is_weight_loaded_); + } else if (name == "1.bias") { + weight::load_weight( + state_dict, "resample.1.bias", param.value(), is_bias_loaded_); + } + } + if (time_conv_) { + time_conv_->load_state_dict( + state_dict.get_dict_with_prefix("time_conv.")); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_weight_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(is_bias_loaded_) << "bias is not loaded for " << prefix + "bias"; + if (time_conv_) { + time_conv_->verify_loaded_weights("time_conv."); + } + } + + private: + int64_t dim_; + std::string mode_; + bool is_weight_loaded_{false}; + bool is_bias_loaded_{false}; + torch::Tensor rep_tensor_; + torch::nn::Sequential resample_{nullptr}; + QwenImageCausalConv3d time_conv_{nullptr}; +}; + +TORCH_MODULE(QwenImageResample); + +class QwenImageResidualBlockImpl : public QwenImageBaseModule { + public: + QwenImageResidualBlockImpl(const ModelContext& context, + int64_t in_dim, + int64_t out_dim, + double dropout = 0.0, + const std::string& non_linearity = "silu") + : in_dim_(in_dim), out_dim_(out_dim) { + activation_ = register_module("silu", torch::nn::SiLU()); + + norm1_ = register_module("norm1", + QwenImageRMS_norm(context, + in_dim, + /*channel_first=*/true, + /*images=*/false, + /*is_bias=*/false, + /*fused=*/false)); + conv1_ = register_module( + "conv1", + QwenImageCausalConv3d(context, + in_dim, + out_dim, + /*kernel_size=*/torch::IntArrayRef{3, 3, 3}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 1, 1})); + norm2_ = register_module("norm2", + QwenImageRMS_norm(context, + out_dim, + /*channel_first=*/true, + /*images=*/false, + /*is_bias=*/false, + /*fused=*/false)); + dropout_layer_ = register_module("dropout", torch::nn::Dropout(dropout)); + conv2_ = register_module( + "conv2", + QwenImageCausalConv3d(context, + out_dim, + out_dim, + /*kernel_size=*/torch::IntArrayRef{3, 3, 3}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 1, 1})); + + if (in_dim != out_dim) { + conv_shortcut_ = register_module( + "conv_shortcut", + QwenImageCausalConv3d(context, + in_dim, + out_dim, + /*kernel_size=*/torch::IntArrayRef{1, 1, 1}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{0, 0, 0})); + } else { + identity_ = register_module("conv_shortcut", torch::nn::Identity()); + } + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) override { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + torch::Tensor h = torch::empty({0}); + if (conv_shortcut_) { + h = conv_shortcut_->forward(x); + } else { + h = identity_->forward(x); + } + auto result_x = x; + + result_x = norm1_->forward(result_x); + result_x = activation_->forward(result_x); + + if (feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + + result_x = conv1_->forward(result_x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv1_->forward(result_x); + } + result_x = norm2_->forward(result_x); + result_x = activation_->forward(result_x); + result_x = dropout_layer_->forward(result_x); + + if (feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + result_x = conv2_->forward(result_x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv2_->forward(result_x); + } + + return result_x + h; + } + + void load_state_dict(const StateDict& state_dict) { + norm1_->load_state_dict(state_dict.get_dict_with_prefix("norm1.")); + norm2_->load_state_dict(state_dict.get_dict_with_prefix("norm2.")); + + conv1_->load_state_dict(state_dict.get_dict_with_prefix("conv1.")); + + conv2_->load_state_dict(state_dict.get_dict_with_prefix("conv2.")); + + if (conv_shortcut_) { + conv_shortcut_->load_state_dict( + state_dict.get_dict_with_prefix("conv_shortcut.")); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + norm1_->verify_loaded_weights("norm1."); + norm2_->verify_loaded_weights("norm2."); + conv1_->verify_loaded_weights("conv1."); + conv2_->verify_loaded_weights("conv2."); + if (conv_shortcut_) { + conv_shortcut_->verify_loaded_weights("conv_shortcut."); + } + } + + private: + int64_t in_dim_, out_dim_; + QwenImageRMS_norm norm1_{nullptr}, norm2_{nullptr}; + QwenImageCausalConv3d conv1_{nullptr}, conv2_{nullptr}; + QwenImageCausalConv3d conv_shortcut_{nullptr}; + torch::nn::Dropout dropout_layer_{nullptr}; + torch::nn::SiLU activation_{nullptr}; + torch::nn::Identity identity_{nullptr}; +}; + +TORCH_MODULE(QwenImageResidualBlock); + +class QwenImageAttentionBlockImpl : public QwenImageBaseModule { + public: + QwenImageAttentionBlockImpl(const ModelContext& context, int64_t dim) + : dim_(dim) { + norm_ = register_module("norm", + QwenImageRMS_norm(context, + dim, + /*channel_first=*/true, + /*images=*/true, + /*is_bias=*/false, + /*fused=*/false)); + to_qkv_ = register_module( + "to_qkv", + torch::nn::Conv2d(torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim * 3, + /*kernel_size=*/1))); + proj_ = register_module( + "proj", + torch::nn::Conv2d(torch::nn::Conv2dOptions(/*in_channels=*/dim, + /*out_channels=*/dim, + /*kernel_size=*/1))); + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) override { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + auto identity = x; + auto sizes = x.sizes(); + auto b = sizes[0], c = sizes[1], t = sizes[2], h = sizes[3], w = sizes[4]; + + auto reshaped_x = x.permute({0, 2, 1, 3, 4}).reshape({b * t, c, h, w}); + reshaped_x = norm_->forward(reshaped_x); + + auto qkv = to_qkv_->forward(reshaped_x); + qkv = qkv.reshape({b * t, 1, c * 3, h * w}); + qkv = qkv.permute({0, 1, 3, 2}).contiguous(); + + auto chunks = qkv.chunk(3, -1); + auto q = chunks[0], k = chunks[1], v = chunks[2]; + + auto results = at_npu::native::custom_ops::npu_fusion_attention( + q, + k, + v, + /*head_num=*/1, + /*input_layout=*/"BNSD", + /*pse*/ torch::nullopt, + /*padding_mask=*/torch::nullopt, + /*atten_mask=*/torch::nullopt, + /*scale=*/pow(c, -0.5), + /*keep_prob=*/1.0, + /*pre_tockens=*/65535, + /*next_tockens=*/65535); + auto attn_output = std::get<0>(results); + attn_output = + attn_output.squeeze(1).permute({0, 2, 1}).reshape({b * t, c, h, w}); + + auto output = proj_->forward(attn_output); + + output = output.view({b, t, c, h, w}).permute({0, 2, 1, 3, 4}); + + return output + identity; + } + + void load_state_dict(const StateDict& state_dict) { + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + + weight::load_weight( + state_dict, "to_qkv.weight", to_qkv_->weight, is_qkv_weight_loaded_); + weight::load_weight( + state_dict, "to_qkv.bias", to_qkv_->bias, is_qkv_bias_loaded_); + weight::load_weight( + state_dict, "proj.weight", proj_->weight, is_proj_weight_loaded_); + weight::load_weight( + state_dict, "proj.bias", proj_->bias, is_proj_bias_loaded_); + } + + void verify_loaded_weights(const std::string& prefix) { + norm_->verify_loaded_weights("norm."); + + CHECK(is_qkv_weight_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(is_qkv_bias_loaded_) + << "weight is not loaded for " << prefix + "bias"; + CHECK(is_proj_weight_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(is_proj_bias_loaded_) + << "weight is not loaded for " << prefix + "bias"; + } + + private: + int64_t dim_; + QwenImageRMS_norm norm_{nullptr}; + torch::nn::Conv2d to_qkv_{nullptr}; + torch::nn::Conv2d proj_{nullptr}; + bool is_qkv_weight_loaded_{false}; + bool is_qkv_bias_loaded_{false}; + bool is_proj_weight_loaded_{false}; + bool is_proj_bias_loaded_{false}; +}; + +TORCH_MODULE(QwenImageAttentionBlock); + +class QwenImageMidBlockImpl : public torch::nn::Module { + public: + QwenImageMidBlockImpl(const ModelContext& context, + int64_t dim, + double dropout = 0.0, + const std::string& non_linearity = "silu", + int64_t num_layers = 1) + : dim_(dim) { + resnets_ = register_module("resnets", torch::nn::ModuleList()); + attentions_ = register_module("attentions", torch::nn::ModuleList()); + + auto resnet_0 = + QwenImageResidualBlock(context, dim, dim, dropout, non_linearity); + resnets_->push_back(resnet_0); + + for (int64_t i = 0; i < num_layers; i++) { + auto attention = QwenImageAttentionBlock(context, dim); + attentions_->push_back(attention); + + auto resnet = + QwenImageResidualBlock(context, dim, dim, dropout, non_linearity); + resnets_->push_back(resnet); + } + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + auto result_x = x; + + result_x = resnets_[0]->as()->forward( + result_x, feat_cache, feat_idx); + + for (size_t i = 0; i < attentions_->size(); i++) { + result_x = + attentions_[i]->as()->forward(result_x); + result_x = resnets_[i + 1]->as()->forward( + result_x, feat_cache, feat_idx); + } + + return result_x; + } + + void load_state_dict(const StateDict& state_dict) { + for (size_t i = 0; i < resnets_->size(); i++) { + auto prefix = "resnets." + std::to_string(i) + "."; + resnets_[i]->as()->load_state_dict( + state_dict.get_dict_with_prefix(prefix)); + } + + for (size_t i = 0; i < attentions_->size(); i++) { + auto prefix = "attentions." + std::to_string(i) + "."; + attentions_[i]->as()->load_state_dict( + state_dict.get_dict_with_prefix(prefix)); + } + } + + void verify_loaded_weights(const std::string& prefix) { + for (size_t i = 0; i < resnets_->size(); i++) { + auto prefix = "resnets." + std::to_string(i) + "."; + resnets_[i]->as()->verify_loaded_weights(prefix); + } + + for (size_t i = 0; i < attentions_->size(); i++) { + auto prefix = "attentions." + std::to_string(i) + "."; + attentions_[i]->as()->verify_loaded_weights( + prefix); + } + } + + private: + int64_t dim_; + torch::nn::ModuleList resnets_; + torch::nn::ModuleList attentions_; +}; + +TORCH_MODULE(QwenImageMidBlock); + +class QwenImageEncoder3dImpl : public torch::nn::Module { + public: + QwenImageEncoder3dImpl(const ModelContext& context, + int64_t dim = 128, + int64_t z_dim = 4, + std::vector dim_mult = {1, 2, 4, 4}, + int64_t num_res_blocks = 2, + std::vector attn_scales = {}, + std::vector temperal_downsample = {true, + true, + false}, + double dropout = 0.0, + int64_t input_channels = 3, + const std::string& non_linearity = "silu") + : dim_(dim), + z_dim_(z_dim), + dim_mult_(dim_mult), + num_res_blocks_(num_res_blocks), + attn_scales_(attn_scales), + temperal_downsample_(temperal_downsample) { + nonlinearity_ = register_module("silu", torch::nn::SiLU()); + + std::vector dims = {dim * 1}; + for (auto u : dim_mult_) { + dims.push_back(dim * u); + } + + double scale = 1.0; + + conv_in_ = register_module( + "conv_in", + QwenImageCausalConv3d(context, + input_channels, + /*out_channels=*/dims[0], + /*kernel_size=*/torch::IntArrayRef{3, 3, 3}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 1, 1})); + + down_blocks_ = register_module("down_blocks", torch::nn::ModuleList()); + + size_t counter = 0; + for (size_t i = 0; i < dims.size() - 1; i++) { + int64_t in_dim = dims[i]; + int64_t out_dim = dims[i + 1]; + + for (int64_t j = 0; j < num_res_blocks_; j++) { + auto res_block = QwenImageResidualBlock( + context, in_dim, out_dim, dropout, non_linearity); + down_blocks_->push_back(res_block); + resnet_blocks_idx_.push_back(counter); + counter += 1; + + if (std::find(attn_scales_.begin(), attn_scales_.end(), scale) != + attn_scales_.end()) { + auto attn_block = QwenImageAttentionBlock(context, out_dim); + down_blocks_->push_back(attn_block); + attention_blocks_idx_.push_back(counter); + counter += 1; + } + in_dim = out_dim; + } + + if (i != dim_mult_.size() - 1) { + std::string mode = + temperal_downsample_[i] ? "downsample3d" : "downsample2d"; + auto downsample = QwenImageResample(context, out_dim, mode); + down_blocks_->push_back(downsample); + resample_blocks_idx_.push_back(counter); + counter += 1; + scale /= 2.0; + } + } + + mid_block_ = register_module( + "mid_block", + QwenImageMidBlock( + context, dims.back(), dropout, non_linearity, /*num_layers=*/1)); + + norm_out_ = register_module("norm_out", + QwenImageRMS_norm(context, + dims.back(), + /*channel_first=*/true, + /*images=*/false, + /*is_bias=*/false, + /*fused=*/false)); + conv_out_ = register_module( + "conv_out", + QwenImageCausalConv3d(context, + /*in_channels=*/dims.back(), + /*out_channels=*/z_dim, + /*kernel_size=*/torch::IntArrayRef{3, 3, 3}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 1, 1})); + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + torch::Tensor result_x; + + if (feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + auto cache_x = + x.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + result_x = conv_in_->forward(x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv_in_->forward(x); + } + + int64_t counter = 0; + for (auto& layer : *down_blocks_) { + if (feat_cache) { + counter = counter + 1; + result_x = + std::dynamic_pointer_cast(layer)->forward( + result_x, feat_cache, feat_idx); + } else { + result_x = + std::dynamic_pointer_cast(layer)->forward( + result_x, + nullptr, + std::make_shared>( + std::vector{0})); + } + } + + result_x = mid_block_->forward(result_x, feat_cache, feat_idx); + + result_x = norm_out_->forward(result_x); + result_x = nonlinearity_->forward(result_x); + + if (feat_cache && feat_idx) { + auto idx = (*feat_idx)[0]; + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && idx < feat_cache->size() && + feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + + result_x = conv_out_->forward(result_x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv_out_->forward(result_x); + } + + return result_x; + } + + void load_state_dict(const StateDict& state_dict) { + conv_in_->load_state_dict(state_dict.get_dict_with_prefix("conv_in.")); + + for (size_t resnet_idx : resnet_blocks_idx_) { + down_blocks_[resnet_idx]->as()->load_state_dict( + state_dict.get_dict_with_prefix("down_blocks." + + std::to_string(resnet_idx) + ".")); + } + + for (size_t attention_idx : attention_blocks_idx_) { + down_blocks_[attention_idx] + ->as() + ->load_state_dict(state_dict.get_dict_with_prefix( + "down_blocks." + std::to_string(attention_idx) + ".")); + } + + for (size_t resample_idx : resample_blocks_idx_) { + down_blocks_[resample_idx]->as()->load_state_dict( + state_dict.get_dict_with_prefix("down_blocks." + + std::to_string(resample_idx) + ".")); + } + + mid_block_->load_state_dict(state_dict.get_dict_with_prefix("mid_block.")); + norm_out_->load_state_dict(state_dict.get_dict_with_prefix("norm_out.")); + conv_out_->load_state_dict(state_dict.get_dict_with_prefix("conv_out.")); + } + + void verify_loaded_weights(const std::string& prefix) { + conv_in_->verify_loaded_weights("conv_in."); + for (size_t resnet_idx : resnet_blocks_idx_) { + down_blocks_[resnet_idx] + ->as() + ->verify_loaded_weights(std::to_string(resnet_idx) + "."); + } + + for (size_t attention_idx : attention_blocks_idx_) { + down_blocks_[attention_idx] + ->as() + ->verify_loaded_weights(std::to_string(attention_idx) + "."); + } + + for (size_t resample_idx : resample_blocks_idx_) { + down_blocks_[resample_idx] + ->as() + ->verify_loaded_weights(std::to_string(resample_idx) + "."); + } + mid_block_->verify_loaded_weights("mid_block."); + norm_out_->verify_loaded_weights("norm_out."); + conv_out_->verify_loaded_weights("conv_out."); + } + + private: + int64_t dim_, z_dim_; + std::vector dim_mult_; + std::vector resnet_blocks_idx_; + std::vector attention_blocks_idx_; + std::vector resample_blocks_idx_; + int64_t num_res_blocks_; + std::vector attn_scales_; + std::vector temperal_downsample_; + + torch::nn::SiLU nonlinearity_{nullptr}; + QwenImageCausalConv3d conv_in_{nullptr}; + torch::nn::ModuleList down_blocks_{nullptr}; + QwenImageMidBlock mid_block_{nullptr}; + QwenImageRMS_norm norm_out_{nullptr}; + QwenImageCausalConv3d conv_out_{nullptr}; +}; + +TORCH_MODULE(QwenImageEncoder3d); + +class QwenImageUpBlockImpl : public torch::nn::Module { + public: + QwenImageUpBlockImpl(const ModelContext& context, + int64_t in_dim, + int64_t out_dim, + int64_t num_res_blocks, + double dropout = 0.0, + const std::string& upsample_mode = "", + const std::string& non_linearity = "silu") + : in_dim_(in_dim), out_dim_(out_dim) { + resnets_ = register_module("resnets", torch::nn::ModuleList()); + int64_t current_dim = in_dim; + + for (int64_t i = 0; i < num_res_blocks + 1; i++) { + auto resnet = QwenImageResidualBlock( + context, current_dim, out_dim, dropout, non_linearity); + resnets_->push_back(resnet); + current_dim = out_dim; + } + + if (!upsample_mode.empty()) { + upsamplers_ = register_module("upsamplers", torch::nn::ModuleList()); + auto upsample = QwenImageResample(context, out_dim, upsample_mode); + upsamplers_->push_back(upsample); + } + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + + auto result_x = x; + + for (auto& resnet : *resnets_) { + if (feat_cache && feat_idx) { + result_x = + std::dynamic_pointer_cast(resnet)->forward( + result_x, feat_cache, feat_idx); + } else { + result_x = + std::dynamic_pointer_cast(resnet)->forward( + result_x, + nullptr, + std::make_shared>( + std::vector{0})); + } + } + + if (upsamplers_) { + if (feat_cache && feat_idx) { + result_x = + std::dynamic_pointer_cast(upsamplers_[0]) + ->forward(result_x, feat_cache, feat_idx); + } else { + result_x = + std::dynamic_pointer_cast(upsamplers_[0]) + ->forward(result_x, + nullptr, + std::make_shared>( + std::vector{0})); + } + } + + return result_x; + } + + void load_state_dict(const StateDict& state_dict) { + for (size_t i = 0; i < resnets_->size(); i++) { + auto prefix = "resnets." + std::to_string(i) + "."; + resnets_[i]->as()->load_state_dict( + state_dict.get_dict_with_prefix(prefix)); + } + + if (upsamplers_) { + upsamplers_[0]->as()->load_state_dict( + state_dict.get_dict_with_prefix("upsamplers.0.")); + } + } + + void verify_loaded_weights(const std::string& prefix) { + for (size_t i = 0; i < resnets_->size(); i++) { + auto prefix = "resnets." + std::to_string(i) + "."; + resnets_[i]->as()->verify_loaded_weights(prefix); + } + + if (upsamplers_) { + upsamplers_[0]->as()->verify_loaded_weights( + "upsamplers.0."); + } + } + + private: + int64_t in_dim_, out_dim_; + torch::nn::ModuleList resnets_{nullptr}; + torch::nn::ModuleList upsamplers_{nullptr}; +}; + +TORCH_MODULE(QwenImageUpBlock); + +class QwenImageDecoder3dImpl : public torch::nn::Module { + public: + QwenImageDecoder3dImpl(const ModelContext& context, + int64_t dim = 128, + int64_t z_dim = 4, + std::vector dim_mult = {1, 2, 4, 4}, + int64_t num_res_blocks = 2, + std::vector attn_scales = {}, + std::vector temperal_upsample = {false, + true, + true}, + double dropout = 0.0, + int64_t input_channels = 3, + const std::string& non_linearity = "silu") + : dim_(dim), + z_dim_(z_dim), + dim_mult_(dim_mult), + num_res_blocks_(num_res_blocks), + attn_scales_(attn_scales), + temperal_upsample_(temperal_upsample) { + nonlinearity_ = register_module("silu", torch::nn::SiLU()); + + std::vector dims = {dim * dim_mult.back()}; + for (int64_t i = dim_mult.size() - 1; i >= 0; i--) { + dims.push_back(dim * dim_mult.at(i)); + } + + double scale = 1.0 / std::pow(2, dim_mult.size() - 2); + + conv_in_ = + register_module("conv_in", + QwenImageCausalConv3d(context, + z_dim, + dims[0], + torch::IntArrayRef{3, 3, 3}, + torch::IntArrayRef{1, 1, 1}, + torch::IntArrayRef{1, 1, 1})); + + mid_block_ = register_module( + "mid_block", + QwenImageMidBlock( + context, dims[0], dropout, non_linearity, /*num_layers=*/1)); + + up_blocks_ = register_module("up_blocks", torch::nn::ModuleList()); + for (size_t i = 0; i < dims.size() - 1; i++) { + int64_t in_dim = dims[i]; + int64_t out_dim = dims[i + 1]; + + if (i > 0) { + in_dim = in_dim / 2; + } + + std::string upsample_mode; + if (i != dim_mult.size() - 1) { + upsample_mode = temperal_upsample[i] ? "upsample3d" : "upsample2d"; + } + + auto up_block = QwenImageUpBlock(context, + in_dim, + out_dim, + num_res_blocks, + dropout, + upsample_mode, + non_linearity); + up_blocks_->push_back(up_block); + + if (!upsample_mode.empty()) { + scale *= 2.0; + } + } + + norm_out_ = register_module( + "norm_out", + QwenImageRMS_norm(context, dims.back(), true, false, false, false)); + conv_out_ = register_module( + "conv_out", + QwenImageCausalConv3d(context, + /*in_channels=*/dims.back(), + /*out_channels=*/input_channels, + /*kernel_size=*/torch::IntArrayRef{3, 3, 3}, + /*stride=*/torch::IntArrayRef{1, 1, 1}, + /*padding=*/torch::IntArrayRef{1, 1, 1})); + } + + torch::Tensor forward( + const torch::Tensor& x, + std::shared_ptr> feat_cache = nullptr, + std::shared_ptr> feat_idx = nullptr) { + if (feat_idx == nullptr) { + feat_idx = + std::make_shared>(std::vector{0}); + } + auto result_x = x; + + if (feat_cache) { + auto idx = (*feat_idx)[0]; + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + + result_x = conv_in_->forward(result_x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv_in_->forward(result_x); + } + + result_x = mid_block_->forward(result_x, feat_cache, feat_idx); + + for (auto& up_block : *up_blocks_) { + result_x = up_block->as()->forward( + result_x, feat_cache, feat_idx); + } + + result_x = norm_out_->forward(result_x); + result_x = nonlinearity_->forward(result_x); + + if (feat_cache) { + auto idx = (*feat_idx)[0]; + auto cache_x = + result_x + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-CACHE_T, torch::indexing::None)}) + .clone(); + + if (cache_x.size(2) < 2 && idx < feat_cache->size() && + feat_cache->at(idx).defined()) { + auto last_frame = + feat_cache->at(idx) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-1, torch::indexing::None)}) + .unsqueeze(2) + .to(cache_x.device()); + cache_x = torch::cat({last_frame, cache_x}, 2); + } + + result_x = conv_out_->forward(result_x, feat_cache->at(idx)); + feat_cache->at(idx) = cache_x; + (*feat_idx)[0]++; + } else { + result_x = conv_out_->forward(result_x); + } + return result_x; + } + + void load_state_dict(const StateDict& state_dict) { + conv_in_->load_state_dict(state_dict.get_dict_with_prefix("conv_in.")); + mid_block_->load_state_dict(state_dict.get_dict_with_prefix("mid_block.")); + + for (size_t i = 0; i < up_blocks_->size(); i++) { + auto prefix = "up_blocks." + std::to_string(i) + "."; + up_blocks_[i]->as()->load_state_dict( + state_dict.get_dict_with_prefix(prefix)); + } + + norm_out_->load_state_dict(state_dict.get_dict_with_prefix("norm_out.")); + conv_out_->load_state_dict(state_dict.get_dict_with_prefix("conv_out.")); + } + + void verify_loaded_weights(const std::string& prefix) { + conv_in_->verify_loaded_weights("conv_in."); + + mid_block_->verify_loaded_weights("mid_block."); + for (size_t i = 0; i < up_blocks_->size(); i++) { + auto prefix = "up_blocks." + std::to_string(i) + "."; + up_blocks_[i]->as()->verify_loaded_weights(prefix); + } + + norm_out_->verify_loaded_weights("norm_out."); + conv_out_->verify_loaded_weights("conv_out."); + } + + std::vector> get_modules() const { + std::vector> module = modules(); + return module; + } + + private: + int64_t dim_, z_dim_; + std::vector dim_mult_; + int64_t num_res_blocks_; + std::vector attn_scales_; + std::vector temperal_upsample_; + + torch::nn::SiLU nonlinearity_{nullptr}; + QwenImageCausalConv3d conv_in_{nullptr}; + QwenImageMidBlock mid_block_{nullptr}; + torch::nn::ModuleList up_blocks_{nullptr}; + QwenImageRMS_norm norm_out_{nullptr}; + QwenImageCausalConv3d conv_out_{nullptr}; +}; + +TORCH_MODULE(QwenImageDecoder3d); + +class DiagonalGaussianDistribution { + public: + DiagonalGaussianDistribution(torch::Tensor parameters, + bool deterministic = false) + : parameters_(std::move(parameters)), deterministic_(deterministic) { + auto chunks = parameters_.chunk(2, 1); + mean_ = chunks[0]; + logvar_ = chunks[1]; + + logvar_ = torch::clamp(logvar_, -30.0f, 20.0f); + + std_ = torch::exp(0.5f * logvar_); + var_ = torch::exp(logvar_); + + if (deterministic_) { + std_.fill_(0.0f); + var_.fill_(0.0f); + } + } + + torch::Tensor sample(int64_t seed) const { + torch::TensorOptions options = mean_.options(); + std::vector shape(mean_.sizes().begin(), mean_.sizes().end()); + return mean_ + std_ * xllm::dit::randn_tensor(shape, seed, options); + } + + torch::Tensor kl(const std::optional& other = + std::nullopt) const { + if (deterministic_) { + return torch::tensor(0.0f, mean_.options()); + } + + if (!other.has_value()) { + return 0.5f * torch::sum(torch::pow(mean_, 2) + var_ - 1.0f - logvar_, + {1, 2, 3}); + } else { + const auto& other_dist = other.value(); + return 0.5f * torch::sum(torch::pow(mean_ - other_dist.mean_, 2) / + other_dist.var_ + + var_ / other_dist.var_ - 1.0f - logvar_ + + other_dist.logvar_, + {1, 2, 3}); + } + } + + torch::Tensor nll(const torch::Tensor& sample, + const std::vector& dims = {1, 2, 3}) const { + if (deterministic_) { + return torch::tensor(0.0f, mean_.options()); + } + const float logtwopi = std::log(2.0f * M_PI); + return 0.5f * + torch::sum(logtwopi + logvar_ + torch::pow(sample - mean_, 2) / var_, + dims); + } + + torch::Tensor mode() const { return mean_; } + + const torch::Tensor& mean() const { return mean_; } + const torch::Tensor& std() const { return std_; } + const torch::Tensor& var() const { return var_; } + const torch::Tensor& logvar() const { return logvar_; } + + private: + torch::Tensor parameters_; + torch::Tensor mean_; + torch::Tensor logvar_; + torch::Tensor std_; + torch::Tensor var_; + bool deterministic_; +}; + +struct AutoencoderKLOutput { + DiagonalGaussianDistribution latent_dist; + AutoencoderKLOutput(DiagonalGaussianDistribution dist) + : latent_dist(std::move(dist)) {} +}; + +struct DecoderOutput { + torch::Tensor sample; + DecoderOutput(torch::Tensor sample) : sample(std::move(sample)) {} +}; + +class AutoencoderKLQwenImageImpl : public torch::nn::Module { + public: + AutoencoderKLQwenImageImpl(const ModelContext& context) + : args_(context.get_model_args()), + z_dim_(context.get_model_args().z_dim()), + temperal_downsample_(context.get_model_args().temperal_downsample()), + base_dim_(context.get_model_args().base_dim()), + dim_mult_(context.get_model_args().dim_mult()), + num_res_blocks_(context.get_model_args().num_res_blocks()), + attn_scales_(context.get_model_args().attn_scales()), + dropout_(context.get_model_args().dropout()) { + temperal_upsample_ = std::vector(temperal_downsample_.rbegin(), + temperal_downsample_.rend()); + + int64_t input_channels = context.get_model_args().in_channels(); + encoder_ = register_module("encoder", + QwenImageEncoder3d(context, + base_dim_, + z_dim_ * 2, + dim_mult_, + num_res_blocks_, + attn_scales_, + temperal_downsample_, + dropout_, + input_channels)); + + quant_conv_ = + register_module("quant_conv", + QwenImageCausalConv3d(context, + z_dim_ * 2, + z_dim_ * 2, + torch::IntArrayRef{1, 1, 1}, + torch::IntArrayRef{1, 1, 1}, + torch::IntArrayRef{0, 0, 0})); + + post_quant_conv_ = + register_module("post_quant_conv", + QwenImageCausalConv3d(context, + z_dim_, + z_dim_, + torch::IntArrayRef{1, 1, 1}, + torch::IntArrayRef{1, 1, 1}, + torch::IntArrayRef{0, 0, 0})); + + decoder_ = register_module("decoder", + QwenImageDecoder3d(context, + base_dim_, + z_dim_, + dim_mult_, + num_res_blocks_, + attn_scales_, + temperal_upsample_, + dropout_, + input_channels)); + + spatial_compression_ratio_ = + static_cast(std::pow(2, temperal_downsample_.size())); + + use_slicing_ = false; + use_tiling_ = false; + tile_sample_min_height_ = 256; + tile_sample_min_width_ = 256; + tile_sample_stride_height_ = 192; + tile_sample_stride_width_ = 192; + + cached_conv_counts_ = {{"decoder", count_conv3d_modules(*decoder_)}, + {"encoder", count_conv3d_modules(*encoder_)}}; + } + + void enable_tiling(int64_t tile_sample_min_height = -1, + int64_t tile_sample_min_width = -1, + int64_t tile_sample_stride_height = -1, + int64_t tile_sample_stride_width = -1) { + use_tiling_ = true; + if (tile_sample_min_height > 0) + tile_sample_min_height_ = tile_sample_min_height; + if (tile_sample_min_width > 0) + tile_sample_min_width_ = tile_sample_min_width; + if (tile_sample_stride_height > 0) + tile_sample_stride_height_ = tile_sample_stride_height; + if (tile_sample_stride_width > 0) + tile_sample_stride_width_ = tile_sample_stride_width; + } + + void clear_cache() { + conv_num_ = count_conv3d_modules(*decoder_); + conv_idx_ = std::make_shared>(std::vector{0}); + feat_map_ = std::make_shared>( + std::vector(conv_num_)); + + enc_conv_num_ = count_conv3d_modules(*encoder_); + enc_conv_idx_ = + std::make_shared>(std::vector{0}); + enc_feat_map_ = std::make_shared>( + std::vector(enc_conv_num_)); + } + + torch::Tensor _encode(const torch::Tensor& x) { + auto sizes = x.sizes(); + auto b = sizes[0], c = sizes[1], num_frame = sizes[2], height = sizes[3], + width = sizes[4]; + + if (use_tiling_ && + (width > tile_sample_min_width_ || height > tile_sample_min_height_)) { + return tiled_encode(x); + } + + clear_cache(); + auto iter = 1 + (num_frame - 1) / 4; + torch::Tensor out; + + for (int64_t i = 0; i < iter; i++) { + enc_conv_idx_->at(0) = 0; + torch::Tensor tile; + + if (i == 0) { + tile = x.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, 1), + torch::indexing::Slice(), + torch::indexing::Slice()}); + } else { + tile = x.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(1 + 4 * (i - 1), 1 + 4 * i), + torch::indexing::Slice(), + torch::indexing::Slice()}); + } + + auto encoded_tile = encoder_->forward(tile, enc_feat_map_, enc_conv_idx_); + + if (i == 0) { + out = encoded_tile; + } else { + out = torch::cat({out, encoded_tile}, 2); + } + } + + auto enc = quant_conv_->forward(out); + clear_cache(); + return enc; + } + + AutoencoderKLOutput encode(const torch::Tensor& x, bool return_dict = true) { + torch::Tensor h; + + if (use_slicing_ && x.size(0) > 1) { + std::vector encoded_slices; + auto slices = x.split(1); + for (auto& slice : slices) { + encoded_slices.push_back(_encode(slice)); + } + h = torch::cat(encoded_slices); + } else { + h = _encode(x); + } + + auto posterior = DiagonalGaussianDistribution(h); + + if (!return_dict) { + return {posterior}; + } + + AutoencoderKLOutput output(posterior); + return output; + } + + DecoderOutput _decode(const torch::Tensor& z, bool return_dict = true) { + auto sizes = z.sizes(); + auto b = sizes[0], c = sizes[1], num_frame = sizes[2], height = sizes[3], + width = sizes[4]; + + auto tile_latent_min_height = + tile_sample_min_height_ / spatial_compression_ratio_; + auto tile_latent_min_width = + tile_sample_min_width_ / spatial_compression_ratio_; + + if (use_tiling_ && + (width > tile_latent_min_width || height > tile_latent_min_height)) { + return tiled_decode(z, return_dict); + } + + clear_cache(); + auto x = post_quant_conv_->forward(z); + torch::Tensor out; + + for (int64_t i = 0; i < num_frame; i++) { + conv_idx_->at(0) = 0; + auto frame = x.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(i, i + 1), + torch::indexing::Slice(), + torch::indexing::Slice()}); + + auto decoded_frame = decoder_->forward(frame, feat_map_, conv_idx_); + + if (i == 0) { + out = decoded_frame; + } else { + out = torch::cat({out, decoded_frame}, 2); + } + } + + out = torch::clamp(out, -1.0, 1.0); + clear_cache(); + + if (!return_dict) { + return {out}; + } + DecoderOutput output(out); + + return output; + } + + DecoderOutput decode(const torch::Tensor& z, bool return_dict = true) { + torch::Tensor decoded; + + if (use_slicing_ && z.size(0) > 1) { + std::vector decoded_slices; + auto slices = z.split(1); + for (auto& slice : slices) { + auto output = _decode(slice, true); + decoded_slices.push_back(output.sample); + } + decoded = torch::cat(decoded_slices); + } else { + auto output = _decode(z, true); + decoded = output.sample; + } + + if (!return_dict) { + return {decoded}; + } + DecoderOutput output(decoded); + + return output; + } + + torch::Tensor blend_v(const torch::Tensor& a, + const torch::Tensor& b, + int64_t blend_extent) { + auto result_b = b.clone(); + blend_extent = std::min({a.size(3), b.size(3), blend_extent}); + + for (int64_t y = 0; y < blend_extent; y++) { + auto weight_a = 1.0 - static_cast(y) / blend_extent; + auto weight_b = static_cast(y) / blend_extent; + + auto a_slice = a.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-blend_extent + y, -blend_extent + y + 1), + torch::indexing::Slice()}); + + auto b_slice = result_b.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(y, y + 1), + torch::indexing::Slice()}); + + auto blended = a_slice * weight_a + b_slice * weight_b; + result_b.index_put_({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(y, y + 1), + torch::indexing::Slice()}, + blended); + } + + return result_b; + } + + torch::Tensor blend_h(const torch::Tensor& a, + const torch::Tensor& b, + int64_t blend_extent) { + auto result_b = b.clone(); + blend_extent = std::min({a.size(4), b.size(4), blend_extent}); + + for (int64_t x = 0; x < blend_extent; x++) { + auto weight_a = 1.0 - static_cast(x) / blend_extent; + auto weight_b = static_cast(x) / blend_extent; + + auto a_slice = a.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(-blend_extent + x, -blend_extent + x + 1)}); + + auto b_slice = result_b.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(x, x + 1)}); + + auto blended = a_slice * weight_a + b_slice * weight_b; + result_b.index_put_({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(x, x + 1)}, + blended); + } + + return result_b; + } + + torch::Tensor tiled_encode(const torch::Tensor& x) { + auto sizes = x.sizes(); + auto b = sizes[0], c = sizes[1], num_frames = sizes[2], height = sizes[3], + width = sizes[4]; + + auto latent_height = height / spatial_compression_ratio_; + auto latent_width = width / spatial_compression_ratio_; + + auto tile_latent_min_height = + tile_sample_min_height_ / spatial_compression_ratio_; + auto tile_latent_min_width = + tile_sample_min_width_ / spatial_compression_ratio_; + auto tile_latent_stride_height = + tile_sample_stride_height_ / spatial_compression_ratio_; + auto tile_latent_stride_width = + tile_sample_stride_width_ / spatial_compression_ratio_; + + auto blend_height = tile_latent_min_height - tile_latent_stride_height; + auto blend_width = tile_latent_min_width - tile_latent_stride_width; + + std::vector> rows; + + for (int64_t i = 0; i < height; i += tile_sample_stride_height_) { + std::vector row; + + for (int64_t j = 0; j < width; j += tile_sample_stride_width_) { + clear_cache(); + std::vector time_frames; + auto frame_range = 1 + (num_frames - 1) / 4; + + for (int64_t k = 0; k < frame_range; k++) { + enc_conv_idx_->at(0) = 0; + torch::Tensor tile; + + if (k == 0) { + tile = x.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, 1), + torch::indexing::Slice(i, i + tile_sample_min_height_), + torch::indexing::Slice(j, j + tile_sample_min_width_)}); + } else { + tile = x.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(1 + 4 * (k - 1), 1 + 4 * k), + torch::indexing::Slice(i, i + tile_sample_min_height_), + torch::indexing::Slice(j, j + tile_sample_min_width_)}); + } + + auto encoded_tile = + encoder_->forward(tile, enc_feat_map_, enc_conv_idx_); + auto quantized_tile = quant_conv_->forward(encoded_tile); + time_frames.push_back(quantized_tile); + } + + row.push_back(torch::cat(time_frames, 2)); + } + rows.push_back(row); + } + clear_cache(); + + std::vector result_rows; + + for (int64_t i = 0; i < static_cast(rows.size()); i++) { + std::vector result_row; + + for (int64_t j = 0; j < static_cast(rows[i].size()); j++) { + auto tile = rows[i][j]; + + if (i > 0) { + tile = blend_v(rows[i - 1][j], tile, blend_height); + } + if (j > 0) { + tile = blend_h(rows[i][j - 1], tile, blend_width); + } + + result_row.push_back( + tile.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, tile_latent_stride_height), + torch::indexing::Slice(0, tile_latent_stride_width)})); + } + + result_rows.push_back(torch::cat(result_row, -1)); + } + + auto enc = torch::cat(result_rows, 3) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, latent_height), + torch::indexing::Slice(0, latent_width)}); + + return enc; + } + + DecoderOutput tiled_decode(const torch::Tensor& z, bool return_dict = true) { + auto sizes = z.sizes(); + auto b = sizes[0], c = sizes[1], num_frames = sizes[2], height = sizes[3], + width = sizes[4]; + + auto sample_height = height * spatial_compression_ratio_; + auto sample_width = width * spatial_compression_ratio_; + + auto tile_latent_min_height = + tile_sample_min_height_ / spatial_compression_ratio_; + auto tile_latent_min_width = + tile_sample_min_width_ / spatial_compression_ratio_; + auto tile_latent_stride_height = + tile_sample_stride_height_ / spatial_compression_ratio_; + auto tile_latent_stride_width = + tile_sample_stride_width_ / spatial_compression_ratio_; + + auto blend_height = tile_sample_min_height_ - tile_sample_stride_height_; + auto blend_width = tile_sample_min_width_ - tile_sample_stride_width_; + + std::vector> rows; + + for (int64_t i = 0; i < height; i += tile_latent_stride_height) { + std::vector row; + + for (int64_t j = 0; j < width; j += tile_latent_stride_width) { + clear_cache(); + std::vector time_frames; + + for (int64_t k = 0; k < num_frames; k++) { + conv_idx_->at(0) = 0; + auto tile = + z.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(k, k + 1), + torch::indexing::Slice(i, i + tile_latent_min_height), + torch::indexing::Slice(j, j + tile_latent_min_width)}); + + auto post_quant_tile = post_quant_conv_->forward(tile); + auto decoded_tile = + decoder_->forward(post_quant_tile, feat_map_, conv_idx_); + time_frames.push_back(decoded_tile); + } + + row.push_back(torch::cat(time_frames, 2)); + } + rows.push_back(row); + } + clear_cache(); + + std::vector result_rows; + + for (int64_t i = 0; i < static_cast(rows.size()); i++) { + std::vector result_row; + + for (int64_t j = 0; j < static_cast(rows[i].size()); j++) { + auto tile = rows[i][j]; + + if (i > 0) { + tile = blend_v(rows[i - 1][j], tile, blend_height); + } + if (j > 0) { + tile = blend_h(rows[i][j - 1], tile, blend_width); + } + + result_row.push_back( + tile.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, tile_sample_stride_height_), + torch::indexing::Slice(0, tile_sample_stride_width_)})); + } + + result_rows.push_back(torch::cat(result_row, -1)); + } + + auto dec = torch::cat(result_rows, 3) + .index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, sample_height), + torch::indexing::Slice(0, sample_width)}); + + if (!return_dict) { + return {dec}; + } + DecoderOutput output(dec); + return output; + } + + DecoderOutput forward(const torch::Tensor& sample, + bool sample_posterior = false, + bool return_dict = true, + int64_t seed = 42) { + auto x = sample; + + auto encode_output = encode(x, true); + auto posterior = encode_output.latent_dist; + + torch::Tensor z; + if (sample_posterior) { + z = posterior.sample(seed); + } else { + z = posterior.mode(); + } + + auto dec = decode(z, return_dict); + return dec; + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + encoder_->load_state_dict(state_dict->get_dict_with_prefix("encoder.")); + decoder_->load_state_dict(state_dict->get_dict_with_prefix("decoder.")); + quant_conv_->load_state_dict( + state_dict->get_dict_with_prefix("quant_conv.")); + post_quant_conv_->load_state_dict( + state_dict->get_dict_with_prefix("post_quant_conv.")); + } + verify_loaded_weights(""); + } + + void verify_loaded_weights(const std::string& prefix) { + encoder_->verify_loaded_weights("encoder."); + decoder_->verify_loaded_weights("decoder."); + quant_conv_->verify_loaded_weights("quant_conv."); + post_quant_conv_->verify_loaded_weights("post_quant_conv."); + } + + private: + template + int64_t count_conv3d_modules(const ModuleType& module) { + int64_t count = 0; + for (const auto& m : module.named_modules()) { + if (auto conv = + dynamic_cast(m.value().get())) { + count++; + } + } + return count; + } + + int64_t base_dim_; + int64_t z_dim_; + std::vector dim_mult_; + int64_t num_res_blocks_; + std::vector attn_scales_; + std::vector temperal_downsample_; + std::vector temperal_upsample_; + double dropout_; + + int64_t spatial_compression_ratio_; + bool use_slicing_, use_tiling_; + int64_t tile_sample_min_height_; + int64_t tile_sample_min_width_; + int64_t tile_sample_stride_height_; + int64_t tile_sample_stride_width_; + + std::unordered_map cached_conv_counts_; + + int64_t conv_num_; + int64_t enc_conv_num_; + std::shared_ptr> conv_idx_; + std::shared_ptr> enc_conv_idx_; + std::shared_ptr> feat_map_; + std::shared_ptr> enc_feat_map_; + + QwenImageEncoder3d encoder_{nullptr}; + QwenImageCausalConv3d quant_conv_{nullptr}; + QwenImageCausalConv3d post_quant_conv_{nullptr}; + QwenImageDecoder3d decoder_{nullptr}; + + ModelArgs args_; +}; + +TORCH_MODULE(AutoencoderKLQwenImage); + +REGISTER_MODEL_ARGS(AutoencoderKLQwenImage, [&] { + LOAD_ARG_OR(base_dim, "base_dim", 96); + LOAD_ARG_OR(z_dim, "z_dim", 16); + LOAD_ARG_OR(in_channels, "in_channels", 3); + LOAD_ARG_OR(dim_mult, "dim_mult", (std::vector{1, 2, 4, 4})); + LOAD_ARG_OR(attn_scales, "attn_scales", (std::vector{})); + LOAD_ARG_OR(temperal_downsample, + "temperal_downsample", + (std::vector{false, true, true})); + LOAD_ARG_OR(num_res_blocks, "num_res_blocks", 2); + LOAD_ARG_OR(dropout, "dropout", 0); + LOAD_ARG_OR(latents_mean, + "latents_mean", + (std::vector{-0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921})); + LOAD_ARG_OR(latents_std, + "latents_std", + (std::vector{2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916})); +}); + +} // namespace qwenimage +} // namespace xllm::dit::npu diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_base.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_base.h new file mode 100644 index 000000000..2bcda98d4 --- /dev/null +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_base.h @@ -0,0 +1,55 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once +#include +#include + +#include +#include +#include + +#include "autoencoder_kl_qwenimage.h" +#include "core/common/global_flags.h" +#include "core/framework/dit_cache/dit_cache.h" +#include "core/framework/dit_model_loader.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/framework/request/dit_request_state.h" +#include "core/framework/state_dict/state_dict.h" +#include "core/framework/state_dict/utils.h" +#include "framework/parallel_state/parallel_state.h" +#include "models/dit/flowmatch_euler_discrete_scheduler.h" +#include "models/dit/utils/common_util.h" +#include "models/model_registry.h" +#include "processors/qwen2_vl_image_processor.h" +#include "transformer_qwen_image.h" +namespace xllm::dit::npu { +namespace qwenimage { + +class QwenImagePipelineBaseImpl : public torch::nn::Module { + protected: + torch::Device device_ = torch::kCPU; + torch::ScalarType dtype_; + torch::TensorOptions options_; + AutoencoderKLQwenImage vae_{nullptr}; + xllm::dit::VAEImageProcessor vae_image_processor_{nullptr}; + std::unique_ptr qwen_image_processor_{nullptr}; + QwenImageTransformer2DModel transformer_{nullptr}; + std::unique_ptr qwen_tokenizer_; + std::unique_ptr tokenizer_; + xllm::FlowMatchEulerDiscreteScheduler scheduler_{nullptr}; +}; +} // namespace qwenimage +} // namespace xllm::dit::npu diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h new file mode 100644 index 000000000..1b19b2ea0 --- /dev/null +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h @@ -0,0 +1,663 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once +#include "core/framework/state_dict/state_dict.h" +#include "pipeline_qwenimage_base.h" + +#define CONDITION_IMAGE_SIZE 147456 +#define VAE_IMAGE_SIZE 1048576 + +namespace xllm::dit::npu { +namespace qwenimage { +class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { + public: + QwenImageEditPlusPipelineImpl(const DiTModelContext& context) + : parallel_args_(context.get_parallel_args()), + vae_model_args_(context.get_model_args("vae")) { + options_ = context.get_tensor_options(); + dtype_ = options_.dtype().toScalarType(); + device_ = options_.device(); + LOG(INFO) << "model info " << dtype_ << " ; " << options_.device(); + + in_channels_ = context.get_model_args("transformer").in_channels(); + num_layers_ = context.get_model_args("transformer").num_layers(); + + vae_scale_factor_ = static_cast( + std::pow(2, vae_model_args_.temperal_downsample().size())); + latent_channels_ = vae_model_args_.z_dim(); + tokenizer_max_length_ = 1024; + + prompt_template_encode_ = + "<|im_start|>system\nDescribe the key features of the input image " + "(color, shape, size, texture, objects, background), then explain how " + "the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while " + "maintaining consistency with the original input where " + "appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>" + "assistant\n"; + prompt_template_encode_start_idx_ = 64; + default_sample_size_ = 128; + + vae_ = AutoencoderKLQwenImage(context.get_model_context("vae")); + transformer_ = QwenImageTransformer2DModel( + context.get_model_context("transformer"), parallel_args_); + scheduler_ = + FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); + + vae_image_processor_ = + xllm::dit::VAEImageProcessor(context.get_model_context("vae"), + true, + true, + false, + false, + false, + latent_channels_); + register_module("vae", vae_); + register_module("scheduler", scheduler_); + register_module("transformer", transformer_); + register_module("vae_image_processor", vae_image_processor_); + } + + std::vector _extract_masked_hidden( + const torch::Tensor& hidden_states, + const torch::Tensor& mask) { + auto bool_mask = mask.to(torch::kBool); + auto valid_lengths = bool_mask.sum(1); + + auto valid_lengths_cpu = valid_lengths.to(torch::kCPU).contiguous(); + + std::vector lengths_list; + lengths_list.reserve(valid_lengths_cpu.numel()); + + int64_t* lengths_ptr = valid_lengths_cpu.data_ptr(); + for (int64_t i = 0; i < valid_lengths_cpu.numel(); ++i) { + lengths_list.push_back(lengths_ptr[i]); + } + + auto selected = hidden_states.index({bool_mask}); + auto split_result = torch::split(selected, lengths_list, 0); + + return std::vector(split_result.begin(), split_result.end()); + } + + std::pair _get_qwen_prompt_embeds( + const std::vector& prompt, + const std::vector& image, + torch::TensorOptions& options) { + std::string img_prompt_template = + "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"; + std::string base_img_prompt = ""; + + torch::Tensor prompt_embeds; + torch::Tensor prompt_embeds_mask; + return std::make_pair(prompt_embeds, prompt_embeds_mask); + } + + void _encode_prompt(const std::vector& image, + const std::vector& prompt, + torch::Tensor& prompt_embeds, + torch::Tensor& prompt_embeds_mask, + torch::TensorOptions& options, + int64_t num_images_per_prompt = 1, + int64_t max_sequence_length = 1024) { + int64_t batch_size = prompt_embeds.defined() ? prompt_embeds.size(0) : 1; + if (!prompt_embeds.defined()) { + std::tie(prompt_embeds, prompt_embeds_mask) = + _get_qwen_prompt_embeds(prompt, image, options); + } + + CHECK(prompt_embeds.defined()) + << "currently, the prompt input is not supported for qwen image, " + << "expected a valid prompt_embeds input, but got empty tensor "; + + auto seq_len = prompt_embeds.size(1); + prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1}); + prompt_embeds = + prompt_embeds.view({batch_size * num_images_per_prompt, seq_len, -1}); + if (prompt_embeds_mask.defined()) { + prompt_embeds_mask = + prompt_embeds_mask.repeat({1, num_images_per_prompt, 1}); + prompt_embeds_mask = prompt_embeds_mask.view( + {batch_size * num_images_per_prompt, seq_len}); + } else { + prompt_embeds_mask = + torch::ones({prompt_embeds.size(0), prompt_embeds.size(1)}); + } + } + + torch::Tensor _retrieve_latents(const AutoencoderKLOutput& encoder_output, + int64_t seed = 42, + const std::string& sample_mode = "sample") { + if (sample_mode == "sample") { + return encoder_output.latent_dist.sample(seed); + } else if (sample_mode == "argmax") { + return encoder_output.latent_dist.mode(); + } else { + CHECK(false) + << "sample_mode is expected to be 'sample' or 'argmax', but get: " + << sample_mode; + return torch::Tensor(); + } + } + + torch::Tensor _pack_latents(torch::Tensor latents, + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width) { + latents = latents.view( + {batch_size, num_channels_latents, height / 2, 2, width / 2, 2}); + latents = latents.permute({0, 2, 4, 1, 3, 5}); + latents = latents.reshape( + {batch_size, (height / 2) * (width / 2), num_channels_latents * 4}); + + return latents; + } + + torch::Tensor _unpack_latents(torch::Tensor latents, + int64_t height, + int64_t width, + int64_t vae_scale_factor) { + auto sizes = latents.sizes(); + int64_t batch_size = sizes[0]; + int64_t num_patches = sizes[1]; + int64_t channels = sizes[2]; + + height = 2 * (height / (vae_scale_factor * 2)); + width = 2 * (width / (vae_scale_factor * 2)); + + latents = + latents.view({batch_size, height / 2, width / 2, channels / 4, 2, 2}); + latents = latents.permute({0, 3, 1, 4, 2, 5}); + latents = + latents.reshape({batch_size, channels / (2 * 2), 1, height, width}); + + return latents; + } + + torch::Tensor _encode_vae_image(torch::Tensor image, + int64_t seed, + torch::Device device) { + auto image_latents = _retrieve_latents(vae_->encode(image), seed, "argmax"); + auto latents_mean = + torch::tensor(vae_model_args_.latents_mean(), torch::kDouble); + latents_mean = latents_mean.view({1, latent_channels_, 1, 1, 1}) + .to(device, image_latents.dtype()); + auto latents_std = + torch::tensor(vae_model_args_.latents_std(), torch::kDouble); + latents_std = latents_std.view({1, latent_channels_, 1, 1, 1}) + .to(device, image_latents.dtype()); + image_latents = (image_latents - latents_mean) / latents_std; + return image_latents; + } + + std::pair _prepare_latents( + const std::vector& images, + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width, + torch::TensorOptions& options, + int64_t seed, + torch::Tensor latents = torch::Tensor()) { + height = 2 * (height / (vae_scale_factor_ * 2)); + width = 2 * (width / (vae_scale_factor_ * 2)); + + std::vector shape = { + batch_size, 1, num_channels_latents, height, width}; + + torch::Tensor image_latents; + if (!images.empty()) { + std::vector all_image_latents; + + for (const auto& image : images) { + auto current_image = image.to(options); + torch::Tensor current_image_latents; + + if (current_image.size(1) != latent_channels_) { + current_image_latents = + _encode_vae_image(current_image, seed, device_); + } else { + current_image_latents = current_image; + } + + current_image_latents = torch::cat({current_image_latents}, 0); + int64_t image_latent_height = current_image_latents.size(3); + int64_t image_latent_width = current_image_latents.size(4); + + current_image_latents = _pack_latents(current_image_latents, + batch_size, + num_channels_latents, + image_latent_height, + image_latent_width); + all_image_latents.emplace_back(current_image_latents); + } + + image_latents = torch::cat(all_image_latents, 1); + } + + if (!latents.defined()) { + latents = xllm::dit::randn_tensor(shape, seed, options); + latents = _pack_latents( + latents, batch_size, num_channels_latents, height, width); + } else { + latents = latents.to(options); + } + return std::make_pair(latents, image_latents); + } + + DiTForwardOutput forward(const DiTForwardInput& input) { + torch::NoGradGuard no_grad; + const auto& generation_params = input.generation_params; + auto height = generation_params.height; + auto width = generation_params.width; + auto true_cfg_scale = generation_params.true_cfg_scale; + auto num_inference_steps = generation_params.num_inference_steps; + DiTCache::get_instance().set_infer_steps(num_inference_steps); + DiTCache::get_instance().set_num_blocks(num_layers_); + auto max_sequence_length = generation_params.max_sequence_length; + auto seed = generation_params.seed >= 0 ? generation_params.seed : 42; + + auto prompts = input.prompts; + auto prompts_2 = input.prompts_2; + auto negative_prompts = input.negative_prompts; + auto negative_prompts_2 = input.negative_prompts_2; + auto latents = input.latents; + if (latents.defined()) { + latents = latents.to(options_.device(), dtype_); + } + + auto prompt_embeds = input.prompt_embeds; + if (prompt_embeds.defined()) { + prompt_embeds = prompt_embeds.to(options_.device(), dtype_); + } + auto pooled_prompt_embeds = input.pooled_prompt_embeds; + torch::Tensor prompt_embeds_mask; + + auto negative_prompt_embeds = input.negative_prompt_embeds; + if (negative_prompt_embeds.defined()) { + negative_prompt_embeds = + negative_prompt_embeds.to(options_.device(), dtype_); + } + auto negative_pooled_prompt_embeds = input.negative_pooled_prompt_embeds; + torch::Tensor negative_prompt_embeds_mask; + + std::vector image_list; + + torch::Tensor images; + + if (FLAGS_dit_debug_print) { + input.debug_print(); + } + + if (input.images.defined()) { + images = input.images.to(options_.device(), dtype_); + if (input.images.dim() == 3) { + image_list.emplace_back(images); + } else if (input.images.dim() == 4) { + if (input.images.size(0) > 1) { + LOG(ERROR) << "currently dit models doesn't support batch inference" + << "batch size: " << input.images.size(0); + } + image_list.emplace_back(images[0]); + } else { + LOG(ERROR) + << "image inputs are expected to be a 4 dim tensor, but got: " + << input.images.dim() << "s tensor"; + } + } else { + LOG(ERROR) << "QwenImageEditPlus pipeline expected to have " + << "image inputs"; + } + + torch::Tensor conditional_images; + if (input.condition_images.defined()) { + conditional_images = input.condition_images.to(options_.device(), dtype_); + if (input.condition_images.dim() == 3) { + image_list.emplace_back(conditional_images); + } else if (input.condition_images.dim() == 4) { + if (input.condition_images.size(0) > 1) { + LOG(ERROR) << "currently dit models doesn't support batch inference" + << "batch size: " << input.condition_images.size(0); + } + image_list.emplace_back(conditional_images[0]); + } else { + LOG(ERROR) + << "image inputs are expected to be a 4 dim tensor, but got: " + << input.condition_images.dim() << "s tensor"; + } + } + double height_size = images.size(2); + double width_size = images.size(3); + int64_t num_images_per_prompt = 1; + + double aspect_ratio = width_size / height_size; + auto [calculated_width, calculated_height] = + xllm::dit::calculate_dimensions(1024 * 1024, aspect_ratio); + + height = (height == 0) ? calculated_height : height; + width = (width == 0) ? calculated_width : width; + + int64_t multiple_of = vae_scale_factor_ * 2; + width = (width / multiple_of) * multiple_of; + height = (height / multiple_of) * multiple_of; + + current_timestep_ = torch::Tensor(); + + int64_t batch_size = prompts.size(); + std::vector condition_images; + std::vector vae_images; + std::vector> condition_image_sizes; + std::vector> vae_image_sizes; + + if (images.defined() && !(images.size(1) == latent_channels_)) { + for (size_t i = 0; i < image_list.size(); i++) { + aspect_ratio = + static_cast(image_list[i].size(2)) / image_list[i].size(1); + auto [condition_width, condition_height] = + xllm::dit::calculate_dimensions(CONDITION_IMAGE_SIZE, aspect_ratio); + auto [vae_width, vae_height] = + xllm::dit::calculate_dimensions(VAE_IMAGE_SIZE, aspect_ratio); + condition_image_sizes.push_back({condition_width, condition_height}); + vae_image_sizes.push_back({vae_width, vae_height}); + + auto img = image_list[i].unsqueeze(0); + auto condition_img = vae_image_processor_->resize( + img, condition_height, condition_width); + auto vae_img = + vae_image_processor_->preprocess(img, vae_height, vae_width) + .unsqueeze(2); + + condition_images.push_back(condition_img); + vae_images.push_back(vae_img); + } + } + + bool has_neg_prompt = negative_prompts.size() > 0; + + bool do_true_cfg = (true_cfg_scale > 1.0) && has_neg_prompt; + // inplace update prompt_embeds and prompt_embeds_mask + _encode_prompt(condition_images, + prompts, + prompt_embeds, + prompt_embeds_mask, + options_, + num_images_per_prompt, + max_sequence_length); + + if (do_true_cfg) { + // inplace update negative_prompt_embeds and negative_prompt_embeds_mask + _encode_prompt(condition_images, + negative_prompts, + negative_prompt_embeds, + negative_prompt_embeds_mask, + options_, + num_images_per_prompt, + max_sequence_length); + } + + int64_t num_channels_latents = in_channels_ / 4; + torch::Tensor final_latents; + torch::Tensor image_latents; + + std::tie(final_latents, image_latents) = + _prepare_latents(vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + options_, + seed, + latents); + + std::vector> main_shape = { + {1, height / vae_scale_factor_ / 2, width / vae_scale_factor_ / 2}}; + + for (const auto& [vae_width, vae_height] : vae_image_sizes) { + main_shape.push_back({1, + vae_height / vae_scale_factor_ / 2, + vae_width / vae_scale_factor_ / 2}); + } + + std::vector new_sigmas; + for (int64_t i = 0; i < num_inference_steps; ++i) { + new_sigmas.push_back(1.0f - static_cast(i) / + (num_inference_steps - 1) * + (1.0f - 1.0f / num_inference_steps)); + } + + int64_t image_seq_len = final_latents.size(1); + float mu = xllm::dit::calculate_shift(image_seq_len, + scheduler_->base_image_seq_len(), + scheduler_->max_image_seq_len(), + scheduler_->base_shift(), + scheduler_->max_shift()); + auto [timesteps, num_inference_steps_actual] = + xllm::dit::retrieve_timesteps( + scheduler_, num_inference_steps, device_, new_sigmas, mu); + int64_t num_warmup_steps = + std::max(static_cast(timesteps.numel()) - + num_inference_steps_actual * scheduler_->order(), + static_cast(0LL)); + + num_timesteps_ = timesteps.size(0); + torch::Tensor txt_seq_lens; + if (prompt_embeds_mask.defined()) { + txt_seq_lens = prompt_embeds_mask.sum(1); + } + torch::Tensor negative_txt_seq_lens; + if (do_true_cfg && negative_prompt_embeds_mask.defined()) { + negative_txt_seq_lens = negative_prompt_embeds_mask.sum(1); + } + /* + if (prompt_embeds.size(1) % FLAGS_sp_size != 0) { + int64_t pad_len = + FLAGS_sp_size - prompt_embeds.size(1) % FLAGS_sp_size; + std::vector pad_with = { + 0, + 0, // 第3维�~Hhe ight�~I� ~Mpad + 0, + pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad + 0, + 0}; // 第1维�~Hbatch�~I�~Mpad + std::vector pad_with_mask = { + // 第3维�~Hhe ight�~I� ~Mpad + 0, + pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad + 0, + 0}; // 第1维�~Hbatch�~I�~Mpad + prompt_embeds = torch::pad(prompt_embeds, pad_with, "constant", 0); + prompt_embeds_mask = + torch::pad(prompt_embeds_mask, pad_with_mask, "constant", 0); + } + + if (negative_prompt_embeds.size(1) % FLAGS_sp_size != 0) { + int64_t pad_len = FLAGS_sp_size - + negative_prompt_embeds.size(1) % FLAGS_sp_size; + std::vector pad_with = { + 0, + 0, // 第3维�~Hhe ight�~I� ~Mpad + 0, + pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad + 0, + 0}; // 第1维�~Hbatch�~I�~Mpad + std::vector pad_with_mask = { + // 第3维�~Hhe ight�~I� ~Mpad + 0, + pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad + 0, + 0}; + negative_prompt_embeds = + torch::pad(negative_prompt_embeds, pad_with, "constant", 0); + negative_prompt_embeds_mask = + torch::pad(negative_prompt_embeds_mask, pad_with_mask, "constant", 0); + } + */ + scheduler_->set_begin_index(0); + for (int64_t i = 0; i < timesteps.size(0); ++i) { + auto t = timesteps[i]; + current_timestep_ = t; + + auto latent_model_input = final_latents; + if (image_latents.defined()) { + latent_model_input = torch::cat({final_latents, image_latents}, 1); + } + + auto timestep_expanded = + t.expand({final_latents.size(0)}).to(final_latents.dtype()); + + torch::Tensor noise_pred; + torch::Tensor neg_noise_pred; + torch::Tensor pos_neg_noise_preds; + if (FLAGS_cfg_size == 2 && do_true_cfg) { + auto rank = parallel_args_.dit_cfg_group_->rank(); + if (rank == 0) { + noise_pred = transformer_->forward(latent_model_input, + prompt_embeds, + prompt_embeds_mask, + timestep_expanded / 1000.0, + main_shape, + txt_seq_lens, + /*use_cfg=*/false, + /*step_index=*/i); + noise_pred = noise_pred.slice(1, 0, final_latents.size(1)); + pos_neg_noise_preds = + xllm::parallel_state::gather(noise_pred, + parallel_args_.dit_cfg_group_, + /*dim=*/0); + } else { + neg_noise_pred = transformer_->forward(latent_model_input, + negative_prompt_embeds, + negative_prompt_embeds_mask, + timestep_expanded / 1000.0, + main_shape, + negative_txt_seq_lens, + /*use_cfg=*/true, + /*step_index=*/i); + + neg_noise_pred = neg_noise_pred.slice(1, 0, final_latents.size(1)); + pos_neg_noise_preds = + xllm::parallel_state::gather(neg_noise_pred, + parallel_args_.dit_cfg_group_, + /*dim=*/0); + } + auto noise_preds = torch::chunk(pos_neg_noise_preds, 2, 0); + auto comb_pred = + noise_preds[1] + true_cfg_scale * (noise_preds[0] - noise_preds[1]); + auto cond_norm = torch::norm(noise_preds[0], 2, -1, true); + auto noise_norm = torch::norm(comb_pred, 2, -1, true); + noise_pred = comb_pred * (cond_norm / noise_norm); + + } else { + noise_pred = transformer_->forward(latent_model_input, + prompt_embeds, + prompt_embeds_mask, + timestep_expanded / 1000.0, + main_shape, + txt_seq_lens, + /*use_cfg=*/false, + /*step_index=*/i); + noise_pred = noise_pred.slice(1, 0, final_latents.size(1)); + if (do_true_cfg) { + neg_noise_pred = transformer_->forward(latent_model_input, + negative_prompt_embeds, + negative_prompt_embeds_mask, + timestep_expanded / 1000.0, + main_shape, + negative_txt_seq_lens, + /*use_cfg=*/true, + /*step_index=*/i); + + neg_noise_pred = neg_noise_pred.slice(1, 0, final_latents.size(1)); + + auto comb_pred = + neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred); + auto cond_norm = torch::norm(noise_pred, 2, -1, true); + auto noise_norm = torch::norm(comb_pred, 2, -1, true); + noise_pred = comb_pred * (cond_norm / noise_norm); + } + } + + auto latents_dtype = final_latents.dtype(); + final_latents = scheduler_->step(noise_pred, t, final_latents); + if (final_latents.dtype() != latents_dtype) { + final_latents = final_latents.to(latents_dtype); + } + } + + current_timestep_ = torch::Tensor(); + + torch::Tensor output_image; + + auto unpacked_latents = + _unpack_latents(final_latents, height, width, vae_scale_factor_) + .to(dtype_); + auto latents_mean = + torch::tensor(vae_model_args_.latents_mean(), torch::kDouble); + latents_mean = latents_mean.view({1, latent_channels_, 1, 1, 1}) + .to(device_, image_latents.dtype()); + auto latents_std = + torch::tensor(vae_model_args_.latents_std(), torch::kDouble); + latents_std = 1.0 / latents_std.view({1, latent_channels_, 1, 1, 1}) + .to(device_, image_latents.dtype()); + + unpacked_latents = unpacked_latents / latents_std + latents_mean; + output_image = vae_->decode(unpacked_latents).sample.squeeze(2); + output_image = vae_image_processor_->postprocess(output_image, "pil"); + auto output = std::vector{{output_image}}; + DiTForwardOutput out; + out.tensors = output; + return out; + } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "QwenImageEditPlusPipeline loading model from" + << loader->model_root_path(); + std::string model_path = loader->model_root_path(); + auto transformer_loader = loader->take_component_loader("transformer"); + auto vae_loader = loader->take_component_loader("vae"); + auto clip_loader = loader->take_component_loader("text_encoder"); + auto tokenizer_loader = loader->take_component_loader("tokenizer"); + auto processor_loader = loader->take_component_loader("processor"); + LOG(INFO) << " QwenImageEditplus model components loaded, start to load " + "weights to sub models"; + + vae_->load_model(std::move(vae_loader)); + vae_->to(options_.device(), dtype_); + transformer_->load_model(std::move(transformer_loader)); + transformer_->to(options_.device(), dtype_); + } + + private: + int64_t vae_scale_factor_; + int64_t latent_channels_; + int64_t tokenizer_max_length_; + int64_t prompt_template_encode_start_idx_; + int64_t default_sample_size_; + int64_t in_channels_; + int64_t num_timesteps_; + int64_t num_layers_; + const ParallelArgs parallel_args_; + torch::Tensor current_timestep_; + string prompt_template_encode_; + const ModelArgs& vae_model_args_; +}; + +REGISTER_MODEL_ARGS(Qwen2Tokenizer, [&] {}); +TORCH_MODULE(QwenImageEditPlusPipeline); + +REGISTER_DIT_MODEL(QwenImageEditPlusPipeline, QwenImageEditPlusPipeline); +} // namespace qwenimage +} // namespace xllm::dit::npu diff --git a/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h new file mode 100644 index 000000000..2d4f5b4f0 --- /dev/null +++ b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h @@ -0,0 +1,2171 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "core/framework/dit_cache/dit_cache.h" +#include "core/framework/dit_model_loader.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/state_dict/state_dict.h" +#include "core/framework/state_dict/utils.h" +#include "core/layers/common/add_matmul.h" +#include "framework/model_context.h" +#include "framework/parallel_state/parallel_state.h" +#include "models/dit/utils/dit_parallel_linear.h" +#include "models/dit/utils/sequence_parallel_pad_manager.h" +#include "models/model_registry.h" + +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#else +#include +#include +#endif + +#include +#include + +#include +namespace xllm::dit::npu { +namespace qwenimage { + +inline torch::Tensor gather_sequence(const torch::Tensor& input_, + int64_t dim, + ProcessGroup* pg) { + auto group_size = pg->world_size(); + auto input = input_.contiguous(); + if (group_size == 1) { + return input; + } + + // all gather + auto tensor_list = parallel_state::gather(input, pg, dim); + + // concat + auto output = torch::cat(tensor_list, dim); + + return output; +} + +inline torch::Tensor split_sequence(const torch::Tensor& input, + int64_t dim, + ProcessGroup* pg) { + auto group_size = pg->world_size(); + auto rank = pg->rank(); + + if (group_size == 1) { + return input; + } + + torch::Tensor input_ = input; + + int64_t dim_size = input_.size(dim); + + auto tensor_list = torch::split(input_, dim_size / group_size, dim); + auto output = tensor_list[rank].contiguous(); + return output; +} + +// TODO: This class should be extracted from dit class and integrated into a +// common class. +class RMSNormImpl : public torch::nn::Module { + public: + // Constructor: dim (normalization dimension), eps (stabilization term) + // elementwise_affine (enable affine transform), bias (enable bias term) + RMSNormImpl(int64_t dim, double eps, bool elementwise_affine, bool bias) + : eps_(eps), elementwise_affine_(elementwise_affine), is_bias_(bias) { + if (elementwise_affine_) { + weight_ = register_parameter("weight", torch::ones({dim})); + if (is_bias_) { + bias_ = register_parameter("bias", torch::zeros({dim})); + } + } + } + + torch::Tensor forward(const torch::Tensor& hidden_states) { + auto [output, rstd] = + at_npu::native::custom_ops::npu_rms_norm(hidden_states, weight_, eps_); + if (is_bias_ && bias_.defined()) { + output = output + bias_; + } + return output; + } + + void load_state_dict(const StateDict& state_dict) { + if (elementwise_affine_) { + weight::load_weight(state_dict, "weight", weight_, weight_is_loaded_); + if (is_bias_) { + weight::load_weight(state_dict, "bias", bias_, bias_is_loaded_); + } + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(weight_is_loaded_) + << "weight is not loaded for " << prefix + "weight"; + CHECK(!is_bias_ || bias_is_loaded_) + << "bias is not loaded for " << prefix + "bias"; + } + + private: + double eps_; // Small epsilon to avoid division by zero + bool elementwise_affine_; // Whether to apply learnable affine parameters + torch::Tensor weight_; // Learnable scale parameter + torch::Tensor bias_; // Learnable bias parameter (optional) + bool is_bias_; + bool weight_is_loaded_{false}; + bool bias_is_loaded_{false}; +}; +TORCH_MODULE(RMSNorm); + +// TODO: This class should be extracted from dit class and integrated into a +// common class. +class AdaLayerNormContinuousImpl : public torch::nn::Module { + public: + explicit AdaLayerNormContinuousImpl(const ModelContext& context, + int64_t embedding_dim, + int64_t conditioning_embedding_dim, + bool elementwise_affine = true, + double eps = 1e-5, + bool bias = true) + : options_(context.get_tensor_options()) { + ModelArgs model_args = context.get_model_args(); + silu_ = register_module("silu", torch::nn::SiLU()); + linear_ = register_module( + "linear", + layer::AddMatmulWeightTransposed( + conditioning_embedding_dim, 2 * embedding_dim, bias, options_)); + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({embedding_dim}) + .elementwise_affine(false) + .eps(eps))); + } + + torch::Tensor forward(const torch::Tensor& x, + const torch::Tensor& conditioning_embedding) { + auto cond_emb = silu_->forward(conditioning_embedding); + cond_emb = cond_emb.to(x.dtype()); + + auto emb = linear_->forward(cond_emb); + auto chunks = torch::chunk(emb, 2, 1); + torch::Tensor scale, shift; + + scale = chunks[0]; + shift = chunks[1]; + auto x_norm = norm_->forward(x); + return x_norm * (1 + scale).unsqueeze(1) + shift.unsqueeze(1); + } + + void load_state_dict(const StateDict& state_dict) { + // linear + linear_->load_state_dict(state_dict.get_dict_with_prefix("linear.")); + } + + void verify_loaded_weights(const std::string& prefix) { + linear_->verify_loaded_weights(prefix + "linear."); + } + + private: + layer::AddMatmulWeightTransposed linear_{nullptr}; + torch::nn::SiLU silu_{nullptr}; + torch::nn::LayerNorm norm_{nullptr}; + double eps_; + std::string norm_type_; + bool elementwise_affine_; + torch::Tensor rms_scale_{nullptr}; + torch::TensorOptions options_; +}; +TORCH_MODULE(AdaLayerNormContinuous); + +// TODO: This class should be extracted from dit class and integrated into a +// common class. +class AdaLayerNormImpl : public torch::nn::Module { + public: + AdaLayerNormImpl(const ModelContext& contex, + int64_t hidden_size, + double eps = 1e-6) + : hidden_size_(hidden_size), eps_(eps) { + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({hidden_size}) + .elementwise_affine(false) + .eps(eps))); + } + + std::tuple forward( + const torch::Tensor& x, + const torch::Tensor& mod_params, + const torch::Tensor& index = torch::Tensor()) { + auto chunks = mod_params.chunk(3, -1); + auto shift = chunks[0]; + auto scale = chunks[1]; + auto gate = chunks[2]; + torch::Tensor shift_result, scale_result, gate_result; + + if (index.defined()) { + // Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + // So shift, scale, gate have shape [2*actual_batch, d] + int64_t actual_batch = shift.size(0) / 2; + + // Split into two parts + auto shift_0 = shift.slice(0, 0, actual_batch); + auto shift_1 = shift.slice(0, actual_batch, shift.size(0)); + + auto scale_0 = scale.slice(0, 0, actual_batch); + auto scale_1 = scale.slice(0, actual_batch, scale.size(0)); + + auto gate_0 = gate.slice(0, 0, actual_batch); + auto gate_1 = gate.slice(0, actual_batch, gate.size(0)); + + // index: [b, l] where b is actual batch size + // Expand to [b, l, 1] to match feature dimension + auto index_expanded = index.unsqueeze(-1); // [b, l, 1] + + // Expand chunks to [b, 1, d] then broadcast to [b, l, d] + auto shift_0_exp = shift_0.unsqueeze(1); // [b, 1, d] + auto shift_1_exp = shift_1.unsqueeze(1); // [b, 1, d] + auto scale_0_exp = scale_0.unsqueeze(1); + auto scale_1_exp = scale_1.unsqueeze(1); + auto gate_0_exp = gate_0.unsqueeze(1); + auto gate_1_exp = gate_1.unsqueeze(1); + + // Use torch::where to select based on index + shift_result = + torch::where(index_expanded == 0, shift_0_exp, shift_1_exp); + scale_result = + torch::where(index_expanded == 0, scale_0_exp, scale_1_exp); + gate_result = torch::where(index_expanded == 0, gate_0_exp, gate_1_exp); + } else { + shift_result = shift.unsqueeze(1); + scale_result = scale.unsqueeze(1); + gate_result = gate.unsqueeze(1); + } + + scale_result = 1 + scale_result; + + // auto result = at_npu::native::custom_ops::npu_layer_norm_eval( + // x, {hidden_size_}, scale_result, shift_result, eps_); + auto x_norm = norm_->forward(x); + auto result = x_norm * scale_result + shift_result; + return std::make_tuple(result, gate_result); + } + + private: + double eps_; + int64_t hidden_size_; + torch::nn::LayerNorm norm_{nullptr}; +}; +TORCH_MODULE(AdaLayerNorm); + +torch::Tensor apply_rotary_emb_qwen(const torch::Tensor& x, + const torch::Tensor& freqs_cis, + bool use_real = true, + int64_t use_real_unbind_dim = -1) { + auto cos = torch::real(freqs_cis); + auto sin = torch::imag(freqs_cis); + + int64_t seqlen = cos.size(0); + + auto cos_expanded = cos.unsqueeze(0) + .unsqueeze(2) + .unsqueeze(-1) + .expand({-1, -1, -1, -1, 2}) + .reshape({1, seqlen, 1, -1}); + auto sin_expanded = sin.unsqueeze(0) + .unsqueeze(2) + .unsqueeze(-1) + .expand({-1, -1, -1, -1, 2}) + .reshape({1, seqlen, 1, -1}); + auto x_out = at_npu::native::custom_ops::npu_rotary_mul( + x, cos_expanded, sin_expanded, "interleave"); + return x_out.to(x.dtype()); +} + +class TimestepsImpl : public torch::nn::Module { + public: + TimestepsImpl(const ModelContext& context, + int64_t num_channels, + bool flip_sin_to_cos, + double downscale_freq_shift, + double scale, + int64_t max_period = 10000) + : embedding_dim_(num_channels), + flip_sin_to_cos_(flip_sin_to_cos), + downscale_freq_shift_(downscale_freq_shift), + scale_(scale), + max_period_(max_period) {} + + torch::Tensor forward(const torch::Tensor& timesteps) { + CHECK(timesteps.dim() == 1) << "Timesteps should be a 1d-array"; + + int64_t half_dim = embedding_dim_ / 2; + + auto exponent = + -std::log(max_period_) * torch::arange(0, + half_dim, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(timesteps.device())); + + exponent = exponent / (half_dim - downscale_freq_shift_); + auto emb = torch::exp(exponent); + emb = timesteps.unsqueeze(1).to(torch::kFloat) * emb.unsqueeze(0); + + emb = scale_ * emb; + + // concat sine and cosine embeddings + auto sin_emb = torch::sin(emb); + auto cos_emb = torch::cos(emb); + emb = torch::cat({sin_emb, cos_emb}, /*dim=*/-1); + // flip sine and cosine embeddings + if (flip_sin_to_cos_) { + emb = torch::cat({cos_emb, sin_emb}, /*dim=*/-1); + } + // zero pad + if (embedding_dim_ % 2 == 1) { + emb = torch::nn::functional::pad( + emb, torch::nn::functional::PadFuncOptions({0, 1})); + } + return emb; + } + + private: + int64_t embedding_dim_; + int64_t max_period_; + bool flip_sin_to_cos_; + double scale_; + double downscale_freq_shift_; +}; +TORCH_MODULE(Timesteps); + +// TODO: a factory function that provides activation functions based on string +// input +std::function get_activation( + const std::string& act_fn) { + if (act_fn == "silu") { + return [](const torch::Tensor& x) { return torch::silu(x); }; + } else if (act_fn == "relu") { + return [](const torch::Tensor& x) { return torch::relu(x); }; + } else if (act_fn == "gelu") { + return [](const torch::Tensor& x) { return torch::gelu(x); }; + } else if (act_fn == "tanh") { + return [](const torch::Tensor& x) { return torch::tanh(x); }; + } else if (act_fn == "sigmoid") { + return [](const torch::Tensor& x) { return torch::sigmoid(x); }; + } else if (act_fn == "none" || act_fn.empty()) { + return [](const torch::Tensor& x) { return x; }; + } else { + LOG(ERROR) << "Unsupported activation function: " << act_fn; + throw std::out_of_range( + "activation function out of range, given activation function: " + + act_fn); + } +} + +class TimestepEmbeddingImpl : public torch::nn::Module { + public: + TimestepEmbeddingImpl(const ModelContext& context, + int64_t in_channels, + int64_t time_embed_dim, + const std::string& act_fn = "silu", + int64_t out_dim = -1, + const std::string& post_act_fn = "", + int64_t cond_proj_dim = -1, + bool sample_proj_bias = true) + : options_(context.get_tensor_options()) { + linear_1_ = register_module( + "linear_1", + layer::AddMatmulWeightTransposed( + in_channels, time_embed_dim, sample_proj_bias, options_)); + + if (cond_proj_dim > 0) { + cond_proj_ = + register_module("cond_proj", + layer::AddMatmulWeightTransposed( + cond_proj_dim, in_channels, false, options_)); + } + + act_fn_ = register_module("act_fn", torch::nn::SiLU()); + + int64_t time_embed_dim_out = (out_dim > 0) ? out_dim : time_embed_dim; + + linear_2_ = register_module( + "linear_2", + layer::AddMatmulWeightTransposed( + time_embed_dim, time_embed_dim_out, sample_proj_bias, options_)); + } + + torch::Tensor forward(const torch::Tensor& sample, + const torch::Tensor& condition = torch::Tensor()) { + torch::Tensor x = sample; + + if (cond_proj_) { + x = x + cond_proj_->forward(condition); + } + x = linear_1_->forward(x); + x = act_fn_(x); + x = linear_2_->forward(x); + + return x; + } + + void load_state_dict(const StateDict& state_dict) { + // linear1 + linear_1_->load_state_dict(state_dict.get_dict_with_prefix("linear_1.")); + // linear2 + linear_2_->load_state_dict(state_dict.get_dict_with_prefix("linear_2.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + linear_1_->verify_loaded_weights(prefix + "linear_1."); + linear_2_->verify_loaded_weights(prefix + "linear_2."); + } + + private: + torch::TensorOptions options_; + torch::nn::SiLU act_fn_{nullptr}; + layer::AddMatmulWeightTransposed linear_1_{nullptr}; + layer::AddMatmulWeightTransposed linear_2_{nullptr}; + layer::AddMatmulWeightTransposed cond_proj_{nullptr}; +}; +TORCH_MODULE(TimestepEmbedding); + +std::tuple, std::optional> +compute_text_seq_len_from_mask( + const torch::Tensor& encoder_hidden_states, + const std::optional& encoder_hidden_states_mask) { + auto batch_size = encoder_hidden_states.size(0); + auto text_seq_len = encoder_hidden_states.size(1); + + if (!encoder_hidden_states_mask.has_value()) { + return std::make_tuple(text_seq_len, std::nullopt, std::nullopt); + } + + auto mask = + encoder_hidden_states_mask.value().to(encoder_hidden_states.device()); + + if (mask.size(0) != batch_size || mask.size(1) != text_seq_len) { + LOG(ERROR) << "`encoder_hidden_states_mask` shape " << mask.sizes() + << " must match (batch_size, text_seq_len)=(" << batch_size + << ", " << text_seq_len << ")."; + } + + if (mask.dtype() != torch::kBool) { + mask = mask.to(torch::kBool); + } + + auto device = encoder_hidden_states.device(); + auto position_ids = torch::arange( + text_seq_len, torch::TensorOptions().device(device).dtype(torch::kLong)); + + // Compute active positions (use position ID where mask is True, else 0) + auto zero_tensor = torch::zeros( + {}, torch::TensorOptions().device(device).dtype(torch::kLong)); + + auto active_positions = torch::where(mask, position_ids, zero_tensor); + + // Check which samples have active positions + auto has_active = mask.any(/*dim=*/1); + + // Compute per-sample length: max position + 1 if active, else use full length + auto max_positions = std::get<0>(active_positions.max(/*dim=*/1)); + auto per_sample_len = torch::where( + has_active, + max_positions + 1, + torch::tensor(text_seq_len, + torch::TensorOptions().device(device).dtype(torch::kLong))); + + return std::make_tuple(text_seq_len, per_sample_len, mask); +} + +class QwenTimestepProjEmbeddingsImpl : public torch::nn::Module { + public: + QwenTimestepProjEmbeddingsImpl(const ModelContext& context, + int64_t embedding_dim, + bool use_additional_t_cond = false) + : use_additional_t_cond_(use_additional_t_cond) { + time_proj_ = register_module("time_proj", + Timesteps(context, + /*num_channels=*/256, + /*flip_sin_to_cos=*/true, + /*downscale_freq_shift=*/0.0, + /*scale=*/1000)); + timestep_embedder_ = + register_module("timestep_embedder", + TimestepEmbedding(context, + /*in_channels=*/256, + /*time_embed_dim*/ embedding_dim)); + if (use_additional_t_cond) { + addition_t_embedding_ = + register_module("addition_t_embedding", + torch::nn::Embedding(torch::nn::EmbeddingOptions( + /*num=*/2, embedding_dim))); + } + } + + torch::Tensor forward( + const torch::Tensor& timestep, + const torch::Tensor& hidden_states, + const torch::Tensor& addition_t_cond = torch::Tensor()) { + auto timesteps_proj = time_proj_->forward(timestep); + auto timesteps_emb = + timestep_embedder_->forward(timesteps_proj.to(hidden_states.dtype())); + + torch::Tensor conditioning = timesteps_emb; + if (use_additional_t_cond_) { + CHECK(addition_t_cond.defined()) + << "expected to pass addition_t_cond when" + << " use_additional_t_cond_ is setup to true"; + auto addition_t_emb = addition_t_embedding_->forward(addition_t_cond); + addition_t_emb = addition_t_emb.to(hidden_states.dtype()); + conditioning = conditioning + addition_t_emb; + } + + return conditioning; + } + void load_state_dict(const StateDict& state_dict) { + timestep_embedder_->load_state_dict( + state_dict.get_dict_with_prefix("timestep_embedder.")); + if (use_additional_t_cond_) { + weight::load_weight(state_dict, + "addition_t_embedding.weight", + addition_t_embedding_->weight, + weight_is_loaded_); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + timestep_embedder_->verify_loaded_weights(prefix + "timestep_embedder."); + if (use_additional_t_cond_) { + CHECK(weight_is_loaded_) + << "weight is not loaded for " << prefix + "weight"; + } + } + + private: + Timesteps time_proj_{nullptr}; + TimestepEmbedding timestep_embedder_{nullptr}; + torch::nn::Embedding addition_t_embedding_{nullptr}; + bool use_additional_t_cond_; + bool weight_is_loaded_{false}; +}; +TORCH_MODULE(QwenTimestepProjEmbeddings); + +class QwenEmbedRopeImpl : public torch::nn::Module { + public: + QwenEmbedRopeImpl(const ModelContext& context, + int64_t theta, + std::vector axes_dim, + bool scale_rope = false) + : theta_(theta), axes_dim_(axes_dim), scale_rope_(scale_rope) { + auto pos_index = torch::arange(4096); + auto neg_index = torch::arange(4096).flip(0) * -1 - 1; + + pos_freqs_ = torch::cat({rope_params(pos_index, axes_dim[0], theta), + rope_params(pos_index, axes_dim[1], theta), + rope_params(pos_index, axes_dim[2], theta)}, + 1); + + neg_freqs_ = torch::cat({rope_params(neg_index, axes_dim[0], theta), + rope_params(neg_index, axes_dim[1], theta), + rope_params(neg_index, axes_dim[2], theta)}, + 1); + } + + std::tuple forward( + const std::vector>& video_fhw, + const std::optional& txt_seq_lens, + torch::Device device, + const std::optional& max_txt_seq_len) { + if (pos_freqs_.device() != device) { + pos_freqs_ = pos_freqs_.to(device); + neg_freqs_ = neg_freqs_.to(device); + } + + std::vector vid_freqs; + int64_t max_vid_index = 0; + + for (size_t idx = 0; idx < video_fhw.size(); idx++) { + const auto& fhw = video_fhw[idx]; + int64_t frame = fhw[0], height = fhw[1], width = fhw[2]; + + std::string rope_key = std::to_string(idx) + "_" + + std::to_string(height) + "_" + + std::to_string(width); + + auto video_freq = _compute_video_freqs(frame, height, width, idx, device); + vid_freqs.push_back(video_freq); + + if (scale_rope_) { + max_vid_index = std::max({height / 2, width / 2, max_vid_index}); + } else { + max_vid_index = std::max({height, width, max_vid_index}); + } + } + + int64_t max_len; + if (txt_seq_lens.has_value() && !max_txt_seq_len.has_value()) { + max_len = txt_seq_lens.value(); + } else if (max_txt_seq_len.has_value()) { + max_len = max_txt_seq_len.value(); + } else { + LOG(ERROR) << "need to pass txt_seq_lens or max_txt_seq_len " + << "to calculate the mrope"; + } + + auto txt_freqs = + pos_freqs_.slice(0, max_vid_index, max_vid_index + max_len); + auto vid_freqs_cat = torch::cat(vid_freqs, 0); + return std::make_tuple(vid_freqs_cat, txt_freqs); + } + + protected: + torch::Tensor rope_params(const torch::Tensor& index, + int64_t dim, + int64_t theta) { + CHECK(dim % 2 == 0) << "dim must be even"; + + auto exponents = + torch::arange( + 0, dim, 2, torch::TensorOptions().dtype(torch::kFloat32)) / + static_cast(dim); + auto freqs = 1.0 / torch::pow(theta, exponents); + + auto outer_result = torch::outer(index.to(torch::kFloat32), freqs); + + auto complex_freqs = + torch::polar(torch::ones_like(outer_result), outer_result); + + return complex_freqs; + } + + torch::Tensor _compute_video_freqs(int64_t frame, + int64_t height, + int64_t width, + int64_t idx, + torch::Device device) { + int64_t seq_lens = frame * height * width; + + auto pos_freqs = pos_freqs_.to(device); + auto neg_freqs = neg_freqs_.to(device); + + std::vector split_sizes; + for (auto dim : axes_dim_) { + split_sizes.push_back(dim / 2); + } + + auto freqs_pos_chunks = pos_freqs_.split_with_sizes(split_sizes, 1); + auto freqs_neg_chunks = neg_freqs_.split_with_sizes(split_sizes, 1); + + auto freqs_frame = freqs_pos_chunks[0] + .slice(0, idx, idx + frame) + .view({frame, 1, 1, -1}) + .expand({frame, height, width, -1}); + + torch::Tensor freqs_height, freqs_width; + if (scale_rope_) { + auto height_neg_part = freqs_neg_chunks[1].slice( + 0, -(height - height / 2), torch::indexing::None); + auto height_pos_part = freqs_pos_chunks[1].slice(0, 0, height / 2); + freqs_height = torch::cat({height_neg_part, height_pos_part}, 0) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + + auto width_neg_part = freqs_neg_chunks[2].slice( + 0, -(width - width / 2), torch::indexing::None); + auto width_pos_part = freqs_pos_chunks[2].slice(0, 0, width / 2); + freqs_width = torch::cat({width_neg_part, width_pos_part}, 0) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } else { + freqs_height = freqs_pos_chunks[1] + .slice(0, 0, height) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + + freqs_width = freqs_pos_chunks[2] + .slice(0, 0, width) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } + auto freqs = torch::cat({freqs_frame, freqs_height, freqs_width}, -1) + .reshape({seq_lens, -1}); + return freqs.contiguous(); + } + + int64_t theta_; + std::vector axes_dim_; + bool scale_rope_; + torch::Tensor pos_freqs_; + torch::Tensor neg_freqs_; + std::unordered_map rope_cache_; +}; + +TORCH_MODULE(QwenEmbedRope); + +class QwenEmbedRopeWithCacheImpl : public QwenEmbedRopeImpl { + public: + QwenEmbedRopeWithCacheImpl(const ModelContext& context, + int64_t theta, + std::vector axes_dim, + bool scale_rope = false) + : QwenEmbedRopeImpl(context, theta, axes_dim, scale_rope) {} + + private: + torch::Tensor _compute_video_freqs_cached(int64_t frame, + int64_t height, + int64_t width, + int64_t idx, + torch::Device device) { + std::string key = std::to_string(idx) + "_" + std::to_string(height) + "_" + + std::to_string(width); + + auto it = rope_cache_.find(key); + if (it != rope_cache_.end()) { + return it->second; + } else { + auto result = _compute_video_freqs(frame, height, width, idx, device); + rope_cache_[key] = result; + return result; + } + } + + std::unordered_map rope_cache_; +}; +TORCH_MODULE(QwenEmbedRopeWithCache); + +class QwenEmbedLayer3DRopeImpl : public torch::nn::Module { + public: + QwenEmbedLayer3DRopeImpl(const ModelContext& context, + int64_t theta, + std::vector& axes_dim, + bool scale_rope = false) + : theta_(theta), axes_dim_(axes_dim), scale_rope_(scale_rope) { + auto pos_index = torch::arange(4096); + auto neg_index = torch::arange(4096).flip(0) * -1 - 1; + + std::vector pos_freqs_parts; + pos_freqs_ = torch::cat({rope_params(pos_index, axes_dim[0], theta), + rope_params(pos_index, axes_dim[1], theta), + rope_params(pos_index, axes_dim[2], theta)}, + 1); + + neg_freqs_ = torch::cat({rope_params(neg_index, axes_dim[0], theta), + rope_params(neg_index, axes_dim[1], theta), + rope_params(neg_index, axes_dim[2], theta)}, + 1); + } + + virtual std::pair forward( + const std::vector>& video_fhw, + int64_t max_txt_seq_len, + torch::Device device = torch::Device(torch::kCPU)) { + std::vector vid_freqs_list; + int64_t max_vid_index = 0; + int64_t layer_num = video_fhw.size() - 1; + + for (size_t idx = 0; idx < video_fhw.size(); idx++) { + const std::vector& fhw = video_fhw[idx]; + + int64_t frame = fhw[0]; + int64_t height = fhw[1]; + int64_t width = fhw[2]; + + torch::Tensor video_freq; + + if (idx != layer_num) { + video_freq = _compute_video_freqs(frame, height, width, idx, device); + } else { + video_freq = _compute_condition_freqs(frame, height, width, device); + } + vid_freqs_list.push_back(video_freq); + + if (scale_rope_) { + max_vid_index = std::max({height / 2, width / 2, max_vid_index}); + } else { + max_vid_index = std::max({height, width, max_vid_index}); + } + } + + int64_t max_txt_seq_len_int = std::max(max_vid_index, layer_num); + + torch::Tensor txt_freqs = pos_freqs_.to(device).slice( + 0, max_vid_index, max_vid_index + max_txt_seq_len_int); + + torch::Tensor vid_freqs = torch::cat(vid_freqs_list, 0); + + return {vid_freqs, txt_freqs}; + } + + protected: + torch::Tensor rope_params(torch::Tensor index, int64_t dim, int64_t theta) { + CHECK(dim % 2 == 0) << "dim must be even"; + + auto exponents = + torch::arange( + 0, dim, 2, torch::TensorOptions().dtype(torch::kFloat32)) / + static_cast(dim); + auto freqs = 1.0 / torch::pow(theta, exponents); + + auto outer_result = torch::outer(index.to(torch::kFloat32), freqs); + + auto complex_freqs = + torch::polar(torch::ones_like(outer_result), outer_result); + + return complex_freqs; + } + + torch::Tensor _compute_video_freqs(int64_t frame, + int64_t height, + int64_t width, + int64_t idx, + torch::Device device) { + int64_t seq_lens = frame * height * width; + + torch::Tensor pos_freqs = pos_freqs_.to(device); + torch::Tensor neg_freqs = neg_freqs_.to(device); + + std::vector split_sizes; + for (int64_t dim : axes_dim_) { + split_sizes.push_back(dim / 2); + } + + auto freqs_pos = pos_freqs.split_with_sizes(split_sizes, 1); + auto freqs_neg = neg_freqs.split_with_sizes(split_sizes, 1); + + auto freqs_frame = freqs_pos[0] + .slice(0, idx, idx + frame) + .view({frame, 1, 1, -1}) + .expand({frame, height, width, -1}); + + torch::Tensor freqs_height; + if (scale_rope_) { + auto height_neg_part = + freqs_neg[1].slice(0, -(height / 2), freqs_neg[1].size(0)); + auto height_pos_part = freqs_pos[1].slice(0, 0, height / 2); + freqs_height = torch::cat({height_neg_part, height_pos_part}, 0) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + } else { + freqs_height = freqs_pos[1] + .slice(0, 0, height) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + } + + torch::Tensor freqs_width; + if (scale_rope_) { + auto neg_part = freqs_neg[2].slice(0, -(width / 2), freqs_neg[2].size(0)); + auto pos_part = freqs_pos[2].slice(0, 0, width / 2); + freqs_width = torch::cat({neg_part, pos_part}, 0) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } else { + freqs_width = freqs_pos[2] + .slice(0, 0, width) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } + auto freqs = + torch::cat({freqs_frame, freqs_height, freqs_width}, /*dim=*/-1) + .reshape({seq_lens, -1}) + .clone() + .contiguous(); + + return freqs; + } + + torch::Tensor _compute_condition_freqs(int64_t frame, + int64_t height, + int64_t width, + torch::Device device) { + int64_t seq_lens = frame * height * width; + + torch::Tensor pos_freqs = pos_freqs_.to(device); + torch::Tensor neg_freqs = neg_freqs_.to(device); + + std::vector split_sizes; + for (int64_t dim : axes_dim_) { + split_sizes.push_back(dim / 2); + } + + auto freqs_pos = pos_freqs.split_with_sizes(split_sizes, 1); + auto freqs_neg = neg_freqs.split_with_sizes(split_sizes, 1); + + auto freqs_frame = freqs_neg[0] + .slice(0, -1, freqs_neg[0].size(0)) + .view({frame, 1, 1, -1}) + .expand({frame, height, width, -1}); + + torch::Tensor freqs_height; + if (scale_rope_) { + auto neg_part = + freqs_neg[1].slice(0, -(height / 2), freqs_neg[1].size(0)); + auto pos_part = freqs_pos[1].slice(0, 0, height / 2); + freqs_height = torch::cat({neg_part, pos_part}, 0) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + } else { + freqs_height = freqs_pos[1] + .slice(0, 0, height) + .view({1, height, 1, -1}) + .expand({frame, height, width, -1}); + } + torch::Tensor freqs_width; + if (scale_rope_) { + auto neg_part = freqs_neg[2].slice(0, -(width / 2), freqs_neg[2].size(0)); + auto pos_part = freqs_pos[2].slice(0, 0, width / 2); + freqs_width = torch::cat({neg_part, pos_part}, 0) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } else { + freqs_width = freqs_pos[2] + .slice(0, 0, width) + .view({1, 1, width, -1}) + .expand({frame, height, width, -1}); + } + auto freqs = torch::cat({freqs_frame, freqs_height, freqs_width}, -1) + .reshape({seq_lens, -1}) + .clone() + .contiguous(); + + return freqs; + } + + int64_t theta_; + std::vector& axes_dim_; + bool scale_rope_; + torch::Tensor pos_freqs_; + torch::Tensor neg_freqs_; +}; + +TORCH_MODULE(QwenEmbedLayer3DRope); + +class QwenEmbedLayer3DRopeWithCacheImpl : public QwenEmbedLayer3DRopeImpl { + public: + QwenEmbedLayer3DRopeWithCacheImpl(const ModelContext& context, + int64_t theta, + std::vector& axes_dim, + bool scale_rope = false) + : QwenEmbedLayer3DRopeImpl(context, theta, axes_dim, scale_rope) {} + + std::pair forward( + const std::vector>& video_fhw, + int64_t max_txt_seq_len, + torch::Device device = torch::Device(torch::kCPU)) override { + std::vector vid_freqs_list; + int64_t max_vid_index = 0; + int64_t layer_num = video_fhw.size() - 1; + + for (size_t idx = 0; idx < video_fhw.size(); idx++) { + const std::vector& fhw = video_fhw[idx]; + + int64_t frame = fhw[0]; + int64_t height = fhw[1]; + int64_t width = fhw[2]; + + torch::Tensor video_freq; + + if (idx != layer_num) { + video_freq = + _compute_video_freqs_with_cache(frame, height, width, idx, device); + } else { + video_freq = + _compute_condition_freqs_with_cache(frame, height, width, device); + } + vid_freqs_list.push_back(video_freq); + + if (scale_rope_) { + max_vid_index = std::max({height / 2, width / 2, max_vid_index}); + } else { + max_vid_index = std::max({height, width, max_vid_index}); + } + } + + int64_t max_txt_seq_len_int = std::max(max_vid_index, layer_num); + + torch::Tensor txt_freqs = pos_freqs_.to(device).slice( + 0, max_vid_index, max_vid_index + max_txt_seq_len_int); + + torch::Tensor vid_freqs = torch::cat(vid_freqs_list, 0); + + return {vid_freqs, txt_freqs}; + } + + private: + torch::Tensor _compute_video_freqs_with_cache(int64_t frame, + int64_t height, + int64_t width, + int64_t idx, + torch::Device device) { + std::string key = std::to_string(frame) + "_" + std::to_string(idx) + "_" + + std::to_string(height) + "_" + std::to_string(width); + + // TODO: currently the freqs tensors are cached on device + // need to check whether to swap them to cpu to save device memory + auto it = video_freqs_cache_.find(key); + if (it != video_freqs_cache_.end()) { + return it->second.clone().contiguous(); + } else { + auto result = _compute_video_freqs(frame, height, width, idx, device); + video_freqs_cache_[key] = result.clone(); + return result; + } + } + + torch::Tensor _compute_condition_freqs_with_cache(int64_t frame, + int64_t height, + int64_t width, + torch::Device device) { + std::string key = std::to_string(frame) + "_" + std::to_string(height) + + "_" + std::to_string(width); + + // TODO: currently the freqs tensors are cached on device + // need to check whether to swap them to cpu to save device memory + auto it = condition_cache_.find(key); + if (it != condition_cache_.end()) { + return it->second.clone().contiguous(); + } else { + auto result = _compute_condition_freqs(frame, height, width, device); + condition_cache_[key] = result.clone(); + return result; + } + } + + std::unordered_map video_freqs_cache_; + std::unordered_map condition_cache_; +}; + +TORCH_MODULE(QwenEmbedLayer3DRopeWithCache); + +// A internel class that only register necessary modules for attention +// implementation The attention forward shouldn't be implemented here, but in +// processor classes +// TODO: This class should be extracted from dit class and integrated into a +// common class. +class AttentionImpl : public torch::nn::Module { + public: + AttentionImpl(const ModelContext context, + int64_t query_dim, + std::optional cross_attention_dim = std::nullopt, + int64_t heads = 8, + std::optional kv_heads = std::nullopt, + int64_t dim_head = 64, + double dropout = 0.0, + bool bias = false, + const std::string& qk_norm = "", + const std::string& cross_attention_norm = "", + std::optional added_kv_proj_dim = std::nullopt, + bool added_proj_bias = true, + bool out_bias = true, + bool scale_qk = true, + bool only_cross_attention = false, + double eps = 1e-5, + double rescale_output_factor = 1.0, + bool residual_connection = false, + std::optional out_dim = std::nullopt, + std::optional out_context_dim = std::nullopt, + std::optional context_pre_only = std::nullopt, + bool pre_only = false, + bool elementwise_affine = true, + bool is_causal = false, + ProcessGroup* sp_group = nullptr) + : options_(context.get_tensor_options()), + heads_(heads), + bias_(bias), + out_bias_(out_bias), + added_proj_bias_(added_proj_bias), + sp_group_(sp_group) { + if (qk_norm == "layer_norm") { + layer_norm_q_ = register_module( + "norm_q", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({dim_head}) + .eps(eps) + .elementwise_affine(elementwise_affine))); + layer_norm_k_ = register_module( + "norm_k", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({dim_head}) + .eps(eps) + .elementwise_affine(elementwise_affine))); + } else if (qk_norm == "layer_norm_across_heads") { + // Lumina applies qk norm across all heads + CHECK(kv_heads.has_value()) + << "qk_norm is set to: " + qk_norm + ", but get no kv_heads "; + layer_norm_q_ = register_module( + "norm_q", + torch::nn::LayerNorm( + torch::nn::LayerNormOptions({dim_head * heads}).eps(eps))); + layer_norm_k_ = register_module( + "norm_k", + torch::nn::LayerNorm( + torch::nn::LayerNormOptions({dim_head * kv_heads.value()}) + .eps(eps))); + } else if (qk_norm == "rms_norm") { + // Assuming you have an RMSNorm implementation + norm_q_ = register_module("norm_q", RMSNorm(dim_head, eps, true, false)); + norm_k_ = register_module("norm_k", RMSNorm(dim_head, eps, true, false)); + } else if (qk_norm == "rms_norm_across_heads") { + // LTX applies qk norm across all heads + CHECK(kv_heads.has_value()) + << "qk_norm is set to: " + qk_norm + ", but get no kv_heads "; + + norm_q_ = register_module("norm_q", RMSNorm(dim_head, eps, true, false)); + norm_k_ = register_module( + "norm_k", RMSNorm(dim_head * kv_heads.value(), eps, true, false)); + } else { + CHECK(qk_norm.empty()) << "unknown qk_norm: " + qk_norm + + ". Should be " + "'','layer_norm','rms_norm','layer_norm_" + "across_heads', 'rms_norm_across_heads'"; + } + + if (cross_attention_norm == "layer_norm") { + norm_cross_ = register_module( + "norm_cross", + torch::nn::LayerNorm( + torch::nn::LayerNormOptions({cross_attention_dim.value()}))); + } else { + CHECK(cross_attention_norm.empty()) + << "unknown cross_attention_norm: " + cross_attention_norm + + ". Should be '', 'layer_norm'"; + } + + int64_t q_dim = out_dim.has_value() ? out_dim.value() : dim_head * heads; + int64_t kv_dim = + !kv_heads.has_value() ? q_dim : dim_head * kv_heads.value(); + cross_attention_dim = cross_attention_dim.has_value() + ? cross_attention_dim.value() + : query_dim; + out_context_dim = + out_context_dim.has_value() ? out_context_dim.value() : query_dim; + + xllm::dit::SpOptions q_sp_option; + xllm::dit::SpOptions kv_sp_option; + xllm::dit::LinearType linear_type = xllm::dit::LinearType::Default; + if (FLAGS_sp_size > 1) { + q_sp_option = xllm::dit::SpOptions(/*head_num=*/heads, + /*head_dim=*/dim_head, + /*hidden_size=*/q_dim, + /*before_attention=*/true, + /*process_group=*/sp_group_); + + kv_sp_option = xllm::dit::SpOptions( + /*head_num=*/kv_heads.has_value() ? kv_heads.value() : heads, + /*head_dim=*/dim_head, + /*hidden_size=*/kv_dim, + /*before_attention=*/true, + /*process_group=*/sp_group_); + linear_type = xllm::dit::LinearType::SequenceParallel; + } + + auto q_linear = + layer::AddMatmulWeightTransposed(query_dim, q_dim, bias, options_); + + to_q_ = register_module("q_linear", + xllm::dit::DiTParallelLinear(std::move(q_linear), + /*module_name=*/"to_q", + linear_type, + q_sp_option)); + + // Key-Value projections (if not only cross attention) + if (!only_cross_attention) { + auto k_linear = layer::AddMatmulWeightTransposed( + cross_attention_dim.value(), kv_dim, bias, options_); + + to_k_ = + register_module("k_linear", + xllm::dit::DiTParallelLinear(std::move(k_linear), + /*module_name=*/"to_k", + linear_type, + kv_sp_option)); + + auto v_linear = layer::AddMatmulWeightTransposed( + cross_attention_dim.value(), kv_dim, bias, options_); + + to_v_ = + register_module("v_linear", + xllm::dit::DiTParallelLinear(std::move(v_linear), + /*module_name=*/"to_v", + linear_type, + kv_sp_option)); + } + + if (added_kv_proj_dim.has_value()) { + auto add_k_linear = layer::AddMatmulWeightTransposed( + added_kv_proj_dim.value(), kv_dim, added_proj_bias, options_); + + add_k_proj_ = register_module( + "add_k_linear", + xllm::dit::DiTParallelLinear(std::move(add_k_linear), + /*module_name=*/"add_k_proj", + linear_type, + kv_sp_option)); + + auto add_v_linear = layer::AddMatmulWeightTransposed( + added_kv_proj_dim.value(), kv_dim, added_proj_bias, options_); + + add_v_proj_ = register_module( + "add_v_linear", + xllm::dit::DiTParallelLinear(std::move(add_v_linear), + /*module_name=*/"add_v_proj", + linear_type, + kv_sp_option)); + if (context_pre_only.has_value()) { + auto add_q_linear = layer::AddMatmulWeightTransposed( + added_kv_proj_dim.value(), q_dim, added_proj_bias, options_); + + add_q_proj_ = register_module( + "add_q_linear", + xllm::dit::DiTParallelLinear(std::move(add_q_linear), + /*module_name=*/"add_q_proj", + linear_type, + q_sp_option)); + } + } + + xllm::dit::SpOptions out_sp_option; + if (FLAGS_sp_size > 1) { + out_sp_option = xllm::dit::SpOptions(/*head_num=*/heads, + /*head_dim=*/dim_head, + /*hidden_size=*/q_dim, + /*before_attention=*/false, + /*process_group=*/sp_group_); + } + + // Output projections + if (!pre_only) { + to_out_ = register_module("to_out", torch::nn::Sequential()); + + auto to_out_linear = layer::AddMatmulWeightTransposed( + q_dim, out_dim.value(), out_bias, options_); + + to_out_->push_back(xllm::dit::DiTParallelLinear(std::move(to_out_linear), + /*module_name=*/"out", + linear_type, + out_sp_option)); + to_out_->push_back( + torch::nn::Dropout(torch::nn::DropoutOptions(dropout))); + } + + // Additional output for context + if (context_pre_only.has_value() && context_pre_only) { + auto to_add_out_linear = layer::AddMatmulWeightTransposed( + q_dim, out_context_dim.value(), out_bias, options_); + + to_add_out_ = register_module( + "to_add_out_linear", + xllm::dit::DiTParallelLinear(std::move(to_add_out_linear), + /*module_name=*/"to_add_out", + linear_type, + out_sp_option)); + } + + // Added QK normalization for added KV projections + if (!qk_norm.empty() && added_kv_proj_dim.has_value()) { + if (qk_norm == "rms_norm") { + norm_added_q_ = register_module("norm_added_q", + RMSNorm(dim_head, eps, true, false)); + norm_added_k_ = register_module("norm_added_k", + RMSNorm(dim_head, eps, true, false)); + } else { + CHECK(qk_norm.empty()) << "unknown qk_norm: " + qk_norm + + ". Should be one of '','rms_norm'"; + // For layer_norm, we would register similar layers here + } + } + } + + void load_state_dict(const StateDict& state_dict) { + // to_out + to_out_[0]->as()->load_state_dict( + state_dict.get_dict_with_prefix("to_out.0.")); + // to_add_out + to_add_out_->load_state_dict( + state_dict.get_dict_with_prefix("to_add_out.")); + // norm_q + norm_q_->load_state_dict(state_dict.get_dict_with_prefix("norm_q.")); + // norm_k + norm_k_->load_state_dict(state_dict.get_dict_with_prefix("norm_k.")); + // norm_added_q + norm_added_q_->load_state_dict( + state_dict.get_dict_with_prefix("norm_added_q.")); + // norm_added_k + norm_added_k_->load_state_dict( + state_dict.get_dict_with_prefix("norm_added_k.")); + + to_q_->load_state_dict(state_dict.get_dict_with_prefix("to_q.")); + to_k_->load_state_dict(state_dict.get_dict_with_prefix("to_k.")); + to_v_->load_state_dict(state_dict.get_dict_with_prefix("to_v.")); + + add_q_proj_->load_state_dict( + state_dict.get_dict_with_prefix("add_q_proj.")); + add_k_proj_->load_state_dict( + state_dict.get_dict_with_prefix("add_k_proj.")); + add_v_proj_->load_state_dict( + state_dict.get_dict_with_prefix("add_v_proj.")); + } + + void verify_loaded_weights(const std::string& prefix) { + // to_out + to_out_[0]->as()->verify_loaded_weights( + prefix + "to_out.0."); + // to_add_out + to_add_out_->verify_loaded_weights(prefix + "to_add_out."); + // norm_q + norm_q_->verify_loaded_weights(prefix + "norm_q."); + // norm_k + norm_k_->verify_loaded_weights(prefix + "norm_k."); + // norm_added_q + norm_added_q_->verify_loaded_weights(prefix + "norm_added_q."); + // norm_added_k + norm_added_k_->verify_loaded_weights(prefix + "norm_added_k."); + + to_q_->verify_loaded_weights(prefix + "to_q."); + to_k_->verify_loaded_weights(prefix + "to_k."); + to_v_->verify_loaded_weights(prefix + "to_v."); + + add_q_proj_->verify_loaded_weights(prefix + "add_q_proj."); + add_k_proj_->verify_loaded_weights(prefix + "add_k_proj."); + add_v_proj_->verify_loaded_weights(prefix + "add_v_proj."); + } + + public: + int64_t heads_; + bool bias_; + bool out_bias_; + bool added_proj_bias_; + ProcessGroup* sp_group_; + + torch::TensorOptions options_; + torch::nn::LayerNorm layer_norm_q_{nullptr}, layer_norm_k_{nullptr}, + norm_cross_{nullptr}; + xllm::dit::DiTParallelLinear to_q_{nullptr}, to_k_{nullptr}, to_v_{nullptr}; + xllm::dit::DiTParallelLinear add_k_proj_{nullptr}, add_v_proj_{nullptr}, + add_q_proj_{nullptr}; + torch::nn::Sequential to_out_{nullptr}; + xllm::dit::DiTParallelLinear to_add_out_{nullptr}; + + // Assuming you have RMSNorm implemented + RMSNorm norm_q_{nullptr}, norm_k_{nullptr}, norm_added_q_{nullptr}, + norm_added_k_{nullptr}; +}; +TORCH_MODULE(Attention); + +// Implementation of attention forward +class QwenDoubleStreamAttnProcessor2_0Impl : public torch::nn::Module { + public: + QwenDoubleStreamAttnProcessor2_0Impl(Attention&& attn_module, + const ParallelArgs& parallel_args) + : parallel_args_(parallel_args) { + attn_ = register_module("attn", std::move(attn_module)); + } + + virtual std::tuple forward( + const torch::Tensor& hidden_states, // Image stream + const torch::Tensor& encoder_hidden_states, // Text stream + const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(), + const torch::Tensor& attention_mask = torch::Tensor(), + const std::tuple& image_rotary_emb = {}) { + // int64_t seq_txt = encoder_hidden_states.size(1); + // int64_t seq_img = hidden_states.size(1); + // Compute QKV for image stream (sample projections) + auto img_query = attn_->to_q_->forward(hidden_states); + auto img_key = attn_->to_k_->forward(hidden_states); + auto img_value = attn_->to_v_->forward(hidden_states); + + // Compute QKV for text stream (context projections) + auto txt_query = attn_->add_q_proj_->forward(encoder_hidden_states); + auto txt_key = attn_->add_k_proj_->forward(encoder_hidden_states); + auto txt_value = attn_->add_v_proj_->forward(encoder_hidden_states); + + // Reshape for multi-head attention + int64_t heads = attn_->heads_; + auto reshape_dims = std::vector{heads / FLAGS_sp_size, -1}; + + img_query = img_query.unflatten(-1, reshape_dims); + img_key = img_key.unflatten(-1, reshape_dims); + img_value = img_value.unflatten(-1, reshape_dims); + txt_query = txt_query.unflatten(-1, reshape_dims); + txt_key = txt_key.unflatten(-1, reshape_dims); + txt_value = txt_value.unflatten(-1, reshape_dims); + // Apply QK normalization + if (attn_->norm_q_) { + img_query = attn_->norm_q_->forward(img_query); + } + if (attn_->norm_k_) { + img_key = attn_->norm_k_->forward(img_key); + } + if (attn_->norm_added_q_) { + txt_query = attn_->norm_added_q_->forward(txt_query); + } + if (attn_->norm_added_k_) { + txt_key = attn_->norm_added_k_->forward(txt_key); + } + + // Apply RoPE if provided + auto img_freqs = std::get<0>(image_rotary_emb); + auto txt_freqs = std::get<1>(image_rotary_emb); + + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_query, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_key, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_value, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_query, /*tensor_name=*/"hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_key, /*tensor_name=*/"hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_value, /*tensor_name=*/"hidden_states", /*dim=*/1); + + img_query = apply_rotary_emb_qwen(img_query, img_freqs, false); + img_key = apply_rotary_emb_qwen(img_key, img_freqs, false); + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, false); + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, false); + + // Concatenate for joint attention - Order: [text, image] + auto joint_query = torch::cat({txt_query, img_query}, 1); + auto joint_key = torch::cat({txt_key, img_key}, 1); + auto joint_value = torch::cat({txt_value, img_value}, 1); + + auto results = at_npu::native::custom_ops::npu_fusion_attention( + joint_query, + joint_key, + joint_value, + heads / FLAGS_sp_size, + /*input_layout=*/"BSND", + /*pse=*/torch::nullopt, + /*padding_mask=*/torch::nullopt, + /*atten_mask*/ torch::nullopt, + /*scale=*/pow(joint_query.size(3), -0.5), + /*keep_prob=*/1.0, + /*pre_tockens=*/65535, + /*next_tockens=*/65535); + + auto joint_hidden_states = std::get<0>(results); + // Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3); + joint_hidden_states = joint_hidden_states.to(joint_query.dtype()); + + int64_t seq_txt = txt_query.size(1); + int64_t seq_img = img_query.size(1); + // Split attention outputs back + auto chunks = torch::split(joint_hidden_states, {seq_txt, seq_img}, 1); + auto txt_attn_output = chunks[0]; + auto img_attn_output = chunks[1]; + + txt_attn_output = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + txt_attn_output, + /*tensor_name=*/"encoder_hidden_states", + /*dim=*/1); + + img_attn_output = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + img_attn_output, /*tensor_name=*/"hidden_states", /*dim=*/1); + + // Apply output projections + img_attn_output = attn_->to_out_->forward(img_attn_output); + + txt_attn_output = attn_->to_add_out_->forward(txt_attn_output); + return std::make_tuple(img_attn_output, txt_attn_output); + } + + void load_state_dict(const StateDict& state_dict) { + attn_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) { + attn_->verify_loaded_weights(prefix); + } + + protected: + Attention attn_{nullptr}; + const ParallelArgs parallel_args_; +}; +TORCH_MODULE(QwenDoubleStreamAttnProcessor2_0); + +class FeedForwardImpl : public torch::nn::Module { + public: + explicit FeedForwardImpl(const ModelContext& context, + int64_t dim, + int64_t dim_out, + int64_t mult = 4, + double dropout = 0.0) + : options_(context.get_tensor_options()) { + auto model_args = context.get_model_args(); + auto inner_dim = dim * 4; + + // linear1 + linear1_ = register_module( + "linear1", + layer::AddMatmulWeightTransposed(dim, inner_dim, true, options_)); + + // activation + activation_ = register_module( + "activation", + torch::nn::Functional(std::function( + [](const at::Tensor& x) { return torch::gelu(x, "tanh"); }))); + + // linear2 + linear2_ = register_module( + "linear2", + layer::AddMatmulWeightTransposed(inner_dim, dim_out, true, options_)); + } + + torch::Tensor forward(const torch::Tensor& hidden_states) { + torch::Tensor out = linear1_->forward(hidden_states); + out = activation_(out); + out = linear2_->forward(out); + return out; + } + + void load_state_dict(const StateDict& state_dict) { + // linear1 + linear1_->load_state_dict(state_dict.get_dict_with_prefix("net.0.proj.")); + // linear2 + linear2_->load_state_dict(state_dict.get_dict_with_prefix("net.2.")); + } + + void verify_loaded_weights(const std::string& prefix) { + linear1_->verify_loaded_weights(prefix + "net.0.proj."); + linear2_->verify_loaded_weights(prefix + "net.2."); + } + + private: + layer::AddMatmulWeightTransposed linear1_{nullptr}; + layer::AddMatmulWeightTransposed linear2_{nullptr}; + torch::nn::Functional activation_{nullptr}; + torch::TensorOptions options_; +}; +TORCH_MODULE(FeedForward); + +bool ADALN_FUSE = true; + +class QwenImageTransformerBlockImpl : public torch::nn::Module { + public: + QwenImageTransformerBlockImpl(const ModelContext& context, + int64_t dim, + int64_t num_attention_heads, + int64_t attention_head_dim, + const ParallelArgs& parallel_args, + bool zero_cond_t = false, + const std::string& qk_norm = "rms_norm", + double eps = 1e-6) + : options_(context.get_tensor_options()), + zero_cond_t_(zero_cond_t), + parallel_args_(parallel_args) { + // Image processing modules + img_mod_ = register_module( + "img_mod", + torch::nn::Sequential( + torch::nn::SiLU(), + layer::AddMatmulWeightTransposed(dim, 6 * dim, true, options_))); + + // Image normalization + img_norm1_ = register_module("img_norm1", AdaLayerNorm(context, dim, eps)); + // Attention module + auto attn_ = Attention(context, + /*query_dim=*/dim, + /*cross_attention_dim=*/std::nullopt, + /*heads=*/num_attention_heads, + /*kv_heads=*/std::nullopt, + /*dim_head=*/attention_head_dim, + /*drop_out=*/0.0, + /*bias=*/true, + /*qk_norm=*/qk_norm, + /*cross_attention_norm=*/"", + /*added_kv_proj_dim=*/dim, + /*added_proj_bias*/ true, + /*out_bias*/ true, + /*scale_qk*/ true, + /*only_cross_attention=*/false, + eps, + /*rescale_output_factor=*/1.0, + /*residual_connection=*/false, + /*out_dim=*/dim, + /*out_context_dim=*/std::nullopt, + /*context_pre_only=*/true, + /*pre_only=*/false, + /*elementwise_affine=*/true, + /*is_causal=*/false, + /*sp_group=*/parallel_args_.dit_sp_group_); + attn_processor_ = register_module( + "attn_processor_", + QwenDoubleStreamAttnProcessor2_0(std::move(attn_), parallel_args_)); + // Image normalization 2 + img_norm2_ = register_module("img_norm2", AdaLayerNorm(context, dim, eps)); + + // Image MLP + img_mlp_ = register_module("img_mlp", FeedForward(context, dim, dim)); + + // Text processing modules + txt_mod_ = register_module( + "txt_mod", + torch::nn::Sequential( + torch::nn::SiLU(), + layer::AddMatmulWeightTransposed(dim, 6 * dim, true, options_))); + + // Text normalization 1 + txt_norm1_ = register_module("txt_norm1", AdaLayerNorm(context, dim, eps)); + + // Text normalization 2 + txt_norm2_ = register_module("txt_norm2", AdaLayerNorm(context, dim, eps)); + + // Text MLP + txt_mlp_ = register_module("txt_mlp", FeedForward(context, dim, dim)); + } + + std::pair _modulate( + const torch::Tensor& x, + const torch::Tensor& mod_params, + const torch::Tensor& index = torch::Tensor()) { + // x: b l d, shift: b d, scale: b d, gate: b d + auto chunks = mod_params.chunk(3, -1); + auto shift = chunks[0]; + auto scale = chunks[1]; + auto gate = chunks[2]; + + torch::Tensor shift_result, scale_result, gate_result; + + if (index.defined()) { + // Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + // So shift, scale, gate have shape [2*actual_batch, d] + int64_t actual_batch = shift.size(0) / 2; + + // Split into two parts + auto shift_0 = shift.slice(0, 0, actual_batch); + auto shift_1 = shift.slice(0, actual_batch, shift.size(0)); + + auto scale_0 = scale.slice(0, 0, actual_batch); + auto scale_1 = scale.slice(0, actual_batch, scale.size(0)); + + auto gate_0 = gate.slice(0, 0, actual_batch); + auto gate_1 = gate.slice(0, actual_batch, gate.size(0)); + + // index: [b, l] where b is actual batch size + // Expand to [b, l, 1] to match feature dimension + auto index_expanded = index.unsqueeze(-1); // [b, l, 1] + + // Expand chunks to [b, 1, d] then broadcast to [b, l, d] + auto shift_0_exp = shift_0.unsqueeze(1); // [b, 1, d] + auto shift_1_exp = shift_1.unsqueeze(1); // [b, 1, d] + auto scale_0_exp = scale_0.unsqueeze(1); + auto scale_1_exp = scale_1.unsqueeze(1); + auto gate_0_exp = gate_0.unsqueeze(1); + auto gate_1_exp = gate_1.unsqueeze(1); + + // Use torch::where to select based on index + shift_result = + torch::where(index_expanded == 0, shift_0_exp, shift_1_exp); + scale_result = + torch::where(index_expanded == 0, scale_0_exp, scale_1_exp); + gate_result = torch::where(index_expanded == 0, gate_0_exp, gate_1_exp); + } else { + shift_result = shift.unsqueeze(1); + scale_result = scale.unsqueeze(1); + gate_result = gate.unsqueeze(1); + } + + // Apply modulation: x * (1 + scale_result) + shift_result + auto modulated_x = x * (1 + scale_result) + shift_result; + + return {modulated_x, gate_result}; + } + + std::tuple forward( + const torch::Tensor& hidden_states, + const torch::Tensor& encoder_hidden_states, + const torch::Tensor& encoder_hidden_states_mask, + const torch::Tensor& temb, + const std::tuple& image_rotary_emb = {}, + const std::unordered_map& + joint_attention_kwargs = {}, + const torch::Tensor& modulate_index = torch::Tensor()) { + // Get modulation parameters for both streams + auto img_mod_params = img_mod_->forward(temb); // [B, 6*dim] + torch::Tensor new_temb; + if (zero_cond_t_) { + new_temb = temb.chunk(2, 0)[0]; + } else { + new_temb = temb; + } + auto txt_mod_params = txt_mod_->forward(new_temb); // [B, 6*dim] + // Split modulation parameters for norm1 and norm2 + auto img_mod_chunks = img_mod_params.chunk(2, -1); + auto img_mod1 = img_mod_chunks[0]; // [B, 3*dim] + auto img_mod2 = img_mod_chunks[1]; // [B, 3*dim] + + auto txt_mod_chunks = txt_mod_params.chunk(2, -1); + auto txt_mod1 = txt_mod_chunks[0]; // [B, 3*dim] + auto txt_mod2 = txt_mod_chunks[1]; // [B, 3*dim] + + // Process image stream - norm1 + modulation + torch::Tensor img_modulated, img_gate1; + std::tie(img_modulated, img_gate1) = + img_norm1_->forward(hidden_states, img_mod1, modulate_index); + // Process text stream - norm1 + modulation + torch::Tensor txt_modulated, txt_gate1; + std::tie(txt_modulated, txt_gate1) = + txt_norm1_->forward(encoder_hidden_states, txt_mod1); + + // Use QwenAttnProcessor2_0 for joint attention computation + auto attn_output = attn_processor_->forward(img_modulated, // Image stream + txt_modulated, // Text stream + encoder_hidden_states_mask, + torch::Tensor(), // timestep + image_rotary_emb); + + // QwenAttnProcessor2_0 returns (img_output, txt_output) + auto img_attn_output = std::get<0>(attn_output); + auto txt_attn_output = std::get<1>(attn_output); + + // Apply attention gates and add residual + auto new_hidden_states = hidden_states + img_gate1 * img_attn_output; + auto new_encoder_hidden_states = + encoder_hidden_states + txt_gate1 * txt_attn_output; + + // Process image stream - norm2 + MLP + torch::Tensor img_modulated2, img_gate2; + std::tie(img_modulated2, img_gate2) = + img_norm2_->forward(new_hidden_states, img_mod2, modulate_index); + + auto img_mlp_output = img_mlp_->forward(img_modulated2); + new_hidden_states = new_hidden_states + img_gate2 * img_mlp_output; + + // Process text stream - norm2 + MLP + torch::Tensor txt_modulated2, txt_gate2; + std::tie(txt_modulated2, txt_gate2) = + txt_norm2_->forward(new_encoder_hidden_states, txt_mod2); + + auto txt_mlp_output = txt_mlp_->forward(txt_modulated2); + new_encoder_hidden_states = + new_encoder_hidden_states + txt_gate2 * txt_mlp_output; + + // Clip to prevent overflow for fp16 + if (new_encoder_hidden_states.dtype() == torch::kFloat16) { + new_encoder_hidden_states = + new_encoder_hidden_states.clamp(-65504, 65504); + } + if (new_hidden_states.dtype() == torch::kFloat16) { + new_hidden_states = new_hidden_states.clamp(-65504, 65504); + } + + return std::make_tuple(new_hidden_states, new_encoder_hidden_states); + } + + void load_state_dict(const StateDict& state_dict) { + img_mod_[1]->as()->load_state_dict( + state_dict.get_dict_with_prefix("img_mod.1.")); + img_mlp_->load_state_dict(state_dict.get_dict_with_prefix("img_mlp.")); + txt_mod_[1]->as()->load_state_dict( + state_dict.get_dict_with_prefix("txt_mod.1.")); + txt_mlp_->load_state_dict(state_dict.get_dict_with_prefix("txt_mlp.")); + attn_processor_->load_state_dict(state_dict.get_dict_with_prefix("attn.")); + } + + void verify_loaded_weights(const std::string& prefix) { + img_mod_[1]->as()->verify_loaded_weights( + prefix + "img_mod.1."); + img_mlp_->verify_loaded_weights(prefix + "img_mlp."); + txt_mod_[1]->as()->verify_loaded_weights( + prefix + "txt_mod.1."); + txt_mlp_->verify_loaded_weights(prefix + "txt_mlp."); + attn_processor_->verify_loaded_weights(prefix + "attn."); + } + + private: + torch::TensorOptions options_; + torch::nn::Sequential img_mod_{nullptr}; + AdaLayerNorm img_norm1_{nullptr}; + AdaLayerNorm img_norm2_{nullptr}; + std::shared_ptr attn_{nullptr}; + QwenDoubleStreamAttnProcessor2_0 attn_processor_{nullptr}; + FeedForward img_mlp_{nullptr}; + + torch::nn::Sequential txt_mod_{nullptr}; + AdaLayerNorm txt_norm1_{nullptr}; + AdaLayerNorm txt_norm2_{nullptr}; + FeedForward txt_mlp_{nullptr}; + bool zero_cond_t_; + const ParallelArgs parallel_args_; +}; + +TORCH_MODULE(QwenImageTransformerBlock); + +class QwenImageTransformer2DModelImpl : public torch::nn::Module { + public: + QwenImageTransformer2DModelImpl(const ModelContext& context, + const ParallelArgs& parallel_args) + : options_(context.get_tensor_options()), parallel_args_(parallel_args) { + auto model_args = context.get_model_args(); + int64_t num_attention_heads = model_args.n_heads(); + int64_t attention_head_dim = model_args.head_dim(); + int64_t joint_attention_dim = model_args.joint_attention_dim(); + std::vector axes_dims_rope = model_args.axes_dims_rope(); + int64_t num_layers = model_args.num_layers(); + int64_t patch_size = model_args.mm_patch_size(); + int64_t in_channels = model_args.in_channels(); + int64_t out_channels = model_args.out_channels(); + bool zero_cond_t = model_args.zero_cond_t(); + bool use_additional_t_cond = model_args.use_additional_t_cond(); + use_layer3d_rope_ = model_args.use_layer3d_rope(); + + out_channels = (out_channels > 0) ? out_channels : in_channels; + auto inner_dim = num_attention_heads * attention_head_dim; + + // Positional embedding + if (use_layer3d_rope_) { + pos_embed_3d_rope_ = register_module( + "pos_embed", + QwenEmbedLayer3DRope(context, /*theta=*/10000, axes_dims_rope, true)); + } else { + pos_embed_ = register_module( + "pos_embed", + QwenEmbedRope(context, /*theta=*/10000, axes_dims_rope, true)); + } + + // Time-text embedding + time_text_embed_ = register_module( + "time_text_embed", + QwenTimestepProjEmbeddings(context, inner_dim, use_additional_t_cond)); + + // Text normalization + txt_norm_ = register_module( + "txt_norm", RMSNorm(joint_attention_dim, 1e-6, true, false)); + + // Input projections + img_in_ = register_module("img_in", + layer::AddMatmulWeightTransposed( + in_channels, inner_dim, true, options_)); + txt_in_ = + register_module("txt_in", + layer::AddMatmulWeightTransposed( + joint_attention_dim, inner_dim, true, options_)); + // Transformer blocks + transformer_blocks_ = + register_module("transformer_blocks", torch::nn::ModuleList()); + for (int64_t i = 0; i < num_layers; ++i) { + transformer_blocks_->push_back( + QwenImageTransformerBlock(context, + inner_dim, + num_attention_heads, + attention_head_dim, + parallel_args_, + zero_cond_t)); + } + + // Output layers + norm_out_ = register_module( + "norm_out", + AdaLayerNormContinuous(context, inner_dim, inner_dim, false, 1e-6)); + proj_out_ = register_module( + "proj_out", + layer::AddMatmulWeightTransposed( + inner_dim, patch_size * patch_size * out_channels, true, options_)); + + // Cache for conditional and unconditional + cache_cond_ = false; + cache_uncond_ = false; + + zero_cond_t_ = zero_cond_t; + } + torch::Tensor forward( + const torch::Tensor& hidden_states, + const torch::Tensor& encoder_hidden_states = torch::Tensor(), + const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(), + torch::Tensor timestep = torch::Tensor(), + std::vector> img_shapes = {}, + torch::Tensor txt_seq_lens = torch::Tensor(), + bool use_cfg = false, + int64_t step_idx = 0, + torch::Tensor addition_t_cond = torch::Tensor(), + torch::Tensor guidance = torch::Tensor(), + const std::unordered_map& attention_kwargs = + {}, + const std::vector& controlnet_block_samples = {}) { + auto new_hidden_states = img_in_->forward(hidden_states); + auto new_timestep = timestep.to(new_hidden_states.dtype()); + torch::Tensor modulate_index; + if (zero_cond_t_) { + new_timestep = torch::cat({new_timestep, new_timestep * 0}, /*dim=*/0); + std::vector modulate_index_list; + for (size_t sample_index = 0; sample_index < 1; sample_index++) { + auto zero_prods = torch::zeros({img_shapes[0][1] * img_shapes[0][2]}, + torch::TensorOptions() + .device(new_timestep.device()) + .dtype(torch::kInt64)); + int64_t one_prods_size = 0; + for (size_t index = 1; index < img_shapes.size(); index++) { + one_prods_size += img_shapes[index][1] * img_shapes[index][2]; + } + auto ones_prods = torch::ones({one_prods_size}, + torch::TensorOptions() + .device(new_timestep.device()) + .dtype(torch::kInt64)); + modulate_index_list.emplace_back( + torch::cat({zero_prods, ones_prods}, /*dim=*/0)); + } + modulate_index = torch::stack(modulate_index_list, /*dim=*/0); + } else { + modulate_index = torch::Tensor(); + } + + auto origin_text_seq_len = encoder_hidden_states.size(1); + + // padding mask for sequence parallel scene + auto padded_encoder_hidden_states_mask = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + encoder_hidden_states_mask, + /*tensor_name=*/"encoder_hidden_states_mask", + /*dim=*/1); + + auto new_encoder_hidden_states = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + encoder_hidden_states, + /*tensor_name=*/"encoder_hidden_states", + /*dim=*/1); + + new_hidden_states = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + new_hidden_states, /*tensor_name=*/"hidden_states", /*dim=*/1); + + modulate_index = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + modulate_index, /*tensor_name=*/"modulate_index", /*dim=*/1); + + new_encoder_hidden_states = txt_norm_->forward(new_encoder_hidden_states); + new_encoder_hidden_states = txt_in_->forward(new_encoder_hidden_states); + + // Use the encoder_hidden_states sequence length for RoPE computation and + // normalize mask + auto [text_seq_len, per_sample_len, new_encoder_hidden_states_mask] = + compute_text_seq_len_from_mask(new_encoder_hidden_states, + padded_encoder_hidden_states_mask); + auto temb = time_text_embed_->forward( + new_timestep, new_hidden_states, addition_t_cond); + std::tuple image_rotary_emb; + if (use_layer3d_rope_) { + image_rotary_emb = pos_embed_3d_rope_->forward( + img_shapes, origin_text_seq_len, new_hidden_states.device()); + } else { + image_rotary_emb = pos_embed_->forward(img_shapes, + origin_text_seq_len, + new_hidden_states.device(), + /*max_txt_seq_len=*/std::nullopt); + } + + std::unordered_map block_attention_kwargs; + if (new_encoder_hidden_states_mask.has_value() && + new_encoder_hidden_states_mask.value().defined()) { + int64_t batch_size = new_hidden_states.size(0); + int64_t image_seq_len = new_hidden_states.size(1); + auto image_mask = torch::ones({batch_size, image_seq_len}, + torch::TensorOptions() + .device(new_hidden_states.device()) + .dtype(torch::kBool)); + auto joint_attention_mask = torch::cat( + {new_encoder_hidden_states_mask.value(), image_mask}, /*dim=*/1); + block_attention_kwargs["attention_mask"] = joint_attention_mask; + } + + if (FLAGS_sp_size > 1) { + new_hidden_states = split_sequence(new_hidden_states, + /*dim=*/1, + parallel_args_.dit_sp_group_); + new_encoder_hidden_states = split_sequence(new_encoder_hidden_states, + /*dim=*/1, + parallel_args_.dit_sp_group_); + if (modulate_index.defined()) { + modulate_index = split_sequence(modulate_index, + /*dim=*/1, + parallel_args_.dit_sp_group_); + } + } + + auto image_rot = std::get<0>(image_rotary_emb); + auto txt_rot = std::get<1>(image_rotary_emb); + + bool use_step_cache = false; + bool use_block_cache = false; + + torch::Tensor original_hidden_states = new_hidden_states; + torch::Tensor original_encoder_hidden_states = new_encoder_hidden_states; + // Step start: prepare inputs (hidden_states, original_hidden_states) + TensorMap step_in_map = { + {"hidden_states", new_hidden_states}, + {"original_hidden_states", original_hidden_states}}; + CacheStepIn stepin_before(step_idx, step_in_map); + use_step_cache = + DiTCache::get_instance().on_before_step(stepin_before, use_cfg); + + if (!use_step_cache) { + for (int64_t index_block = 0; index_block < transformer_blocks_->size(); + ++index_block) { + TensorMap block_in_before_map = {}; + CacheBlockIn blockin_before(index_block, block_in_before_map); + use_block_cache = + DiTCache::get_instance().on_before_block(blockin_before, use_cfg); + + if (!use_block_cache) { + std::tie(new_hidden_states, new_encoder_hidden_states) = + transformer_blocks_[index_block] + ->as() + ->forward(new_hidden_states, + new_encoder_hidden_states, + /*encoder_hidden_states_mask=*/torch::Tensor(), + temb, + image_rotary_emb, + block_attention_kwargs, + modulate_index); + } + + TensorMap block_in_after_map = { + {"hidden_states", new_hidden_states}, + {"encoder_hidden_states", new_encoder_hidden_states}, + {"original_hidden_states", original_hidden_states}, + {"original_encoder_hidden_states", original_encoder_hidden_states}}; + CacheBlockIn blockin_after(index_block, block_in_after_map); + CacheBlockOut blockout_after = + DiTCache::get_instance().on_after_block(blockin_after, use_cfg); + + new_hidden_states = blockout_after.tensors.at("hidden_states"); + new_encoder_hidden_states = + blockout_after.tensors.at("encoder_hidden_states"); + } + } + + // Step end: update outputs (hidden_states, original_hidden_states) + TensorMap step_after_map = { + {"hidden_states", new_hidden_states}, + {"original_hidden_states", original_hidden_states}}; + CacheStepIn stepin_after(step_idx, step_after_map); + CacheStepOut stepout_after = + DiTCache::get_instance().on_after_step(stepin_after, use_cfg); + new_hidden_states = stepout_after.tensors.at("hidden_states"); + + if (zero_cond_t_) { + temb = temb.chunk(2, 0)[0]; + } + + new_hidden_states = norm_out_->forward(new_hidden_states, temb); + new_hidden_states = proj_out_->forward(new_hidden_states); + if (FLAGS_sp_size > 1) { + new_hidden_states = gather_sequence( + new_hidden_states, /*dim=*/1, parallel_args_.dit_sp_group_); + } + return new_hidden_states; + } + + void verify_loaded_weights(const std::string& prefix) { + time_text_embed_->verify_loaded_weights(prefix + "time_text_embed."); + txt_norm_->verify_loaded_weights(prefix + "txt_norm."); + img_in_->verify_loaded_weights(prefix + "img_in."); + txt_in_->verify_loaded_weights(prefix + "txt_in."); + norm_out_->verify_loaded_weights(prefix + "norm_out."); + proj_out_->verify_loaded_weights(prefix + "proj_out."); + for (size_t i = 0; i < transformer_blocks_->size(); i++) { + auto block_prefix = "transformer_blocks." + std::to_string(i) + "."; + transformer_blocks_[i] + ->as() + ->verify_loaded_weights(prefix + block_prefix); + } + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + time_text_embed_->load_state_dict( + state_dict->get_dict_with_prefix("time_text_embed.")); + txt_norm_->load_state_dict(state_dict->get_dict_with_prefix("txt_norm.")); + + img_in_->load_state_dict(state_dict->get_dict_with_prefix("img_in.")); + txt_in_->load_state_dict(state_dict->get_dict_with_prefix("txt_in.")); + + norm_out_->load_state_dict(state_dict->get_dict_with_prefix("norm_out.")); + proj_out_->load_state_dict(state_dict->get_dict_with_prefix("proj_out.")); + + for (size_t i = 0; i < transformer_blocks_->size(); i++) { + auto prefix = "transformer_blocks." + std::to_string(i) + "."; + transformer_blocks_[i] + ->as() + ->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + } + } + verify_loaded_weights(""); + LOG(INFO) << "qwen image vae model loaded successfully."; + } + + private: + torch::TensorOptions options_; + QwenEmbedRope pos_embed_{nullptr}; + QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr}; + QwenTimestepProjEmbeddings time_text_embed_{nullptr}; + RMSNorm txt_norm_{nullptr}; + layer::AddMatmulWeightTransposed img_in_{nullptr}; + layer::AddMatmulWeightTransposed txt_in_{nullptr}; + torch::nn::ModuleList transformer_blocks_{nullptr}; + AdaLayerNormContinuous norm_out_{nullptr}; + layer::AddMatmulWeightTransposed proj_out_{nullptr}; + + const ParallelArgs parallel_args_; + + // Cache objects + bool cache_cond_; + bool cache_uncond_; + + bool zero_cond_t_; + bool use_layer3d_rope_; +}; + +TORCH_MODULE(QwenImageTransformer2DModel); + +REGISTER_MODEL_ARGS(QwenImageTransformer2DModel, [&] { + // qwen-image 2509 params + LOAD_ARG_OR(dtype, "dtype", "bfloat16"); + LOAD_ARG_OR(in_channels, "in_channels", 64); + LOAD_ARG_OR(out_channels, "out_channels", 16); + LOAD_ARG_OR(num_layers, "num_layers", 60); + LOAD_ARG_OR(num_single_layers, "num_single_layers", 24); + LOAD_ARG_OR(head_dim, "attention_head_dim", 128); + LOAD_ARG_OR(n_heads, "num_attention_heads", 24); + LOAD_ARG_OR(joint_attention_dim, "joint_attention_dim", 3584); + LOAD_ARG_OR(mm_patch_size, "patch_size", 2); + LOAD_ARG_OR(guidance_embeds, "guidance_embeds", false); + LOAD_ARG_OR( + axes_dims_rope, "axes_dims_rope", (std::vector{16, 56, 56})); + + // qwen-image 2511 params + LOAD_ARG_OR(zero_cond_t, "zero_cond_t", false); + LOAD_ARG_OR(use_additional_t_cond, "use_additional_t_cond", false); + LOAD_ARG_OR(use_layer3d_rope, "use_layer3d_rope", false); +}); + +} // namespace qwenimage +} // namespace xllm::dit::npu diff --git a/xllm/models/dit/pipeline_flux_base.h b/xllm/models/dit/pipeline_flux_base.h index 604bc6d81..eb5f3de09 100644 --- a/xllm/models/dit/pipeline_flux_base.h +++ b/xllm/models/dit/pipeline_flux_base.h @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #pragma once + +#if defined(USE_NPU) #include +#endif #include #include diff --git a/xllm/models/dit/pipeline_longcat_image.h b/xllm/models/dit/pipeline_longcat_image.h index 5f4dd6558..6a2fe4421 100644 --- a/xllm/models/dit/pipeline_longcat_image.h +++ b/xllm/models/dit/pipeline_longcat_image.h @@ -383,7 +383,7 @@ class LongCatImagePipelineImpl : public torch::nn::Module { char c = prompt_text[i]; if ((c == '\'' || c == '\"') && !in_quotes) { if (!current.empty()) { - result.push_back({current, false}); + result.emplace_back(current, false); current.clear(); } in_quotes = true; @@ -391,7 +391,7 @@ class LongCatImagePipelineImpl : public torch::nn::Module { current += c; } else if (in_quotes && c == quote_char) { current += c; - result.push_back({current, true}); + result.emplace_back(current, true); current.clear(); in_quotes = false; quote_char = '\0'; @@ -400,7 +400,7 @@ class LongCatImagePipelineImpl : public torch::nn::Module { } } if (!current.empty()) { - result.push_back({current, in_quotes}); + result.emplace_back(current, in_quotes); } return result; } diff --git a/xllm/models/dit/pipeline_longcat_image_edit.h b/xllm/models/dit/pipeline_longcat_image_edit.h index 1c6a0af00..2beb9ef6a 100644 --- a/xllm/models/dit/pipeline_longcat_image_edit.h +++ b/xllm/models/dit/pipeline_longcat_image_edit.h @@ -347,7 +347,7 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { char c = prompt_text[i]; if ((c == '\'' || c == '\"') && !in_quotes) { if (!current.empty()) { - result.push_back({current, false}); + result.emplace_back(current, false); current.clear(); } in_quotes = true; @@ -355,7 +355,7 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { current += c; } else if (in_quotes && c == quote_char) { current += c; - result.push_back({current, true}); + result.emplace_back(current, true); current.clear(); in_quotes = false; quote_char = '\0'; @@ -364,7 +364,7 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { } } if (!current.empty()) { - result.push_back({current, in_quotes}); + result.emplace_back(current, in_quotes); } return result; } diff --git a/xllm/models/dit/transformer_flux.h b/xllm/models/dit/transformer_flux.h index 8e0e93618..22974c5e7 100644 --- a/xllm/models/dit/transformer_flux.h +++ b/xllm/models/dit/transformer_flux.h @@ -59,7 +59,7 @@ inline torch::Tensor apply_rotary_emb(const torch::Tensor& x, #if defined(USE_NPU) return at_npu::native::custom_ops::npu_rotary_mul(x, cos, sin, "interleave"); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_MLU) std::vector reshape_shape; for (int64_t i = 0; i < x.dim() - 1; ++i) { reshape_shape.push_back(x.size(i)); @@ -278,7 +278,7 @@ class FluxSingleAttentionImpl : public torch::nn::Module { auto attn_output = std::get<0>(results); attn_output = attn_output.to(query.dtype()); return attn_output.flatten(2); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_MLU) query = query.view({batch_size, -1, attn_heads, head_dim}).transpose(1, 2); key = key.view({batch_size, -1, attn_heads, head_dim}).transpose(1, 2); value = value.view({batch_size, -1, attn_heads, head_dim}).transpose(1, 2); @@ -458,7 +458,7 @@ class FluxAttentionImpl : public torch::nn::Module { auto attn_output = std::get<0>(results); attn_output = attn_output.reshape({batch_size, -1, attn_heads * head_dim}); -#elif defined(USE_CUDA) || defined(USE_MUSA) +#elif defined(USE_CUDA) || defined(USE_MLU) || defined(USE_MUSA) // SDPA expects (B, H, S, D); our query1/key1/value1 are (B, S, H, D). // Transpose to match diffusers dispatch_attention_fn (permute 0,2,1,3). query1 = query1.transpose(1, 2); diff --git a/xllm/models/dit/utils/common_util.h b/xllm/models/dit/utils/common_util.h new file mode 100644 index 000000000..739c53799 --- /dev/null +++ b/xllm/models/dit/utils/common_util.h @@ -0,0 +1,240 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once + +#include + +#include "models/dit/flowmatch_euler_discrete_scheduler.h" + +namespace xllm::dit { + +float calculate_shift(int64_t image_seq_len, + int64_t base_seq_len = 256, + int64_t max_seq_len = 4096, + float base_shift = 0.5f, + float max_shift = 1.15f) { + float m = + (max_shift - base_shift) / static_cast(max_seq_len - base_seq_len); + float b = base_shift - m * static_cast(base_seq_len); + float mu = static_cast(image_seq_len) * m + b; + return mu; +} + +std::pair retrieve_timesteps( + xllm::FlowMatchEulerDiscreteScheduler scheduler, + int64_t num_inference_steps = 0, + torch::Device device = torch::kCPU, + std::optional> sigmas = std::nullopt, + std::optional mu = std::nullopt) { + torch::Tensor scheduler_timesteps; + int64_t steps; + if (sigmas.has_value()) { + steps = sigmas->size(); + scheduler->set_timesteps( + static_cast(steps), device, *sigmas, mu, std::nullopt); + + scheduler_timesteps = scheduler->timesteps(); + } else { + steps = num_inference_steps; + scheduler->set_timesteps( + static_cast(steps), device, std::nullopt, mu, std::nullopt); + scheduler_timesteps = scheduler->timesteps(); + } + if (scheduler_timesteps.device() != device) { + scheduler_timesteps = scheduler_timesteps.to(device); + } + return {scheduler_timesteps, steps}; +} + +std::pair calculate_dimensions(double target_area, + double ratio) { + double width = std::sqrt(target_area * ratio); + double height = width / ratio; + + width = std::round(width / 32) * 32; + height = std::round(height / 32) * 32; + + return {static_cast(width), static_cast(height)}; +} + +torch::Tensor randn_tensor(const std::vector& shape, + int64_t seed, + torch::TensorOptions& options) { + if (shape.empty()) { + LOG(FATAL) << "Shape must not be empty."; + } + at::Generator gen = at::detail::createCPUGenerator(); + gen = gen.clone(); + gen.set_current_seed(seed); + torch::Tensor latents; + latents = torch::randn(shape, gen, options.device(torch::kCPU)); + latents = latents.to(options); + return latents; +} + +class VAEImageProcessorImpl : public torch::nn::Module { + public: + explicit VAEImageProcessorImpl(ModelContext context, + bool do_resize = true, + bool do_normalize = true, + bool do_binarize = false, + bool do_convert_rgb = false, + bool do_convert_grayscale = false, + int64_t latent_channels = 4) { + const auto& model_args = context.get_model_args(); + dtype_ = context.get_tensor_options().dtype().toScalarType(); + scale_factor_ = 1 << model_args.block_out_channels().size(); + latent_channels_ = latent_channels; + do_resize_ = do_resize; + do_normalize_ = do_normalize; + do_binarize_ = do_binarize; + do_convert_rgb_ = do_convert_rgb; + do_convert_grayscale_ = do_convert_grayscale; + } + + std::pair adjust_dimensions(int64_t height, + int64_t width) const { + height = height - (height % scale_factor_); + width = width - (width % scale_factor_); + return {height, width}; + } + + torch::Tensor preprocess( + const torch::Tensor& image, + std::optional height = std::nullopt, + std::optional width = std::nullopt, + const std::string& resize_mode = "default", + std::optional> + crop_coords = std::nullopt) { + torch::Tensor processed = image.clone(); + if (processed.dtype() != torch::kFloat32) { + processed = processed.to(torch::kFloat32); + } + if (processed.max().item() > 1.1f) { + processed = processed / 255.0f; + } + if (crop_coords.has_value()) { + auto [x1, y1, x2, y2] = crop_coords.value(); + x1 = std::max(int64_t(0), x1); + y1 = std::max(int64_t(0), y1); + x2 = std::min(processed.size(-1), x2); + y2 = std::min(processed.size(-2), y2); + + if (processed.dim() == 3) { + processed = processed.index({torch::indexing::Slice(), + torch::indexing::Slice(y1, y2), + torch::indexing::Slice(x1, x2)}); + } else if (processed.dim() == 4) { + processed = processed.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(y1, y2), + torch::indexing::Slice(x1, x2)}); + } + } + int64_t channel = processed.size(1); + if (channel == latent_channels_) { + return image; + } + auto [target_h, target_w] = + get_default_height_width(processed, height, width); + if (do_resize_) { + processed = resize(processed, target_h, target_w); + } + + if (do_normalize_) { + processed = normalize(processed); + } + if (do_binarize_) { + processed = (processed >= 0.5f).to(torch::kFloat32); + } + processed = processed.to(dtype_); + return processed; + } + + torch::Tensor postprocess( + const torch::Tensor& tensor, + const std::string& output_type = "pt", + std::optional> do_denormalize = std::nullopt) { + torch::Tensor processed = tensor.clone(); + if (do_normalize_) { + if (!do_denormalize.has_value()) { + processed = denormalize(processed); + } else { + for (int64_t i = 0; i < processed.size(0); ++i) { + if (i < do_denormalize.value().size() && do_denormalize.value()[i]) { + processed[i] = denormalize(processed[i]); + } + } + } + } + if (output_type == "np") { + return processed.permute({0, 2, 3, 1}).contiguous(); + } + return processed; + } + + private: + std::pair get_default_height_width( + const torch::Tensor& image, + std::optional height = std::nullopt, + std::optional width = std::nullopt) const { + int64_t h, w; + if (image.dim() == 3) { + h = image.size(1); + w = image.size(2); + } else if (image.dim() == 4) { + h = image.size(2); + w = image.size(3); + } else { + LOG(FATAL) << "Unsupported image dimension: " << image.dim(); + } + + int64_t target_h = height.value_or(h); + int64_t target_w = width.value_or(w); + return adjust_dimensions(target_h, target_w); + } + + torch::Tensor normalize(const torch::Tensor& tensor) const { + return 2.0 * tensor - 1.0; + } + + torch::Tensor denormalize(const torch::Tensor& tensor) const { + return (tensor * 0.5 + 0.5).clamp(0.0, 1.0); + } + + public: + torch::Tensor resize(const torch::Tensor& image, + int64_t target_height, + int64_t target_width) const { + return torch::nn::functional::interpolate( + image, + torch::nn::functional::InterpolateFuncOptions() + .size(std::vector{target_height, target_width}) + .mode(torch::kNearest)); + } + + private: + int64_t scale_factor_ = 8; + int64_t latent_channels_ = 4; + bool do_resize_ = true; + bool do_normalize_ = true; + bool do_binarize_ = false; + bool do_convert_rgb_ = false; + bool do_convert_grayscale_ = false; + torch::ScalarType dtype_ = torch::kFloat32; +}; +TORCH_MODULE(VAEImageProcessor); + +} // namespace xllm::dit diff --git a/xllm/models/dit/utils/dit_parallel_linear.h b/xllm/models/dit/utils/dit_parallel_linear.h new file mode 100644 index 000000000..e7ab59e1a --- /dev/null +++ b/xllm/models/dit/utils/dit_parallel_linear.h @@ -0,0 +1,175 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/framework/state_dict/utils.h" +#include "core/layers/common/add_matmul.h" +#include "framework/parallel_state/parallel_state.h" + +namespace xllm::dit { +namespace F = torch::nn::functional; + +enum class LinearType { Default, SequenceParallel }; + +// NOTE: The order of linear and all2all Operations depends on the +// before_attention param if before_attention is true, order is: linear->all2all +// if before_attention is false, order is: all2all->linear +struct SpOptions { + // the num of attention heads + int64_t head_num = 0; + + // the size of single attention head + int64_t head_dim = 0; + + // hidden_size + int64_t hidden_size = 0; + + // before_attention: a Bool value that indicates where to apply the all2all, + // According to the classic ulysses sequence parallel, we should apply + // all2all communication for q, k, v, text_q (optional), text_k (optional), + // text_v (optional), before attention operation to gather full sequence + // (splited_sequence * group_size) and scatter the head nums (head_nums / + // group_size) , and we should apply all2all communication for attn_output, + // text_attn_output (optional) after the attention operation to split the + // full sequence (full_sequence / group_size) , and gather the head nums + // (splited_head_num * group_size) + bool before_attention = false; + + // the process_group for sequence parallel + ProcessGroup* process_group = nullptr; + + SpOptions() = default; + + SpOptions(int64_t head_num, + int64_t head_dim, + int64_t hidden_size, + bool before_attention, + ProcessGroup* process_group = nullptr) + : head_num(head_num), + head_dim(head_dim), + hidden_size(hidden_size), + before_attention(before_attention), + process_group(process_group) {} + + void valid() const { + CHECK(head_num > 0) << "head_num should be greater than 0 to initialize " + "DiTParallelLinear for " + "linear type 'sequence_parallel' " + << " but got " << head_num; + CHECK(head_dim > 0) << "head_dim should be greater than 0 to initialize " + "DiTParallelLinear for " + "linear type 'sequence_parallel' " + << " but got " << head_dim; + CHECK(hidden_size > 0) << "head_size should be greater than 0 to " + "initialize DiTParallelLinear for " + "linear type 'sequence_parallel' " + << " but got " << hidden_size; + CHECK(hidden_size == head_dim * head_num) + << "hidden_size should equal to head_dim * head_num" + << "got head_dim " << head_dim << ", head num" << head_num + << ", hidden_size " << hidden_size; + if (!process_group) { + LOG(ERROR) + << "DiTSpLinear expected to receive an initialized processgroup for" + << "all2all communication, but got nullptr"; + } + } +}; + +// TODO : Need to Implement a template funciton, but +// libtorch doesn't allow to creat module holder for +// template class. +// template +class DiTParallelLinearImpl : public torch::nn::Module { + public: + DiTParallelLinearImpl(layer::AddMatmulWeightTransposed linear, + const string& module_name, + LinearType linear_type = LinearType::Default, + const SpOptions& sp_options = SpOptions()) + : sp_options_(sp_options), linear_type_(linear_type) { + linear_ = register_module(module_name, std::move(linear)); + if (linear_type == LinearType::SequenceParallel) { + sp_options_.valid(); + } + } + + torch::Tensor linear_forward(const torch::Tensor& input) { + return linear_->forward(input); + } + + // sp_forward combines the linear operation with all2all communication, + // output: A torch tensor with shape {batch, seq_len, hidden_size} + torch::Tensor sp_forward(const torch::Tensor& input) { + CHECK(input.sizes().size() == 3) + << "Sp linear input is expected to be a tensor " + << "with shape {batch, seq_len, hidden_size}"; + auto group_size = sp_options_.process_group->world_size(); + if (sp_options_.before_attention) { + auto linear_output = this->linear_forward(input); + auto all_to_all_func = parallel_state::all_to_all_4D( + /*input=*/linear_output.view( + {input.size(0), -1, sp_options_.head_num, sp_options_.head_dim}), + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/false, + sp_options_.process_group); + auto output = all_to_all_func(); + return output.view( + {input.size(0), -1, sp_options_.hidden_size / group_size}); + } else { + auto all_to_all_func = parallel_state::all_to_all_4D( + /*input=*/input.view({input.size(0), + -1, + sp_options_.head_num / group_size, + sp_options_.head_dim}), + /*scatter_dim=*/1, + /*gather_dim=*/2, + /*async_ops=*/false, + sp_options_.process_group); + auto all_to_all_output = all_to_all_func(); + all_to_all_output = + all_to_all_output.view({input.size(0), -1, sp_options_.hidden_size}); + auto output = this->linear_forward(all_to_all_output); + return output; + } + } + + torch::Tensor forward(const torch::Tensor& input) { + if (linear_type_ == LinearType::Default) { + return this->linear_forward(input); + } else { + return this->sp_forward(input); + } + } + + void load_state_dict(const StateDict& state_dict) { + linear_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + linear_->verify_loaded_weights(prefix); + } + + private: + layer::AddMatmulWeightTransposed linear_{nullptr}; + SpOptions sp_options_; + LinearType linear_type_; +}; + +TORCH_MODULE(DiTParallelLinear); +} // namespace xllm::dit diff --git a/xllm/models/dit/utils/sequence_parallel_pad_manager.h b/xllm/models/dit/utils/sequence_parallel_pad_manager.h new file mode 100644 index 000000000..82f9154a9 --- /dev/null +++ b/xllm/models/dit/utils/sequence_parallel_pad_manager.h @@ -0,0 +1,93 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +namespace xllm::dit { + +class SequenceParallelPadManager { + public: + static SequenceParallelPadManager& getInstance() { + static SequenceParallelPadManager instance; + return instance; + } + + torch::Tensor pad_tensor(const torch::Tensor& ref_tensor, + const string& tensor_name, + int64_t dim = -1, + bool right_pad = true) { + auto pad_dim = dim; + if (pad_dim == -1) { + pad_dim = ref_tensor.dim() - 1; + } + + if (ref_tensor.defined()) { + if (ref_tensor.size(dim) % FLAGS_sp_size != 0) { + int64_t pad_len = FLAGS_sp_size - ref_tensor.size(dim) % FLAGS_sp_size; + set(tensor_name, pad_len); + + std::vector pad_shape(ref_tensor.dim() * 2); + int64_t pad_shift = right_pad ? 1 : 0; + pad_shape[2 * (ref_tensor.dim() - pad_dim - 1) + pad_shift] = pad_len; + auto pad_tensor = torch::pad(ref_tensor, pad_shape, "constant", 0); + return pad_tensor; + } + set(tensor_name, 0); + } + + return ref_tensor; + } + + void unpad_tensor(torch::Tensor& ref_tensor, + const string& tensor_name, + int64_t dim = -1, + bool right_pad = true) { + if (ref_tensor.defined()) { + auto pad = get(tensor_name); + ref_tensor = ref_tensor.narrow(dim, 0, ref_tensor.size(dim) - pad); + } + } + + void set(const std::string& key, int64_t length) { + pad_lengths_[key] = length; + } + + int64_t get(const std::string& key) const { + auto it = pad_lengths_.find(key); + return it != pad_lengths_.end() ? it->second : 0; + } + + bool has(const std::string& key) const { + return pad_lengths_.find(key) != pad_lengths_.end(); + } + + void remove(const std::string& key) { pad_lengths_.erase(key); } + + void clear() { pad_lengths_.clear(); } + + private: + SequenceParallelPadManager() = default; + ~SequenceParallelPadManager() = default; + SequenceParallelPadManager(const SequenceParallelPadManager&) = delete; + SequenceParallelPadManager& operator=(const SequenceParallelPadManager&) = + delete; + + std::unordered_map pad_lengths_; +}; + +} // namespace xllm::dit diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 1ec041de7..4960d22c8 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -103,6 +103,10 @@ class DeepseekV2ModelImpl : public torch::nn::Module { attn_metadata, kv_caches[i], modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + hidden_states.device())) { + return ModelOutput(); + } } auto [h, res] = norm_(hidden_states, residual); return ModelOutput(h, res); @@ -171,16 +175,7 @@ class DeepseekV2ForCausalLMImpl : public LlmForCausalLMImplBase { public: DeepseekV2ForCausalLMImpl(const ModelContext& context) - : LlmForCausalLMImplBase(context) { - // Check if prefix cache or chunked prefill is enabled for unsupported - // models - CHECK(!FLAGS_enable_prefix_cache) - << "deepseek_v2 have not supported " - "enable_prefix_cache yet. Please disable it."; - CHECK(!FLAGS_enable_chunked_prefill) - << "deepseek_v2 have not supported " - "enable_chunked_prefill yet. Please disable it."; - } + : LlmForCausalLMImplBase(context) {} void load_model( std::unique_ptr loader, diff --git a/xllm/models/llm/joyai_llm_flash.h b/xllm/models/llm/joyai_llm_flash.h index 2c72a252a..9719e2c07 100644 --- a/xllm/models/llm/joyai_llm_flash.h +++ b/xllm/models/llm/joyai_llm_flash.h @@ -23,16 +23,7 @@ class JoyAILLMFlashForCausalLMImpl : public LlmForCausalLMImplBase { public: JoyAILLMFlashForCausalLMImpl(const ModelContext& context) - : LlmForCausalLMImplBase(context) { - // Check if prefix cache or chunked prefill is enabled for unsupported - // models - CHECK(!FLAGS_enable_prefix_cache) - << "JoyAILLMFlash have not supported " - "enable_prefix_cache yet. Please disable it."; - CHECK(!FLAGS_enable_chunked_prefill) - << "JoyAILLMFlash have not supported " - "enable_chunked_prefill yet. Please disable it."; - } + : LlmForCausalLMImplBase(context) {} }; TORCH_MODULE(JoyAILLMFlashForCausalLM); @@ -62,7 +53,7 @@ REGISTER_MODEL_ARGS(joyai_llm_flash, [&] { LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 1); LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); - // LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); + LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1); LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); diff --git a/xllm/models/llm/joyai_llm_flash_mtp.h b/xllm/models/llm/joyai_llm_flash_mtp.h new file mode 100644 index 000000000..bad9f03a1 --- /dev/null +++ b/xllm/models/llm/joyai_llm_flash_mtp.h @@ -0,0 +1,85 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once + +#include "deepseek_mtp.h" + +namespace xllm { + +class JoyAILLMFlashMtpForCausalLMImpl + : public LlmForCausalLMImplBase { + public: + JoyAILLMFlashMtpForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} + + void load_model( + std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) override { + // no need to load lm_head since it shares the same weights with main model + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + } + model_->verify_loaded_weights(); + } +}; +TORCH_MODULE(JoyAILLMFlashMtpForCausalLM); + +REGISTER_CAUSAL_MODEL(joyai_llm_flash_mtp, JoyAILLMFlashMtpForCausalLM); + +REGISTER_MODEL_ARGS(joyai_llm_flash_mtp, [&] { + LOAD_ARG_OR(model_type, "model_type", "joyai_llm_flash_mtp"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 129280); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 40); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 32); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 7168); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 131072); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 1); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); + LOAD_ARG_OR(rope_theta, "rope_theta", 32000000.0f); + + LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 1); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); + LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); + LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 768); + LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 2.5f); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_group, "n_group", 1); + LOAD_ARG_OR(topk_group, "topk_group", 1); + LOAD_ARG_OR(scoring_func, "scoring_func", "sigmoid"); + LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); + LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); + LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); + LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); + LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); + }); + LOAD_ARG_OR_FUNC( + rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); + + SET_ARG(rope_scaling_rope_type, "default"); + SET_ARG(stop_token_ids, std::unordered_set({1})); +}); +} // namespace xllm diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 7c823fb66..7bed947ab 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "core/common/interruption_bus.h" -#include "core/common/rec_model_utils.h" #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/model/model_input_params.h" #include "core/framework/model/model_output.h" @@ -30,6 +29,7 @@ limitations under the License. #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/common/lm_head.h" #include "core/layers/common/rms_norm.h" +#include "core/util/rec_model_utils.h" #include "models/model_registry.h" namespace xllm { @@ -121,6 +121,10 @@ class LlmModelImplBase : public torch::nn::Module { attn_metadata, kv_caches[i], modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + h.device())) { + return ModelOutput(); + } } auto [hidden_states, residual_out] = norm_(h, residual); return ModelOutput(hidden_states, residual_out); @@ -215,15 +219,20 @@ class LlmForCausalLMImplBase : public torch::nn::Module { std::unique_ptr loader, std::string prefix = "model." /*llm model weight prefix*/) { for (const auto& state_dict : loader->get_state_dicts()) { - auto sub_dict = state_dict->get_dict_with_prefix(prefix); - if (sub_dict.size() == 0) { - sub_dict = state_dict->get_dict_with_prefix(""); - } - model_->load_state_dict(sub_dict); + // The same model_type may come from checkpoints with different top-level + // weight prefixes. Try these candidate prefixes in order to improve + // compatibility across such variants. + model_->load_state_dict(state_dict->get_dict_with_prefix( + std::vector{"model.language_model.", + "language_model.model.", + prefix, + "model.", + ""})); if (tie_word_embeddings) { lm_head_->load_state_dict( - state_dict->get_dict_with_prefix(prefix + "embed_tokens.")); + state_dict->get_dict_with_prefix(std::vector{ + prefix + "embed_tokens.", "embed_tokens."})); } else { lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } diff --git a/xllm/models/llm/mtp_model_base.h b/xllm/models/llm/mtp_model_base.h index 5f3c2fd93..298c314fa 100644 --- a/xllm/models/llm/mtp_model_base.h +++ b/xllm/models/llm/mtp_model_base.h @@ -196,6 +196,10 @@ class MtpModelImplBase : public torch::nn::Module { attn_metadata, kv_caches[i], modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + hidden_states.device())) { + return ModelOutput(); + } } auto [h_out, r_out] = norm_(hidden_states, residual); return ModelOutput(h_out, r_out); diff --git a/xllm/models/llm/musa/qwen3.h b/xllm/models/llm/musa/qwen3.h index f3b3a65f3..f77161163 100644 --- a/xllm/models/llm/musa/qwen3.h +++ b/xllm/models/llm/musa/qwen3.h @@ -15,10 +15,10 @@ limitations under the License. #pragma once -#include "core/common/rec_model_utils.h" #include "core/framework/model/model_output.h" #include "core/layers/common/rotary_embedding.h" #include "core/layers/musa/musa_qwen3_decoder_layer_impl.h" +#include "core/util/rec_model_utils.h" #include "models/llm/llm_model_base.h" namespace xllm { diff --git a/xllm/models/llm/npu/joyai_llm_flash.h b/xllm/models/llm/npu/joyai_llm_flash.h new file mode 100644 index 000000000..d7ea42ea8 --- /dev/null +++ b/xllm/models/llm/npu/joyai_llm_flash.h @@ -0,0 +1,319 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "deepseek_v2.h" +#include "layers/common/rotary_embedding_util.h" + +namespace xllm::npu::model { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class JoyAILLMFlashModelImpl : public torch::nn::Module { + public: + JoyAILLMFlashModelImpl(const ModelContext& context) + : device_(context.get_tensor_options().device()) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + // register submodules + device_ = options.device(); + dtype_ = options.dtype().toScalarType(); + num_speculative_tokens_ = model_args.num_speculative_tokens(); + + npu_embed_tokens_ = + register_module("npu_embed_tokens", layer::NpuWordEmbedding(context)); + atb_pos_emb_ = layer::NpuPosEmbedding(context); + cos_sin_ = layer::rotary::get_concat_rotary_embedding( + model_args.qk_rope_head_dim(), + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); + + max_seq_len_ = model_args.max_position_embeddings(); + int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + + for (int32_t i = 0; i < model_args.n_layers(); ++i) { + auto block = DeepseekV2DecoderLayer(context, i); + layers_.push_back(block); + blocks_->push_back(block); + } + + norm_ = register_module("norm", layer::NpuRMSNorm(context)); + + dp_size_ = parallel_args.dp_size(); + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + num_experts_per_tok_ = model_args.num_experts_per_tok(); + } + + ModelOutput forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); + } + } + + auto h = npu_embed_tokens_(tokens, 0); + auto cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = cos_sin_chunks[0].contiguous(); + auto sin_pos = cos_sin_chunks[1].contiguous(); + + torch::Tensor attn_mask; + if (FLAGS_enable_prefix_cache && + !input_params.batch_forward_type.is_decode()) { + attn_mask = attn_mask_.get_attn_mask(512, dtype_, device_); + } else if (input_params.batch_forward_type.is_prefill()) { + attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); + } else if (num_speculative_tokens_ > 0) { + // TODO :the judgement of gen_free_mask need more check + attn_mask = attn_mask_.gen_free_mask( + num_speculative_tokens_ + 1, dtype_, device_); + } + + RollingLayerGuard rolling_guard(rolling_mgr_); + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (!input_params.synchronize_layer(i)) { + return ModelOutput(); + } + + auto& layer = layers_[i]; + const int32_t layer_index = i; + rolling_guard.before_layer(layer_index); + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params, + event, + event_flag); + rolling_guard.after_layer(layer_index); + } + auto hidden_states = norm_(h, 0); + return ModelOutput(hidden_states); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + npu_embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // call each layer's load_state_dict function + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + npu_embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "norm."); + } + + void merge_loaded_weights() { + npu_embed_tokens_->merge_loaded_weights(); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->merge_loaded_weights(); + } + norm_->merge_loaded_weights(); + } + + void merge_and_move_pinned_host() { + npu_embed_tokens_->merge_and_move_pinned_host(); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->merge_and_move_pinned_host(); + } + norm_->merge_and_move_pinned_host(); + } + + void free_weights() { + npu_embed_tokens_->free_weights(); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->free_weights(); + } + norm_->free_weights(); + } + + void reload_weights() { + npu_embed_tokens_->reload_weights(); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->reload_weights(); + } + norm_->reload_weights(); + } + + void reload_non_decoder_weights() { + npu_embed_tokens_->reload_weights(); + norm_->reload_weights(); + } + void reload_weights_from_device() { + npu_embed_tokens_->reload_weights_from_device(); + for (size_t i = 0; i < layers_.size(); i++) { + layers_[i]->reload_weights_from_device(); + } + norm_->reload_weights_from_device(); + } + + void refresh_rolling_weights() { + for (auto& layer : layers_) { + layer->refresh_rolling_weights(); + } + } + + std::vector get_decoder_loaders() { + std::vector loaders; + loaders.reserve(layers_.size()); + for (auto& layer : layers_) { + loaders.push_back(layer->get_manual_loader()); + } + return loaders; + } + + void set_rolling_load_manager(RollingLoadManager* mgr) { rolling_mgr_ = mgr; } + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + layers_[layer_id]->prepare_expert_weight(expert_ids); + } + + void update_expert_weight(int32_t layer_id) { + layers_[layer_id]->update_expert_weight(); + } + + layer::NpuWordEmbedding get_npu_word_embedding() { return npu_embed_tokens_; } + + void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) { + npu_embed_tokens_ = npu_word_embedding; + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + int32_t max_seq_len_ = 0; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + int32_t num_experts_per_tok_; + int32_t num_speculative_tokens_ = 0; + at::Device device_; + torch::Dtype dtype_; + layer::NpuWordEmbedding npu_embed_tokens_{nullptr}; + torch::Tensor cos_sin_; + layer::NpuPosEmbedding atb_pos_emb_{nullptr}; + layer::AttentionMask attn_mask_; + layer::NpuRMSNorm norm_{nullptr}; + RollingLoadManager* rolling_mgr_ = nullptr; +}; +TORCH_MODULE(JoyAILLMFlashModel); + +class JoyAILLMFlashForCausalLMImpl + : public LlmForCausalLMImplBase { + public: + JoyAILLMFlashForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context), + first_k_dense_replace_( + context.get_model_args().first_k_dense_replace()) {} + + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) override { + model_->prepare_expert_weight(layer_id + first_k_dense_replace_, + expert_ids); + } + + void update_expert_weight(int32_t layer_id) override { + model_->update_expert_weight(layer_id + first_k_dense_replace_); + } + + private: + int32_t first_k_dense_replace_; +}; +TORCH_MODULE(JoyAILLMFlashForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(joyai_llm_flash, JoyAILLMFlashForCausalLM); +// register the model args +// example config: +// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json +REGISTER_MODEL_ARGS(joyai_llm_flash, [&] { + LOAD_ARG_OR(model_type, "model_type", "joyai_llm_flash"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 129280); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 40); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 32); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 7168); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 131072); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 1); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); + LOAD_ARG_OR(rope_theta, "rope_theta", 32000000.0f); + + LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 1); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); + LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); + LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 768); + LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 2.5f); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_group, "n_group", 1); + LOAD_ARG_OR(topk_group, "topk_group", 1); + LOAD_ARG_OR(scoring_func, "scoring_func", "sigmoid"); + LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); + LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); + LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); + LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536); + LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); + }); + LOAD_ARG_OR_FUNC( + rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); + + // uses default rope_type, no deepseek_yarn scaling + SET_ARG(rope_scaling_rope_type, "default"); + SET_ARG(stop_token_ids, std::unordered_set({1})); +}); +} // namespace xllm::npu::model diff --git a/xllm/models/llm/npu/mtp_model_base.h b/xllm/models/llm/npu/mtp_model_base.h index 6a4581e11..33e474566 100644 --- a/xllm/models/llm/npu/mtp_model_base.h +++ b/xllm/models/llm/npu/mtp_model_base.h @@ -46,7 +46,8 @@ class MtpModelImplBase : public torch::nn::Module { public: // mode type: qwen2, qwen3 .etc MtpModelImplBase(const std::string& model_type, const ModelContext& context) - : model_type_(model_type) { + : model_type_(model_type), + device_(context.get_tensor_options().device()) { InterruptionBus::get_instance().subscribe([this](bool interrupted) { this->layer_forward_interrupted_ = interrupted; }); @@ -90,9 +91,11 @@ class MtpModelImplBase : public torch::nn::Module { torch::Tensor positions, std::vector& kv_caches, const ModelInputParams& input_params) { - if (dp_size_ > 1 && tokens.numel() == 0) { - tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); - positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + if (dp_size_ > 1 && (!tokens.defined() || tokens.numel() == 0)) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(device_); + tokens = torch::tensor({1}, options); + positions = torch::tensor({0}, options); } torch::Tensor h = embed_tokens_(tokens, 0); @@ -134,6 +137,10 @@ class MtpModelImplBase : public torch::nn::Module { req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); + } else { + // handle dp empty case + attn_mask = + attn_mask_.get_attn_mask(128, h.dtype().toScalarType(), h.device()); } } else if (model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache && !input_params.batch_forward_type.is_decode()) { @@ -259,6 +266,7 @@ class MtpModelImplBase : public torch::nn::Module { private: std::string model_type_; + torch::Device device_; }; template diff --git a/xllm/models/llm/oxygen.h b/xllm/models/llm/oxygen.h new file mode 100644 index 000000000..6a9024a8a --- /dev/null +++ b/xllm/models/llm/oxygen.h @@ -0,0 +1,138 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include "core/framework/model/model_output.h" +#include "core/layers/common/rotary_embedding_util.h" +#include "llm_model_base.h" +#include "qwen3.h" + +namespace xllm { + +class OxygenModelImpl : public QWen3ModelImpl { + public: + OxygenModelImpl(const ModelContext& context) : QWen3ModelImpl(context) {} + + virtual ModelOutput forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool use_deepstack = input_params.deep_stacks.size() > 0; + ModelInputParams& input_params_new = + const_cast(input_params); + std::vector deep_stacks; + + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens); + } + if (use_deepstack) { + deep_stacks = input_params.deep_stacks; // [num_deepstack, hidden_size] + } + + auto& dp_token_nums = input_params_new.dp_global_token_nums; + std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); + if (!input_params_new.attn_metadata) { + input_params_new.attn_metadata = + std::make_shared( + get_attention_metadata(input_params_new, h)); + } + + auto& attn_metadata = *(input_params_new.attn_metadata); + bool only_prefill = + (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill); + if (positions.dim() == 2 && only_prefill && !mrope_section_.empty()) { + std::tie(attn_metadata.mrope_cos, attn_metadata.mrope_sin) = + apply_mrope(positions); + } + + std::optional residual; + for (size_t i = 0; i < layers_.size(); i++) { + if (is_rec_multi_round_mode() && input_params_new.has_llmrec_params()) { + const auto& llmrec_params = input_params_new.llmrec_params(); + attn_metadata.full_k_cache = llmrec_params->full_k_caches[i]; + attn_metadata.full_v_cache = llmrec_params->full_v_caches[i]; + attn_metadata.unshared_k_cache = llmrec_params->unshared_k_caches[i]; + attn_metadata.unshared_v_cache = llmrec_params->unshared_v_caches[i]; + } + auto& layer = layers_[i]; + h = layer(h, + residual, + positions, + attn_metadata, + kv_caches[i], + input_params_new); + if (!input_params_new.record_layer(static_cast(i), + h.device())) { + return ModelOutput(); + } + + if (use_deepstack) { + if (deep_stacks.size() > 0 && i < deep_stacks.size()) { + h = deepstack_process( + h, input_params.visual_pos_masks, deep_stacks[i]); + } + } + } + auto [hidden_states, residual_out] = norm_(h, residual); + return ModelOutput(hidden_states, residual_out); + } + + protected: + std::pair apply_mrope( + const torch::Tensor positions) override { + auto target_cos_sin = cos_sin_.index({positions}); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + auto apply = [this](torch::Tensor x) { + auto sections = mrope_section_; + sections.insert(sections.end(), sections.begin(), sections.end()); + + auto vec = x.split(sections, -1); + std::vector selects; + selects.reserve(vec.size()); + + for (int64_t i = 0; i < vec.size(); ++i) { + auto m = vec[i]; + selects.push_back(m[i % mrope_section_.size()]); + } + return torch::cat(selects, -1); + }; + cos_pos = apply(cos_pos.reshape({positions.size(0), -1, cos_pos.size(-1)})); + sin_pos = apply(sin_pos.reshape({positions.size(0), -1, sin_pos.size(-1)})); + return std::make_pair(cos_pos, sin_pos); + } +}; +TORCH_MODULE(OxygenModel); + +class OxygenForCausalLMImpl : public LlmForCausalLMImplBase { + public: + OxygenForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(OxygenForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(oxygenvlm_text, OxygenForCausalLM); + +} // namespace xllm diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index 09267b9f0..a2cd8acc1 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -43,7 +43,7 @@ class QWen2ModelImpl : public LlmModelImplBase { register_module("embed_tokens", layer::WordEmbedding(context)); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto layer = layer::Qwen2DecoderLayer(context); + auto layer = layer::Qwen2DecoderLayer(context, i); layers_.push_back(layer); } } diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index d716297c7..5d934f615 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -17,8 +17,8 @@ limitations under the License. #include -#include "core/common/rec_model_utils.h" #include "core/framework/model/model_output.h" +#include "core/util/rec_model_utils.h" #if defined(USE_NPU) #include "core/common/global_flags.h" #include "core/layers/common/attention_mask.h" @@ -55,7 +55,7 @@ class QWen3ModelImpl : public LlmModelImplBase { options.device(), options.dtype().toScalarType(), mask_value); #endif for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto layer = layer::Qwen3DecoderLayer(context); + auto layer = layer::Qwen3DecoderLayer(context, i); layers_.push_back(layer); } } @@ -172,6 +172,10 @@ class QWen3ModelImpl : public LlmModelImplBase { attn_metadata, kv_caches[i], input_params_new); + if (!input_params_new.record_layer(static_cast(i), + h.device())) { + return ModelOutput(); + } if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { @@ -184,7 +188,7 @@ class QWen3ModelImpl : public LlmModelImplBase { return ModelOutput(hidden_states, residual_out); } - private: + protected: layer::AttentionMetadata get_attention_metadata( const ModelInputParams& params, const torch::Tensor& h) { @@ -224,6 +228,7 @@ class QWen3ModelImpl : public LlmModelImplBase { #endif } + private: #if defined(USE_NPU) layer::AttentionMask attn_mask_; #endif diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index e6c5a2d90..a2729be8a 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -15,9 +15,9 @@ limitations under the License. #pragma once -#include "core/common/rec_model_utils.h" #include "core/framework/model/model_output.h" #include "core/layers/qwen3_moe_decoder_layer.h" +#include "core/util/rec_model_utils.h" #include "llm_model_base.h" namespace xllm { @@ -164,6 +164,10 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { attn_metadata, kv_caches[i], modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + h.device())) { + return ModelOutput(); + } if (deep_stack_size && i < deep_stack_size) { h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); diff --git a/xllm/models/model_registry.cpp b/xllm/models/model_registry.cpp index 738faf9d4..24241d4f5 100644 --- a/xllm/models/model_registry.cpp +++ b/xllm/models/model_registry.cpp @@ -200,20 +200,6 @@ void ModelRegistry::register_causalvlm_factory(const std::string& name, } } -void ModelRegistry::register_vlm_embedding_factory( - const std::string& name, - EmbeddingVLMFactory factory) { - ModelRegistry* instance = get_instance(); - - if (instance->model_registry_[name].embedding_vlm_factory != nullptr) { - SAFE_LOG_WARNING("embedding vlm factory for " << name - << " already registered."); - } else { - instance->model_registry_[name].embedding_vlm_factory = factory; - instance->model_backend_[name] = "vlm"; - } -} - void ModelRegistry::register_mm_embedding_vlm_factory( const std::string& name, MMEmbeddingVLMFactory factory) { @@ -321,13 +307,6 @@ CausalVLMFactory ModelRegistry::get_causalvlm_factory(const std::string& name) { return instance->model_registry_[name].causal_vlm_factory; } -EmbeddingVLMFactory ModelRegistry::get_embeddingvlm_factory( - const std::string& name) { - ModelRegistry* instance = get_instance(); - - return instance->model_registry_[name].embedding_vlm_factory; -} - MMEmbeddingVLMFactory ModelRegistry::get_mm_embedding_vlm_factory( const std::string& name) { ModelRegistry* instance = get_instance(); @@ -373,6 +352,12 @@ TokenizerArgsLoader ModelRegistry::get_tokenizer_args_loader( return instance->model_registry_[name].tokenizer_args_loader; } +bool ModelRegistry::has_dit_model_factory(const std::string& name) { + ModelRegistry* instance = get_instance(); + return (instance->model_registry_.find(name) != + instance->model_registry_.end()); +} + std::string ModelRegistry::get_model_backend(const std::string& name) { ModelRegistry* instance = get_instance(); return instance->model_backend_[name]; @@ -441,28 +426,6 @@ std::unique_ptr create_vlm_model(const ModelContext& context) { return nullptr; } -std::unique_ptr create_vlm_embedding_model( - const ModelContext& context) { - std::string resolved_name; - std::string error_message; - if (!resolve_model_registration_name(context.get_model_args().model_type(), - &resolved_name, - &error_message)) { - LOG(ERROR) << error_message; - return nullptr; - } - - auto factory = ModelRegistry::get_embeddingvlm_factory(resolved_name); - if (factory) { - return factory(context); - } - - LOG(ERROR) << "Unsupported model type: " - << context.get_model_args().model_type(); - - return nullptr; -} - std::unique_ptr create_vlm_mm_embedding_model( const ModelContext& context) { std::string resolved_name; diff --git a/xllm/models/model_registry.h b/xllm/models/model_registry.h index 168f80b5a..77c674146 100644 --- a/xllm/models/model_registry.h +++ b/xllm/models/model_registry.h @@ -24,7 +24,6 @@ limitations under the License. #include "core/framework/model/causal_lm.h" #include "core/framework/model/causal_vlm.h" #include "core/framework/model/dit_model.h" -#include "core/framework/model/embedding_vlm.h" #include "core/framework/model/mm_embedding_vlm.h" #include "core/framework/model_context.h" #include "core/framework/tokenizer/tokenizer_args.h" @@ -44,9 +43,6 @@ using RecModelFactory = using CausalVLMFactory = std::function(const ModelContext& context)>; -using EmbeddingVLMFactory = - std::function(const ModelContext& context)>; - using MMEmbeddingVLMFactory = std::function(const ModelContext& context)>; @@ -73,7 +69,6 @@ struct ModelMeta { CausalLMFactory causal_lm_factory; RecModelFactory rec_model_factory; CausalVLMFactory causal_vlm_factory; - EmbeddingVLMFactory embedding_vlm_factory; MMEmbeddingVLMFactory mm_embedding_vlm_factory; DiTModelFactory dit_model_factory; InputProcessorFactory input_processor_factory; @@ -98,9 +93,6 @@ class ModelRegistry { static void register_causalvlm_factory(const std::string& name, CausalVLMFactory factory); - static void register_vlm_embedding_factory(const std::string& name, - EmbeddingVLMFactory factory); - static void register_mm_embedding_vlm_factory(const std::string& name, MMEmbeddingVLMFactory factory); @@ -127,8 +119,6 @@ class ModelRegistry { static CausalVLMFactory get_causalvlm_factory(const std::string& name); - static EmbeddingVLMFactory get_embeddingvlm_factory(const std::string& name); - static MMEmbeddingVLMFactory get_mm_embedding_vlm_factory( const std::string& name); @@ -146,6 +136,8 @@ class ModelRegistry { static ImageProcessorFactory get_image_processor_factory( const std::string& name); + static bool has_dit_model_factory(const std::string& name); + static std::string get_model_backend(const std::string& name); private: @@ -169,9 +161,6 @@ std::unique_ptr create_rec_model(const ModelContext& context); std::unique_ptr create_vlm_model(const ModelContext& context); -std::unique_ptr create_vlm_embedding_model( - const ModelContext& context); - std::unique_ptr create_vlm_mm_embedding_model( const ModelContext& context); @@ -223,22 +212,6 @@ std::unique_ptr create_dit_model(const DiTModelContext& context); #define REGISTER_CAUSAL_VLM_MODEL(ModelType, ModelClass) \ REGISTER_CAUSAL_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) -#define REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME( \ - VarName, ModelType, ModelClass) \ - const bool VarName##_registered = []() { \ - ModelRegistry::register_vlm_embedding_factory( \ - #ModelType, [](const ModelContext& context) { \ - ModelClass model(context); \ - model->eval(); \ - return std::make_unique>( \ - std::move(model), context.get_tensor_options()); \ - }); \ - return true; \ - }() - -#define REGISTER_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \ - REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass) - #define REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME( \ VarName, ModelType, ModelClass) \ const bool VarName##_registered = []() { \ diff --git a/xllm/models/models.h b/xllm/models/models.h index b07609a38..482dc6af3 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -16,6 +16,7 @@ limitations under the License. #pragma once #if defined(USE_NPU) +#include "dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h" // IWYU pragma: keep #include "dit/pipeline_flux.h" // IWYU pragma: keep #include "dit/pipeline_flux_control.h" // IWYU pragma: keep #include "dit/pipeline_flux_fill.h" // IWYU pragma: keep @@ -30,6 +31,7 @@ limitations under the License. #include "llm/npu/glm4_moe_mtp.h" // IWYU pragma: keep #include "llm/npu/glm5_moe.h" // IWYU pragma: keep #include "llm/npu/glm5_moe_mtp.h" // IWYU pragma: keep +#include "llm/npu/joyai_llm_flash.h" // IWYU pragma: keep #include "llm/npu/kimi_k2.h" // IWYU pragma: keep #include "llm/npu/llama.h" // IWYU pragma: keep #include "llm/npu/llama3.h" // IWYU pragma: keep @@ -42,7 +44,7 @@ limitations under the License. #include "llm/qwen3_5.h" // IWYU pragma: keep #include "llm/qwen3_5_mtp.h" // IWYU pragma: keep #include "llm/qwen3_next.h" // IWYU pragma: keep -#include "rec/onerec.h" // IWYU pragma: keep +#include "rec/npu/onerec.h" // IWYU pragma: keep #include "vlm/npu/glm4v.h" // IWYU pragma: keep #include "vlm/npu/glm4v_moe.h" // IWYU pragma: keep #include "vlm/npu/minicpmv.h" // IWYU pragma: keep @@ -50,27 +52,32 @@ limitations under the License. #include "vlm/npu/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/npu/qwen2_5_vl_mm_embedding.h" // IWYU pragma: keep #include "vlm/npu/qwen2_vl.h" // IWYU pragma: keep -#include "vlm/npu/qwen2_vl_embedding.h" // IWYU pragma: keep #include "vlm/npu/qwen3_vl.h" // IWYU pragma: keep #include "vlm/npu/qwen3_vl_mm_embedding.h" // IWYU pragma: keep #include "vlm/npu/qwen3_vl_moe.h" // IWYU pragma: keep + #elif defined(USE_MLU) -#include "llm/deepseek_mtp.h" // IWYU pragma: keep -#include "llm/deepseek_v2.h" // IWYU pragma: keep -#include "llm/deepseek_v3.h" // IWYU pragma: keep -#include "llm/deepseek_v32.h" // IWYU pragma: keep -#include "llm/glm5.h" // IWYU pragma: keep -#include "llm/glm5_mtp.h" // IWYU pragma: keep -#include "llm/joyai_llm_flash.h" // IWYU pragma: keep -#include "llm/mtp_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep -#include "llm/qwen3.h" // IWYU pragma: keep -#include "llm/qwen3_moe.h" // IWYU pragma: keep -#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep -#include "vlm/qwen2_vl.h" // IWYU pragma: keep -#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep -#include "vlm/qwen3_vl.h" // IWYU pragma: keep -#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep +#include "dit/pipeline_flux.h" // IWYU pragma: keep +#include "dit/pipeline_flux_control.h" // IWYU pragma: keep +#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep +#include "llm/deepseek_mtp.h" // IWYU pragma: keep +#include "llm/deepseek_v2.h" // IWYU pragma: keep +#include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/deepseek_v32.h" // IWYU pragma: keep +#include "llm/glm5.h" // IWYU pragma: keep +#include "llm/glm5_mtp.h" // IWYU pragma: keep +#include "llm/joyai_llm_flash.h" // IWYU pragma: keep +#include "llm/joyai_llm_flash_mtp.h" // IWYU pragma: keep +#include "llm/mtp_model_base.h" // IWYU pragma: keep +#include "llm/oxygen.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3.h" // IWYU pragma: keep +#include "llm/qwen3_moe.h" // IWYU pragma: keep +#include "vlm/oxygen_vlm.h" // IWYU pragma: keep +#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen2_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #elif defined(USE_ILU) #include "llm/qwen2.h" // IWYU pragma: keep #include "llm/qwen3.h" // IWYU pragma: keep @@ -83,7 +90,6 @@ limitations under the License. #include "llm/qwen3_moe.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/qwen2_vl.h" // IWYU pragma: keep -#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #elif defined(USE_MUSA) diff --git a/xllm/models/rec/npu/onerec.h b/xllm/models/rec/npu/onerec.h new file mode 100644 index 000000000..96496c5ea --- /dev/null +++ b/xllm/models/rec/npu/onerec.h @@ -0,0 +1,361 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model/model_output.h" +#include "core/framework/model_context.h" +#include "core/framework/model_loader.h" +#include "core/layers/common/lm_head.h" +#include "core/layers/common/word_embedding.h" +#include "models/model_registry.h" +#include "models/rec/npu/onerec_npu_impl.h" +#include "models/rec/rec_model_base.h" + +namespace xllm { + +class OneRecModelImpl final : public torch::nn::Module { + public: + explicit OneRecModelImpl(const ModelContext& context) { + hidden_size_ = context.get_model_args().hidden_size(); + options_ = context.get_tensor_options(); + shared_ = register_module("shared", layer::WordEmbedding(context)); + + encoder_ = register_module( + "encoder", OneRecStack(context, /*is_decode=*/false, shared_)); + decoder_ = register_module( + "decoder", OneRecStack(context, /*is_decode=*/true, shared_)); + } + + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (!tokens.defined()) { + return ModelOutput(); + } + (void)positions; + (void)kv_caches; + + const auto* onerec_params = input_params.onerec_params(); + + if (onerec_params != nullptr) { + if (onerec_params->is_encoder_forward) { + std::vector encoder_kv_caches; + auto encoder_output = + encoder_(tokens, positions, encoder_kv_caches, input_params); + + torch::Tensor cached_encoder_output; + if (encoder_output.defined() && + onerec_params->encoder_max_seq_len > 0 && + !onerec_params->encoder_seq_lens.empty()) { + cached_encoder_output = + pad_encoder_output(encoder_output, input_params); + } else { + cached_encoder_output = encoder_output; + } + { + std::lock_guard lock(encoder_output_mutex_); + encoder_output_ = cached_encoder_output; + } + return ModelOutput(cached_encoder_output); + } + + torch::Tensor cached_encoder_output; + if (onerec_params->has_encoder_output) { + std::lock_guard lock(encoder_output_mutex_); + cached_encoder_output = encoder_output_; + } + + const torch::Tensor& decoder_context = + onerec_params->decoder_context_embedding; + + if (!decoder_context.defined() && !cached_encoder_output.defined()) { + LOG(ERROR) + << "OneRec decoder requires decoder_context_embedding or encoder " + "output."; + return ModelOutput(); + } + + auto decoder_output = decoder_( + tokens, positions, kv_caches, input_params, cached_encoder_output); + return ModelOutput(decoder_output); + } + + const bool is_encoder_forward = + (onerec_params != nullptr) && onerec_params->is_encoder_forward; + + auto hidden_states = + build_hidden_states(tokens, onerec_params, is_encoder_forward); + if (!hidden_states.defined()) { + return ModelOutput(); + } + + if (is_encoder_forward) { + return ModelOutput(hidden_states); + } + + auto cross_context = resolve_cross_context(onerec_params); + if (cross_context.defined()) { + auto enriched_hidden_states = + add_cross_context_bias(hidden_states, cross_context); + if (enriched_hidden_states.defined()) { + hidden_states = std::move(enriched_hidden_states); + } + } + + return ModelOutput(hidden_states); + } + + void load_state_dict(const StateDict& state_dict) { + auto shared_dict = state_dict.get_dict_with_prefix("shared."); + if (shared_dict.size() > 0) { + shared_->load_state_dict(shared_dict); + } + + auto encoder_dict = state_dict.get_dict_with_prefix("encoder."); + if (encoder_dict.size() > 0) { + encoder_->load_state_dict(encoder_dict); + } + auto decoder_dict = state_dict.get_dict_with_prefix("decoder."); + if (decoder_dict.size() > 0) { + decoder_->load_state_dict(decoder_dict); + } + } + + void verify_loaded_weights() const { + encoder_->verify_loaded_weights("encoder."); + decoder_->verify_loaded_weights("decoder."); + } + + void merge_loaded_weights() { + encoder_->merge_loaded_weights(); + decoder_->merge_loaded_weights(); + } + + layer::WordEmbedding get_word_embedding() { return shared_; } + + void set_word_embedding(layer::WordEmbedding& embedding) { + shared_ = embedding; + encoder_->set_word_embedding(shared_); + decoder_->set_word_embedding(shared_); + } + + private: + static bool is_token_id_tensor(const torch::Tensor& tokens) { + return tokens.scalar_type() == torch::kLong || + tokens.scalar_type() == torch::kInt; + } + + torch::Tensor build_hidden_states(const torch::Tensor& tokens, + const OneRecModelInputParams* onerec_params, + bool is_encoder_forward) { + if (tokens.numel() == 0) { + return torch::empty({0, hidden_size_}, options_); + } + + if (is_token_id_tensor(tokens)) { + return shared_(tokens); + } + + if (tokens.dim() == 2 && tokens.size(-1) == hidden_size_) { + if (onerec_params != nullptr) { + if (onerec_params->is_hybrid_mode || is_encoder_forward) { + return tokens; + } + if (onerec_params->decoder_context_embedding.defined()) { + return tokens; + } + } + return tokens; + } + + LOG(ERROR) << "Invalid OneRec token tensor shape for non-id path: " + << tokens.sizes(); + return torch::Tensor(); + } + + torch::Tensor resolve_cross_context( + const OneRecModelInputParams* onerec_params) const { + if (onerec_params == nullptr) { + return torch::Tensor(); + } + if (onerec_params->decoder_context_embedding.defined()) { + return onerec_params->decoder_context_embedding; + } + return torch::Tensor(); + } + + torch::Tensor add_cross_context_bias( + const torch::Tensor& hidden_states, + const torch::Tensor& cross_context) const { + if (!hidden_states.defined() || !cross_context.defined()) { + return hidden_states; + } + + if (hidden_states.dim() != 2 || hidden_states.size(-1) != hidden_size_) { + LOG(ERROR) << "Unexpected hidden_states shape in OneRec decoder: " + << hidden_states.sizes(); + return hidden_states; + } + + auto context = cross_context; + if (context.device() != hidden_states.device()) { + context = context.to(hidden_states.device()); + } + if (context.dtype() != hidden_states.dtype()) { + context = context.to(hidden_states.dtype()); + } + + if (context.dim() == 1 && context.size(0) == hidden_size_) { + context = context.unsqueeze(0); + } else if (context.dim() > 2 && context.size(-1) == hidden_size_) { + context = context.reshape({-1, hidden_size_}); + } + + if (context.dim() != 2 || context.size(-1) != hidden_size_) { + LOG(ERROR) << "Unexpected OneRec cross context shape: " + << context.sizes(); + return hidden_states; + } + + auto pooled_context = context.mean(/*dim=*/0, /*keepdim=*/true); + return hidden_states + pooled_context.expand( + {hidden_states.size(0), pooled_context.size(1)}); + } + + torch::TensorOptions options_; + int64_t hidden_size_ = 0; + layer::WordEmbedding shared_{nullptr}; + + OneRecStack encoder_{nullptr}; + OneRecStack decoder_{nullptr}; + torch::Tensor encoder_output_; + std::mutex encoder_output_mutex_; +}; +TORCH_MODULE(OneRecModel); + +class OneRecForConditionalGenerationImpl final + : public RecForCausalLMImplBase { + public: + explicit OneRecForConditionalGenerationImpl(const ModelContext& context) + : RecForCausalLMImplBase(context) {} + + void load_model(std::unique_ptr loader, + std::string prefix = "model.") override { + for (const auto& state_dict : loader->get_state_dicts()) { + StateDict model_state_dict = state_dict->get_dict_with_prefix(prefix); + if (model_state_dict.size() == 0) { + model_state_dict = *state_dict; + } + model_->load_state_dict(model_state_dict); + + if (tie_word_embeddings_) { + auto shared_dict = model_state_dict.get_dict_with_prefix("shared."); + if (shared_dict.size() == 0) { + shared_dict = state_dict->get_dict_with_prefix("shared."); + } + if (shared_dict.size() != 0) { + lm_head_->load_state_dict(shared_dict); + } + } else { + auto lm_head_dict = model_state_dict.get_dict_with_prefix("lm_head."); + if (lm_head_dict.size() == 0) { + lm_head_dict = state_dict->get_dict_with_prefix("lm_head."); + } + if (lm_head_dict.size() != 0) { + lm_head_->load_state_dict(lm_head_dict); + } + } + } + + model_->verify_loaded_weights(); + model_->merge_loaded_weights(); + } +}; +TORCH_MODULE(OneRecForConditionalGeneration); + +using OneRecCausalLM = CausalLMImpl; +static_assert(std::is_base_of_v, + "OneRec must satisfy CausalLM contract."); + +REGISTER_REC_MODEL(onerec, OneRecForConditionalGeneration); + +REGISTER_MODEL_ARGS(onerec, [&] { + LOAD_ARG_OR(model_type, "model_type", "onerec"); + LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16"); + + LOAD_ARG_OR(hidden_size, "d_model", 128); + LOAD_ARG_OR(intermediate_size, "d_ff", 256); + + LOAD_ARG_OR(n_layers, "num_decoder_layers", 4); + LOAD_ARG_OR(n_encoder_layers, "num_layers", 12); + + LOAD_ARG_OR(n_heads, "num_heads", 4); + LOAD_ARG_OR(head_dim, "d_kv", 32); + LOAD_ARG_OR_FUNC( + decoder_n_heads, "decoder_num_heads", [&] { return args->n_heads(); }); + LOAD_ARG_OR_FUNC( + decoder_head_dim, "decoder_d_kv", [&] { return args->head_dim(); }); + + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + LOAD_ARG(decoder_n_kv_heads, "decoder_num_key_value_heads"); + + LOAD_ARG_OR(vocab_size, "vocab_size", 8200); + LOAD_ARG_OR(rms_norm_eps, "layer_norm_epsilon", 1e-6); + LOAD_ARG_OR(max_position_embeddings, "max_length", 500); + LOAD_ARG_OR(use_absolute_position_embedding, + "use_absolute_position_embedding", + false); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", true); + + LOAD_ARG_OR(use_moe, "use_moe", false); + LOAD_ARG_OR(moe_score_func, "moe_score_func", "softmax"); + LOAD_ARG_OR(moe_route_scale, "moe_route_scale", 1.0f); + LOAD_ARG_OR(n_routed_experts, "moe_num_experts", 8); + LOAD_ARG_OR(moe_use_shared_experts, "moe_use_shared_experts", false); + LOAD_ARG_OR(n_shared_experts, "moe_num_shared_experts", 0); + LOAD_ARG_OR(num_experts_per_tok, "moe_topk", 2); + LOAD_ARG_OR(moe_intermediate_size, "moe_inter_dim", 1024); + + LOAD_ARG_OR( + relative_attention_num_buckets, "relative_attention_num_buckets", 32); + LOAD_ARG_OR( + relative_attention_max_distance, "relative_attention_max_distance", 128); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 128001); + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +REGISTER_TOKENIZER_ARGS(onerec, [&] { + SET_ARG(tokenizer_type, "rec"); + LOAD_ARG_OR(vocab_file, "vocab_file", ""); +}); + +} // namespace xllm diff --git a/xllm/models/rec/npu/onerec_npu_impl.h b/xllm/models/rec/npu/onerec_npu_impl.h new file mode 100644 index 000000000..a04233d6e --- /dev/null +++ b/xllm/models/rec/npu/onerec_npu_impl.h @@ -0,0 +1,478 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/common/global_flags.h" +#include "core/layers/common/rms_norm.h" +#include "core/layers/npu/npu_onerec_block_layer_impl.h" + +namespace xllm { + +inline torch::Tensor pad_encoder_output(const torch::Tensor& encoder_output, + const ModelInputParams& input_params) { + const auto* onerec_params = input_params.onerec_params(); + CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; + + const int64_t bs = onerec_params->bs; + const int64_t hidden_size = encoder_output.size(1); + const auto& seq_lens = onerec_params->encoder_seq_lens; + const int64_t max_seq_len = onerec_params->encoder_max_seq_len; + + CHECK_EQ(static_cast(seq_lens.size()), bs) + << "encoder_seq_lens size mismatch."; + + std::vector seq_list; + seq_list.reserve(static_cast(bs)); + + int64_t token_offset = 0; + for (int64_t i = 0; i < bs; ++i) { + const int64_t seq_len = seq_lens[i]; + seq_list.emplace_back(encoder_output.narrow(0, token_offset, seq_len)); + token_offset += seq_len; + } + + auto padded_output = torch::nn::utils::rnn::pad_sequence( + seq_list, /*batch_first=*/true, /*padding_value=*/0.0); + + if (padded_output.size(1) < max_seq_len) { + auto extra_padding = + torch::zeros({bs, max_seq_len - padded_output.size(1), hidden_size}, + encoder_output.options()); + padded_output = torch::cat({padded_output, extra_padding}, /*dim=*/1); + } + + return padded_output; +} + +inline torch::Tensor compute_onerec_position_bias( + int64_t query_length, + int64_t key_length, + int64_t num_heads, + bool is_decoder, + layer::WordEmbedding& position_bias_embedding, + int64_t num_buckets = 32, + int64_t max_distance = 128, + const torch::TensorOptions& options = torch::kFloat32, + bool is_decode_stage = false, + const ModelInputParams* input_params = nullptr) { + auto device = options.device(); + auto dtype = options.dtype(); + + int64_t actual_query_length = is_decode_stage ? key_length : query_length; + if (actual_query_length <= 0) { + actual_query_length = 1; + } + if (key_length <= 0) { + key_length = 1; + } + + auto context_position = + torch::arange(actual_query_length, + torch::dtype(torch::kLong).device(device)) + .unsqueeze(1); + auto memory_position = + torch::arange(key_length, torch::dtype(torch::kLong).device(device)) + .unsqueeze(0); + auto relative_position = memory_position - context_position; + + auto relative_buckets = torch::zeros_like(relative_position); + + if (!is_decoder) { + num_buckets = num_buckets / 2; + relative_buckets += (relative_position > 0).to(torch::kLong) * num_buckets; + relative_position = torch::abs(relative_position); + } else { + relative_position = + -torch::min(relative_position, torch::zeros_like(relative_position)); + } + + const int64_t max_exact = num_buckets / 2; + auto is_small = relative_position < max_exact; + auto relative_position_if_large = + max_exact + (torch::log(relative_position.to(torch::kFloat) / max_exact) / + std::log(static_cast(max_distance) / max_exact) * + (num_buckets - max_exact)) + .to(torch::kLong); + + relative_position_if_large = + torch::min(relative_position_if_large, + torch::full_like(relative_position_if_large, num_buckets - 1)); + + relative_buckets += + torch::where(is_small, relative_position, relative_position_if_large); + + auto original_shape = relative_buckets.sizes(); + auto flattened_buckets = relative_buckets.flatten(); + auto values = position_bias_embedding(flattened_buckets); + + if (values.dim() == 2) { + CHECK_EQ(values.size(0), flattened_buckets.size(0)); + values = + values.view({original_shape[0], original_shape[1], values.size(1)}); + } else if (values.dim() == 1) { + values = + values.unsqueeze(-1).expand({flattened_buckets.size(0), num_heads}); + values = values.view({original_shape[0], original_shape[1], num_heads}); + } else { + LOG(FATAL) << "Unexpected OneRec position bias dim: " << values.dim(); + } + + if (values.dim() == 3) { + values = values.permute({2, 0, 1}); + } + + if (is_decode_stage && input_params != nullptr && + !input_params->kv_seq_lens_vec.empty()) { + const int32_t seq_kv_len = input_params->kv_seq_lens_vec[0]; + values = values.slice(1, -1, values.size(1)).slice(2, 0, seq_kv_len); + } else if (is_decode_stage) { + values = values.slice(1, -1, values.size(1)); + } + + return values.to(dtype); +} + +class OneRecStackImpl : public torch::nn::Module { + public: + OneRecStackImpl(const ModelContext& context, + bool is_decode, + layer::WordEmbedding& embed_tokens) { + const auto& args = context.get_model_args(); + const auto& options = context.get_tensor_options(); + + hidden_size_ = args.hidden_size(); + is_decoder_ = is_decode; + use_absolute_position_embedding_ = args.use_absolute_position_embedding(); + use_moe_ = args.use_moe() && is_decoder_; + num_experts_per_tok_ = args.num_experts_per_tok(); + relative_attention_num_buckets_ = args.relative_attention_num_buckets(); + relative_attention_max_distance_ = args.relative_attention_max_distance(); + num_heads_ = is_decode ? args.decoder_n_heads() : args.n_heads(); + + embed_tokens_ = embed_tokens; + if (!use_absolute_position_embedding_) { + position_bias_embedding_ = register_module("position_bias_embedding", + layer::WordEmbedding(context)); + } + + norm_ = register_module("final_layer_norm", layer::RMSNorm(context)); + + blocks_ = register_module("block", torch::nn::ModuleList()); + const uint32_t num_layers = + is_decode ? args.n_layers() : args.n_encoder_layers(); + layers_.reserve(num_layers); + for (uint32_t i = 0; i < num_layers; ++i) { + auto block = layer::NpuOneRecBlockLayer(context, is_decode, i); + layers_.emplace_back(block); + blocks_->push_back(block); + } + + (void)options; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params, + const torch::Tensor& encoder_output = torch::Tensor()) { + (void)positions; + + const auto* onerec_params = input_params.onerec_params(); + CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; + + torch::Tensor h; + if (onerec_params->is_hybrid_mode && !is_decoder_) { + h = tokens; + } else if (onerec_params->decoder_context_embedding.defined()) { + if (tokens.numel() == 0) { + h = onerec_params->decoder_context_embedding.reshape( + {-1, onerec_params->decoder_context_embedding.size(-1)}); + } else { + h = embed_tokens_(tokens); + + auto context_emb = onerec_params->decoder_context_embedding.clone(); + const int64_t hidden_size = context_emb.size(3); + const int64_t bs = onerec_params->bs; + const int64_t group_width = onerec_params->group_width; + const int64_t context_total_tokens = context_emb.size(2); + const int64_t token_total_tokens = h.size(0); + + const int64_t bs_group = bs * group_width; + const int64_t seq_len1 = + token_total_tokens / std::max(1, bs_group); + const int64_t seq_len2 = context_total_tokens - seq_len1; + + auto token_embedding_reshaped = + h.view({bs, group_width, seq_len1, hidden_size}); + context_emb.narrow(2, seq_len2, seq_len1) + .copy_(token_embedding_reshaped); + h = context_emb.view({-1, hidden_size}); + } + if (!h.is_contiguous()) { + h = h.contiguous(); + } + if (h.device().type() == torch::DeviceType::PrivateUse1 && + at_npu::native::get_npu_format(h) != ACL_FORMAT_ND) { + h = at_npu::native::npu_format_cast(h, ACL_FORMAT_ND).contiguous(); + } + } else { + h = embed_tokens_(tokens); + } + + torch::Tensor npu_encoder_output = encoder_output; + if (npu_encoder_output.defined() && + npu_encoder_output.device().type() != h.device().type()) { + npu_encoder_output = npu_encoder_output.to(h.device()); + } + + const bool is_prefill = + onerec_params->rec_stage == OneRecModelInputParams::RecStage::PREFILL; + auto [query_length, key_length] = compute_sequence_lengths( + input_params.q_max_seq_len, is_prefill, input_params); + + ModelInputParams input_params_local = input_params; + auto& mutable_onerec_params = input_params_local.mutable_onerec_params(); + + const bool is_decode_stage = is_decoder_ && !is_prefill; + torch::Tensor effective_attn_mask; + if (use_absolute_position_embedding_) { + effective_attn_mask = + create_moe_attention_mask(query_length, h, is_decoder_); + } else { + effective_attn_mask = compute_position_bias_mask( + query_length, key_length, h, is_decode_stage, input_params); + } + + auto preprocessed_attn_mask = + preprocess_attention_mask(effective_attn_mask, h); + + if (mutable_onerec_params.encoder_seq_lens_tensor.defined()) { + auto flattened_tensor = + mutable_onerec_params.encoder_seq_lens_tensor.flatten(); + mutable_onerec_params.encoder_seq_lens_tensor = + flattened_tensor.to(h.device(), torch::kInt).contiguous(); + } + + torch::Tensor expert_array; + if (use_moe_) { + const int64_t input_length = h.size(0); + expert_array = torch::arange( + 0, + input_length * num_experts_per_tok_, + torch::TensorOptions().dtype(torch::kInt32).device(h.device())); + } + + for (size_t i = 0; i < layers_.size(); ++i) { + if (input_params.layer_synchronizer) { + input_params.layer_synchronizer->synchronize_layer(i); + } + + KVCache dummy_kv_cache; + if (is_decoder_) { + CHECK_LT(i, kv_caches.size()) + << "OneRec decoder layer kv_cache is missing at layer " << i; + } + KVCache& kv_cache_ref = is_decoder_ ? kv_caches[i] : dummy_kv_cache; + + layers_[i]->forward( + h, + preprocessed_attn_mask, + kv_cache_ref, + input_params_local, + npu_encoder_output.defined() ? &npu_encoder_output : nullptr, + static_cast(i), + nullptr, + nullptr, + expert_array); + } + + std::optional residual = std::nullopt; + h = std::get<0>(norm_->forward(h, residual)); + return h; + } + + void load_state_dict(const StateDict& state_dict) { + auto embed_dict = state_dict.get_dict_with_prefix("embed_tokens."); + if (embed_dict.size() > 0) { + embed_tokens_->load_state_dict(embed_dict); + } + + if (!use_absolute_position_embedding_ && position_bias_embedding_) { + auto pos_bias_dict = state_dict.get_dict_with_prefix( + "block.0.layer.0.SelfAttention.relative_attention_bias."); + if (pos_bias_dict.size() > 0) { + position_bias_embedding_->load_state_dict(pos_bias_dict); + } + } + + for (int i = 0; i < static_cast(layers_.size()); ++i) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("block." + std::to_string(i) + ".")); + } + + norm_->load_state_dict( + state_dict.get_dict_with_prefix("final_layer_norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + for (int i = 0; i < static_cast(layers_.size()); ++i) { + layers_[i]->verify_loaded_weights(prefix + "block." + std::to_string(i) + + "."); + } + } + + void merge_loaded_weights() { + for (int i = 0; i < static_cast(layers_.size()); ++i) { + layers_[i]->merge_loaded_weights(); + } + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + + private: + std::pair compute_sequence_lengths( + int64_t seq_length, + bool is_prefill, + const ModelInputParams& input_params) const { + int64_t query_length = seq_length; + int64_t key_length = seq_length; + + const auto* onerec_params = input_params.onerec_params(); + CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; + + if (is_decoder_) { + if (is_prefill) { + query_length = seq_length; + key_length = seq_length; + } else { + query_length = 1; + if (!input_params.kv_seq_lens_vec.empty()) { + key_length = *std::max_element(input_params.kv_seq_lens_vec.begin(), + input_params.kv_seq_lens_vec.end()); + } + // Decode keeps a square bias/mask shape expected by OneRec NPU block. + query_length = key_length; + } + } else { + query_length = onerec_params->encoder_max_seq_len; + key_length = onerec_params->encoder_max_seq_len; + } + + return {query_length, key_length}; + } + + torch::Tensor create_moe_attention_mask(int64_t seq_length, + const torch::Tensor& h, + bool is_decoder) const { + if (!is_decoder) { + return torch::ones({num_heads_, seq_length, seq_length}, h.options()); + } + + const float mask_value = -9984.0f; + auto upper_tri_mask = + torch::triu(torch::ones({seq_length, seq_length}, + torch::dtype(h.dtype()).device(h.device())), + 1); + auto expanded_mask = upper_tri_mask.unsqueeze(0).expand( + {num_heads_, seq_length, seq_length}); + auto effective_attn_mask = + torch::zeros({num_heads_, seq_length, seq_length}, + torch::dtype(h.dtype()).device(h.device())); + effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool), + mask_value); + return effective_attn_mask; + } + + torch::Tensor compute_position_bias_mask( + int64_t query_length, + int64_t key_length, + const torch::Tensor& h, + bool is_decode_stage, + const ModelInputParams& input_params) { + CHECK(!position_bias_embedding_.is_empty()) + << "position_bias_embedding is required for relative attention."; + + auto layer_position_bias = + compute_onerec_position_bias(query_length, + key_length, + num_heads_, + is_decoder_, + position_bias_embedding_, + relative_attention_num_buckets_, + relative_attention_max_distance_, + torch::dtype(h.dtype()).device(h.device()), + is_decode_stage, + &input_params); + + auto effective_attn_mask = layer_position_bias.is_contiguous() + ? layer_position_bias + : layer_position_bias.contiguous(); + + if (is_decoder_ && FLAGS_enable_rec_prefill_only) { + const float mask_value = -9984.0f; + auto upper_tri_mask = + torch::triu(torch::ones({query_length, query_length}, + effective_attn_mask.options()), + 1); + auto expanded_mask = upper_tri_mask.unsqueeze(0).expand( + {num_heads_, query_length, query_length}); + effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool), + mask_value); + } + + return effective_attn_mask; + } + + torch::Tensor preprocess_attention_mask( + const torch::Tensor& effective_attn_mask, + const torch::Tensor& h) const { + if (!effective_attn_mask.defined()) { + return torch::Tensor(); + } + + if (effective_attn_mask.device() != h.device()) { + return effective_attn_mask.to(h.device()).contiguous(); + } + return effective_attn_mask.is_contiguous() + ? effective_attn_mask + : effective_attn_mask.contiguous(); + } + + int64_t hidden_size_ = 0; + bool is_decoder_ = true; + bool use_absolute_position_embedding_ = false; + bool use_moe_ = false; + int64_t relative_attention_num_buckets_ = 32; + int64_t relative_attention_max_distance_ = 128; + int64_t num_heads_ = 4; + int32_t num_experts_per_tok_ = 2; + + layer::WordEmbedding embed_tokens_{nullptr}; + layer::WordEmbedding position_bias_embedding_{nullptr}; + layer::RMSNorm norm_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; +}; +TORCH_MODULE(OneRecStack); + +} // namespace xllm diff --git a/xllm/models/rec/onerec.h b/xllm/models/rec/onerec.h index 6dc0bcb27..90ce579eb 100644 --- a/xllm/models/rec/onerec.h +++ b/xllm/models/rec/onerec.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include +#include #include #include #include @@ -28,6 +30,7 @@ limitations under the License. #include "core/framework/model/model_output.h" #include "core/framework/model_context.h" #include "core/framework/model_loader.h" +#include "core/layers/common/lm_head.h" #include "core/layers/common/word_embedding.h" #include "models/model_registry.h" #include "models/rec/rec_model_base.h" @@ -46,14 +49,14 @@ class OneRecModelImpl : public torch::nn::Module { const torch::Tensor& positions, std::vector& kv_caches, const ModelInputParams& input_params) { - (void)positions; - (void)kv_caches; - if (!tokens.defined()) { return ModelOutput(); } + (void)positions; + (void)kv_caches; const auto* onerec_params = input_params.onerec_params(); + const bool is_encoder_forward = (onerec_params != nullptr) && onerec_params->is_encoder_forward; @@ -67,7 +70,7 @@ class OneRecModelImpl : public torch::nn::Module { return ModelOutput(hidden_states); } - auto cross_context = get_cross_context_embedding(onerec_params); + auto cross_context = resolve_cross_context(onerec_params); if (cross_context.defined()) { auto enriched_hidden_states = add_cross_context_bias(hidden_states, cross_context); @@ -126,7 +129,7 @@ class OneRecModelImpl : public torch::nn::Module { return torch::Tensor(); } - torch::Tensor get_cross_context_embedding( + torch::Tensor resolve_cross_context( const OneRecModelInputParams* onerec_params) const { if (onerec_params == nullptr) { return torch::Tensor(); diff --git a/xllm/models/vlm/npu/glm4v.h b/xllm/models/vlm/npu/glm4v.h index 7de339fc1..23ffd11b8 100644 --- a/xllm/models/vlm/npu/glm4v.h +++ b/xllm/models/vlm/npu/glm4v.h @@ -30,206 +30,12 @@ limitations under the License. #include "models/llm/npu/glm4.h" #include "models/model_registry.h" #include "processors/glm4v_image_processor.h" -#include "processors/input_processor.h" +#include "processors/glm4v_input_processor.h" #include "torch_npu/csrc/aten/CustomFunctions.h" #include "xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" namespace xllm::npu::model { -class GLM4VInputProcessor : public InputProcessor { - enum class TokenType { - INVALID, - IMAGE, - VIDEO, - }; - - public: - GLM4VInputProcessor(const ModelArgs& args) { - merge_size_ = args.mm_image_merge_size(); - image_start_token_id_ = args.image_start_token_id(); - image_end_token_id_ = args.image_end_token_id(); - video_start_token_id_ = args.video_start_token_id(); - video_end_token_id_ = args.video_end_token_id(); - image_token_id_ = args.image_token_id(); - } - - void process(std::string& prompt, const MMData& mm_data) override { - torch::Tensor image_grid_thw; - if (auto res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - - torch::Tensor video_grid_thw; - if (auto res = mm_data.get("video_grid_thw")) - video_grid_thw = res.value(); - - if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; - - std::vector video_metadata; - mm_data.get_metadata(MMType::VIDEO, video_metadata); - - if (video_metadata.size() > 0) { - CHECK(video_metadata.size() == - static_cast(video_grid_thw.sizes()[0])); - } - - auto merge_length = merge_size_ * merge_size_; - int total_image_token = 0; - - if (image_grid_thw.defined()) { - auto count = image_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_image_token += - image_grid_thw[idx].prod().item() / merge_length; - } - - int total_video_token = 0; - if (video_grid_thw.defined()) { - auto count = video_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_video_token += video_grid_thw[idx].prod().item() / - merge_length / video_grid_thw[idx][0].item(); - } - - size_t total_token_len = total_image_token * image_token_.size() + - total_video_token * image_token_.size(); - std::string data; - data.reserve(prompt.size() + total_token_len); - - int image_index = 0; - int video_index = 0; - - size_t begin = 0; - auto pair = find_vision_token(prompt, begin); - - while (pair.second != std::string::npos) { - data.append(prompt, begin, pair.second - begin); - - if (pair.first == TokenType::IMAGE) { - auto token_num = - image_grid_thw[image_index].prod().item() / merge_length; - while (token_num--) data.append(image_token_); - - image_index++; - begin = pair.second + image_token_.size(); - } else if (pair.first == TokenType::VIDEO) { - auto num_frames = video_grid_thw[video_index][0].item(); - auto timestamps = video_metadata[video_index].timestamps; - CHECK(!timestamps.empty()); - - auto selected = build_timestamps(timestamps, num_frames); - auto token_num = video_grid_thw[video_index].prod().item() / - merge_length / num_frames; - - for (size_t idx = 0; idx < num_frames; ++idx) { - data.append(begin_of_image_token_); - - auto num = token_num; - while (num--) data.append(image_token_); - - data.append(end_of_image_token_); - data.append(format_timestamp_str(selected[idx])); - } - - video_index++; - begin = pair.second + video_token_.size(); - } else { - assert(false); - } - - pair = find_vision_token(prompt, begin); - } - - if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); - - prompt = std::move(data); - } - - void find_mm_spans(const std::vector& prompt, MMData& mm_data) override { - size_t tokens_num = prompt.size(); - uint32_t global_mm_index = 0; - uint32_t offset = 0; - uint32_t length = 0; - bool is_video = false; - auto& mm_items = mm_data.items(); - // TODO:support video info. - for (size_t idx = 0; idx < tokens_num; ++idx) { - auto token = prompt[idx]; - if (token == video_start_token_id_) { - is_video = true; - } else if (token == video_end_token_id_) { - is_video = false; - } - if (is_video) continue; - if (token == image_start_token_id_) { - offset = idx + 1; - } - if (token == image_token_id_) { - length++; - } else if (token == image_end_token_id_) { - auto& item = mm_items[global_mm_index++]; - item.mutable_state().mutable_token_pos() = {offset, length}; - length = 0; - } - } - } - - private: - std::pair find_vision_token(const std::string& prompt, - size_t begin) { - auto img_pos = prompt.find(image_token_, begin); - auto vid_pos = prompt.find(video_token_, begin); - - if (img_pos == std::string::npos && vid_pos == std::string::npos) - return {TokenType::INVALID, std::string::npos}; - else if (vid_pos == std::string::npos) - return {TokenType::IMAGE, img_pos}; - else if (img_pos == std::string::npos) - return {TokenType::VIDEO, vid_pos}; - else - return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) - : std::make_pair(TokenType::VIDEO, vid_pos); - } - - std::vector build_timestamps(const std::vector& timestamps, - size_t num_frames) { - std::vector vec; - vec.reserve(num_frames); - - for (size_t i = 0; i < timestamps.size(); i += 2) { - vec.push_back(timestamps[i]); - if (vec.size() == num_frames) break; - } - - while (vec.size() < num_frames) { - vec.push_back(vec.back()); - } - - return vec; - } - - std::string format_timestamp_str(double timestamp) { - char buffer[32]; - sprintf(buffer, "%.1f seconds", timestamp); - return buffer; - } - - private: - const std::string image_token_ = "<|image|>"; - const std::string video_token_ = "<|video|>"; - - const std::string begin_of_image_token_ = "<|begin_of_image|>"; - const std::string end_of_image_token_ = "<|end_of_image|>"; - - int32_t image_start_token_id_; - int32_t image_end_token_id_; - int32_t video_start_token_id_; - int32_t video_end_token_id_; - int32_t image_token_id_; - - int32_t merge_size_ = 0; -}; - class Glm4VisionRmsNormImpl : public torch::nn::Module { public: torch::Tensor weight; diff --git a/xllm/models/vlm/npu/glm4v_moe.h b/xllm/models/vlm/npu/glm4v_moe.h index f50cb7217..3279681b3 100644 --- a/xllm/models/vlm/npu/glm4v_moe.h +++ b/xllm/models/vlm/npu/glm4v_moe.h @@ -33,8 +33,7 @@ limitations under the License. #include "models/llm/npu/glm4_moe.h" #include "models/model_registry.h" #include "processors/glm4v_image_processor.h" -#include "processors/input_processor.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" +#include "processors/glm4v_input_processor.h" namespace xllm::npu::model { diff --git a/xllm/models/vlm/npu/minicpmv.h b/xllm/models/vlm/npu/minicpmv.h old mode 100755 new mode 100644 index 646af3123..75a7f7443 --- a/xllm/models/vlm/npu/minicpmv.h +++ b/xllm/models/vlm/npu/minicpmv.h @@ -30,188 +30,12 @@ limitations under the License. #include "core/layers/npu/npu_siglip_encoder_layer_impl.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/minicpmv_image_processor.h" +#include "processors/minicpmv_input_processor.h" #include "processors/pywarpper_image_processor.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" namespace xllm::npu::model { -class MiniCPMInputProcessor : public InputProcessor { - public: - MiniCPMInputProcessor(const ModelArgs& args) { - image_feature_size_ = args.mm_image_feature_size(); - max_slice_nums_ = args.vision_max_slice_nums(); - slice_mode_ = args.mm_slice_mode(); - use_image_id_ = args.mm_use_image_id(); - scale_resolution_ = args.mm_scale_resolution(); - } - - void process(std::string& prompt, const MMData& mm_data) override { - std::vector image_sizes; - mm_data.get("image_sizes", image_sizes); - - const std::regex pattern(R"(\([\s\S]*?\))"); - - std::sregex_iterator image_tag_begin(prompt.begin(), prompt.end(), pattern); - std::sregex_iterator image_tag_end; - - if (image_tag_begin == image_tag_end) { - return; - } - - std::vector> image_size_list; - image_size_list.reserve(image_sizes.size()); - for (auto& image_size : image_sizes) { - if (image_size.dim() != 1 || image_size.size(0) != 2) { - const auto& sizes = image_size.sizes(); - LOG(FATAL) << "image_size must be a 1D tensor with 2 " - "elements representing height and width;" - "now sizes: " - << sizes; - } - image_size_list.emplace_back( - std::make_pair(image_size[0].item(), image_size[1].item())); - } - - std::vector text_chunks; - size_t last_pos = 0; - - for (auto it = image_tag_begin; it != image_tag_end; ++it) { - auto match = *it; - text_chunks.push_back( - prompt.substr(last_pos, match.position() - last_pos)); - last_pos = match.position() + match.length(); - } - - text_chunks.push_back(prompt.substr(last_pos)); - - std::string new_prompt; - for (size_t i = 0; i < image_size_list.size(); ++i) { - new_prompt += text_chunks[i]; - new_prompt += get_slice_image_placeholder(image_size_list[i], i); - } - - new_prompt += text_chunks.back(); - prompt = new_prompt; - } - void find_mm_spans(const std::vector& prompt, MMData& mm_data) override { - uint32_t global_mm_index = 0; - uint32_t offset = 0; - uint32_t length = 0; - auto& mm_items = mm_data.items(); - auto start = prompt.begin(); - while (true) { - auto image_start_it = std::find(start, prompt.end(), im_start_id_); - auto image_end_it = std::find(start, prompt.end(), im_end_id_); - if (image_start_it == prompt.end()) { - break; - } - offset = std::distance(prompt.begin(), image_start_it); - length = std::distance(image_start_it + 1, image_end_it); - auto& item = mm_items[global_mm_index++]; - item.mutable_state().mutable_token_pos() = {offset + 1, length}; - start = std::next(image_end_it); - } - } - - private: - std::string get_image_id_placeholder(int idx) const { - return im_id_start_ + std::to_string(idx) + im_id_end_; - } - - std::string get_grid_placeholder(const std::pair& grid) const { - if (grid.first == 0 || grid.second == 0) { - return ""; - } - - // Prepare the slice placeholder - std::string slice_placeholder = slice_start_token_; - - // Append the repeated unk_token_ - for (int i = 0; i < image_feature_size_; ++i) { - slice_placeholder += unk_token_; - } - - slice_placeholder += slice_end_token_; - - // Use a string to accumulate the result - std::string grid_placeholder; - - // Loop over the grid and append placeholders - for (int i = 0; i < grid.second; ++i) { // Iterate through rows - for (int j = 0; j < grid.first; ++j) { // Iterate through columns - grid_placeholder += slice_placeholder; // Append the placeholder - } - if (i < grid.second - 1) { - grid_placeholder += - "\n"; // Add a newline after each row except the last one - } - } - - return grid_placeholder; - } - - std::string get_slice_image_placeholder( - const std::pair& image_size, - int image_idx = 0, - int max_slice_nums = -1, - std::optional use_image_id_opt = std::nullopt) const { - if (max_slice_nums < 0) { - max_slice_nums = max_slice_nums_; - } - - bool use_image_id = - use_image_id_opt.has_value() ? use_image_id_opt.value() : use_image_id_; - - assert(max_slice_nums > 0); - - auto grid = MiniCPMVImageProcessor::get_sliced_grid( - image_size, max_slice_nums, scale_resolution_); - - std::string image_placeholder = im_start_token_; - - for (int i = 0; i < image_feature_size_; ++i) { - image_placeholder += unk_token_; - } - - image_placeholder += im_end_token_; - - std::string final_placeholder; - - if (use_image_id) { - final_placeholder = - get_image_id_placeholder(image_idx) + image_placeholder; - } else { - final_placeholder = image_placeholder; - } - - if (slice_mode_) { - final_placeholder += get_grid_placeholder(grid); - } - - return final_placeholder; - } - - private: - const std::string im_start_token_ = ""; - const std::string im_end_token_ = ""; - const std::string slice_start_token_ = ""; - const std::string slice_end_token_ = ""; - const std::string unk_token_ = ""; - const std::string im_id_start_ = ""; - const std::string im_id_end_ = ""; - - const int im_start_id_ = 151659; - const int im_end_id_ = 151658; - - bool slice_mode_; - bool use_image_id_; - int max_slice_nums_; - int image_feature_size_; - int scale_resolution_; -}; - class BaseResamplerImpl : public torch::nn::Module { public: BaseResamplerImpl(const ModelContext& context) @@ -816,7 +640,7 @@ class VisionAdapterMLPImpl : public torch::nn::Module { auto seq = torch::nn::Sequential(lni, cpl, act, rpl); layers_->push_back(seq); - mlps_.push_back(std::make_tuple(lni, cpl, act, rpl)); + mlps_.emplace_back(lni, cpl, act, rpl); } } @@ -1214,6 +1038,7 @@ class MiniCPMV2_6Impl : public torch::nn::Module { const ModelInputParams& input_params) { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); diff --git a/xllm/models/vlm/npu/oxygen_vlm.h b/xllm/models/vlm/npu/oxygen_vlm.h index a27bbe1a6..edb7ecae2 100644 --- a/xllm/models/vlm/npu/oxygen_vlm.h +++ b/xllm/models/vlm/npu/oxygen_vlm.h @@ -30,8 +30,8 @@ limitations under the License. #include "glm4v.h" #include "models/llm/npu/oxygen.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" #include "qwen2_5_vl.h" #include "torch_npu/csrc/aten/CustomFunctions.h" diff --git a/xllm/models/vlm/npu/qwen2_5_vl.h b/xllm/models/vlm/npu/qwen2_5_vl.h index b507e6854..1e2502c86 100644 --- a/xllm/models/vlm/npu/qwen2_5_vl.h +++ b/xllm/models/vlm/npu/qwen2_5_vl.h @@ -30,155 +30,13 @@ limitations under the License. #include "core/layers/npu/npu_rms_norm_impl.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" +#include "processors/qwen2_vl_input_processor.h" namespace xllm::npu::model { #define PrintTensor(tensor) print_tensor(tensor, #tensor, 10, true, false); -class Qwen2_5_VLInputProcessor : public InputProcessor { - enum class TokenType { - INVALID, - IMAGE, - VIDEO, - }; - - public: - Qwen2_5_VLInputProcessor(const ModelArgs& args) { - merge_size_ = args.mm_image_merge_size(); - vision_start_token_id_ = args.vision_start_token_id(); - vision_end_token_id_ = args.vision_end_token_id(); - image_token_id_ = args.image_token_id(); - video_token_id_ = args.video_token_id(); - } - - void process(std::string& prompt, const MMData& mm_data) override { - torch::Tensor image_grid_thw; - if (auto res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - - torch::Tensor video_grid_thw; - if (auto res = mm_data.get("video_grid_thw")) - video_grid_thw = res.value(); - - if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; - - auto merge_length = merge_size_ * merge_size_; - int total_image_token = 0; - if (image_grid_thw.defined()) { - auto count = image_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_image_token += - image_grid_thw[idx].prod().item() / merge_length; - } - - int total_video_token = 0; - if (video_grid_thw.defined()) { - auto count = video_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_video_token += - video_grid_thw[idx].prod().item() / merge_length; - } - - size_t total_token_len = total_image_token * image_token_.size() + - total_video_token * video_token_.size(); - std::string data; - data.reserve(prompt.size() + total_token_len); - - int image_index = 0; - int video_index = 0; - - const torch::Tensor* grid_thw = nullptr; - const std::string* token = nullptr; - int* index = 0; - - size_t begin = 0; - auto pair = find_vision_token(prompt, begin); - - while (pair.second != std::string::npos) { - data.append(prompt, begin, pair.second - begin); - - if (pair.first == TokenType::IMAGE) { - grid_thw = &image_grid_thw; - token = &image_token_; - index = &image_index; - } else if (pair.first == TokenType::VIDEO) { - grid_thw = &video_grid_thw; - token = &video_token_; - index = &video_index; - } else { - assert(false); - } - - auto token_num = (*grid_thw)[(*index)].prod().item() / merge_length; - while (token_num--) data.append(*token); - - ++(*index); - begin = pair.second + token->size(); - pair = find_vision_token(prompt, begin); - } - - if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); - - prompt = std::move(data); - } - - void find_mm_spans(const std::vector& prompt, MMData& mm_data) { - auto start = prompt.begin(); - uint32_t global_mm_index = 0; - uint32_t offset = 0; - uint32_t length = 0; - auto& mm_items = mm_data.items(); - while (true) { - auto vision_start_it = - std::find(start, prompt.end(), vision_start_token_id_); - auto vision_end_it = std::find(start, prompt.end(), vision_end_token_id_); - if (vision_start_it == prompt.end()) { - break; - } - offset = std::distance(prompt.begin(), vision_start_it); - length = std::distance(vision_start_it + 1, vision_end_it); - - auto& item = mm_items[global_mm_index]; - if (*(vision_start_it + 1) == image_token_id_) { - item.mutable_state().mutable_token_pos() = {offset + 1, length}; - } else if (*(vision_start_it + 1) == video_token_id_) { - item.mutable_state().mutable_token_pos() = {offset + 1, length}; - } - global_mm_index++; - start = std::next(vision_end_it); - } - } - - private: - std::pair find_vision_token(const std::string& prompt, - size_t begin) { - auto img_pos = prompt.find(image_token_, begin); - auto vid_pos = prompt.find(video_token_, begin); - - if (img_pos == std::string::npos && vid_pos == std::string::npos) - return {TokenType::INVALID, std::string::npos}; - else if (vid_pos == std::string::npos) - return {TokenType::IMAGE, img_pos}; - else if (img_pos == std::string::npos) - return {TokenType::VIDEO, vid_pos}; - else - return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) - : std::make_pair(TokenType::VIDEO, vid_pos); - } - - private: - const std::string image_token_ = "<|image_pad|>"; - const std::string video_token_ = "<|video_pad|>"; - int32_t vision_start_token_id_; - int32_t vision_end_token_id_; - int32_t image_token_id_; - int32_t video_token_id_; - int32_t merge_size_ = 0; -}; - class Qwen2_5_VisionBlockImpl : public torch::nn::Module { public: Qwen2_5_VisionBlockImpl(const ModelContext& context) { @@ -624,6 +482,7 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { std::vector cu_w_seqlens_vec( cu_window_seqlens_cpu.data_ptr(), // windows seqlen vec cu_window_seqlens_cpu.data_ptr() + cu_window_seqlens_cpu.numel()); + for (int idx = 0; idx < blocks_->size(); ++idx) { torch::Tensor cu_seqlens_now; std::vector cu_seqlens_now_vec; @@ -837,6 +696,22 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + // return full embeddings if set flag + if (FLAGS_enable_return_mm_full_embeddings) { + return h; + } + + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + auto pooler_output = torch::nn::functional::normalize( + h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); + return pooler_output; + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); diff --git a/xllm/models/vlm/npu/qwen2_vl.h b/xllm/models/vlm/npu/qwen2_vl.h index af7906b4a..f626e6ba0 100644 --- a/xllm/models/vlm/npu/qwen2_vl.h +++ b/xllm/models/vlm/npu/qwen2_vl.h @@ -29,10 +29,9 @@ limitations under the License. #include "core/layers/npu/npu_rms_norm_impl.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" #include "qwen2_5_vl.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" namespace xllm::npu::model { @@ -578,6 +577,16 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + namespace F = torch::nn::functional; + return F::normalize(h, F::NormalizeFuncOptions().p(2).dim(1)); + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); diff --git a/xllm/models/vlm/npu/qwen2_vl_embedding.h b/xllm/models/vlm/npu/qwen2_vl_embedding.h deleted file mode 100644 index cb1e25964..000000000 --- a/xllm/models/vlm/npu/qwen2_vl_embedding.h +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include "core/framework/model/embedding_vlm.h" -#include "core/framework/model/model_output.h" -#include "models/vlm/npu/qwen2_5_vl.h" -#include "models/vlm/npu/qwen2_vl.h" - -namespace xllm::npu::model { - -class Qwen2_VLForEmbeddingImpl : public torch::nn::Module { - public: - Qwen2_VLForEmbeddingImpl(const ModelContext& context) - : model_args_(context.get_model_args()), - options_(context.get_tensor_options()) { - visual_ = register_module("visual", Qwen2_VisionTransformer(context)); - language_model_ = - register_module("language_model", QWen2ForCausalLM(context)); - } - - void prepare_encoder_input(const ModelInputParams& input_params, - std::optional& image_inputs, - std::optional& video_inputs) { - const auto& mm_data = input_params.mm_data; - torch::Tensor pixel_values; - if (const auto& res = mm_data.get("pixel_values")) - pixel_values = res.value(); - - torch::Tensor image_grid_thw; - if (const auto& res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - - torch::Tensor pixel_values_videos; - if (const auto& res = mm_data.get("pixel_values_videos")) - pixel_values_videos = res.value(); - - if (pixel_values.defined() && image_grid_thw.defined()) - image_inputs = Qwen2_VLImageInputs{pixel_values, image_grid_thw}; - } - - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { - std::optional image_input; - std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); - auto merge_size = model_args_.mm_image_merge_size(); - MMDict multimodal_embeds; - if (image_input) { - // visual - auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); - auto image_tokens = - (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) - .cpu() - .contiguous() - .to(torch::kLong); - - std::vector image_tokens_vec( - image_tokens.data_ptr(), - image_tokens.data_ptr() + image_tokens.numel()); - multimodal_embeds["image|embedding"] = - image_embeds.split(image_tokens_vec, 0 /*dim*/); - } - return multimodal_embeds; - } - - torch::Tensor generate_multimodal_mask(torch::Tensor input_ids) { - auto special_token_ids = torch::tensor( - {model_args_.image_token_id(), model_args_.video_token_id()}, - input_ids.options().dtype(torch::kInt64)); - auto is_multimodal = torch::isin(input_ids, special_token_ids); - return is_multimodal; - } - - torch::Tensor merge_multimodal_embeddings( - torch::Tensor inputs_embeds, - const torch::Tensor& multimodal_embeds, - const torch::Tensor& is_multimodal) { - inputs_embeds.index_put_({is_multimodal}, multimodal_embeds); - return inputs_embeds; - } - - torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.mm_data; - torch::Tensor multimodal_embeds; - if (const auto& emb = mm_data.get("embedding")) { - multimodal_embeds = emb.value(); - } - auto inputs_embeds = language_model_->get_input_embeddings(input_ids); - if (!multimodal_embeds.defined()) { - return inputs_embeds; - } - auto is_multimodal = generate_multimodal_mask(input_ids); - inputs_embeds = merge_multimodal_embeddings( - inputs_embeds, multimodal_embeds, is_multimodal); - return inputs_embeds; - } - - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); - } - - torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) { - auto h = hidden_states; - if (seleted_idxes.defined()) { - h = h.index_select(/*dim=*/0, seleted_idxes); - } - auto pooler_output = torch::nn::functional::normalize( - h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); - return pooler_output; - } - - torch::Tensor logits(const torch::Tensor&, const torch::Tensor&) { - NOT_IMPLEMENTED(); - return torch::Tensor(); - } - - torch::Device device() const { return options_.device(); } - - const torch::TensorOptions& options() const { return options_; } - - void load_model(std::unique_ptr loader) { - for (const auto& state_dict : loader->get_state_dicts()) { - visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); - } - // verify - visual_->verify_loaded_weights("visual."); - visual_->merge_loaded_weights(); - // if (!model_args_.image_embedding_mode()) { - language_model_->load_model(std::move(loader)); - // } - } - - layer::NpuLmHead get_npu_lm_head() { - return language_model_->get_npu_lm_head(); - } - void set_npu_lm_head(layer::NpuLmHead& head) { - language_model_->set_npu_lm_head(head); - } - - layer::NpuWordEmbedding get_npu_word_embedding() { - return language_model_->get_npu_word_embedding(); - } - - void set_npu_word_embedding(layer::NpuWordEmbedding& npu_word_embedding) { - language_model_->set_npu_word_embedding(npu_word_embedding); - } - - private: - ModelArgs model_args_; - torch::TensorOptions options_; - - Qwen2_VisionTransformer visual_{nullptr}; - QWen2ForCausalLM language_model_{nullptr}; -}; -TORCH_MODULE(Qwen2_VLForEmbedding); - -} // namespace xllm::npu::model - -namespace xllm { - -template <> -class EmbeddingVLMImpl : public EmbeddingVLM { - public: - EmbeddingVLMImpl(npu::model::Qwen2_VLForEmbedding model, - const torch::TensorOptions& options) - : model_(std::move(model)), options_(options) {} - - MMDict encode(const ModelInputParams& input_params) override { - return model_->get_multimodal_embeddings(input_params); - }; - torch::Tensor get_input_embeddings(const torch::Tensor& input_ids, - const ModelInputParams& input_params) { - return model_->get_input_embeddings(input_ids, input_params); - } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - return model_->forward(tokens, positions, kv_caches, parameters); - } - - torch::Tensor logits(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->logits(hidden_states, seleted_idxes); - } - - torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->pooler(hidden_states, seleted_idxes); - } - - void load_model(std::unique_ptr loader) override { - model_->load_model(std::move(loader)); - } - - torch::Device device() const override { return model_->device(); } - - const torch::TensorOptions& options() const override { - return model_->options(); - } - - virtual void prepare_expert_weight(int32_t layer_id, - const std::vector& expert_ids) { - return; - } - virtual void update_expert_weight(int32_t layer_id) { return; } - - // Delegate head/embedding accessors to underlying model implementation. - layer::NpuLmHead get_npu_lm_head() override { - return model_->get_npu_lm_head(); - } - void set_npu_lm_head(layer::NpuLmHead& head) override { - model_->set_npu_lm_head(head); - } - layer::NpuWordEmbedding get_npu_word_embedding() override { - return model_->get_npu_word_embedding(); - } - void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override { - model_->set_npu_word_embedding(embedding); - } - - private: - npu::model::Qwen2_VLForEmbedding model_; - torch::TensorOptions options_; -}; - -REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(qwen2_vl_embedding, - qwen2_vl, - npu::model::Qwen2_VLForEmbedding); -} // namespace xllm diff --git a/xllm/models/vlm/npu/qwen3_vl.h b/xllm/models/vlm/npu/qwen3_vl.h index c43c43dfa..28c6dcc77 100644 --- a/xllm/models/vlm/npu/qwen3_vl.h +++ b/xllm/models/vlm/npu/qwen3_vl.h @@ -25,10 +25,9 @@ limitations under the License. #include "core/layers/npu/npu_rms_norm_impl.h" #include "models/llm/npu/qwen3.h" #include "models/model_registry.h" -#include "processors/input_processor.h" -#include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen3_vl_image_processor.h" +#include "processors/qwen3_vl_input_processor.h" #include "qwen2_5_vl.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" namespace xllm::npu::model { @@ -626,7 +625,6 @@ struct Qwen3_VLImageInputs { struct Qwen3_VLVideoInputs { torch::Tensor pixel_values_videos; torch::Tensor video_grid_thw; - torch::Tensor second_per_grid_ts; }; class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { @@ -659,17 +657,11 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { if (const auto& res = mm_data.get("video_grid_thw")) video_grid_thw = res.value(); - torch::Tensor second_per_grid_ts; - if (const auto& res = mm_data.get("second_per_grid_ts")) - second_per_grid_ts = res.value(); - if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; - if (pixel_values_videos.defined() && video_grid_thw.defined() && - second_per_grid_ts.defined()) - video_inputs = Qwen3_VLVideoInputs{ - pixel_values_videos, video_grid_thw, second_per_grid_ts}; + if (pixel_values_videos.defined() && video_grid_thw.defined()) + video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { @@ -751,6 +743,7 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { mm_data.get("embedding|deepstack_2").value()}; return deepstacks; } + torch::Tensor merge_multimodal_embeddings( torch::Tensor inputs_embeds, const torch::Tensor& multimodal_embeds, @@ -785,6 +778,11 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->pooler(hidden_states, seleted_idxes); + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); @@ -829,9 +827,9 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { }; TORCH_MODULE(Qwen3_VLForConditionalGeneration); -REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen2_5_VLInputProcessor); +REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen3_VLInputProcessor); REGISTER_CAUSAL_VLM_MODEL(qwen3_vl, Qwen3_VLForConditionalGeneration); -REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen2VLImageProcessor); +REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen3VLImageProcessor); REGISTER_MODEL_ARGS(qwen3_vl, [&] { // text config diff --git a/xllm/models/vlm/npu/qwen3_vl_moe.h b/xllm/models/vlm/npu/qwen3_vl_moe.h index 6e4f31158..507ba6ffb 100644 --- a/xllm/models/vlm/npu/qwen3_vl_moe.h +++ b/xllm/models/vlm/npu/qwen3_vl_moe.h @@ -26,11 +26,10 @@ limitations under the License. #include "core/layers/npu/npu_rms_norm_impl.h" #include "models/llm/npu/qwen3_moe.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" #include "qwen2_5_vl.h" #include "qwen3_vl.h" -#include "xllm_atb_layers/core/include/atb_speed/log.h" namespace xllm::npu::model { diff --git a/xllm/models/vlm/oxygen_vlm.h b/xllm/models/vlm/oxygen_vlm.h new file mode 100644 index 000000000..c53588c3b --- /dev/null +++ b/xllm/models/vlm/oxygen_vlm.h @@ -0,0 +1,791 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/framework/model/model_output.h" +#include "core/layers/common/lm_head.h" +#include "core/layers/oxygen_vision_layer.h" +#include "core/layers/qwen2_5_vision_layer.h" +#include "core/layers/qwen2_decoder_layer.h" +#include "models/llm/oxygen.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "qwen2_5_vl.h" + +namespace xllm { +using OxygenImageInputs = Qwen2_5_VLImageInputs; + +struct OxygenVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; +}; + +class OxygenVisionPatchEmbedImpl : public torch::nn::Module { + public: + OxygenVisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(true))); + + proj_->weight.set_data(proj_->weight.to(options)); + proj_->bias.set_data(proj_->bias.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + auto bias = state_dict.get_tensor("proj.bias"); + if (bias.defined()) { + bias = bias.reshape({bias.size(0)}); + DCHECK_EQ(proj_->bias.sizes(), bias.sizes()) + << "proj bias size mismatch for " << name(); + proj_->bias.data().copy_(bias); + proj_bias_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + CHECK(proj_bias_loaded_) + << "bias is not loaded for " << prefix + "proj.bias"; + } + + private: + bool proj_weight_loaded_ = false; + bool proj_bias_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(OxygenVisionPatchEmbed); + +class OxygenVisionEmbeddingsImpl : public torch::nn::Module { + public: + OxygenVisionEmbeddingsImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + embed_dim_ = model_args.mm_hidden_size(); + image_size_ = model_args.mm_image_size(); + patch_size_ = model_args.mm_patch_size(); + num_positions_ = image_size_ / patch_size_; + num_positions_ = num_positions_ * num_positions_; + position_embedding_ = register_module( + "position_embedding", torch::nn::Embedding(num_positions_, embed_dim_)); + position_embedding_->weight.set_data( + position_embedding_->weight.to(options)); + } + torch::Tensor forward(torch::Tensor x, + std::vector lengths, + torch::Tensor image_shapes, + torch::Tensor h_coords, + torch::Tensor w_coords) { + const auto& pos_embed_weight = position_embedding_->weight; + const int64_t hidden_size = pos_embed_weight.size(1); + const int64_t total_seq = x.size(0); + const auto device = pos_embed_weight.device(); + const auto dtype = pos_embed_weight.dtype(); + + image_shapes = image_shapes.to(device); + h_coords = h_coords.to(device); + w_coords = w_coords.to(device); + x = x.to(device, dtype); + + torch::Tensor adapted_pos_embed; + if (total_seq == 0) { + adapted_pos_embed = torch::empty( + {0, hidden_size}, torch::TensorOptions().device(device).dtype(dtype)); + } else { + const int64_t batch_size = static_cast(lengths.size()); + const int64_t orig_size_sq = pos_embed_weight.size(0); + const int64_t orig_size = static_cast(std::sqrt(orig_size_sq)); + auto pos_embed_2d = + pos_embed_weight.view({orig_size, orig_size, hidden_size}) + .permute({2, 0, 1}) + .unsqueeze(0) + .to(torch::kFloat32); + + std::vector target_h_list; + std::vector target_w_list; + target_h_list.reserve(batch_size); + target_w_list.reserve(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + const int64_t seq_len = lengths[i]; + const auto img_h = image_shapes.index({i, 1}).to(torch::kFloat32); + const auto img_w = image_shapes.index({i, 2}).to(torch::kFloat32); + + target_h_list.push_back(img_h.repeat({seq_len})); + target_w_list.push_back(img_w.repeat({seq_len})); + } + + auto target_h = torch::cat(target_h_list, 0); + auto target_w = torch::cat(target_w_list, 0); + + auto h_coords_fp32 = h_coords.to(torch::kFloat32); + auto w_coords_fp32 = w_coords.to(torch::kFloat32); + + const auto norm_w = ((w_coords_fp32 + 0.5f) / target_w) * 2.0f - 1.0f; + const auto norm_h = ((h_coords_fp32 + 0.5f) / target_h) * 2.0f - 1.0f; + auto grid = torch::stack({norm_w, norm_h}, -1).unsqueeze(0).unsqueeze(2); + namespace F = torch::nn::functional; + auto interpolated_embed = F::grid_sample(pos_embed_2d, + grid, + F::GridSampleFuncOptions() + .mode(torch::kBicubic) + .padding_mode(torch::kBorder) + .align_corners(false)); + adapted_pos_embed = + interpolated_embed.squeeze(0).squeeze(-1).permute({1, 0}).to(dtype); + } + + return x + adapted_pos_embed; + } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("position_embedding.weight"); + if (weight.defined()) { + position_embedding_->weight.data().copy_(weight); + position_embedding_weight_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(position_embedding_weight_loaded_) + << "weight is not loaded for " << prefix + "position_embedding.weight"; + } + + private: + int64_t embed_dim_ = 0; + int64_t image_size_ = 0; + int64_t patch_size_ = 0; + int64_t num_positions_ = 0; + bool position_embedding_weight_loaded_ = false; + torch::nn::Embedding position_embedding_{nullptr}; +}; +TORCH_MODULE(OxygenVisionEmbeddings); + +class OxygenVisionPatchMergerImpl : public torch::nn::Module { + public: + OxygenVisionPatchMergerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + options_ = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + int64_t dim = model_args.mm_projection_dim(); + int64_t context_dim = model_args.mm_projector_hidden_size(); + norm_ = register_module( + "norm", torch::nn::LayerNorm(torch::nn::LayerNormOptions({dim}))); + norm_->weight.set_data(norm_->weight.to(options_)); + norm_->bias.set_data(norm_->bias.to(options_)); + proj_ = register_module( + "proj", + torch::nn::Linear(torch::nn::LinearOptions(dim, dim).bias(false))); + proj_->weight.set_data(proj_->weight.to(options_)); + act_ = register_module("act", torch::nn::GELU()); + silu_ = register_module("silu", torch::nn::SiLU()); + + gate_ = register_module( + "gate", + torch::nn::Linear( + torch::nn::LinearOptions(dim, context_dim).bias(false))); + gate_->weight.set_data(gate_->weight.to(options_)); + up_ = register_module( + "up", + torch::nn::Linear( + torch::nn::LinearOptions(dim, context_dim).bias(false))); + up_->weight.set_data(up_->weight.to(options_)); + down_ = register_module( + "down", + torch::nn::Linear( + torch::nn::LinearOptions(context_dim, dim).bias(false))); + down_->weight.set_data(down_->weight.to(options_)); + } + + torch::Tensor forward(torch::Tensor x) { + x = proj_(x); + x = act_(norm_(x)); + x = down_(torch::mul(silu_((gate_(x))), up_(x))); + return x; + } + + void load_state_dict(const StateDict& state_dict) { + // norm + const auto& norm_dict = + state_dict.get_dict_with_prefix("post_projection_norm."); + const auto& norm_weight = norm_dict.get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(norm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + norm_->weight.data().copy_(norm_weight); + is_norm_weight_loaded = true; + } + const auto norm_bias = norm_dict.get_tensor("bias"); + if (norm_bias.defined()) { + CHECK_EQ(norm_->bias.sizes(), norm_bias.sizes()) + << "bias size mismatch for " << name(); + norm_->bias.data().copy_(norm_bias); + is_norm_bias_loaded = true; + } + + const auto& proj_dict = state_dict.get_dict_with_prefix("proj."); + const auto& proj_weight = proj_dict.get_tensor("weight"); + if (proj_weight.defined()) { + proj_->weight.data().copy_(proj_weight); + is_proj_weight_loaded = true; + } + + const auto& up_dict = state_dict.get_dict_with_prefix("up_proj."); + const auto& up_weight = up_dict.get_tensor("weight"); + if (up_weight.defined()) { + up_->weight.data().copy_(up_weight); + is_up_weight_loaded = true; + } + + const auto& down_dict = state_dict.get_dict_with_prefix("down_proj."); + const auto& down_weight = down_dict.get_tensor("weight"); + if (down_weight.defined()) { + down_->weight.data().copy_(down_weight); + is_down_weight_loaded = true; + } + + const auto& gate_dict = state_dict.get_dict_with_prefix("gate_proj."); + const auto& gate_weight = gate_dict.get_tensor("weight"); + if (gate_weight.defined()) { + gate_->weight.data().copy_(gate_weight); + is_gate_weight_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_proj_weight_loaded) + << "weight is not loaded for " << prefix + "proj_weight" + ".weight"; + CHECK(is_up_weight_loaded) + << "weight is not loaded for " << prefix + "up_weight" + ".weight"; + CHECK(is_down_weight_loaded) + << "weight is not loaded for " << prefix + "down_weight" + ".weight"; + CHECK(is_gate_weight_loaded) + << "weight is not loaded for " << prefix + "gate_weight" + ".weight"; + CHECK(is_norm_weight_loaded) + << "weight is not loaded for " << prefix + "norm" + ".weight"; + CHECK(is_norm_bias_loaded) + << "bias is not loaded for " << prefix + "norm" + ".bias"; + } + + private: + torch::nn::LayerNorm norm_{nullptr}; + torch::nn::Linear proj_{nullptr}; + torch::nn::Linear up_{nullptr}; + torch::nn::Linear gate_{nullptr}; + torch::nn::Linear down_{nullptr}; + torch::nn::GELU act_{nullptr}; + torch::nn::SiLU silu_{nullptr}; + torch::TensorOptions options_; + + bool is_proj_weight_loaded = false; + bool is_up_weight_loaded = false; + bool is_down_weight_loaded = false; + bool is_gate_weight_loaded = false; + bool is_norm_weight_loaded = false; + bool is_norm_bias_loaded = false; +}; +TORCH_MODULE(OxygenVisionPatchMerger); + +class OxygenVisionTransformerImpl : public torch::nn::Module { + public: + OxygenVisionTransformerImpl(const ModelContext& context) + : options_(context.get_tensor_options()) { + auto model_args = context.get_model_args(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + hidden_size_ = model_args.mm_hidden_size(); + out_hidden_size_ = model_args.mm_projection_dim(); + + patch_embed_ = + register_module("patch_embed", OxygenVisionPatchEmbed(context)); + rotary_pos_emb_ = register_module("rotary_pos_emb", + Qwen2_5_VisionRotaryEmbedding(context)); + post_conv_layernorm_ = register_module( + "post_conv_layernorm", + layer::RMSNorm( + model_args.mm_hidden_size(), model_args.rms_norm_eps(), options_)); + + embeddings_ = + register_module("embeddings", OxygenVisionEmbeddings(context)); + + blocks_ = register_module("blocks", torch::nn::ModuleList()); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = layer::OxygenVisionLayer(context); + blocks_->push_back(block); + layers_.push_back(block); + } + post_layernorm_ = register_module( + "post_layernorm", + layer::RMSNorm( + model_args.mm_hidden_size(), model_args.rms_norm_eps(), options_)); + + downsample_ = register_module( + "downsample", + torch::nn::Conv2d(torch::nn::Conv2dOptions(hidden_size_, + out_hidden_size_, + spatial_merge_size_) + .stride(spatial_merge_size_) + .bias(true) + .padding(0))); + downsample_->weight.set_data(downsample_->weight.to(options_)); + downsample_->bias.set_data(downsample_->bias.to(options_)); + merger_ = register_module("merger", OxygenVisionPatchMerger(context)); + } + + std::tuple rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + auto grid_thw_cpu = grid_thw.cpu(); + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + pos_ids_vec.push_back( + torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = + grid_thw + .index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}) + .max(); + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return std::make_tuple(rotary_pos_emb, pos_ids); + } + + torch::Tensor forward(torch::Tensor hidden_states, + torch::Tensor grid_thw, + const ModelInputParams& input_params) { + hidden_states = patch_embed_(hidden_states); + hidden_states = std::get<0>(post_conv_layernorm_(hidden_states)); + + auto [rotary_pos_emb, image_type_ids] = rot_pos_emb(grid_thw); + auto emb = torch::cat({rotary_pos_emb, rotary_pos_emb}, -1); + auto m_cos = emb.cos().type_as(hidden_states); + auto m_sin = emb.sin().type_as(hidden_states); + + auto device = grid_thw.device(); + auto grid_t = grid_thw.index_select( + 1, + torch::tensor( + {0}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto grid_h = grid_thw.index_select( + 1, + torch::tensor( + {1}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto grid_w = grid_thw.index_select( + 1, + torch::tensor( + {2}, torch::TensorOptions().dtype(torch::kInt).device(device))); + auto h_times_w = (grid_h * grid_w).squeeze(1); + auto repeats = grid_t.squeeze(1); + auto repeated = torch::repeat_interleave(h_times_w, repeats, 0); + c10::optional cumsum_dtype; + + cumsum_dtype = torch::kInt32; + auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype); + namespace F = torch::nn::functional; + cu_seqlens = F::pad( + cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + auto seqlens_cpu = torch::diff(cu_seqlens).cpu().to(torch::kInt); + std::vector seqlens; + seqlens.assign(seqlens_cpu.data_ptr(), + seqlens_cpu.data_ptr() + seqlens_cpu.numel()); + + hidden_states = embeddings_(hidden_states, + seqlens, + grid_thw, + image_type_ids.select(1, 0), + image_type_ids.select(1, 1)); + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + cu_seqlens = cu_seqlens.to(hidden_states.device()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + hidden_states = layers_[idx](hidden_states, + m_cos, + m_sin, + cu_seqlens, + cu_seqlens_vec, + input_params_new, + idx); + } + hidden_states = std::get<0>(post_layernorm_(hidden_states)); + hidden_states = hidden_states.view( + {-1, spatial_merge_size_, spatial_merge_size_, hidden_states.size(-1)}); + hidden_states = hidden_states.permute({0, 3, 1, 2}); + hidden_states = downsample_(hidden_states).view({-1, out_hidden_size_}); + hidden_states = merger_(hidden_states); + return hidden_states; + }; + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + embeddings_->load_state_dict( + state_dict.get_dict_with_prefix("embeddings.")); + const auto& norm_weight = + state_dict.get_dict_with_prefix("post_conv_layernorm.") + .get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(post_conv_layernorm_->weight().sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + post_conv_layernorm_->weight().data().copy_(norm_weight); + is_post_conv_layernorm_weight_loaded = true; + } + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + } + + const auto& post_norm_weight = + state_dict.get_dict_with_prefix("post_layernorm.").get_tensor("weight"); + if (post_norm_weight.defined()) { + CHECK_EQ(post_layernorm_->weight().sizes(), post_norm_weight.sizes()) + << "weight size mismatch for " << name(); + post_layernorm_->weight().data().copy_(post_norm_weight); + is_post_layernorm_weight_loaded = true; + } + const auto& downsample_dict = + state_dict.get_dict_with_prefix("downsample."); + const auto& downsample_weight = downsample_dict.get_tensor("weight"); + const auto& downsample_bias = downsample_dict.get_tensor("bias"); + if (downsample_weight.defined()) { + downsample_->weight.data().copy_(downsample_weight); + is_downsample_weight_loaded_ = true; + } + if (downsample_bias.defined()) { + downsample_->bias.data().copy_(downsample_bias); + is_downsample_bias_loaded_ = true; + } + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + embeddings_->verify_loaded_weights(prefix + "embeddings."); + CHECK(is_post_conv_layernorm_weight_loaded) + << "weight is not loaded for " << prefix + "post_conv_layernorm.weight"; + CHECK(is_post_layernorm_weight_loaded) + << "weight is not loaded for " << prefix + "post_layernorm.weight"; + merger_->verify_loaded_weights(prefix + "merger."); + + CHECK(is_downsample_weight_loaded_) + << "weight is not loaded for " << prefix + "downsample.weight"; + CHECK(is_downsample_bias_loaded_) + << "bias is not loaded for " << prefix + "downsample.bias"; + } + + private: + int hidden_size_ = 0; + int out_hidden_size_ = 0; + int spatial_merge_size_ = 0; + + OxygenVisionPatchEmbed patch_embed_{nullptr}; + Qwen2_5_VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::ModuleList blocks_{nullptr}; + OxygenVisionEmbeddings embeddings_{nullptr}; + layer::RMSNorm post_conv_layernorm_{nullptr}; + layer::RMSNorm post_layernorm_{nullptr}; + torch::nn::Conv2d downsample_{nullptr}; + std::vector layers_; + OxygenVisionPatchMerger merger_{nullptr}; + torch::TensorOptions options_; + bool is_post_conv_layernorm_weight_loaded = false; + bool is_post_layernorm_weight_loaded = false; + bool is_downsample_weight_loaded_ = false; + bool is_downsample_bias_loaded_ = false; + torch::Tensor m_cos; + torch::Tensor m_sin; +}; +TORCH_MODULE(OxygenVisionTransformer); + +class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { + public: + OxygenvlmForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", OxygenVisionTransformer(context)); + + language_model_ = + register_module("language_model", OxygenForCausalLM(context)); + } + + void prepare_encoder_input(const ModelInputParams& input_params, + std::optional& image_inputs, + std::optional& video_inputs) { + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor pixel_values_videos; + if (const auto& res = mm_data.get("pixel_values_videos")) + pixel_values_videos = res.value(); + + torch::Tensor video_grid_thw; + if (const auto& res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = OxygenImageInputs{pixel_values, image_grid_thw}; + + if (pixel_values_videos.defined() && video_grid_thw.defined()) + video_inputs = OxygenVideoInputs{pixel_values_videos, video_grid_thw}; + } + + MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + std::optional image_input; + std::optional video_input; + prepare_encoder_input(input_params, image_input, video_input); + + auto merge_size = model_args_.mm_image_merge_size(); + MMDict multimodal_embeds; + if (image_input) { + // visual + auto image_embeds = visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + auto image_tokens = + (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) + .cpu() + .contiguous() + .to(torch::kLong); + + std::vector image_tokens_vec( + image_tokens.data_ptr(), + image_tokens.data_ptr() + image_tokens.numel()); + multimodal_embeds["image|embedding"] = + image_embeds.split(image_tokens_vec, 0); + } + if (video_input) { + std::vector temp_frames_hw; + for (int i = 0; i < video_input->video_grid_thw.size(0); ++i) { + auto t = video_input->video_grid_thw[i][0].item(); + auto h = video_input->video_grid_thw[i][1].item(); + auto w = video_input->video_grid_thw[i][2].item(); + auto repeated_row = + torch::tensor({1, h, w}).unsqueeze(0).repeat({t, 1}); + temp_frames_hw.push_back(repeated_row); + } + auto flatten_video_grid_thw = torch::cat(temp_frames_hw, 0); + // visual + auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), + flatten_video_grid_thw, + input_params); + // Split based on original video count, not frame count + // video_grid_thw has shape [num_videos, 3], video_embeds is flattened + // We need to split video_embeds back to match num_videos + std::vector split_sizes; + for (int i = 0; i < video_input->video_grid_thw.size(0); ++i) { + auto t = video_input->video_grid_thw[i][0].item(); + auto h = video_input->video_grid_thw[i][1].item(); + auto w = video_input->video_grid_thw[i][2].item(); + // Tokens for this video = t frames * (h * w / merge_size / merge_size) + auto tokens = t * h * w / merge_size / merge_size; + split_sizes.push_back(tokens); + } + multimodal_embeds["video|embedding"] = video_embeds.split(split_sizes, 0); + } + return multimodal_embeds; + } + + torch::Tensor generate_multimodal_mask(torch::Tensor input_ids) { + auto special_token_ids = torch::tensor( + {model_args_.image_token_id(), model_args_.video_token_id()}, + input_ids.options().dtype(torch::kInt64)); + auto is_multimodal = torch::isin(input_ids, special_token_ids); + return is_multimodal; + } + + torch::Tensor merge_multimodal_embeddings( + torch::Tensor inputs_embeds, + const torch::Tensor& multimodal_embeds, + const torch::Tensor& is_multimodal) { + inputs_embeds.index_put_({is_multimodal}, multimodal_embeds); + return inputs_embeds; + } + + torch::Tensor get_input_embeddings(const torch::Tensor input_ids, + const ModelInputParams& input_params) { + const auto& mm_data = input_params.mm_data; + torch::Tensor multimodal_embeds; + if (const auto& emb = mm_data.get("embedding")) { + multimodal_embeds = emb.value(); + } + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (!multimodal_embeds.defined()) { + return inputs_embeds; + } + auto is_multimodal = generate_multimodal_mask(input_ids); + inputs_embeds = merge_multimodal_embeddings( + inputs_embeds, multimodal_embeds, is_multimodal); + return inputs_embeds; + } + + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return language_model_(tokens, positions, kv_caches, input_params); + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + OxygenVisionTransformer visual_{nullptr}; + OxygenForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(OxygenvlmForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(oxygenvlm, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(oxygenvlm, OxygenvlmForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(oxygenvlm, Qwen2VLImageProcessor); + +// register the model args +REGISTER_MODEL_ARGS(oxygenvlm, [&] { + LOAD_ARG_OR(model_type, "model_type", "oxygenvlm"); + LOAD_ARG_OR(vision_start_token_id, "vision_start_token_id", 151652); + LOAD_ARG_OR(vision_end_token_id, "vision_end_token_id", 151653); + LOAD_ARG_OR(vision_token_id, "vision_token_id", 151654); + LOAD_ARG_OR(video_token_id, "video_token_id", 151656); + LOAD_ARG_OR(image_token_id, "image_token_id", 151655); + + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + // text config + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151936); + LOAD_ARG_OR(eos_token_id, "text_config.eos_token_id", 151645); + LOAD_ARG_OR(attention_bias, "text_config.attention_bias", false); + LOAD_ARG_OR(attention_dropout, "text_config.attention_dropout", 0.0f); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 5120); + LOAD_ARG_OR(initializer_range, "text_config.initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 25600); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 40960); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 64); + LOAD_ARG_OR(head_dim, "text_config.head_dim", 128); + + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 64); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 8); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-05); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(rope_scaling_rope_type, "text_config.rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 1000000); + + // vision config + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 24); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "silu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1536); + LOAD_ARG_OR(mm_image_size, "vision_config.image_size", 336); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR( + mm_projector_hidden_size, "vision_config.projector_hidden_size", 4096); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 12); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 5120); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 14); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + if (args->rope_scaling_rope_type() == "default") + args->rope_scaling_rope_type() = "mrope"; + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4096); +}); + +#undef LOAD_OXYGENVLM_MODEL_ARGS + +} // namespace xllm diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h index 116a9f1c5..d1cf9a691 100644 --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -21,152 +21,11 @@ limitations under the License. #include "core/layers/qwen2_decoder_layer.h" #include "models/llm/qwen2.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" namespace xllm { -class Qwen2_5_VLInputProcessor : public InputProcessor { - enum class TokenType { - INVALID, - IMAGE, - VIDEO, - }; - - public: - Qwen2_5_VLInputProcessor(const ModelArgs& args) { - merge_size_ = args.mm_image_merge_size(); - vision_start_token_id_ = args.vision_start_token_id(); - vision_end_token_id_ = args.vision_end_token_id(); - image_token_id_ = args.image_token_id(); - video_token_id_ = args.video_token_id(); - } - - void process(std::string& prompt, const MMData& mm_data) override { - torch::Tensor image_grid_thw; - if (auto res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - - torch::Tensor video_grid_thw; - if (auto res = mm_data.get("video_grid_thw")) - video_grid_thw = res.value(); - - if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; - - auto merge_length = merge_size_ * merge_size_; - int total_image_token = 0; - if (image_grid_thw.defined()) { - auto count = image_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_image_token += - image_grid_thw[idx].prod().item() / merge_length; - } - - int total_video_token = 0; - if (video_grid_thw.defined()) { - auto count = video_grid_thw.sizes()[0]; - for (int idx = 0; idx < count; ++idx) - total_video_token += - video_grid_thw[idx].prod().item() / merge_length; - } - - size_t total_token_len = total_image_token * image_token_.size() + - total_video_token * video_token_.size(); - std::string data; - data.reserve(prompt.size() + total_token_len); - - int image_index = 0; - int video_index = 0; - - const torch::Tensor* grid_thw = nullptr; - const std::string* token = nullptr; - int* index = 0; - - size_t begin = 0; - auto pair = find_vision_token(prompt, begin); - - while (pair.second != std::string::npos) { - data.append(prompt, begin, pair.second - begin); - - if (pair.first == TokenType::IMAGE) { - grid_thw = &image_grid_thw; - token = &image_token_; - index = &image_index; - } else if (pair.first == TokenType::VIDEO) { - grid_thw = &video_grid_thw; - token = &video_token_; - index = &video_index; - } else { - assert(false); - } - - auto token_num = (*grid_thw)[(*index)].prod().item() / merge_length; - while (token_num--) data.append(*token); - - ++(*index); - begin = pair.second + token->size(); - pair = find_vision_token(prompt, begin); - } - - if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); - - prompt = std::move(data); - } - - void find_mm_spans(const std::vector& prompt, MMData& mm_data) { - auto start = prompt.begin(); - uint32_t global_mm_index = 0; - uint32_t offset = 0; - uint32_t length = 0; - auto& mm_items = mm_data.items(); - while (true) { - auto vision_start_it = - std::find(start, prompt.end(), vision_start_token_id_); - auto vision_end_it = std::find(start, prompt.end(), vision_end_token_id_); - if (vision_start_it == prompt.end()) { - break; - } - offset = std::distance(prompt.begin(), vision_start_it); - length = std::distance(vision_start_it + 1, vision_end_it); - - auto& item = mm_items[global_mm_index]; - if (*(vision_start_it + 1) == image_token_id_) { - item.mutable_state().mutable_token_pos() = {offset + 1, length}; - } else if (*(vision_start_it + 1) == video_token_id_) { - item.mutable_state().mutable_token_pos() = {offset + 1, length}; - } - global_mm_index++; - start = std::next(vision_end_it); - } - } - - private: - std::pair find_vision_token(const std::string& prompt, - size_t begin) { - auto img_pos = prompt.find(image_token_, begin); - auto vid_pos = prompt.find(video_token_, begin); - - if (img_pos == std::string::npos && vid_pos == std::string::npos) - return {TokenType::INVALID, std::string::npos}; - else if (vid_pos == std::string::npos) - return {TokenType::IMAGE, img_pos}; - else if (img_pos == std::string::npos) - return {TokenType::VIDEO, vid_pos}; - else - return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) - : std::make_pair(TokenType::VIDEO, vid_pos); - } - - private: - const std::string image_token_ = "<|image_pad|>"; - const std::string video_token_ = "<|video_pad|>"; - int32_t vision_start_token_id_; - int32_t vision_end_token_id_; - int32_t image_token_id_; - int32_t video_token_id_; - int32_t merge_size_ = 0; -}; - class Qwen2_5_VisionPatchEmbedImpl : public torch::nn::Module { public: Qwen2_5_VisionPatchEmbedImpl(const ModelContext& context) { @@ -766,6 +625,22 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + // return full embeddings if set flag + if (FLAGS_enable_return_mm_full_embeddings) { + return h; + } + + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + auto pooler_output = torch::nn::functional::normalize( + h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); + return pooler_output; + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); @@ -773,7 +648,8 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { void load_model(std::unique_ptr loader) { for (const auto& state_dict : loader->get_state_dicts()) { - visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); + visual_->load_state_dict(state_dict->get_dict_with_prefix( + std::vector{"visual.", "model.visual."})); } if (!model_args_.image_embedding_mode()) { @@ -870,6 +746,9 @@ REGISTER_IMAGE_PROCESSOR(qwen2_5_vl, Qwen2VLImageProcessor); "rope_scaling.mrope_section", \ std::vector({16, 24, 24})); \ LOAD_ARG_OR(vocab_size, "vocab_size", 152064); \ + if (args->rope_scaling_rope_type() == "default") { \ + args->rope_scaling_rope_type() = "mrope"; \ + } \ } while (0) REGISTER_MODEL_ARGS(qwen2_5_vl, [&] { LOAD_QWEN2_5_VL_MODEL_ARGS(); }); diff --git a/xllm/models/vlm/qwen2_vl.h b/xllm/models/vlm/qwen2_vl.h index 0178d6826..d897df0d8 100644 --- a/xllm/models/vlm/qwen2_vl.h +++ b/xllm/models/vlm/qwen2_vl.h @@ -21,8 +21,8 @@ limitations under the License. #include "core/layers/qwen2_vision_layer.h" #include "models/llm/qwen2.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" #include "qwen2_5_vl.h" namespace xllm { @@ -512,6 +512,16 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + namespace F = torch::nn::functional; + return F::normalize(h, F::NormalizeFuncOptions().p(2).dim(1)); + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); diff --git a/xllm/models/vlm/qwen2_vl_embedding.h b/xllm/models/vlm/qwen2_vl_embedding.h deleted file mode 100644 index 8eb54b4bc..000000000 --- a/xllm/models/vlm/qwen2_vl_embedding.h +++ /dev/null @@ -1,235 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include "core/framework/model/embedding_vlm.h" -#include "core/framework/model/model_output.h" -#include "models/llm/qwen2.h" -#include "models/vlm/qwen2_vl.h" - -namespace xllm { - -class Qwen2_VLForEmbeddingImpl : public torch::nn::Module { - public: - Qwen2_VLForEmbeddingImpl(const ModelContext& context) - : model_args_(context.get_model_args()), - options_(context.get_tensor_options()) { - visual_ = register_module("visual", Qwen2_VisionTransformer(context)); - language_model_ = - register_module("language_model", QWen2ForCausalLM(context)); - } - - void prepare_encoder_input(const ModelInputParams& input_params, - std::optional& image_inputs, - std::optional& video_inputs) { - const auto& mm_data = input_params.mm_data; - torch::Tensor pixel_values; - if (const auto& res = mm_data.get("pixel_values")) - pixel_values = res.value(); - - torch::Tensor image_grid_thw; - if (const auto& res = mm_data.get("image_grid_thw")) - image_grid_thw = res.value(); - - torch::Tensor pixel_values_videos; - if (const auto& res = mm_data.get("pixel_values_videos")) - pixel_values_videos = res.value(); - - if (pixel_values.defined() && image_grid_thw.defined()) - image_inputs = Qwen2_VLImageInputs{pixel_values, image_grid_thw}; - } - - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { - std::optional image_input; - std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); - auto merge_size = model_args_.mm_image_merge_size(); - MMDict multimodal_embeds; - if (image_input) { - // visual - auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); - auto image_tokens = - (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) - .cpu() - .contiguous() - .to(torch::kLong); - - std::vector image_tokens_vec( - image_tokens.data_ptr(), - image_tokens.data_ptr() + image_tokens.numel()); - multimodal_embeds["image|embedding"] = - image_embeds.split(image_tokens_vec, 0 /*dim*/); - } - return multimodal_embeds; - } - - torch::Tensor generate_multimodal_mask(torch::Tensor input_ids) { - auto special_token_ids = torch::tensor( - {model_args_.image_token_id(), model_args_.video_token_id()}, - input_ids.options().dtype(torch::kInt64)); - auto is_multimodal = torch::isin(input_ids, special_token_ids); - return is_multimodal; - } - - torch::Tensor merge_multimodal_embeddings( - torch::Tensor inputs_embeds, - const torch::Tensor& multimodal_embeds, - const torch::Tensor& is_multimodal) { - inputs_embeds.index_put_({is_multimodal}, multimodal_embeds); - return inputs_embeds; - } - - torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.mm_data; - torch::Tensor multimodal_embeds; - if (const auto& emb = mm_data.get("embedding")) { - multimodal_embeds = emb.value(); - } - auto inputs_embeds = language_model_->get_input_embeddings(input_ids); - if (!multimodal_embeds.defined()) { - return inputs_embeds; - } - auto is_multimodal = generate_multimodal_mask(input_ids); - inputs_embeds = merge_multimodal_embeddings( - inputs_embeds, multimodal_embeds, is_multimodal); - return inputs_embeds; - } - - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); - } - - torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) { - auto h = hidden_states; - if (seleted_idxes.defined()) { - h = h.index_select(/*dim=*/0, seleted_idxes); - } - auto pooler_output = torch::nn::functional::normalize( - h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); - return pooler_output; - } - - torch::Tensor logits(const torch::Tensor&, const torch::Tensor&) { - NOT_IMPLEMENTED(); - return torch::Tensor(); - } - - torch::Device device() const { return options_.device(); } - - const torch::TensorOptions& options() const { return options_; } - - void load_model(std::unique_ptr loader) { - for (const auto& state_dict : loader->get_state_dicts()) { - visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); - } - // if (!model_args_.image_embedding_mode()) { - language_model_->load_model(std::move(loader)); - // } - } - - layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } - void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } - - layer::WordEmbedding get_word_embedding() { - return language_model_->get_word_embedding(); - } - - void set_word_embedding(layer::WordEmbedding& word_embedding) { - language_model_->set_word_embedding(word_embedding); - } - - private: - ModelArgs model_args_; - torch::TensorOptions options_; - - Qwen2_VisionTransformer visual_{nullptr}; - QWen2ForCausalLM language_model_{nullptr}; -}; -TORCH_MODULE(Qwen2_VLForEmbedding); - -template <> -class EmbeddingVLMImpl : public EmbeddingVLM { - public: - EmbeddingVLMImpl(Qwen2_VLForEmbedding model, - const torch::TensorOptions& options) - : model_(std::move(model)), options_(options) {} - - MMDict encode(const ModelInputParams& input_params) override { - return model_->get_multimodal_embeddings(input_params); - }; - torch::Tensor get_input_embeddings(const torch::Tensor& input_ids, - const ModelInputParams& input_params) { - return model_->get_input_embeddings(input_ids, input_params); - } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - return model_->forward(tokens, positions, kv_caches, parameters); - } - - torch::Tensor logits(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->logits(hidden_states, seleted_idxes); - } - - torch::Tensor pooler(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) override { - return model_->pooler(hidden_states, seleted_idxes); - } - - void load_model(std::unique_ptr loader) override { - model_->load_model(std::move(loader)); - } - - torch::Device device() const override { return model_->device(); } - - const torch::TensorOptions& options() const override { - return model_->options(); - } - - virtual void prepare_expert_weight(int32_t layer_id, - const std::vector& expert_ids) { - return; - } - virtual void update_expert_weight(int32_t layer_id) { return; } - - // Delegate head/embedding accessors to underlying model implementation. - layer::LmHead get_lm_head() override { return model_->get_lm_head(); } - void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); } - layer::WordEmbedding get_word_embedding() override { - return model_->get_word_embedding(); - } - void set_word_embedding(layer::WordEmbedding& embedding) override { - model_->set_word_embedding(embedding); - } - - private: - Qwen2_VLForEmbedding model_; - torch::TensorOptions options_; -}; - -REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(qwen2_vl_embedding, - qwen2_vl, - Qwen2_VLForEmbedding); -} // namespace xllm \ No newline at end of file diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h index 9d664b406..ea4034eec 100644 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -20,8 +20,8 @@ limitations under the License. #include "core/layers/qwen3_vision_layer.h" #include "models/llm/qwen3.h" #include "models/model_registry.h" -#include "processors/input_processor.h" -#include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen3_vl_image_processor.h" +#include "processors/qwen3_vl_input_processor.h" #include "qwen2_5_vl.h" namespace xllm { @@ -552,7 +552,6 @@ struct Qwen3_VLImageInputs { struct Qwen3_VLVideoInputs { torch::Tensor pixel_values_videos; torch::Tensor video_grid_thw; - torch::Tensor second_per_grid_ts; }; class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { @@ -585,17 +584,11 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { if (const auto& res = mm_data.get("video_grid_thw")) video_grid_thw = res.value(); - torch::Tensor second_per_grid_ts; - if (const auto& res = mm_data.get("second_per_grid_ts")) - second_per_grid_ts = res.value(); - if (pixel_values.defined() && image_grid_thw.defined()) image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; - if (pixel_values_videos.defined() && video_grid_thw.defined() && - second_per_grid_ts.defined()) - video_inputs = Qwen3_VLVideoInputs{ - pixel_values_videos, video_grid_thw, second_per_grid_ts}; + if (pixel_values_videos.defined() && video_grid_thw.defined()) + video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { @@ -677,6 +670,7 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { mm_data.get("embedding|deepstack_2").value()}; return deepstacks; } + torch::Tensor merge_multimodal_embeddings( torch::Tensor inputs_embeds, const torch::Tensor& multimodal_embeds, @@ -711,6 +705,11 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { return language_model_(tokens, positions, kv_caches, input_params); } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->pooler(hidden_states, seleted_idxes); + } + torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { return language_model_->logits(hidden_states, seleted_idxes); @@ -746,9 +745,9 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { }; TORCH_MODULE(Qwen3_VLForConditionalGeneration); -REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen2_5_VLInputProcessor); +REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen3_VLInputProcessor); REGISTER_CAUSAL_VLM_MODEL(qwen3_vl, Qwen3_VLForConditionalGeneration); -REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen2VLImageProcessor); +REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen3VLImageProcessor); REGISTER_MODEL_ARGS(qwen3_vl, [&] { // text config diff --git a/xllm/models/vlm/qwen3_vl_moe.h b/xllm/models/vlm/qwen3_vl_moe.h index f698cec24..523bd734f 100644 --- a/xllm/models/vlm/qwen3_vl_moe.h +++ b/xllm/models/vlm/qwen3_vl_moe.h @@ -20,8 +20,8 @@ limitations under the License. #include "core/layers/qwen3_vision_layer.h" #include "models/llm/qwen3_moe.h" #include "models/model_registry.h" -#include "processors/input_processor.h" #include "processors/qwen2_vl_image_processor.h" +#include "processors/qwen2_vl_input_processor.h" #include "qwen2_5_vl.h" #include "qwen3_vl.h" diff --git a/xllm/processors/CMakeLists.txt b/xllm/processors/CMakeLists.txt index b5a4e1832..347f1578d 100755 --- a/xllm/processors/CMakeLists.txt +++ b/xllm/processors/CMakeLists.txt @@ -23,16 +23,29 @@ cc_library( clip_image_processor.h minicpmv_image_processor.h qwen2_vl_image_processor.h + qwen3_vl_image_processor.h glm4v_image_processor.h pywarpper_image_processor.h input_processor.h + qwen2_vl_input_processor.h + qwen3_vl_input_processor.h + glm4v_input_processor.h + minicpmv_input_processor.h + clip_input_processor.h SRCS image_processor.cpp clip_image_processor.cpp minicpmv_image_processor.cpp qwen2_vl_image_processor.cpp + qwen3_vl_image_processor.cpp glm4v_image_processor.cpp pywarpper_image_processor.cpp + qwen2_vl_input_processor.cpp + qwen3_vl_input_processor.cpp + glm4v_input_processor.cpp + minicpmv_input_processor.cpp + clip_input_processor.cpp DEPS ${BASE_DEPS} + :request ) diff --git a/xllm/processors/clip_input_processor.cpp b/xllm/processors/clip_input_processor.cpp new file mode 100644 index 000000000..9e1e7de70 --- /dev/null +++ b/xllm/processors/clip_input_processor.cpp @@ -0,0 +1,103 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "processors/clip_input_processor.h" + +#include + +#include +#include + +namespace xllm { + +CLIPVLInputProcessor::CLIPVLInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); +} + +void CLIPVLInputProcessor::process(std::string& prompt, const MMData& mm_data) { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + auto merge_length = merge_size_ * merge_size_; + int32_t total_image_token = 0; + if (image_grid_thw.defined()) { + auto count = image_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + int32_t total_video_token = 0; + if (video_grid_thw.defined()) { + auto count = video_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_video_token += + video_grid_thw[idx].prod().item() / merge_length; + } + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * video_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + int32_t image_index = 0; + int32_t video_index = 0; + const torch::Tensor* grid_thw = nullptr; + const std::string* token = nullptr; + int32_t* index = nullptr; + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + while (pair.second != std::string::npos) { + data.append(prompt, begin, pair.second - begin); + if (pair.first == TokenType::IMAGE) { + grid_thw = &image_grid_thw; + token = &image_token_; + index = &image_index; + } else if (pair.first == TokenType::VIDEO) { + grid_thw = &video_grid_thw; + token = &video_token_; + index = &video_index; + } else { + assert(false); + } + auto token_num = + (*grid_thw)[(*index)].prod().item() / merge_length; + while (token_num--) data.append(*token); + ++(*index); + begin = pair.second + token->size(); + pair = find_vision_token(prompt, begin); + } + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + prompt = std::move(data); +} + +std::pair +CLIPVLInputProcessor::find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); +} + +} // namespace xllm diff --git a/xllm/processors/clip_input_processor.h b/xllm/processors/clip_input_processor.h new file mode 100644 index 000000000..0794f29ba --- /dev/null +++ b/xllm/processors/clip_input_processor.h @@ -0,0 +1,49 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "core/framework/model/model_args.h" +#include "core/framework/request/mm_data.h" +#include "processors/input_processor.h" + +namespace xllm { + +class CLIPVLInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + explicit CLIPVLInputProcessor(const ModelArgs& args); + + void process(std::string& prompt, const MMData& mm_data) override; + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin); + + const std::string image_token_ = "<|image_pad|>"; + const std::string video_token_ = "<|video_pad|>"; + int32_t merge_size_ = 0; +}; + +} // namespace xllm diff --git a/xllm/processors/glm4v_input_processor.cpp b/xllm/processors/glm4v_input_processor.cpp new file mode 100644 index 000000000..6ede384ec --- /dev/null +++ b/xllm/processors/glm4v_input_processor.cpp @@ -0,0 +1,198 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "processors/glm4v_input_processor.h" + +#include +#include + +#include +#include +#include + +namespace xllm { + +GLM4VInputProcessor::GLM4VInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); + image_start_token_id_ = args.image_start_token_id(); + image_end_token_id_ = args.image_end_token_id(); + video_start_token_id_ = args.video_start_token_id(); + video_end_token_id_ = args.video_end_token_id(); + image_token_id_ = args.image_token_id(); +} + +void GLM4VInputProcessor::process(std::string& prompt, const MMData& mm_data) { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + + std::vector video_metadata; + mm_data.get_metadata(MMType::VIDEO, video_metadata); + + if (video_metadata.size() > 0) { + CHECK(video_metadata.size() == + static_cast(video_grid_thw.sizes()[0])); + } + + auto merge_length = merge_size_ * merge_size_; + int32_t total_image_token = 0; + + if (image_grid_thw.defined()) { + auto count = image_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + + int32_t total_video_token = 0; + if (video_grid_thw.defined()) { + auto count = video_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_video_token += video_grid_thw[idx].prod().item() / + merge_length / + video_grid_thw[idx][0].item(); + } + + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * image_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + + int32_t image_index = 0; + int32_t video_index = 0; + + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + + while (pair.second != std::string::npos) { + data.append(prompt, begin, pair.second - begin); + + if (pair.first == TokenType::IMAGE) { + auto token_num = + image_grid_thw[image_index].prod().item() / merge_length; + while (token_num--) data.append(image_token_); + + image_index++; + begin = pair.second + image_token_.size(); + } else if (pair.first == TokenType::VIDEO) { + auto num_frames = video_grid_thw[video_index][0].item(); + auto timestamps = video_metadata[video_index].timestamps; + CHECK(!timestamps.empty()); + + auto selected = build_timestamps(timestamps, num_frames); + auto token_num = video_grid_thw[video_index].prod().item() / + merge_length / num_frames; + + for (size_t idx = 0; idx < num_frames; ++idx) { + data.append(begin_of_image_token_); + + auto num = token_num; + while (num--) data.append(image_token_); + + data.append(end_of_image_token_); + data.append(format_timestamp_str(selected[idx])); + } + + video_index++; + begin = pair.second + video_token_.size(); + } else { + assert(false); + } + + pair = find_vision_token(prompt, begin); + } + + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + + prompt = std::move(data); +} + +void GLM4VInputProcessor::find_mm_spans(const std::vector& prompt, + MMData& mm_data) { + size_t tokens_num = prompt.size(); + uint32_t global_mm_index = 0; + uint32_t offset = 0; + uint32_t length = 0; + bool is_video = false; + auto& mm_items = mm_data.items(); + for (size_t idx = 0; idx < tokens_num; ++idx) { + auto token = prompt[idx]; + if (token == video_start_token_id_) { + is_video = true; + } else if (token == video_end_token_id_) { + is_video = false; + } + if (is_video) continue; + if (token == image_start_token_id_) { + offset = idx + 1; + } + if (token == image_token_id_) { + length++; + } else if (token == image_end_token_id_) { + auto& item = mm_items[global_mm_index++]; + item.mutable_state().mutable_token_pos() = {offset, length}; + length = 0; + } + } +} + +std::pair +GLM4VInputProcessor::find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); +} + +std::vector GLM4VInputProcessor::build_timestamps( + const std::vector& timestamps, + size_t num_frames) { + std::vector vec; + vec.reserve(num_frames); + + for (size_t i = 0; i < timestamps.size(); i += 2) { + vec.push_back(timestamps[i]); + if (vec.size() == num_frames) break; + } + + while (vec.size() < num_frames) { + vec.push_back(vec.back()); + } + + return vec; +} + +std::string GLM4VInputProcessor::format_timestamp_str(double timestamp) { + char buffer[32]; + snprintf(buffer, sizeof(buffer), "%.1f seconds", timestamp); + return buffer; +} + +} // namespace xllm diff --git a/xllm/processors/glm4v_input_processor.h b/xllm/processors/glm4v_input_processor.h new file mode 100644 index 000000000..7832bba47 --- /dev/null +++ b/xllm/processors/glm4v_input_processor.h @@ -0,0 +1,62 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include "core/framework/model/model_args.h" +#include "core/framework/request/mm_data.h" +#include "processors/input_processor.h" + +namespace xllm { + +class GLM4VInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + explicit GLM4VInputProcessor(const ModelArgs& args); + + void process(std::string& prompt, const MMData& mm_data) override; + void find_mm_spans(const std::vector& prompt, MMData& mm_data) override; + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin); + std::vector build_timestamps(const std::vector& timestamps, + size_t num_frames); + std::string format_timestamp_str(double timestamp); + + const std::string image_token_ = "<|image|>"; + const std::string video_token_ = "<|video|>"; + const std::string begin_of_image_token_ = "<|begin_of_image|>"; + const std::string end_of_image_token_ = "<|end_of_image|>"; + + int32_t image_start_token_id_; + int32_t image_end_token_id_; + int32_t video_start_token_id_; + int32_t video_end_token_id_; + int32_t image_token_id_; + int32_t merge_size_ = 0; +}; + +} // namespace xllm diff --git a/xllm/processors/minicpmv_input_processor.cpp b/xllm/processors/minicpmv_input_processor.cpp new file mode 100644 index 000000000..84d302d34 --- /dev/null +++ b/xllm/processors/minicpmv_input_processor.cpp @@ -0,0 +1,178 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "processors/minicpmv_input_processor.h" + +#include +#include + +#include +#include + +#include "processors/minicpmv_image_processor.h" + +namespace xllm { + +MiniCPMInputProcessor::MiniCPMInputProcessor(const ModelArgs& args) { + image_feature_size_ = args.mm_image_feature_size(); + max_slice_nums_ = args.vision_max_slice_nums(); + slice_mode_ = args.mm_slice_mode(); + use_image_id_ = args.mm_use_image_id(); + scale_resolution_ = args.mm_scale_resolution(); +} + +void MiniCPMInputProcessor::process(std::string& prompt, + const MMData& mm_data) { + std::vector image_sizes; + mm_data.get("image_sizes", image_sizes); + + const std::regex pattern(R"(\([\s\S]*?\))"); + + std::sregex_iterator image_tag_begin(prompt.begin(), prompt.end(), pattern); + std::sregex_iterator image_tag_end; + + if (image_tag_begin == image_tag_end) { + return; + } + + std::vector> image_size_list; + image_size_list.reserve(image_sizes.size()); + for (auto& image_size : image_sizes) { + if (image_size.dim() != 1 || image_size.size(0) != 2) { + const auto& sizes = image_size.sizes(); + LOG(FATAL) << "image_size must be a 1D tensor with 2 " + "elements representing height and width;" + "now sizes: " + << sizes; + } + image_size_list.emplace_back(std::make_pair(image_size[0].item(), + image_size[1].item())); + } + + std::vector text_chunks; + size_t last_pos = 0; + + for (auto it = image_tag_begin; it != image_tag_end; ++it) { + auto match = *it; + text_chunks.push_back(prompt.substr(last_pos, match.position() - last_pos)); + last_pos = match.position() + match.length(); + } + + text_chunks.push_back(prompt.substr(last_pos)); + + std::string new_prompt; + for (int32_t i = 0; i < static_cast(image_size_list.size()); ++i) { + new_prompt += text_chunks[i]; + new_prompt += get_slice_image_placeholder(image_size_list[i], i); + } + + new_prompt += text_chunks.back(); + prompt = new_prompt; +} + +void MiniCPMInputProcessor::find_mm_spans(const std::vector& prompt, + MMData& mm_data) { + uint32_t global_mm_index = 0; + uint32_t offset = 0; + uint32_t length = 0; + auto& mm_items = mm_data.items(); + auto start = prompt.begin(); + while (true) { + auto image_start_it = std::find(start, prompt.end(), im_start_id_); + auto image_end_it = std::find(start, prompt.end(), im_end_id_); + if (image_start_it == prompt.end()) { + break; + } + offset = std::distance(prompt.begin(), image_start_it); + length = std::distance(image_start_it + 1, image_end_it); + auto& item = mm_items[global_mm_index++]; + item.mutable_state().mutable_token_pos() = {offset + 1, length}; + start = std::next(image_end_it); + } +} + +std::string MiniCPMInputProcessor::get_image_id_placeholder(int32_t idx) const { + return im_id_start_ + std::to_string(idx) + im_id_end_; +} + +std::string MiniCPMInputProcessor::get_grid_placeholder( + const std::pair& grid) const { + if (grid.first == 0 || grid.second == 0) { + return ""; + } + + std::string slice_placeholder = slice_start_token_; + + for (int32_t i = 0; i < image_feature_size_; ++i) { + slice_placeholder += unk_token_; + } + + slice_placeholder += slice_end_token_; + + std::string grid_placeholder; + + for (int32_t i = 0; i < grid.second; ++i) { + for (int32_t j = 0; j < grid.first; ++j) { + grid_placeholder += slice_placeholder; + } + if (i < grid.second - 1) { + grid_placeholder += "\n"; + } + } + + return grid_placeholder; +} + +std::string MiniCPMInputProcessor::get_slice_image_placeholder( + const std::pair& image_size, + int32_t image_idx, + int32_t max_slice_nums, + std::optional use_image_id_opt) const { + if (max_slice_nums < 0) { + max_slice_nums = max_slice_nums_; + } + + bool use_image_id = + use_image_id_opt.has_value() ? use_image_id_opt.value() : use_image_id_; + + assert(max_slice_nums > 0); + + auto grid = MiniCPMVImageProcessor::get_sliced_grid( + image_size, max_slice_nums, scale_resolution_); + + std::string image_placeholder = im_start_token_; + + for (int i = 0; i < image_feature_size_; ++i) { + image_placeholder += unk_token_; + } + + image_placeholder += im_end_token_; + + std::string final_placeholder; + + if (use_image_id) { + final_placeholder = get_image_id_placeholder(image_idx) + image_placeholder; + } else { + final_placeholder = image_placeholder; + } + + if (slice_mode_) { + final_placeholder += get_grid_placeholder(grid); + } + + return final_placeholder; +} + +} // namespace xllm diff --git a/xllm/processors/minicpmv_input_processor.h b/xllm/processors/minicpmv_input_processor.h new file mode 100644 index 000000000..a42f59400 --- /dev/null +++ b/xllm/processors/minicpmv_input_processor.h @@ -0,0 +1,63 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include "core/framework/model/model_args.h" +#include "core/framework/request/mm_data.h" +#include "processors/input_processor.h" + +namespace xllm { + +class MiniCPMInputProcessor : public InputProcessor { + public: + explicit MiniCPMInputProcessor(const ModelArgs& args); + + void process(std::string& prompt, const MMData& mm_data) override; + void find_mm_spans(const std::vector& prompt, MMData& mm_data) override; + + private: + std::string get_image_id_placeholder(int idx) const; + std::string get_grid_placeholder(const std::pair& grid) const; + std::string get_slice_image_placeholder( + const std::pair& image_size, + int image_idx = 0, + int max_slice_nums = -1, + std::optional use_image_id_opt = std::nullopt) const; + + const std::string im_start_token_ = ""; + const std::string im_end_token_ = ""; + const std::string slice_start_token_ = ""; + const std::string slice_end_token_ = ""; + const std::string unk_token_ = ""; + const std::string im_id_start_ = ""; + const std::string im_id_end_ = ""; + + const int32_t im_start_id_ = 151659; + const int32_t im_end_id_ = 151658; + + bool slice_mode_; + bool use_image_id_; + int32_t max_slice_nums_; + int32_t image_feature_size_; + int32_t scale_resolution_; +}; + +} // namespace xllm diff --git a/xllm/processors/qwen2_vl_image_processor.cpp b/xllm/processors/qwen2_vl_image_processor.cpp index 3c95ede12..d6442e1d8 100644 --- a/xllm/processors/qwen2_vl_image_processor.cpp +++ b/xllm/processors/qwen2_vl_image_processor.cpp @@ -17,55 +17,64 @@ limitations under the License. namespace xllm { -namespace { - -using Size = std::pair; -std::optional smart_resize(int height, - int width, - int factor = 28, - int min_pixels = 56 * 56, - int max_pixels = 14 * 14 * 4 * 1280) { +std::optional +Qwen2VLImageProcessor::smart_resize_image(int32_t height, + int32_t width, + int32_t factor = 28, + int32_t min_pixels = 56 * 56, + int32_t max_pixels = 14 * 14 * 4 * + 1280) const { if (static_cast(std::max(height, width)) / std::min(height, width) > - 200) { + 200.0) { LOG(ERROR) << "Absolute aspect ratio must be smaller than 200"; return std::nullopt; } - int h_bar = - static_cast(std::round(height / static_cast(factor))) * + int32_t h_bar = + static_cast(std::round(height / static_cast(factor))) * factor; - int w_bar = - static_cast(std::round(width / static_cast(factor))) * + int32_t w_bar = + static_cast(std::round(width / static_cast(factor))) * factor; if (h_bar * w_bar > max_pixels) { double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = static_cast( + h_bar = static_cast( std::floor(height / beta / static_cast(factor))) * factor; - w_bar = static_cast( + w_bar = static_cast( std::floor(width / beta / static_cast(factor))) * factor; } else if (h_bar * w_bar < min_pixels) { double beta = std::sqrt(min_pixels / static_cast(height * width)); - h_bar = static_cast( + h_bar = static_cast( std::ceil(height * beta / static_cast(factor))) * factor; - w_bar = static_cast( + w_bar = static_cast( std::ceil(width * beta / static_cast(factor))) * factor; } return std::make_pair(h_bar, w_bar); } -} // namespace + +std::optional +Qwen2VLImageProcessor::smart_resize_video(int32_t num_frames, + int32_t height, + int32_t width, + int32_t temporal_factor, + int32_t factor, + int32_t min_pixels, + int32_t max_pixels) const { + return smart_resize_image(height, width, factor, min_pixels, max_pixels); +} torch::Tensor Qwen2VLImageProcessor::sample_frames( const VideoMetadata& metadata, - int temporal_patch_size, - int min_frames, - int max_frames, - int num_frames, + int32_t temporal_patch_size, + int32_t min_frames, + int32_t max_frames, + int32_t num_frames, double set_fps) { if (set_fps > 0.0 && num_frames > 0) { LOG(FATAL) << "num_frames and fps are mutually exclusive arguments, please " @@ -74,13 +83,13 @@ torch::Tensor Qwen2VLImageProcessor::sample_frames( double fps = set_fps; - int total_num_frames = metadata.total_num_frames; + int32_t total_num_frames = metadata.total_num_frames; if (num_frames > 0) { double double_num_frames = std::round(static_cast(num_frames) / temporal_patch_size) * temporal_patch_size; - num_frames = static_cast(double_num_frames); + num_frames = static_cast(double_num_frames); } else if (fps > 0.0) { if (metadata.fps <= 0.0) { LOG(FATAL) @@ -88,9 +97,9 @@ torch::Tensor Qwen2VLImageProcessor::sample_frames( "was provided which is required when sampling with `fps`. "; } - max_frames = - (std::min(max_frames, total_num_frames) / temporal_patch_size) * - temporal_patch_size; + max_frames = (std::min(max_frames, static_cast(total_num_frames)) / + temporal_patch_size) * + temporal_patch_size; double double_num_frames = static_cast(total_num_frames) / metadata.fps * fps; double_num_frames = std::min( @@ -100,7 +109,7 @@ torch::Tensor Qwen2VLImageProcessor::sample_frames( double_num_frames = std::floor(double_num_frames / temporal_patch_size) * temporal_patch_size; - num_frames = static_cast(double_num_frames); + num_frames = static_cast(double_num_frames); } if (num_frames > total_num_frames) { @@ -110,19 +119,19 @@ torch::Tensor Qwen2VLImageProcessor::sample_frames( } if (num_frames > 0) { - std::vector indices; + std::vector indices; indices.reserve(num_frames); - for (int i = 0; i < num_frames; ++i) { - int64_t k = static_cast( - (static_cast(i) * total_num_frames) / num_frames); - if (k >= total_num_frames) k = total_num_frames - 1; + for (int32_t i = 0; i < num_frames; ++i) { + int32_t k = i * total_num_frames / num_frames; + if (k >= total_num_frames) { + k = total_num_frames - 1; + } indices.push_back(k); } - return torch::tensor(indices, torch::TensorOptions().dtype(torch::kLong)); + return torch::tensor(indices, torch::TensorOptions().dtype(torch::kInt32)); } else { - return torch::arange(0, - static_cast(total_num_frames), - torch::TensorOptions().dtype(torch::kLong)); + return torch::arange( + 0, total_num_frames, torch::TensorOptions().dtype(torch::kInt32)); } } @@ -263,11 +272,11 @@ bool Qwen2VLImageProcessor::process_image(torch::Tensor image, // resize if (do_resize_) { - auto size = smart_resize(resized_height, - resized_width, - patch_size_ * merge_size_, - min_pixels_, - max_pixels_); + auto size = smart_resize_image(resized_height, + resized_width, + patch_size_ * merge_size_, + min_pixels_, + max_pixels_); // size_["shortest_edge"], // size_["longest_edge"]); if (!size) { @@ -375,19 +384,18 @@ bool Qwen2VLImageProcessor::process_video(torch::Tensor origin_video, /*num_frames=*/-1, /*set_fps=*/2.0); } else { - indices = torch::arange(0, - static_cast(origin_video.size(0)), - torch::TensorOptions().dtype(torch::kLong)); + indices = torch::arange( + 0, origin_video.size(0), torch::TensorOptions().dtype(torch::kInt32)); } auto video = origin_video.index_select(/*dim=*/0, indices); - int64_t sampled_total_frames = video.size(0); + int32_t sampled_total_frames = video.size(0); metadata.frame_indices = indices; metadata.timestamps.clear(); metadata.timestamps.reserve(static_cast(sampled_total_frames)); double fps_for_ts = (metadata.fps > 0.0) ? metadata.fps : 24.0; - for (int64_t i = 0; i < sampled_total_frames; ++i) { - int64_t frame_idx = metadata.frame_indices[i].item(); + for (int32_t i = 0; i < sampled_total_frames; ++i) { + int32_t frame_idx = metadata.frame_indices[i].item(); metadata.timestamps.push_back(static_cast(frame_idx) / fps_for_ts); } @@ -405,11 +413,13 @@ bool Qwen2VLImageProcessor::process_video(torch::Tensor origin_video, auto resized_width = shape[3]; if (do_resize_) { - auto size = smart_resize(resized_height, - resized_width, - patch_size_ * merge_size_, - size_["shortest_edge"], - size_["longest_edge"]); + auto size = smart_resize_video(static_cast(time_len), + resized_height, + resized_width, + temporal_patch_size_, + patch_size_ * merge_size_, + min_pixels_, + max_pixels_); if (!size) { return false; } diff --git a/xllm/processors/qwen2_vl_image_processor.h b/xllm/processors/qwen2_vl_image_processor.h index cc60f4a1a..12325dc83 100644 --- a/xllm/processors/qwen2_vl_image_processor.h +++ b/xllm/processors/qwen2_vl_image_processor.h @@ -30,6 +30,21 @@ class Qwen2VLImageProcessor : public ImageProcessor { bool process(const MMInput& mm_inputs, MMData& mm_datas) override; + using Size = std::pair; + virtual std::optional smart_resize_image(int32_t height, + int32_t width, + int32_t factor, + int32_t min_pixels, + int32_t max_pixels) const; + + virtual std::optional smart_resize_video(int32_t num_frames, + int32_t height, + int32_t width, + int32_t temporal_factor, + int32_t factor, + int32_t min_pixels, + int32_t max_pixels) const; + private: bool process_images(std::vector images, MMData& mm_datas); bool process_image(torch::Tensor image, @@ -47,12 +62,12 @@ class Qwen2VLImageProcessor : public ImageProcessor { VideoMetadata& metadata, torch::Tensor& pixel_values, torch::Tensor& thw); - torch::Tensor sample_frames(const VideoMetadata& metadata, - int temporal_patch_size, - int min_frames, - int max_frames, - int num_frames = -1, - double set_fps = -1.0); + virtual torch::Tensor sample_frames(const VideoMetadata& metadata, + int32_t temporal_patch_size, + int32_t min_frames, + int32_t max_frames, + int32_t num_frames = -1, + double set_fps = -1.0); private: bool do_convert_rgb_ = true; @@ -64,22 +79,22 @@ class Qwen2VLImageProcessor : public ImageProcessor { std::vector image_mean_; std::vector image_std_; - int max_pixels_ = 12845056; - int min_pixels_ = 3136; + int32_t max_pixels_ = 12845056; + int32_t min_pixels_ = 3136; - int merge_size_ = 2; - int patch_size_ = 14; + int32_t merge_size_ = 2; + int32_t patch_size_ = 14; - int resample_ = 3; + int32_t resample_ = 3; double rescale_factor_ = 0.00392156862745098; std::unordered_map size_; - int temporal_patch_size_ = 2; + int32_t temporal_patch_size_ = 2; bool do_sample_frame_ = true; - int min_frames_ = 4; - int max_frames_ = 768; + int32_t min_frames_ = 4; + int32_t max_frames_ = 768; }; } // namespace xllm diff --git a/xllm/processors/qwen2_vl_input_processor.cpp b/xllm/processors/qwen2_vl_input_processor.cpp new file mode 100644 index 000000000..e43c152c3 --- /dev/null +++ b/xllm/processors/qwen2_vl_input_processor.cpp @@ -0,0 +1,151 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "processors/qwen2_vl_input_processor.h" + +#include + +#include +#include + +namespace xllm { + +Qwen2_5_VLInputProcessor::Qwen2_5_VLInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); + vision_start_token_id_ = args.vision_start_token_id(); + vision_end_token_id_ = args.vision_end_token_id(); + image_token_id_ = args.image_token_id(); + video_token_id_ = args.video_token_id(); +} + +void Qwen2_5_VLInputProcessor::process(std::string& prompt, + const MMData& mm_data) { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + + auto merge_length = merge_size_ * merge_size_; + int32_t total_image_token = 0; + if (image_grid_thw.defined()) { + auto count = image_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + + int32_t total_video_token = 0; + if (video_grid_thw.defined()) { + auto count = video_grid_thw.sizes()[0]; + for (int32_t idx = 0; idx < count; ++idx) + total_video_token += + video_grid_thw[idx].prod().item() / merge_length; + } + + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * video_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + + int32_t image_index = 0; + int32_t video_index = 0; + + const torch::Tensor* grid_thw = nullptr; + const std::string* token = nullptr; + int32_t* index = nullptr; + + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + + while (pair.second != std::string::npos) { + data.append(prompt, begin, pair.second - begin); + + if (pair.first == TokenType::IMAGE) { + grid_thw = &image_grid_thw; + token = &image_token_; + index = &image_index; + } else if (pair.first == TokenType::VIDEO) { + grid_thw = &video_grid_thw; + token = &video_token_; + index = &video_index; + } else { + assert(false); + } + + auto token_num = + (*grid_thw)[(*index)].prod().item() / merge_length; + while (token_num--) data.append(*token); + + ++(*index); + begin = pair.second + token->size(); + pair = find_vision_token(prompt, begin); + } + + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + + prompt = std::move(data); +} + +void Qwen2_5_VLInputProcessor::find_mm_spans(const std::vector& prompt, + MMData& mm_data) { + auto start = prompt.begin(); + uint32_t global_mm_index = 0; + uint32_t offset = 0; + uint32_t length = 0; + auto& mm_items = mm_data.items(); + while (true) { + auto vision_start_it = + std::find(start, prompt.end(), vision_start_token_id_); + auto vision_end_it = std::find(start, prompt.end(), vision_end_token_id_); + if (vision_start_it == prompt.end()) { + break; + } + offset = std::distance(prompt.begin(), vision_start_it); + length = std::distance(vision_start_it + 1, vision_end_it); + + auto& item = mm_items[global_mm_index]; + if (*(vision_start_it + 1) == image_token_id_) { + item.mutable_state().mutable_token_pos() = {offset + 1, length}; + } else if (*(vision_start_it + 1) == video_token_id_) { + item.mutable_state().mutable_token_pos() = {offset + 1, length}; + } + global_mm_index++; + start = std::next(vision_end_it); + } +} + +std::pair +Qwen2_5_VLInputProcessor::find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); +} + +} // namespace xllm diff --git a/xllm/processors/qwen2_vl_input_processor.h b/xllm/processors/qwen2_vl_input_processor.h new file mode 100644 index 000000000..dbffd329c --- /dev/null +++ b/xllm/processors/qwen2_vl_input_processor.h @@ -0,0 +1,54 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "core/framework/model/model_args.h" +#include "core/framework/request/mm_data.h" +#include "processors/input_processor.h" + +namespace xllm { + +class Qwen2_5_VLInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + explicit Qwen2_5_VLInputProcessor(const ModelArgs& args); + + void process(std::string& prompt, const MMData& mm_data) override; + void find_mm_spans(const std::vector& prompt, MMData& mm_data) override; + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin); + + const std::string image_token_ = "<|image_pad|>"; + const std::string video_token_ = "<|video_pad|>"; + int32_t vision_start_token_id_; + int32_t vision_end_token_id_; + int32_t image_token_id_; + int32_t video_token_id_; + int32_t merge_size_ = 0; +}; + +} // namespace xllm diff --git a/xllm/processors/qwen3_vl_image_processor.cpp b/xllm/processors/qwen3_vl_image_processor.cpp new file mode 100644 index 000000000..2ad2d4c2a --- /dev/null +++ b/xllm/processors/qwen3_vl_image_processor.cpp @@ -0,0 +1,120 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_vl_image_processor.h" + +namespace xllm { + +std::optional +Qwen3VLImageProcessor::smart_resize_video(int32_t num_frames, + int32_t height, + int32_t width, + int32_t temporal_factor, + int32_t factor, + int32_t min_pixels, + int32_t max_pixels) const { + if (height < factor || width < factor) { + LOG(ERROR) << "height:" << height << " or width:" << width + << " must be larger than factor:" << factor; + return std::nullopt; + } + if (static_cast(std::max(height, width)) / std::min(height, width) > + 200.0) { + LOG(ERROR) << "Absolute aspect ratio must be smaller than 200"; + return std::nullopt; + } + + int32_t h_bar = + static_cast(std::round(height / static_cast(factor))) * + factor; + int32_t w_bar = + static_cast(std::round(width / static_cast(factor))) * + factor; + int32_t t_bar = static_cast(std::ceil( + num_frames / static_cast(temporal_factor))) * + temporal_factor; + + const double thw_bar = static_cast(t_bar) * + static_cast(h_bar) * + static_cast(w_bar); + + if (thw_bar > static_cast(max_pixels)) { + const double beta = + std::sqrt((static_cast(num_frames) * height * width) / + static_cast(max_pixels)); + int32_t h_new = static_cast(std::floor( + height / beta / static_cast(factor))) * + factor; + int32_t w_new = static_cast(std::floor( + width / beta / static_cast(factor))) * + factor; + h_bar = std::max(factor, h_new); + w_bar = std::max(factor, w_new); + } else if (thw_bar < static_cast(min_pixels)) { + const double beta = + std::sqrt(static_cast(min_pixels) / + (static_cast(num_frames) * height * width)); + h_bar = static_cast( + std::ceil(height * beta / static_cast(factor))) * + factor; + w_bar = static_cast( + std::ceil(width * beta / static_cast(factor))) * + factor; + } + + return std::make_pair(h_bar, w_bar); +} + +torch::Tensor Qwen3VLImageProcessor::sample_frames( + const VideoMetadata& metadata, + int32_t /*temporal_patch_size*/, + int32_t min_frames, + int32_t max_frames, + int32_t num_frames, + double set_fps) { + if (set_fps > 0.0 && num_frames > 0) { + LOG(FATAL) << "num_frames and fps are mutually exclusive arguments, please " + "use only one!"; + } + + double fps = set_fps; + int32_t total_num_frames = metadata.total_num_frames; + + if (num_frames <= 0 && fps > 0.0) { + if (metadata.fps <= 0.0) { + LOG(FATAL) + << "Asked to sample `fps` frames per second but no video metadata " + "was provided which is required when sampling with `fps`. "; + } + num_frames = static_cast(static_cast(total_num_frames) / + metadata.fps * fps); + num_frames = std::min(std::max(num_frames, min_frames), + std::min(max_frames, total_num_frames)); + } + + if (num_frames <= 0) { + num_frames = std::min(std::max(total_num_frames, min_frames), max_frames); + } + + auto lin = torch::linspace(0.0, + total_num_frames - 1, + num_frames, + torch::TensorOptions().dtype(torch::kFloat32)); + auto idx = torch::round(lin).to(torch::kInt32); + idx = torch::clamp(idx, 0, total_num_frames - 1); + return idx; +} + +} // namespace xllm diff --git a/xllm/processors/qwen3_vl_image_processor.h b/xllm/processors/qwen3_vl_image_processor.h new file mode 100644 index 000000000..79f4db6be --- /dev/null +++ b/xllm/processors/qwen3_vl_image_processor.h @@ -0,0 +1,47 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "qwen2_vl_image_processor.h" + +namespace xllm { + +class Qwen3VLImageProcessor : public Qwen2VLImageProcessor { + public: + explicit Qwen3VLImageProcessor(const ModelArgs& args) + : Qwen2VLImageProcessor(args) {} + + std::optional smart_resize_video(int32_t num_frames, + int32_t height, + int32_t width, + int32_t temporal_factor, + int32_t factor, + int32_t min_pixels, + int32_t max_pixels) const override; + + torch::Tensor sample_frames(const VideoMetadata& metadata, + int32_t temporal_patch_size, + int32_t min_frames, + int32_t max_frames, + int32_t num_frames, + double set_fps) override; +}; + +} // namespace xllm diff --git a/xllm/processors/qwen3_vl_input_processor.cpp b/xllm/processors/qwen3_vl_input_processor.cpp new file mode 100644 index 000000000..130e938b8 --- /dev/null +++ b/xllm/processors/qwen3_vl_input_processor.cpp @@ -0,0 +1,265 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "processors/qwen3_vl_input_processor.h" + +#include + +#include +#include + +namespace xllm { + +Qwen3_VLInputProcessor::Qwen3_VLInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); + vision_start_token_id_ = args.vision_start_token_id(); + vision_end_token_id_ = args.vision_end_token_id(); + image_token_id_ = args.image_token_id(); + video_token_id_ = args.video_token_id(); + temporal_patch_size_ = args.mm_temporal_patch_size(); +} + +void Qwen3_VLInputProcessor::process(std::string& prompt, + const MMData& mm_data) { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + + std::vector video_metadata; + mm_data.get_metadata(MMType::VIDEO, video_metadata); + if (video_grid_thw.defined()) { + CHECK(video_metadata.size() == static_cast(video_grid_thw.size(0))); + } + + const int32_t merge_length = merge_size_ * merge_size_; + + int32_t total_image_token = 0; + if (image_grid_thw.defined()) { + int32_t count = image_grid_thw.size(0); + for (int32_t idx = 0; idx < count; ++idx) { + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + } + + int32_t total_video_token = 0; + if (video_grid_thw.defined()) { + int32_t count = video_grid_thw.size(0); + for (int32_t idx = 0; idx < count; ++idx) { + total_video_token += + video_grid_thw[idx].prod().item() / merge_length; + } + } + + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * video_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + + int32_t image_index = 0; + int32_t video_index = 0; + + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + + while (pair.second != std::string::npos) { + if (pair.first == TokenType::IMAGE) { + data.append(prompt, begin, pair.second - begin); + + auto token_num = + image_grid_thw[image_index].prod().item() / merge_length; + while (token_num--) { + data.append(image_token_); + } + + ++image_index; + begin = pair.second + image_token_.size(); + + } else if (pair.first == TokenType::VIDEO) { + const size_t pos = pair.second; + const size_t vs_len = vision_start_token_.size(); + const size_t ve_len = vision_end_token_.size(); + const size_t vt_len = video_token_.size(); + + size_t replace_begin = pos; + size_t replace_end = pos + vt_len; + + if (pos >= vs_len && + prompt.compare(pos - vs_len, vs_len, vision_start_token_) == 0 && + prompt.compare(pos + vt_len, ve_len, vision_end_token_) == 0) { + replace_begin = pos - vs_len; + replace_end = pos + vt_len + ve_len; + } + + data.append(prompt, begin, replace_begin - begin); + + const int32_t num_frames = video_grid_thw[video_index][0].item(); + const int32_t token_num = video_grid_thw[video_index][1].item() * + video_grid_thw[video_index][2].item() / + merge_length; + + const auto& timestamps = video_metadata[video_index].timestamps; + CHECK(!timestamps.empty()); + + auto selected = build_timestamps( + timestamps, static_cast(num_frames), temporal_patch_size_); + + for (int32_t idx = 0; idx < num_frames; ++idx) { + data.append(format_timestamp_str(selected[idx])); + data.append(vision_start_token_); + int32_t num = token_num; + + while (num--) { + data.append(video_token_); + } + data.append(vision_end_token_); + } + + ++video_index; + begin = replace_end; + } else { + assert(false); + } + + pair = find_vision_token(prompt, begin); + } + + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + prompt = std::move(data); +} + +void Qwen3_VLInputProcessor::find_mm_spans(const std::vector& prompt, + MMData& mm_data) { + auto start = prompt.begin(); + uint32_t global_mm_index = 0; + uint32_t offset = 0; + uint32_t length = 0; + auto& mm_items = mm_data.items(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) { + video_grid_thw = res.value(); + } + + int32_t video_index = 0; + int32_t video_frames_left = 0; + + while (true) { + auto vision_start_it = + std::find(start, prompt.end(), vision_start_token_id_); + if (vision_start_it == prompt.end()) { + break; + } + auto vision_end_it = + std::find(vision_start_it + 1, prompt.end(), vision_end_token_id_); + CHECK(vision_end_it != prompt.end()); + + offset = std::distance(prompt.begin(), vision_start_it); + length = std::distance(vision_start_it + 1, vision_end_it); + + int32_t first_token = *(vision_start_it + 1); + if (first_token == image_token_id_) { + CHECK(global_mm_index < mm_items.size()); + auto& item = mm_items[global_mm_index]; + item.mutable_state().mutable_token_pos() = {offset + 1, length}; + ++global_mm_index; + + } else if (first_token == video_token_id_) { + if (video_frames_left == 0) { + CHECK(video_grid_thw.defined() && video_grid_thw.numel() > 0) + << "video token exists but video_grid_thw is missing"; + CHECK(video_index < video_grid_thw.size(0)); + CHECK(global_mm_index < mm_items.size()); + + video_frames_left = video_grid_thw[video_index][0].item(); + + auto& item = mm_items[global_mm_index]; + item.mutable_state().mutable_token_pos() = {offset + 1, length}; + + ++global_mm_index; + ++video_index; + } + + CHECK(video_frames_left > 0); + --video_frames_left; + } + + start = std::next(vision_end_it); + } +} + +std::pair +Qwen3_VLInputProcessor::find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); +} + +std::vector Qwen3_VLInputProcessor::build_timestamps( + const std::vector& timestamps, + size_t num_frames, + int32_t merge_size) { + CHECK_GT(merge_size, 0); + + if (timestamps.empty()) { + return std::vector(num_frames, 0.0); + } + + std::vector ts = timestamps; + const size_t rem = ts.size() % static_cast(merge_size); + if (rem != 0) { + ts.insert(ts.end(), static_cast(merge_size) - rem, ts.back()); + } + + std::vector out; + out.reserve(ts.size() / static_cast(merge_size)); + + for (size_t i = 0; i < ts.size(); i += static_cast(merge_size)) { + out.push_back((ts[i] + ts[i + static_cast(merge_size) - 1]) / 2.0); + } + + if (out.size() > num_frames) { + out.resize(num_frames); + } + while (out.size() < num_frames) { + out.push_back(out.back()); + } + + return out; +} + +std::string Qwen3_VLInputProcessor::format_timestamp_str(double timestamp) { + char buffer[32]; + snprintf(buffer, sizeof(buffer), "<%.1f seconds>", timestamp); + return buffer; +} + +} // namespace xllm diff --git a/xllm/processors/qwen3_vl_input_processor.h b/xllm/processors/qwen3_vl_input_processor.h new file mode 100644 index 000000000..f2536de06 --- /dev/null +++ b/xllm/processors/qwen3_vl_input_processor.h @@ -0,0 +1,63 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "core/framework/model/model_args.h" +#include "core/framework/request/mm_data.h" +#include "processors/input_processor.h" + +namespace xllm { + +class Qwen3_VLInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + explicit Qwen3_VLInputProcessor(const ModelArgs& args); + + void process(std::string& prompt, const MMData& mm_data) override; + void find_mm_spans(const std::vector& prompt, + MMData& mm_data) override; + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin); + + std::vector build_timestamps(const std::vector& timestamps, + size_t num_frames, + int32_t merge_size); + std::string format_timestamp_str(double timestamp); + + const std::string image_token_ = "<|image_pad|>"; + const std::string video_token_ = "<|video_pad|>"; + const std::string vision_start_token_ = "<|vision_start|>"; + const std::string vision_end_token_ = "<|vision_end|>"; + int32_t vision_start_token_id_; + int32_t vision_end_token_id_; + int32_t image_token_id_; + int32_t video_token_id_; + int32_t merge_size_ = 0; + int32_t temporal_patch_size_ = 0; +}; + +} // namespace xllm diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index 080f3ae09..859b40fa7 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -46,6 +46,9 @@ message Input { // Control Image optional string control_image = 13; + + // Condition Image + optional string condition_image = 14; } // Generation parameters container @@ -142,4 +145,4 @@ message ImageGenerationResponse { // Contains task details and generation results ImageGenerationOutput output = 5; -} \ No newline at end of file +} diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 92443e2f8..b4160b37f 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -193,6 +193,55 @@ message MMData { map dict = 2; } +message DiTForwardInput { + int32 batch_size = 1; + + // Primary input text description for image generation + repeated string prompts = 2; + + // Secondary prompt for additional details (e.g., color, lighting) + repeated string prompts_2 = 3; + + // Negative prompt to exclude low-quality features + repeated string negative_prompts = 4; + + // Secondary negative prompt to exclude additional unwanted features + repeated string negative_prompts_2 = 5; + + // Tensor fields + Tensor images = 6; + Tensor condition_images = 7; + Tensor mask_images = 8; + Tensor control_image = 9; + Tensor masked_image_latents = 10; + Tensor prompt_embeds = 11; + Tensor pooled_prompt_embeds = 12; + Tensor negative_prompt_embeds = 13; + Tensor negative_pooled_prompt_embeds = 14; + Tensor latents = 15; + + // generation params + DiTGenerationParams generation_params = 16; +} + +message DiTForwardOutput { + TensorList tensors = 1; +} + +message DiTGenerationParams { + int32 width = 1; + int32 height = 2; + int32 num_inference_steps = 3; + float true_cfg_scale = 4; + float guidance_scale = 5; + uint32 num_images_per_prompt = 6; + int64 seed = 7; + int32 max_sequence_length = 8; + float strength = 9; + bool enable_cfg_renorm = 10; + float cfg_renorm_min = 11; +} + message ForwardInput { // flatten the token ids and positions repeated int32 flatten_tokens_vec = 1; @@ -241,6 +290,7 @@ message ForwardInput { repeated int32 dp_is_decode = 42; repeated int32 kv_cache_tokens_nums = 43; repeated string request_ids = 44; + DiTForwardInput dit_forward_input = 45; } message BatchedForwardInputs { @@ -273,6 +323,7 @@ message ForwardOutput { repeated int32 src_seq_idxes = 5; repeated int32 out_tokens = 6; repeated float out_logprobs = 7; + DiTForwardOutput dit_forward_output = 8; } // master create Collective service diff --git a/xllm/server/xllm_server.cpp b/xllm/server/xllm_server.cpp index 0e859d53b..0934eb608 100644 --- a/xllm/server/xllm_server.cpp +++ b/xllm/server/xllm_server.cpp @@ -17,12 +17,89 @@ limitations under the License. #include #include +#include + +#include +#include #include "core/common/global_flags.h" #include "health_reporter.h" namespace xllm { +namespace { +volatile std::sig_atomic_t g_quit_flag = 0; + +void quit_signal_handler(int /*signum*/) { g_quit_flag = 1; } + +constexpr const char* kApiServiceRoutes = + "v1/completions => CompletionsHttp," + "v1/sample => SampleHttp," + "v1/chat/completions => ChatCompletionsHttp," + "v1/embeddings => EmbeddingsHttp," + "v1/models => ModelsHttp," + "v1/image/generation => ImageGenerationHttp," + "v1/rerank => RerankHttp," + "v1/messages => AnthropicMessagesHttp," + "v2/repository/index => ModelVersionsHttp," + "fork_master => ForkMasterHttp," + "sleep => SleepHttp," + "wakeup => WakeupHttp," + "link_d2d => LinkD2DHttp," + "unlink_d2d => UnlinkD2DHttp"; + +constexpr const char* kForkOnlyRoute = "fork_master => ForkMasterHttp"; + +struct ApiRouteBinding { + const char* name; + bool (*enabled)(); + const char* routes; +}; + +bool is_master_node() { return FLAGS_node_rank == 0; } + +bool is_xtensor_node() { return FLAGS_node_rank != 0 && FLAGS_enable_xtensor; } + +const char* get_api_service_routes_for_current_mode() { + static constexpr std::array kBindings = {{ + {"master_node", &is_master_node, kApiServiceRoutes}, + {"xtensor_node", &is_xtensor_node, kForkOnlyRoute}, + }}; + for (const auto& binding : kBindings) { + if (binding.enabled()) { + LOG(INFO) << "Use API route mode: " << binding.name; + return binding.routes; + } + } + return nullptr; +} + +void install_quit_signal_handler() { + g_quit_flag = 0; + struct sigaction sa = {}; + sa.sa_handler = quit_signal_handler; + sigemptyset(&sa.sa_mask); + sigaction(SIGINT, &sa, nullptr); + sigaction(SIGTERM, &sa, nullptr); +} + +void wait_for_quit_signal() { + while (!g_quit_flag) { + sleep(1); + } +} + +bool configure_generic_server(brpc::Server* server, + google::protobuf::Service* service, + const std::string& server_name) { + if (server->AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "Fail to add " << server_name << " service"; + return false; + } + return true; +} +} // namespace + XllmServer::XllmServer() { butil::AtExitManager exit_manager; } XllmServer::~XllmServer() { @@ -35,33 +112,15 @@ XllmServer::~XllmServer() { bool XllmServer::start(std::unique_ptr service) { server_ = std::make_unique(); - if (FLAGS_node_rank == 0) { - if (server_->AddService(service.get(), - brpc::SERVER_DOESNT_OWN_SERVICE, - "v1/completions => CompletionsHttp," - "v1/sample => SampleHttp," - "v1/chat/completions => ChatCompletionsHttp," - "v1/embeddings => EmbeddingsHttp," - "v1/models => ModelsHttp," - "v1/image/generation => ImageGenerationHttp," - "v1/rerank => RerankHttp," - "v1/messages => AnthropicMessagesHttp," - "v2/repository/index => ModelVersionsHttp," - "fork_master => ForkMasterHttp," - "sleep => SleepHttp," - "wakeup => WakeupHttp," - "link_d2d => LinkD2DHttp," - "unlink_d2d => UnlinkD2DHttp") != 0) { - LOG(ERROR) << "Fail to add api service"; - return false; - } - } else if (FLAGS_enable_xtensor) { - if (server_->AddService(service.get(), - brpc::SERVER_DOESNT_OWN_SERVICE, - "fork_master => ForkMasterHttp") != 0) { + if (const char* routes = get_api_service_routes_for_current_mode(); + routes != nullptr) { + if (server_->AddService( + service.get(), brpc::SERVER_DOESNT_OWN_SERVICE, routes) != 0) { LOG(ERROR) << "Fail to add api service"; return false; } + } else { + LOG(INFO) << "No API routes enabled on current node mode."; } brpc::ServerOptions options; @@ -84,8 +143,22 @@ bool XllmServer::start(std::unique_ptr service) { std::string(butil::endpoint2str(server_->listen_address()).c_str()); listen_port_ = FLAGS_port; has_initialized_ = true; - // Wait until Ctrl-C is pressed, then Stop() and Join() the server. - server_->RunUntilAskedToQuit(); + + auto pid = getpid(); + LOG(INFO) << " Started server process [" << pid << "]"; + LOG(INFO) << " Waiting for application startup."; + LOG(INFO) << " Application startup complete."; + + install_quit_signal_handler(); + wait_for_quit_signal(); + + LOG(INFO) << " Shutting down"; + LOG(INFO) << " Waiting for application shutdown."; + + stop(); + + LOG(INFO) << " Application shutdown complete."; + LOG(INFO) << " Finished server process [" << pid << "]"; return true; } @@ -145,54 +218,18 @@ bool XllmServer::start(std::shared_ptr service, bool XllmServer::start(std::shared_ptr service, const std::string& addr) { - server_ = std::make_unique(); - if (server_->AddService(service.get(), brpc::SERVER_DOESNT_OWN_SERVICE) != - 0) { - LOG(ERROR) << "Fail to add DistributeWorker service"; - return false; - } - - brpc::ServerOptions options; - options.idle_timeout_sec = FLAGS_rpc_idle_timeout_s; - options.num_threads = FLAGS_num_threads; - listen_address_ = addr; - if (server_->Start(addr.c_str(), &options) != 0) { - LOG(ERROR) << "Failed to start distribute server on address: " << addr; - return false; - } - listen_port_ = server_->listen_address().port; - LOG(INFO) << "DistributeWorker started on address " - << server_->listen_address() - << ", idle_timeout_sec: " << FLAGS_rpc_idle_timeout_s - << ", num_threads: " << FLAGS_num_threads; - - return true; + return create_server(static_cast(service.get()), + addr, + -1, + "DistributeWorker"); } bool XllmServer::start(std::shared_ptr service, const std::string& addr) { - server_ = std::make_unique(); - if (server_->AddService(service.get(), brpc::SERVER_DOESNT_OWN_SERVICE) != - 0) { - LOG(ERROR) << "Fail to add XTensorDist service"; - return false; - } - - brpc::ServerOptions options; - options.idle_timeout_sec = FLAGS_rpc_idle_timeout_s; - options.num_threads = FLAGS_num_threads; - listen_address_ = addr; - if (server_->Start(addr.c_str(), &options) != 0) { - LOG(ERROR) << "Failed to start XTensorDist server on address: " << addr; - return false; - } - listen_port_ = server_->listen_address().port; - LOG(INFO) << "XTensorDist server started on address " - << server_->listen_address() - << ", idle_timeout_sec: " << FLAGS_rpc_idle_timeout_s - << ", num_threads: " << FLAGS_num_threads; - - return true; + return create_server(static_cast(service.get()), + addr, + -1, + "XTensorDist"); } bool XllmServer::create_server(google::protobuf::Service* service, @@ -200,8 +237,7 @@ bool XllmServer::create_server(google::protobuf::Service* service, int port, const std::string& server_name) { server_ = std::make_unique(); - if (server_->AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { - LOG(ERROR) << "Fail to add " << server_name << " service"; + if (!configure_generic_server(server_.get(), service, server_name)) { return false; } @@ -218,17 +254,20 @@ bool XllmServer::create_server(google::protobuf::Service* service, } } else { endpoint = butil::EndPoint(butil::IP_ANY, port); - listen_address_ = - std::string(butil::endpoint2str(server_->listen_address()).c_str()); } - listen_port_ = port > 0 ? port : server_->listen_address().port; if (server_->Start(endpoint, &options) != 0) { LOG(ERROR) << "Failed to start " << server_name << " server on address: " << endpoint; return false; } - LOG(INFO) << server_name << " server started on address " << endpoint + + listen_address_ = + std::string(butil::endpoint2str(server_->listen_address()).c_str()); + listen_port_ = server_->listen_address().port; + + LOG(INFO) << server_name << " server started on address " + << server_->listen_address() << ", idle_timeout_sec: " << FLAGS_rpc_idle_timeout_s << ", num_threads: " << FLAGS_num_threads; @@ -245,6 +284,9 @@ void XllmServer::run() { } void XllmServer::stop() { + if (!server_) { + return; + } server_->Stop(0); server_->Join(); } diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 42a4ab845..413fc456b 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -31,13 +31,12 @@ limitations under the License. #include "core/common/metrics.h" #include "core/common/options.h" #include "core/common/types.h" +#include "core/distributed_runtime/dit_master.h" #include "core/distributed_runtime/master.h" #include "core/framework/xtensor/global_xtensor.h" #include "core/framework/xtensor/options.h" #include "core/framework/xtensor/xtensor_allocator.h" #include "core/util/device_name_utils.h" -#include "core/util/json_reader.h" -#include "core/util/model_config_utils.h" #include "core/util/net.h" #include "core/util/utils.h" #include "function_call/function_call_parser.h" @@ -51,6 +50,45 @@ static const std::unordered_set prefill_sp_supported_model_set = { "deepseek_v32", "glm_moe_dsa"}; +namespace { + +void fix_mlu_disagg_pd_flags() { + if (FLAGS_kv_cache_transfer_type != "Mooncake") { + LOG(WARNING) << "MLU disaggregated PD requires " + << "kv_cache_transfer_type=Mooncake; forcing from " + << FLAGS_kv_cache_transfer_type << " to Mooncake."; + FLAGS_kv_cache_transfer_type = "Mooncake"; + } + if (FLAGS_kv_cache_transfer_mode != "PUSH") { + LOG(WARNING) << "MLU disaggregated PD requires " + << "kv_cache_transfer_mode=PUSH; forcing from " + << FLAGS_kv_cache_transfer_mode << " to PUSH."; + FLAGS_kv_cache_transfer_mode = "PUSH"; + } + if (FLAGS_kv_cache_dtype != "auto") { + LOG(WARNING) << "MLU disaggregated PD requires kv_cache_dtype=auto; " + << "forcing from " << FLAGS_kv_cache_dtype << " to auto."; + FLAGS_kv_cache_dtype = "auto"; + } + if (FLAGS_enable_prefix_cache) { + LOG(WARNING) << "MLU disaggregated PD does not support prefix cache; " + << "forcing enable_prefix_cache=false."; + FLAGS_enable_prefix_cache = false; + } + if (FLAGS_enable_chunked_prefill) { + LOG(WARNING) << "MLU disaggregated PD does not support chunked prefill; " + << "forcing enable_chunked_prefill=false."; + FLAGS_enable_chunked_prefill = false; + } + if (FLAGS_enable_pd_ooc) { + LOG(WARNING) << "MLU disaggregated PD does not support pd_ooc; " + << "forcing enable_pd_ooc=false."; + FLAGS_enable_pd_ooc = false; + } +} + +} // namespace + void shutdown_handler(int signal) { // TODO: gracefully shutdown the server LOG(WARNING) << "Received signal " << signal << ", stopping server..."; @@ -68,11 +106,25 @@ void validate_flags(const std::string& model_type) { << model_type; } #if defined(USE_MLU) + // Disable enable_schedule_overlap for VLM models on MLU backend + if (FLAGS_enable_schedule_overlap && FLAGS_backend == "vlm") { + LOG(WARNING) << "enable_schedule_overlap is not supported for VLM models " + "on MLU backend. " + << "Disabling enable_schedule_overlap."; + FLAGS_enable_schedule_overlap = false; + } // TODO: support other block sizes in the future - if (FLAGS_block_size != 16 && FLAGS_block_size != 1) { + if (FLAGS_block_size != 16 && FLAGS_block_size != 1 && + FLAGS_backend != "dit") { LOG(FATAL) << "Currently, block_size must be 16 for MLU backend, we will " "support other block sizes in the future."; } + if (FLAGS_enable_disagg_pd) { + if (FLAGS_backend != "llm") { + LOG(FATAL) << "MLU disaggregated PD only supports backend=llm."; + } + fix_mlu_disagg_pd_flags(); + } #endif #if defined(USE_NPU) @@ -116,15 +168,11 @@ int run() { std::filesystem::path model_path = std::filesystem::path(FLAGS_model).lexically_normal(); + const std::string default_model_name = xllm::util::get_model_name(model_path); if (FLAGS_model_id.empty()) { // use last part of the path as model id - if (model_path.has_filename()) { - FLAGS_model_id = std::filesystem::path(FLAGS_model).filename(); - } else { - FLAGS_model_id = - std::filesystem::path(FLAGS_model).parent_path().filename(); - } + FLAGS_model_id = default_model_name; } if (FLAGS_backend.empty()) { @@ -158,13 +206,13 @@ int run() { FLAGS_max_tokens_per_chunk_for_prefill = FLAGS_max_tokens_per_batch; } -// disable block copy kernel on non-NPU backend -#if !defined(USE_NPU) +// disable block copy kernel on unsupported backends +#if !defined(USE_NPU) && !defined(USE_CUDA) FLAGS_enable_block_copy_kernel = false; #endif - - std::string model_type = xllm::util::get_model_type(model_path); + std::string model_type = ""; if (FLAGS_backend != "dit") { + model_type = xllm::util::get_model_type(model_path); FLAGS_tool_call_parser = function_call::FunctionCallParser::get_parser_auto( FLAGS_tool_call_parser, model_type); FLAGS_reasoning_parser = @@ -226,12 +274,16 @@ int run() { .dp_size(FLAGS_dp_size) .cp_size(FLAGS_cp_size) .ep_size(FLAGS_ep_size) + .tp_size(FLAGS_tp_size) + .sp_size(FLAGS_sp_size) + .cfg_size(FLAGS_cfg_size) .instance_name(FLAGS_host + ":" + std::to_string(FLAGS_port)) .enable_disagg_pd(FLAGS_enable_disagg_pd) .enable_pd_ooc(FLAGS_enable_pd_ooc) .enable_schedule_overlap(FLAGS_enable_schedule_overlap) .kv_cache_transfer_mode(FLAGS_kv_cache_transfer_mode) .etcd_addr(FLAGS_etcd_addr) + .etcd_namespace(FLAGS_etcd_namespace) .enable_service_routing(FLAGS_enable_service_routing || FLAGS_enable_disagg_pd) .tool_call_parser(FLAGS_tool_call_parser) @@ -307,7 +359,11 @@ int run() { std::unique_ptr master; // working node if (options.node_rank() != 0) { - master = std::make_unique(options); + if (FLAGS_backend == "dit") { + master = std::make_unique(options); + } else { + master = std::make_unique(options); + } } else { if (FLAGS_random_seed < 0) { FLAGS_random_seed = std::random_device{}() % (1 << 30); @@ -319,12 +375,7 @@ int run() { // supported models std::vector model_names = {FLAGS_model_id}; - std::string model_version; - if (model_path.has_filename()) { - model_version = std::filesystem::path(FLAGS_model).filename(); - } else { - model_version = std::filesystem::path(FLAGS_model).parent_path().filename(); - } + std::string model_version = default_model_name; std::vector model_versions = {model_version}; if (FLAGS_node_rank == 0 || FLAGS_enable_xtensor) {