Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion xllm/core/common/global_flags.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,48 @@ DEFINE_string(dit_cache_policy,
"The policy of dit cache(e.g. None, FBCache, TaylorSeer, "
"FBCacheTaylorSeer, ResidualCache).");

DEFINE_int64(dit_cache_warmup_steps, 0, "The number of warmup steps.");
DEFINE_int64(dit_cache_warmup_steps, 5, "The number of warmup steps.");
DEFINE_int64(
dit_cache_probe_depth,
2,
"Number of DiT blocks to run before RACFGCache makes a joint decision.");

DEFINE_double(dit_cache_tau,
0.28,
"Accumulated guided-risk threshold for RACFGCache.");

DEFINE_bool(dit_cache_use_prop_weight,
true,
"Whether RACFGCache uses propagation-aware reweighting.");

DEFINE_double(dit_cache_prop_a,
1.7997948016,
"Propagation-aware fitted parameter a for RACFGCache.");

DEFINE_double(dit_cache_prop_alpha,
1.1729645804,
"Propagation-aware fitted parameter alpha for RACFGCache.");

DEFINE_double(dit_cache_prop_b,
-0.1016712615,
"Propagation-aware fitted parameter b for RACFGCache.");

DEFINE_int64(dit_cache_proxy_error_type,
0,
"Proxy error type for RACFGCache: 0=delta_y, 1=delta_minus.");

DEFINE_string(dit_cache_rho_table_path,
"qwen_image_edit_cfg4_steps40",
"Path to the offline rho lookup table used by RACFGCache.");

DEFINE_string(
dit_cache_model_name,
"qwen_image_edit_plus",
"Model name for selecting hardcoded rho table when using RACFGCache.");

DEFINE_double(true_cfg_scale,
4.0,
"Optional CFG scale recorded with the rho table for RACFGCache.");
DEFINE_int64(dit_cache_n_derivatives,
3,
"The number of derivatives to use in TaylorSeer.");
Expand Down
21 changes: 21 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ DECLARE_int64(dit_cache_skip_interval_steps);
DECLARE_double(dit_cache_residual_diff_threshold);

DECLARE_bool(enable_constrained_decoding);

DECLARE_bool(enable_convert_tokens_to_item);
DECLARE_bool(enable_output_sku_logprobs);
DECLARE_int32(each_conversion_threshold);
Expand All @@ -309,6 +310,26 @@ DECLARE_int64(dit_cache_start_blocks);

DECLARE_int64(dit_cache_end_blocks);

DECLARE_int64(dit_cache_probe_depth);

DECLARE_double(dit_cache_tau);

DECLARE_bool(dit_cache_use_prop_weight);

DECLARE_double(dit_cache_prop_a);

DECLARE_double(dit_cache_prop_alpha);

DECLARE_double(dit_cache_prop_b);

DECLARE_int64(dit_cache_proxy_error_type);

DECLARE_string(dit_cache_rho_table_path);

DECLARE_string(dit_cache_model_name);

DECLARE_double(true_cfg_scale);

DECLARE_int64(tp_size);

