Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import keras.ops as ops


def get_gemma_config(backbone):
hf_config = {
"vocab_size": backbone.vocabulary_size,
"num_hidden_layers": backbone.num_layers,
"num_attention_heads": backbone.num_query_heads,
"num_key_value_heads": backbone.num_key_value_heads,
"hidden_size": backbone.hidden_dim,
"intermediate_size": backbone.intermediate_dim // 2,
"head_dim": backbone.head_dim,
"max_position_embeddings": 8192,
}
return hf_config


def get_gemma_weights_map(backbone):
weights_dict = {}

# Map token embedding
token_embedding_layer = backbone.get_layer("token_embedding")
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]

for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
decoder_layer.pre_attention_norm.weights[0]
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.weights[0]
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
query_kernel = ops.transpose(query_kernel)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

# Attention key projection
key_kernel = decoder_layer.attention.key_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
ops.transpose(key_kernel)
)

# Attention value projection
value_kernel = decoder_layer.attention.value_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
ops.transpose(value_kernel)
)

# Attention output projection
out_kernel = decoder_layer.attention.output_dense.weights[0]
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

# Post-attention normalization
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
decoder_layer.pre_ffw_norm.weights[0]
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.weights[0]
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
gate_kernel
)

# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.weights[0]
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
up_kernel
)

# MLP down projection
down_kernel = decoder_layer.ffw_linear.weights[0]
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
down_kernel
)

# Map final normalization
weights_dict["model.norm.weight"] = backbone.get_layer(
"final_normalization"
).weights[0]

# Tie weights, but clone to avoid sharing memory issues
weights_dict["lm_head.weight"] = ops.copy(token_embedding_layer.weights[0])

return weights_dict
143 changes: 143 additions & 0 deletions keras_hub/src/utils/transformers/export/gemma_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os

import numpy as np
import torch
from sentencepiece import SentencePieceTrainer
from transformers import GemmaForCausalLM
from transformers import GemmaTokenizer as HFGemmaTokenizer

from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
GemmaCausalLMPreprocessor,
)
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_to_safetensors,
)


class TestGemmaExport(TestCase):
def test_export_to_hf(self):
# Create a dummy tokenizer
train_sentences = [
"The quick brown fox jumped.",
"I like pizza.",
"This is a test.",
]
# TODO:Consider using keras_hub/src/tests/test_data/gemma_test_vocab.spm
# instead of retraining a new vocab here. Will be faster.
proto_prefix = os.path.join(self.get_temp_dir(), "dummy_vocab")
SentencePieceTrainer.train(
sentence_iterator=iter(train_sentences),
model_prefix=proto_prefix,
vocab_size=290,
model_type="unigram",
pad_id=0,
bos_id=1,
eos_id=2,
unk_id=3,
byte_fallback=True,
pad_piece="<pad>",
bos_piece="<bos>",
eos_piece="<eos>",
unk_piece="<unk>",
user_defined_symbols=["<start_of_turn>", "<end_of_turn>"],
)
tokenizer = GemmaTokenizer(proto=f"{proto_prefix}.model")

# Create a small backbone
backbone = GemmaBackbone(
vocabulary_size=tokenizer.vocabulary_size(),
num_layers=2,
num_query_heads=4,
num_key_value_heads=1,
hidden_dim=512,
intermediate_dim=1028,
head_dim=128,
)
# Create preprocessor
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)

# Create the causal LM model
keras_model = GemmaCausalLM(
backbone=backbone, preprocessor=preprocessor
)

# Set all weights to random values
rng = np.random.default_rng(42)
weights = keras_model.get_weights()
for i in range(len(weights)):
weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype)
keras_model.set_weights(weights)

# Export to Hugging Face format
export_path = os.path.join(self.get_temp_dir(), "export_small_model")
export_to_safetensors(keras_model, export_path)
# Load Hugging Face model and tokenizer
hf_model = GemmaForCausalLM.from_pretrained(export_path)
hf_tokenizer = HFGemmaTokenizer.from_pretrained(export_path)

