Skip to content

Commit 1691557

Browse files
authored
Add support for ModernBert (#1104)
* Fix token decode in fill-mask pipeline * Add support for ModernBERT * Add modernbert unit tests * Cleanup bert unit tests * Add unit test for `sequence_length > local_attention_window`
1 parent 610391d commit 1691557

File tree

7 files changed

+377
-93
lines changed

7 files changed

+377
-93
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
366366
1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
367367
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
368368
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
369+
1. **[ModernBERT](https://huggingface.co/docs/transformers/model_doc/modernbert)** (from Answer.AI) released with the paper [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Gallagher, Raja Biswas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Griffin Adams, Jeremy Howard, Iacopo Poli.
369370
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
370371
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
371372
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.

docs/snippets/6_supported-models.snippet

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
1. **MobileNetV4** (from Google Inc.) released with the paper [MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) by Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard.
8282
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
8383
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
84+
1. **[ModernBERT](https://huggingface.co/docs/transformers/model_doc/modernbert)** (from Answer.AI) released with the paper [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Gallagher, Raja Biswas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Griffin Adams, Jeremy Howard, Iacopo Poli.
8485
1. **Moondream1** released in the repository [moondream](https://github.com/vikhyat/moondream) by vikhyat.
8586
1. **[Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine)** (from Useful Sensors) released with the paper [Moonshine: Speech Recognition for Live Transcription and Voice Commands](https://arxiv.org/abs/2410.15608) by Nat Jeffries, Evan King, Manjunath Kudlur, Guy Nicholson, James Wang, Pete Warden.
8687
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.

src/models.js

+47
Original file line numberDiff line numberDiff line change
@@ -1951,6 +1951,49 @@ export class BertForQuestionAnswering extends BertPreTrainedModel {
19511951
}
19521952
//////////////////////////////////////////////////
19531953

1954+
//////////////////////////////////////////////////
1955+
// ModernBert models
1956+
export class ModernBertPreTrainedModel extends PreTrainedModel { }
1957+
export class ModernBertModel extends ModernBertPreTrainedModel { }
1958+
1959+
export class ModernBertForMaskedLM extends ModernBertPreTrainedModel {
1960+
/**
1961+
* Calls the model on new inputs.
1962+
*
1963+
* @param {Object} model_inputs The inputs to the model.
1964+
* @returns {Promise<MaskedLMOutput>} An object containing the model's output logits for masked language modeling.
1965+
*/
1966+
async _call(model_inputs) {
1967+
return new MaskedLMOutput(await super._call(model_inputs));
1968+
}
1969+
}
1970+
1971+
export class ModernBertForSequenceClassification extends ModernBertPreTrainedModel {
1972+
/**
1973+
* Calls the model on new inputs.
1974+
*
1975+
* @param {Object} model_inputs The inputs to the model.
1976+
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
1977+
*/
1978+
async _call(model_inputs) {
1979+
return new SequenceClassifierOutput(await super._call(model_inputs));
1980+
}
1981+
}
1982+
1983+
export class ModernBertForTokenClassification extends ModernBertPreTrainedModel {
1984+
/**
1985+
* Calls the model on new inputs.
1986+
*
1987+
* @param {Object} model_inputs The inputs to the model.
1988+
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
1989+
*/
1990+
async _call(model_inputs) {
1991+
return new TokenClassifierOutput(await super._call(model_inputs));
1992+
}
1993+
}
1994+
//////////////////////////////////////////////////
1995+
1996+
19541997
//////////////////////////////////////////////////
19551998
// NomicBert models
19561999
export class NomicBertPreTrainedModel extends PreTrainedModel { }
@@ -6921,6 +6964,7 @@ export class PretrainedMixin {
69216964

69226965
const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
69236966
['bert', ['BertModel', BertModel]],
6967+
['modernbert', ['ModernBertModel', ModernBertModel]],
69246968
['nomic_bert', ['NomicBertModel', NomicBertModel]],
69256969
['roformer', ['RoFormerModel', RoFormerModel]],
69266970
['electra', ['ElectraModel', ElectraModel]],
@@ -7059,6 +7103,7 @@ const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([
70597103

70607104
const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
70617105
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
7106+
['modernbert', ['ModernBertForSequenceClassification', ModernBertForSequenceClassification]],
70627107
['roformer', ['RoFormerForSequenceClassification', RoFormerForSequenceClassification]],
70637108
['electra', ['ElectraForSequenceClassification', ElectraForSequenceClassification]],
70647109
['esm', ['EsmForSequenceClassification', EsmForSequenceClassification]],
@@ -7080,6 +7125,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
70807125

70817126
const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
70827127
['bert', ['BertForTokenClassification', BertForTokenClassification]],
7128+
['modernbert', ['ModernBertForTokenClassification', ModernBertForTokenClassification]],
70837129
['roformer', ['RoFormerForTokenClassification', RoFormerForTokenClassification]],
70847130
['electra', ['ElectraForTokenClassification', ElectraForTokenClassification]],
70857131
['esm', ['EsmForTokenClassification', EsmForTokenClassification]],
@@ -7148,6 +7194,7 @@ const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
71487194

71497195
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
71507196
['bert', ['BertForMaskedLM', BertForMaskedLM]],
7197+
['modernbert', ['ModernBertForMaskedLM', ModernBertForMaskedLM]],
71517198
['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
71527199
['electra', ['ElectraForMaskedLM', ElectraForMaskedLM]],
71537200
['esm', ['EsmForMaskedLM', EsmForMaskedLM]],

src/pipelines.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons
688688
return {
689689
score: values[i],
690690
token: Number(x),
691-
token_str: this.tokenizer.model.vocab[x],
691+
token_str: this.tokenizer.decode([x]),
692692
sequence: this.tokenizer.decode(sequence, { skip_special_tokens: true }),
693693
}
694694
}));

tests/models/bert/test_modeling_bert.js

+4-14
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,9 @@ export default () => {
100100
async () => {
101101
const inputs = tokenizer("hello");
102102
const { logits } = await model(inputs);
103-
const target = [[0.00043986947275698185, -0.030218850821256638]].flat();
103+
const target = [[0.00043986947275698185, -0.030218850821256638]];
104104
expect(logits.dims).toEqual([1, 2]);
105-
logits
106-
.tolist()
107-
.flat()
108-
.forEach((item, i) => {
109-
expect(item).toBeCloseTo(target[i], 5);
110-
});
105+
expect(logits.tolist()).toBeCloseToNested(target, 5);
111106
},
112107
MAX_TEST_EXECUTION_TIME,
113108
);
@@ -120,14 +115,9 @@ export default () => {
120115
const target = [
121116
[0.00043986947275698185, -0.030218850821256638],
122117
[0.0003853091038763523, -0.03022204339504242],
123-
].flat();
118+
];
124119
expect(logits.dims).toEqual([2, 2]);
125-
logits
126-
.tolist()
127-
.flat()
128-
.forEach((item, i) => {
129-
expect(item).toBeCloseTo(target[i], 5);
130-
});
120+
expect(logits.tolist()).toBeCloseToNested(target, 5);
131121
},
132122
MAX_TEST_EXECUTION_TIME,
133123
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import { PreTrainedTokenizer, ModernBertModel, ModernBertForMaskedLM, ModernBertForSequenceClassification, ModernBertForTokenClassification } from "../../../src/transformers.js";
2+
3+
import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../../init.js";
4+
5+
export default () => {
6+
describe("ModernBertModel", () => {
7+
const model_id = "hf-internal-testing/tiny-random-ModernBertModel";
8+
9+
/** @type {ModernBertModel} */
10+
let model;
11+
/** @type {PreTrainedTokenizer} */
12+
let tokenizer;
13+
beforeAll(async () => {
14+
model = await ModernBertModel.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
15+
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
16+
}, MAX_MODEL_LOAD_TIME);
17+
18+
it(
19+
"batch_size=1",
20+
async () => {
21+
const inputs = tokenizer("hello");
22+
const { last_hidden_state } = await model(inputs);
23+
expect(last_hidden_state.dims).toEqual([1, 3, 32]);
24+
expect(last_hidden_state.mean().item()).toBeCloseTo(-0.08922556787729263, 5);
25+
},
26+
MAX_TEST_EXECUTION_TIME,
27+
);
28+
29+
it(
30+
"batch_size>1",
31+
async () => {
32+
const inputs = tokenizer(["hello", "hello world"], { padding: true });
33+
const { last_hidden_state } = await model(inputs);
34+
expect(last_hidden_state.dims).toEqual([2, 4, 32]);
35+
expect(last_hidden_state.mean().item()).toBeCloseTo(0.048988230526447296, 5);
36+
},
37+
MAX_TEST_EXECUTION_TIME,
38+
);
39+
40+
it(
41+
"sequence_length > local_attention_window",
42+
async () => {
43+
const text = "The sun cast long shadows across the weathered cobblestones as Thomas made his way through the ancient city. The evening air carried whispers of autumn, rustling through the golden leaves that danced and swirled around his feet. His thoughts wandered to the events that had brought him here, to this moment, in this forgotten corner of the world. The old buildings loomed above him, their facades telling stories of centuries past. Windows reflected the dying light of day, creating a kaleidoscope of amber and rose that painted the narrow streets. The distant sound of church bells echoed through the maze of alleyways, marking time's steady march forward. In his pocket, he fingered the small brass key that had belonged to his grandfather. Its weight seemed to grow heavier with each step, a tangible reminder of the promise he had made. The mystery of its purpose had consumed his thoughts for weeks, leading him through archives and dusty libraries, through conversations with local historians and elderly residents who remembered the old days. As the evening deepened into dusk, streetlamps flickered to life one by one, creating pools of warm light that guided his way. The smell of wood smoke and distant cooking fires drifted through the air, reminding him of childhood evenings spent by the hearth, listening to his grandfather's tales of hidden treasures and secret passages. His footsteps echoed against the stone walls, a rhythmic accompaniment to his journey. Each step brought him closer to his destination, though uncertainty still clouded his mind about what he might find. The old map in his other pocket, creased and worn from constant consultation, had led him this far. The street ahead narrowed, and the buildings seemed to lean in closer, their upper stories nearly touching above his head. The air grew cooler in this shadowed passage, and his breath formed small clouds in front of him. Something about this place felt different, charged with possibility and ancient secrets. He walked down the [MASK]";
44+
const inputs = tokenizer(text);
45+
const { last_hidden_state } = await model(inputs);
46+
expect(last_hidden_state.dims).toEqual([1, 397, 32]);
47+
expect(last_hidden_state.mean().item()).toBeCloseTo(-0.06889555603265762, 5);
48+
},
49+
MAX_TEST_EXECUTION_TIME,
50+
);
51+
52+
afterAll(async () => {
53+
await model?.dispose();
54+
}, MAX_MODEL_DISPOSE_TIME);
55+
});
56+
57+
describe("ModernBertForMaskedLM", () => {
58+
const model_id = "hf-internal-testing/tiny-random-ModernBertForMaskedLM";
59+
60+
const texts = ["The goal of life is [MASK].", "Paris is the [MASK] of France."];
61+
62+
/** @type {ModernBertForMaskedLM} */
63+
let model;
64+
/** @type {PreTrainedTokenizer} */
65+
let tokenizer;
66+
beforeAll(async () => {
67+
model = await ModernBertForMaskedLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
68+
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
69+
}, MAX_MODEL_LOAD_TIME);
70+
71+
it(
72+
"batch_size=1",
73+
async () => {
74+
const inputs = tokenizer(texts[0]);
75+
const { logits } = await model(inputs);
76+
expect(logits.dims).toEqual([1, 9, 50368]);
77+
expect(logits.mean().item()).toBeCloseTo(0.0053214821964502335, 5);
78+
},
79+
MAX_TEST_EXECUTION_TIME,
80+
);
81+
82+
it(
83+
"batch_size>1",
84+
async () => {
85+
const inputs = tokenizer(texts, { padding: true });
86+
const { logits } = await model(inputs);
87+
expect(logits.dims).toEqual([2, 9, 50368]);
88+
expect(logits.mean().item()).toBeCloseTo(0.009154772385954857, 5);
89+
},
90+
MAX_TEST_EXECUTION_TIME,
91+
);
92+
93+
afterAll(async () => {
94+
await model?.dispose();
95+
}, MAX_MODEL_DISPOSE_TIME);
96+
});
97+
98+
describe("ModernBertForSequenceClassification", () => {
99+
const model_id = "hf-internal-testing/tiny-random-ModernBertForSequenceClassification";
100+
101+
/** @type {ModernBertForSequenceClassification} */
102+
let model;
103+
/** @type {PreTrainedTokenizer} */
104+
let tokenizer;
105+
beforeAll(async () => {
106+
model = await ModernBertForSequenceClassification.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
107+
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
108+
}, MAX_MODEL_LOAD_TIME);
109+
110+
it(
111+
"batch_size=1",
112+
async () => {
113+
const inputs = tokenizer("hello");
114+
const { logits } = await model(inputs);
115+
const target = [[-0.7050137519836426, 2.343430519104004]];
116+
expect(logits.dims).toEqual([1, 2]);
117+
expect(logits.tolist()).toBeCloseToNested(target, 5);
118+
},
119+
MAX_TEST_EXECUTION_TIME,
120+
);
121+
122+
it(
123+
"batch_size>1",
124+
async () => {
125+
const inputs = tokenizer(["hello", "hello world"], { padding: true });
126+
const { logits } = await model(inputs);
127+
const target = [
128+
[-0.7050137519836426, 2.343430519104004],
129+
[-2.6860175132751465, 3.993380546569824],
130+
];
131+
expect(logits.dims).toEqual([2, 2]);
132+
expect(logits.tolist()).toBeCloseToNested(target, 5);
133+
},
134+
MAX_TEST_EXECUTION_TIME,
135+
);
136+
137+
afterAll(async () => {
138+
await model?.dispose();
139+
}, MAX_MODEL_DISPOSE_TIME);
140+
});
141+
142+
describe("ModernBertForTokenClassification", () => {
143+
const model_id = "hf-internal-testing/tiny-random-ModernBertForTokenClassification";
144+
145+
/** @type {ModernBertForTokenClassification} */
146+
let model;
147+
/** @type {PreTrainedTokenizer} */
148+
let tokenizer;
149+
beforeAll(async () => {
150+
model = await ModernBertForTokenClassification.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
151+
tokenizer = await PreTrainedTokenizer.from_pretrained(model_id);
152+
}, MAX_MODEL_LOAD_TIME);
153+
154+
it(
155+
"batch_size=1",
156+
async () => {
157+
const inputs = tokenizer("hello");
158+
const { logits } = await model(inputs);
159+
expect(logits.dims).toEqual([1, 3, 2]);
160+
expect(logits.mean().item()).toBeCloseTo(1.0337047576904297, 5);
161+
},
162+
MAX_TEST_EXECUTION_TIME,
163+
);
164+
165+
it(
166+
"batch_size>1",
167+
async () => {
168+
const inputs = tokenizer(["hello", "hello world"], { padding: true });
169+
const { logits } = await model(inputs);
170+
expect(logits.dims).toEqual([2, 4, 2]);
171+
expect(logits.mean().item()).toBeCloseTo(-1.3397092819213867, 5);
172+
},
173+
MAX_TEST_EXECUTION_TIME,
174+
);
175+
176+
afterAll(async () => {
177+
await model?.dispose();
178+
}, MAX_MODEL_DISPOSE_TIME);
179+
});
180+
};

0 commit comments

Comments
 (0)