Skip to content
9 changes: 8 additions & 1 deletion onnxruntime_extensions/_hf_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ def convert_json_vocab(hf_tokenizer):
# get vocab object from json file
vocab = tokenizer_json.get("model", {}).get("vocab", {})
sorted_merges = tokenizer_json.get("model", {}).get("merges", [])
sorted_merges = [v_.replace("\n", "<0x0A>") for v_ in sorted_merges]

attrs = {"vocab": json.dumps(vocab, separators=(",", ":"))}

# merges data can be a list of string or list of list of string
if (all(isinstance(v_,(list,tuple))) for v_ in sorted_merges) :
sorted_merges = [ " ".join(v if v != "\n" else "<0x0A>" for v in v_ ) for v_ in sorted_merges]
else :
sorted_merges = [v_.replace("\n", "<0x0A>") for v_ in sorted_merges]

attrs["merges"] = "\n".join(sorted_merges)
if hf_tokenizer.added_tokens_encoder:
token_map = [f"{_k}={_v}" for _k,
Expand Down
14 changes: 14 additions & 0 deletions test/test_autotokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import unittest

import os
import numpy as np
from transformers import AutoTokenizer, GPT2Tokenizer
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference, util
Expand Down Expand Up @@ -128,6 +129,19 @@ def print_prime(n):
self.assertEqual(len(ids["input_ids"].shape), len(actual_ids.shape))
np.testing.assert_array_equal(ids["input_ids"], actual_ids)

def test_microsoft_phi4(self):
script_dir = os.path.dirname(os.path.abspath(__file__))
tokenizer_dir = os.path.join(script_dir, 'data',"phi-4-mini-reasoning")

tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, torch_dtype="auto")
code = 'This is a sample Code'

ids = tokenizer(code, return_tensors="np")
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
actual_ids, *_ = ort_inference(ort_tok, [code])
self.assertEqual(len(ids["input_ids"].shape), len(actual_ids.shape))
np.testing.assert_array_equal(ids["input_ids"], actual_ids)

if __name__ == '__main__':
unittest.main()