DECLARE_int64(sp_size);
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/framework/dit_cache/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ cc_library(
fbcache_taylorseer.h
taylorseer.h
residual_cache.h
racfgcache.h
calibration/racfgcache_calibration_tables.h
SRCS
dit_cache_impl.cpp
dit_cache.cpp
Expand All @@ -22,6 +24,8 @@ cc_library(
fbcache_taylorseer.cpp
taylorseer.cpp
residual_cache.cpp
racfgcache.cpp
calibration/racfgcache_calibration_tables.cpp
DEPS
torch
glog::glog
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include <string>

namespace xllm {

// Structured identifier for a hardcoded rho table.
// This is designed for future extensibility across:
// - different model families
// - different CFG scales
// - different inference step counts
struct RhoTableSpec {
std::string model_name;
float cfg_scale = 0.0f;
int64_t infer_steps = 0;

bool operator==(const RhoTableSpec& other) const {
return model_name == other.model_name && cfg_scale == other.cfg_scale &&
infer_steps == other.infer_steps;
}
};

// Hash support for unordered_map.
struct RhoTableSpecHash {
std::size_t operator()(const RhoTableSpec& spec) const;
};

// Return a hardcoded rho table if the exact spec is registered.
// Return an undefined tensor if not found.
torch::Tensor get_hardcoded_rho_table(const RhoTableSpec& spec);

// Return whether an exact hardcoded rho table exists for the given spec.
bool has_hardcoded_rho_table(const RhoTableSpec& spec);

// Convert spec to a readable string for logging/debugging.
std::string to_string(const RhoTableSpec& spec);

} // namespace xllm
5 changes: 5 additions & 0 deletions xllm/core/framework/dit_cache/dit_cache.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,9 @@ CacheStepOut DiTCache::on_after_step(const CacheStepIn& stepin, bool use_cfg) {
return active_cache_->on_after_step(stepin);
}

void DiTCache::set_runtime_context(const DiTCacheRuntimeContext& ctx) {
if (active_cache_) active_cache_->set_runtime_context(ctx);
if (active_cond_cache_) active_cond_cache_->set_runtime_context(ctx);
}

} // namespace xllm
2 changes: 2 additions & 0 deletions xllm/core/framework/dit_cache/dit_cache.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class DiTCache {
active_cond_cache_->set_num_blocks(num_blocks);
}

void set_runtime_context(const DiTCacheRuntimeContext& ctx);

private:
torch::Tensor get_tensor_or_empty(const TensorMap& m, const std::string& k);

Expand Down
38 changes: 37 additions & 1 deletion xllm/core/framework/dit_cache/dit_cache_config.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/

#pragma once
#include <cstdint>
#include <string>

namespace xllm {

Expand All @@ -22,7 +24,8 @@ enum class PolicyType {
FBCache,
TaylorSeer,
FBCacheTaylorSeer,
ResidualCache
ResidualCache,
RACFGCache
};

struct DiTBaseCacheOptions {
Expand Down Expand Up @@ -68,6 +71,36 @@ struct ResidualCacheOptions {
int64_t skip_interval_steps = 3;
};

struct RACFGCacheOptions : public DiTBaseCacheOptions {
// Number of blocks to run before making the joint decision.
int64_t probe_depth = 2;

// Joint accumulated-risk threshold.
float tau = 0.0f;

// True CFG scale used in the guided combination.
float true_cfg_scale = 3.0f;

// Whether to use propagation-aware reweighting.
bool use_prop_weight = true;

// Propagation-aware fitted parameters.
float prop_a = 0.4806166f;
float prop_alpha = 0.4782565f;
float prop_b = 0.0641170f;

// Branch-local proxy error choice:
// 0 -> delta_y
// 1 -> delta_minus
int64_t proxy_error_type = 0;

// Offline rho table path.
std::string rho_table_path = "";
std::string model_name = "";
// Optional matched CFG scale recorded with the rho table.
float matched_cfg_scale = -1.0f;
};

struct DiTCacheConfig {
DiTCacheConfig() = default;

Expand All @@ -85,6 +118,9 @@ struct DiTCacheConfig {

// the configuration for ResidualCache policy.
ResidualCacheOptions residual_cache;

// the configuration for RACFGCache policy.
RACFGCacheOptions racfgcache;
};

} // namespace xllm
31 changes: 27 additions & 4 deletions xllm/core/framework/dit_cache/dit_cache_impl.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ limitations under the License.

#include "dit_cache_impl.h"

#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>

#include "dit_non_cache.h"
#include "fbcache.h"
#include "fbcache_taylorseer.h"
#include "framework/parallel_state/parallel_state.h"
#include "racfgcache.h"
#include "residual_cache.h"
#include "taylorseer.h"

Expand All @@ -32,17 +36,34 @@ torch::Tensor DitCacheImpl::get_tensor_or_empty(const TensorMap& m,

bool DitCacheImpl::is_similar(const torch::Tensor& lhs,
const torch::Tensor& rhs,
float threshold) {
float threshold) const {
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs.sizes() != rhs.sizes()) return false;

if (threshold <= 0.0f) {
return torch::allclose(lhs, rhs);
}

auto diff = (lhs - rhs).abs();
auto mean_diff = diff.mean();
auto mean_lhs = lhs.abs().mean();
torch::Device dev = lhs.device();
auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(dev);

auto sum_abs_diff = (lhs - rhs).abs().sum().to(torch::kFloat32);
auto sum_abs_lhs = lhs.abs().sum().to(torch::kFloat32);
auto count = torch::tensor({static_cast<float>(lhs.numel())}, opts);

if (runtime_ctx_.sp_enabled && runtime_ctx_.sp_world_size > 1 &&
runtime_ctx_.sp_group != nullptr) {
auto* sp_group = static_cast<ProcessGroup*>(runtime_ctx_.sp_group);
CHECK(sp_group != nullptr)
<< "sp_group is null in DitCacheImpl::is_similar";

sum_abs_diff = xllm::parallel_state::reduce(sum_abs_diff, sp_group);
sum_abs_lhs = xllm::parallel_state::reduce(sum_abs_lhs, sp_group);
count = xllm::parallel_state::reduce(count, sp_group);
}

auto mean_diff = sum_abs_diff / count;
auto mean_lhs = sum_abs_lhs / count;

auto rel = mean_diff / (mean_lhs + 1e-6);
return (rel < threshold).item<bool>();
Expand All @@ -58,6 +79,8 @@ std::unique_ptr<DitCacheImpl> create_dit_cache(const DiTCacheConfig& cfg) {
return std::make_unique<FBCacheTaylorSeer>();
case PolicyType::ResidualCache:
return std::make_unique<ResidualCache>();
case PolicyType::RACFGCache:
return std::make_unique<RACFGCache>();
default:
return std::make_unique<DiTNonCache>();
}
Expand Down
34 changes: 31 additions & 3 deletions xllm/core/framework/dit_cache/dit_cache_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#pragma once
#include <cstdint>
#include <string>
#include <unordered_map>

Expand All @@ -24,6 +25,24 @@ namespace xllm {

using TensorMap = std::unordered_map<std::string, torch::Tensor>;

struct DiTCacheRuntimeContext {
std::string model_name;
// cfg parallel runtime info
void* cfg_group = nullptr;
int64_t cfg_rank = 0;
int64_t cfg_world_size = 1;
bool cfg_enabled = false;
float true_cfg_scale = 1.0f;
// sequence parallel runtime info
void* sp_group = nullptr;
int64_t sp_rank = 0;
int64_t sp_world_size = 1;
bool sp_enabled = false;
// optional runtime metadata
int64_t infer_steps = 0;
int64_t num_blocks = 0;
};

class DitCacheImpl {
public:
DitCacheImpl() = default;
Expand All @@ -45,19 +64,28 @@ class DitCacheImpl {
num_blocks_ = num_blocks;
}

virtual void set_runtime_context(const DiTCacheRuntimeContext& ctx) {
runtime_ctx_ = ctx;
}

const DiTCacheRuntimeContext& get_runtime_context() const {
return runtime_ctx_;
}

protected:
int64_t num_inference_steps_;
int64_t warmup_steps_;
int64_t current_step_;
int64_t infer_steps_;
int64_t num_blocks_;
TensorMap buffers;
DiTCacheRuntimeContext runtime_ctx_;

static torch::Tensor get_tensor_or_empty(const TensorMap& m,
const std::string& k);
static bool is_similar(const torch::Tensor& lhs,
const torch::Tensor& rhs,
float threshold);
bool is_similar(const torch::Tensor& lhs,
const torch::Tensor& rhs,
float threshold) const;
};

std::unique_ptr<DitCacheImpl> create_dit_cache(const DiTCacheConfig& cfg);
Expand Down
Loading
Loading