Skip to content

Commit 7abb2d8

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
IOManager Interface (#10418)
Summary: Hopefully this is sufficient for the contract. Going to do 2 follow up tests. Add a basic cpu implementation add a static attention implementation. Differential Revision: D73450877
1 parent de98d25 commit 7abb2d8

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/executor/method.h>
14+
#include <executorch/runtime/executor/method_meta.h>
15+
16+
namespace executorch {
17+
namespace extension {
18+
namespace llm {
19+
20+
/**
21+
* @brief Base class for managing input/output operations for LLM inference.
22+
*
23+
* IOManagerBase provides an interface for handling the input preparation and
24+
* output processing for both prefill and decode phases of LLM inference.
25+
* Derived classes must implement the virtual methods to provide specific IO
26+
* management functionality.
27+
*/
28+
class ET_EXPERIMENTAL IOManagerBase {
29+
public:
30+
/**
31+
* @brief Virtual destructor to allow proper cleanup in derived classes.
32+
*/
33+
virtual ~IOManagerBase() = default;
34+
35+
/**
36+
* @brief Initialize the IO manager with method metadata for prefill and
37+
* decode operations.
38+
*
39+
* @param prefill_method The prefill method to initialize with.
40+
* @param decode_method The decode method to initialize with.
41+
*/
42+
ET_NODISCARD virtual runtime::Error init(
43+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
44+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
45+
46+
/**
47+
* @brief Reset the IO manager state.
48+
*
49+
* @param prefill_method The prefill method to reset with.
50+
* @param decode_method The decode method to reset with.
51+
*/
52+
ET_NODISCARD virtual runtime::Error reset(
53+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
54+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
55+
56+
/**
57+
* @brief Prepare inputs for the prefill phase of LLM inference.
58+
*
59+
* @param input The input tensor containing token IDs.
60+
* @param start_pos The tensor containing the starting position of the current
61+
* input within the context.
62+
* @param prefill_method The prefill method to prepare inputs for.
63+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
64+
* for the prefill method.
65+
*/
66+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
67+
prepare_prefill(
68+
const executorch::extension::TensorPtr& input,
69+
const executorch::extension::TensorPtr& start_pos,
70+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) = 0;
71+
72+
/**
73+
* @brief Prepare inputs for the decode phase of LLM inference.
74+
*
75+
* @param input The input tensor containing token IDs.
76+
* @param start_pos The tensor containing the starting position of the current
77+
* input within the context.
78+
* @param decode_method The decode method to prepare inputs for.
79+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
80+
* for the decode method.
81+
*/
82+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
83+
prepare_decode(
84+
const executorch::extension::TensorPtr& input,
85+
const executorch::extension::TensorPtr& start_pos,
86+
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) = 0;
87+
88+
/**
89+
* @brief Process and update internal state with outputs from the prefill
90+
* phase.
91+
*
92+
* @param prefill_method The prefill method to update with outputs.
93+
* @param model_outputs Vector of outputs from the prefill method execution.
94+
*/
95+
ET_NODISCARD virtual runtime::Error update_prefill(
96+
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
97+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
98+
99+
/**
100+
* @brief Process and update internal state with outputs from the decode
101+
* phase.
102+
*
103+
* @param decode_method The decode method to update with outputs.
104+
* @param model_outputs Vector of outputs from the decode method execution.
105+
*/
106+
ET_NODISCARD virtual runtime::Error update_decode(
107+
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
108+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
109+
};
110+
111+
} // namespace llm
112+
} // namespace extension
113+
} // namespace executorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
for aten in (True, False):
5+
aten_suffix = "_aten" if aten else ""
6+
7+
# Interface for IOManager. No concrete impl from this dep.
8+
runtime.cxx_library(
9+
name = "io_manager" + aten_suffix,
10+
exported_headers = [
11+
"io_manager.h",
12+
],
13+
deps = [
14+
"//executorch/extension/module:module" + aten_suffix,
15+
"//executorch/extension/tensor:tensor" + aten_suffix,
16+
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
17+
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
18+
],
19+
visibility = [
20+
"@EXECUTORCH_CLIENTS",
21+
],
22+
)

0 commit comments

Comments
 (0)