Skip to content

Commit 5a734f9

Browse files
Add get_vocabulary, id_to_token and token_to_id methods to ByteTokenizer and UnicodeCodepointTokenizer. (#1664)
1 parent 50e0414 commit 5a734f9

File tree

4 files changed

+76
-0
lines changed

4 files changed

+76
-0
lines changed

keras_nlp/src/tokenizers/byte_tokenizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ def vocabulary_size(self):
209209
"""Get the integer size of the tokenizer vocabulary."""
210210
return 256
211211

212+
def get_vocabulary(self):
213+
vocab = {}
214+
for i in range(self.vocabulary_size()):
215+
vocab[chr(i)] = i
216+
return vocab
217+
212218
def tokenize(self, inputs):
213219
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
214220
inputs = tf.convert_to_tensor(inputs)
@@ -264,6 +270,24 @@ def detokenize(self, inputs):
264270
outputs = tf.squeeze(outputs, 0)
265271
return outputs
266272

273+
def id_to_token(self, id):
274+
"""Convert an integer id to a string token."""
275+
if id >= self.vocabulary_size() or id < 0:
276+
raise ValueError(
277+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
278+
f"Received: {id}"
279+
)
280+
return chr(id)
281+
282+
def token_to_id(self, token):
283+
"""Convert a string token to an integer id."""
284+
id = ord(token)
285+
if id >= self.vocabulary_size():
286+
raise ValueError(
287+
f"Token {token} is not supported by `ByteTokenizer`."
288+
)
289+
return id
290+
267291
def get_config(self):
268292
config = super().get_config()
269293
config.update(

keras_nlp/src/tokenizers/byte_tokenizer_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,17 @@ def test_config(self):
222222
tokenizer(input_data),
223223
cloned_tokenizer(input_data),
224224
)
225+
226+
def test_token_to_id(self):
227+
input_tokens = ["f", "u", "n"]
228+
expected_ids = [102, 117, 110]
229+
tokenizer = ByteTokenizer()
230+
ids = [tokenizer.token_to_id(t) for t in input_tokens]
231+
self.assertAllEqual(ids, expected_ids)
232+
233+
def test_id_to_token(self):
234+
input_ids = [102, 117, 110]
235+
expected_tokens = ["f", "u", "n"]
236+
tokenizer = ByteTokenizer()
237+
tokens = [tokenizer.id_to_token(i) for i in input_ids]
238+
self.assertAllEqual(tokens, expected_tokens)

keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ def vocabulary_size(self):
280280
size was provided"""
281281
return self._vocabulary_size
282282

283+
def get_vocabulary(self):
284+
vocab = {}
285+
for i in range(self.vocabulary_size()):
286+
vocab[chr(i)] = i
287+
return vocab
288+
283289
def tokenize(self, inputs):
284290
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
285291
inputs = tf.convert_to_tensor(inputs)
@@ -331,3 +337,21 @@ def detokenize(self, inputs):
331337
if unbatched:
332338
outputs = tf.squeeze(outputs, 0)
333339
return outputs
340+
341+
def id_to_token(self, id):
342+
"""Convert an integer id to a string token."""
343+
if id >= self.vocabulary_size() or id < 0:
344+
raise ValueError(
345+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
346+
f"Received: {id}"
347+
)
348+
return chr(id)
349+
350+
def token_to_id(self, token):
351+
"""Convert a string token to an integer id."""
352+
id = ord(token)
353+
if id >= self.vocabulary_size():
354+
raise ValueError(
355+
f"Token {token} is not supported by `UnicodeCodepointTokenizer`."
356+
)
357+
return id

keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,17 @@ def test_config(self):
280280
tokenizer(input_data),
281281
cloned_tokenizer(input_data),
282282
)
283+
284+
def test_token_to_id(self):
285+
input_tokens = ["ب", "و", "خ"]
286+
expected_ids = [1576, 1608, 1582]
287+
tokenizer = UnicodeCodepointTokenizer(vocabulary_size=2000)
288+
ids = [tokenizer.token_to_id(t) for t in input_tokens]
289+
self.assertAllEqual(ids, expected_ids)
290+
291+
def test_id_to_token(self):
292+
input_ids = [1576, 1608, 1582]
293+
expected_tokens = ["ب", "و", "خ"]
294+
tokenizer = UnicodeCodepointTokenizer(vocabulary_size=2000)
295+
tokens = [tokenizer.id_to_token(i) for i in input_ids]
296+
self.assertAllEqual(tokens, expected_tokens)

0 commit comments

Comments
 (0)