2222import dlclive .pose_estimation_pytorch .models as models
2323import dlclive .pose_estimation_pytorch .dynamic_cropping as dynamic_cropping
2424from dlclive .core .runner import BaseRunner
25+ from dlclive .pose_estimation_pytorch .data .image import AutoPadToDivisor
2526
2627
2728@dataclass
@@ -142,7 +143,8 @@ def __init__(
142143 self .cfg = None
143144 self .detector = None
144145 self .model = None
145- self .transform = None
146+ self .detector_transform = None
147+ self .pose_transform = None
146148
147149 # Parse Dynamic Cropping parameters
148150 if isinstance (dynamic , dict ):
@@ -172,13 +174,7 @@ def close(self) -> None:
172174 @torch .inference_mode ()
173175 def get_pose (self , frame : np .ndarray ) -> np .ndarray :
174176 c , h , w = frame .shape
175- frame = (
176- self .transform (torch .from_numpy (frame ).permute (2 , 0 , 1 ))
177- .unsqueeze (0 )
178- .to (self .device )
179- )
180- if self .precision == "FP16" :
181- frame = frame .half ()
177+ tensor = torch .from_numpy (frame ).permute (2 , 0 , 1 ) # CHW, still on CPU
182178
183179 offsets_and_scales = None
184180 if self .detector is not None :
@@ -187,18 +183,32 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:
187183 detections = self .top_down_config .skip_frames .get_detections ()
188184
189185 if detections is None :
190- detections = self .detector (frame )[0 ]
186+ # Apply detector transform before inference
187+ detector_input = self .detector_transform (tensor ).unsqueeze (0 ).to (self .device )
188+ if self .precision == "FP16" :
189+ detector_input = detector_input .half ()
190+ detections = self .detector (detector_input )[0 ]
191191
192- frame_batch , offsets_and_scales = self ._prepare_top_down (frame , detections )
192+ frame_batch , offsets_and_scales = self ._prepare_top_down (tensor , detections )
193193 if len (frame_batch ) == 0 :
194194 offsets_and_scales = [(0 , 0 ), 1 ]
195195 else :
196- frame = frame_batch . to ( self . device )
196+ tensor = frame_batch # still CHW, batched
197197
198198 if self .dynamic is not None :
199- frame = self .dynamic .crop (frame )
199+ tensor = self .dynamic .crop (tensor )
200+
201+ # Apply pose transform
202+ model_input = self .pose_transform (tensor )
203+ # Ensure 4D input: (N, C, H, W)
204+ if model_input .dim () == 3 :
205+ model_input = model_input .unsqueeze (0 )
206+ # Send to device
207+ model_input = model_input .to (self .device )
208+ if self .precision == "FP16" :
209+ model_input = model_input .half ()
200210
201- outputs = self .model (frame )
211+ outputs = self .model (model_input )
202212 batch_pose = self .model .get_predictions (outputs )["bodypart" ]["poses" ]
203213
204214 if self .dynamic is not None :
@@ -264,15 +274,18 @@ def load_model(self) -> None:
264274 self .detector .to (self .device )
265275 self .detector .load_state_dict (raw_data ["detector" ])
266276 self .detector .eval ()
267-
268277 if self .precision == "FP16" :
269278 self .detector = self .detector .half ()
270279
271280 if self .top_down_config is None :
272281 self .top_down_config = TopDownConfig ()
273-
274282 self .top_down_config .read_config (self .cfg )
275283
284+ detector_transforms = [v2 .ToDtype (torch .float32 , scale = True )]
285+ if self .cfg ["detector" ]["data" ]["inference" ].get ("normalize_images" , False ):
286+ detector_transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
287+ self .detector_transform = v2 .Compose (detector_transforms )
288+
276289 if isinstance (self .dynamic , dynamic_cropping .TopDownDynamicCropper ):
277290 crop = self .cfg ["data" ]["inference" ].get ("top_down_crop" , {})
278291 w , h = crop .get ("width" , 256 ), crop .get ("height" , 256 )
@@ -287,12 +300,18 @@ def load_model(self) -> None:
287300 "Top-down models must either use a detector or a TopDownDynamicCropper."
288301 )
289302
290- self .transform = v2 .Compose (
291- [
292- v2 .ToDtype (torch .float32 , scale = True ),
293- v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
294- ]
295- )
303+ pose_transforms = [v2 .ToDtype (torch .float32 , scale = True )]
304+ auto_padding_cfg = self .cfg ["data" ]["inference" ].get ("auto_padding" , None )
305+ if auto_padding_cfg :
306+ pose_transforms .append (
307+ AutoPadToDivisor (
308+ pad_height_divisor = auto_padding_cfg .get ("pad_height_divisor" , 1 ),
309+ pad_width_divisor = auto_padding_cfg .get ("pad_width_divisor" , 1 ),
310+ )
311+ )
312+ if self .cfg ["data" ]["inference" ].get ("normalize_images" , False ):
313+ pose_transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
314+ self .pose_transform = v2 .Compose (pose_transforms )
296315
297316 def read_config (self ) -> dict :
298317 """Reads the configuration file"""
@@ -306,8 +325,17 @@ def _prepare_top_down(
306325 self , frame : torch .Tensor , detections : dict [str , torch .Tensor ]
307326 ):
308327 """Prepares a frame for top-down pose estimation."""
328+ # Accept unbatched frame (C, H, W) or batched frame (1, C, H, W)
329+ if frame .dim () == 4 :
330+ if frame .size (0 ) != 1 :
331+ raise ValueError (f"Expected batch size 1, got { frame .size (0 )} " )
332+ frame = frame [0 ] # (C, H, W)
333+ elif frame .dim () != 3 :
334+ raise ValueError (f"Expected frame of shape (C, H, W) or (1, C, H, W), got { frame .shape } " )
335+
309336 bboxes , scores = detections ["boxes" ], detections ["scores" ]
310337 bboxes = bboxes [scores >= self .top_down_config .bbox_cutoff ]
338+
311339 if len (bboxes ) > 0 and self .top_down_config .max_detections is not None :
312340 bboxes = bboxes [: self .top_down_config .max_detections ]
313341
@@ -316,7 +344,7 @@ def _prepare_top_down(
316344 for bbox in bboxes :
317345 x1 , y1 , x2 , y2 = bbox .tolist ()
318346 cropped_frame , offset , scale = data .top_down_crop_torch (
319- frame [ 0 ] ,
347+ frame ,
320348 (x1 , y1 , x2 - x1 , y2 - y1 ),
321349 output_size = self .top_down_config .crop_size ,
322350 margin = 0 ,
0 commit comments