Skip to content

Commit 934cad4

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
IOManager Interface
Summary: Hopefully this is sufficient for the contract. Going to do 2 follow up tests. Add a basic cpu implementation add a static attention implementation. Differential Revision: D73450877
1 parent 954f2cb commit 934cad4

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/executor/method_meta.h>
12+
#include <executorch/extension/tensor/tensor.h>
13+
#include <runtime/executor/method.h>
14+
15+
namespace executorch {
16+
namespace extension {
17+
namespace llm {
18+
19+
/**
20+
* @brief Base class for managing input/output operations for LLM inference.
21+
*
22+
* IOManagerBase provides an interface for handling the input preparation and output
23+
* processing for both prefill and decode phases of LLM inference. Derived classes
24+
* must implement the virtual methods to provide specific IO management functionality.
25+
*/
26+
class ET_EXPERIMENTAL IOManagerBase {
27+
public:
28+
29+
/**
30+
* @brief Virtual destructor to allow proper cleanup in derived classes.
31+
*/
32+
ET_EXPERIMENTAL virtual ~IOManagerBase() = default;
33+
34+
35+
/**
36+
* @brief Initialize the IO manager with method metadata for prefill and decode operations.
37+
*
38+
* @param prefill_method The prefill method to initialize with.
39+
* @param decode_method The decode method to initialize with.
40+
*/
41+
ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error init(
42+
executorch::runtime::Method& prefill_method,
43+
executorch::runtime::Method& decode_method) = 0;
44+
45+
/**
46+
* @brief Reset the IO manager state.
47+
*
48+
* @param prefill_method The prefill method to reset with.
49+
* @param decode_method The decode method to reset with.
50+
*/
51+
ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error reset(
52+
executorch::runtime::Method& prefill_method,
53+
executorch::runtime::Method& decode_method) = 0;
54+
55+
/**
56+
* @brief Prepare inputs for the prefill phase of LLM inference.
57+
*
58+
* @param input The input tensor containing token IDs.
59+
* @param start_pos The tensor containing the starting position of the current input within the context.
60+
* @param prefill_method The prefill method to prepare inputs for.
61+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs for the prefill method.
62+
*/
63+
ET_EXPERIMENTAL virtual runtime::Result<std::vector<executorch::runtime::EValue>> prepare_prefill(
64+
const executorch::extension::TensorPtr& input,
65+
const executorch::extension::TensorPtr& start_pos,
66+
executorch::runtime::Method& prefill_method) = 0;
67+
68+
/**
69+
* @brief Prepare inputs for the decode phase of LLM inference.
70+
*
71+
* @param input The input tensor containing token IDs.
72+
* @param start_pos The tensor containing the starting position of the current input within the context.
73+
* @param decode_method The decode method to prepare inputs for.
74+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs for the decode method.
75+
*/
76+
ET_EXPERIMENTAL virtual runtime::Result<std::vector<executorch::runtime::EValue>> prepare_decode(
77+
const executorch::extension::TensorPtr& input,
78+
const executorch::extension::TensorPtr& start_pos,
79+
executorch::runtime::Method& decode_method) = 0;
80+
81+
/**
82+
* @brief Process and update internal state with outputs from the prefill phase.
83+
*
84+
* @param prefill_method The prefill method to update with outputs.
85+
* @param model_outputs Vector of outputs from the prefill method execution.
86+
*/
87+
ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_prefill(
88+
executorch::runtime::Method& prefill_method,
89+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
90+
91+
/**
92+
* @brief Process and update internal state with outputs from the decode phase.
93+
*
94+
* @param decode_method The decode method to update with outputs.
95+
* @param model_outputs Vector of outputs from the decode method execution.
96+
*/
97+
ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_decode(
98+
const executorch::runtime::Method& decode_method,
99+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
100+
101+
};
102+
103+
} // namespace llm
104+
} // namespace extension
105+
} // namespace executorch
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "io_manager",
6+
exported_headers = [
7+
"io_manager.h",
8+
],
9+
visibility = [
10+
"@EXECUTORCH_CLIENTS",
11+
],
12+
)

0 commit comments

Comments
 (0)