Skip to content

Commit bc18834

Browse files
authored
[multimodal] Allow float32 image input (#14359)
Letting `Image` class support both `uint8_t` and `float` data types, changing `MultimodalPrefiller` class to support text, image, and audio modalities with error checking and modularity. **Image Data Handling and Type Safety:** * Refactored the `Image` class in `image.h` from a simple struct to a class that uses a `std::variant` to support both `uint8_t` and `float` image data, providing type-safe accessors and a `toTensor` method for conversion to tensors. * Updated `load_image` in Llava `main.cpp` to construct `Image` objects using the new class interface and move semantics, ensuring correct data layout and encapsulation. * Added a runtime check in `LlavaImagePrefiller` to ensure only `uint8_t` images are processed, using the new type-checking methods. **Multimodal Prefill Logic and Flexibility:** * Updated the `MultimodalPrefiller` class in `multimodal_prefiller.h` to dynamically check input types, validate tensor types against model expectations, and handles encoder/decoder execution with improved error handling and modularity.
1 parent d25c35a commit bc18834

File tree

6 files changed

+205
-101
lines changed

6 files changed

+205
-101
lines changed

examples/models/llava/main.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,20 @@ void load_image(const std::string& image_path, Image& image) {
8181
new_height,
8282
0,
8383
channels);
84-
// transpose to CHW
85-
image.data.resize(channels * new_width * new_height);
84+
std::vector<uint8_t> chw_data(channels * new_width * new_height);
8685
for (int i = 0; i < new_width * new_height; ++i) {
8786
for (int c = 0; c < channels; ++c) {
88-
image.data[c * new_width * new_height + i] =
89-
resized_data[i * channels + c];
87+
chw_data[c * new_width * new_height + i] = resized_data[i * channels + c];
9088
}
9189
}
92-
image.width = new_width;
93-
image.height = new_height;
94-
image.channels = channels;
90+
image = Image(std::move(chw_data), new_width, new_height, channels);
9591
// convert to tensor
9692
ET_LOG(
9793
Info,
9894
"image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32,
99-
image.channels,
100-
image.height,
101-
image.width);
95+
image.channels(),
96+
image.height(),
97+
image.width());
10298
stbi_image_free(data);
10399
}
104100

