Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ serde_json = "1.0"
thiserror = "1.0"
rand = "0.9"
serial_test = "2.0.0"
cudarc = { version = "0.13", features =["cuda-12020"], default-features = false }
cudarc = { version = "0.13", features = ["cuda-12020"], default-features = false }
intel-mkl-src = { version = "0.8", default-features = false }
candle = { version = "0.8", package = "candle-core" }
candle-nn = { version = "0.8" }
Expand All @@ -52,10 +52,11 @@ candle-cublaslt = { version = "0.0.1" }
candle-layer-norm = { version = "0.0.1" }
candle-rotary = { version = "0.0.1" }
candle-flash-attn-v1 = { version = "0.0.1" }
candle-moe = { git = "https://github.com/kozistr/candle-moe", rev = "990ac1f42248dd441c51c9b5bcb73c5b77c03f99" }
half = { version = "2.3.1", features = ["num-traits"] }

[patch.crates-io]
cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9"}
cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9" }
candle = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-transformers" }
Expand Down
3 changes: 2 additions & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ candle-flash-attn-v1 = { workspace = true, optional = true }
candle-cublaslt = { workspace = true, optional = true }
candle-layer-norm = { workspace = true, optional = true }
candle-rotary = { workspace = true, optional = true }
candle-moe = { workspace = true, optional = true }
nohash-hasher = { workspace = true }
text-embeddings-backend-core = { path = "../core" }
tracing = { workspace = true }
Expand All @@ -41,6 +42,6 @@ anyhow = { version = "1", features = ["backtrace"] }
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
metal = ["candle/metal", "candle-nn/metal"]
mkl = ["dep:intel-mkl-src", "candle/_mkl"]
cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"]
cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-moe"]
flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"]
flash-attn = ["dep:candle-flash-attn", "cuda"]
184 changes: 183 additions & 1 deletion backends/candle/src/models/nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::layers::{
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
#[cfg(feature = "cuda")]
use candle_moe;
use candle_nn::{Embedding, VarBuilder};
use candle_transformers::models::deepseek2::{BincountOp, NonZeroOp, TopKLastDimOp, TopKOutput};
use serde::Deserialize;
Expand Down Expand Up @@ -239,6 +241,55 @@ impl NomicRouter {
}
}

#[cfg(feature = "cuda")]
pub struct NomicFusedRouter {
layer: Linear,
top_k: usize,

span: tracing::Span,
}

