|
| 1 | +import sys |
| 2 | + |
| 3 | +# import torch_xla2 first! |
| 4 | +import torch_xla2 # pylint: disable |
| 5 | +import jax |
| 6 | +from absl import app, flags |
| 7 | +from jetstream.core import server_lib |
| 8 | +from jetstream.core.config_lib import ServerConfig, MetricsServerConfig |
| 9 | +import torch |
| 10 | + |
| 11 | +from jetstream_pt import fetch_models |
| 12 | +from jetstream_pt import environment, engine, quantize_model, torchjax |
| 13 | +from jetstream_pt import config |
| 14 | + |
| 15 | + |
| 16 | +FLAGS = flags.FLAGS |
| 17 | + |
| 18 | +flags.DEFINE_string("model_id", "", "") |
| 19 | +flags.DEFINE_integer("override_batch_size", 32, "The batch size") |
| 20 | +flags.DEFINE_integer("max_input_length", 1024, "The batch size") |
| 21 | +flags.DEFINE_integer("max_output_length", 1024, "The batch size") |
| 22 | +flags.DEFINE_integer("port", 9000, "port to listen on") |
| 23 | +flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") |
| 24 | + |
| 25 | + |
| 26 | +def shard_weights(env, weights, weight_shardings): |
| 27 | + """Shard weights according to weight_shardings""" |
| 28 | + for k, v in weight_shardings.items(): |
| 29 | + print("SHARDING", k, v) |
| 30 | + sharded = {} |
| 31 | + for key, val in weights.items(): |
| 32 | + sharding = env.sharding_by_axis(weight_shardings.get(key, -1)) |
| 33 | + with jax.default_device(jax.devices("cpu")[0]): |
| 34 | + arr = torch_xla2.tensor.t2j(val) |
| 35 | + arr = jax.device_put(arr, sharding) |
| 36 | + sharded[key] = torchjax.to_torch(arr) |
| 37 | + return sharded |
| 38 | + |
| 39 | + |
| 40 | +def create_engine(devices): |
| 41 | + """Create Pytorch engine from flags""" |
| 42 | + torch.set_default_dtype(torch.bfloat16) |
| 43 | + quant_config = config.create_quantization_config_from_flags() |
| 44 | + env_data = fetch_models.construct_env_data_from_model_id( |
| 45 | + FLAGS.model_id, |
| 46 | + FLAGS.override_batch_size, |
| 47 | + FLAGS.max_input_length, |
| 48 | + FLAGS.max_output_length, |
| 49 | + quant_config.enable_weight_quantization, |
| 50 | + ) |
| 51 | + env = environment.JetEngineEnvironment(env_data) |
| 52 | + model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env) |
| 53 | + |
| 54 | + weight_shardings = model.get_sharding_annotations() |
| 55 | + sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) |
| 56 | + |
| 57 | + if quant_config.enable_weight_quantization: |
| 58 | + model.load_state_dict(sharded_weights, assign=True, strict=False) |
| 59 | + quantize_model.quantize_model(model, quant_config) |
| 60 | + sharded_weights = model.state_dict() |
| 61 | + |
| 62 | + return engine.PyTorchEngine( |
| 63 | + pt_model=model, |
| 64 | + env=env, |
| 65 | + weights=torchjax.from_torch_with_copy(sharded_weights), |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def list_model(): |
| 70 | + """Print list of models.""" |
| 71 | + for model_id in fetch_models.model_id_to_class: |
| 72 | + print(model_id) |
| 73 | + |
| 74 | + |
| 75 | +def serve(): |
| 76 | + """Run gRPC server.""" |
| 77 | + if FLAGS.model_id == "": |
| 78 | + print("Please specify model_id with --model_id") |
| 79 | + print("valid model ids are:") |
| 80 | + list_model() |
| 81 | + sys.exit(1) |
| 82 | + devices = server_lib.get_devices() |
| 83 | + print(f"devices: {devices}") |
| 84 | + |
| 85 | + server_config = ServerConfig( |
| 86 | + interleaved_slices=(f"tpu={len(jax.devices())}",), |
| 87 | + interleaved_engine_create_fns=[create_engine], |
| 88 | + ) |
| 89 | + print(f"server_config: {server_config}") |
| 90 | + |
| 91 | + metrics_server_config: MetricsServerConfig | None = None |
| 92 | + |
| 93 | + # We separate credential from run so that we can unit test it with local credentials. |
| 94 | + # We would like to add grpc credentials for OSS. |
| 95 | + jetstream_server = server_lib.run( |
| 96 | + threads=FLAGS.threads, |
| 97 | + port=FLAGS.port, |
| 98 | + config=server_config, |
| 99 | + devices=devices, |
| 100 | + metrics_server_config=metrics_server_config, |
| 101 | + ) |
| 102 | + print("Started jetstream_server....") |
| 103 | + jetstream_server.wait_for_termination() |
| 104 | + |
| 105 | + |
| 106 | +def interactive(): |
| 107 | + """Run interactive""" |
| 108 | + raise RuntimeError("Not implemented") |
| 109 | + |
| 110 | + |
| 111 | +def main(argv): |
| 112 | + """Entry point""" |
| 113 | + if len(argv) < 2: |
| 114 | + print("Invalid arguments. please specify 'list' or 'serve'") |
| 115 | + |
| 116 | + if argv[1] == "list": |
| 117 | + list_model() |
| 118 | + return |
| 119 | + |
| 120 | + if argv[1] == "serve": |
| 121 | + serve() |
| 122 | + return |
| 123 | + |
| 124 | + if argv[1] == "interative": |
| 125 | + interactive() |
| 126 | + return |
| 127 | + |
| 128 | + |
| 129 | +if __name__ == "__main__": |
| 130 | + app.run(main) |
0 commit comments