Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Server.

## Table of Contents

| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) |
| [Pre-requisites](#pre-requisites) | [Installation](#installation) | [Quickstart](#quickstart) | [Serving LLM Models](#serving-llm-models) | [Serving a vLLM Model](#serving-a-vllm-model) | [Serving a TRT-LLM Model](#serving-a-trt-llm-model) | [Serving a LLM model with OpenAI API](#serving-a-llm-model-with-openai-api) | [Serving a HuggingFace LLM Model with LLM API](#serving-a-huggingface-llm-model-with-llm-api) | [Additional Dependencies for Custom Environments](#additional-dependencies-for-custom-environments) | [Known Limitations](#known-limitations) |

## Pre-requisites

Expand Down Expand Up @@ -351,7 +351,59 @@ triton start --frontend openai
# Interact with model at http://localhost:9000
curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{
"model": "llama-3.1-8b-instruct",
"messages": [{"role": "user", "content": "What is machine learning?"}]
"messages": [{"role": "user", "content": "What is machine learning?"}],
"max_tokens": 256
}'

# Profile model with GenAI-Perf
triton profile -m llama-3.1-8b-instruct --service-kind openai --endpoint-type chat --url localhost:9000 --streaming
```

## Serving a HuggingFace LLM Model with LLM API

The LLM API is a high-level Python API and designed for Tensorrt LLM workflows. It could
convert a LLM model in Hugging Face format into a Tensorrt LLM engine and serve the engine with a unified Python API without invoking different
engine build and converting scripts.
To use the LLM API with Triton CLI, import the model with `--backend llmapi`
```bash
triton import -m "llama-3.1-8b-instruct" --backend llmapi
```

Huggingface models will be downloaded at runtime when starting the LLM API engine if not found
locally in the HuggingFace cache. No offline engine building step is required,
but you can pre-download the model in advance to avoid downloading at server
startup time. tensorrt_llm>=0.18.0 is required.

#### Example

```bash
docker run -ti \
--gpus all \
--network=host \
--shm-size=1g --ulimit memlock=-1 \
-v /tmp:/tmp \
-v ${HOME}/models:/root/models \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
nvcr.io/nvidia/tritonserver:25.03-trtllm-python-py3

# Install the Triton CLI
pip install git+https://github.com/triton-inference-server/triton_cli.git@main

# Authenticate with huggingface for restricted models like Llama-2 and Llama-3
huggingface-cli login

# Build TRT LLM engine and generate a Triton model repository pointing at it
triton remove -m all
triton import -m llama-3.1-8b-instruct --backend llmapi

# Start Triton pointing at the default model repository
triton start --frontend openai

# Interact with model at http://localhost:9000
curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{
"model": "llama-3.1-8b-instruct",
"messages": [{"role": "user", "content": "What is machine learning?"}],
"max_tokens": 256
}'

# Profile model with GenAI-Perf
Expand Down
3 changes: 3 additions & 0 deletions src/triton_cli/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
*.json
*.cache

# Except model.json from the llmapi template
!templates/llmapi/1/model.json
2 changes: 1 addition & 1 deletion src/triton_cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ class TritonCLIException(Exception):
DEFAULT_MODEL_REPO: Path = Path.home() / "models"
DEFAULT_HF_CACHE: Path = Path.home() / ".cache" / "huggingface"
HF_CACHE: Path = Path(os.environ.get("HF_HOME", DEFAULT_HF_CACHE))
SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm"}
SUPPORTED_BACKENDS: set = {"vllm", "tensorrtllm", "llmapi"}
1 change: 1 addition & 0 deletions src/triton_cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"opt125m": "hf:facebook/opt-125m",
"mistral-7b": "hf:mistralai/Mistral-7B-v0.1",
"falcon-7b": "hf:tiiuae/falcon-7b",
"tinyllama-1.1b-chat-v1.0": "hf:TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}


Expand Down
34 changes: 32 additions & 2 deletions src/triton_cli/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@
SOURCE_PREFIX_NGC = "ngc:"
SOURCE_PREFIX_LOCAL = "local:"

TRT_TEMPLATES_PATH = Path(__file__).parent / "templates" / "trt_llm"
TEMPLATES_PATH = Path(__file__).parent / "templates"
TRTLLM_TEMPLATES_PATH = TEMPLATES_PATH / "trt_llm"
LLMAPI_TEMPLATES_PATH = TEMPLATES_PATH / "llmapi"

# Support changing destination dynamically to point at
# pre-downloaded checkpoints in various circumstances
Expand Down Expand Up @@ -266,6 +268,8 @@ def __add_huggingface_model(
self.remove(model, verbose=False)
# Let detailed traceback be reported for TRT-LLM errors for debugging
raise e
elif backend == "llmapi":
self.__generate_llmapi_model(version_dir, huggingface_id)
else:
# TODO: Add generic support for HuggingFace models with HF API.
# For now, use vLLM as a means of deploying HuggingFace Transformers
Expand Down Expand Up @@ -322,6 +326,21 @@ def __generate_vllm_model(self, huggingface_id: str):
model_files = {"model.json": model_contents}
return model_config, model_files

def __generate_llmapi_model(self, version_dir, huggingface_id: str):
# load the model.json from llmapi template
model_config_file = version_dir / "model.json"
with open(model_config_file) as f:
model_config_str = f.read()

model_config_json = json.loads(model_config_str)

# change the model id as the huggingface_id
model_config_json["model"] = huggingface_id

# write back the model.json
with open(model_config_file, "w") as f:
f.write(json.dumps(model_config_json))

def __generate_ngc_model(self, name: str, source: str):
engines_path = ENGINE_DEST_PATH + "/" + source
parse_and_substitute(
Expand Down Expand Up @@ -392,7 +411,7 @@ def __create_model_repository(
)

shutil.copytree(
TRT_TEMPLATES_PATH,
TRTLLM_TEMPLATES_PATH,
self.repo,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__"),
Expand All @@ -402,6 +421,17 @@ def __create_model_repository(
logger.debug(f"Adding TensorRT-LLM models at: {self.repo}")
else:
version_dir.mkdir(parents=True, exist_ok=False)
if backend == "llmapi":
shutil.copytree(
LLMAPI_TEMPLATES_PATH / "1",
version_dir,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__"),
)
shutil.copy(
LLMAPI_TEMPLATES_PATH / "config.pbtxt",
model_dir,
)
logger.debug(f"Adding new model to repo at: {version_dir}")
except FileExistsError:
logger.warning(f"Overwriting existing model in repo at: {version_dir}")
Expand Down
5 changes: 4 additions & 1 deletion src/triton_cli/server/server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LOGGER_NAME,
TritonCLIException,
)
from .server_utils import TRTLLMUtils, VLLMUtils
from .server_utils import TRTLLMUtils, VLLMUtils, LLMAPIUtils

logger = logging.getLogger(LOGGER_NAME)

Expand Down Expand Up @@ -198,11 +198,14 @@ def _get_openai_chat_template_tokenizer(config):
)
trtllm_utils = TRTLLMUtils(config.model_repository)
vllm_utils = VLLMUtils(config.model_repository)
llmapi_utils = LLMAPIUtils(config.model_repository)

if trtllm_utils.has_trtllm_model():
tokenizer_path = trtllm_utils.get_engine_path()
elif vllm_utils.has_vllm_model():
tokenizer_path = vllm_utils.get_vllm_model_huggingface_id_or_path()
elif llmapi_utils.has_llmapi_model():
tokenizer_path = llmapi_utils.get_llmapi_model_huggingface_id_or_path()
else:
raise TritonCLIException(
"Unable to find a tokenizer to start the Triton OpenAI RESTful API, please use '--openai-chat-template-tokenizer' to specify a tokenizer."
Expand Down
85 changes: 84 additions & 1 deletion src/triton_cli/server/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_launch_command(
Parameters
----------
server_config : TritonServerConfig
A TritonServerConfig object containing command-line arguments to run tritonserver
A TritonServerConfig object containing command-line arguments to run tritonserver.
cmd_as_list : bool
Whether the command string needs to be returned as a list of string (local requires list,
docker requires str)
Expand Down Expand Up @@ -304,3 +304,86 @@ def _find_vllm_model_huggingface_id_or_path(self) -> str:
return model_id
except OSError:
raise Exception(f"Unable to open {model_config_json_file}")


class LLMAPIUtils:
"""
A utility class for handling LLMAPI specific models.
"""

def __init__(self, model_path: Path):
self._model_repo_path = model_path
self._llmapi_model_path = self._find_llmapi_model_path()
self._is_llmapi_model = self._llmapi_model_path is not None

def has_llmapi_model(self) -> bool:
"""
Returns
-------
A boolean indicating whether a LLMAPI model exists in the model repo
"""
return self._is_llmapi_model

def get_llmapi_model_huggingface_id_or_path(self) -> str:
"""
Returns
-------
The LLMAPI model's Huggingface Id or path
"""
return self._find_llmapi_model_huggingface_id_or_path()

def _find_llmapi_model_path(self) -> Path:
"""
Returns
-------
A pathlib.Path object containing the path to the LLMAPI model folder.
Assumptions
----------
- Assumes only a single model uses the LLMAPI backend (could have multiple models)
"""
# Search the llmapi model from all models in model repository
model_dirs = [
model_dir
for model_dir in self._model_repo_path.iterdir()
if model_dir.is_dir()
]
for model_dir in model_dirs:
model_config_file = Path(self._model_repo_path) / model_dir / "config.pbtxt"
model_json_path = model_config_file.parent / "1" / "model.json"
# check if config.pbtxt exists
if model_config_file.is_file():
# read the config.pbtxt file and identify the backend
with open(model_config_file) as config_file:
config = text_format.Parse(config_file.read(), mc.ModelConfig())
json_config = json.loads(
json_format.MessageToJson(
config, preserving_proto_field_name=True
)
)
# check if the model.json also exists.
if json_config["backend"] == "python" and model_json_path.is_file():
return model_config_file.parent

return None

def _find_llmapi_model_huggingface_id_or_path(self) -> str:
"""
Returns
-------
The llmapi model's Huggingface Id or path
"""
assert self._is_llmapi_model, "model Huggingface Id or path cannot be parsed from a model repository that does not contain a LLMAPI model."
try:
# assume the version is always "1"
model_version_path = self._llmapi_model_path / "1"
model_config_json_file = model_version_path / "model.json"
with open(model_config_json_file) as json_data:
data = json.load(json_data)
model_id = data.get("model")
if not model_id:
raise Exception(
f"Unable to parse config from {model_config_json_file}"
)
return model_id
except OSError:
raise Exception(f"Unable to open {model_config_json_file}")
28 changes: 28 additions & 0 deletions src/triton_cli/templates/llmapi/1/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"max_batch_size": 64,
"decoupled": true,

"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"tokenizer": null,
"tokenizer_mode": null,
"skip_tokenizer_init": null,
"trust_remote_code": null,
"tensor_parallel_size": null,
"pipeline_parallel_size": null,
"dtype": null,
"revision": null,
"tokenizer_revision": null,
"speculative_model": null,
"enable_chunked_prefill": null,
"num_instances": null,

"use_cuda_graph": null,
"cuda_graph_batch_sizes": null,
"cuda_graph_max_batch_size": null,
"cuda_graph_padding_enabled": null,
"enable_overlap_scheduler": null,
"kv_cache_dtype": null,
"torch_compile_enabled": null,
"torch_compile_fullgraph": null,
"torch_compile_inductor_enabled": null
}
Loading