@@ -75,7 +75,12 @@ This section details the necessary modifications to make to a Transformers compa
7575To make your model compatible with the Transformers backend, it needs:
7676
77771 . ` kwargs ` passed down through all modules from ` MyModel ` to ` MyAttention ` .
78- 1 . If your model is encoder-only, you must also add ` is_causal = False ` to ` MyAttention ` .
78+ - If your model is encoder-only:
79+ 1 . Add ` is_causal = False ` to ` MyAttention ` .
80+ - If your model is mixture-of-experts (MoE):
81+ 1 . Your sparse MoE block must have an attribute called ` experts ` .
82+ 2 . The class of ` experts ` (` MyExperts ` ) must inherit from ` nn.ModuleList ` .
83+ 3 . ` MyExperts.forward ` must accept ` hidden_states ` , ` top_k_index ` , ` top_k_weights ` .
79842 . ` MyAttention ` must use ` ALL_ATTENTION_FUNCTIONS ` to call attention.
80853 . ` MyModel ` must contain ` _supports_attention_backend = True ` .
8186
@@ -102,6 +107,23 @@ class MyAttention(nn.Module):
102107 )
103108 ...
104109
110+ # Only do this for mixture-of-experts models
111+ class MyExperts (nn .ModuleList ):
112+ def forward (self , hidden_states , top_k_index , top_k_weights ):
113+ ...
114+
115+ # Only do this for mixture-of-experts models
116+ class MySparseMoEBlock (nn .Module ):
117+ def __init__ (self , config ):
118+ ...
119+ self .experts = MyExperts(config)
120+ ...
121+
122+ def forward (self , hidden_states : torch.Tensor):
123+ ...
124+ hidden_states = self .experts(hidden_states, top_k_index, top_k_weights)
125+ ...
126+
105127class MyModel (PreTrainedModel ):
106128 _supports_attention_backend = True
107129```
0 commit comments