|
7 | 7 | import colorcet as cc |
8 | 8 | import cv2 |
9 | 9 | import h5py |
| 10 | +from pathlib import Path |
10 | 11 | from PIL import ImageColor |
11 | 12 | from pip._internal.operations import freeze |
12 | 13 | import torch |
@@ -95,7 +96,7 @@ def benchmark( |
95 | 96 | cropping=None, # Adding cropping to the function parameters |
96 | 97 | dynamic=(False, 0.5, 10), |
97 | 98 | save_poses=False, |
98 | | - save_dir="model_predictions", |
| 99 | + save_dir=None, |
99 | 100 | draw_keypoint_names=False, |
100 | 101 | cmap="bmy", |
101 | 102 | get_sys_info=True, |
@@ -130,8 +131,9 @@ def benchmark( |
130 | 131 | Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. <detection treshold). The current position is utilized for updating the crop window for the next frame (this is why the margin is important and should be set large enough given the movement of the animal). |
131 | 132 | save_poses : bool, optional, default=False |
132 | 133 | Whether to save the detected poses to CSV and HDF5 files. |
133 | | - save_dir : str, optional, default='model_predictions' |
| 134 | + save_dir : str, optional |
134 | 135 | Directory to save output data and labeled video. |
| 136 | + If not specified, will use the directory of video_path, by default None |
135 | 137 | draw_keypoint_names : bool, optional, default=False |
136 | 138 | Whether to display keypoint names on video frames in the saved video. |
137 | 139 | cmap : str, optional, default='bmy' |
@@ -164,8 +166,10 @@ def benchmark( |
164 | 166 | display_cmap=cmap, |
165 | 167 | ) |
166 | 168 |
|
| 169 | + if save_dir is None: |
| 170 | + save_dir = Path(video_path).resolve().parent |
167 | 171 | # Ensure save directory exists |
168 | | - os.makedirs(name=save_dir, exist_ok=True) |
| 172 | + save_dir.mkdir(parents=True, exist_ok=True) |
169 | 173 |
|
170 | 174 | # Get the current date and time as a string |
171 | 175 | timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
0 commit comments