Skip to content

Commit 07377a6

Browse files
update whisper sample (#629)
1 parent a87891e commit 07377a6

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

PyTorch/audio/whisper/whisper/model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,9 @@ def forward(
177177

178178
class AudioEncoder(nn.Module):
179179
def __init__(
180-
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, use_dml_attn: bool = False,
180+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
181181
):
182182
super().__init__()
183-
self.use_dml_attn = use_dml_attn
184183
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
185184
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
186185
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
@@ -203,7 +202,7 @@ def forward(self, x: Tensor):
203202
x = (x + self.positional_embedding).to(x.dtype)
204203

205204
for block in self.blocks:
206-
x = block(x, use_dml_attn=self.use_dml_attn)
205+
x = block(x)
207206

208207
x = self.ln_post(x)
209208
return x
@@ -270,7 +269,6 @@ def __init__(self, dims: ModelDimensions, use_dml_attn=False):
270269
self.dims.n_audio_state,
271270
self.dims.n_audio_head,
272271
self.dims.n_audio_layer,
273-
use_dml_attn
274272
)
275273
self.decoder = TextDecoder(
276274
self.dims.n_vocab,

0 commit comments

Comments
 (0)