Skip to content

Commit 3546418

Browse files
xanecdotexDongheJin
authored andcommitted
feat: 1.move sample_frames from image_processor. 2.support diffenrent types in chat_template content.
1 parent 01f2c50 commit 3546418

File tree

9 files changed

+209
-230
lines changed

9 files changed

+209
-230
lines changed

xllm/core/framework/chat_template/jinja_chat_template.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,6 @@ std::optional<std::string> JinjaChatTemplate::apply(
121121
nlohmann::ordered_json& messages,
122122
const nlohmann::ordered_json& tools,
123123
const nlohmann::ordered_json& chat_template_kwargs) const {
124-
for (auto& msg : messages) {
125-
if (!msg.contains("content")) continue;
126-
auto& content = msg["content"];
127-
auto normalize_item = [](nlohmann::ordered_json& item) {
128-
if (item.contains("type") && item["type"].is_string()) {
129-
std::string t = item["type"].get<std::string>();
130-
if (t == "video_url") item["type"] = "video";
131-
}
132-
if (item.contains("video_url") && !item.contains("video"))
133-
item["video"] = item["video_url"];
134-
};
135-
136-
if (content.is_array()) {
137-
for (auto& it : content) normalize_item(it);
138-
} else if (content.is_object()) {
139-
normalize_item(content);
140-
}
141-
}
142124
minja::chat_template_inputs input;
143125
input.messages = messages;
144126
input.tools = tools;
@@ -159,6 +141,15 @@ nlohmann::ordered_json JinjaChatTemplate::get_mm_content(
159141

160142
if (item.type == "text") {
161143
item_json["text"] = item.text;
144+
} else if (item.type == "video_url") {
145+
item_json["video"] = "mm place holder";
146+
item_json["video_url"] = "mm place holder";
147+
} else if (item.type == "image_url") {
148+
item_json["image"] = "mm place holder";
149+
item_json["image_url"] = "mm place holder";
150+
} else if (item.type == "audio_url") {
151+
item_json["audio"] = "mm place holder";
152+
item_json["audio_url"] = "mm place holder";
162153
} else {
163154
item_json[item.type] = "mm place holder";
164155
}

xllm/core/framework/request/mm_codec.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ bool OpenCVVideoDecoder::decode(const std::string& raw_data,
159159
av_dict_set(&opts, "probesize", "20000000", 0);
160160
av_dict_set(&opts, "analyzeduration", "5000000", 0);
161161

162-
const AVInputFormat* in_fmt = av_find_input_format("mp4");
163-
int ret = avformat_open_input(&fmt, nullptr, in_fmt, &opts);
162+
int ret = avformat_open_input(&fmt, nullptr, nullptr, &opts);
164163
av_dict_free(&opts);
165164

166165
if (ret < 0) {

xllm/core/framework/request/mm_input.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,11 @@ struct MMInput {
5858
return std::move(vec);
5959
}
6060

61-
std::vector<VideoMetadata> get_video_metadata(MMType type) const {
61+
std::vector<VideoMetadata> get_video_metadata() const {
6262
std::vector<VideoMetadata> metas;
6363
metas.reserve(items_.size());
6464
for (auto& item : items_) {
65-
if (item.type_ == type) {
66-
metas.push_back(item.video_meta_);
67-
}
65+
metas.push_back(item.video_meta_);
6866
}
6967
return metas;
7068
}

xllm/processors/glm4v_image_processor.cpp

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,117 @@ std::optional<Size> smart_resize(int num_frames,
7777
}
7878
} // namespace
7979

80+
torch::Tensor Glm4VImageProcessor::sample_frames(const VideoMetadata& metadata,
81+
int temporal_patch_size) {
82+
// video: [T, C, H, W]
83+
const int total_frames = metadata.total_num_frames;
84+
if (total_frames <= 0) {
85+
return torch::empty({0}, torch::dtype(torch::kLong));
86+
}
87+
88+
if (metadata.fps <= 0.0) {
89+
LOG(FATAL) << "invalid metadata.fps <= 0";
90+
}
91+
92+
const int max_frame_idx = total_frames - 1;
93+
94+
// duration = metadata.duration or round(max_idx / fps) + 1
95+
double duration = metadata.duration;
96+
if (duration <= 0.0) {
97+
duration =
98+
std::round(static_cast<double>(max_frame_idx) / metadata.fps) + 1.0;
99+
}
100+
101+
constexpr double DYN_FPS_30 = 3.0;
102+
constexpr double DYN_FPS_300 = 1.0;
103+
constexpr double DYN_FPS_2400 = 0.5;
104+
constexpr int MAX_FRAME_COUNT_DYNAMIC = 640;
105+
constexpr double MAX_DURATION = 2400.0;
106+
107+
const double effective_duration = std::min(duration, MAX_DURATION);
108+
109+
double target_fps = 0.0;
110+
if (effective_duration <= 30.0) {
111+
target_fps = DYN_FPS_30;
112+
} else if (effective_duration <= 300.0) {
113+
target_fps = DYN_FPS_300;
114+
} else {
115+
target_fps = DYN_FPS_2400;
116+
}
117+
118+
// extract_t = int(effective_duration * target_fps * temporal_patch_size)
119+
int extract_t = static_cast<int>(effective_duration * target_fps *
120+
static_cast<double>(temporal_patch_size));
121+
extract_t = std::min(extract_t, MAX_FRAME_COUNT_DYNAMIC);
122+
123+
const double duration_per_frame = 1.0 / metadata.fps;
124+
std::vector<double> timestamps(total_frames);
125+
for (int i = 0; i < total_frames; ++i) {
126+
timestamps[i] = static_cast<double>(i) * duration_per_frame;
127+
}
128+
const int max_second = static_cast<int>(duration);
129+
130+
torch::Tensor frame_indices;
131+
132+
if (total_frames < extract_t) {
133+
frame_indices = torch::linspace(
134+
0, total_frames - 1, extract_t, torch::dtype(torch::kLong));
135+
} else {
136+
std::vector<int64_t> tmp;
137+
tmp.reserve(static_cast<size_t>(total_frames));
138+
double current_second = 0.0;
139+
const double inv_fps =
140+
1.0 / (static_cast<double>(temporal_patch_size) * target_fps);
141+
142+
for (int frame_index = 0; frame_index < total_frames; frame_index++) {
143+
if (timestamps[frame_index] >= current_second) {
144+
current_second += inv_fps;
145+
tmp.push_back(frame_index);
146+
if (current_second >= static_cast<double>(max_second)) {
147+
break;
148+
}
149+
}
150+
}
151+
frame_indices =
152+
torch::tensor(tmp, torch::TensorOptions().dtype(torch::kLong));
153+
}
154+
int64_t len = frame_indices.size(0);
155+
if (len < extract_t) {
156+
int64_t start, end;
157+
if (len == 0) {
158+
start = 0;
159+
end = std::max<int64_t>(total_frames - 1, 0);
160+
} else {
161+
start = frame_indices[0].item<int64_t>();
162+
end = frame_indices[len - 1].item<int64_t>();
163+
}
164+
frame_indices =
165+
torch::linspace(start, end, extract_t, torch::dtype(torch::kLong));
166+
} else if (len > extract_t) {
167+
frame_indices = torch::linspace(
168+
0, total_frames - 1, extract_t, torch::dtype(torch::kLong));
169+
}
170+
171+
len = frame_indices.size(0);
172+
std::unordered_set<int64_t> seen;
173+
seen.reserve(static_cast<size_t>(len) * 2);
174+
std::vector<int64_t> uniq;
175+
uniq.reserve(static_cast<size_t>(len));
176+
177+
for (int64_t i = 0; i < len; ++i) {
178+
auto idx = frame_indices[i].item<int64_t>();
179+
if (seen.insert(idx).second) {
180+
uniq.push_back(idx);
181+
}
182+
}
183+
184+
if (!uniq.empty() && (uniq.size() & 1)) {
185+
uniq.push_back(uniq.back());
186+
}
187+
188+
return torch::tensor(uniq, torch::TensorOptions().dtype(torch::kLong));
189+
}
190+
80191
Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
81192
image_mean_ = args.mm_image_normalize_mean();
82193
image_std_ = args.mm_image_normalize_std();
@@ -112,8 +223,7 @@ Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
112223
bool Glm4VImageProcessor::process(const MMInput& inputs, MMData& datas) {
113224
std::vector<torch::Tensor> images = inputs.get_decode_data(MMType::IMAGE);
114225
std::vector<torch::Tensor> videos = inputs.get_decode_data(MMType::VIDEO);
115-
std::vector<VideoMetadata> video_meta_list =
116-
inputs.get_video_metadata(MMType::VIDEO);
226+
std::vector<VideoMetadata> video_meta_list = inputs.get_video_metadata();
117227

118228
if (images.empty() && (videos.empty() || video_meta_list.empty())) {
119229
LOG(ERROR) << "no image/video tensor found.";
@@ -249,8 +359,8 @@ bool Glm4VImageProcessor::process_videos(
249359

250360
auto values = torch::cat(pixel_values);
251361
auto thw = torch::tensor(grids).clone().reshape({-1, 3});
252-
mm_datas.update(MMType::VIDEO, "video_grid_thw", thw);
253-
mm_datas.update(MMType::VIDEO, "pixel_values_videos", values);
362+
mm_datas.add(MMType::VIDEO, "video_grid_thw", thw);
363+
mm_datas.add(MMType::VIDEO, "pixel_values_videos", values);
254364

255365
return true;
256366
}
@@ -266,9 +376,11 @@ bool Glm4VImageProcessor::process_video(
266376

267377
torch::Tensor indices;
268378
if (do_sample_frame_) {
269-
indices = this->GLM_sample_frames(metadata, temporal_patch_size_);
379+
indices = this->sample_frames(metadata, temporal_patch_size_);
270380
} else {
271-
indices = this->init_frames(metadata); // default sample to 32 frames
381+
indices = torch::arange(0,
382+
static_cast<int64_t>(origin_video.size(0)),
383+
torch::TensorOptions().dtype(torch::kLong));
272384
}
273385
auto video = origin_video.index_select(/*dim=*/0, indices);
274386
int64_t sampled_total_frames = video.size(0);

xllm/processors/glm4v_image_processor.h

100755100644
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class Glm4VImageProcessor : public ImageProcessor {
4242
VideoMetadata& metadata,
4343
std::vector<torch::Tensor>& pixel_values,
4444
std::vector<int64_t>& grids);
45+
torch::Tensor sample_frames(const VideoMetadata& metadata,
46+
int temporal_patch_size);
4547

4648
private:
4749
bool do_convert_rgb_ = true;

0 commit comments

Comments
 (0)