Skip to content

Commit 1970418

Browse files
authored
prototyping better UX (#134)
* Load weights directly from HF for llama
1 parent 94449c3 commit 1970418

14 files changed

+646
-24
lines changed

convert_checkpoints.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ def _load_from_local(input_ckpt_dir: epath.Path):
357357
if not _FROM_HF.value:
358358
return _load_orig_llama_weight(input_ckpt_dir)
359359
else:
360-
assert (
361-
not FLAGS.quantize_weights
362-
), "Quantization not supported for HF checkpoint."
363360
return _load_hf_llama_weight(input_ckpt_dir)
364361

365362

install_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pip show torch_xla2 && pip uninstall -y torch_xla2
2727
pip install flax
2828
pip install tensorflow-text
2929
pip install tensorflow
30+
pip install huggingface_hub
3031

3132
pip install ray[default]==2.22.0
3233
# torch cpu

jetstream_pt/cli.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys
2+
3+
# import torch_xla2 first!
4+
import torch_xla2 # pylint: disable
5+
import jax
6+
from absl import app, flags
7+
from jetstream.core import server_lib
8+
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
9+
import torch
10+
11+
from jetstream_pt import fetch_models
12+
from jetstream_pt import environment, engine, quantize_model, torchjax
13+
from jetstream_pt import config
14+
15+
16+
FLAGS = flags.FLAGS
17+
18+
flags.DEFINE_string("model_id", "", "")
19+
flags.DEFINE_integer("override_batch_size", 32, "The batch size")
20+
flags.DEFINE_integer("max_input_length", 1024, "The batch size")
21+
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
22+
flags.DEFINE_integer("port", 9000, "port to listen on")
23+
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
24+
25+
26+
def shard_weights(env, weights, weight_shardings):
27+
"""Shard weights according to weight_shardings"""
28+
for k, v in weight_shardings.items():
29+
print("SHARDING", k, v)
30+
sharded = {}
31+
for key, val in weights.items():
32+
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
33+
with jax.default_device(jax.devices("cpu")[0]):
34+
arr = torch_xla2.tensor.t2j(val)
35+
arr = jax.device_put(arr, sharding)
36+
sharded[key] = torchjax.to_torch(arr)
37+
return sharded
38+
39+
40+
def create_engine(devices):
41+
"""Create Pytorch engine from flags"""
42+
torch.set_default_dtype(torch.bfloat16)
43+
quant_config = config.create_quantization_config_from_flags()
44+
env_data = fetch_models.construct_env_data_from_model_id(
45+
FLAGS.model_id,
46+
FLAGS.override_batch_size,
47+
FLAGS.max_input_length,
48+
FLAGS.max_output_length,
49+
quant_config.enable_weight_quantization,
50+
)
51+
env = environment.JetEngineEnvironment(env_data)
52+
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
53+
54+
weight_shardings = model.get_sharding_annotations()
55+
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)
56+
57+
if quant_config.enable_weight_quantization:
58+
model.load_state_dict(sharded_weights, assign=True, strict=False)
59+
quantize_model.quantize_model(model, quant_config)
60+
sharded_weights = model.state_dict()
61+
62+
return engine.PyTorchEngine(
63+
pt_model=model,
64+
env=env,
65+
weights=torchjax.from_torch_with_copy(sharded_weights),
66+
)
67+
68+
69+
def list_model():
70+
"""Print list of models."""
71+
for model_id in fetch_models.model_id_to_class:
72+
print(model_id)
73+
74+
75+
def serve():
76+
"""Run gRPC server."""
77+
if FLAGS.model_id == "":
78+
print("Please specify model_id with --model_id")
79+
print("valid model ids are:")
80+
list_model()
81+
sys.exit(1)
82+
devices = server_lib.get_devices()
83+
print(f"devices: {devices}")
84+
85+
server_config = ServerConfig(
86+
interleaved_slices=(f"tpu={len(jax.devices())}",),
87+
interleaved_engine_create_fns=[create_engine],
88+
)
89+
print(f"server_config: {server_config}")
90+
91+
metrics_server_config: MetricsServerConfig | None = None
92+
93+
# We separate credential from run so that we can unit test it with local credentials.
94+
# We would like to add grpc credentials for OSS.
95+
jetstream_server = server_lib.run(
96+
threads=FLAGS.threads,
97+
port=FLAGS.port,
98+
config=server_config,
99+
devices=devices,
100+
metrics_server_config=metrics_server_config,
101+
)
102+
print("Started jetstream_server....")
103+
jetstream_server.wait_for_termination()
104+
105+
106+
def interactive():
107+
"""Run interactive"""
108+
raise RuntimeError("Not implemented")
109+
110+
111+
def main(argv):
112+
"""Entry point"""
113+
if len(argv) < 2:
114+
print("Invalid arguments. please specify 'list' or 'serve'")
115+
116+
if argv[1] == "list":
117+
list_model()
118+
return
119+
120+
if argv[1] == "serve":
121+
serve()
122+
return
123+
124+
if argv[1] == "interative":
125+
interactive()
126+
return
127+
128+
129+
if __name__ == "__main__":
130+
app.run(main)

jetstream_pt/engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
import os
2121

22+
import glob
2223
from etils import epath
2324
from flax import struct
2425
import jax
@@ -40,6 +41,9 @@
4041
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
4142
from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model
4243

44+
from absl import flags
45+
46+
FLAGS = flags.FLAGS
4347

4448
Mesh = jax.sharding.Mesh
4549
P = jax.sharding.PartitionSpec
@@ -82,11 +86,13 @@ def __init__(
8286
self,
8387
pt_model: torch.nn.Module,
8488
env: JetEngineEnvironment,
89+
weights=None,
8590
):
8691
self.pt_model = pt_model
8792
self.env = env
8893
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
8994
self.rng = jax.random.PRNGKey(0)
95+
self.weights = weights
9096

9197
self.y_sharding = env.sharding_by_axis(1)
9298
self.x_sharding = env.sharding_by_axis(0)
@@ -713,6 +719,8 @@ def _load_from_state_dict(self, path):
713719

714720
# pylint: disable-next=all
715721
def load_params(self) -> Params:
722+
if self.weights is not None:
723+
return self.weights
716724
# We want to fix this: load from files
717725
with jax.default_device(self.colocated_cpus):
718726
if self.env.checkpoint_path:

jetstream_pt/environment.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple, Dict
16-
1715
import dataclasses
18-
import yaml
16+
from typing import Tuple
1917

2018
import jax
2119
import jax.sharding as jsharding
2220
from jax.experimental import mesh_utils
2321
import torch_xla2
22+
import yaml
2423

2524

2625
from jetstream_pt import cache_manager
@@ -36,7 +35,6 @@ class QuantizationConfig:
3635
is_symmetric_weight: bool = True
3736

3837
enable_activation_quantization: bool = False
39-
4038
enable_kv_quantization: bool = False
4139

4240

@@ -75,11 +73,6 @@ class JetEngineEnvironmentData:
7573
# This string must be one of the values of attention_kv_axis_names above
7674
kv_cache_shard_axis: str = "num_attn_heads"
7775

78-
# Override sharding axis of a weight by name
79-
experimental_sharding_axis_override: Dict[str, int] = dataclasses.field(
80-
default_factory=dict
81-
)
82-
8376
# QKV fusion has negative performance on TPU, slicing takes longer
8477
qkv_fusion: bool = False
8578

0 commit comments

Comments
 (0)