Skip to content

WIP [DeepSeek R1] Add DeepSeekV3 Base + Weight Conversion #2171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

DavidLandup0
Copy link
Collaborator

@DavidLandup0 DavidLandup0 commented Mar 27, 2025

Adds DeepSeekV3 base and weight conversion script.

The architecture itself builds and runs, but requires massive RAM. Example of a one-block model (0.5B size) running on some tokens below:

image

Or a 2-layer model (1B size):

image

Needs more refactoring and simplification.

WIP/TODOs

  • The weight download takes around 880GB of disk space. Then instantiating the model and torch weights both in memory requires massive RAM. Figure out if this can be done iteratively.
  • Figure out Keras weight sharding + this
  • Move from ModelArgs dataclass syntax into config.json style config



@dataclass
class ModelArgs:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comes from the original impl - currently here for sanity checking. Will be removed in lieu of json configs.

return logits


if __name__ == "__main__":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check main call - will be removed.

rank = 0


class Embedding(layers.Layer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Remove custom class and just use layers.Embedding.

return linear(x, self.weight, self.bias)


class ColumnParallelLinear(Linear):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the custom XParallel classes if we don't use torch.dist which boils most of them back to the standard implementations?

@divyashreepathihalli divyashreepathihalli self-requested a review March 28, 2025 03:57
@sachinprasadhs sachinprasadhs added the WIP Pull requests which are work in progress and not ready yet for review. label Apr 25, 2025
Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did one normal pass of review and have made generic comments.

from keras_hub.api import tokenizers
from keras_hub.api import utils
from keras_hub.src.utils.preset_utils import upload_preset
from keras_hub.api import layers as layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rebase the code to latest changes in Keras Hub in master and generate API again, so that it will show only the changes from DeepSeek

"vocab_size": 129280,
"dim": 7168,
"inter_dim": 10944,
"moe_inter_dim": 1408,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove both the config files and the directory deepseek, this will be part of the checkpoint files which will be uploaded.

Comment on lines +73 to +99
n_layers,
n_dense_layers,
n_heads,
# moe
n_routed_experts,
n_shared_experts,
n_activated_experts,
n_expert_groups,
n_limited_groups,
score_func,
route_scale,
# mla
q_lora_rank,
kv_lora_rank,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
# yarn
original_seq_len,
rope_theta,
rope_factor,
beta_fast,
beta_slow,
mscale,
**kwargs,
):
# === Layers ===
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace all the n_ to num_, example n_layers --> num_layers, so that this will be consistent with other implementaions.