# Verify configuration
hf_config = hf_model.config
self.assertEqual(
hf_config.vocab_size,
backbone.vocabulary_size,
"Vocabulary sizes do not match",
)
self.assertEqual(
hf_config.num_hidden_layers,
backbone.num_layers,
"Number of layers do not match",
)
self.assertEqual(
hf_config.num_attention_heads,
backbone.num_query_heads,
"Number of query heads do not match",
)
self.assertEqual(
hf_config.num_key_value_heads,
backbone.num_key_value_heads,
"Number of key value heads do not match",
)
self.assertEqual(
hf_config.hidden_size,
backbone.hidden_dim,
"Hidden dimensions do not match",
)
self.assertEqual(
hf_config.intermediate_size,
backbone.intermediate_dim // 2,
"Intermediate sizes do not match",
)
self.assertEqual(
hf_config.head_dim,
backbone.head_dim,
"Head dimensions do not match",
)
self.assertEqual(
hf_config.max_position_embeddings,
8192,
"Max position embeddings do not match",
)

# Verify tokenizer compatibility
self.assertEqual(
hf_tokenizer.vocab_size,
tokenizer.vocabulary_size(),
"Tokenizer vocabulary sizes do not match",
)

# Compare generated outputs
prompt = "the quick"
keras_output = keras_model.generate(prompt, max_length=20)
input_ids = hf_tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output_ids = hf_model.generate(
input_ids, max_length=20, do_sample=False
)
hf_output = hf_tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.assertEqual(
keras_output, hf_output, "Generated outputs do not match"
)
98 changes: 98 additions & 0 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
import os
import shutil
import warnings

import keras

from keras_hub.src.utils.transformers.export.gemma import get_gemma_config
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map

MODEL_CONFIGS = {
"GemmaBackbone": get_gemma_config,
# Add future models here, e.g., "LlamaBackbone": get_llama_config,
}

MODEL_EXPORTERS = {
"GemmaBackbone": get_gemma_weights_map,
# Add future models here, e.g., "LlamaBackbone": get_llama_weights_map,
}


def export_to_safetensors(keras_model, path):
"""Converts a Keras model to Hugging Face safetensor format.

It does the following:
- Extracts and maps weights from the Keras backbone to safetensors.
- Saves the configuration as 'config.json'.
- Saves weights in 'model.safetensors'.
- Saves tokenizer assets.

Args:
keras_model: The Keras model to convert.
path: str. Path of the directory to which the safetensors file,
config and tokenizer will be saved.
"""
backend = keras.config.backend()
backbone = keras_model.backbone
model_type = backbone.__class__.__name__

if model_type not in MODEL_CONFIGS:
raise ValueError(f"Config not implemented for {model_type}")

if model_type not in MODEL_EXPORTERS:
raise ValueError(f"Exporter not implemented for {model_type}")

get_config_fn = MODEL_CONFIGS[model_type]
hf_config = get_config_fn(backbone)

get_weights_fn = MODEL_EXPORTERS[model_type]
weights_dict = get_weights_fn(backbone)

if not weights_dict:
raise ValueError("No weights to save.")

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)

# Save weights based on backend
weights_path = os.path.join(path, "model.safetensors")
if backend == "torch":
from safetensors.torch import save_file

weights_dict_contiguous = {
k: v.value.contiguous() if hasattr(v, "value") else v.contiguous()
for k, v in weights_dict.items()
}
save_file(
weights_dict_contiguous, weights_path, metadata={"format": "pt"}
)
elif backend == "tensorflow":
from safetensors.tensorflow import save_file

save_file(weights_dict, weights_path, metadata={"format": "pt"})
elif backend == "jax":
from safetensors.flax import save_file

save_file(weights_dict, weights_path, metadata={"format": "pt"})
else:
raise ValueError(f"Unsupported backend: {backend}")

# Save tokenizer assets
keras_model.preprocessor.tokenizer.save_assets(path)

# Rename vocabulary file
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(
f"{vocab_spm_path} not found. Tokenizer may not load "
"correctly. Ensure that the tokenizer configuration "
"is correct and that the vocabulary file is present "
"in the original model."
)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sentencepiece
tensorflow-datasets
safetensors
pillow
transformers
Loading