|
17 | 17 | from typing import Any, List, Optional, Tuple, Union |
18 | 18 | import threading |
19 | 19 | import functools |
| 20 | +import os |
20 | 21 | import humanize |
21 | 22 |
|
22 | 23 |
|
|
39 | 40 | from jetstream_pt import cache_manager |
40 | 41 | from jetstream_pt import quantize |
41 | 42 | from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData |
| 43 | +from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model |
42 | 44 |
|
43 | 45 |
|
44 | 46 | Mesh = jax.sharding.Mesh |
@@ -103,6 +105,7 @@ def __init__( |
103 | 105 | quantize_weights=False, |
104 | 106 | quantize_kv=False, |
105 | 107 | max_cache_length=1024, |
| 108 | + sharding_config=None, |
106 | 109 | ): |
107 | 110 |
|
108 | 111 | jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
@@ -144,38 +147,61 @@ def __init__( |
144 | 147 | checkpoint_format = "safetensors" |
145 | 148 | checkpoint_path = paths[0] |
146 | 149 |
|
| 150 | + if not sharding_config: |
| 151 | + sharding_config = os.path.join("default_shardings", model_name + ".yaml") |
| 152 | + |
147 | 153 | env_data = JetEngineEnvironmentData( |
148 | 154 | tokenizer_path=tokenizer_path, |
149 | 155 | checkpoint_path=checkpoint_path, |
150 | 156 | checkpoint_format=checkpoint_format, |
151 | | - model_type="llama-2-" + param_size, |
152 | 157 | batch_size=batch_size, |
153 | 158 | max_decode_length=max_decode_length, |
154 | 159 | max_input_sequence_length=context_length, |
155 | 160 | enable_weight_quantization=quantize_weights, |
156 | 161 | enable_kv_quantization=quantize_kv, |
157 | 162 | cache_sequence_length=max_cache_length, |
158 | 163 | bf16_enable=bf16_enable, |
| 164 | + sharding_config_path=sharding_config, |
159 | 165 | ) |
160 | 166 | env = JetEngineEnvironment(env_data) |
161 | 167 |
|
162 | | - pt_model = None |
163 | | - if "llama" in model_name: |
| 168 | + if model_name.startswith("llama"): |
| 169 | + |
164 | 170 | args = model_args.get_model_args( |
165 | | - model_name + "-" + param_size, |
166 | | - context_length, |
167 | | - batch_size, |
168 | | - bf16_enable, |
| 171 | + model_name + "-" + param_size, context_length, batch_size, bf16_enable |
169 | 172 | ) |
170 | 173 | args.device = "meta" |
171 | 174 | args.quantize = quantize_weights |
| 175 | + env_data.cache_shape = ( |
| 176 | + batch_size, |
| 177 | + args.n_kv_heads, |
| 178 | + max_cache_length, |
| 179 | + args.dim // args.n_heads, |
| 180 | + ) |
| 181 | + env_data.model_type = "llama-2-" + param_size |
| 182 | + env_data.num_layers = args.n_layers |
| 183 | + env = JetEngineEnvironment(env_data) |
172 | 184 | pt_model = model_exportable.Transformer(args, env) |
| 185 | + elif model_name == "gemma": |
| 186 | + args = gemma_config.get_model_config(param_size) |
| 187 | + env_data.cache_shape = ( |
| 188 | + batch_size, |
| 189 | + args.num_key_value_heads, |
| 190 | + max_cache_length, |
| 191 | + args.head_dim, |
| 192 | + ) |
| 193 | + env_data.model_type = "gemma-" + param_size |
| 194 | + env_data.num_layers = args.num_hidden_layers |
| 195 | + env = JetEngineEnvironment(env_data) |
| 196 | + pt_model = gemma_model.GemmaModel(args, env) |
| 197 | + else: |
| 198 | + raise RuntimeError(f"Model with name {model_name} not found") |
173 | 199 |
|
174 | | - num_params_size = 0 |
175 | | - num_params = 0 |
176 | | - for _, v in pt_model.state_dict().items(): |
177 | | - num_params += 1 |
178 | | - num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) |
| 200 | + num_params_size = 0 |
| 201 | + num_params = 0 |
| 202 | + for _, v in pt_model.state_dict().items(): |
| 203 | + num_params += 1 |
| 204 | + num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) |
179 | 205 | print("Number of param Gbytes:", num_params_size / (1 << 30)) |
180 | 206 | print("Number of param: ", num_params) |
181 | 207 |
|
|
0 commit comments