Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ struct AttentionActivations {
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
size_t max_workers, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: rep_factor(max_workers *
: heads(layer_config.heads),
qkv_dim(layer_config.qkv_dim),
rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
layer_config.heads),
// `vocab_size == 0` means it is for Vit part, VitAttention
Expand Down Expand Up @@ -115,6 +117,9 @@ struct AttentionActivations {
// query tiles, which is not known here.
flash_params.reserve(batch_size * layer_config.heads);
split_flash_params.reserve(batch_size * layer_config.heads);
bf16_queries.resize(batch_size * heads * qkv_dim);
int16_queries.resize(batch_size * heads * qkv_dim);
q_scales.resize(batch_size * heads);

// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
Expand All @@ -140,10 +145,19 @@ struct AttentionActivations {
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
bf16_queries.resize(batch_size * heads * qkv_dim);
int16_queries.resize(batch_size * heads * qkv_dim);
q_scales.resize(batch_size * heads);

// `inv_timescale*` are not batched.
}

size_t heads;
size_t qkv_dim;
AlignedBF16Vector bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>> int16_queries;
AlignedFloatVector q_scales;

// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
Expand Down Expand Up @@ -191,6 +205,9 @@ struct AttentionActivationsPtrs {
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
bf16_queries(nullptr),
int16_queries(nullptr),
q_scales(nullptr),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
Expand All @@ -212,6 +229,9 @@ struct AttentionActivationsPtrs {
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
bf16_queries = &activations.bf16_queries;
int16_queries = &activations.int16_queries;
q_scales = &activations.q_scales;
}

void SetBatchSize(size_t batch_size) {
Expand Down Expand Up @@ -277,6 +297,9 @@ struct AttentionActivationsPtrs {
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>*
sub_task_max_logits;
AlignedBF16Vector* bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>>* int16_queries;
AlignedFloatVector* q_scales;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.
Expand Down
Loading
Loading