@@ -219,32 +219,17 @@ def benchmark(
219219 times .append (inf_time )
220220
221221 if save_video :
222- # Visualize keypoints
223- this_pose = pose ["poses" ][0 ][0 ]
224- for j in range (this_pose .shape [0 ]):
225- if this_pose [j , 2 ] > pcutoff :
226- x , y = map (int , this_pose [j , :2 ])
227- cv2 .circle (
228- frame ,
229- center = (x , y ),
230- radius = display_radius ,
231- color = colors [j ],
232- thickness = - 1 ,
233- )
222+ draw_pose_and_write (
223+ frame = frame ,
224+ pose = pose ,
225+ colors = colors ,
226+ bodyparts = bodyparts ,
227+ pcutoff = pcutoff ,
228+ display_radius = display_radius ,
229+ draw_keypoint_names = draw_keypoint_names ,
230+ vwriter = vwriter
231+ )
234232
235- if draw_keypoint_names :
236- cv2 .putText (
237- frame ,
238- text = bodyparts [j ],
239- org = (x + 10 , y ),
240- fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
241- fontScale = 0.5 ,
242- color = colors [j ],
243- thickness = 1 ,
244- lineType = cv2 .LINE_AA ,
245- )
246-
247- vwriter .write (image = frame )
248233 frame_index += 1
249234
250235 cap .release ()
@@ -291,6 +276,47 @@ def setup_video_writer(
291276
292277 return colors , vwriter
293278
279+ def draw_pose_and_write (
280+ frame : np .ndarray ,
281+ pose : np .ndarray ,
282+ colors : list [tuple [int , int , int ]],
283+ bodyparts : list [str ],
284+ pcutoff : float ,
285+ display_radius : int ,
286+ draw_keypoint_names : bool ,
287+ vwriter : cv2 .VideoWriter ,
288+ ):
289+ if len (pose .shape ) == 2 :
290+ pose = pose [None ]
291+
292+ # Visualize keypoints
293+ for i in range (pose .shape [0 ]):
294+ for j in range (pose .shape [1 ]):
295+ if pose [i , j , 2 ] > pcutoff :
296+ x , y = map (int , pose [i , j , :2 ])
297+ cv2 .circle (
298+ frame ,
299+ center = (x , y ),
300+ radius = display_radius ,
301+ color = colors [j ],
302+ thickness = - 1 ,
303+ )
304+
305+ if draw_keypoint_names :
306+ cv2 .putText (
307+ frame ,
308+ text = bodyparts [j ],
309+ org = (x + 10 , y ),
310+ fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
311+ fontScale = 0.5 ,
312+ color = colors [j ],
313+ thickness = 1 ,
314+ lineType = cv2 .LINE_AA ,
315+ )
316+
317+
318+ vwriter .write (image = frame )
319+
294320def save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp ):
295321 """
296322 Saves the detected keypoint poses from the video to CSV and HDF5 files.
0 commit comments