Skip to content

Commit f02f227

Browse files
authored
Add gpt-neo-x support (#745)
* add gpt-neo-x configs * fixes * name fix
1 parent cd21dca commit f02f227

File tree

6 files changed

+18
-0
lines changed

6 files changed

+18
-0
lines changed

docs/source/exporters/onnx/package_reference/configuration.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ They specify which input generators should be used for the dummy inputs, but rem
9292
- GPT-2
9393
- GPT-J
9494
- GPT-Neo
95+
- GPT-NeoX
9596
- GroupVit
9697
- Hubert
9798
- IBert

optimum/exporters/onnx/model_configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ class GPTNeoOnnxConfig(TextDecoderOnnxConfig):
191191
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
192192

193193

194+
class GPTNeoXOnnxConfig(TextDecoderOnnxConfig):
195+
DEFAULT_ONNX_OPSET = 13
196+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
197+
198+
194199
class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
195200
def generate(self, input_name: str, framework: str = "pt"):
196201
past_key_shape = (

optimum/exporters/tasks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,13 @@ class TasksManager:
383383
"sequence-classification",
384384
onnx="GPTNeoOnnxConfig",
385385
),
386+
"gpt-neox": supported_tasks_mapping(
387+
"default",
388+
"default-with-past",
389+
"causal-lm",
390+
"causal-lm-with-past",
391+
onnx="GPTNeoXOnnxConfig",
392+
),
386393
"groupvit": supported_tasks_mapping(
387394
"default",
388395
onnx="GroupViTOnnxConfig",

optimum/utils/normalized_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class NormalizedConfigManager:
193193
"electra": NormalizedTextConfig,
194194
"gpt2": GPT2LikeNormalizedTextConfig,
195195
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
196+
"gpt_neox": NormalizedTextConfig,
196197
"gptj": GPT2LikeNormalizedTextConfig,
197198
"longt5": T5LikeNormalizedTextConfig,
198199
"marian": BartLikeNormalizedTextConfig,

tests/exporters/exporters_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"flaubert": "hf-internal-testing/tiny-random-flaubert",
5454
"gpt2": "hf-internal-testing/tiny-random-gpt2",
5555
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
56+
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
5657
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
5758
"groupvit": "hf-internal-testing/tiny-random-groupvit",
5859
"ibert": "hf-internal-testing/tiny-random-IBertModel",
@@ -151,6 +152,7 @@
151152
"flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO
152153
"gpt2": "gpt2",
153154
"gpt-neo": "EleutherAI/gpt-neo-125M",
155+
"gpt-neox": "EleutherAI/gpt-neox-20b",
154156
"gptj": "anton-l/gpt-j-tiny-random", # TODO
155157
"groupvit": "nvidia/groupvit-gcc-yfcc",
156158
"ibert": "kssteven/ibert-roberta-base",

tests/onnxruntime/test_modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
"flaubert": "hf-internal-testing/tiny-random-flaubert",
112112
"gpt2": "hf-internal-testing/tiny-random-gpt2",
113113
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
114+
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
114115
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
115116
"groupvit": "hf-internal-testing/tiny-random-groupvit",
116117
"ibert": "hf-internal-testing/tiny-random-IBertModel",
@@ -1731,6 +1732,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
17311732
"codegen",
17321733
"gpt2",
17331734
"gpt_neo",
1735+
"gpt_neox",
17341736
"gptj",
17351737
]
17361738

0 commit comments

Comments
 (0)