-
Notifications
You must be signed in to change notification settings - Fork 20
Ray Disaggregated Serving MVP #106
Conversation
allenwang28
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment - it looks like the main difference for is_disaggregated within PyTorchRayEngine is whether or not prefill returns outputs.
If the prefill/decode/interleave functionality is essentially the same, then I guess it's an implementation detail for orchestrator to trigger the transfer. If so, then it possible to exclude is_disaggregated from the worker? That'd simplify the complexity
Simplified the prefill call from engine side. On the worker side. Yes, they are same on insert and decode side. But I feel it's better to keep disaggregated and interleave for prefill. Several reasons:
|
I think that makes sense to me, thanks! |
This PR enable pytorch engine disaggregated serving on multiple TPU POD slices.
This PR delivered:
Result validation:
Command:
python /home/{user}/jetstream-pytorch/run_interactive_disaggregated.py --size=7b --batch_size=1 --is_disaggregated=True --num_hosts=8 --decode_pod_slice_name={user}-tpu-vm-2 --model_name=llama-2 --max_cache_length=2048 --quantize_weights=False --quantize_kv_cache=False --checkpoint_path=/home/{user}/data/llama-2-7b-chat-safetensor/model.safetensors --tokenizer_path=/home/{user}/data/tokenizer.model --sharding_config=/home/{user}/jetstream-pytorch/default_shardings/llama.yamlInterleave result:
Disaggregated result:
Next Steps:
5: Support multiple prefill engine and multiple decode engine