#[cfg(feature = "cuda")]
impl NomicFusedRouter {
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
let num_experts = config.num_experts.unwrap();
let top_k = config.moe_top_k.unwrap();

let layer_weight = vb.pp("layer").get((num_experts, config.n_embd), "weight")?;
let layer = Linear::new(layer_weight, None, None);

Ok(Self {
layer,
top_k,
span: tracing::span!(tracing::Level::TRACE, "router"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();

let device = hidden_states.device();

let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?;
let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?;

let (seq_len, _) = weights.shape().dims2()?;

let topk_weight = Tensor::zeros((seq_len, self.top_k), DType::F32, device)?;
let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?;
let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?;

candle_moe::apply_topk_softmax_inplace(
&weights,
&topk_weight,
&topk_indices,
&token_expert_indices,
)?;

Ok((topk_weight, topk_indices))
}
}

pub struct NomicExpertMLP {
w1: Tensor,
w2: Tensor,
Expand Down Expand Up @@ -363,6 +414,95 @@ impl NomicExperts {
}
}

#[cfg(feature = "cuda")]
pub struct NomicFusedExperts {
gate_weight: Tensor,
up_weight: Tensor,
bias: Tensor,
fused_moe: candle_moe::FusedMoeForward,

span: tracing::Span,
}

#[cfg(feature = "cuda")]
impl NomicFusedExperts {
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
let hidden_size = config.n_embd;
let ffn_hidden_size = config.n_inner;
let num_experts = config.num_experts.unwrap();
let top_k = config.moe_top_k.unwrap();
let activation = config.activation_function.clone();

let gate_weight = vb
.pp("mlp")
.get((num_experts * ffn_hidden_size, hidden_size), "w1")?
.reshape((num_experts, ffn_hidden_size, hidden_size))?
.permute((0, 2, 1))?
.contiguous()?;
let up_weight = vb
.pp("mlp")
.get((num_experts * ffn_hidden_size, hidden_size), "w2")?
.reshape((num_experts, ffn_hidden_size, hidden_size))?
.permute((0, 2, 1))?
.contiguous()?;

let bias = vb.get((config.n_embd,), "bias")?;

let moe_act = match activation {
HiddenAct::Silu => candle_moe::Activation::Silu,
HiddenAct::Gelu => candle_moe::Activation::Gelu,
HiddenAct::Relu => candle_moe::Activation::Relu,
_ => candle::bail!("not supported activation type"),
};

let fused_moe = candle_moe::FusedMoeForward::new(num_experts, top_k, moe_act);

Ok(Self {
gate_weight,
up_weight,
bias,
fused_moe,
span: tracing::span!(tracing::Level::TRACE, "experts"),
})
}

pub fn forward(
&self,
hidden_states: &Tensor,
top_weights: &Tensor,
top_experts: &Tensor,
) -> Result<Tensor> {
let _enter = self.span.enter();

let dims = hidden_states.dims();
let ndim = dims.len();

let (bs, seq_len, hidden_size) = match ndim {
3 => (dims[0], dims[1], dims[2]),
2 => (1, dims[0], dims[1]),
_ => unreachable!(),
};

let hidden_states = hidden_states.reshape(((), hidden_size))?;

let mut out = self.fused_moe.forward(
&hidden_states,
&self.gate_weight,
&self.up_weight,
None,
&top_weights,
&top_experts,
1_u32, // Nomic MoE
)?;

if ndim == 3 {
out = out.reshape((bs, seq_len, hidden_size))?;
}

out.broadcast_add(&self.bias)
}
}

pub struct NomicMoELayer {
router: NomicRouter,
experts: NomicExperts,
Expand Down Expand Up @@ -392,8 +532,41 @@ impl NomicMoELayer {
}
}

#[cfg(feature = "cuda")]
pub struct NomicFusedMoELayer {
router: NomicFusedRouter,
experts: NomicFusedExperts,

span: tracing::Span,
}

#[cfg(feature = "cuda")]
impl NomicFusedMoELayer {
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
let router = NomicFusedRouter::load(vb.pp("router"), config)?;
let experts = NomicFusedExperts::load(vb.pp("experts"), config)?;

Ok(Self {
router,
experts,
span: tracing::span!(tracing::Level::TRACE, "moe"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let (top_weights, top_experts) = self.router.forward(hidden_states)?;

self.experts
.forward(hidden_states, &top_weights, &top_experts)
}
}

pub enum NomicMLP {
MoE(NomicMoELayer),
#[cfg(feature = "cuda")]
FusedMoE(NomicFusedMoELayer),
GatedMLP(NomicBertGatedMLP),
Mlp(NomicBertMLP),
}
Expand All @@ -403,7 +576,14 @@ impl NomicMLP {
let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1);

if use_moe {
Ok(Self::MoE(NomicMoELayer::load(vb, config)?))
#[cfg(feature = "cuda")]
{
Ok(Self::FusedMoE(NomicFusedMoELayer::load(vb, config)?))
}
#[cfg(not(feature = "cuda"))]
{
Ok(Self::MoE(NomicMoELayer::load(vb, config)?))
}
} else if config.activation_function == HiddenAct::Gelu {
Ok(Self::Mlp(NomicBertMLP::load(vb, config)?))
} else {
Expand All @@ -414,6 +594,8 @@ impl NomicMLP {
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
match self {
Self::MoE(layer) => layer.forward(hidden_states),
#[cfg(feature = "cuda")]
Self::FusedMoE(layer) => layer.forward(hidden_states),
Self::GatedMLP(layer) => layer.forward(hidden_states),
Self::Mlp(layer) => layer.forward(hidden_states),
}
Expand Down