Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.nio.ByteBuffer;
import java.util.List;
import org.pytorch.executorch.ExecuTorchRuntime;
import org.pytorch.executorch.annotations.Experimental;
Expand Down Expand Up @@ -383,6 +384,7 @@ public int generate(
* @throws RuntimeException if the prefill failed
*/
@Experimental
@Deprecated
public long prefillImages(int[] image, int width, int height, int channels) {
int nativeResult = appendImagesInput(image, width, height, channels);
if (nativeResult != 0) {
Expand All @@ -391,8 +393,102 @@ public long prefillImages(int[] image, int width, int height, int channels) {
return 0;
}

/**
* Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data
* is accessed directly without JNI array copies, unlike {@link #prefillImages(int[], int, int,
* int)}. The ByteBuffer must contain raw uint8 pixel data in CHW format with at least channels *
* height * width bytes remaining. Only the first channels * height * width bytes from the
* buffer's current position are consumed.
*
* @param image Input image as a direct ByteBuffer containing uint8 pixel data
* @param width Input image width
* @param height Input image height
* @param channels Input image number of channels
* @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining
* bytes
* @throws RuntimeException if the prefill failed
*/
@Experimental
public void prefillImages(ByteBuffer image, int width, int height, int channels) {
if (!image.isDirect()) {
throw new IllegalArgumentException("Input ByteBuffer must be direct.");
}
long expectedBytes = (long) width * height * channels;
if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) {
throw new IllegalArgumentException(
"ByteBuffer remaining ("
+ image.remaining()
+ ") must be at least width*height*channels ("
+ expectedBytes
+ ").");
}
// slice() so that getDirectBufferAddress on the native side returns a pointer
// starting at the current position, not the base address.
int nativeResult = appendImagesInputBuffer(image.slice(), width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
}

/**
* Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The
* buffer data is accessed directly without JNI array copies, unlike {@link
* #prefillImages(float[], int, int, int)}. The ByteBuffer must contain normalized float pixel
* data in CHW format with at least channels * height * width * 4 bytes remaining. Only the first
* channels * height * width floats from the buffer's current position are consumed. The buffer
* must use the platform's native byte order (set via {@code
* buffer.order(ByteOrder.nativeOrder())}).
*
* @param image Input normalized image as a direct ByteBuffer containing float pixel data in
* native byte order
* @param width Input image width
* @param height Input image height
* @param channels Input image number of channels
* @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining
* bytes, is not float-aligned, or does not use native byte order
* @throws RuntimeException if the prefill failed
*/
@Experimental
public void prefillNormalizedImages(ByteBuffer image, int width, int height, int channels) {
if (!image.isDirect()) {
throw new IllegalArgumentException("Input ByteBuffer must be direct.");
}
if (image.order() != java.nio.ByteOrder.nativeOrder()) {
throw new IllegalArgumentException(
"Input ByteBuffer must use native byte order (ByteOrder.nativeOrder()).");
}
if (image.position() % Float.BYTES != 0) {
throw new IllegalArgumentException(
"Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned.");
}
long expectedBytes = (long) width * height * channels * Float.BYTES;
if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) {
throw new IllegalArgumentException(
"ByteBuffer remaining ("
+ image.remaining()
+ ") must be at least width*height*channels*4 ("
+ expectedBytes
+ ").");
}
if (image.remaining() % Float.BYTES != 0) {
throw new IllegalArgumentException(
"ByteBuffer remaining (" + image.remaining() + ") must be a multiple of 4 (float size).");
}
// slice() so that getDirectBufferAddress on the native side returns a pointer
// starting at the current position, not the base address.
int nativeResult = appendNormalizedImagesInputBuffer(image.slice(), width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
}

private native int appendImagesInput(int[] image, int width, int height, int channels);

private native int appendImagesInputBuffer(ByteBuffer image, int width, int height, int channels);

private native int appendNormalizedImagesInputBuffer(
ByteBuffer image, int width, int height, int channels);

/**
* Prefill a multimodal Module with the given images input.
*
Expand All @@ -405,6 +501,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
* @throws RuntimeException if the prefill failed
*/
@Experimental
@Deprecated
public long prefillImages(float[] image, int width, int height, int channels) {
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
if (nativeResult != 0) {
Expand Down
50 changes: 50 additions & 0 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <chrono>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -280,6 +281,49 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

jint append_images_input_buffer(
facebook::jni::alias_ref<facebook::jni::JByteBuffer> image,
jint width,
jint height,
jint channels) {
if (image == nullptr || width <= 0 || height <= 0 || channels <= 0) {
return static_cast<jint>(Error::InvalidArgument);
}
auto* data = image->getDirectBytes();
auto size = image->getDirectSize();
size_t expected = static_cast<size_t>(width) * height * channels;
if (data == nullptr || size < expected) {
return static_cast<jint>(Error::InvalidArgument);
}
std::vector<uint8_t> image_data(data, data + expected);
llm::Image image_runner{std::move(image_data), width, height, channels};
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(image_runner)});
return 0;
}

jint append_normalized_images_input_buffer(
facebook::jni::alias_ref<facebook::jni::JByteBuffer> image,
jint width,
jint height,
jint channels) {
if (image == nullptr || width <= 0 || height <= 0 || channels <= 0) {
return static_cast<jint>(Error::InvalidArgument);
}
auto* data = image->getDirectBytes();
auto size = image->getDirectSize();
size_t expected_bytes =
static_cast<size_t>(width) * height * channels * sizeof(float);
if (data == nullptr || size < expected_bytes || size % sizeof(float) != 0) {
return static_cast<jint>(Error::InvalidArgument);
}
size_t num_floats = static_cast<size_t>(width) * height * channels;
std::vector<float> image_data(num_floats);
std::memcpy(image_data.data(), data, expected_bytes);
llm::Image image_runner{std::move(image_data), width, height, channels};
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(image_runner)});
return 0;
}

// Returns status_code
jint append_normalized_images_input(
facebook::jni::alias_ref<jfloatArray> image,
Expand Down Expand Up @@ -427,9 +471,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
makeNativeMethod(
"appendImagesInputBuffer",
ExecuTorchLlmJni::append_images_input_buffer),
makeNativeMethod(
"appendNormalizedImagesInput",
ExecuTorchLlmJni::append_normalized_images_input),
makeNativeMethod(
"appendNormalizedImagesInputBuffer",
ExecuTorchLlmJni::append_normalized_images_input_buffer),
makeNativeMethod(
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
makeNativeMethod(
Expand Down
Loading