|
22 | 22 | from torch.nn import functional as F
|
23 | 23 | from .config import ModelArgs, find_multiple
|
24 | 24 | from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
|
| 25 | +from jetstream_pt import quantize, torchjax |
25 | 26 |
|
26 | 27 | import jax
|
| 28 | +import jax.numpy as jnp |
27 | 29 |
|
28 | 30 |
|
29 | 31 | class Transformer(nn.Module):
|
@@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
|
233 | 235 | else:
|
234 | 236 | return self.forward_for_short_seq_len(x, expert_indices)
|
235 | 237 |
|
| 238 | + def _int_ti_eoi_teo(self, lhs, rhs): |
| 239 | + # x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) |
| 240 | + result = torchjax.call_jax( |
| 241 | + jax.lax.dot_general, |
| 242 | + lhs, |
| 243 | + rhs, |
| 244 | + (((1,), (2)), ((), ())), |
| 245 | + None, |
| 246 | + jnp.bfloat16.dtype, |
| 247 | + ) |
| 248 | + return result |
| 249 | + |
| 250 | + def _int_teo_eio_tei(self, lhs, rhs): |
| 251 | + #torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler |
| 252 | + result = torchjax.call_jax( |
| 253 | + jax.lax.dot_general, |
| 254 | + lhs, |
| 255 | + rhs, |
| 256 | + (((2,), (2,)), ((1, ), (0, ))), |
| 257 | + None, |
| 258 | + jnp.bfloat16.dtype, |
| 259 | + ) # output is (eti) for some reason |
| 260 | + return result.transpose(0, 1) |
| 261 | + |
| 262 | + |
236 | 263 | def forward_for_short_seq_len(
|
237 | 264 | self, x: Tensor, expert_indices: Tensor
|
238 | 265 | ) -> Tensor:
|
@@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
|
260 | 287 | # o = config.imtermediate size
|
261 | 288 | # i = config.dim
|
262 | 289 | with jax.named_scope("conditional_ff"):
|
263 |
| - x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) |
264 |
| - x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler |
| 290 | + x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,)) |
| 291 | + x_scaler = x_scaler.reshape(seqlen, 1, 1) |
| 292 | + |
| 293 | + x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler) |
| 294 | + x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler |
| 295 | + |
| 296 | + x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2)) |
| 297 | + x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1) |
265 | 298 | expert_outs = (
|
266 |
| - torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler |
| 299 | + self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler |
267 | 300 | )
|
268 | 301 | # e = 8; need to reduce to 2
|
269 | 302 | seq_indexes = torch.arange(seqlen).unsqueeze(1)
|
270 |
| - return expert_outs[seq_indexes, expert_indices] |
| 303 | + return expert_outs[seq_indexes, expert_indices] * x1x3_scaler |
271 | 304 |
|
272 | 305 |
|
273 | 306 | class ConditionalFeedForward(nn.Module):
|
|
0 commit comments