Skip to content

Commit a7e15a6

Browse files
committed
feat: support new model glm4-flash.
1 parent 9f72f29 commit a7e15a6

File tree

3 files changed

+19
-31
lines changed

3 files changed

+19
-31
lines changed

xllm/core/runtime/vlm_master.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ void VLMMaster::handle_request(const std::vector<Message>& messages,
220220
"Image processor process failed.");
221221
return;
222222
}
223-
223+
if (const auto& res = mm_data.get<torch::Tensor>("image_grid_thw"))
224+
{
225+
auto image_grid_thw = res.value();
226+
LOG(INFO)<<"image_grid_thw:"<<image_grid_thw;
227+
}
224228
this->handle_request(messages, mm_data, sp, callback);
225229
}
226230

@@ -307,7 +311,6 @@ std::shared_ptr<Request> VLMMaster::generate_request(std::string prompt,
307311
}
308312
Timer timer;
309313
input_processor_->process(prompt, mm_data);
310-
311314
std::vector<int> prompt_tokens;
312315
if (!tokenizer_->encode(prompt, &prompt_tokens)) {
313316
LOG(ERROR) << "Failed to encode prompt: " << prompt;

xllm/models/llm/glm4.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,18 @@ class Glm4ModelImpl : public LlmModelImplBase<Glm4DecoderLayer> {
9393

9494
if (positions.dim() == 2) { // mrope
9595
auto apply = [this](torch::Tensor x) {
96-
auto freqs_t = x[0].clone();
97-
for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) {
98-
int64_t offset = dim_idx;
99-
int64_t section_len = mrope_section_[dim_idx];
100-
int64_t length = section_len * 2;
101-
auto idx_first_half = torch::arange(offset, length, 3, torch::kLong);
102-
auto idx_second_half = torch::arange(offset, length, 3, torch::kLong);
103-
auto idx_tensor =
104-
torch::cat({idx_first_half, idx_second_half}, 0).to(x.device());
105-
// freqs_t[..., idx] = freqs[dim_idx][..., idx]
106-
auto src = x[dim_idx].index_select(-1, idx_tensor);
107-
freqs_t.index_copy_(-1, idx_tensor, src);
96+
auto sections = mrope_section_;
97+
sections.insert(sections.end(), sections.begin(), sections.end());
98+
99+
auto vec = x.split(sections, -1);
100+
std::vector<torch::Tensor> selects;
101+
selects.reserve(vec.size());
102+
103+
for (int64_t i = 0; i < vec.size(); ++i) {
104+
auto m = vec[i];
105+
selects.push_back(m[i % mrope_section_.size()]);
108106
}
109-
return freqs_t;
107+
return torch::cat(selects, -1);
110108
};
111109
cos_pos = apply(cos_pos.reshape(
112110
{positions.sizes().front(), -1, cos_pos.sizes().back()}));

xllm/models/vlm/glm4v.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
605605
blocks_->push_back(block);
606606
layers_.push_back(block);
607607
}
608-
// TODO 融合算子
609608
post_layernorm_ = register_module("post_layernorm", Glm4VisionRmsNorm(context));
610609

611610
downsample_ = register_module("downsample", torch::nn::Conv2d(torch::nn::Conv2dOptions(hidden_size_, out_hidden_size_, spatial_merge_size_)
@@ -672,8 +671,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
672671
auto repeated = torch::repeat_interleave(h_times_w, repeats, 0);
673672
c10::optional<torch::ScalarType> cumsum_dtype;
674673

675-
LOG(INFO) << " Glm4VisionTransformerImpl repeated " << repeated;
676-
677674
cumsum_dtype = torch::kInt32;
678675
auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype);
679676
namespace F = torch::nn::functional;
@@ -682,27 +679,21 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
682679
std::vector<int> seqlens;
683680
seqlens.assign(cu_seqlens.data_ptr<int>(),cu_seqlens.data_ptr<int>() + cu_seqlens.numel());
684681

685-
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding before cu_seqlens " << cu_seqlens << "seqlens.size()" << seqlens.size();
686682
hidden_states = embeddings_(hidden_states, seqlens, grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1));
687-
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding after ";
688683
ModelInputParams& input_params_new = const_cast<ModelInputParams&>(input_params);
689684
torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu();
690685
std::vector<int> cu_seqlens_vec(
691-
cu_seqlens_cpu.data_ptr<int>(), // full seqlen vec
686+
cu_seqlens_cpu.data_ptr<int>(),
692687
cu_seqlens_cpu.data_ptr<int>() + cu_seqlens_cpu.numel());
688+
cu_seqlens = cu_seqlens.to(hidden_states.device());
693689
for (int idx = 0; idx < blocks_->size(); ++idx) {
694-
hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); //TODO
695-
LOG(INFO) << " Glm4VisionTransformerImpl forward layer "<< idx;
690+
hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx);
696691
}
697-
LOG(INFO) << " Glm4VisionTransformerImpl forward layer after ";
698692
hidden_states = post_layernorm_(hidden_states);
699693
hidden_states = hidden_states.view({-1, spatial_merge_size_, spatial_merge_size_, hidden_states.size(-1)});
700-
// TO down sample merge op
701694
hidden_states = hidden_states.permute({0, 3, 1, 2});
702695
hidden_states = downsample_(hidden_states).view({-1, out_hidden_size_});
703-
LOG(INFO) << " Glm4VisionTransformerImpl downsample after";
704696
hidden_states = merger_(hidden_states);
705-
LOG(INFO) << " Glm4VisionTransformerImpl forward end";
706697
return hidden_states;
707698
};
708699

@@ -820,12 +811,10 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
820811
const ModelInputParams& input_params) {
821812
auto inputs_embeds = language_model_->get_input_embeddings(input_ids);
822813
if (image_input) {
823-
// visual
824814
auto image_embeds =
825815
visual_(image_input->pixel_values.to(options_),
826816
image_input->image_grid_thw,
827817
input_params);
828-
// merge
829818
auto is_multimodal = torch::isin(input_ids,
830819
model_args_.image_token_id()); input_params.visual_pos_masks =
831820
is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds);
@@ -851,7 +840,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
851840

852841
if (pixel_values.defined() && image_grid_thw.defined())
853842
image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw};
854-
855843
auto inputs_embeds = get_input_embeddings(tokens, image_inputs, video_inputs, input_params);
856844
input_params.input_embedding = inputs_embeds;
857845
auto emb = language_model_(tokens, positions, kv_caches, input_params);
@@ -869,7 +857,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
869857
visual_->load_state_dict(
870858
state_dict->get_dict_with_prefix("model.visual."));
871859
}
872-
// verify
873860
visual_->verify_loaded_weights("model.visual.");
874861
visual_->merge_loaded_weights();
875862
if (!model_args_.image_embedding_mode()) {

0 commit comments

Comments
 (0)