Skip to content

Commit e6cc76f

Browse files
authored
Implement DeepSeek V2 (#2744)
* Add deepseek v2 * Fix * Remove unused * Add kv cache * Remove from cargo.toml * Fix dtype selection logic * Fix unnecessary u32->f32->gather->u32 * Remove fromstr impl * Use local scopes for some clarity * Typo * Repeat k_pe * Chain calls to remove mut * Actually, remove all muts * Update readme
1 parent fd7f724 commit e6cc76f

File tree

4 files changed

+1367
-0
lines changed

4 files changed

+1367
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# DeepSeek V2
2+
3+
DeepSeek V2 an MoE model featuring MLA (Multi-Latent Attention). There is a lite (16B) and a full (236B) model.
4+
5+
- Context length of **32k tokens** (Lite model), **128k tokens** (full model)
6+
- 64 routed experts (Lite model), 160 routed experts (full model)
7+
8+
## Running the example
9+
10+
```bash
11+
$ cargo run --example deepseekv2 --release --features metal -- --prompt "Recursive fibonacci code in Rust:" --which lite --sample-len 150
12+
13+
fn fibonacci(n: u32) -> u32 {
14+
if n <= 1 {
15+
return n;
16+
} else {
17+
return fibonacci(n - 1) + fibonacci(n - 2);
18+
}
19+
}
20+
21+
## Fibonacci code in Python:
22+
23+
def fibonacci(n):
24+
if n <= 1:
25+
return n
26+
else:
27+
return fibonacci(n-1) + fibonacci(n-2)
28+
29+
## Fibonacci code in JavaScript:
30+
31+
function fibonacci(n) {
32+
if (n <= 1
33+
```
+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#[cfg(feature = "mkl")]
2+
extern crate intel_mkl_src;
3+
4+
#[cfg(feature = "accelerate")]
5+
extern crate accelerate_src;
6+
7+
use anyhow::{Error as E, Result};
8+
use clap::Parser;
9+
10+
use candle_transformers::models::deepseek2::{DeepSeekV2, DeepSeekV2Config};
11+
12+
use candle::{DType, Device, Tensor};
13+
use candle_examples::token_output_stream::TokenOutputStream;
14+
use candle_nn::VarBuilder;
15+
use candle_transformers::generation::{LogitsProcessor, Sampling};
16+
use hf_hub::{api::sync::Api, Repo, RepoType};
17+
use tokenizers::Tokenizer;
18+
19+
struct TextGeneration {
20+
model: DeepSeekV2,
21+
device: Device,
22+
tokenizer: TokenOutputStream,
23+
logits_processor: LogitsProcessor,
24+
repeat_penalty: f32,
25+
repeat_last_n: usize,
26+
}
27+
28+
impl TextGeneration {
29+
#[allow(clippy::too_many_arguments)]
30+
fn new(
31+
model: DeepSeekV2,
32+
tokenizer: Tokenizer,
33+
seed: u64,
34+
temp: Option<f64>,
35+
top_p: Option<f64>,
36+
top_k: Option<usize>,
37+
repeat_penalty: f32,
38+
repeat_last_n: usize,
39+
device: &Device,
40+
) -> Self {
41+
let logits_processor = {
42+
let temperature = temp.unwrap_or(0.);
43+
let sampling = if temperature <= 0. {
44+
Sampling::ArgMax
45+
} else {
46+
match (top_k, top_p) {
47+
(None, None) => Sampling::All { temperature },
48+
(Some(k), None) => Sampling::TopK { k, temperature },
49+
(None, Some(p)) => Sampling::TopP { p, temperature },
50+
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
51+
}
52+
};
53+
LogitsProcessor::from_sampling(seed, sampling)
54+
};
55+
56+
Self {
57+
model,
58+
tokenizer: TokenOutputStream::new(tokenizer),
59+
logits_processor,
60+
repeat_penalty,
61+
repeat_last_n,
62+
device: device.clone(),
63+
}
64+
}
65+
66+
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
67+
use std::io::Write;
68+
self.tokenizer.clear();
69+
let mut tokens = self
70+
.tokenizer
71+
.tokenizer()
72+
.encode(prompt, true)
73+
.map_err(E::msg)?
74+
.get_ids()
75+
.to_vec();
76+
for &t in tokens.iter() {
77+
if let Some(t) = self.tokenizer.next_token(t)? {
78+
print!("{t}")
79+
}
80+
}
81+
std::io::stdout().flush()?;
82+
83+
let mut generated_tokens = 0usize;
84+
let eos_token = match self.tokenizer.get_token("<|end▁of▁sentence|>") {
85+
Some(token) => token,
86+
None => anyhow::bail!("cannot find the <|end▁of▁sentence|> token"),
87+
};
88+
let start_gen = std::time::Instant::now();
89+
for index in 0..sample_len {
90+
let context_size = if index > 0 { 1 } else { tokens.len() };
91+
let start_pos = tokens.len().saturating_sub(context_size);
92+
let ctxt = &tokens[start_pos..];
93+
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
94+
let logits = self.model.forward(&input, start_pos)?;
95+
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
96+
let logits = if self.repeat_penalty == 1. {
97+
logits
98+
} else {
99+
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
100+
candle_transformers::utils::apply_repeat_penalty(
101+
&logits,
102+
self.repeat_penalty,
103+
&tokens[start_at..],
104+
)?
105+
};
106+
107+
let next_token = self.logits_processor.sample(&logits)?;
108+
tokens.push(next_token);
109+
generated_tokens += 1;
110+
if next_token == eos_token {
111+
break;
112+
}
113+
if let Some(t) = self.tokenizer.next_token(next_token)? {
114+
print!("{t}");
115+
std::io::stdout().flush()?;
116+
}
117+
}
118+
let dt = start_gen.elapsed();
119+
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
120+
print!("{rest}");
121+
}
122+
std::io::stdout().flush()?;
123+
println!(
124+
"\n{generated_tokens} tokens generated ({:.2} token/s)",
125+
generated_tokens as f64 / dt.as_secs_f64(),
126+
);
127+
Ok(())
128+
}
129+
}
130+
131+
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
132+
enum Which {
133+
#[value(name = "lite")]
134+
Lite,
135+
#[value(name = "lite-chat")]
136+
LiteChat,
137+
#[value(name = "coder-lite-chat")]
138+
CoderLiteChat,
139+
#[value(name = "v2")]
140+
V2,
141+
#[value(name = "v2-chat")]
142+
V2Chat,
143+
}
144+
145+
#[derive(Parser, Debug)]
146+
#[command(author, version, about, long_about = None)]
147+
struct Args {
148+
/// Run on CPU rather than on GPU.
149+
#[arg(long)]
150+
cpu: bool,
151+
152+
/// Enable tracing (generates a trace-timestamp.json file).
153+
#[arg(long)]
154+
tracing: bool,
155+
156+
#[arg(long)]
157+
use_flash_attn: bool,
158+
159+
#[arg(long)]
160+
prompt: String,
161+
162+
/// The temperature used to generate samples.
163+
#[arg(long)]
164+
temperature: Option<f64>,
165+
166+
/// Nucleus sampling probability cutoff.
167+
#[arg(long)]
168+
top_p: Option<f64>,
169+
170+
/// Only sample among the top K samples.
171+
#[arg(long)]
172+
top_k: Option<usize>,
173+
174+
/// The seed to use when generating random samples.
175+
#[arg(long, default_value_t = 299792458)]
176+
seed: u64,
177+
178+
/// The length of the sample to generate (in tokens).
179+
#[arg(long, short = 'n', default_value_t = 10000)]
180+
sample_len: usize,
181+
182+
/// The model size to use.
183+
#[arg(long, default_value = "lite")]
184+
which: Which,
185+
186+
#[arg(long)]
187+
model_id: Option<String>,
188+
189+
#[arg(long, default_value = "main")]
190+
revision: String,
191+
192+
/// Penalty to be applied for repeating tokens, 1. means no penalty.
193+
#[arg(long, default_value_t = 1.1)]
194+
repeat_penalty: f32,
195+
196+
/// The context size to consider for the repeat penalty.
197+
#[arg(long, default_value_t = 64)]
198+
repeat_last_n: usize,
199+
}
200+
201+
fn main() -> Result<()> {
202+
use tracing_chrome::ChromeLayerBuilder;
203+
use tracing_subscriber::prelude::*;
204+
205+
let args = Args::parse();
206+
207+
let _guard = if args.tracing {
208+
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
209+
tracing_subscriber::registry().with(chrome_layer).init();
210+
Some(guard)
211+
} else {
212+
None
213+
};
214+
println!(
215+
"avx: {}, neon: {}, simd128: {}, f16c: {}",
216+
candle::utils::with_avx(),
217+
candle::utils::with_neon(),
218+
candle::utils::with_simd128(),
219+
candle::utils::with_f16c()
220+
);
221+
println!(
222+
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
223+
args.temperature.unwrap_or(0.),
224+
args.repeat_penalty,
225+
args.repeat_last_n
226+
);
227+
228+
let start = std::time::Instant::now();
229+
let api = Api::new()?;
230+
let model_id = match args.model_id {
231+
Some(model_id) => model_id,
232+
None => match args.which {
233+
Which::CoderLiteChat => "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct".to_string(),
234+
Which::LiteChat => "deepseek-ai/DeepSeek-V2-Lite-Chat".to_string(),
235+
Which::Lite => "deepseek-ai/DeepSeek-V2-Lite".to_string(),
236+
Which::V2 => "deepseek-ai/DeepSeek-V2".to_string(),
237+
Which::V2Chat => "deepseek-ai/DeepSeek-V2-Chat".to_string(),
238+
},
239+
};
240+
let repo = api.repo(Repo::with_revision(
241+
model_id,
242+
RepoType::Model,
243+
args.revision,
244+
));
245+
let tokenizer_filename = repo.get("tokenizer.json")?;
246+
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
247+
println!("retrieved the files in {:?}", start.elapsed());
248+
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
249+
250+
let start = std::time::Instant::now();
251+
let config: DeepSeekV2Config = {
252+
let config_file = repo.get("config.json")?;
253+
serde_json::from_slice(&std::fs::read(config_file)?)?
254+
};
255+
let device = candle_examples::device(args.cpu)?;
256+
let (model, device) = {
257+
let dtype = if device.is_cpu() {
258+
DType::F16
259+
} else {
260+
DType::BF16
261+
};
262+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
263+
let model = DeepSeekV2::new(&config, vb)?;
264+
(model, device)
265+
};
266+
267+
println!("loaded the model in {:?}", start.elapsed());
268+
269+
let mut pipeline = TextGeneration::new(
270+
model,
271+
tokenizer,
272+
args.seed,
273+
args.temperature,
274+
args.top_p,
275+
args.top_k,
276+
args.repeat_penalty,
277+
args.repeat_last_n,
278+
&device,
279+
);
280+
pipeline.run(&args.prompt, args.sample_len)?;
281+
Ok(())
282+
}

0 commit comments

Comments
 (0)