Skip to content

Commit 0b07561

Browse files
authored
Implement addedTokens in BertTokenizer (#193)
1 parent eec56ed commit 0b07561

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

Sources/Tokenizers/BertTokenizer.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ public class BertTokenizer {
4646
}
4747

4848
public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws {
49-
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else {
50-
throw TokenizerError.missingVocab
49+
guard var vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else { throw TokenizerError.missingVocab }
50+
if let addedTokens = tokenizerData.added_tokens?.dictionary["value"] as? [[String: Any]],
51+
let pairs = addedTokens.compactMap({ ($0["content"] as? String, $0["id"] as? Int) }) as? [(String, Int)] {
52+
vocab.merge(pairs, uniquingKeysWith: {$1})
5153
}
54+
vocab.merge(addedTokens, uniquingKeysWith: {$1})
5255
let merges = tokenizerData.model?.merges?.value as? [String]
5356
let tokenizeChineseChars = tokenizerConfig.handleChineseChars?.boolValue ?? true
5457
let eosToken = tokenizerConfig.eosToken?.stringValue

Tests/TokenizersTests/BertTokenizerTests.swift

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import XCTest
1010
@testable import Tokenizers
11-
11+
@testable import Hub
1212

1313

1414
class BertTokenizerTests: XCTestCase {
@@ -178,4 +178,35 @@ class BertTokenizerTests: XCTestCase {
178178
XCTAssertEqual(decoded, String(expected))
179179
}
180180
}
181+
182+
func testBertTokenizerAddedTokensRecognized() async throws {
183+
let base: URL = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests")
184+
let hubApi = HubApi(downloadBase: base)
185+
let configuration = LanguageModelConfigurationFromHub(modelName: "google-bert/bert-base-uncased", hubApi: hubApi)
186+
guard let tokenizerConfig = try await configuration.tokenizerConfig else { fatalError("missing tokenizer config") }
187+
let tokenizerData = try await configuration.tokenizerData
188+
let addedTokens = [
189+
"[ROAD]": 60_001,
190+
"[RIVER]": 60_002,
191+
"[BUILDING]": 60_003,
192+
"[PARK]": 60_004,
193+
"[BUFFER]": 60_005,
194+
"[INTERSECT]": 60_006,
195+
"[UNION]": 60_007,
196+
]
197+
let tokenizer = try BertTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
198+
for (token, idx) in addedTokens {
199+
XCTAssertEqual(tokenizer.convertTokenToId(token), idx)
200+
}
201+
for (token, idx) in addedTokens {
202+
XCTAssertEqual(tokenizer.convertIdToToken(idx), token)
203+
}
204+
205+
// Reading added_tokens from tokenizer.json
206+
XCTAssertEqual(tokenizer.convertTokenToId("[PAD]"), 0)
207+
XCTAssertEqual(tokenizer.convertTokenToId("[UNK]"), 100)
208+
XCTAssertEqual(tokenizer.convertTokenToId("[CLS]"), 101)
209+
XCTAssertEqual(tokenizer.convertTokenToId("[SEP]"), 102)
210+
XCTAssertEqual(tokenizer.convertTokenToId("[MASK]"), 103)
211+
}
181212
}

0 commit comments

Comments
 (0)