@@ -101,41 +101,55 @@ def tokenize_and_pad(
101
101
102
102
103
103
def chunk_and_pad_tokens (
104
- tokens ,
104
+ tokens : np . ndarray ,
105
105
bos_id : int ,
106
106
pad_id : int ,
107
- is_bos : bool = True ,
107
+ is_bos : bool ,
108
+ chunk_size : int ,
108
109
prefill_lengths : Optional [List [int ]] = None ,
109
110
max_prefill_length : Optional [int ] = None ,
110
- chunk_size : Optional [int ] = None ,
111
111
jax_padding : bool = True ,
112
112
) -> Tuple [
113
113
List [Union [jax .Array , np .ndarray ]],
114
- List [Union [ jax . Array , np . ndarray ] ],
115
- List [Union [ jax .Array , np . ndarray ] ],
114
+ List [int ],
115
+ List [jax .Array ],
116
116
]:
117
- """Chunks and pads tokens for chunked prefill
118
- if total token size is 520 and chunk size is 256,
117
+ """Chunks and pads tokens for chunked prefill.
118
+
119
+ If total token size is 520 and chunk size is 256,
119
120
the function will return 3 chunks and return tuple is as follows-
120
121
[[t0,..t255][t256,..t511][t512,..t519]],
121
122
[256, 256, 7],
122
- [[0,..255],[ 256,..511],[ 512..518..]]
123
+ [[[ 0,..255]],[[ 256,..511]],[[ 512..518..] ]]
123
124
124
125
Args:
125
126
tokens: Tokens.
126
127
bos_id: Bos ID.
127
128
pad_id: Pad ID.
128
129
is_bos: Add a beginning of sequence token if this is ture.
130
+ chunk_size: maximum size of each chunk
129
131
prefill_lengths: Buckets to pad the sequence to for static compilation.
130
132
max_prefill_length: Maximum bucket to use.
131
- chunk_size: maximum size of each chunk
132
133
jax_padding: convert to JAX padded tokens if True.
133
134
134
135
Returns:
135
136
chunk_padded_tokens: List of chunked and padded tokens.
136
137
padded_chunk_true_lengths: List of integers - true length of each chunk
137
138
positions:list of position of each token in the chunk
138
139
"""
140
+ # Add a beginning of sequence token if this is the beginning.
141
+ if is_bos :
142
+ tokens = np .concatenate (
143
+ [
144
+ np .array (
145
+ [
146
+ bos_id ,
147
+ ]
148
+ ),
149
+ tokens ,
150
+ ],
151
+ axis = - 1 ,
152
+ )
139
153
140
154
num_tokens = len (tokens )
141
155
num_chunks = int (math .ceil (num_tokens / chunk_size ))
@@ -147,33 +161,22 @@ def chunk_and_pad_tokens(
147
161
148
162
# positions of tokens in each chunk
149
163
positions = []
150
- # to be able to slice the tokens
151
- tokens = jnp .array (tokens )
164
+
152
165
for chunk_num in range (num_chunks ):
153
- start = int (chunk_num * chunk_size )
154
- end = jnp .minimum ((chunk_num + 1 ) * chunk_size , num_tokens )
155
- chunk_tokens = jax .lax .slice (tokens , (start ,), (end ,))
156
- if chunk_num == 0 :
157
- padded_chunk , padded_chunk_true_length = pad_tokens (
158
- chunk_tokens ,
159
- bos_id ,
160
- pad_id ,
161
- is_bos ,
162
- prefill_lengths ,
163
- max_prefill_length ,
164
- jax_padding ,
165
- )
166
- else :
167
- # is_bos should be false in subsequent chunks.
168
- padded_chunk , padded_chunk_true_length = pad_tokens (
169
- chunk_tokens ,
170
- bos_id ,
171
- pad_id ,
172
- False ,
173
- prefill_lengths ,
174
- max_prefill_length ,
175
- jax_padding ,
176
- )
166
+ start : int = chunk_num * chunk_size
167
+ end : int = min ((chunk_num + 1 ) * chunk_size , num_tokens )
168
+ chunk_tokens = tokens [start :end ]
169
+ # the bos is added at the begin of the function.
170
+ # is_bos should be false in chunks.
171
+ padded_chunk , padded_chunk_true_length = pad_tokens (
172
+ chunk_tokens ,
173
+ bos_id ,
174
+ pad_id ,
175
+ False ,
176
+ prefill_lengths ,
177
+ max_prefill_length ,
178
+ jax_padding ,
179
+ )
177
180
178
181
positions_chunk = jnp .expand_dims (
179
182
jnp .arange (start , start + len (padded_chunk ), dtype = jnp .int32 ), 0
0 commit comments