@@ -77,6 +77,117 @@ std::optional<Size> smart_resize(int num_frames,
7777}
7878} // namespace
7979
80+ torch::Tensor Glm4VImageProcessor::sample_frames (const VideoMetadata& metadata,
81+ int temporal_patch_size) {
82+ // video: [T, C, H, W]
83+ const int total_frames = metadata.total_num_frames ;
84+ if (total_frames <= 0 ) {
85+ return torch::empty ({0 }, torch::dtype (torch::kLong ));
86+ }
87+
88+ if (metadata.fps <= 0.0 ) {
89+ LOG (FATAL) << " invalid metadata.fps <= 0" ;
90+ }
91+
92+ const int max_frame_idx = total_frames - 1 ;
93+
94+ // duration = metadata.duration or round(max_idx / fps) + 1
95+ double duration = metadata.duration ;
96+ if (duration <= 0.0 ) {
97+ duration =
98+ std::round (static_cast <double >(max_frame_idx) / metadata.fps ) + 1.0 ;
99+ }
100+
101+ constexpr double DYN_FPS_30 = 3.0 ;
102+ constexpr double DYN_FPS_300 = 1.0 ;
103+ constexpr double DYN_FPS_2400 = 0.5 ;
104+ constexpr int MAX_FRAME_COUNT_DYNAMIC = 640 ;
105+ constexpr double MAX_DURATION = 2400.0 ;
106+
107+ const double effective_duration = std::min (duration, MAX_DURATION);
108+
109+ double target_fps = 0.0 ;
110+ if (effective_duration <= 30.0 ) {
111+ target_fps = DYN_FPS_30;
112+ } else if (effective_duration <= 300.0 ) {
113+ target_fps = DYN_FPS_300;
114+ } else {
115+ target_fps = DYN_FPS_2400;
116+ }
117+
118+ // extract_t = int(effective_duration * target_fps * temporal_patch_size)
119+ int extract_t = static_cast <int >(effective_duration * target_fps *
120+ static_cast <double >(temporal_patch_size));
121+ extract_t = std::min (extract_t , MAX_FRAME_COUNT_DYNAMIC);
122+
123+ const double duration_per_frame = 1.0 / metadata.fps ;
124+ std::vector<double > timestamps (total_frames);
125+ for (int i = 0 ; i < total_frames; ++i) {
126+ timestamps[i] = static_cast <double >(i) * duration_per_frame;
127+ }
128+ const int max_second = static_cast <int >(duration);
129+
130+ torch::Tensor frame_indices;
131+
132+ if (total_frames < extract_t ) {
133+ frame_indices = torch::linspace (
134+ 0 , total_frames - 1 , extract_t , torch::dtype (torch::kLong ));
135+ } else {
136+ std::vector<int64_t > tmp;
137+ tmp.reserve (static_cast <size_t >(total_frames));
138+ double current_second = 0.0 ;
139+ const double inv_fps =
140+ 1.0 / (static_cast <double >(temporal_patch_size) * target_fps);
141+
142+ for (int frame_index = 0 ; frame_index < total_frames; frame_index++) {
143+ if (timestamps[frame_index] >= current_second) {
144+ current_second += inv_fps;
145+ tmp.push_back (frame_index);
146+ if (current_second >= static_cast <double >(max_second)) {
147+ break ;
148+ }
149+ }
150+ }
151+ frame_indices =
152+ torch::tensor (tmp, torch::TensorOptions ().dtype (torch::kLong ));
153+ }
154+ int64_t len = frame_indices.size (0 );
155+ if (len < extract_t ) {
156+ int64_t start, end;
157+ if (len == 0 ) {
158+ start = 0 ;
159+ end = std::max<int64_t >(total_frames - 1 , 0 );
160+ } else {
161+ start = frame_indices[0 ].item <int64_t >();
162+ end = frame_indices[len - 1 ].item <int64_t >();
163+ }
164+ frame_indices =
165+ torch::linspace (start, end, extract_t , torch::dtype (torch::kLong ));
166+ } else if (len > extract_t ) {
167+ frame_indices = torch::linspace (
168+ 0 , total_frames - 1 , extract_t , torch::dtype (torch::kLong ));
169+ }
170+
171+ len = frame_indices.size (0 );
172+ std::unordered_set<int64_t > seen;
173+ seen.reserve (static_cast <size_t >(len) * 2 );
174+ std::vector<int64_t > uniq;
175+ uniq.reserve (static_cast <size_t >(len));
176+
177+ for (int64_t i = 0 ; i < len; ++i) {
178+ auto idx = frame_indices[i].item <int64_t >();
179+ if (seen.insert (idx).second ) {
180+ uniq.push_back (idx);
181+ }
182+ }
183+
184+ if (!uniq.empty () && (uniq.size () & 1 )) {
185+ uniq.push_back (uniq.back ());
186+ }
187+
188+ return torch::tensor (uniq, torch::TensorOptions ().dtype (torch::kLong ));
189+ }
190+
80191Glm4VImageProcessor::Glm4VImageProcessor (const ModelArgs& args) {
81192 image_mean_ = args.mm_image_normalize_mean ();
82193 image_std_ = args.mm_image_normalize_std ();
@@ -112,8 +223,7 @@ Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
112223bool Glm4VImageProcessor::process (const MMInput& inputs, MMData& datas) {
113224 std::vector<torch::Tensor> images = inputs.get_decode_data (MMType::IMAGE);
114225 std::vector<torch::Tensor> videos = inputs.get_decode_data (MMType::VIDEO);
115- std::vector<VideoMetadata> video_meta_list =
116- inputs.get_video_metadata (MMType::VIDEO);
226+ std::vector<VideoMetadata> video_meta_list = inputs.get_video_metadata ();
117227
118228 if (images.empty () && (videos.empty () || video_meta_list.empty ())) {
119229 LOG (ERROR) << " no image/video tensor found." ;
@@ -249,8 +359,8 @@ bool Glm4VImageProcessor::process_videos(
249359
250360 auto values = torch::cat (pixel_values);
251361 auto thw = torch::tensor (grids).clone ().reshape ({-1 , 3 });
252- mm_datas.update (MMType::VIDEO, " video_grid_thw" , thw);
253- mm_datas.update (MMType::VIDEO, " pixel_values_videos" , values);
362+ mm_datas.add (MMType::VIDEO, " video_grid_thw" , thw);
363+ mm_datas.add (MMType::VIDEO, " pixel_values_videos" , values);
254364
255365 return true ;
256366}
@@ -266,9 +376,11 @@ bool Glm4VImageProcessor::process_video(
266376
267377 torch::Tensor indices;
268378 if (do_sample_frame_) {
269- indices = this ->GLM_sample_frames (metadata, temporal_patch_size_);
379+ indices = this ->sample_frames (metadata, temporal_patch_size_);
270380 } else {
271- indices = this ->init_frames (metadata); // default sample to 32 frames
381+ indices = torch::arange (0 ,
382+ static_cast <int64_t >(origin_video.size (0 )),
383+ torch::TensorOptions ().dtype (torch::kLong ));
272384 }
273385 auto video = origin_video.index_select (/* dim=*/ 0 , indices);
274386 int64_t sampled_total_frames = video.size (0 );
0 commit comments