Skip to content

Add LLaDA 8b Diffusion model #14771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3431,34 +3431,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

// diffusion parameters
// shared diffusion parameters
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
string_format("number of diffusion steps (default: %d)", params.diffusion_dream.steps),
[](common_params & params, int value) {
params.diffusion_dream.steps = value;
params.diffusion_llada.steps = value;
}
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM, LLAMA_EXAMPLE_DIFFUSION_LLADA }));
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion_dream.visual_mode ? "true" : "false"),
[](common_params & params) {
params.diffusion_dream.visual_mode = true;
params.diffusion_llada.visual_mode = true;
}
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM, LLAMA_EXAMPLE_DIFFUSION_LLADA }));

// DREAM-specific diffusion parameters
add_opt(common_arg(
{ "--diffusion-eps" }, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion_dream.eps),
[](common_params & params, const std::string & value) { params.diffusion_dream.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
params.diffusion_dream.algorithm),
[](common_params & params, int value) { params.diffusion_dream.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));
add_opt(common_arg(
{ "--diffusion-alg-temp" }, "F",
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion_dream.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion_dream.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));

// LLADA-specific diffusion parameters
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{ "--diffusion-block-length" }, "N",
string_format("block length for generation (default: %d)", params.diffusion_llada.block_length),
[](common_params & params, int value) { params.diffusion_llada.block_length = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));
add_opt(common_arg(
{ "--diffusion-cfg-scale" }, "F",
string_format("classifier-free guidance scale (default: %.3f)", (double) params.diffusion_llada.cfg_scale),
[](common_params & params, const std::string & value) { params.diffusion_llada.cfg_scale = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("remasking algorithm: 0=LOW_CONFIDENCE, 1=RANDOM (default: %d)", params.diffusion_llada.remasking),
[](common_params & params, int value) { params.diffusion_llada.remasking = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));

return ctx_arg;
}
18 changes: 14 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_DIFFUSION_DREAM,
LLAMA_EXAMPLE_DIFFUSION_LLADA,

LLAMA_EXAMPLE_COUNT,
};
Expand Down Expand Up @@ -219,14 +220,22 @@ struct common_params_vocoder {
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};

struct common_params_diffusion {
struct common_params_diffusion_dream {
int32_t steps = 64; // number of diffusion steps
float eps = 1e-3f; // epsilon for timesteps
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
float alg_temp = 0.0f; // algorithm temperature
bool visual_mode = false; // show progressive diffusion on screen
};

struct common_params_diffusion_llada {
int32_t steps = 64; // number of diffusion steps
int32_t block_length = 32; // block length for generation
float cfg_scale = 0.2f; // classifier-free guidance scale
int32_t remasking = 0; // remasking algorithm: 0=LOW_CONFIDENCE, 1=RANDOM
bool visual_mode = false; // show progressive diffusion on screen
};

enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
Expand Down Expand Up @@ -277,8 +286,9 @@ struct common_params {

struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_diffusion diffusion;
struct common_params_vocoder vocoder;
struct common_params_diffusion_dream diffusion_dream;
struct common_params_diffusion_llada diffusion_llada;

struct common_params_model model;

Expand Down
109 changes: 109 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,115 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("LLaDAModelLM")
class LLaDAModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLADA
undo_permute = True

def get_vocab_base(self) -> tuple[list[str], list[int], str]:
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

vocab_dict = tokenizer.get_vocab()
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
assert max(vocab_dict.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
added_vocab = tokenizer.get_added_vocab()

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
# Check if it's a special token - treat special tokens as CONTROL tokens
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
toktypes.append(gguf.TokenType.CONTROL)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

return tokens, toktypes, tokpre

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()

# Add parameters similar to LlamaModel
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])

if (rope_dim := hparams.get("head_dim")) is None:
n_heads = hparams.get("num_attention_heads", hparams.get("n_heads"))
rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads
self.gguf_writer.add_rope_dimension_count(rope_dim)

# Set context length for LLaDA
context_length = self.hparams.get("max_sequence_length", 4096)
self.gguf_writer.add_context_length(context_length)

# Set embedding length (dimension size)
embedding_length = self.hparams.get("d_model", 4096)
self.gguf_writer.add_embedding_length(embedding_length)

# Set feed forward length (MLP hidden size)
feed_forward_length = self.hparams.get("mlp_hidden_size", 12288)
self.gguf_writer.add_feed_forward_length(feed_forward_length)

# Set RoPE parameters
if "rope_theta" in self.hparams:
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])

# Set RMS norm epsilon
if "rms_norm_eps" in self.hparams:
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])

# LLaDA models use non-causal attention for diffusion, similar to Dream
self.gguf_writer.add_causal_attention(False)
# Handle RoPE scaling similar to LlamaModel and Dream

# Add LLaDA-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads"))
n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads"))

if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LLaDAModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LLaDAModel.permute(data_torch, n_head, n_kv_head)

