Unify Android JNI to single IRunner, wire prefill to runner#17756
Unify Android JNI to single IRunner, wire prefill to runner#17756kirklandsign wants to merge 9 commits intomainfrom
Conversation
Replace the dual-runner pattern (runner_ + multi_modal_runner_) with a single IRunner* that holds either TextLLMRunner or MultimodalRunner, leveraging MultimodalRunner's new IRunner inheritance from #17741. Each prefill method (text, images, audio) now immediately calls IRunner::prefill(vector<MultimodalInput>) instead of buffering inputs for later consumption by generate(). A needs_bos_ flag tracks whether the next prefill should apply BOS tokens — MultimodalRunner also guards this via pos_==0 internally, but TextLLMRunner trusts the caller. generate(), stop(), load(), and reset() no longer branch on model_type_category_; all dispatch through the unified runner_. Rename all JNI native methods from append* to prefill* to match the existing Java public API naming.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17756
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit 0a3a5e4 with merge base 67bc28b ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Remove "multimodal Module" wording since prefill methods now work through the unified IRunner for both text-only and multimodal models. Simplify return value docs.
Return Error::InvalidState when runner_ is null instead of silently returning 0 (success). Use needs_bos_ to gate num_bos in GenerationConfig so that a prior prefill() call prevents generate() from adding BOS a second time.
There was a problem hiding this comment.
Pull request overview
This pull request unifies the Android JNI layer to use a single IRunner* interface for both text-only and multimodal models, eliminating the dual-runner pattern. It changes the prefill workflow from buffering inputs to immediately populating the KV cache, and renames all JNI methods from append* to prefill* to match the public Java API.
Changes:
- Replaced dual-runner pattern (
runner_+multi_modal_runner_) with unifiedIRunner* runner_leveragingMultimodalRunner's newIRunnerinheritance from PR #17741 - Changed prefill methods to immediately call
IRunner::prefill()instead of buffering inputs inprefill_inputs_vector for later consumption - Added
needs_bos_flag to track whether the next prefill should apply BOS tokens, set to true on construction and afterreset_context(), false after prefill operations
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| extension/android/jni/jni_layer_llama.cpp | Unified runner implementation, removed multi_modal_runner_ and prefill_inputs_ members, refactored all methods to use single runner_ interface, renamed native methods from append* to prefill* |
| extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java | Updated native method declarations from append* to prefill* and revised documentation to reflect immediate KV cache population behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (!runner_) { | ||
| return 0; | ||
| } | ||
| needs_bos_ = false; |
There was a problem hiding this comment.
Setting needs_bos_ to false before calling runner_->generate() is problematic because if generate() fails (e.g., due to encoding errors or other issues), needs_bos_ will remain false even though no actual generation occurred. This corrupts the state and means subsequent prefill calls won't apply BOS tokens even though they should. The needs_bos_ flag should only be set to false after verifying that generate() succeeded. However, since generate() currently doesn't check the return value (see separate comment), this should be fixed together with proper error handling.
| runner_->generate( | ||
| prompt->toStdString(), | ||
| config, | ||
| token_callback, | ||
| [callback](const llm::Stats& result) { callback->onStats(result); }); |
There was a problem hiding this comment.
The generate method does not check the return value of runner_->generate(), which can return Error::InvalidArgument or other errors. If generate() fails, the error is silently ignored and 0 (success) is returned to the Java layer. This can lead to incorrect behavior and makes debugging difficult. The method should check the return value and propagate errors appropriately, similar to how the prefill_* methods handle errors.
| runner_->generate( | |
| prompt->toStdString(), | |
| config, | |
| token_callback, | |
| [callback](const llm::Stats& result) { callback->onStats(result); }); | |
| auto result = runner_->generate( | |
| prompt->toStdString(), | |
| config, | |
| token_callback, | |
| [callback](const llm::Stats& result) { callback->onStats(result); }); | |
| if (!result.ok()) { | |
| return static_cast<jint>(result.error()); | |
| } |
Use num_bos_ (constructor) as the single BOS source in generate(), matching what prefill methods already use. This prevents inconsistency when the per-call num_bos parameter differs from the constructor value. Add block comment above prefill methods documenting the eager execution model, BOS tracking via needs_bos_, and the expected prefill+generate interaction (including echo behavior).
The compiler needs the full TextLLMRunner definition (not just the forward declaration from llm_runner_helper.h) to verify the inheritance from IRunner and allow unique_ptr<TextLLMRunner> to convert to unique_ptr<IRunner>.
runner_->generate() return value was silently ignored, always returning 0 (success). Now propagates the error code back to Java, consistent with how the prefill methods handle errors.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int32_t bos = needs_bos_ ? num_bos_ : 0; | ||
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| int32_t bos = needs_bos_ ? num_bos_ : 0; | ||
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| int32_t bos = needs_bos_ ? num_bos_ : 0; | ||
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| int32_t bos = needs_bos_ ? num_bos_ : 0; | ||
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| .echo = static_cast<bool>(echo), | ||
| .seq_len = seq_len, | ||
| .temperature = effective_temperature, | ||
| .num_bos = needs_bos_ ? num_bos_ : 0, |
There was a problem hiding this comment.
generate() ignores the num_bos argument and instead uses the constructor default num_bos_ when needs_bos_ is true. This changes the behavior of LlmGenerationConfig.getNumBos() / the JNI signature and makes per-call BOS control ineffective. Use the num_bos parameter for this call (or remove it from the JNI/Java API if it’s intentionally fixed at init).
| .num_bos = needs_bos_ ? num_bos_ : 0, | |
| .num_bos = needs_bos_ ? num_bos : 0, |
| ss << "Invalid model type category: " << model_type_category_ | ||
| << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " | ||
| << MODEL_TYPE_CATEGORY_MULTIMODAL; |
There was a problem hiding this comment.
When runner_ is null, load() always throws an exception saying the model type category is invalid. But create_text_llm_runner(...) / create_multimodal_runner(...) can also return nullptr on initialization failures (e.g., tokenizer not loaded, metadata read failure), which would make this message misleading. Consider distinguishing “unsupported model_type_category_” from “failed to create runner” and reporting the latter accurately.
| ss << "Invalid model type category: " << model_type_category_ | |
| << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " | |
| << MODEL_TYPE_CATEGORY_MULTIMODAL; | |
| if (model_type_category_ != MODEL_TYPE_CATEGORY_LLM && | |
| model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { | |
| ss << "Invalid model type category: " << model_type_category_ | |
| << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " | |
| << MODEL_TYPE_CATEGORY_MULTIMODAL; | |
| } else { | |
| ss << "Failed to create runner for model type category: " | |
| << model_type_category_; | |
| } |
| /** | ||
| * Prefill a multimodal Module with the given images input. | ||
| * Prefill the KV cache with the given image input. | ||
| * | ||
| * @param image Input image as a byte array | ||
| * @param width Input image width | ||
| * @param height Input image height | ||
| * @param channels Input image number of channels | ||
| * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer | ||
| * exposed to user. | ||
| * @return 0 on success |
There was a problem hiding this comment.
Javadoc says the image parameter is a “byte array”, but the API type is int[]. Update the parameter description to match the actual type (and expected value range) so callers don’t pass the wrong format.
| int32_t bos = needs_bos_ ? num_bos_ : 0; | ||
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| needs_bos_ = false; | ||
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | ||
| if (!result.ok()) { | ||
| return static_cast<jint>(result.error()); | ||
| } |
There was a problem hiding this comment.
needs_bos_ is cleared before checking whether runner_->prefill(...) succeeded. If prefill fails, subsequent prefill/generate calls won’t add BOS even though no successful prefill occurred. Only set needs_bos_ = false after result.ok() (or restore it on error).
| needs_bos_ = false; | |
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | |
| if (!result.ok()) { | |
| return static_cast<jint>(result.error()); | |
| } | |
| auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); | |
| if (!result.ok()) { | |
| return static_cast<jint>(result.error()); | |
| } | |
| needs_bos_ = false; |
| needs_bos_ = false; | ||
| auto err = runner_->generate( | ||
| prompt->toStdString(), | ||
| config, | ||
| token_callback, | ||
| [callback](const llm::Stats& result) { callback->onStats(result); }); | ||
| return static_cast<jint>(err); |
There was a problem hiding this comment.
needs_bos_ is set to false before verifying that runner_->generate(...) succeeded, and the return value from generate() is ignored. If generation fails, the JNI method still returns success (0) and future calls won’t prepend BOS. Capture and return/throw on the runtime::Error from runner_->generate(...), and only clear needs_bos_ on success.
| needs_bos_ = false; | |
| auto err = runner_->generate( | |
| prompt->toStdString(), | |
| config, | |
| token_callback, | |
| [callback](const llm::Stats& result) { callback->onStats(result); }); | |
| return static_cast<jint>(err); | |
| auto err = runner_->generate( | |
| prompt->toStdString(), | |
| config, | |
| token_callback, | |
| [callback](const llm::Stats& result) { callback->onStats(result); }); | |
| if (err != Error::Ok) { | |
| return static_cast<jint>(err); | |
| } | |
| needs_bos_ = false; | |
| return static_cast<jint>(Error::Ok); |
Summary
Replace the dual-runner pattern (runner_ + multi_modal_runner_) with a single IRunner* that holds either TextLLMRunner or MultimodalRunner, leveraging MultimodalRunner's new IRunner inheritance from #17741.
Each prefill method (text, images, audio) now immediately calls IRunner::prefill(vector) instead of buffering inputs for later consumption by generate(). A needs_bos_ flag tracks whether the next prefill should apply BOS tokens — MultimodalRunner also guards this via pos_==0 internally, but TextLLMRunner trusts the caller.
generate(), stop(), load(), and reset() no longer branch on model_type_category_; all dispatch through the unified runner_.
Rename all JNI native methods from append* to prefill* to match the existing Java public API naming.
Test plan
CI