@@ -77,117 +77,6 @@ std::optional<Size> smart_resize(int num_frames,
7777}
7878} // namespace
7979
80- torch::Tensor Glm4VImageProcessor::sample_frames (const VideoMetadata& metadata,
81- int temporal_patch_size) {
82- // video: [T, C, H, W]
83- const int total_frames = metadata.total_num_frames ;
84- if (total_frames <= 0 ) {
85- return torch::empty ({0 }, torch::dtype (torch::kLong ));
86- }
87-
88- if (metadata.fps <= 0.0 ) {
89- LOG (FATAL) << " invalid metadata.fps <= 0" ;
90- }
91-
92- const int max_frame_idx = total_frames - 1 ;
93-
94- // duration = metadata.duration or round(max_idx / fps) + 1
95- double duration = metadata.duration ;
96- if (duration <= 0.0 ) {
97- duration =
98- std::round (static_cast <double >(max_frame_idx) / metadata.fps ) + 1.0 ;
99- }
100-
101- constexpr double DYN_FPS_30 = 3.0 ;
102- constexpr double DYN_FPS_300 = 1.0 ;
103- constexpr double DYN_FPS_2400 = 0.5 ;
104- constexpr int MAX_FRAME_COUNT_DYNAMIC = 640 ;
105- constexpr double MAX_DURATION = 2400.0 ;
106-
107- const double effective_duration = std::min (duration, MAX_DURATION);
108-
109- double target_fps = 0.0 ;
110- if (effective_duration <= 30.0 ) {
111- target_fps = DYN_FPS_30;
112- } else if (effective_duration <= 300.0 ) {
113- target_fps = DYN_FPS_300;
114- } else {
115- target_fps = DYN_FPS_2400;
116- }
117-
118- // extract_t = int(effective_duration * target_fps * temporal_patch_size)
119- int extract_t = static_cast <int >(effective_duration * target_fps *
120- static_cast <double >(temporal_patch_size));
121- extract_t = std::min (extract_t , MAX_FRAME_COUNT_DYNAMIC);
122-
123- const double duration_per_frame = 1.0 / metadata.fps ;
124- std::vector<double > timestamps (total_frames);
125- for (int i = 0 ; i < total_frames; ++i) {
126- timestamps[i] = static_cast <double >(i) * duration_per_frame;
127- }
128- const int max_second = static_cast <int >(duration);
129-
130- torch::Tensor frame_indices;
131-
132- if (total_frames < extract_t ) {
133- frame_indices = torch::linspace (
134- 0 , total_frames - 1 , extract_t , torch::dtype (torch::kLong ));
135- } else {
136- std::vector<int64_t > tmp;
137- tmp.reserve (static_cast <size_t >(total_frames));
138- double current_second = 0.0 ;
139- const double inv_fps =
140- 1.0 / (static_cast <double >(temporal_patch_size) * target_fps);
141-
142- for (int frame_index = 0 ; frame_index < total_frames; frame_index++) {
143- if (timestamps[frame_index] >= current_second) {
144- current_second += inv_fps;
145- tmp.push_back (frame_index);
146- if (current_second >= static_cast <double >(max_second)) {
147- break ;
148- }
149- }
150- }
151- frame_indices =
152- torch::tensor (tmp, torch::TensorOptions ().dtype (torch::kLong ));
153- }
154- int64_t len = frame_indices.size (0 );
155- if (len < extract_t ) {
156- int64_t start, end;
157- if (len == 0 ) {
158- start = 0 ;
159- end = std::max<int64_t >(total_frames - 1 , 0 );
160- } else {
161- start = frame_indices[0 ].item <int64_t >();
162- end = frame_indices[len - 1 ].item <int64_t >();
163- }
164- frame_indices =
165- torch::linspace (start, end, extract_t , torch::dtype (torch::kLong ));
166- } else if (len > extract_t ) {
167- frame_indices = torch::linspace (
168- 0 , total_frames - 1 , extract_t , torch::dtype (torch::kLong ));
169- }
170-
171- len = frame_indices.size (0 );
172- std::unordered_set<int64_t > seen;
173- seen.reserve (static_cast <size_t >(len) * 2 );
174- std::vector<int64_t > uniq;
175- uniq.reserve (static_cast <size_t >(len));
176-
177- for (int64_t i = 0 ; i < len; ++i) {
178- auto idx = frame_indices[i].item <int64_t >();
179- if (seen.insert (idx).second ) {
180- uniq.push_back (idx);
181- }
182- }
183-
184- if (!uniq.empty () && (uniq.size () & 1 )) {
185- uniq.push_back (uniq.back ());
186- }
187-
188- return torch::tensor (uniq, torch::TensorOptions ().dtype (torch::kLong ));
189- }
190-
19180Glm4VImageProcessor::Glm4VImageProcessor (const ModelArgs& args) {
19281 image_mean_ = args.mm_image_normalize_mean ();
19382 image_std_ = args.mm_image_normalize_std ();
@@ -223,7 +112,8 @@ Glm4VImageProcessor::Glm4VImageProcessor(const ModelArgs& args) {
223112bool Glm4VImageProcessor::process (const MMInput& inputs, MMData& datas) {
224113 std::vector<torch::Tensor> images = inputs.get_decode_data (MMType::IMAGE);
225114 std::vector<torch::Tensor> videos = inputs.get_decode_data (MMType::VIDEO);
226- std::vector<VideoMetadata> video_meta_list = inputs.get_video_metadata ();
115+ std::vector<VideoMetadata> video_meta_list =
116+ inputs.get_video_metadata (MMType::VIDEO);
227117
228118 if (images.empty () && (videos.empty () || video_meta_list.empty ())) {
229119 LOG (ERROR) << " no image/video tensor found." ;
@@ -359,8 +249,8 @@ bool Glm4VImageProcessor::process_videos(
359249
360250 auto values = torch::cat (pixel_values);
361251 auto thw = torch::tensor (grids).clone ().reshape ({-1 , 3 });
362- mm_datas.add (MMType::VIDEO, " video_grid_thw" , thw);
363- mm_datas.add (MMType::VIDEO, " pixel_values_videos" , values);
252+ mm_datas.update (MMType::VIDEO, " video_grid_thw" , thw);
253+ mm_datas.update (MMType::VIDEO, " pixel_values_videos" , values);
364254
365255 return true ;
366256}
@@ -376,11 +266,9 @@ bool Glm4VImageProcessor::process_video(
376266
377267 torch::Tensor indices;
378268 if (do_sample_frame_) {
379- indices = this ->sample_frames (metadata, temporal_patch_size_);
269+ indices = this ->GLM_sample_frames (metadata, temporal_patch_size_);
380270 } else {
381- indices = torch::arange (0 ,
382- static_cast <int64_t >(origin_video.size (0 )),
383- torch::TensorOptions ().dtype (torch::kLong ));
271+ indices = this ->init_frames (metadata); // default sample to 32 frames
384272 }
385273 auto video = origin_video.index_select (/* dim=*/ 0 , indices);
386274 int64_t sampled_total_frames = video.size (0 );
0 commit comments