# LLaDA model tensors should be mapped directly since it's the base model
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Ernie4_5_ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
Expand Down
10 changes: 8 additions & 2 deletions examples/diffusion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
set(TARGET llama-diffusion-cli)
add_executable(${TARGET} diffusion-cli.cpp)
set(TARGET llama-diffusion-dream-cli)
add_executable(${TARGET} diffusion-dream-cli.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

set(TARGET llama-diffusion-llada-cli)
add_executable(${TARGET} diffusion-llada-cli.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
39 changes: 39 additions & 0 deletions examples/diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Diffusion Text Generation Examples

This directory contains implementations for diffusion-based text generation using two different model architectures: **Dream** and **LLaDA-8B**. Both models use iterative denoising processes to generate text, but employ different sampling strategies and algorithms.

## Supported Models

### 1. Dream Model (`llama-diffusion-dream-cli`)

- https://huggingface.co/Dream-org/Dream-v0-Base-7B
- Original PR - https://github.com/ggml-org/llama.cpp/pull/14644

The Dream model supports four different sampling algorithms controlled by the `--diffusion-algorithm` parameter:

1. **ORIGIN (0)** - Original diffusion algorithm
- Uses probability transfer based on timestep ratios
- Default algorithm with standard confidence-based token selection

2. **MASKGIT_PLUS (1)** - Enhanced MaskGIT sampling
- Improved version of the MaskGIT algorithm

3. **TOPK_MARGIN (2)** - Top-K margin-based sampling
- Confidence calculated as the margin between top-1 and top-2 probabilities

4. **ENTROPY (3)** - Entropy-based sampling (recommended)
- Uses entropy calculation for confidence estimation

### 2. LLaDA-8B Model (`llama-diffusion-llada-cli`)

- https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct

### LLaDA Model Remasking Strategies

The LLaDA model uses two remasking approaches controlled by the `--diffusion-algorithm` parameter:

1. **REMASKING_LOW_CONFIDENCE (0)** - Default strategy
- Remasks tokens with lowest confidence scores
- Uses softmax probabilities to determine confidence

2. **REMASKING_RANDOM (1)** - Random remasking
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ static std::string format_input_text(const std::string & prompt, bool use_chat_t
}

struct callback_data {
const common_params_diffusion * diff_params;
const llama_vocab * vocab;
int32_t n_input;
const common_params_diffusion_dream * diff_params;
const llama_vocab * vocab;
int32_t n_input;
};

static bool diffusion_step_callback(int32_t step,
Expand Down Expand Up @@ -396,13 +396,13 @@ int main(int argc, char ** argv) {

common_params params;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION_DREAM)) {
return 1;
}

const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
alg_names[params.diffusion.algorithm] :
const char * alg_name = (params.diffusion_dream.algorithm >= 0 && params.diffusion_dream.algorithm <= 3) ?
alg_names[params.diffusion_dream.algorithm] :
"UNKNOWN";

common_init();
Expand All @@ -421,6 +421,11 @@ int main(int argc, char ** argv) {
return 1;
}

// Check if the model architecture is Dream
char arch_str[128];
GGML_ASSERT(llama_model_meta_val_str(model, "general.architecture", arch_str, 128) >= 0 &&
std::string(arch_str) == "dream");

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
Expand All @@ -445,7 +450,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
/*add special tokens*/ true,
/*parse special*/ true);
int n_input = input_tokens.size();
int n_input = input_tokens.size();

if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
Expand All @@ -455,28 +460,28 @@ int main(int argc, char ** argv) {
}

struct diffusion_params ldiff_params = diffusion_default_params();
ldiff_params.steps = params.diffusion.steps;
ldiff_params.eps = params.diffusion.eps;
ldiff_params.steps = params.diffusion_dream.steps;
ldiff_params.eps = params.diffusion_dream.eps;
ldiff_params.temperature = params.sampling.temp;
ldiff_params.top_p = params.sampling.top_p;
ldiff_params.top_k = params.sampling.top_k;
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
ldiff_params.alg_temp = params.diffusion.alg_temp;
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion_dream.algorithm);
ldiff_params.alg_temp = params.diffusion_dream.alg_temp;
ldiff_params.seed = params.sampling.seed;

llama_token mask_token_id = llama_vocab_mask(vocab);
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);

LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
LOG_INF("dream_diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
LOG_INF("dream_diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion_dream.steps);
LOG_INF("dream_diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion_dream.eps);
LOG_INF("dream_diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion_dream.algorithm,
alg_name);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
LOG_INF("dream_diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion_dream.alg_temp);

ldiff_params.mask_token_id = mask_token_id;

callback_data cb_data = { &params.diffusion, vocab, n_input };
callback_data cb_data = { &params.diffusion_dream, vocab, n_input };

ldiff_params.step_callback = diffusion_step_callback;
ldiff_params.step_callback_user_data = &cb_data;
Expand All @@ -488,7 +493,7 @@ int main(int argc, char ** argv) {
ldiff_params, n_generated);

if (n_generated > 0) {
if (params.diffusion.visual_mode) {
if (params.diffusion_dream.visual_mode) {
//clear screen and move cursor to top-left
LOG_INF("\033[2J\033[H");
}
Expand Down
Loading
Loading