diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index 263ce2228..238475761 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -123,10 +123,8 @@ AVIOToTensorContext::AVIOToTensorContext() } torch::Tensor AVIOToTensorContext::getOutputTensor() { - throw std::runtime_error( - "AVIOToTensorContext::getOutputTensor is not implemented yet."); - // return tensorContext_.data.narrow( - // /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos); + return tensorContext_.data.narrow( + /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index bd87f12d3..8d9e9f651 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1030,15 +1030,15 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( frames.push_back(*lastSamples); } - // TORCH_CHECK( - // frames.size() > 0 && firstFramePtsSeconds.has_value(), - // "No audio frames were decoded. ", - // "This is probably because start_seconds is too high(", - // startSeconds, - // "),", - // "or because stop_seconds(", - // stopSecondsOptional, - // ") is too low."); + TORCH_CHECK( + frames.size() > 0 && firstFramePtsSeconds.has_value(), + "No audio frames were decoded. ", + "This is probably because start_seconds is too high(", + startSeconds, + "),", + "or because stop_seconds(", + stopSecondsOptional, + ") is too low."); return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds}; } @@ -1419,11 +1419,8 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { auto actualNumRemainingSamples = swr_convert( swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); - throw std::runtime_error( - "SingleStreamDecoder::maybeFlushSwrBuffers is not implemented yet."); - - // return lastSamples.narrow( - // /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); + return lastSamples.narrow( + /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); } // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 13ad3be35..f7bfceec1 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -10,12 +10,12 @@ #include #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" -#include "torch/library.h" #include "src/torchcodec/_core/AVIOFileLikeContext.h" #include "src/torchcodec/_core/AVIOTensorContext.h" #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" #include "src/torchcodec/_core/ValidationUtils.h" +#include "torch/library.h" namespace facebook::torchcodec { @@ -118,7 +118,7 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) { // frame.data, // torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), // torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64))); - return std::make_tuple( + return std::make_tuple( frame.data, torch::full({}, frame.ptsSeconds, torch::kFloat64), torch::full({}, frame.durationSeconds, torch::kFloat64)); @@ -920,15 +920,15 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { videoDecoder->scanFileAndUpdateMetadataAndIndex(); } -TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { +TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); m.impl("_create_from_file_like", &_create_from_file_like); m.impl( "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions); -// } +} -// TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { +TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("encode_audio_to_file", &encode_audio_to_file); m.impl("encode_audio_to_tensor", &encode_audio_to_tensor); m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like); diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 6b524f119..cd1839c57 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -146,6 +146,7 @@ def __init__( # if isinstance(device, torch_device): # device = str(device) import paddle + if isinstance(device, paddle.base.core.Place): if device.is_cpu_place(): return "cpu" @@ -158,12 +159,11 @@ def __init__( core.add_video_stream( self._decoder, - num_threads=num_ffmpeg_threads, - dimension_order=dimension_order, stream_index=stream_index, + dimension_order=dimension_order, + num_threads=num_ffmpeg_threads, device=device, device_variant=device_variant, - transform_specs="", custom_frame_mappings=custom_frame_mappings_data, ) @@ -265,9 +265,6 @@ def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch: FrameBatch: The frames at the given indices. """ - if isinstance(indices, list): - indices = torch.tensor(indices, dtype=torch.int64).cpu() - data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices ) @@ -347,9 +344,6 @@ def get_frames_played_at( FrameBatch: The frames that are played at ``seconds``. """ - if isinstance(seconds, list): - seconds = torch.tensor(seconds, dtype=torch.float32).cpu() - data, pts_seconds, duration_seconds = core.get_frames_by_pts( self._decoder, timestamps=seconds ) diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index 2620c171c..d8f107c5e 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -151,7 +151,7 @@ def _generic_index_based_sampler( if kind == "random": clip_start_indices = torch.randint( - sampling_range_start, sampling_range_end, (num_clips,) + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) ) else: # Note [num clips larger than sampling range]