Skip to content

Commit b585280

Browse files
committed
Replacing throw with TORCH_CHECK
Replaced all the occurrences of throw in the cpp code with TORCH_CHECK. TORCH_CHECK throws a runtime error, so the cpp test suite had to be updated as well.
1 parent f4a351c commit b585280

File tree

5 files changed

+104
-112
lines changed

5 files changed

+104
-112
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(
3737
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3838
: DeviceInterface(device) {
3939
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
40-
if (device_.type() != torch::kCPU) {
41-
throw std::runtime_error("Unsupported device: " + device_.str());
42-
}
40+
TORCH_CHECK(
41+
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
4342
}
4443

4544
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
@@ -161,9 +160,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
161160
frameOutput.data = outputTensor;
162161
}
163162
} else {
164-
throw std::runtime_error(
165-
"Invalid color conversion library: " +
166-
std::to_string(static_cast<int>(colorConversionLibrary)));
163+
TORCH_CHECK(
164+
false,
165+
"Invalid color conversion library: ",
166+
static_cast<int>(colorConversionLibrary));
167167
}
168168
}
169169

@@ -189,9 +189,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189189
const UniqueAVFrame& avFrame) {
190190
int status = av_buffersrc_write_frame(
191191
filterGraphContext_.sourceContext, avFrame.get());
192-
if (status < AVSUCCESS) {
193-
throw std::runtime_error("Failed to add frame to buffer source context");
194-
}
192+
TORCH_CHECK(
193+
status >= AVSUCCESS, "Failed to add frame to buffer source context");
195194

196195
UniqueAVFrame filteredAVFrame(av_frame_alloc());
197196
status = av_buffersink_get_frame(
@@ -241,11 +240,12 @@ void CpuDeviceInterface::createFilterGraph(
241240
filterArgs.str().c_str(),
242241
nullptr,
243242
filterGraphContext_.filterGraph.get());
244-
if (status < 0) {
245-
throw std::runtime_error(
246-
std::string("Failed to create filter graph: ") + filterArgs.str() +
247-
": " + getFFMPEGErrorStringFromErrorCode(status));
248-
}
243+
TORCH_CHECK(
244+
status >= 0,
245+
"Failed to create filter graph: ",
246+
filterArgs.str(),
247+
": ",
248+
getFFMPEGErrorStringFromErrorCode(status));
249249

250250
status = avfilter_graph_create_filter(
251251
&filterGraphContext_.sinkContext,
@@ -254,11 +254,10 @@ void CpuDeviceInterface::createFilterGraph(
254254
nullptr,
255255
nullptr,
256256
filterGraphContext_.filterGraph.get());
257-
if (status < 0) {
258-
throw std::runtime_error(
259-
"Failed to create filter graph: " +
260-
getFFMPEGErrorStringFromErrorCode(status));
261-
}
257+
TORCH_CHECK(
258+
status >= 0,
259+
"Failed to create filter graph: ",
260+
getFFMPEGErrorStringFromErrorCode(status));
262261

263262
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
264263

@@ -268,11 +267,10 @@ void CpuDeviceInterface::createFilterGraph(
268267
pix_fmts,
269268
AV_PIX_FMT_NONE,
270269
AV_OPT_SEARCH_CHILDREN);
271-
if (status < 0) {
272-
throw std::runtime_error(
273-
"Failed to set output pixel formats: " +
274-
getFFMPEGErrorStringFromErrorCode(status));
275-
}
270+
TORCH_CHECK(
271+
status >= 0,
272+
"Failed to set output pixel formats: ",
273+
getFFMPEGErrorStringFromErrorCode(status));
276274

277275
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
278276
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
@@ -301,19 +299,17 @@ void CpuDeviceInterface::createFilterGraph(
301299
nullptr);
302300
outputs.reset(outputsTmp);
303301
inputs.reset(inputsTmp);
304-
if (status < 0) {
305-
throw std::runtime_error(
306-
"Failed to parse filter description: " +
307-
getFFMPEGErrorStringFromErrorCode(status));
308-
}
302+
TORCH_CHECK(
303+
status >= 0,
304+
"Failed to parse filter description: ",
305+
getFFMPEGErrorStringFromErrorCode(status));
309306

310307
status =
311308
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
312-
if (status < 0) {
313-
throw std::runtime_error(
314-
"Failed to configure filter graph: " +
315-
getFFMPEGErrorStringFromErrorCode(status));
316-
}
309+
TORCH_CHECK(
310+
status >= 0,
311+
"Failed to configure filter graph: ",
312+
getFFMPEGErrorStringFromErrorCode(status));
317313
}
318314

319315
void CpuDeviceInterface::createSwsContext(

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
166166
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
167167
: DeviceInterface(device) {
168168
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
169-
if (device_.type() != torch::kCUDA) {
170-
throw std::runtime_error("Unsupported device: " + device_.str());
171-
}
169+
TORCH_CHECK(
170+
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
172171
}
173172

174173
CudaDeviceInterface::~CudaDeviceInterface() {

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,10 @@ void SingleStreamDecoder::initializeDecoder() {
103103
// which decodes a few frames to get missing info. For more, see:
104104
// https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html
105105
int status = avformat_find_stream_info(formatContext_.get(), nullptr);
106-
if (status < 0) {
107-
throw std::runtime_error(
108-
"Failed to find stream info: " +
109-
getFFMPEGErrorStringFromErrorCode(status));
110-
}
106+
TORCH_CHECK(
107+
status >= 0,
108+
"Failed to find stream info: ",
109+
getFFMPEGErrorStringFromErrorCode(status));
111110

112111
for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
113112
AVStream* avStream = formatContext_->streams[i];
@@ -222,11 +221,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
222221
break;
223222
}
224223

225-
if (status != AVSUCCESS) {
226-
throw std::runtime_error(
227-
"Failed to read frame from input file: " +
228-
getFFMPEGErrorStringFromErrorCode(status));
229-
}
224+
TORCH_CHECK(
225+
status == AVSUCCESS,
226+
"Failed to read frame from input file: ",
227+
getFFMPEGErrorStringFromErrorCode(status));
230228

231229
if (packet->flags & AV_PKT_FLAG_DISCARD) {
232230
continue;
@@ -279,11 +277,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
279277

280278
// Reset the seek-cursor back to the beginning.
281279
int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
282-
if (status < 0) {
283-
throw std::runtime_error(
284-
"Could not seek file to pts=0: " +
285-
getFFMPEGErrorStringFromErrorCode(status));
286-
}
280+
TORCH_CHECK(
281+
status >= 0,
282+
"Could not seek file to pts=0: ",
283+
getFFMPEGErrorStringFromErrorCode(status));
287284

288285
// Sort all frames by their pts.
289286
for (auto& [streamIndex, streamInfo] : streamInfos_) {
@@ -363,11 +360,11 @@ void SingleStreamDecoder::addStream(
363360
activeStreamIndex_ = av_find_best_stream(
364361
formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0);
365362

366-
if (activeStreamIndex_ < 0) {
367-
throw std::invalid_argument(
368-
"No valid stream found in input file. Is " +
369-
std::to_string(streamIndex) + " of the desired media type?");
370-
}
363+
TORCH_CHECK(
364+
activeStreamIndex_ >= 0,
365+
"No valid stream found in input file. Is ",
366+
std::to_string(streamIndex),
367+
" of the desired media type?");
371368

372369
TORCH_CHECK(avCodec != nullptr);
373370

@@ -415,9 +412,7 @@ void SingleStreamDecoder::addStream(
415412
}
416413

417414
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
418-
if (retVal < AVSUCCESS) {
419-
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
420-
}
415+
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
421416

422417
codecContext->time_base = streamInfo.stream->time_base;
423418
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
@@ -446,13 +441,12 @@ void SingleStreamDecoder::addVideoStream(
446441
auto& streamMetadata =
447442
containerMetadata_.allStreamMetadata[activeStreamIndex_];
448443

449-
if (seekMode_ == SeekMode::approximate &&
450-
!streamMetadata.averageFpsFromHeader.has_value()) {
451-
throw std::runtime_error(
452-
"Seek mode is approximate, but stream " +
453-
std::to_string(activeStreamIndex_) +
454-
" does not have an average fps in its metadata.");
455-
}
444+
TORCH_CHECK(
445+
!(seekMode_ == SeekMode::approximate &&
446+
!streamMetadata.averageFps.has_value()),
447+
"Seek mode is approximate, but stream ",
448+
std::to_string(activeStreamIndex_),
449+
" does not have an average fps in its metadata.");
456450

457451
auto& streamInfo = streamInfos_[activeStreamIndex_];
458452
streamInfo.videoStreamOptions = videoStreamOptions;
@@ -1048,11 +1042,13 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
10481042
desiredPts,
10491043
desiredPts,
10501044
0);
1051-
if (status < 0) {
1052-
throw std::runtime_error(
1053-
"Could not seek file to pts=" + std::to_string(desiredPts) + ": " +
1054-
getFFMPEGErrorStringFromErrorCode(status));
1055-
}
1045+
TORCH_CHECK(
1046+
status >= 0,
1047+
"Could not seek file to pts=",
1048+
std::to_string(desiredPts),
1049+
": ",
1050+
getFFMPEGErrorStringFromErrorCode(status));
1051+
10561052
decodeStats_.numFlushes++;
10571053
avcodec_flush_buffers(streamInfo.codecContext.get());
10581054
}
@@ -1121,21 +1117,20 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11211117
status = avcodec_send_packet(
11221118
streamInfo.codecContext.get(),
11231119
/*avpkt=*/nullptr);
1124-
if (status < AVSUCCESS) {
1125-
throw std::runtime_error(
1126-
"Could not flush decoder: " +
1127-
getFFMPEGErrorStringFromErrorCode(status));
1128-
}
1120+
TORCH_CHECK(
1121+
status >= AVSUCCESS,
1122+
"Could not flush decoder: ",
1123+
getFFMPEGErrorStringFromErrorCode(status));
11291124

11301125
reachedEOF = true;
11311126
break;
11321127
}
11331128

1134-
if (status < AVSUCCESS) {
1135-
throw std::runtime_error(
1136-
"Could not read frame from input file: " +
1137-
getFFMPEGErrorStringFromErrorCode(status));
1138-
}
1129+
TORCH_CHECK(
1130+
status >= AVSUCCESS,
1131+
"Could not read frame from input file: ",
1132+
getFFMPEGErrorStringFromErrorCode(status));
1133+
11391134
} while (packet->stream_index != activeStreamIndex_);
11401135

11411136
if (reachedEOF) {
@@ -1147,23 +1142,22 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11471142
// We got a valid packet. Send it to the decoder, and we'll receive it in
11481143
// the next iteration.
11491144
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1150-
if (status < AVSUCCESS) {
1151-
throw std::runtime_error(
1152-
"Could not push packet to decoder: " +
1153-
getFFMPEGErrorStringFromErrorCode(status));
1154-
}
1145+
TORCH_CHECK(
1146+
status >= AVSUCCESS,
1147+
"Could not push packet to decoder: ",
1148+
getFFMPEGErrorStringFromErrorCode(status));
11551149

11561150
decodeStats_.numPacketsSentToDecoder++;
11571151
}
11581152

11591153
if (status < AVSUCCESS) {
1160-
if (reachedEOF || status == AVERROR_EOF) {
1161-
throw SingleStreamDecoder::EndOfFileException(
1162-
"Requested next frame while there are no more frames left to "
1163-
"decode.");
1164-
}
1165-
throw std::runtime_error(
1166-
"Could not receive frame from decoder: " +
1154+
TORCH_CHECK(
1155+
!(reachedEOF || status == AVERROR_EOF),
1156+
"Requested next frame while there are no more frames left to decode.");
1157+
1158+
TORCH_CHECK(
1159+
false,
1160+
"Could not receive frame from decoder: ",
11671161
getFFMPEGErrorStringFromErrorCode(status));
11681162
}
11691163

@@ -1429,7 +1423,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
14291423
return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
14301424
}
14311425
default:
1432-
throw std::runtime_error("Unknown SeekMode");
1426+
TORCH_CHECK(false, "Unknown SeekMode");
14331427
}
14341428
}
14351429

@@ -1456,7 +1450,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
14561450
return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
14571451
}
14581452
default:
1459-
throw std::runtime_error("Unknown SeekMode");
1453+
TORCH_CHECK(false, "Unknown SeekMode");
14601454
}
14611455
}
14621456

