Skip to content

Commit 01f2c50

Browse files
committed
feat: support new model glm4-flash.
1 parent 6cda3ee commit 01f2c50

File tree

12 files changed

+250
-245
lines changed

12 files changed

+250
-245
lines changed

xllm/core/framework/chat_template/jinja_chat_template.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,24 @@ std::optional<std::string> JinjaChatTemplate::apply(
121121
nlohmann::ordered_json& messages,
122122
const nlohmann::ordered_json& tools,
123123
const nlohmann::ordered_json& chat_template_kwargs) const {
124+
for (auto& msg : messages) {
125+
if (!msg.contains("content")) continue;
126+
auto& content = msg["content"];
127+
auto normalize_item = [](nlohmann::ordered_json& item) {
128+
if (item.contains("type") && item["type"].is_string()) {
129+
std::string t = item["type"].get<std::string>();
130+
if (t == "video_url") item["type"] = "video";
131+
}
132+
if (item.contains("video_url") && !item.contains("video"))
133+
item["video"] = item["video_url"];
134+
};
135+
136+
if (content.is_array()) {
137+
for (auto& it : content) normalize_item(it);
138+
} else if (content.is_object()) {
139+
normalize_item(content);
140+
}
141+
}
124142
minja::chat_template_inputs input;
125143
input.messages = messages;
126144
input.tools = tools;
@@ -137,23 +155,10 @@ nlohmann::ordered_json JinjaChatTemplate::get_mm_content(
137155

138156
for (const auto& item : vec) {
139157
nlohmann::ordered_json item_json;
140-
if (item.type == "video_url") {
141-
item_json["type"] = "video";
142-
} else {
143-
item_json["type"] = item.type;
144-
}
158+
item_json["type"] = item.type;
145159

146160
if (item.type == "text") {
147161
item_json["text"] = item.text;
148-
} else if (item.type == "video_url") {
149-
item_json["video"] = "mm place holder";
150-
item_json["video_url"] = "mm place holder";
151-
} else if (item.type == "image_url") {
152-
item_json["image"] = "mm place holder";
153-
item_json["image_url"] = "mm place holder";
154-
} else if (item.type == "audio_url") {
155-
item_json["audio"] = "mm place holder";
156-
item_json["audio_url"] = "mm place holder";
157162
} else {
158163
item_json[item.type] = "mm place holder";
159164
}

xllm/core/framework/request/mm_codec.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ bool OpenCVVideoDecoder::decode(const std::string& raw_data,
159159
av_dict_set(&opts, "probesize", "20000000", 0);
160160
av_dict_set(&opts, "analyzeduration", "5000000", 0);
161161

162-
int ret = avformat_open_input(&fmt, nullptr, nullptr, &opts);
162+
const AVInputFormat* in_fmt = av_find_input_format("mp4");
163+
int ret = avformat_open_input(&fmt, nullptr, in_fmt, &opts);
163164
av_dict_free(&opts);
164165

165166
if (ret < 0) {

xllm/core/framework/request/mm_input.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ struct MMInput {
5858
return std::move(vec);
5959
}
6060

61-
std::vector<VideoMetadata> get_video_metadata() const {
61+
std::vector<VideoMetadata> get_video_metadata(MMType type) const {
6262
std::vector<VideoMetadata> metas;
6363
metas.reserve(items_.size());
6464
for (auto& item : items_) {
65-
metas.push_back(item.video_meta_);
65+
if (item.type_ == type) {
66+
metas.push_back(item.video_meta_);
67+
}
6668
}
6769
return metas;
6870
}

xllm/core/runtime/vlm_master.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ void VLMMaster::handle_request(const std::vector<Message>& messages,
220220
"Image processor process failed.");
221221
return;
222222
}
223-
223+
if (const auto& res = mm_data.get<torch::Tensor>("image_grid_thw"))
224+
{
225+
auto image_grid_thw = res.value();
226+
LOG(INFO)<<"image_grid_thw:"<<image_grid_thw;
227+
}
224228
this->handle_request(messages, mm_data, sp, callback);
225229
}
226230

@@ -307,7 +311,6 @@ std::shared_ptr<Request> VLMMaster::generate_request(std::string prompt,
307311
}
308312
Timer timer;
309313
input_processor_->process(prompt, mm_data);
310-
311314
std::vector<int> prompt_tokens;
312315
if (!tokenizer_->encode(prompt, &prompt_tokens)) {
313316
LOG(ERROR) << "Failed to encode prompt: " << prompt;

xllm/models/llm/glm4.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,18 @@ class Glm4ModelImpl : public LlmModelImplBase<Glm4DecoderLayer> {
9393

9494
if (positions.dim() == 2) { // mrope
9595
auto apply = [this](torch::Tensor x) {
96-
auto freqs_t = x[0].clone();
97-
for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) {
98-
int64_t offset = dim_idx;
99-
int64_t section_len = mrope_section_[dim_idx];
100-
int64_t length = section_len * 2;
101-
auto idx_first_half = torch::arange(offset, length, 3, torch::kLong);
102-
auto idx_second_half = torch::arange(offset, length, 3, torch::kLong);
103-
auto idx_tensor =
104-
torch::cat({idx_first_half, idx_second_half}, 0).to(x.device());
105-
// freqs_t[..., idx] = freqs[dim_idx][..., idx]
106-
auto src = x[dim_idx].index_select(-1, idx_tensor);
107-
freqs_t.index_copy_(-1, idx_tensor, src);
96+
auto sections = mrope_section_;
97+
sections.insert(sections.end(), sections.begin(), sections.end());
98+
99+
auto vec = x.split(sections, -1);
100+
std::vector<torch::Tensor> selects;
101+
selects.reserve(vec.size());
102+
103+
for (int64_t i = 0; i < vec.size(); ++i) {
104+
auto m = vec[i];
105+
selects.push_back(m[i % mrope_section_.size()]);
108106
}
109-
return freqs_t;
107+
return torch::cat(selects, -1);
110108
};
111109
cos_pos = apply(cos_pos.reshape(
112110
{positions.sizes().front(), -1, cos_pos.sizes().back()}));

xllm/models/vlm/glm4v.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
605605
blocks_->push_back(block);
606606
layers_.push_back(block);
607607
}
608-
// TODO 融合算子
609608
post_layernorm_ = register_module("post_layernorm", Glm4VisionRmsNorm(context));
610609

611610
downsample_ = register_module("downsample", torch::nn::Conv2d(torch::nn::Conv2dOptions(hidden_size_, out_hidden_size_, spatial_merge_size_)
@@ -672,8 +671,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
672671
auto repeated = torch::repeat_interleave(h_times_w, repeats, 0);
673672
c10::optional<torch::ScalarType> cumsum_dtype;
674673

675-
LOG(INFO) << " Glm4VisionTransformerImpl repeated " << repeated;
676-
677674
cumsum_dtype = torch::kInt32;
678675
auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype);
679676
namespace F = torch::nn::functional;
@@ -682,27 +679,21 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
682679
std::vector<int> seqlens;
683680
seqlens.assign(cu_seqlens.data_ptr<int>(),cu_seqlens.data_ptr<int>() + cu_seqlens.numel());
684681

685-
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding before cu_seqlens " << cu_seqlens << "seqlens.size()" << seqlens.size();
686682
hidden_states = embeddings_(hidden_states, seqlens, grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1));
687-
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding after ";
688683
ModelInputParams& input_params_new = const_cast<ModelInputParams&>(input_params);
689684
torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu();
690685
std::vector<int> cu_seqlens_vec(
691-
cu_seqlens_cpu.data_ptr<int>(), // full seqlen vec
686+
cu_seqlens_cpu.data_ptr<int>(),
692687
cu_seqlens_cpu.data_ptr<int>() + cu_seqlens_cpu.numel());
688+
cu_seqlens = cu_seqlens.to(hidden_states.device());
693689
for (int idx = 0; idx < blocks_->size(); ++idx) {
694-
hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); //TODO
695-
LOG(INFO) << " Glm4VisionTransformerImpl forward layer "<< idx;
690+
hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx);
696691
}
697-
LOG(INFO) << " Glm4VisionTransformerImpl forward layer after ";
698692
hidden_states = post_layernorm_(hidden_states);
699693
hidden_states = hidden_states.view({-1, spatial_merge_size_, spatial_merge_size_, hidden_states.size(-1)});
700-
// TO down sample merge op
701694
hidden_states = hidden_states.permute({0, 3, 1, 2});
702695
hidden_states = downsample_(hidden_states).view({-1, out_hidden_size_});
703-
LOG(INFO) << " Glm4VisionTransformerImpl downsample after";
704696
hidden_states = merger_(hidden_states);
705-
LOG(INFO) << " Glm4VisionTransformerImpl forward end";
706697
return hidden_states;
707698
};
708699

@@ -820,12 +811,10 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
820811
const ModelInputParams& input_params) {
821812
auto inputs_embeds = language_model_->get_input_embeddings(input_ids);
822813
if (image_input) {
823-
// visual
824814
auto image_embeds =
825815
visual_(image_input->pixel_values.to(options_),
826816
image_input->image_grid_thw,
827817
input_params);
828-
// merge
829818
auto is_multimodal = torch::isin(input_ids,
830819
model_args_.image_token_id()); input_params.visual_pos_masks =
831820
is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds);
@@ -851,7 +840,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
851840

852841
if (pixel_values.defined() && image_grid_thw.defined())
853842
image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw};
854-
855843
auto inputs_embeds = get_input_embeddings(tokens, image_inputs, video_inputs, input_params);
856844
input_params.input_embedding = inputs_embeds;
857845
auto emb = language_model_(tokens, positions, kv_caches, input_params);
@@ -869,7 +857,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
869857
visual_->load_state_dict(
870858
state_dict->get_dict_with_prefix("model.visual."));
871859
}
872-
// verify
873860
visual_->verify_loaded_weights("model.visual.");
874861
visual_->merge_loaded_weights();
875862
if (!model_args_.image_embedding_mode()) {

xllm/processors/glm4v_image_processor.cpp

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -77,117 +77,6 @@ std::optional<Size> smart_resize(int num_frames,
7777
}
7878
} // namespace
7979

80-
torch::Tensor Glm4VImageProcessor::sample_frames(const VideoMetadata& metadata,
81-
int temporal_patch_size) {
82-
// video: [T, C, H, W]
83-
const int total_frames = metadata.total_num_frames;
84-
if (total_frames <= 0) {
85-
return torch::empty({0}, torch::dtype(torch::kLong));
86-
}
87-
88-
if (metadata.fps <= 0.0) {
89-
LOG(FATAL) << "invalid metadata.fps <= 0";
90-
}
91-
92-
const int max_frame_idx = total_frames - 1;
93-
94-
// duration = metadata.duration or round(max_idx / fps) + 1
95-
double duration = metadata.duration;
96-
if (duration <= 0.0) {
97-
duration =
98-
std::round(static_cast<double>(max_frame_idx) / metadata.fps) + 1.0;
99-
}
100-
101-
constexpr double DYN_FPS_30 = 3.0;
102-
constexpr double DYN_FPS_300 = 1.0;
103-
constexpr double DYN_FPS_2400 = 0.5;
104-
constexpr int MAX_FRAME_COUNT_DYNAMIC = 640;
105-
constexpr double MAX_DURATION = 2400.0;
106-
107-
const double effective_duration = std::min(duration, MAX_DURATION);
108-
109-
double target_fps = 0.0;
110-
if (effective_duration <= 30.0) {
111-
target_fps = DYN_FPS_30;
112-
} else if (effective_duration <= 300.0) {
113-
target_fps = DYN_FPS_300;
114-
} else {
115-
target_fps = DYN_FPS_2400;
116-
}
117-
118-
// extract_t = int(effective_duration * target_fps * temporal_patch_size)
119-
int extract_t = static_cast<int>(effective_duration * target_fps *
120-
static_cast<double>(temporal_patch_size));
121-
extract_t = std::min(extract_t, MAX_FRAME_COUNT_DYNAMIC);
122-
123-
const double duration_per_frame = 1.0 / metadata.fps;
124-
std::vector<double> timestamps(total_frames);
125-
for (int i = 0; i < total_frames; ++i) {
126-
timestamps[i] = static_cast<double>(i) * duration_per_frame;
127-
}
128-
const int max_second = static_cast<int>(duration);
129-
130-
torch::Tensor frame_indices;
131-
132-
if (total_frames < extract_t) {
133-
frame_indices = torch::linspace(
134-
0, total_frames - 1, extract_t, torch::dtype(torch::kLong));
135-
} else {
136-
std::vector<int64_t> tmp;
137-
tmp.reserve(static_cast<size_t>(total_frames));
138-
double current_second = 0.0;
139-
const double inv_fps =
140-
1.0 / (static_cast<double>(temporal_patch_size) * target_fps);
141-
142-
for (int frame_index = 0; frame_index < total_frames; frame_index++) {
143-
if (timestamps[frame_index] >= current_second) {
144-
current_second += inv_fps;
145-
tmp.push_back(frame_index);
146-
if (current_second >= static_cast<double>(max_second)) {
147-
break;
148-
}
149-
}
150-
}
151-
frame_indices =
152-
torch::tensor(tmp, torch::TensorOptions().dtype(torch::kLong));
153-
}
154-
int64_t len = frame_indices.size(0);
155-
if (len < extract_t) {
156-
int64_t start, end;
157-
if (len == 0) {
158-
start = 0;
159-
end = std::max<int64_t>(total_frames - 1, 0);
160-
} else {
161-
start = frame_indices[0].item<int64_t>();
162-
end = frame_indices[len - 1].item<int64_t>();
163-
}
164-
frame_indices =
165-
torch::linspace(start, end, extract_t, torch::dtype(torch::kLong));
166-
} else if (len > extract_t) {
167-
frame_indices = torch::linspace(
168-
0, total_frames - 1, extract_t, torch::dtype(torch::kLong));
169-
}
170-
171-
len = frame_indices.size(0);
172-
std::unordered_set<int64_t> seen;
173-
seen.reserve(static_cast<size_t>(len) * 2);
174-
std::vector<int64_t> uniq;
175-
uniq.reserve(static_cast<size_t>(len));
176-
177-
for (int64_t i = 0; i < len; ++i) {
178-
auto idx = frame_indices[i].item<int64_t>();
179-
if (seen.insert(idx).second) {
180-
uniq.push_back(idx);
181-
}
182-
}
183-
184-
if (!uniq.empty() && (uniq.size() & 1)) {
185-
uniq.push_back(uniq.back());
186-
}
187-
188-
return torch::tensor(uniq, torch::TensorOptions().dtype(torch::kLong));
189-
}
190-
19180
Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
19281
image_mean_ = args.mm_image_normalize_mean();
19382
image_std_ = args.mm_image_normalize_std();
@@ -223,7 +112,8 @@ Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
223112
bool Glm4VImageProcessor::process(const MMInput& inputs, MMData& datas) {
224113
std::vector<torch::Tensor> images = inputs.get_decode_data(MMType::IMAGE);
225114
std::vector<torch::Tensor> videos = inputs.get_decode_data(MMType::VIDEO);
226-
std::vector<VideoMetadata> video_meta_list = inputs.get_video_metadata();
115+
std::vector<VideoMetadata> video_meta_list =
116+
inputs.get_video_metadata(MMType::VIDEO);
227117

228118
if (images.empty() && (videos.empty() || video_meta_list.empty())) {
229119
LOG(ERROR) << "no image/video tensor found.";
@@ -359,8 +249,8 @@ bool Glm4VImageProcessor::process_videos(
359249

360250
auto values = torch::cat(pixel_values);
361251
auto thw = torch::tensor(grids).clone().reshape({-1, 3});
362-
mm_datas.add(MMType::VIDEO, "video_grid_thw", thw);
363-
mm_datas.add(MMType::VIDEO, "pixel_values_videos", values);
252+
mm_datas.update(MMType::VIDEO, "video_grid_thw", thw);
253+
mm_datas.update(MMType::VIDEO, "pixel_values_videos", values);
364254

365255
return true;
366256
}
@@ -376,11 +266,9 @@ bool Glm4VImageProcessor::process_video(
376266

377267
torch::Tensor indices;
378268
if (do_sample_frame_) {
379-
indices = this->sample_frames(metadata, temporal_patch_size_);
269+
indices = this->GLM_sample_frames(metadata, temporal_patch_size_);
380270
} else {
381-
indices = torch::arange(0,
382-
static_cast<int64_t>(origin_video.size(0)),
383-
torch::TensorOptions().dtype(torch::kLong));
271+
indices = this->init_frames(metadata); // default sample to 32 frames
384272
}
385273
auto video = origin_video.index_select(/*dim=*/0, indices);
386274
int64_t sampled_total_frames = video.size(0);

xllm/processors/glm4v_image_processor.h

100644100755
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ class Glm4VImageProcessor : public ImageProcessor {
4242
VideoMetadata& metadata,
4343
std::vector<torch::Tensor>& pixel_values,
4444
std::vector<int64_t>& grids);
45-
torch::Tensor sample_frames(const VideoMetadata& metadata,
46-
int temporal_patch_size);
4745

4846
private:
4947
bool do_convert_rgb_ = true;

0 commit comments

Comments
 (0)