diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp old mode 100644 new mode 100755 index fa24d4d9c..385a5081a --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -607,8 +607,48 @@ DEFINE_string(dit_cache_policy, "The policy of dit cache(e.g. None, FBCache, TaylorSeer, " "FBCacheTaylorSeer, ResidualCache)."); -DEFINE_int64(dit_cache_warmup_steps, 0, "The number of warmup steps."); +DEFINE_int64(dit_cache_warmup_steps, 5, "The number of warmup steps."); +DEFINE_int64( + dit_cache_probe_depth, + 2, + "Number of DiT blocks to run before RACFGCache makes a joint decision."); + +DEFINE_double(dit_cache_tau, + 0.28, + "Accumulated guided-risk threshold for RACFGCache."); + +DEFINE_bool(dit_cache_use_prop_weight, + true, + "Whether RACFGCache uses propagation-aware reweighting."); + +DEFINE_double(dit_cache_prop_a, + 1.7997948016, + "Propagation-aware fitted parameter a for RACFGCache."); + +DEFINE_double(dit_cache_prop_alpha, + 1.1729645804, + "Propagation-aware fitted parameter alpha for RACFGCache."); + +DEFINE_double(dit_cache_prop_b, + -0.1016712615, + "Propagation-aware fitted parameter b for RACFGCache."); + +DEFINE_int64(dit_cache_proxy_error_type, + 0, + "Proxy error type for RACFGCache: 0=delta_y, 1=delta_minus."); + +DEFINE_string(dit_cache_rho_table_path, + "qwen_image_edit_cfg4_steps40", + "Path to the offline rho lookup table used by RACFGCache."); + +DEFINE_string( + dit_cache_model_name, + "qwen_image_edit_plus", + "Model name for selecting hardcoded rho table when using RACFGCache."); +DEFINE_double(true_cfg_scale, + 4.0, + "Optional CFG scale recorded with the rho table for RACFGCache."); DEFINE_int64(dit_cache_n_derivatives, 3, "The number of derivatives to use in TaylorSeer."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 59412570d..40e3c6b74 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -294,6 +294,7 @@ DECLARE_int64(dit_cache_skip_interval_steps); DECLARE_double(dit_cache_residual_diff_threshold); DECLARE_bool(enable_constrained_decoding); + DECLARE_bool(enable_convert_tokens_to_item); DECLARE_bool(enable_output_sku_logprobs); DECLARE_int32(each_conversion_threshold); @@ -309,6 +310,26 @@ DECLARE_int64(dit_cache_start_blocks); DECLARE_int64(dit_cache_end_blocks); +DECLARE_int64(dit_cache_probe_depth); + +DECLARE_double(dit_cache_tau); + +DECLARE_bool(dit_cache_use_prop_weight); + +DECLARE_double(dit_cache_prop_a); + +DECLARE_double(dit_cache_prop_alpha); + +DECLARE_double(dit_cache_prop_b); + +DECLARE_int64(dit_cache_proxy_error_type); + +DECLARE_string(dit_cache_rho_table_path); + +DECLARE_string(dit_cache_model_name); + +DECLARE_double(true_cfg_scale); + DECLARE_int64(tp_size); DECLARE_int64(sp_size); diff --git a/xllm/core/framework/dit_cache/CMakeLists.txt b/xllm/core/framework/dit_cache/CMakeLists.txt old mode 100644 new mode 100755 index 7509fd99b..6185fcf4f --- a/xllm/core/framework/dit_cache/CMakeLists.txt +++ b/xllm/core/framework/dit_cache/CMakeLists.txt @@ -14,6 +14,8 @@ cc_library( fbcache_taylorseer.h taylorseer.h residual_cache.h + racfgcache.h + calibration/racfgcache_calibration_tables.h SRCS dit_cache_impl.cpp dit_cache.cpp @@ -22,6 +24,8 @@ cc_library( fbcache_taylorseer.cpp taylorseer.cpp residual_cache.cpp + racfgcache.cpp + calibration/racfgcache_calibration_tables.cpp DEPS torch glog::glog diff --git a/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.cpp b/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.cpp new file mode 100644 index 000000000..b82092263 --- /dev/null +++ b/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.cpp @@ -0,0 +1,308 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "racfgcache_calibration_tables.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace xllm { +namespace { + +using RhoTableBuilder = torch::Tensor (*)(); + +inline float NaN() { return std::numeric_limits::quiet_NaN(); } + +// Build a square rho table with shape [steps, steps]. +torch::Tensor make_square_table(int64_t steps, + const std::vector& values) { + CHECK(steps > 0) << "steps must be positive, got " << steps; + CHECK(values.size() == static_cast(steps * steps)) + << "rho table size mismatch, expected " << steps * steps << ", got " + << values.size(); + + return torch::tensor(values, torch::TensorOptions().dtype(torch::kFloat32)) + .view({steps, steps}) + .contiguous() + .clone(); +} + +// Registered hardcoded tables +// Add new tables by: +// +// 1. defining a new build_xxx() function +// 2. adding one entry to GetRhoRegistry() + +torch::Tensor build_qwen_image_edit_plus_cfg4_steps40() { + static torch::Tensor table = make_square_table( + 40, + {NaN(), 0.7310f, 0.8270f, 0.8686f, 0.8948f, 0.9094f, 0.9218f, 0.9295f, + 0.9373f, 0.9442f, 0.9500f, 0.9553f, 0.9587f, 0.9630f, 0.9662f, 0.9694f, + 0.9723f, 0.9749f, 0.9773f, 0.9795f, 0.9818f, 0.9836f, 0.9853f, 0.9868f, + 0.9883f, 0.9898f, 0.9909f, 0.9921f, 0.9932f, 0.9941f, 0.9951f, 0.9957f, + 0.9966f, 0.9974f, 0.9980f, 0.9988f, 0.9994f, 1.0002f, 1.0019f, 1.0050f, + NaN(), NaN(), 0.6363f, 0.7810f, 0.8425f, 0.8759f, 0.9009f, 0.9150f, + 0.9273f, 0.9381f, 0.9461f, 0.9536f, 0.9580f, 0.9632f, 0.9672f, 0.9707f, + 0.9742f, 0.9770f, 0.9796f, 0.9818f, 0.9841f, 0.9859f, 0.9874f, 0.9888f, + 0.9901f, 0.9915f, 0.9926f, 0.9935f, 0.9945f, 0.9953f, 0.9961f, 0.9966f, + 0.9974f, 0.9980f, 0.9985f, 0.9991f, 0.9996f, 1.0005f, 1.0022f, 1.0054f, + NaN(), NaN(), NaN(), 0.5467f, 0.7181f, 0.8004f, 0.8535f, 0.8786f, + 0.9002f, 0.9190f, 0.9311f, 0.9421f, 0.9488f, 0.9562f, 0.9616f, 0.9665f, + 0.9709f, 0.9745f, 0.9777f, 0.9804f, 0.9829f, 0.9850f, 0.9868f, 0.9882f, + 0.9897f, 0.9913f, 0.9923f, 0.9935f, 0.9945f, 0.9953f, 0.9961f, 0.9966f, + 0.9973f, 0.9981f, 0.9985f, 0.9992f, 0.9997f, 1.0005f, 1.0023f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), 0.4969f, 0.6716f, 0.7815f, 0.8289f, + 0.8667f, 0.8951f, 0.9137f, 0.9292f, 0.9388f, 0.9487f, 0.9558f, 0.9621f, + 0.9676f, 0.9719f, 0.9756f, 0.9788f, 0.9818f, 0.9841f, 0.9860f, 0.9877f, + 0.9893f, 0.9910f, 0.9922f, 0.9934f, 0.9945f, 0.9953f, 0.9961f, 0.9967f, + 0.9974f, 0.9981f, 0.9986f, 0.9992f, 0.9998f, 1.0006f, 1.0023f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), 0.4771f, 0.6648f, 0.7509f, + 0.8174f, 0.8629f, 0.8908f, 0.9134f, 0.9259f, 0.9390f, 0.9486f, 0.9567f, + 0.9636f, 0.9688f, 0.9731f, 0.9769f, 0.9803f, 0.9829f, 0.9850f, 0.9870f, + 0.9886f, 0.9905f, 0.9917f, 0.9930f, 0.9942f, 0.9950f, 0.9959f, 0.9964f, + 0.9973f, 0.9980f, 0.9985f, 0.9992f, 0.9997f, 1.0005f, 1.0023f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.4765f, 0.6389f, + 0.7531f, 0.8145f, 0.8585f, 0.8910f, 0.9094f, 0.9266f, 0.9389f, 0.9490f, + 0.9579f, 0.9642f, 0.9697f, 0.9740f, 0.9779f, 0.9810f, 0.9836f, 0.9856f, + 0.9876f, 0.9897f, 0.9911f, 0.9924f, 0.9937f, 0.9946f, 0.9955f, 0.9962f, + 0.9970f, 0.9978f, 0.9983f, 0.9990f, 0.9996f, 1.0005f, 1.0023f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.4385f, + 0.6383f, 0.7278f, 0.8023f, 0.8565f, 0.8829f, 0.9081f, 0.9250f, 0.9379f, + 0.9497f, 0.9577f, 0.9647f, 0.9699f, 0.9746f, 0.9785f, 0.9815f, 0.9839f, + 0.9864f, 0.9886f, 0.9902f, 0.9917f, 0.9931f, 0.9941f, 0.9951f, 0.9959f, + 0.9968f, 0.9976f, 0.9982f, 0.9990f, 0.9996f, 1.0004f, 1.0023f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + 0.4180f, 0.6034f, 0.7215f, 0.8044f, 0.8456f, 0.8825f, 0.9067f, 0.9246f, + 0.9397f, 0.9502f, 0.9593f, 0.9658f, 0.9713f, 0.9758f, 0.9793f, 0.9823f, + 0.9849f, 0.9874f, 0.9892f, 0.9910f, 0.9925f, 0.9936f, 0.9947f, 0.9955f, + 0.9966f, 0.9974f, 0.9980f, 0.9988f, 0.9995f, 1.0004f, 1.0023f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), 0.4438f, 0.6026f, 0.7271f, 0.7919f, 0.8471f, 0.8820f, 0.9070f, + 0.9275f, 0.9409f, 0.9521f, 0.9602f, 0.9670f, 0.9724f, 0.9766f, 0.9800f, + 0.9832f, 0.9860f, 0.9881f, 0.9900f, 0.9917f, 0.9930f, 0.9942f, 0.9951f, + 0.9961f, 0.9971f, 0.9978f, 0.9986f, 0.9994f, 1.0003f, 1.0022f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), 0.4017f, 0.6266f, 0.7222f, 0.8027f, 0.8508f, 0.8841f, + 0.9121f, 0.9289f, 0.9434f, 0.9536f, 0.9622f, 0.9687f, 0.9738f, 0.9777f, + 0.9815f, 0.9847f, 0.9870f, 0.9891f, 0.9910f, 0.9924f, 0.9937f, 0.9947f, + 0.9958f, 0.9969f, 0.9976f, 0.9985f, 0.9992f, 1.0003f, 1.0022f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), 0.4453f, 0.5923f, 0.7231f, 0.7986f, 0.8504f, + 0.8889f, 0.9125f, 0.9317f, 0.9447f, 0.9557f, 0.9638f, 0.9700f, 0.9747f, + 0.9791f, 0.9828f, 0.9855f, 0.9879f, 0.9901f, 0.9917f, 0.9931f, 0.9942f, + 0.9955f, 0.9966f, 0.9974f, 0.9983f, 0.9991f, 1.0002f, 1.0022f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), 0.4092f, 0.6031f, 0.7228f, 0.8052f, + 0.8584f, 0.8902f, 0.9172f, 0.9340f, 0.9479f, 0.9579f, 0.9653f, 0.9710f, + 0.9763f, 0.9807f, 0.9838f, 0.9864f, 0.9889f, 0.9907f, 0.9924f, 0.9936f, + 0.9950f, 0.9962f, 0.9971f, 0.9980f, 0.9989f, 1.0000f, 1.0021f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), 0.4017f, 0.6019f, 0.7334f, + 0.8134f, 0.8612f, 0.8973f, 0.9194f, 0.9373f, 0.9502f, 0.9595f, 0.9667f, + 0.9730f, 0.9781f, 0.9818f, 0.9849f, 0.9879f, 0.9898f, 0.9917f, 0.9931f, + 0.9945f, 0.9958f, 0.9969f, 0.9979f, 0.9988f, 1.0000f, 1.0020f, 1.0055f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.3901f, 0.6117f, + 0.7383f, 0.8139f, 0.8672f, 0.8994f, 0.9227f, 0.9401f, 0.9517f, 0.9610f, + 0.9686f, 0.9750f, 0.9792f, 0.9830f, 0.9863f, 0.9886f, 0.9907f, 0.9923f, + 0.9939f, 0.9954f, 0.9965f, 0.9976f, 0.9986f, 0.9999f, 1.0019f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.4308f, + 0.6192f, 0.7413f, 0.8231f, 0.8700f, 0.9035f, 0.9262f, 0.9416f, 0.9536f, + 0.9632f, 0.9708f, 0.9760f, 0.9806f, 0.9844f, 0.9872f, 0.9897f, 0.9914f, + 0.9933f, 0.9949f, 0.9962f, 0.9973f, 0.9984f, 0.9998f, 1.0019f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + 0.4524f, 0.6337f, 0.7564f, 0.8279f, 0.8762f, 0.9078f, 0.9293f, 0.9446f, + 0.9569f, 0.9662f, 0.9727f, 0.9780f, 0.9825f, 0.9856f, 0.9885f, 0.9905f, + 0.9926f, 0.9944f, 0.9958f, 0.9971f, 0.9982f, 0.9996f, 1.0018f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), 0.4232f, 0.6396f, 0.7544f, 0.8324f, 0.8793f, 0.9100f, 0.9314f, + 0.9475f, 0.9593f, 0.9676f, 0.9743f, 0.9798f, 0.9837f, 0.9870f, 0.9893f, + 0.9917f, 0.9938f, 0.9953f, 0.9967f, 0.9980f, 0.9995f, 1.0017f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), 0.4490f, 0.6393f, 0.7659f, 0.8381f, 0.8835f, 0.9133f, + 0.9358f, 0.9510f, 0.9616f, 0.9698f, 0.9765f, 0.9812f, 0.9851f, 0.9878f, + 0.9906f, 0.9929f, 0.9947f, 0.9963f, 0.9976f, 0.9993f, 1.0016f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), 0.4237f, 0.6431f, 0.7648f, 0.8406f, 0.8864f, + 0.9180f, 0.9391f, 0.9533f, 0.9641f, 0.9725f, 0.9781f, 0.9829f, 0.9862f, + 0.9895f, 0.9921f, 0.9940f, 0.9959f, 0.9973f, 0.9991f, 1.0015f, 1.0054f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), 0.4558f, 0.6601f, 0.7805f, 0.8502f, + 0.8960f, 0.9246f, 0.9437f, 0.9575f, 0.9678f, 0.9747f, 0.9805f, 0.9845f, + 0.9883f, 0.9912f, 0.9934f, 0.9954f, 0.9969f, 0.9989f, 1.0014f, 1.0053f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), 0.4558f, 0.6737f, 0.7860f, + 0.8590f, 0.9014f, 0.9293f, 0.9476f, 0.9613f, 0.9702f, 0.9771f, 0.9819f, + 0.9866f, 0.9899f, 0.9924f, 0.9948f, 0.9966f, 0.9986f, 1.0013f, 1.0053f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.4770f, 0.6794f, + 0.8035f, 0.8691f, 0.9088f, 0.9343f, 0.9524f, 0.9643f, 0.9730f, 0.9791f, + 0.9846f, 0.9885f, 0.9914f, 0.9941f, 0.9961f, 0.9984f, 1.0012f, 1.0053f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.4699f, + 0.6968f, 0.8121f, 0.8755f, 0.9139f, 0.9397f, 0.9560f, 0.9677f, 0.9752f, + 0.9819f, 0.9867f, 0.9902f, 0.9932f, 0.9955f, 0.9981f, 1.0011f, 1.0052f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + 0.4956f, 0.7028f, 0.8182f, 0.8815f, 0.9205f, 0.9439f, 0.9598f, 0.9699f, + 0.9784f, 0.9844f, 0.9887f, 0.9922f, 0.9948f, 0.9977f, 1.0009f, 1.0052f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), 0.4918f, 0.7106f, 0.8276f, 0.8906f, 0.9265f, 0.9490f, 0.9628f, + 0.9740f, 0.9815f, 0.9867f, 0.9911f, 0.9941f, 0.9973f, 1.0007f, 1.0051f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), 0.5015f, 0.7317f, 0.8442f, 0.9012f, 0.9347f, 0.9541f, + 0.9685f, 0.9783f, 0.9846f, 0.9897f, 0.9932f, 0.9969f, 1.0006f, 1.0051f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), 0.5291f, 0.7533f, 0.8578f, 0.9116f, 0.9407f, + 0.9610f, 0.9733f, 0.9816f, 0.9880f, 0.9922f, 0.9964f, 1.0004f, 1.0051f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), 0.5507f, 0.7709f, 0.8700f, 0.9187f, + 0.9489f, 0.9668f, 0.9777f, 0.9856f, 0.9908f, 0.9957f, 1.0001f, 1.0050f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), 0.5809f, 0.7915f, 0.8818f, + 0.9310f, 0.9574f, 0.9723f, 0.9827f, 0.9892f, 0.9950f, 0.9999f, 1.0050f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.6050f, 0.8052f, + 0.8975f, 0.9415f, 0.9641f, 0.9783f, 0.9869f, 0.9939f, 0.9994f, 1.0049f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.6376f, + 0.8337f, 0.9148f, 0.9514f, 0.9719f, 0.9836f, 0.9925f, 0.9990f, 1.0048f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + 0.6782f, 0.8583f, 0.9272f, 0.9616f, 0.9791f, 0.9907f, 0.9984f, 1.0047f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), 0.7215f, 0.8841f, 0.9450f, 0.9720f, 0.9882f, 0.9976f, 1.0046f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), 0.7453f, 0.9067f, 0.9592f, 0.9842f, 0.9965f, 1.0045f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), 0.7948f, 0.9312f, 0.9766f, 0.9945f, 1.0042f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), 0.8412f, 0.9607f, 0.9912f, 1.0039f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), 0.9070f, 0.9836f, 1.0032f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.9564f, 1.0019f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), 0.9973f, + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), + NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN(), NaN()}); + return table; +} + +// Registry of exact hardcoded matches. +// Future extension only requires adding new entries here. +const std::unordered_map& +GetRhoRegistry() { + static const std:: + unordered_map + registry = { + {RhoTableSpec{"qwen_image_edit_plus", 4.0f, 40}, + &build_qwen_image_edit_plus_cfg4_steps40}, + }; + return registry; +} + +} // namespace + +std::size_t RhoTableSpecHash::operator()(const RhoTableSpec& spec) const { + std::size_t h1 = std::hash{}(spec.model_name); + std::size_t h2 = + std::hash{}(static_cast(std::round(spec.cfg_scale * 1000.0f))); + std::size_t h3 = std::hash{}(spec.infer_steps); + + // A simple hash combine. + std::size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; +} + +torch::Tensor get_hardcoded_rho_table(const RhoTableSpec& spec) { + const auto& registry = GetRhoRegistry(); + auto it = registry.find(spec); + if (it == registry.end()) { + return torch::Tensor(); + } + return it->second(); +} + +bool has_hardcoded_rho_table(const RhoTableSpec& spec) { + const auto& registry = GetRhoRegistry(); + return registry.find(spec) != registry.end(); +} + +std::string to_string(const RhoTableSpec& spec) { + std::ostringstream oss; + oss << "{model_name=" << spec.model_name << ", cfg_scale=" << spec.cfg_scale + << ", infer_steps=" << spec.infer_steps << "}"; + return oss.str(); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.h b/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.h new file mode 100755 index 000000000..4b5bb832e --- /dev/null +++ b/xllm/core/framework/dit_cache/calibration/racfgcache_calibration_tables.h @@ -0,0 +1,55 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +namespace xllm { + +// Structured identifier for a hardcoded rho table. +// This is designed for future extensibility across: +// - different model families +// - different CFG scales +// - different inference step counts +struct RhoTableSpec { + std::string model_name; + float cfg_scale = 0.0f; + int64_t infer_steps = 0; + + bool operator==(const RhoTableSpec& other) const { + return model_name == other.model_name && cfg_scale == other.cfg_scale && + infer_steps == other.infer_steps; + } +}; + +// Hash support for unordered_map. +struct RhoTableSpecHash { + std::size_t operator()(const RhoTableSpec& spec) const; +}; + +// Return a hardcoded rho table if the exact spec is registered. +// Return an undefined tensor if not found. +torch::Tensor get_hardcoded_rho_table(const RhoTableSpec& spec); + +// Return whether an exact hardcoded rho table exists for the given spec. +bool has_hardcoded_rho_table(const RhoTableSpec& spec); + +// Convert spec to a readable string for logging/debugging. +std::string to_string(const RhoTableSpec& spec); + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/dit_cache/dit_cache.cpp b/xllm/core/framework/dit_cache/dit_cache.cpp old mode 100644 new mode 100755 index 156161915..64fabdd31 --- a/xllm/core/framework/dit_cache/dit_cache.cpp +++ b/xllm/core/framework/dit_cache/dit_cache.cpp @@ -64,4 +64,9 @@ CacheStepOut DiTCache::on_after_step(const CacheStepIn& stepin, bool use_cfg) { return active_cache_->on_after_step(stepin); } +void DiTCache::set_runtime_context(const DiTCacheRuntimeContext& ctx) { + if (active_cache_) active_cache_->set_runtime_context(ctx); + if (active_cond_cache_) active_cond_cache_->set_runtime_context(ctx); +} + } // namespace xllm diff --git a/xllm/core/framework/dit_cache/dit_cache.h b/xllm/core/framework/dit_cache/dit_cache.h old mode 100644 new mode 100755 index f536f5685..9d967668f --- a/xllm/core/framework/dit_cache/dit_cache.h +++ b/xllm/core/framework/dit_cache/dit_cache.h @@ -65,6 +65,8 @@ class DiTCache { active_cond_cache_->set_num_blocks(num_blocks); } + void set_runtime_context(const DiTCacheRuntimeContext& ctx); + private: torch::Tensor get_tensor_or_empty(const TensorMap& m, const std::string& k); diff --git a/xllm/core/framework/dit_cache/dit_cache_config.h b/xllm/core/framework/dit_cache/dit_cache_config.h old mode 100644 new mode 100755 index 289aab985..e72808dc0 --- a/xllm/core/framework/dit_cache/dit_cache_config.h +++ b/xllm/core/framework/dit_cache/dit_cache_config.h @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #pragma once +#include +#include namespace xllm { @@ -22,7 +24,8 @@ enum class PolicyType { FBCache, TaylorSeer, FBCacheTaylorSeer, - ResidualCache + ResidualCache, + RACFGCache }; struct DiTBaseCacheOptions { @@ -68,6 +71,36 @@ struct ResidualCacheOptions { int64_t skip_interval_steps = 3; }; +struct RACFGCacheOptions : public DiTBaseCacheOptions { + // Number of blocks to run before making the joint decision. + int64_t probe_depth = 2; + + // Joint accumulated-risk threshold. + float tau = 0.0f; + + // True CFG scale used in the guided combination. + float true_cfg_scale = 3.0f; + + // Whether to use propagation-aware reweighting. + bool use_prop_weight = true; + + // Propagation-aware fitted parameters. + float prop_a = 0.4806166f; + float prop_alpha = 0.4782565f; + float prop_b = 0.0641170f; + + // Branch-local proxy error choice: + // 0 -> delta_y + // 1 -> delta_minus + int64_t proxy_error_type = 0; + + // Offline rho table path. + std::string rho_table_path = ""; + std::string model_name = ""; + // Optional matched CFG scale recorded with the rho table. + float matched_cfg_scale = -1.0f; +}; + struct DiTCacheConfig { DiTCacheConfig() = default; @@ -85,6 +118,9 @@ struct DiTCacheConfig { // the configuration for ResidualCache policy. ResidualCacheOptions residual_cache; + + // the configuration for RACFGCache policy. + RACFGCacheOptions racfgcache; }; } // namespace xllm diff --git a/xllm/core/framework/dit_cache/dit_cache_impl.cpp b/xllm/core/framework/dit_cache/dit_cache_impl.cpp old mode 100644 new mode 100755 index 1c0fe7b7a..7edf75efb --- a/xllm/core/framework/dit_cache/dit_cache_impl.cpp +++ b/xllm/core/framework/dit_cache/dit_cache_impl.cpp @@ -15,9 +15,13 @@ limitations under the License. #include "dit_cache_impl.h" +#include + #include "dit_non_cache.h" #include "fbcache.h" #include "fbcache_taylorseer.h" +#include "framework/parallel_state/parallel_state.h" +#include "racfgcache.h" #include "residual_cache.h" #include "taylorseer.h" @@ -32,7 +36,7 @@ torch::Tensor DitCacheImpl::get_tensor_or_empty(const TensorMap& m, bool DitCacheImpl::is_similar(const torch::Tensor& lhs, const torch::Tensor& rhs, - float threshold) { + float threshold) const { if (!lhs.defined() || !rhs.defined()) return false; if (lhs.sizes() != rhs.sizes()) return false; @@ -40,9 +44,26 @@ bool DitCacheImpl::is_similar(const torch::Tensor& lhs, return torch::allclose(lhs, rhs); } - auto diff = (lhs - rhs).abs(); - auto mean_diff = diff.mean(); - auto mean_lhs = lhs.abs().mean(); + torch::Device dev = lhs.device(); + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(dev); + + auto sum_abs_diff = (lhs - rhs).abs().sum().to(torch::kFloat32); + auto sum_abs_lhs = lhs.abs().sum().to(torch::kFloat32); + auto count = torch::tensor({static_cast(lhs.numel())}, opts); + + if (runtime_ctx_.sp_enabled && runtime_ctx_.sp_world_size > 1 && + runtime_ctx_.sp_group != nullptr) { + auto* sp_group = static_cast(runtime_ctx_.sp_group); + CHECK(sp_group != nullptr) + << "sp_group is null in DitCacheImpl::is_similar"; + + sum_abs_diff = xllm::parallel_state::reduce(sum_abs_diff, sp_group); + sum_abs_lhs = xllm::parallel_state::reduce(sum_abs_lhs, sp_group); + count = xllm::parallel_state::reduce(count, sp_group); + } + + auto mean_diff = sum_abs_diff / count; + auto mean_lhs = sum_abs_lhs / count; auto rel = mean_diff / (mean_lhs + 1e-6); return (rel < threshold).item(); @@ -58,6 +79,8 @@ std::unique_ptr create_dit_cache(const DiTCacheConfig& cfg) { return std::make_unique(); case PolicyType::ResidualCache: return std::make_unique(); + case PolicyType::RACFGCache: + return std::make_unique(); default: return std::make_unique(); } diff --git a/xllm/core/framework/dit_cache/dit_cache_impl.h b/xllm/core/framework/dit_cache/dit_cache_impl.h index f6096d303..3cf3986f5 100644 --- a/xllm/core/framework/dit_cache/dit_cache_impl.h +++ b/xllm/core/framework/dit_cache/dit_cache_impl.h @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #pragma once +#include #include #include @@ -24,6 +25,24 @@ namespace xllm { using TensorMap = std::unordered_map; +struct DiTCacheRuntimeContext { + std::string model_name; + // cfg parallel runtime info + void* cfg_group = nullptr; + int64_t cfg_rank = 0; + int64_t cfg_world_size = 1; + bool cfg_enabled = false; + float true_cfg_scale = 1.0f; + // sequence parallel runtime info + void* sp_group = nullptr; + int64_t sp_rank = 0; + int64_t sp_world_size = 1; + bool sp_enabled = false; + // optional runtime metadata + int64_t infer_steps = 0; + int64_t num_blocks = 0; +}; + class DitCacheImpl { public: DitCacheImpl() = default; @@ -45,6 +64,14 @@ class DitCacheImpl { num_blocks_ = num_blocks; } + virtual void set_runtime_context(const DiTCacheRuntimeContext& ctx) { + runtime_ctx_ = ctx; + } + + const DiTCacheRuntimeContext& get_runtime_context() const { + return runtime_ctx_; + } + protected: int64_t num_inference_steps_; int64_t warmup_steps_; @@ -52,12 +79,13 @@ class DitCacheImpl { int64_t infer_steps_; int64_t num_blocks_; TensorMap buffers; + DiTCacheRuntimeContext runtime_ctx_; static torch::Tensor get_tensor_or_empty(const TensorMap& m, const std::string& k); - static bool is_similar(const torch::Tensor& lhs, - const torch::Tensor& rhs, - float threshold); + bool is_similar(const torch::Tensor& lhs, + const torch::Tensor& rhs, + float threshold) const; }; std::unique_ptr create_dit_cache(const DiTCacheConfig& cfg); diff --git a/xllm/core/framework/dit_cache/racfgcache.cpp b/xllm/core/framework/dit_cache/racfgcache.cpp new file mode 100644 index 000000000..c8b9754da --- /dev/null +++ b/xllm/core/framework/dit_cache/racfgcache.cpp @@ -0,0 +1,605 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "racfgcache.h" + +#include +#include + +#include +#include +#include + +#include "calibration/racfgcache_calibration_tables.h" +#include "framework/parallel_state/parallel_state.h" + +namespace xllm { + +namespace { +inline bool is_nan_f(float x) { return std::isnan(x); } +} // namespace + +void RACFGCache::init(const DiTCacheConfig& cfg) { + warmup_steps_ = cfg.racfgcache.warmup_steps; + probe_depth_ = cfg.racfgcache.probe_depth; + tau_ = cfg.racfgcache.tau; + true_cfg_scale_ = cfg.racfgcache.true_cfg_scale; + + use_prop_weight_ = cfg.racfgcache.use_prop_weight; + prop_a_ = cfg.racfgcache.prop_a; + prop_alpha_ = cfg.racfgcache.prop_alpha; + prop_b_ = cfg.racfgcache.prop_b; + + proxy_error_type_ = + static_cast(cfg.racfgcache.proxy_error_type); + + rho_table_path_ = cfg.racfgcache.rho_table_path; + model_name_ = cfg.racfgcache.model_name; + validate_config_(); + reset_all_state_(); + build_prop_weight_schedule_(); +} + +void RACFGCache::set_runtime_context(const DiTCacheRuntimeContext& ctx) { + DitCacheImpl::set_runtime_context(ctx); + if (ctx.infer_steps > 0) { + infer_steps_ = ctx.infer_steps; + } + if (ctx.num_blocks > 0) { + num_blocks_ = ctx.num_blocks; + } + if (ctx.true_cfg_scale > 0.0f) { + true_cfg_scale_ = ctx.true_cfg_scale; + } + build_prop_weight_schedule_(); + load_rho_table_(); +} + +bool RACFGCache::on_before_step(const CacheStepIn& stepin) { + current_step_ = stepin.step_id; + + if (current_step_ == 0) { + reset_all_state_(); + build_prop_weight_schedule_(); + } else { + reset_step_state_(current_step_); + } + + auto hidden_states = get_tensor_or_empty(stepin.tensors, "hidden_states"); + auto original_hidden_states = + get_tensor_or_empty(stepin.tensors, "original_hidden_states"); + auto encoder_hidden_states = + get_tensor_or_empty(stepin.tensors, "encoder_hidden_states"); + auto original_encoder_hidden_states = + get_tensor_or_empty(stepin.tensors, "original_encoder_hidden_states"); + + local_.base_hidden_states = + original_hidden_states.defined() ? original_hidden_states : hidden_states; + local_.base_encoder_hidden_states = original_encoder_hidden_states.defined() + ? original_encoder_hidden_states + : encoder_hidden_states; + + if (!cfg_parallel_enabled_()) { + force_full_this_step_ = true; + } + + if (current_step_ <= warmup_steps_) { + force_full_this_step_ = true; + } + return false; +} + +CacheStepOut RACFGCache::on_after_step(const CacheStepIn& stepin) { + auto hidden_states = get_tensor_or_empty(stepin.tensors, "hidden_states"); + auto original_hidden_states = + get_tensor_or_empty(stepin.tensors, "original_hidden_states"); + auto encoder_hidden_states = + get_tensor_or_empty(stepin.tensors, "encoder_hidden_states"); + auto original_encoder_hidden_states = + get_tensor_or_empty(stepin.tensors, "original_encoder_hidden_states"); + + if (!use_cache_) { + if (hidden_states.defined() && original_hidden_states.defined()) { + local_.previous_residual = + (hidden_states - original_hidden_states).detach().contiguous(); + } + + if (encoder_hidden_states.defined() && + original_encoder_hidden_states.defined()) { + local_.previous_encoder_residual = + (encoder_hidden_states - original_encoder_hidden_states) + .detach() + .contiguous(); + } + + update_local_history_after_full_(); + update_joint_state_after_full_(); + } else { + update_joint_state_after_reuse_(); + } + + TensorMap out_map; + if (hidden_states.defined()) { + out_map["hidden_states"] = hidden_states; + } + if (encoder_hidden_states.defined()) { + out_map["encoder_hidden_states"] = encoder_hidden_states; + } + return CacheStepOut(out_map); +} + +bool RACFGCache::on_before_block(const CacheBlockIn& blockin) { + if (force_full_this_step_) { + return false; + } + + if (!joint_decision_ready_) { + return false; + } + + if (!use_cache_) { + return false; + } + + if (blockin.block_id < probe_depth_) { + return false; + } + return true; +} + +CacheBlockOut RACFGCache::on_after_block(const CacheBlockIn& blockin) { + auto hidden_states = get_tensor_or_empty(blockin.tensors, "hidden_states"); + auto encoder_hidden_states = + get_tensor_or_empty(blockin.tensors, "encoder_hidden_states"); + auto original_hidden_states = + get_tensor_or_empty(blockin.tensors, "original_hidden_states"); + auto original_encoder_hidden_states = + get_tensor_or_empty(blockin.tensors, "original_encoder_hidden_states"); + + TensorMap out_map; + out_map["hidden_states"] = hidden_states; + if (encoder_hidden_states.defined()) { + out_map["encoder_hidden_states"] = encoder_hidden_states; + } + + if (force_full_this_step_) { + return CacheBlockOut(out_map); + } + + if (blockin.block_id != probe_depth_ - 1) { + return CacheBlockOut(out_map); + } + + prepare_probe_at_block_(blockin); + + if (!local_probe_available_() || !local_history_available_()) { + current_decision_.ready = true; + current_decision_.reuse_both = false; + joint_decision_ready_ = true; + use_cache_ = false; + return CacheBlockOut(out_map); + } + + current_decision_ = make_joint_decision_(current_step_); + joint_decision_ready_ = true; + use_cache_ = current_decision_.ready && current_decision_.reuse_both; + + if (!use_cache_) { + return CacheBlockOut(out_map); + } + + auto residual_applied = apply_prev_residual_pair_( + original_hidden_states, original_encoder_hidden_states); + + TensorMap cached_out_map; + cached_out_map["hidden_states"] = std::move(residual_applied.first); + if (residual_applied.second.defined()) { + cached_out_map["encoder_hidden_states"] = + std::move(residual_applied.second); + } + + return CacheBlockOut(cached_out_map); +} + +void RACFGCache::reset_all_state_() { + use_cache_ = false; + joint_decision_ready_ = false; + force_full_this_step_ = false; + + local_ = BranchLocalState{}; + joint_ = JointState{}; + current_decision_ = JointDecision{}; + + buffers.clear(); +} + +void RACFGCache::reset_step_state_(int64_t step_id) { + use_cache_ = false; + joint_decision_ready_ = false; + force_full_this_step_ = false; + + current_decision_ = JointDecision{}; + + local_.base_hidden_states = torch::Tensor(); + local_.base_encoder_hidden_states = torch::Tensor(); + + local_.current_probe_input = torch::Tensor(); + local_.current_probe_hidden = torch::Tensor(); + + local_.current_dx = std::numeric_limits::quiet_NaN(); + local_.current_dy = std::numeric_limits::quiet_NaN(); + local_.current_branch_error = std::numeric_limits::quiet_NaN(); + + local_.probe_ready_this_step = false; +} + +torch::Tensor RACFGCache::apply_prev_hidden_states_residual_( + const torch::Tensor& hidden_states) const { + if (!hidden_states.defined() || !local_.previous_residual.defined()) { + return hidden_states; + } + return (hidden_states + local_.previous_residual).contiguous(); +} + +std::pair RACFGCache::apply_prev_residual_pair_( + const torch::Tensor& original_hidden_states, + const torch::Tensor& original_encoder_hidden_states) const { + torch::Tensor new_hidden = original_hidden_states; + torch::Tensor new_encoder = original_encoder_hidden_states; + + if (original_hidden_states.defined() && local_.previous_residual.defined()) { + new_hidden = + (original_hidden_states + local_.previous_residual).contiguous(); + } + + if (original_encoder_hidden_states.defined() && + local_.previous_encoder_residual.defined()) { + new_encoder = + (original_encoder_hidden_states + local_.previous_encoder_residual) + .contiguous(); + } + + return {new_hidden, new_encoder}; +} + +float RACFGCache::compute_rel_l1_(const torch::Tensor& curr, + const torch::Tensor& prev, + float eps) const { + if (!curr.defined() || !prev.defined()) { + return std::numeric_limits::quiet_NaN(); + } + + torch::Device dev = curr.device(); + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(dev); + + auto sum_abs_diff = (curr - prev).abs().sum().to(torch::kFloat32); + auto sum_abs_prev = prev.abs().sum().to(torch::kFloat32); + auto count = torch::tensor({static_cast(curr.numel())}, opts); + + if (runtime_ctx_.sp_enabled && runtime_ctx_.sp_world_size > 1 && + runtime_ctx_.sp_group != nullptr) { + auto* sp_group = static_cast(runtime_ctx_.sp_group); + CHECK(sp_group != nullptr) << "sp_group is null in RACFGCache"; + + sum_abs_diff = xllm::parallel_state::reduce(sum_abs_diff, sp_group); + sum_abs_prev = xllm::parallel_state::reduce(sum_abs_prev, sp_group); + count = xllm::parallel_state::reduce(count, sp_group); + } + + auto denom = sum_abs_prev / count; + float denom_f = denom.item(); + if (std::abs(denom_f) < eps) { + return 0.0f; + } + + auto mean_abs_diff = sum_abs_diff / count; + auto rel = mean_abs_diff / (denom + eps); + return rel.item(); +} + +float RACFGCache::compute_branch_error_(float dx, float dy) const { + if (is_nan_f(dy)) { + return std::numeric_limits::quiet_NaN(); + } + + if (proxy_error_type_ == ProxyErrorType::DeltaMinus) { + if (is_nan_f(dx)) { + return std::numeric_limits::quiet_NaN(); + } + return std::abs(dy - dx); + } + + // default: delta_y + return dy; +} + +bool RACFGCache::local_probe_available_() const { + return local_.probe_ready_this_step && !is_nan_f(local_.current_branch_error); +} + +bool RACFGCache::local_history_available_() const { + return local_.previous_residual.defined() && + local_.first_probe_input_prev.defined() && + local_.probe_hidden_prev.defined(); +} + +void RACFGCache::prepare_probe_at_block_(const CacheBlockIn& blockin) { + auto hidden_states = get_tensor_or_empty(blockin.tensors, "hidden_states"); + auto original_hidden_states = + get_tensor_or_empty(blockin.tensors, "original_hidden_states"); + + local_.current_probe_input = + original_hidden_states.defined() ? original_hidden_states : hidden_states; + local_.current_probe_hidden = hidden_states; + + local_.current_dx = compute_rel_l1_(local_.current_probe_input, + local_.first_probe_input_prev); + local_.current_dy = + compute_rel_l1_(local_.current_probe_hidden, local_.probe_hidden_prev); + + local_.current_branch_error = + compute_branch_error_(local_.current_dx, local_.current_dy); + + local_.probe_ready_this_step = true; +} + +bool RACFGCache::cfg_parallel_enabled_() const { + return runtime_ctx_.cfg_enabled && runtime_ctx_.cfg_world_size == 2 && + runtime_ctx_.cfg_group != nullptr; +} + +bool RACFGCache::is_cond_rank_() const { return runtime_ctx_.cfg_rank == 0; } + +bool RACFGCache::is_uncond_rank_() const { return runtime_ctx_.cfg_rank == 1; } + +std::pair RACFGCache::exchange_branch_errors_( + float local_err) const { + if (!cfg_parallel_enabled_()) { + return {local_err, local_err}; + } + + auto* cfg_group = static_cast(runtime_ctx_.cfg_group); + CHECK(cfg_group != nullptr) << "cfg_group is null in RACFGCache"; + + torch::Device dev = local_.base_hidden_states.defined() + ? local_.base_hidden_states.device() + : torch::Device(torch::kCPU); + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(dev); + auto local_tensor = torch::tensor({local_err}, opts); + + auto gathered = xllm::parallel_state::gather(local_tensor, cfg_group, 0); + auto gathered_cpu = gathered.to(torch::kCPU).contiguous().view({-1}); + + CHECK_EQ(gathered_cpu.numel(), 2) + << "RACFGCache expects cfg world size = 2, but got gathered numel = " + << gathered_cpu.numel(); + + // rank 0 = cond, rank 1 = uncond + float ec = gathered_cpu[0].item(); + float eu = gathered_cpu[1].item(); + return {ec, eu}; +} + +float RACFGCache::lookup_rho_(int64_t anchor_step, int64_t step_id) const { + if (!rho_table_at_cpu_.defined() || rho_table_at_cpu_.numel() == 0) { + return 0.0f; + } + + CHECK_EQ(rho_table_at_cpu_.dim(), 2) + << "rho_table_at is expected to be a 2D tensor"; + + int64_t a_max = rho_table_at_cpu_.size(0) - 1; + int64_t t_max = rho_table_at_cpu_.size(1) - 1; + + int64_t a = std::max(0, std::min(anchor_step, a_max)); + int64_t t = std::max(0, std::min(step_id, t_max)); + + return rho_table_at_cpu_.index({a, t}).item(); +} + +float RACFGCache::build_prop_weight_(int64_t step_i) const { + if (!use_prop_weight_) { + return 1.0f; + } + + if (!joint_.prop_weight_schedule.empty()) { + int64_t idx = std::max( + 0, + std::min( + step_i, + static_cast(joint_.prop_weight_schedule.size()) - 1)); + return joint_.prop_weight_schedule[idx]; + } + + int64_t T = infer_steps_ > 0 ? infer_steps_ : runtime_ctx_.infer_steps; + if (T <= 1) { + return 1.0f; + } + + int64_t k = std::max(0, std::min(step_i, T - 1)); + double x = static_cast(T - 1 - k) / static_cast(T - 1); + double raw_t = static_cast(prop_a_) * std::pow(x, prop_alpha_) + + static_cast(prop_b_); + + double raw_sum = 0.0; + for (int64_t i = 0; i < T; ++i) { + double xi = static_cast(T - 1 - i) / static_cast(T - 1); + raw_sum += static_cast(prop_a_) * std::pow(xi, prop_alpha_) + + static_cast(prop_b_); + } + double raw_mean = std::max(raw_sum / static_cast(T), 1e-12); + + return static_cast(raw_t / raw_mean); +} + +void RACFGCache::build_prop_weight_schedule_() { + joint_.prop_weight_schedule.clear(); + + if (!use_prop_weight_) { + return; + } + + int64_t T = infer_steps_ > 0 ? infer_steps_ : runtime_ctx_.infer_steps; + if (T <= 0) { + return; + } + + joint_.prop_weight_schedule.resize(T, 1.0f); + + if (T == 1) { + joint_.prop_weight_schedule[0] = 1.0f; + return; + } + + std::vector raws(T, 0.0); + double raw_sum = 0.0; + for (int64_t i = 0; i < T; ++i) { + double x = static_cast(T - 1 - i) / static_cast(T - 1); + double raw = static_cast(prop_a_) * std::pow(x, prop_alpha_) + + static_cast(prop_b_); + raws[i] = raw; + raw_sum += raw; + } + double raw_mean = std::max(raw_sum / static_cast(T), 1e-12); + + for (int64_t i = 0; i < T; ++i) { + joint_.prop_weight_schedule[i] = static_cast(raws[i] / raw_mean); + } +} + +RACFGCache::JointDecision RACFGCache::make_joint_decision_(int64_t step_id) { + JointDecision out; + out.ready = true; + + if (force_full_this_step_) { + out.reuse_both = false; + return out; + } + + if (!local_history_available_()) { + out.reuse_both = false; + return out; + } + + auto local_err = local_.current_branch_error; + if (is_nan_f(local_err)) { + out.reuse_both = false; + return out; + } + + auto branch_errors = exchange_branch_errors_(local_err); + out.ec = branch_errors.first; + out.eu = branch_errors.second; + + if (is_nan_f(out.ec) || is_nan_f(out.eu)) { + out.reuse_both = false; + return out; + } + + float s = true_cfg_scale_; + float rho = lookup_rho_(joint_.anchor_step, step_id); + + float term_u = (1.0f - s) * (1.0f - s) * out.eu * out.eu; + float term_c = s * s * out.ec * out.ec; + float term_cross = 2.0f * s * (1.0f - s) * out.eu * out.ec * rho; + + float val_raw = term_u + term_c + term_cross; + float val_clamped = std::max(val_raw, 0.0f); + + out.dhat = std::sqrt(val_clamped); + out.ghat = build_prop_weight_(step_id); + out.rhat = std::max(out.dhat * out.ghat, 0.0f); + + bool budget_ok = (joint_.accumulated_risk + out.rhat <= tau_); + out.reuse_both = budget_ok; + return out; +} + +void RACFGCache::update_local_history_after_full_() { + if (local_.current_probe_input.defined()) { + local_.first_probe_input_prev = + local_.current_probe_input.detach().contiguous(); + } + + if (local_.current_probe_hidden.defined()) { + local_.probe_hidden_prev = + local_.current_probe_hidden.detach().contiguous(); + } +} + +void RACFGCache::update_joint_state_after_full_() { + joint_.anchor_step = current_step_; + joint_.accumulated_risk = 0.0f; + joint_.consecutive_reuse = 0; + joint_.last_reuse = false; +} + +void RACFGCache::update_joint_state_after_reuse_() { + if (!is_nan_f(current_decision_.rhat)) { + joint_.accumulated_risk += current_decision_.rhat; + } + joint_.consecutive_reuse += 1; + joint_.last_reuse = true; +} + +void RACFGCache::validate_config_() const { + CHECK_GE(warmup_steps_, 0) << "warmup_steps must be >= 0"; + CHECK_GE(probe_depth_, 1) << "probe_depth must be >= 1"; + CHECK_GT(tau_, 0.0f) << "tau must be > 0"; + CHECK_GE(true_cfg_scale_, 1.0f) << "true_cfg_scale must be >= 1.0"; + CHECK(proxy_error_type_ == ProxyErrorType::DeltaY || + proxy_error_type_ == ProxyErrorType::DeltaMinus) + << "unsupported proxy_error_type"; +} + +void RACFGCache::load_rho_table_() { + if (rho_table_path_.empty()) { + rho_table_at_cpu_ = torch::Tensor(); + LOG(INFO) << "[RACFG][RHO] no rho key configured, fallback to rho=0"; + return; + } + + RhoTableSpec spec; + spec.model_name = model_name_; + spec.cfg_scale = true_cfg_scale_; + spec.infer_steps = infer_steps_ > 0 ? infer_steps_ : runtime_ctx_.infer_steps; + LOG(INFO) << "[RACFG][RHO] loading rho table with spec: " + << "model_name=" << spec.model_name + << " cfg_scale=" << spec.cfg_scale + << " infer_steps=" << spec.infer_steps; + auto hardcoded = get_hardcoded_rho_table(spec); + if (hardcoded.defined()) { + CHECK_EQ(hardcoded.dim(), 2) << "hardcoded rho table must be 2D"; + rho_table_at_cpu_ = + hardcoded.to(torch::kCPU).to(torch::kFloat32).contiguous(); + LOG(INFO) << "[RACFG][RHO] loaded hardcoded rho table with key=" + << rho_table_path_ << " shape=" << rho_table_at_cpu_.sizes(); + return; + } + + // fallback: old file-based load + torch::Tensor table; + torch::load(table, rho_table_path_); + CHECK(table.defined()) << "failed to load rho table from " << rho_table_path_; + CHECK_EQ(table.dim(), 2) << "rho table must be a 2D tensor"; + + rho_table_at_cpu_ = table.to(torch::kCPU).to(torch::kFloat32).contiguous(); + + LOG(INFO) << "[RACFG][RHO] loaded file rho table from path=" + << rho_table_path_ << " shape=" << rho_table_at_cpu_.sizes(); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/dit_cache/racfgcache.h b/xllm/core/framework/dit_cache/racfgcache.h new file mode 100755 index 000000000..b26d28759 --- /dev/null +++ b/xllm/core/framework/dit_cache/racfgcache.h @@ -0,0 +1,171 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "dit_cache_impl.h" + +namespace xllm { + +class RACFGCache : public DitCacheImpl { + public: + RACFGCache() = default; + ~RACFGCache() override = default; + + RACFGCache(const RACFGCache&) = delete; + RACFGCache& operator=(const RACFGCache&) = delete; + RACFGCache(RACFGCache&&) = default; + RACFGCache& operator=(RACFGCache&&) = default; + + void init(const DiTCacheConfig& cfg) override; + + void set_runtime_context(const DiTCacheRuntimeContext& ctx) override; + + bool on_before_step(const CacheStepIn& stepin) override; + CacheStepOut on_after_step(const CacheStepIn& stepin) override; + + bool on_before_block(const CacheBlockIn& blockin) override; + CacheBlockOut on_after_block(const CacheBlockIn& blockin) override; + + private: + static constexpr float kEps = 1e-6f; + + enum class ProxyErrorType : int64_t { + DeltaY = 0, + DeltaMinus = 1, + }; + + struct BranchLocalState { + // residual cache + torch::Tensor previous_residual; + torch::Tensor previous_encoder_residual; + + // base state for this step + torch::Tensor base_hidden_states; + torch::Tensor base_encoder_hidden_states; + + torch::Tensor first_probe_input_prev; + torch::Tensor probe_hidden_prev; + // previous probe reference + torch::Tensor proxy_prev_input; + torch::Tensor proxy_prev_probe_states; + + // current-step probe tensors + torch::Tensor current_probe_input; + torch::Tensor current_probe_hidden; + + // current-step probe scalars + float current_dx = std::numeric_limits::quiet_NaN(); + float current_dy = std::numeric_limits::quiet_NaN(); + float current_branch_error = std::numeric_limits::quiet_NaN(); + + bool probe_ready_this_step = false; + }; + + struct JointState { + int64_t anchor_step = 0; + float accumulated_risk = 0.0f; + int64_t consecutive_reuse = 0; + bool last_reuse = false; + + // precomputed propagation-aware weights, size = infer_steps + std::vector prop_weight_schedule; + }; + + struct JointDecision { + bool ready = false; + bool reuse_both = false; + + float ec = std::numeric_limits::quiet_NaN(); + float eu = std::numeric_limits::quiet_NaN(); + + float dhat = std::numeric_limits::quiet_NaN(); + float ghat = 1.0f; + float rhat = std::numeric_limits::quiet_NaN(); + }; + + private: + // config + int64_t probe_depth_ = 2; + float tau_ = 0.24f; + float true_cfg_scale_ = 3.0f; + + bool use_prop_weight_ = true; + float prop_a_ = 0.4806166f; + float prop_alpha_ = 0.4782565f; + float prop_b_ = 0.0641170f; + + ProxyErrorType proxy_error_type_ = ProxyErrorType::DeltaY; + + std::string rho_table_path_; + std::string model_name_; + + // runtime state + bool use_cache_ = false; + bool joint_decision_ready_ = false; + bool force_full_this_step_ = false; + + BranchLocalState local_; + JointState joint_; + JointDecision current_decision_; + + // rho_table_at[a, t], stored on CPU + torch::Tensor rho_table_at_cpu_; + + private: + void reset_all_state_(); + void reset_step_state_(int64_t step_id); + torch::Tensor apply_prev_hidden_states_residual_( + const torch::Tensor& hidden_states) const; + + std::pair apply_prev_residual_pair_( + const torch::Tensor& original_hidden_states, + const torch::Tensor& original_encoder_hidden_states) const; + float compute_rel_l1_(const torch::Tensor& curr, + const torch::Tensor& prev, + float eps = kEps) const; + + float compute_branch_error_(float dx, float dy) const; + + bool local_probe_available_() const; + bool local_history_available_() const; + + void prepare_probe_at_block_(const CacheBlockIn& blockin); + bool cfg_parallel_enabled_() const; + bool is_cond_rank_() const; + bool is_uncond_rank_() const; + + std::pair exchange_branch_errors_(float local_err) const; + + float lookup_rho_(int64_t anchor_step, int64_t step_id) const; + float build_prop_weight_(int64_t step_i) const; + void build_prop_weight_schedule_(); + + JointDecision make_joint_decision_(int64_t step_id); + void update_local_history_after_full_(); + void update_joint_state_after_full_(); + void update_joint_state_after_reuse_(); + void validate_config_() const; + void load_rho_table_(); +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/dit_worker_impl.cpp b/xllm/core/runtime/dit_worker_impl.cpp old mode 100644 new mode 100755 index f38bde768..bd6ddec3e --- a/xllm/core/runtime/dit_worker_impl.cpp +++ b/xllm/core/runtime/dit_worker_impl.cpp @@ -72,6 +72,19 @@ DiTCacheConfig parse_dit_cache_from_flags() { FLAGS_dit_cache_end_blocks; cache_config.residual_cache.skip_interval_steps = FLAGS_dit_cache_skip_interval_steps; + } else if (FLAGS_dit_cache_policy == "RACFGCache") { + cache_config.selected_policy = PolicyType::RACFGCache; + cache_config.racfgcache.warmup_steps = FLAGS_dit_cache_warmup_steps; + cache_config.racfgcache.probe_depth = FLAGS_dit_cache_probe_depth; + cache_config.racfgcache.tau = FLAGS_dit_cache_tau; + cache_config.racfgcache.true_cfg_scale = FLAGS_true_cfg_scale; + cache_config.racfgcache.use_prop_weight = FLAGS_dit_cache_use_prop_weight; + cache_config.racfgcache.prop_a = FLAGS_dit_cache_prop_a; + cache_config.racfgcache.prop_alpha = FLAGS_dit_cache_prop_alpha; + cache_config.racfgcache.prop_b = FLAGS_dit_cache_prop_b; + cache_config.racfgcache.proxy_error_type = FLAGS_dit_cache_proxy_error_type; + cache_config.racfgcache.rho_table_path = FLAGS_dit_cache_rho_table_path; + cache_config.racfgcache.model_name = FLAGS_dit_cache_model_name; } else if (FLAGS_dit_cache_policy == "None") { cache_config.selected_policy = PolicyType::None; } diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h old mode 100644 new mode 100755 index 1b19b2ea0..10d857128 --- a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h @@ -271,9 +271,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { auto seed = generation_params.seed >= 0 ? generation_params.seed : 42; auto prompts = input.prompts; - auto prompts_2 = input.prompts_2; auto negative_prompts = input.negative_prompts; - auto negative_prompts_2 = input.negative_prompts_2; auto latents = input.latents; if (latents.defined()) { latents = latents.to(options_.device(), dtype_); @@ -283,7 +281,6 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { if (prompt_embeds.defined()) { prompt_embeds = prompt_embeds.to(options_.device(), dtype_); } - auto pooled_prompt_embeds = input.pooled_prompt_embeds; torch::Tensor prompt_embeds_mask; auto negative_prompt_embeds = input.negative_prompt_embeds; @@ -291,7 +288,6 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { negative_prompt_embeds = negative_prompt_embeds.to(options_.device(), dtype_); } - auto negative_pooled_prompt_embeds = input.negative_pooled_prompt_embeds; torch::Tensor negative_prompt_embeds_mask; std::vector image_list; @@ -385,9 +381,37 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { } } - bool has_neg_prompt = negative_prompts.size() > 0; + bool has_neg_prompt = + negative_prompts.size() > 0 || negative_prompt_embeds.defined(); bool do_true_cfg = (true_cfg_scale > 1.0) && has_neg_prompt; + DiTCacheRuntimeContext ctx; + // cfg parallel + ctx.cfg_group = static_cast(parallel_args_.dit_cfg_group_); + ctx.cfg_rank = parallel_args_.dit_cfg_group_ + ? parallel_args_.dit_cfg_group_->rank() + : 0; + ctx.cfg_world_size = parallel_args_.dit_cfg_group_ + ? parallel_args_.dit_cfg_group_->world_size() + : 1; + LOG(INFO) << "CHECk cfg enabled: FLAGS_cfg_size=" << FLAGS_cfg_size + << ", true_cfg_scale=" << true_cfg_scale + << ", cfg_group=" << parallel_args_.dit_cfg_group_; + ctx.cfg_enabled = (FLAGS_cfg_size == 2 && do_true_cfg && + parallel_args_.dit_cfg_group_ != nullptr); + // sequence parallel + ctx.sp_group = static_cast(parallel_args_.dit_sp_group_); + ctx.sp_rank = + parallel_args_.dit_sp_group_ ? parallel_args_.dit_sp_group_->rank() : 0; + ctx.sp_world_size = parallel_args_.dit_sp_group_ + ? parallel_args_.dit_sp_group_->world_size() + : 1; + ctx.sp_enabled = + (FLAGS_sp_size > 1 && parallel_args_.dit_sp_group_ != nullptr); + ctx.true_cfg_scale = generation_params.true_cfg_scale; + ctx.infer_steps = num_inference_steps; + ctx.num_blocks = num_layers_; + DiTCache::get_instance().set_runtime_context(ctx); // inplace update prompt_embeds and prompt_embeds_mask _encode_prompt(condition_images, prompts, diff --git a/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h index 2d4f5b4f0..a05f7e871 100644 --- a/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h +++ b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h @@ -2021,7 +2021,9 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { // Step start: prepare inputs (hidden_states, original_hidden_states) TensorMap step_in_map = { {"hidden_states", new_hidden_states}, - {"original_hidden_states", original_hidden_states}}; + {"original_hidden_states", original_hidden_states}, + {"encoder_hidden_states", new_encoder_hidden_states}, + {"original_encoder_hidden_states", original_encoder_hidden_states}}; CacheStepIn stepin_before(step_idx, step_in_map); use_step_cache = DiTCache::get_instance().on_before_step(stepin_before, use_cfg); @@ -2065,7 +2067,10 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { // Step end: update outputs (hidden_states, original_hidden_states) TensorMap step_after_map = { {"hidden_states", new_hidden_states}, - {"original_hidden_states", original_hidden_states}}; + {"original_hidden_states", original_hidden_states}, + {"encoder_hidden_states", new_encoder_hidden_states}, + {"original_encoder_hidden_states", original_encoder_hidden_states}}; + CacheStepIn stepin_after(step_idx, step_after_map); CacheStepOut stepout_after = DiTCache::get_instance().on_after_step(stepin_after, use_cfg);