Comment on lines +46 to +49
preprocessor = keras_hub.models.DeepSeekR1CausalLMPreprocessor.from_preset(
"..."
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can mention the probable preset name, the same thing can be used in conversion script/ Kaggle whereever it is applicable.

"""
Multi-Headed Attention Layer (MLA).

Attributes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attributes --> Args

Comment on lines +279 to +280
"""
Multi-Headed Attention Layer (MLA).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to --> """Multi-Headed Attention Layer (MLA).

Basically, """ should be followed by string in the same line

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces the DeepSeekV3 model, including its architecture, tokenizer, and a weight conversion script. The implementation is detailed, but requires attention to backend-agnosticism, graph compatibility, serialization, and the weight conversion script.

Comment on lines +187 to +208
def get_config(self):
config = super().get_config()
config.update(
{
"vocab_size": self.vocab_size,
"n_layers": self.n_layers,
"dim": self.dim,
"n_heads": self.n_heads,
"q_lora_rank": self.q_lora_rank,
"kv_lora_rank": self.kv_lora_rank,
"inter_dim": self.inter_dim,
"n_dense_layers": self.n_dense_layers,
"n_routed_experts": self.n_routed_experts,
"n_activated_experts": self.n_activated_experts,
"n_expert_groups": self.n_expert_groups,
"n_limited_groups": self.n_limited_groups,
"score_func": self.score_func,
"route_scale": self.route_scale,
"max_seq_len": self.max_seq_len,
}
)
return config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_config method is missing several parameters, preventing correct model serialization and deserialization. Ensure all __init__ arguments are saved as instance attributes and included in the get_config dictionary. Missing parameters include: max_batch_size, moe_inter_dim, n_shared_experts, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, original_seq_len, rope_theta, rope_factor, beta_fast, beta_slow, mscale.

import math

import keras
import torch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file contains backend-specific implementations, preventing it from working with TensorFlow or JAX. Keras models should be backend-agnostic. Issues include: direct torch import and usage (line 4), Torch-specific tensor methods (e.g., .expand(), .squeeze()), non-graph-compatible operations (.tolist()), and in-place updates.

Comment on lines +192 to +195
with safe_open(
f"model-0000{i + 1}-of-000163.safetensors",
framework="pt",
device="cpu",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The filename for loading safetensors is constructed incorrectly. Use zfill to correctly pad the file index.

Suggested change
with safe_open(
f"model-0000{i + 1}-of-000163.safetensors",
framework="pt",
device="cpu",
with safe_open(
f"model-{str(i + 1).zfill(5)}-of-00163.safetensors",
framework="pt",
device="cpu",

)

# === Functional Model ===
tokens = keras.Input(shape=(128,), dtype="int32", name="tokens")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input shape for tokens is hardcoded to (128,). To allow for variable sequence lengths, the shape should be set to (None,).

Suggested change
tokens = keras.Input(shape=(128,), dtype="int32", name="tokens")
tokens = keras.Input(shape=(None,), dtype="int32", name="tokens")

from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization

attn_impl = "absorb"
rank = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The distributed training rank is hardcoded to 0. In a distributed environment, each process needs a unique rank obtained from the distributed strategy context.

Comment on lines +17 to +27
def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin▁of▁sentence|>",
eos_token="<|end▁of▁sentence|>",
misc_special_tokens={
"<|▁pad▁|>",
},
**kwargs,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The __init__ method modifies the mutable default argument misc_special_tokens, which can cause unexpected behavior. Initialize it to None and create an empty set if not provided.

Suggested change
def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin▁of▁sentence|>",
eos_token="<|end▁of▁sentence|>",
misc_special_tokens={
"<|▁pad▁|>",
},
**kwargs,
):
def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin of sentence|>",
eos_token="<|end of sentence|>",
misc_special_tokens=None,
**kwargs,
):
if misc_special_tokens is None:
misc_special_tokens = {
"<| pad |>",
}

Comment on lines +116 to +154
def convert_block(keras_block, torch_weights, index):
print("Weights and shapes")
for i, w in enumerate(keras_block.weights):
print(i, w.path, w.shape)
print()
for i, w in enumerate(torch_weights):
if f"layers.{index - 1}" in w:
print(i - 1, w, torch_weights[w].shape)

keras_block.weights[0].assign(
torch_weights[f"layers.{index - 1}.attn.wq.weight"]
)
keras_block.weights[1].assign(
torch_weights[f"layers.{index - 1}.attn.wkv_a.weight"]
)
keras_block.weights[2].assign(
torch_weights[f"layers.{index - 1}.attn.kv_norm.weight"]
)
keras_block.weights[3].assign(
torch_weights[f"layers.{index - 1}.attn.wkv_b.weight"]
)
keras_block.weights[4].assign(
torch_weights[f"layers.{index - 1}.attn.wo.weight"]
)
keras_block.weights[5].assign(
torch_weights[f"layers.{index - 1}.ffn.w1.weight"]
)
keras_block.weights[6].assign(
torch_weights[f"layers.{index - 1}.ffn.w2.weight"]
)
keras_block.weights[7].assign(
torch_weights[f"layers.{index - 1}.ffn.w3.weight"]
)
keras_block.weights[8].assign(
torch_weights[f"layers.{index - 1}.attn_norm.weight"]
)
keras_block.weights[9].assign(
torch_weights[f"layers.{index - 1}.ffn_norm.weight"]
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Assigning weights by hardcoded indices is brittle. Access layers and weights by name using keras_block.get_layer(name) for more robust weight loading.

Comment on lines +13 to +28
"""An end-to-end Llama 3 model for causal language modeling.

A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a LLaMA 3 model, simply by calling `fit()`.

This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_hub.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.

Args:
backbone: A `keras_hub.models.DeepSeekR1Backbone` instance.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is partially copied from a Llama 3 model. Update it to refer to the DeepSeekV3 model. Also, the Args section on line 28 refers to DeepSeekR1Backbone, but the class used is DeepSeekV3Backbone. Ensure consistency.

Suggested change
"""An end-to-end Llama 3 model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a LLaMA 3 model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_hub.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Args:
backbone: A `keras_hub.models.DeepSeekR1Backbone` instance.
"""An end-to-end DeepSeekV3 model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a DeepSeekV3 model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_hub.samplers` objects to control the generation. By
default, `"top_k"` sampling will be used.
Args:
backbone: A `keras_hub.models.DeepSeekV3Backbone` instance.

def download_and_rename_weight_files():
for i in range(end):
print(f"Downloading model part {i + 1}/{end}")
part = str(i + 1).zfill(5)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove the stray trailing space.

Suggested change
part = str(i + 1).zfill(5)
part = str(i + 1).zfill(5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Pull requests which are work in progress and not ready yet for review.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants