Skip to content

Commit

Permalink
llama : add support for MiniMax-Text-01 model
Browse files Browse the repository at this point in the history
  • Loading branch information
sszymczy committed Jan 21, 2025
1 parent e28245f commit 5373298
Show file tree
Hide file tree
Showing 14 changed files with 593 additions and 1 deletion.
67 changes: 67 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
res = "deepseek-r1-qwen"
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
# ref: https://huggingface.co/MiniMaxAI/MiniMax-Text-01
res = "minimax-01"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -4906,6 +4909,70 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
return data_torch


@Model.register("MiniMaxText01ForCausalLM")
class MiniMaxText01Model(Model):
model_arch = gguf.MODEL_ARCH.MINIMAX01

def set_gguf_parameters(self):
super().set_gguf_parameters()

layernorm_full_attention_alpha = self.hparams["layernorm_full_attention_alpha"]
layernorm_full_attention_beta = self.hparams["layernorm_full_attention_beta"]
layernorm_linear_attention_alpha = self.hparams["layernorm_linear_attention_alpha"]
layernorm_linear_attention_beta = self.hparams["layernorm_linear_attention_beta"]
layernorm_mlp_alpha = self.hparams["layernorm_mlp_alpha"]
layernorm_mlp_beta = self.hparams["layernorm_mlp_beta"]
assert layernorm_full_attention_alpha == layernorm_linear_attention_alpha == layernorm_mlp_alpha
assert layernorm_full_attention_beta == layernorm_linear_attention_beta == layernorm_mlp_beta == 1.0
# we do not store the layernorm betas as they are all 1.0
# layernorm alphas are stored as single residual_scale hparam
self.gguf_writer.add_residual_scale(layernorm_full_attention_alpha)

self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for wid in ["w1", "w2", "w3"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]


###### CONVERSION LOGIC ######


Expand Down
22 changes: 22 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class MODEL_ARCH(IntEnum):
GRANITE_MOE = auto()
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
MINIMAX01 = auto()


class MODEL_TENSOR(IntEnum):
Expand All @@ -301,6 +302,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_OUT_NORM = auto()
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
ATTN_GATE = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
Expand Down Expand Up @@ -466,6 +468,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.MINIMAX01: "minimax01",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand All @@ -490,6 +493,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
Expand Down Expand Up @@ -1535,6 +1539,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.POSNET_ATTN_OUT,
],
# TODO
MODEL_ARCH.MINIMAX01: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
}

# tensors that will not be serialized
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv
"model.layers.{bid}.self_attn.norm", # minimax_text-01
),

# Attention query-key-value
Expand Down Expand Up @@ -214,6 +215,7 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
"model.layers.{bid}.self_attn.out_proj", # minimax_text-01
),

# Attention output norm
Expand All @@ -236,6 +238,10 @@ class TensorNameMap:
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
),

MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.output_gate", # minimax-text-01
),

# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_MINIMAX = 29,
};

enum llama_rope_type {
Expand Down
23 changes: 23 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_MINIMAX01, "minimax01" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1292,6 +1293,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
{
LLM_ARCH_MINIMAX01,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -1319,6 +1341,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum llm_arch {
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_MINIMAX01,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -214,6 +215,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_QKV,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_OUT_NORM,
Expand Down
67 changes: 67 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,73 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
}
}
}

if (lctx.inp_slopes) {
const int64_t n_head = hparams.n_head();

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_slopes->buffer));

float * data = (float *) lctx.inp_slopes->data;

float start = powf(2, -powf(2, -(log2f(n_head) - 3)));
float ratio = start;

for (int h = 0; h < n_head; ++h) {
data[h] = start * powf(ratio, h);
}
}

if (lctx.inp_q_decay) {
const int64_t n_head = hparams.n_head();
const int64_t n_tokens = ubatch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_q_decay->buffer));

float * slopes = (float *) lctx.inp_slopes->data;
float * data = (float *) lctx.inp_q_decay->data;

for (int i = 0; i < n_tokens; ++i) {
for (int h = 0; h < n_head; ++h) {
data[i * n_head + h] = -slopes[h] * (i + 1);
}
}
}

if (lctx.inp_k_decay) {
const int64_t n_head = hparams.n_head();
const int64_t n_tokens = ubatch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_k_decay->buffer));

float * slopes = (float *) lctx.inp_slopes->data;
float * data = (float *) lctx.inp_k_decay->data;

for (int i = 0; i < n_tokens; ++i) {
for (int h = 0; h < n_head; ++h) {
data[i * n_head + h] = -slopes[h] * (n_tokens - i - 1);
}
}
}

if (lctx.inp_diag_decay) {
const int64_t n_head = hparams.n_head();
const int64_t n_tokens = ubatch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_diag_decay->buffer));

float * slopes = (float *) lctx.inp_slopes->data;
float * data = (float *) lctx.inp_diag_decay->data;

for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_tokens; ++i) {
int index = j - i;
for (int h = 0; h < n_head; ++h) {
float s_index = index >= 0 ? -slopes[h] * index : -INFINITY;
data[j * n_head * n_tokens + i * n_head + h] = s_index;
}
}
}
}
}

// llama output
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
struct ggml_tensor * inp_slopes; // F32 [n_head]
struct ggml_tensor * inp_q_decay; // F32 [n_batch, n_head]
struct ggml_tensor * inp_k_decay; // F32 [n_batch, n_head]
struct ggml_tensor * inp_diag_decay; // F32 [n_batch, n_batch, n_head]
};

// TODO: make these methods of llama_context
Expand Down
8 changes: 7 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ bool llama_kv_cache_init(
const struct llama_hparams & hparams = model.hparams;

const int32_t n_layer = hparams.n_layer;
const int n_head = hparams.n_head();
const int n_embd_head_k = hparams.n_embd_head_k;

cache.has_shift = false;

Expand All @@ -53,7 +55,7 @@ bool llama_kv_cache_init(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
struct ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead() + 4u*n_head*n_embd_head_k*n_embd_head_k),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
Expand All @@ -70,6 +72,7 @@ bool llama_kv_cache_init(

cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);
cache.kv_l.reserve(n_layer);

for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
Expand All @@ -93,10 +96,13 @@ bool llama_kv_cache_init(

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_tensor * kv = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head_k, n_embd_head_k, n_head);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
cache.kv_l.push_back(kv);
}

// allocate tensors and initialize the buffers to avoid NaNs in the padding
Expand Down
2 changes: 2 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct llama_kv_cache {
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;

std::vector<struct ggml_tensor *> kv_l;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;

Expand Down
Loading

0 comments on commit 5373298

Please sign in to comment.