diff --git a/xllm/core/runtime/embed_vlm_worker_impl.cpp b/xllm/core/runtime/embed_vlm_worker_impl.cpp index e188df2e0..398a0ff0a 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/embed_vlm_worker_impl.cpp @@ -84,7 +84,6 @@ std::optional EmbedVLMWorkerImpl::step( input.sampling_params.is_embeddings) { auto embeddings = model_->pooler(hidden_states, sampling_params.selected_token_idxes); - sample_output.embeddings = embeddings; // split full embeddings and add them to mm_embeddings // so that the user could receive embeddings of images and texts if (FLAGS_enable_return_mm_full_embeddings) { @@ -97,10 +96,12 @@ std::optional EmbedVLMWorkerImpl::step( sample_output.mm_embeddings.emplace_back(image_embed); token_start_idx += seq_len; } + output.sample_output = sample_output; + } else { + sample_output.embeddings = embeddings; + output.sample_output = sample_output; + output.embedding = embeddings; } - - output.sample_output = sample_output; - output.embedding = embeddings; } ret = device_.synchronize_default_stream(); return output; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index e3e74395a..35fedcf71 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -344,9 +344,10 @@ inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { return size; } -inline size_t get_dit_forward_output_size(const DiTForwardOutput& output) { - size_t size = type_size; // vector size - for (const auto& tensor : output.tensors) { +inline size_t get_vector_tensor_size( + const std::vector& tensor_vec) { + size_t size = type_size; // vector size + for (const auto& tensor : tensor_vec) { size += get_tensor_size(tensor); } return size; @@ -1017,14 +1018,6 @@ inline void write_dit_forward_input(RawInputSerializeContext& context, write_dit_generation_params(context, input.generation_params); } -inline void write_dit_forward_output(char*& buffer, - const DiTForwardOutput& output) { - write_data(buffer, static_cast(output.tensors.size())); - for (const auto& tensor : output.tensors) { - write_tensor(buffer, tensor); - } -} - inline void safe_advance_buffer(const char*& buffer, size_t offset) { if (buffer != nullptr) { buffer += offset; @@ -1876,16 +1869,6 @@ inline void read_dit_forward_input(ReadContext& context, read_dit_generation_params(context, input.generation_params); } -inline void read_dit_forward_output(const char*& buffer, - DiTForwardOutput& output) { - uint64_t size; - read_data(buffer, size); - output.tensors.resize(size); - for (auto& tensor : output.tensors) { - read_tensor(buffer, tensor); - } -} - inline void initialize_device_buffer_session(ReadContext& context, ForwardInput& forward_input, const torch::Device& device, @@ -2094,7 +2077,9 @@ inline void deserialize_raw_forward_input(const char*& buffer, /*force_host_materialize=*/true); #endif - read_dit_forward_input(context, input_params.dit_forward_input); + if (FLAGS_backend == "dit") { + read_dit_forward_input(context, input_params.dit_forward_input); + } finalize_device_buffer_session(device_session, stream); buffer = context.tensor_cursor; @@ -2177,8 +2162,9 @@ inline void serialize_raw_forward_input_sections( write_vector_to_tensor(context, input.new_token_slot_ids); write_2d_vector_to_tensor(context, input.block_tables_vec); - - write_dit_forward_input(context, input.dit_forward_input); + if (FLAGS_backend == "dit") { + write_dit_forward_input(context, input.dit_forward_input); + } } inline RawInputLayoutHeader calculate_raw_forward_input_layout( @@ -2260,8 +2246,12 @@ size_t calculate_raw_forward_output_size(const RawForwardOutput& output) { size += get_vector_size(output.out_tokens); size += get_vector_size(output.out_logprobs); size += type_size; // prepared_layer_id + // mm_embedding_data + size += get_vector_tensor_size(output.mm_embeddings); // dit output data - size += get_dit_forward_output_size(output.dit_forward_output); + if (FLAGS_backend == "dit") { + size += get_vector_tensor_size(output.dit_forward_output.tensors); + } return size; } @@ -2327,8 +2317,11 @@ void deserialize_raw_forward_output(const char* buffer, read_data(buffer, output.prepared_layer_id); read_vector_tensor(buffer, output.mm_embeddings); + // read dit output - read_dit_forward_output(buffer, output.dit_forward_output); + if (FLAGS_backend == "dit") { + read_vector_tensor(buffer, output.dit_forward_output.tensors); + } } void serialize_raw_forward_output(const RawForwardOutput& output, @@ -2344,7 +2337,9 @@ void serialize_raw_forward_output(const RawForwardOutput& output, write_vector_tensor(buffer, output.mm_embeddings); // write dit output - write_dit_forward_output(buffer, output.dit_forward_output); + if (FLAGS_backend == "dit") { + write_vector_tensor(buffer, output.dit_forward_output.tensors); + } } void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,