Skip to content

Commit f4a351c

Browse files
authored
Turn BytesContext into FromTensorContext (#721)
1 parent efe7553 commit f4a351c

11 files changed

+214
-205
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 0 additions & 137 deletions
This file was deleted.

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ namespace facebook::torchcodec {
2727
// tracks the custom behavior of reading, seeking and writing. It is
2828
// provided upon AVIOContext creation and to the read, seek and
2929
// write callback functions.
30-
// While it's not required, it is natural for the derived classes to make
31-
// all of the above members. Base classes need to call
30+
// The callback functions do not need to be members of the derived class,
31+
// but the derived class must have access to them. The context object must
32+
// be a member of the derived class. Derived classes need to call
3233
// createAVIOContext(), ideally in their constructor.
3334
// 3. A generic handle for those that just need to manage having access to an
3435
// AVIOContext, but aren't necessarily concerned with how it was customized:
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/AVIOTensorContext.h"
8+
#include <torch/types.h>
9+
10+
namespace facebook::torchcodec {
11+
12+
namespace {
13+
14+
constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10 MB
15+
constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
16+
17+
// The signature of this function is defined by FFMPEG.
18+
int read(void* opaque, uint8_t* buf, int buf_size) {
19+
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
20+
TORCH_CHECK(
21+
tensorContext->current <= tensorContext->data.numel(),
22+
"Tried to read outside of the buffer: current=",
23+
tensorContext->current,
24+
", size=",
25+
tensorContext->data.numel());
26+
27+
int64_t numBytesRead = std::min(
28+
static_cast<int64_t>(buf_size),
29+
tensorContext->data.numel() - tensorContext->current);
30+
31+
TORCH_CHECK(
32+
numBytesRead >= 0,
33+
"Tried to read negative bytes: numBytesRead=",
34+
numBytesRead,
35+
", size=",
36+
tensorContext->data.numel(),
37+
", current=",
38+
tensorContext->current);
39+
40+
if (numBytesRead == 0) {
41+
return AVERROR_EOF;
42+
}
43+
44+
std::memcpy(
45+
buf,
46+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
47+
numBytesRead);
48+
tensorContext->current += numBytesRead;
49+
return numBytesRead;
50+
}
51+
52+
// The signature of this function is defined by FFMPEG.
53+
int write(void* opaque, const uint8_t* buf, int buf_size) {
54+
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
55+
56+
int64_t bufSize = static_cast<int64_t>(buf_size);
57+
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
58+
TORCH_CHECK(
59+
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
60+
"We tried to allocate an output encoded tensor larger than ",
61+
MAX_TENSOR_SIZE,
62+
" bytes. If you think this should be supported, please report.");
63+
64+
// We double the size of the outpout tensor. Calling cat() may not be the
65+
// most efficient, but it's simple.
66+
tensorContext->data =
67+
torch::cat({tensorContext->data, tensorContext->data});
68+
}
69+
70+
TORCH_CHECK(
71+
tensorContext->current + bufSize <= tensorContext->data.numel(),
72+
"Re-allocation of the output tensor didn't work. ",
73+
"This should not happen, please report on TorchCodec bug tracker");
74+
75+
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
76+
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
77+
tensorContext->current += bufSize;
78+
return buf_size;
79+
}
80+
81+
// The signature of this function is defined by FFMPEG.
82+
int64_t seek(void* opaque, int64_t offset, int whence) {
83+
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
84+
int64_t ret = -1;
85+
86+
switch (whence) {
87+
case AVSEEK_SIZE:
88+
ret = tensorContext->data.numel();
89+
break;
90+
case SEEK_SET:
91+
tensorContext->current = offset;
92+
ret = offset;
93+
break;
94+
default:
95+
break;
96+
}
97+
98+
return ret;
99+
}
100+
101+
} // namespace
102+
103+
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
104+
: tensorContext_{data, 0} {
105+
TORCH_CHECK(data.numel() > 0, "data must not be empty");
106+
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107+
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
108+
createAVIOContext(&read, nullptr, &seek, &tensorContext_);
109+
}
110+
111+
AVIOToTensorContext::AVIOToTensorContext()
112+
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
113+
createAVIOContext(nullptr, &write, &seek, &tensorContext_);
114+
}
115+
116+
torch::Tensor AVIOToTensorContext::getOutputTensor() {
117+
return tensorContext_.data.narrow(
118+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
119+
}
120+
121+
} // namespace facebook::torchcodec
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torch/types.h>
10+
#include "src/torchcodec/_core/AVIOContextHolder.h"
11+
12+
namespace facebook::torchcodec {
13+
14+
namespace detail {
15+
16+
struct TensorContext {
17+
torch::Tensor data;
18+
int64_t current;
19+
};
20+
21+
} // namespace detail
22+
23+
// For Decoding: enables users to pass in the entire video or audio as bytes.
24+
// Our read and seek functions then traverse the bytes in memory.
25+
class AVIOFromTensorContext : public AVIOContextHolder {
26+
public:
27+
explicit AVIOFromTensorContext(torch::Tensor data);
28+
29+
private:
30+
detail::TensorContext tensorContext_;
31+
};
32+
33+
// For Encoding: used to encode into an output uint8 (bytes) tensor.
34+
class AVIOToTensorContext : public AVIOContextHolder {
35+
public:
36+
explicit AVIOToTensorContext();
37+
torch::Tensor getOutputTensor();
38+
39+
private:
40+
detail::TensorContext tensorContext_;
41+
};
42+
43+
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function(make_torchcodec_libraries
6565
set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}")
6666
set(decoder_sources
6767
AVIOContextHolder.cpp
68-
AVIOBytesContext.cpp
68+
AVIOTensorContext.cpp
6969
FFMPEGCommon.cpp
7070
Frame.cpp
7171
DeviceInterface.cpp
@@ -102,7 +102,7 @@ function(make_torchcodec_libraries
102102
# 2. Create libtorchcodec_custom_opsN.{ext}.
103103
set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}")
104104
set(custom_ops_sources
105-
AVIOBytesContext.cpp
105+
AVIOTensorContext.cpp
106106
custom_ops.cpp
107107
)
108108
set(custom_ops_dependencies

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <sstream>
22

3-
#include "src/torchcodec/_core/AVIOBytesContext.h"
3+
#include "src/torchcodec/_core/AVIOTensorContext.h"
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

0 commit comments

Comments
 (0)