@@ -1476,7 +1470,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14761470
streamInfo.timeBase);
14771471
}
14781472
default:
1479-
throw std::runtime_error("Unknown SeekMode");
1473+
TORCH_CHECK(false, "Unknown SeekMode");
14801474
}
14811475
}
14821476

@@ -1493,7 +1487,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14931487
return streamMetadata.numFramesFromHeader;
14941488
}
14951489
default:
1496-
throw std::runtime_error("Unknown SeekMode");
1490+
TORCH_CHECK(false, "Unknown SeekMode");
14971491
}
14981492
}
14991493

@@ -1505,7 +1499,7 @@ double SingleStreamDecoder::getMinSeconds(
15051499
case SeekMode::approximate:
15061500
return 0;
15071501
default:
1508-
throw std::runtime_error("Unknown SeekMode");
1502+
TORCH_CHECK(false, "Unknown SeekMode");
15091503
}
15101504
}
15111505

@@ -1518,7 +1512,7 @@ std::optional<double> SingleStreamDecoder::getMaxSeconds(
15181512
return streamMetadata.durationSecondsFromHeader;
15191513
}
15201514
default:
1521-
throw std::runtime_error("Unknown SeekMode");
1515+
TORCH_CHECK(false, "Unknown SeekMode");
15221516
}
15231517
}
15241518

@@ -1552,10 +1546,10 @@ void SingleStreamDecoder::validateActiveStream(
15521546
}
15531547

15541548
void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
1555-
if (!scannedAllStreams_) {
1556-
throw std::runtime_error(
1557-
"Must scan all streams to update metadata before calling " + msg);
1558-
}
1549+
TORCH_CHECK(
1550+
scannedAllStreams_,
1551+
"Must scan all streams to update metadata before calling ",
1552+
msg);
15591553
}
15601554

15611555
void SingleStreamDecoder::validateFrameIndex(

0 commit comments

Comments
 (0)