You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: base/science-tech-maths/machine-learning/algorithms/neural-nets/conv-neural-nets/diffusion-models/diffusion-models.md
+1
Original file line number
Diff line number
Diff line change
@@ -32,3 +32,4 @@ What conditions do diffusion model architectures need to fulfill?
Copy file name to clipboardExpand all lines: base/science-tech-maths/machine-learning/algorithms/neural-nets/transformers/transformers.md
+36-1
Original file line number
Diff line number
Diff line change
@@ -461,6 +461,8 @@ If the embedding space consists of more than two dimensions (which it almost alw
461
461
In GPT2, there are two matrixes called WTE (word token embedding) and WPE (word position embedding).
462
462
WPE is 1024×768. It means that the maximum number of tokens that we can use in a prompt to GPT2 is 1024.
463
463
464
+
More information about the reasoning behind the positional encoding: <https://fleetwood.dev/posts/you-could-have-designed-SOTA-positional-encoding>
465
+
464
466
### Transformer decoder
465
467
466
468
<imgsrc="transformer-decoder.png"width="200">
@@ -569,9 +571,42 @@ This pairwise communication means a forward pass is O(n²) time complexity in tr
569
571
570
572
## KV cache
571
573
574
+
Imagine you're writing a story, and for each new word you write, you need to re-read the entire story so far to maintain consistency. The longer your story gets, the more time you spend re-reading.
575
+
576
+
The key insight behind KV caching is that we're doing a lot of redundant work. When generating each new token, we're recomputing things for all previous tokens that we've already processed before.
577
+
578
+
For each token, we compute and store two things:
579
+
580
+
- A key (kk): Think of this as an addressing mechanism - it helps determine how relevant this token is to future tokens
581
+
- A value (vv): Think of this as the actual information that gets used when this token is found to be relevant
582
+
572
583
The KV cache is a cache of the key-value pairs of the encoder output. It is used to speed up the inference process.
573
584
574
-
storing this KV cache requires O(n) space.
585
+
This is a dramatic improvement over O(n3)! While we still have to do the fundamental work of looking at all previous tokens (O(n2)), we avoid the costly recomputation at each step.
586
+
587
+
Let's look at the memory cost of KV caching with a concrete example.
588
+
589
+
For a modern large language model like Llama3 70B with:
590
+
591
+
- $L=80$ layers
592
+
- $H=64$ attention heads
593
+
- $B=8$ batch size
594
+
- $d_k=128$ key/value dimension
595
+
- $2$ K and V
596
+
- 16-bit precision
597
+
598
+
For a batch of 8 sequences of 1000 tokens each, the memory required would be:
- $L \times H \times B \times n$ gives us the total number of key-value pairs
605
+
- $d_k$ is the dimension of each key/value vector
606
+
- First $\times 2$ is for storing both keys and values
607
+
- Second $\times 2$ is for 16-bit precision (2 bytes per value)
608
+
609
+
This shows that while KV caching provides significant speedup by avoiding redundant computations, it comes with substantial memory requirements that grow linearly with sequence length and batch size.
0 commit comments