extension/android/jni/jni_layer_llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
268268
for (int i = 0; i < image_size; i++) {
269269
image_data[i] = image_data_jint[i];
270270
}
271-
llm::Image image_runner{image_data, width, height, channels};
271+
llm::Image image_runner{std::move(image_data), width, height, channels};
272272
prefill_inputs_.emplace_back(
273273
llm::MultimodalInput{std::move(image_runner)});
274274
}

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ - (BOOL)generate:(NSArray<ExecuTorchLLMMultimodalInput *> *)inputs
172172
case ExecuTorchLLMMultimodalInputTypeImage: {
173173
ExecuTorchLLMImage *image = input.image;
174174
std::vector<uint8_t> data((uint8_t *)image.data.bytes, (uint8_t *)image.data.bytes + image.data.length);
175-
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image{
176-
.data = std::move(data),
177-
.width = (int32_t)image.width,
178-
.height = (int32_t)image.height,
179-
.channels = (int32_t)image.channels
180-
}));
175+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
176+
std::move(data),
177+
(int32_t)image.width,
178+
(int32_t)image.height,
179+
(int32_t)image.channels
180+
)));
181181
break;
182182
}
183183
default: {

extension/llm/runner/image.h

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,112 @@
1010

1111
#pragma once
1212
#include <executorch/runtime/platform/compiler.h>
13+
#include <cstddef>
1314
#include <cstdint>
15+
#include <variant>
1416
#include <vector>
1517

18+
#include <executorch/extension/tensor/tensor.h>
19+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
20+
1621
namespace executorch {
1722
namespace extension {
1823
namespace llm {
1924

20-
struct ET_EXPERIMENTAL Image {
25+
class ET_EXPERIMENTAL Image {
26+
public:
27+
// Default constructor
28+
Image() : width_(0), height_(0), channels_(0) {}
29+
30+
// Constructor for uint8_t data
31+
Image(
32+
std::vector<uint8_t>&& data,
33+
int32_t width,
34+
int32_t height,
35+
int32_t channels)
36+
: data_(std::move(data)),
37+
width_(width),
38+
height_(height),
39+
channels_(channels) {}
40+
41+
// Constructor for float data
42+
Image(
43+
std::vector<float>&& data,
44+
int32_t width,
45+
int32_t height,
46+
int32_t channels)
47+
: data_(std::move(data)),
48+
width_(width),
49+
height_(height),
50+
channels_(channels) {}
51+
52+
// Getters
53+
int32_t width() const {
54+
return width_;
55+
}
56+
int32_t height() const {
57+
return height_;
58+
}
59+
int32_t channels() const {
60+
return channels_;
61+
}
62+
63+
// Data access
64+
bool is_uint8() const {
65+
return std::holds_alternative<std::vector<uint8_t>>(data_);
66+
}
67+
68+
bool is_float() const {
69+
return std::holds_alternative<std::vector<float>>(data_);
70+
}
71+
72+
const std::vector<uint8_t>& get_uint8_data() const& {
73+
return std::get<std::vector<uint8_t>>(data_);
74+
}
75+
76+
std::vector<uint8_t>& get_uint8_data() & {
77+
return std::get<std::vector<uint8_t>>(data_);
78+
}
79+
80+
const std::vector<float>& get_float_data() const& {
81+
return std::get<std::vector<float>>(data_);
82+
}
83+
84+
std::vector<float>& get_float_data() & {
85+
return std::get<std::vector<float>>(data_);
86+
}
87+
88+
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
89+
bool with_batch = false) const {
90+
// Note: This creates a 3D tensor (CHW). The model might expect a 4D
91+
// tensor (NCHW). The caller should handle reshaping if needed.
92+
std::vector<executorch::aten::SizesType> sizes = {
93+
channels(), height(), width()};
94+
if (with_batch) {
95+
sizes.insert(sizes.begin(), 1);
96+
}
97+
if (is_float()) {
98+
return executorch::extension::from_blob(
99+
const_cast<float*>(get_float_data().data()),
100+
sizes,
101+
::executorch::aten::ScalarType::Float);
102+
} else if (is_uint8()) {
103+
return executorch::extension::from_blob(
104+
const_cast<uint8_t*>(get_uint8_data().data()),
105+
sizes,
106+
::executorch::aten::ScalarType::Byte);
107+
}
108+
ET_LOG(
109+
Error, "Image data is not initialized with uint8_t or float vector.");
110+
return ::executorch::runtime::Error::NotSupported;
111+
}
112+
113+
private:
21114
// Assuming NCHW format
22-
std::vector<uint8_t> data;
23-
int32_t width;
24-
int32_t height;
25-
int32_t channels;
115+
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
116+
int32_t width_;
117+
int32_t height_;
118+
int32_t channels_;
26119
};
27120

28121
} // namespace llm

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,42 @@ Result<uint64_t> MultimodalPrefiller::prefill(
4141
::executorch::runtime::EValue encoder_output;
4242
if (input.is_image()) {
4343
Image image = input.get_image();
44-
auto image_tensor = executorch::extension::from_blob(
45-
image.data.data(),
46-
{3, image.height, image.width},
47-
::executorch::aten::ScalarType::Byte);
44+
45+
auto method_meta = ET_UNWRAP(
46+
module_->method_meta(kImageEncoderMethod),
47+
"Failed to get method_meta for %s",
48+
kImageEncoderMethod);
49+
50+
ET_CHECK_MSG(
51+
method_meta.num_inputs() > 0,
52+
"Image encoder should have at least 1 input");
53+
auto input_meta = ET_UNWRAP(
54+
method_meta.input_tensor_meta(0),
55+
"Cannot get input tensor meta at index 0");
56+
auto expected_dtype = input_meta.scalar_type();
57+
58+
if (expected_dtype == ::executorch::aten::ScalarType::Float) {
59+
ET_CHECK_MSG(
60+
image.is_float(),
61+
"Model expects float image data, but image has uint8_t data.");
62+
} else if (expected_dtype == ::executorch::aten::ScalarType::Byte) {
63+
ET_CHECK_MSG(
64+
image.is_uint8(),
65+
"Model expects uint8_t image data, but image has float data.");
66+
} else {
67+
ET_LOG(
68+
Error,
69+
"Unsupported image encoder input dtype: %s",
70+
::executorch::runtime::toString(expected_dtype));
71+
return ::executorch::runtime::Error::NotSupported;
72+
}
73+
74+
// The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D
75+
// tensor (CHW). Add a batch dimension of 1 if needed.
76+
auto expected_dims = input_meta.sizes();
77+
auto image_tensor = ET_UNWRAP(
78+
image.toTensor(/*with_batch*/ expected_dims.size() == 4),
79+
"Failed to convert image to tensor");
4880

4981
// Run image encoder
5082
auto image_encoder_outputs =

0 commit comments

Comments
 (0)