Skip to content

Commit 5b10eb1

Browse files
committed
Benchmark pytorch: change save_dir default value
1 parent 45c85a5 commit 5b10eb1

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import colorcet as cc
88
import cv2
99
import h5py
10+
from pathlib import Path
1011
from PIL import ImageColor
1112
from pip._internal.operations import freeze
1213
import torch
@@ -95,7 +96,7 @@ def benchmark(
9596
cropping=None, # Adding cropping to the function parameters
9697
dynamic=(False, 0.5, 10),
9798
save_poses=False,
98-
save_dir="model_predictions",
99+
save_dir=None,
99100
draw_keypoint_names=False,
100101
cmap="bmy",
101102
get_sys_info=True,
@@ -130,8 +131,9 @@ def benchmark(
130131
Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. <detection treshold). The current position is utilized for updating the crop window for the next frame (this is why the margin is important and should be set large enough given the movement of the animal).
131132
save_poses : bool, optional, default=False
132133
Whether to save the detected poses to CSV and HDF5 files.
133-
save_dir : str, optional, default='model_predictions'
134+
save_dir : str, optional
134135
Directory to save output data and labeled video.
136+
If not specified, will use the directory of video_path, by default None
135137
draw_keypoint_names : bool, optional, default=False
136138
Whether to display keypoint names on video frames in the saved video.
137139
cmap : str, optional, default='bmy'
@@ -164,8 +166,10 @@ def benchmark(
164166
display_cmap=cmap,
165167
)
166168

169+
if save_dir is None:
170+
save_dir = Path(video_path).resolve().parent
167171
# Ensure save directory exists
168-
os.makedirs(name=save_dir, exist_ok=True)
172+
save_dir.mkdir(parents=True, exist_ok=True)
169173

170174
# Get the current date and time as a string
171175
timestamp = time.strftime("%Y%m%d_%H%M%S")

0 commit comments

Comments
 (0)