diff --git a/gemma/activations.h b/gemma/activations.h index f00f81e2..2529041e 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -55,7 +55,9 @@ struct AttentionActivations { size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, size_t max_workers, const Allocator& allocator, std::vector>& row_ptrs) - : rep_factor(max_workers * + : heads(layer_config.heads), + qkv_dim(layer_config.qkv_dim), + rep_factor(max_workers * AttentionActivations::kThreadReplicationFactor / layer_config.heads), // `vocab_size == 0` means it is for Vit part, VitAttention @@ -115,6 +117,9 @@ struct AttentionActivations { // query tiles, which is not known here. flash_params.reserve(batch_size * layer_config.heads); split_flash_params.reserve(batch_size * layer_config.heads); + bf16_queries.resize(batch_size * heads * qkv_dim); + int16_queries.resize(batch_size * heads * qkv_dim); + q_scales.resize(batch_size * heads); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but @@ -140,10 +145,19 @@ struct AttentionActivations { softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); + bf16_queries.resize(batch_size * heads * qkv_dim); + int16_queries.resize(batch_size * heads * qkv_dim); + q_scales.resize(batch_size * heads); // `inv_timescale*` are not batched. } + size_t heads; + size_t qkv_dim; + AlignedBF16Vector bf16_queries; + std::vector> int16_queries; + AlignedFloatVector q_scales; + // Maximum factor by which we might scale-up work to maximize parallelism. size_t rep_factor = 1; // Parameters for flash attention. The size of the vector is somewhere between @@ -191,6 +205,9 @@ struct AttentionActivationsPtrs { : config(config), flash_params(flash_params), split_flash_params(split_flash_params), + bf16_queries(nullptr), + int16_queries(nullptr), + q_scales(nullptr), div_seq_len(static_cast(seq_len)), div_heads(static_cast(config.layer_configs[0].heads)), query_scale(ChooseQueryScale(config)) {} @@ -212,6 +229,9 @@ struct AttentionActivationsPtrs { att_sums = activations.att_sums; inv_timescale = activations.inv_timescale; inv_timescale_global = activations.inv_timescale_global; + bf16_queries = &activations.bf16_queries; + int16_queries = &activations.int16_queries; + q_scales = &activations.q_scales; } void SetBatchSize(size_t batch_size) { @@ -277,6 +297,9 @@ struct AttentionActivationsPtrs { sub_task_exp_denominator_sums; std::vector* sub_task_max_logits; + AlignedBF16Vector* bf16_queries; + std::vector>* int16_queries; + AlignedFloatVector* q_scales; // Inverse timescales for RoPE computation. MatPtrT inv_timescale; // Inverse timescales for global RoPE computation. diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 016f8a08..ce52d783 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -857,40 +857,33 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1, VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max, - float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx, - size_t kNumQueriesPerGroup, float* HWY_RESTRICT q_scales_s = nullptr, - float max_v_scale = 1.0f) { + float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t query_idx, + float* HWY_RESTRICT q_scales_s = nullptr, float max_v_scale = 1.0f) { constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); [[maybe_unused]] constexpr int kSecondHalfAmountOfQueries = kNumQueries - kFirstHalfAmountOfQueries; if constexpr (kNumQueries <= 4) { FlashAttentionTileStepAndApplySoftCap4( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, - x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s, - max_v_scale); + x_2_p1, x_3_p0, x_3_p1, old_max + query_idx, old_d + query_idx, scales, + q_scales_s, max_v_scale); } else { #if HWY_MAX_BYTES <= 16 FlashAttentionTileStepAndApplySoftCap4<4>( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, - x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s, - max_v_scale); + x_2_p1, x_3_p0, x_3_p1, old_max + query_idx, old_d + query_idx, scales, + q_scales_s, max_v_scale); FlashAttentionTileStepAndApplySoftCap4( df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, - x_6_p1, x_7_p0, x_7_p1, - old_max + (q_group_idx + 1) * kNumQueriesPerGroup, - old_d + (q_group_idx + 1) * kNumQueriesPerGroup, - scales + kNumQueriesPerGroup, - q_scales_s == nullptr ? nullptr : q_scales_s + kNumQueriesPerGroup, + x_6_p1, x_7_p0, x_7_p1, old_max + query_idx + 4, old_d + query_idx + 4, + scales + 4, q_scales_s == nullptr ? nullptr : q_scales_s + 4, max_v_scale); #else FlashAttentionTileStepAndApplySoftCap8( df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0, x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1, - x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup, - old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s, - max_v_scale); + x_7_p0, x_7_p1, old_max + query_idx, old_d + query_idx, scales, + q_scales_s, max_v_scale); #endif } } @@ -898,7 +891,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap( template , typename T> static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( - DQ_T df, const Q_T* HWY_RESTRICT q, const Q_T* HWY_RESTRICT q2, + DQ_T df, const Q_T* HWY_RESTRICT q_base, const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VQ_T& sum0_p0, VQ_T& sum0_p1, VQ_T& sum1_p0, VQ_T& sum1_p1, VQ_T& sum2_p0, VQ_T& sum2_p1, VQ_T& sum3_p0, VQ_T& sum3_p1, VQ_T& sum4_p0, VQ_T& sum4_p1, VQ_T& sum5_p0, @@ -939,10 +932,6 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( sum7_p1 = hn::Zero(df); } - constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); - constexpr int kSecondHalfAmountOfQueries = - kNumQueries - kFirstHalfAmountOfQueries; - HWY_UNROLL(1) for (size_t i = 0; i < qkv_dim; ++i) { VQ_T k_vec1, k_vec2; if constexpr (HWY_TARGET == HWY_AVX2) { @@ -951,58 +940,42 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( } Decompress2(df, k_transposed_span, i * gcpp::KVCache::kTileSize, k_vec1, k_vec2); - sum0_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p0); - sum0_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p1); + auto mul_add = [&](auto& sum_p0, auto& sum_p1, size_t q_idx) HWY_ATTR { + float q_scalar; + std::memcpy(&q_scalar, &q_base[q_idx * qkv_dim + i], sizeof(float)); + auto q_val = hn::Set(df, q_scalar); + sum_p0 = hn::MulAdd(k_vec1, q_val, sum_p0); + sum_p1 = hn::MulAdd(k_vec2, q_val, sum_p1); + }; + + mul_add(sum0_p0, sum0_p1, 0); if constexpr (kNumQueries >= 2) { - sum1_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p0); - sum1_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p1); + mul_add(sum1_p0, sum1_p1, 1); } if constexpr (kNumQueries >= 3) { - sum2_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p0); - sum2_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p1); + mul_add(sum2_p0, sum2_p1, 2); } if constexpr (kNumQueries >= 4) { - sum3_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p0); - sum3_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p1); + mul_add(sum3_p0, sum3_p1, 3); } if constexpr (kNumQueries >= 5) { - sum4_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p0); - sum4_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p1); + mul_add(sum4_p0, sum4_p1, 4); } if constexpr (kNumQueries >= 6) { - sum5_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p0); - sum5_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p1); + mul_add(sum5_p0, sum5_p1, 5); } if constexpr (kNumQueries >= 7) { - sum6_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p0); - sum6_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p1); + mul_add(sum6_p0, sum6_p1, 6); } if constexpr (kNumQueries >= 8) { - sum7_p0 = hn::MulAdd( - k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p0); - sum7_p1 = hn::MulAdd( - k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p1); + mul_add(sum7_p0, sum7_p1, 7); } } } template > static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16( - DF df, const int16_t* HWY_RESTRICT q, const int16_t* HWY_RESTRICT q2, + DF df, const int16_t* HWY_RESTRICT q_base, const int8_t* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0, VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1, VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0, @@ -1035,11 +1008,8 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16( VI32 isum6_odd_p0 = hn::Zero(di32), isum6_odd_p1 = hn::Zero(di32); VI32 isum7_odd_p0 = hn::Zero(di32), isum7_odd_p1 = hn::Zero(di32); - const int32_t* q_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q); - const int32_t* q2_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q2); - constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); - constexpr int kSecondHalfAmountOfQueries = - kNumQueries - kFirstHalfAmountOfQueries; + const int32_t* q_base_i32 = HWY_RCAST_ALIGNED(const int32_t*, q_base); + const size_t q_stride = qkv_dim / 2; const hn::Repartition di8; const hn::Half di8_half; @@ -1061,35 +1031,34 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16( sum_odd_p1); }; - accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries], isum0_p0, isum0_p1, - isum0_odd_p0, isum0_odd_p1); + accumulate(q_base_i32[i], isum0_p0, isum0_p1, isum0_odd_p0, isum0_odd_p1); if constexpr (kNumQueries >= 2) { - accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 1], isum1_p0, - isum1_p1, isum1_odd_p0, isum1_odd_p1); + accumulate(q_base_i32[q_stride + i], isum1_p0, isum1_p1, isum1_odd_p0, + isum1_odd_p1); } if constexpr (kNumQueries >= 3) { - accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 2], isum2_p0, - isum2_p1, isum2_odd_p0, isum2_odd_p1); + accumulate(q_base_i32[2 * q_stride + i], isum2_p0, isum2_p1, isum2_odd_p0, + isum2_odd_p1); } if constexpr (kNumQueries >= 4) { - accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 3], isum3_p0, - isum3_p1, isum3_odd_p0, isum3_odd_p1); + accumulate(q_base_i32[3 * q_stride + i], isum3_p0, isum3_p1, isum3_odd_p0, + isum3_odd_p1); } if constexpr (kNumQueries >= 5) { - accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 0], isum4_p0, - isum4_p1, isum4_odd_p0, isum4_odd_p1); + accumulate(q_base_i32[4 * q_stride + i], isum4_p0, isum4_p1, isum4_odd_p0, + isum4_odd_p1); } if constexpr (kNumQueries >= 6) { - accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 1], isum5_p0, - isum5_p1, isum5_odd_p0, isum5_odd_p1); + accumulate(q_base_i32[5 * q_stride + i], isum5_p0, isum5_p1, isum5_odd_p0, + isum5_odd_p1); } if constexpr (kNumQueries >= 7) { - accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 2], isum6_p0, - isum6_p1, isum6_odd_p0, isum6_odd_p1); + accumulate(q_base_i32[6 * q_stride + i], isum6_p0, isum6_p1, isum6_odd_p0, + isum6_odd_p1); } if constexpr (kNumQueries >= 8) { - accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 3], isum7_p0, - isum7_p1, isum7_odd_p0, isum7_odd_p1); + accumulate(q_base_i32[7 * q_stride + i], isum7_p0, isum7_p1, isum7_odd_p0, + isum7_odd_p1); } } @@ -1134,7 +1103,7 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16( template , typename T> static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16( - DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2, + DF df, const BF16* HWY_RESTRICT q_base, const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0, VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1, VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0, @@ -1187,85 +1156,46 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16( VF helper_sum5_p0 = hn::Zero(df), helper_sum5_p1 = hn::Zero(df); VF helper_sum6_p0 = hn::Zero(df), helper_sum6_p1 = hn::Zero(df); VF helper_sum7_p0 = hn::Zero(df), helper_sum7_p1 = hn::Zero(df); - const float* q_float_ptr = HWY_RCAST_ALIGNED(const float*, q); - const float* q2_float_ptr = HWY_RCAST_ALIGNED(const float*, q2); - constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); - constexpr int kSecondHalfAmountOfQueries = - kNumQueries - kFirstHalfAmountOfQueries; + + const float* q_base_f32 = HWY_RCAST_ALIGNED(const float*, q_base); + const size_t q_stride = qkv_dim / 2; for (size_t i = 0; i < qkv_dim / 2; i++) { VBF k_vec1, k_vec2; Decompress2(dbf, k_transposed_span, i * 2 * gcpp::KVCache::kTileSize, k_vec1, k_vec2); - VF q_0_as_float = hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries]); - VBF q_0 = hn::BitCast(dbf, q_0_as_float); - sum0_p0 = - hn::ReorderWidenMulAccumulate(df, k_vec1, q_0, sum0_p0, helper_sum0_p0); - sum0_p1 = - hn::ReorderWidenMulAccumulate(df, k_vec2, q_0, sum0_p1, helper_sum0_p1); + auto mul_accumulate = [&](auto& sum_p0, auto& sum_p1, auto& helper_p0, + auto& helper_p1, size_t q_idx) HWY_ATTR { + VBF q_val = + hn::BitCast(dbf, hn::Set(df, q_base_f32[q_idx * q_stride + i])); + sum_p0 = + hn::ReorderWidenMulAccumulate(df, k_vec1, q_val, sum_p0, helper_p0); + sum_p1 = + hn::ReorderWidenMulAccumulate(df, k_vec2, q_val, sum_p1, helper_p1); + }; + + mul_accumulate(sum0_p0, sum0_p1, helper_sum0_p0, helper_sum0_p1, 0); if constexpr (kNumQueries >= 2) { - VF q_1_as_float = - hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 1]); - VBF q_1 = hn::BitCast(dbf, q_1_as_float); - sum1_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_1, sum1_p0, - helper_sum1_p0); - sum1_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_1, sum1_p1, - helper_sum1_p1); + mul_accumulate(sum1_p0, sum1_p1, helper_sum1_p0, helper_sum1_p1, 1); } if constexpr (kNumQueries >= 3) { - VF q_2_as_float = - hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 2]); - VBF q_2 = hn::BitCast(dbf, q_2_as_float); - sum2_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_2, sum2_p0, - helper_sum2_p0); - sum2_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_2, sum2_p1, - helper_sum2_p1); + mul_accumulate(sum2_p0, sum2_p1, helper_sum2_p0, helper_sum2_p1, 2); } if constexpr (kNumQueries >= 4) { - VF q_3_as_float = - hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 3]); - VBF q_3 = hn::BitCast(dbf, q_3_as_float); - sum3_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_3, sum3_p0, - helper_sum3_p0); - sum3_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_3, sum3_p1, - helper_sum3_p1); + mul_accumulate(sum3_p0, sum3_p1, helper_sum3_p0, helper_sum3_p1, 3); } if constexpr (kNumQueries >= 5) { - VF q_4_as_float = - hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 0]); - VBF q_4 = hn::BitCast(dbf, q_4_as_float); - sum4_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_4, sum4_p0, - helper_sum4_p0); - sum4_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_4, sum4_p1, - helper_sum4_p1); + mul_accumulate(sum4_p0, sum4_p1, helper_sum4_p0, helper_sum4_p1, 4); } if constexpr (kNumQueries >= 6) { - VF q_5_as_float = - hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 1]); - VBF q_5 = hn::BitCast(dbf, q_5_as_float); - sum5_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_5, sum5_p0, - helper_sum5_p0); - sum5_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_5, sum5_p1, - helper_sum5_p1); + mul_accumulate(sum5_p0, sum5_p1, helper_sum5_p0, helper_sum5_p1, 5); } if constexpr (kNumQueries >= 7) { - VF q_6_as_float = - hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 2]); - VBF q_6 = hn::BitCast(dbf, q_6_as_float); - sum6_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_6, sum6_p0, - helper_sum6_p0); - sum6_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_6, sum6_p1, - helper_sum6_p1); + mul_accumulate(sum6_p0, sum6_p1, helper_sum6_p0, helper_sum6_p1, 6); } if constexpr (kNumQueries >= 8) { - VF q_7_as_float = - hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 3]); - VBF q_7 = hn::BitCast(dbf, q_7_as_float); - sum7_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_7, sum7_p0, - helper_sum7_p0); - sum7_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_7, sum7_p1, - helper_sum7_p1); + mul_accumulate(sum7_p0, sum7_p1, helper_sum7_p0, helper_sum7_p1, 7); } } #if HWY_NATIVE_DOT_BF16 == 0 @@ -1442,42 +1372,40 @@ static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0, template > static HWY_INLINE void ApplyQuantizationScale( - DF df, const float* HWY_RESTRICT q_scales, int q_group_idx, - int kNumQueriesPerGroup, VF& x0_p0, VF& x0_p1, VF& x1_p0, VF& x1_p1, - VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0, VF& x4_p1, VF& x5_p0, - VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, VF& x7_p1) { - auto apply_scale = [&](int group_offset, int query_offset, VF& x_p0, - VF& x_p1) HWY_ATTR { - int scale_idx = - (q_group_idx + group_offset) * kNumQueriesPerGroup + query_offset; + DF df, const float* HWY_RESTRICT q_scales, size_t query_idx, VF& x0_p0, + VF& x0_p1, VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, + VF& x4_p0, VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, + VF& x7_p1) { + auto apply_scale = [&](size_t i, VF& x_p0, VF& x_p1) HWY_ATTR { + size_t scale_idx = query_idx + i; VF s = hn::Set(df, q_scales[scale_idx]); x_p0 = hn::Mul(x_p0, s); x_p1 = hn::Mul(x_p1, s); }; if constexpr (kNumQueries >= 1) { - apply_scale(0, 0, x0_p0, x0_p1); + apply_scale(0, x0_p0, x0_p1); } if constexpr (kNumQueries >= 2) { - apply_scale(0, 1, x1_p0, x1_p1); + apply_scale(1, x1_p0, x1_p1); } if constexpr (kNumQueries >= 3) { - apply_scale(0, 2, x2_p0, x2_p1); + apply_scale(2, x2_p0, x2_p1); } if constexpr (kNumQueries >= 4) { - apply_scale(0, 3, x3_p0, x3_p1); + apply_scale(3, x3_p0, x3_p1); } if constexpr (kNumQueries >= 5) { - apply_scale(1, 0, x4_p0, x4_p1); + apply_scale(4, x4_p0, x4_p1); } if constexpr (kNumQueries >= 6) { - apply_scale(1, 1, x5_p0, x5_p1); + apply_scale(5, x5_p0, x5_p1); } if constexpr (kNumQueries >= 7) { - apply_scale(1, 2, x6_p0, x6_p1); + apply_scale(6, x6_p0, x6_p1); } if constexpr (kNumQueries >= 8) { - apply_scale(1, 3, x7_p0, x7_p1); + apply_scale(7, x7_p0, x7_p1); } } @@ -1506,9 +1434,8 @@ static HWY_INLINE void ApplyQuantizationScale( // keep values between calls to this function and avoid explicit merge. template HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( - const hwy::Span> kvs, int q_count, - const hwy::Span q_T_in_groups_up_to_4, - const hwy::Span q_scales, + const hwy::Span> kvs, size_t q_count, + const Q_T* HWY_RESTRICT q_base, const hwy::Span q_scales, hwy::Span start_pos_per_query, hwy::Span last_pos_per_query, const float att_cap, MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, @@ -1520,15 +1447,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( [[maybe_unused]] const DU du; constexpr int kTileSize = gcpp::KVCache::kTileSize; HWY_LANES_CONSTEXPR size_t kHTileSize = hn::Lanes(df); - constexpr int kNumQueriesPerGroup = 4; + constexpr int kNumQueriesPerLoop = (!HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3)) ? 8 : 4; - constexpr int kNumGroupsPerLoop = kNumQueriesPerLoop / kNumQueriesPerGroup; - const size_t full_groups_of_queries = q_count / kNumQueriesPerGroup; + const size_t num_loops = hwy::DivCeil(q_count, kNumQueriesPerLoop); const size_t qkv_dim = att_out.Cols(); HWY_DASSERT(kHTileSize <= hn::MaxLanes(df)); - HWY_LANES_CONSTEXPR size_t step_size = kHTileSize * 2; + + HWY_LANES_CONSTEXPR size_t step_size = 2 * kHTileSize; size_t smallest_start_pos = std::numeric_limits::max(); size_t largest_last_pos = std::numeric_limits::min(); for (size_t i = 0; i < start_pos_per_query.size(); ++i) { @@ -1589,8 +1516,8 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( q_scales_s_ptr = q_scales_s; } float max_v_scale = 1.0f; - auto inner_loop = [&](int q_group_idx) HWY_ATTR { - int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup); + auto inner_loop = [&](size_t query_idx) HWY_ATTR { + size_t loop_idx = query_idx / kNumQueriesPerLoop; if (position + step_size <= min_start_pos_per_group[loop_idx] || position > max_last_pos_per_group[loop_idx]) { return; @@ -1607,30 +1534,27 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( const KV_T* v_tile = tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim; - const Q_T* q_group = q_T_in_groups_up_to_4[q_group_idx]; - const Q_T* q2_group = nullptr; - if (kNumQueries > 4) { - q2_group = q_T_in_groups_up_to_4[q_group_idx + 1]; - } + + const Q_T* HWY_RESTRICT q_group = q_base + query_idx * qkv_dim; if constexpr (IsF32()) { const KV_T* k_transposed_tile = tile_base + pos_in_tile; QDotKTilexUpTo8TransposedKDoubleWidth( - df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, - x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, - x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + df, q_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, x_1_p_0, + x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, + x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } else if constexpr (IsBF16()) { const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2; QDotKTilexUpTo8TransposedKDoubleWidthBF16( - df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, - x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, - x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + df, q_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, x_1_p_0, + x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, + x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } else if constexpr (IsInt16()) { const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2; QDotKTilexUpTo8TransposedKDoubleWidthInt16( - df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, - x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, - x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + df, q_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, x_1_p_0, + x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, + x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } else { static_assert(false, "Query type not supported, only float, BF16, and " @@ -1653,10 +1577,9 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( } if constexpr (IsInt16()) { ApplyQuantizationScale( - df, q_scales.data(), q_group_idx, kNumQueriesPerGroup, x_0_p_0, - x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, - x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, - x_7_p_1); + df, q_scales.data(), query_idx, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, + x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, + x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); @@ -1674,12 +1597,10 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( if (position < max_start_pos_per_group[loop_idx] || position + step_size - 1 > min_last_pos_per_group[loop_idx]) { ApplyMasking( - df, du, position, - start_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup, - last_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup, - x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, - x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, - x_7_p_0, x_7_p_1); + df, du, position, start_pos_per_query.data() + query_idx, + last_pos_per_query.data() + query_idx, x_0_p_0, x_0_p_1, x_1_p_0, + x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, + x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } HWY_ALIGN float scales[kNumQueriesPerLoop]; @@ -1688,7 +1609,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( } if constexpr (IsInt16() && kUseMicroScaling) { - if (q_group_idx == 0) { // update only when needed + if (query_idx == 0) { // update only when needed const BF16* microscaling_scales_v = reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + kTileSize + pos_in_tile; @@ -1703,8 +1624,8 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( FlashAttentionTileStepAndApplySoftCap( df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, - x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx, - kNumQueriesPerGroup, q_scales_s_ptr, max_v_scale); + x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, query_idx, + q_scales_s_ptr, max_v_scale); if constexpr (kUseMicroScaling) { const BF16* microscaling_scales_v = reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + @@ -1739,27 +1660,31 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( current_kv_start_offset += kvs[current_kv_idx].Rows() * kTileSize; current_kv_idx++; } - int group_idx = 0; - for (; group_idx + kNumGroupsPerLoop <= full_groups_of_queries; - group_idx += kNumGroupsPerLoop) { - inner_loop.template operator()(group_idx); - } - if (group_idx < full_groups_of_queries) { - inner_loop.template operator()<4>(group_idx); - group_idx++; - } - switch (q_count % kNumQueriesPerGroup) { - case 1: - inner_loop.template operator()<1>(group_idx); - break; - case 2: - inner_loop.template operator()<2>(group_idx); - break; - case 3: - inner_loop.template operator()<3>(group_idx); - break; - default: - break; + size_t query_idx = 0; + for (; query_idx + kNumQueriesPerLoop <= q_count; + query_idx += kNumQueriesPerLoop) { + inner_loop.template operator()(query_idx); + } + if (query_idx < q_count) { + size_t rem = q_count - query_idx; + if (rem >= 4) { + inner_loop.template operator()<4>(query_idx); + query_idx += 4; + rem -= 4; + } + switch (rem) { + case 1: + inner_loop.template operator()<1>(query_idx); + break; + case 2: + inner_loop.template operator()<2>(query_idx); + break; + case 3: + inner_loop.template operator()<3>(query_idx); + break; + default: + break; + } } position += step_size; @@ -1767,37 +1692,36 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( } void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - hwy::Span kvs, int q_count, - const hwy::Span q_T_in_groups_up_to_4, + hwy::Span kvs, size_t q_count, + const float* HWY_RESTRICT q_base, hwy::Span start_pos_per_query, hwy::Span last_pos_per_query, const float att_cap, MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, float* HWY_RESTRICT max_logits) { CallUpcastedKVs(kvs, [&](const auto& kv_t) { return TileFlashAttentionReturnExpSumsAndMaxLogits( - kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query, - last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); + kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); }); } void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - hwy::Span kvs, int q_count, - const hwy::Span q_T_in_groups_up_to_4, + hwy::Span kvs, size_t q_count, + const BF16* HWY_RESTRICT q_base, hwy::Span start_pos_per_query, hwy::Span last_pos_per_query, const float att_cap, MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, float* HWY_RESTRICT max_logits) { CallUpcastedKVs(kvs, [&](const auto& kv_t) { return TileFlashAttentionReturnExpSumsAndMaxLogits( - kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query, - last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); + kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); }); } void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( - hwy::Span kvs, int q_count, - const hwy::Span q_T_in_groups_up_to_4, - const hwy::Span q_scales, + hwy::Span kvs, size_t q_count, + const int16_t* HWY_RESTRICT q_base, const hwy::Span q_scales, hwy::Span start_pos_per_query, hwy::Span last_pos_per_query, float att_cap, MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, @@ -1809,9 +1733,8 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( hwy::Span> matptrs_span(matptrs.data(), matptrs.size()); return TileFlashAttentionReturnExpSumsAndMaxLogits( - matptrs_span, q_count, q_T_in_groups_up_to_4, q_scales, - start_pos_per_query, last_pos_per_query, att_cap, att_out, - exp_denominator_sums, max_logits); + matptrs_span, q_count, q_base, q_scales, start_pos_per_query, + last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); } // Implements flash attention for a strip of tiles of size 1, 4 or 8 query diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 07410761..47618c48 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -51,25 +51,24 @@ namespace gcpp { ThreadingContext& ctx, AttentionImpl attention_impl); \ \ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \ - hwy::Span kvs, int q_count, \ - const hwy::Span q_T_in_groups_up_to_4, \ + hwy::Span kvs, size_t q_count, \ + const float* HWY_RESTRICT q_base, \ hwy::Span start_pos_per_query, \ hwy::Span last_pos_per_query, const float att_cap, \ MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ float* HWY_RESTRICT max_logits); \ \ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \ - hwy::Span kvs, int q_count, \ - const hwy::Span q_T_in_groups_up_to_4, \ + hwy::Span kvs, size_t q_count, \ + const BF16* HWY_RESTRICT q_base, \ hwy::Span start_pos_per_query, \ hwy::Span last_pos_per_query, const float att_cap, \ MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ float* HWY_RESTRICT max_logits); \ \ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \ - hwy::Span kvs, int q_count, \ - const hwy::Span q_T_in_groups_up_to_4, \ - hwy::Span q_scales, \ + hwy::Span kvs, size_t q_count, \ + const int16_t* HWY_RESTRICT q_base, hwy::Span q_scales, \ hwy::Span start_pos_per_query, \ hwy::Span last_pos_per_query, const float att_cap, \ MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 04c00051..427bdf10 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -286,6 +286,17 @@ void PopulateTestKVCache(MatStorageT& kv, gcpp::KVEncoding encoding, } } +AlignedFloatVector PopulateTestQueries(size_t num_queries, size_t qkv_dim) { + AlignedFloatVector q_all(num_queries * qkv_dim); + const float unpredictable_factor = 0.01f * hwy::Unpredictable1(); + for (size_t i = 0; i < num_queries; ++i) { + for (size_t j = 0; j < qkv_dim; ++j) { + q_all[i * qkv_dim + j] = unpredictable_factor * (i + 1) / (j + 1); + } + } + return q_all; +} + struct AttentionTestEnv { AttentionTestEnv(size_t num_queries, size_t kv_seq_len, size_t qkv_dim, AttentionImpl attention_impl); @@ -492,18 +503,7 @@ void TestTiledFlashAttention() { 2 * qkv_dim * gcpp::KVCache::kTileSize), ctx.allocator, MatPadding::kPacked); PopulateTestKVCache(kv, gcpp::KVEncoding::kF32, qkv_dim); - std::vector q_float(4 * qkv_dim); - std::vector q_float2(4 * qkv_dim); - // fill in qs with predictable, synthetic data - for (size_t i = 0; i < 4; ++i) { - for (size_t j = 0; j < qkv_dim; j++) { - float val_1 = 0.01f * (i + 1) / (j + 1); - float val_2 = 0.01f * (i + 4 + 1) / (j + 1); - q_float[j * 4 + i] = val_1; - q_float2[j * 4 + i] = val_2; - } - } - const float* q_T[2] = {q_float.data(), q_float2.data()}; + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); @@ -536,7 +536,7 @@ void TestTiledFlashAttention() { hwy::Span kvs(&kv, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - kvs, num_queries, hwy::Span(q_T, 2), + kvs, num_queries, q_all.data(), hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); @@ -578,22 +578,11 @@ void TestTiledFlashAttentionBF16() { ctx.allocator, MatPadding::kPacked); PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim); - std::vector q_all(num_queries * qkv_dim); - for (size_t i = 0; i < num_queries; ++i) { - for (size_t j = 0; j < qkv_dim; ++j) { - q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1); - } - } - std::vector q_ptrs(num_queries); - for (int i = 0; i < num_queries; ++i) { - q_ptrs[i] = q_all.data() + i * qkv_dim; - } - auto [transposed_queries, transposed_queries_ptrs, _] = - TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span(q_ptrs), - qkv_dim, /*group_size=*/4); - hwy::Span q_T( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()); + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); + std::vector> bf16_queries(num_queries * + qkv_dim); + CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries, + bf16_queries.data()); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); @@ -624,7 +613,8 @@ void TestTiledFlashAttentionBF16() { } hwy::Span kvs(&kv, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - kvs, num_queries, q_T, hwy::Span(start_pos_per_query), + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); @@ -673,18 +663,7 @@ void TestTiledFlashAttentionInt8() { ctx.allocator, MatPadding::kPacked); PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8, qkv_dim); - std::vector q_float(4 * qkv_dim); - std::vector q_float2(4 * qkv_dim); - // fill in qs with predictable, synthetic data - for (size_t i = 0; i < 4; ++i) { - for (size_t j = 0; j < qkv_dim; j++) { - float val_1 = 0.01f * (i + 1) / (j + 1); - float val_2 = 0.01f * (i + 4 + 1) / (j + 1); - q_float[j * 4 + i] = val_1; - q_float2[j * 4 + i] = val_2; - } - } - const float* q_T[2] = {q_float.data(), q_float2.data()}; + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); @@ -717,7 +696,7 @@ void TestTiledFlashAttentionInt8() { hwy::Span kvs(&kv, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - kvs, num_queries, hwy::Span(q_T, 2), + kvs, num_queries, q_all.data(), hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); @@ -741,22 +720,23 @@ void TestTiledFlashAttentionInt8() { void TestTiledFlashAttentionInt8BF16() { - int qkv_dim = 64; - int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by - // tiles size to test the padding logic. - int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + size_t qkv_dim = 64; + size_t kv_seq_len = 60; // number of tokens we will attend to. Not divisible + // by tiles size to test the padding logic. + size_t padded_kv_seq_len = + hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); float att_cap = 10.0f; - int num_queries = 8; - int num_queries_per_timestep = 4; - int num_tokens = num_queries / num_queries_per_timestep; - int kv_seq_end = + size_t num_queries = 8; + size_t num_queries_per_timestep = 4; + size_t num_tokens = num_queries / num_queries_per_timestep; + size_t kv_seq_end = kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; - int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + - 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; + size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + + 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), ctx.allocator, MatPadding::kPacked); @@ -764,22 +744,11 @@ void TestTiledFlashAttentionInt8BF16() { // fill in kvs with predictable, synthetic data matching BF16 paired layout PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim); - std::vector q_all(num_queries * qkv_dim); - for (int i = 0; i < num_queries; ++i) { - for (int j = 0; j < qkv_dim; ++j) { - q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1); - } - } - std::vector q_ptrs(num_queries); - for (int i = 0; i < num_queries; ++i) { - q_ptrs[i] = q_all.data() + i * qkv_dim; - } - auto [transposed_queries, transposed_queries_ptrs, _] = - TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span(q_ptrs), - qkv_dim, /*group_size=*/4); - hwy::Span q_T( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()); + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); + std::vector> bf16_queries(num_queries * + qkv_dim); + CompressQueriesBF16Contiguous(q_all.data(), qkv_dim, num_queries, + bf16_queries.data()); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); @@ -812,7 +781,8 @@ void TestTiledFlashAttentionInt8BF16() { hwy::Span kvs(&kv, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - kvs, num_queries, q_T, hwy::Span(start_pos_per_query), + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); @@ -853,23 +823,12 @@ void TestTiledFlashAttentionInt8Int16() { // fill in kvs with predictable, synthetic data matching BF16 paired layout PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim); - std::vector q_all(num_queries * qkv_dim); - for (int i = 0; i < num_queries; ++i) { - for (int j = 0; j < qkv_dim; ++j) { - q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1); - } - } - std::vector q_ptrs(num_queries); - for (int i = 0; i < num_queries; ++i) { - q_ptrs[i] = q_all.data() + i * qkv_dim; - } - auto [transposed_queries, transposed_queries_ptrs, q_scales] = - TransposeQueriesToGroupsOfNBF16orInt16( - hwy::Span(q_ptrs), qkv_dim, /*group_size=*/4); - hwy::Span q_T( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()); - + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); + std::vector> int16_queries( + num_queries * qkv_dim); + AlignedFloatVector q_scales(num_queries); + CompressQueriesInt16Contiguous(q_all.data(), qkv_dim, num_queries, + int16_queries.data(), q_scales.data()); MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); using DF = hn::ScalableTag; @@ -901,7 +860,7 @@ void TestTiledFlashAttentionInt8Int16() { hwy::Span kvs(&kv, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( - kvs, num_queries, q_T, q_scales, + kvs, num_queries, int16_queries.data(), q_scales, hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); diff --git a/gemma/tiled_attention.cc b/gemma/tiled_attention.cc index 9c6a2a77..06b06468 100644 --- a/gemma/tiled_attention.cc +++ b/gemma/tiled_attention.cc @@ -295,241 +295,94 @@ static HWY_INLINE void ComputeQKVTransposedTile( }); } -// Transposes queries -// Input: vector of pointers to subsequent queries. (allows for arbitrary -// strides) -// qkv_dim: dimension of query -// allocator: aligned allocator to use for temporary storage -// -// Output: Pointer to contiguous memory with shape (qkv_dim, -// queries.size()) -void TransposeStridedQueries( - hwy::Span queries, int qkv_dim, - hwy::Span transposed_queries) { +template +static HWY_INLINE void CompressSingleQueryBF16orInt16( + DF df, DOut d_out, const float* q_ptr, int qkv_dim, OutT* out_ptr, + float* scale_out = nullptr) { namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - const DF df; - using VF = hn::Vec; - using DI = hn::ScalableTag; - const DI di; - using VI = hn::Vec; const size_t lanes = hn::Lanes(df); - const size_t num_queries = queries.size(); - const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, lanes); - std::vector> query_offsets( - num_queries_rounded_up); - for (size_t i = 0; i < num_queries; ++i) { - query_offsets[i] = queries[i] - queries[0]; - } - for (size_t i = num_queries; i < num_queries_rounded_up; ++i) { - // last offset is the same so gather doesn't read out of bounds - query_offsets[i] = query_offsets[num_queries - 1]; + const hn::ScalableTag d_out_full; + float s = 1.0f; + if constexpr (IsInt16()) { + float max_abs = AbsMaxOfSpan(hwy::Span(q_ptr, qkv_dim)); + s = max_abs == 0.0f ? 1.0f : 32767.0f / max_abs; + *scale_out = 1.0f / s; } + auto scale_vec = hn::Set(df, s); - for (size_t i = 0; i < qkv_dim; i++) { - size_t j = 0; - if (num_queries >= lanes) { - for (; j <= num_queries-lanes; j += lanes) { - const VI offsets = hn::LoadU(di, query_offsets.data() + j); - VF x = hn::GatherIndex(df, queries[0] + i, offsets); - hn::StoreU(x, df, transposed_queries.data() + i * num_queries + j); - } - } - if (j < num_queries) { - const VI offsets = hn::LoadU(di, query_offsets.data() + j); - VF x = hn::GatherIndex(df, queries[0] + i, offsets); - hn::StoreN(x, df, transposed_queries.data() + i * num_queries + j, - num_queries - j); + for (size_t i = 0; i < qkv_dim; i += 2 * lanes) { + auto x0 = hn::LoadU(df, q_ptr + i); + auto x1 = hn::LoadU(df, q_ptr + i + lanes); + if constexpr (IsInt16()) { + x0 = hn::Mul(x0, scale_vec); + x1 = hn::Mul(x1, scale_vec); + auto demoted = hn::OrderedDemote2To(d_out_full, hn::NearestInt(x0), + hn::NearestInt(x1)); + hn::StoreU(demoted, d_out_full, out_ptr + i); + } else { + auto demoted = hn::OrderedDemote2To(d_out_full, x0, x1); + hn::StoreU(demoted, d_out_full, out_ptr + i); } } } -std::pair> TransposeQueriesToGroupsOf4( - hwy::Span queries_ptrs, int qkv_dim) { - int num_queries = queries_ptrs.size(); - int num_groups = hwy::DivCeil(num_queries, 4); - AlignedFloatVector transposed_queries(num_groups * 4 * qkv_dim); - std::vector transposed_queries_ptrs; - for (int group_idx = 0; group_idx < num_groups; ++group_idx){ - int group_size = std::min(4, num_queries - group_idx * 4); - transposed_queries_ptrs.push_back(transposed_queries.data() + - group_idx * qkv_dim * 4); - TransposeStridedQueries( - hwy::Span(queries_ptrs.data() + group_idx * 4, - group_size), - qkv_dim, - hwy::Span(transposed_queries_ptrs.back(), qkv_dim * group_size)); +template +static HWY_INLINE void CompressQueriesBF16orInt16( + hwy::Span input, int qkv_dim, OutT* HWY_RESTRICT output, + float* HWY_RESTRICT scale = nullptr) { + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + const DF df; + auto d_out = hn::Rebind(); + const size_t num_queries = input.size(); + + for (size_t q = 0; q < num_queries; ++q) { + CompressSingleQueryBF16orInt16( + df, d_out, input[q], qkv_dim, output + q * qkv_dim, + scale == nullptr ? nullptr : scale + q); } - return std::make_pair(std::move(transposed_queries), - std::move(transposed_queries_ptrs)); } template -static HWY_INLINE void TransposeStridedQueriesBF16orInt16( - hwy::Span queries, int qkv_dim, - hwy::Span transposed_queries, hwy::Span q_scales) { +static HWY_INLINE void CompressQueriesBF16orInt16Contiguous( + const float* HWY_RESTRICT input, int qkv_dim, size_t num_queries, + OutT* HWY_RESTRICT output, float* HWY_RESTRICT scale = nullptr) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; const DF df; - using VF = hn::Vec; - // doubles to avoid moving between int/float domains when gathering - using DF64 = hn::ScalableTag; - const DF64 dd64; - using DI64 = hn::ScalableTag; - const DI64 di64; - using VI64 = hn::Vec; auto d_out = hn::Rebind(); - const size_t lanes = hn::Lanes(df); - const size_t half_lanes = lanes / 2; - const size_t num_queries = queries.size(); - const size_t num_numbers_to_gather = num_queries * 2; - const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, half_lanes); - const size_t num_scales_rounded_up = - hwy::RoundUpTo(num_numbers_to_gather, lanes); - // We store scales twice so we will be able to just load them without a need - // to duplicate for multiplication - AlignedFloatVector inverted_q_scales_doubled(num_scales_rounded_up); - - if constexpr (IsInt16()) { - // compute microscales - for (size_t i = 0; i < num_queries; ++i) { - float max_abs = AbsMaxOfSpan(hwy::Span(queries[i], qkv_dim)); - float scale = max_abs == 0.0f ? 1.0f : 32767.0f / max_abs; - inverted_q_scales_doubled[2 * i] = scale; - inverted_q_scales_doubled[2 * i + 1] = scale; - q_scales[i] = 1.0f / scale; - } - } - - std::vector> query_offsets( - num_queries_rounded_up); - for (size_t i = 0; i < num_queries; ++i) { - query_offsets[i] = (queries[i] - queries[0]) / 2; - } - for (size_t i = num_queries; i < num_queries_rounded_up; ++i) { - // last offset is the same so gather doesn't read out of bounds - query_offsets[i] = query_offsets[num_queries > 0 ? num_queries - 1 : 0]; + for (size_t q = 0; q < num_queries; ++q) { + CompressSingleQueryBF16orInt16( + df, d_out, input + q * qkv_dim, qkv_dim, output + q * qkv_dim, + scale == nullptr ? nullptr : scale + q); } +} - const double* queries_0_double = HWY_RCAST_ALIGNED(const double*, queries[0]); +void CompressQueriesBF16(hwy::Span input, int qkv_dim, + BF16* HWY_RESTRICT output) { + CompressQueriesBF16orInt16(input, qkv_dim, output); +} - // Lambda to handle the scaling and demotion for Int16 types. - auto process_values = [&]() HWY_ATTR { - if constexpr (IsInt16()) { - return [&](VF& x, size_t j) HWY_ATTR { - VF scales = hn::Load(df, inverted_q_scales_doubled.data() + j * 2); - x = hn::Mul(x, scales); - return hn::DemoteTo(d_out, hn::NearestInt(x)); - }; - } else { - return [&](VF& x, size_t j) HWY_ATTR { return hn::DemoteTo(d_out, x); }; - } - }(); - - for (size_t i = 0; i < qkv_dim; i += 2) { - size_t j = 0; - if (num_queries >= half_lanes) { - for (; j <= num_queries - half_lanes; j += half_lanes) { - const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j); - auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets); - VF x = hn::BitCast(df, x64); - if constexpr (IsInt16()) { - auto demoted = process_values(x, j); - hn::Store(demoted, d_out, - transposed_queries.data() + i * num_queries + j * 2); - } else if constexpr (IsBF16()) { - auto demoted = hn::DemoteTo(d_out, x); - hn::Store(demoted, d_out, - transposed_queries.data() + i * num_queries + j * 2); - } else { - static_assert(false, "Unsupported type"); - } - } - } - if (j < num_queries) { - const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j); - auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets); - VF x = hn::BitCast(df, x64); - if constexpr (IsInt16()) { - auto demoted = process_values(x, j); - hn::StoreN(demoted, d_out, - transposed_queries.data() + i * num_queries + j * 2, - num_numbers_to_gather - j * 2); - } else if constexpr (IsBF16()) { - auto demoted = hn::DemoteTo(d_out, x); - hn::StoreN(demoted, d_out, - transposed_queries.data() + i * num_queries + j * 2, - num_numbers_to_gather - j * 2); - } else { - static_assert(false, "Unsupported type"); - } - } - } +void CompressQueriesBF16Contiguous(const float* HWY_RESTRICT input, int qkv_dim, + size_t num_queries, + BF16* HWY_RESTRICT output) { + CompressQueriesBF16orInt16Contiguous(input, qkv_dim, num_queries, + output); } -// Transposed queries data, vector of pointers to transposed queries, vector of -// scales -template -std::tuple>, std::vector, - AlignedFloatVector> -TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span queries_ptrs, - int qkv_dim, size_t group_size) { - size_t num_queries = queries_ptrs.size(); - size_t num_groups = hwy::DivCeil(num_queries, group_size); - std::vector> transposed_queries( - num_groups * group_size * qkv_dim); - std::vector transposed_queries_ptrs; - AlignedFloatVector q_scales(num_groups * 4); - for (size_t group_idx = 0; group_idx < num_groups; ++group_idx) { - size_t current_group_size = - std::min(group_size, num_queries - group_idx * group_size); - transposed_queries_ptrs.push_back(transposed_queries.data() + - group_idx * qkv_dim * group_size); - TransposeStridedQueriesBF16orInt16( - hwy::Span( - const_cast(queries_ptrs.data() + - group_idx * group_size), - current_group_size), - qkv_dim, - hwy::Span(transposed_queries_ptrs.back(), - qkv_dim * current_group_size), - hwy::Span(q_scales.data() + group_idx * group_size, - current_group_size)); - } - return std::make_tuple(std::move(transposed_queries), - std::move(transposed_queries_ptrs), - std::move(q_scales)); +void CompressQueriesInt16(hwy::Span input, int qkv_dim, + int16_t* HWY_RESTRICT output, + float* HWY_RESTRICT scale) { + CompressQueriesBF16orInt16(input, qkv_dim, output, scale); } -std::pair> -TransposeTransposedQueriesAndPackIntoBF16(hwy::Span queries_ptrs, - int qkv_dim, int num_queries) { - constexpr int kMaxGroupSize = 4; - int num_groups = queries_ptrs.size(); - AlignedBF16Vector transposed_queries(num_groups * kMaxGroupSize * qkv_dim); - std::vector transposed_queries_ptrs; - transposed_queries_ptrs.reserve(num_groups); - for (int group_idx = 0; group_idx < num_groups; ++group_idx) { - int group_size = - std::min(kMaxGroupSize, num_queries - group_idx * kMaxGroupSize); - transposed_queries_ptrs.push_back(transposed_queries.data() + - group_idx * qkv_dim * kMaxGroupSize); - for (int dim_idx = 0; dim_idx < qkv_dim; dim_idx += 2) { - for (int query_idx = 0; query_idx < group_size; ++query_idx) { - transposed_queries_ptrs.back()[dim_idx * group_size + query_idx * 2] = - hwy::ConvertScalarTo( - queries_ptrs[group_idx][dim_idx * group_size + query_idx]); - transposed_queries_ptrs - .back()[dim_idx * group_size + query_idx * 2 + 1] = - hwy::ConvertScalarTo( - queries_ptrs[group_idx] - [(dim_idx + 1) * group_size + query_idx]); - } - } - } - return std::make_pair(std::move(transposed_queries), - std::move(transposed_queries_ptrs)); +void CompressQueriesInt16Contiguous(const float* HWY_RESTRICT input, + int qkv_dim, size_t num_queries, + int16_t* HWY_RESTRICT output, + float* HWY_RESTRICT scale) { + CompressQueriesBF16orInt16Contiguous(input, qkv_dim, num_queries, + output, scale); } template @@ -730,40 +583,38 @@ void LocalAttentionForAllHeadsTokensAndBatch( } if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { - // pack transposed queries into BF16 - auto [transposed_queries, transposed_queries_ptrs, _] = - TransposeQueriesToGroupsOfNBF16orInt16( - queries_ptrs_span, qkv_dim, /*group_size=*/4); - hwy::Span queries_span( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()); + HWY_DASSERT(activations.bf16_queries != nullptr); + BF16* bf16_queries_ptr = activations.bf16_queries->data(); + CompressQueriesBF16(queries_ptrs_span, qkv_dim, bf16_queries_ptr); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( - kv_ptrs, num_queries, queries_span, + kv_ptrs, num_queries, bf16_queries_ptr, hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); } else if (attention_impl == AttentionImpl::kFlashTransposedQsInt16) { - auto [transposed_queries, transposed_queries_ptrs, q_scales] = - TransposeQueriesToGroupsOfNBF16orInt16( - queries_ptrs_span, qkv_dim, /*group_size=*/4); - hwy::Span queries_span( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()); + HWY_DASSERT(activations.int16_queries != nullptr); + HWY_DASSERT(activations.q_scales != nullptr); + int16_t* int16_queries_ptr = activations.int16_queries->data(); + float* q_scales_ptr = activations.q_scales->data(); + CompressQueriesInt16(queries_ptrs_span, qkv_dim, int16_queries_ptr, + q_scales_ptr); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( - kv_ptrs, num_queries, queries_span, q_scales, + kv_ptrs, num_queries, int16_queries_ptr, *activations.q_scales, hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); } else { - auto [transposed_queries, transposed_queries_ptrs] = - TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim); + std::vector> contiguous_queries( + num_queries * qkv_dim); + for (int i = 0; i < num_queries; ++i) { + hwy::CopyBytes(queries_ptrs_span[i], + contiguous_queries.data() + i * qkv_dim, + qkv_dim * sizeof(float)); + } DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( - kv_ptrs, num_queries, - hwy::Span( - const_cast(transposed_queries_ptrs.data()), - transposed_queries_ptrs.size()), + kv_ptrs, num_queries, contiguous_queries.data(), hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), diff --git a/gemma/tiled_attention.h b/gemma/tiled_attention.h index 9b1a2ce6..bc06bf6c 100644 --- a/gemma/tiled_attention.h +++ b/gemma/tiled_attention.h @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -15,27 +16,33 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - MatMulEnv& env, int flags); \ - void TransposeStridedQueries(hwy::Span queries, int qkv_dim, \ - hwy::Span transposed_queries); \ - void LocalAttentionForAllHeadsTokensAndBatch( \ - AttentionImpl attention_impl, const size_t num_tokens, \ - const size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - \ - template \ - std::tuple>, \ - std::vector, AlignedFloatVector> \ - TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span queries_ptrs, \ - int qkv_dim, size_t group_size); \ - \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + void LocalAttentionForAllHeadsTokensAndBatch( \ + AttentionImpl attention_impl, const size_t num_tokens, \ + const size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + void CompressQueriesBF16(hwy::Span input, int qkv_dim, \ + BF16* HWY_RESTRICT output); \ + void CompressQueriesBF16Contiguous(const float* HWY_RESTRICT input, \ + int qkv_dim, size_t num_queries, \ + BF16* HWY_RESTRICT output); \ + \ + void CompressQueriesInt16(hwy::Span input, int qkv_dim, \ + int16_t* HWY_RESTRICT output, \ + float* HWY_RESTRICT scale); \ + \ + void CompressQueriesInt16Contiguous(const float* HWY_RESTRICT input, \ + int qkv_dim, size_t num_queries, \ + int16_t* HWY_RESTRICT output, \ + float* HWY_RESTRICT scale); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index f93338b2..5e09c2a3 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -17,6 +17,7 @@ #include "gemma/weights.h" #include "util/mat.h" #include "util/threading_context.h" +#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -197,33 +198,32 @@ struct AttentionTestEnv { MatStorageT key_norm_scale; }; -void TestTransposeStridedQueries() { +void TestCompressQueries() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); size_t qkv_dim = 64; size_t num_queries = 24; AlignedPtr input_queries = ctx.allocator.Alloc(qkv_dim * num_queries); - AlignedPtr output_queries = - ctx.allocator.Alloc(qkv_dim * num_queries); + AlignedPtr output_queries = + ctx.allocator.Alloc(qkv_dim * num_queries); for (size_t i = 0; i < num_queries; ++i) { for (size_t j = 0; j < qkv_dim; ++j) { - input_queries[i * qkv_dim + j] = i * qkv_dim + j; + input_queries[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1); } } - std::vector queries; + std::vector queries; for (size_t i = 0; i < num_queries; ++i) { queries.push_back(input_queries.get() + i * qkv_dim); } - hwy::Span queries_span(queries.data(), queries.size()); - TransposeStridedQueries( - queries_span, qkv_dim, - hwy::Span(output_queries.get(), qkv_dim * num_queries)); + CompressQueriesBF16( + hwy::Span(queries.data(), queries.size()), qkv_dim, + output_queries.get()); for (size_t i = 0; i < num_queries; ++i) { for (size_t j = 0; j < qkv_dim; ++j) { - EXPECT_EQ(output_queries[j * num_queries + i], - input_queries[i * qkv_dim + j]) + EXPECT_NEAR(hwy::ConvertScalarTo(output_queries[i * qkv_dim + j]), + input_queries[i * qkv_dim + j], 1e-3) << "i=" << i << " j=" << j; } } @@ -777,7 +777,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(TiledAttentionTest); -HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestCompressQueries); // TODO() Fix the goldens for the change in KV_t to BF16 // HWY_EXPORT_AND_TEST_P(TiledAttentionTest, // TestLocalAttentionForAllHeadsTokensAndBatch);