Skip to content

Commit ba5fc7d

Browse files
authored
Add support for stop words in TRTLLM (#2678)
* feat(trtllm): rewrite health to not account for current state * chore(looper): cleanup a bit more * feat(post_processing): max_new_tokens is const evaluated now * chore(ffi):formatting * feat(trtllm): add stop words handling # Conflicts: # backends/trtllm/lib/backend.cpp * chore(trtllm): create specific parallelconfig factory and logging init methods * chore(trtllm): define a macro for SizeType cast * chore(trtllm): use GetParallelConfig * chore(trtllm): minor refactoring * chore(trtllm): validate there are enough GPus on the system for the desired model * chore(trtllm): ensure max throughput scheduling policy is selected * chore(trtllm): minor fix * chore(router): minor refactorings * feat(docker): build with-slurm ompi * feat(docker): add python3.10 dev to runtime deps * chore(docker): add mpi to ld_library_path * chore(docker): install transformers * feat(trtllm): detect stop_words from generation_config.json
1 parent db68bd0 commit ba5fc7d

File tree

6 files changed

+144
-69
lines changed

6 files changed

+144
-69
lines changed

Dockerfile_trtllm

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE
4343
mkdir /usr/src/mpi && \
4444
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
4545
cd /usr/src/mpi && \
46-
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
46+
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
4747
make -j all && \
4848
make install && \
4949
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
@@ -84,12 +84,13 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
8484
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
8585

8686
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
87-
RUN apt update && apt install -y python3 && \
88-
rm -rf /var/lib/{apt,dpkg,cache,log}/
87+
RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
88+
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
89+
python3 -m pip install transformers tokenizers
8990

9091
WORKDIR /usr/local/tgi/bin
9192

92-
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
93+
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
9394
ENV TOKENIZERS_PARALLELISM=false
9495
ENV OMPI_MCA_plm_rsh_agent=""
9596

backends/trtllm/include/backend.h

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#ifndef TGI_TRTLLM_BACKEND_H
66
#define TGI_TRTLLM_BACKEND_H
77

8+
#include <array>
89
#include <cmath>
910
#include <filesystem>
1011
#include <span>
@@ -19,11 +20,16 @@
1920
using json = nlohmann::json;
2021
namespace tle = tensorrt_llm::executor;
2122

23+
24+
#define CAST_SIZETYPE(x) static_cast<tle::SizeType32>(x)
25+
2226
namespace huggingface::tgi::backends {
2327
using RequestId = tle::IdType;
2428
using TokenId = tle::TokenIdType;
2529

2630
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
31+
constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
32+
"Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
2733
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
2834
"Submitting inference [{}] to the executor ({:d} already in-flight)");
2935
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
@@ -35,6 +41,12 @@ namespace huggingface::tgi::backends {
3541
*/
3642
void InitializeBackend();
3743

44+
/**
45+
* Initialize logging mechanism
46+
*/
47+
void InitializeLogging();
48+
49+
3850
/**
3951
*
4052
* @param config TensorRT-LLM configuration object
@@ -43,6 +55,14 @@ namespace huggingface::tgi::backends {
4355
*/
4456
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
4557

58+
/**
59+
*
60+
* @param worldSize
61+
* @param workerPath
62+
* @return
63+
*/
64+
tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
65+
4666
/**
4767
* Get the sampling configuration from the parameters provided by TGI
4868
* @param topK
@@ -62,6 +82,14 @@ namespace huggingface::tgi::backends {
6282
uint64_t seed
6383
) noexcept;
6484

85+
/**
86+
* Attempt to retrieve the
87+
* @param generationConfigPath
88+
* @return
89+
*/
90+
std::optional<std::list<std::vector<TokenId>>>
91+
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
92+
6593
/**
6694
*
6795
*/
@@ -72,6 +100,7 @@ namespace huggingface::tgi::backends {
72100

73101
/** Frequently accessed variables cached here **/
74102
uint32_t maxNumTokens;
103+
std::list<std::vector<TokenId>> stopWords;
75104

76105
public:
77106
explicit TensorRtLlmBackend(
@@ -91,20 +120,20 @@ namespace huggingface::tgi::backends {
91120
* @param topK
92121
* @param topP
93122
* @param temperature
94-
* @param repetition_penalty
95-
* @param frequency_penalty
123+
* @param repetitionPenalty
124+
* @param frequencyPenalty
96125
* @param seed
97126
* @return Request id related to this generation for reference
98127
*/
99128
[[nodiscard]] RequestId Submit(
100129
const std::vector<TokenId> &tokens,
101-
const uint32_t maxNewTokens,
102-
const int32_t topK,
103-
const float_t topP,
104-
const float_t temperature,
105-
const float_t repetition_penalty,
106-
const float_t frequency_penalty,
107-
const uint64_t seed
130+
uint32_t maxNewTokens,
131+
int32_t topK,
132+
float_t topP,
133+
float_t temperature,
134+
float_t repetitionPenalty,
135+
float_t frequencyPenalty,
136+
uint64_t seed
108137
);
109138

110139
[[nodiscard]] std::vector<tle::Response> PullNewTokens();

backends/trtllm/include/hardware.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ namespace huggingface::hardware::cuda {
2323
int32_t major;
2424
int32_t minor;
2525

26-
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
26+
[[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
2727

28-
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
28+
[[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
2929
};
3030

3131
CudaComputeCapabilities GetCudaComputeCapabilities() {

backends/trtllm/lib/backend.cpp

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include "backend.h"
99
#include "hardware.h"
1010

11-
void huggingface::tgi::backends::InitializeBackend() {
11+
12+
void huggingface::tgi::backends::InitializeLogging() {
13+
#ifdef NDEBUG
1214
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
1315
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
1416
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
@@ -20,11 +22,18 @@ void huggingface::tgi::backends::InitializeBackend() {
2022
else
2123
spdlog::set_level(spdlog::level::info);
2224
}
25+
#else
26+
spdlog::set_level(spdlog::level::debug);
27+
#endif
28+
}
2329

30+
void huggingface::tgi::backends::InitializeBackend() {
2431
SPDLOG_INFO("Initializing Backend...");
2532
nvmlInit_v2();
2633
initTrtLlmPlugins();
2734

35+
InitializeLogging();
36+
2837
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
2938
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
3039
if (numGpus.has_value()) {
@@ -34,6 +43,23 @@ void huggingface::tgi::backends::InitializeBackend() {
3443
}
3544
}
3645

46+
[[nodiscard]]
47+
tle::ParallelConfig
48+
huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
49+
auto mode = tle::CommunicationMode::kLEADER;
50+
std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;
51+
52+
if (worldSize > 1) {
53+
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
54+
mode = tle::CommunicationMode::kORCHESTRATOR;
55+
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, workerPath, nullptr, true);
56+
} else {
57+
SPDLOG_INFO("Detected single engine deployment, using leader mode");
58+
}
59+
60+
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
61+
}
62+
3763
[[nodiscard]]
3864
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
3965
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
@@ -42,29 +68,13 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
4268
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
4369

4470
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
45-
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
46-
SPDLOG_INFO("Detected single engine deployment, using leader mode");
47-
execConfig.setParallelConfig(tle::ParallelConfig(
48-
tle::CommunicationType::kMPI,
49-
tle::CommunicationMode::kLEADER,
50-
std::nullopt,
51-
std::nullopt,
52-
std::nullopt
53-
));
54-
} else { // Multiple engines -> using orchestrator mode (MPI involved)
55-
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
56-
execConfig.setParallelConfig(tle::ParallelConfig(
57-
tle::CommunicationType::kMPI,
58-
tle::CommunicationMode::kORCHESTRATOR,
59-
std::nullopt,
60-
std::nullopt,
61-
tle::OrchestratorConfig(true, workerPath, nullptr, true)
62-
));
63-
}
71+
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
72+
execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
6473

6574
// Define some configuration variables
6675
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
67-
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
76+
execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
77+
execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
6878
return execConfig;
6979
}
7080

@@ -93,28 +103,66 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
93103
);
94104
}
95105

106+
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
107+
huggingface::tgi::backends::GetStopWordsFromConfig(
108+
const std::filesystem::path &generationConfigPath) noexcept {
109+
if (exists(generationConfigPath)) {
110+
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
111+
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
112+
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
113+
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
114+
115+
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
116+
return {tokenIdObj.template get<tle::TokenIdType>()};
117+
};
118+
119+
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
120+
return stopWords;
121+
} else {
122+
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
123+
}
124+
} else {
125+
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
126+
}
127+
128+
return std::nullopt;
129+
}
130+
96131
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
97132
const std::filesystem::path &enginesFolder,
98133
const std::filesystem::path &executorWorker
99134
) :
100135
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
101136
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
102137
GetExecutorConfig(config, executorWorker.string())) {
103-
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
138+
139+
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
140+
141+
// Ensure we have enough GPUs on the system
142+
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
143+
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
144+
if (numGpus < worldSize) {
145+
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
146+
// todo : raise exception to catch on rust side
147+
}
104148

105149
// Cache variables
106150
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
151+
152+
// Attempt to discover stopWords from the generation_config.json
153+
const auto generationConfigPath = enginesFolder / "generation_config.json";
154+
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
107155
}
108156

109157
[[nodiscard("Returned number of requests needs to be consumed")]]
110158
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
159+
#ifdef NDEBUG
160+
return executor.getNumResponsesReady();
161+
#else
111162
const auto numResponses = executor.getNumResponsesReady();
112-
113-
#ifndef NDEBUG
114-
if(numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
115-
#endif
116-
163+
if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
117164
return numResponses;
165+
#endif
118166
}
119167

120168
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
@@ -124,8 +172,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
124172
const int32_t topK,
125173
const float_t topP,
126174
const float_t temperature,
127-
const float_t repetition_penalty,
128-
const float_t frequency_penalty,
175+
const float_t repetitionPenalty,
176+
const float_t frequencyPenalty,
129177
const uint64_t seed
130178
) {
131179
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
@@ -135,14 +183,19 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
135183
const auto &lastIteration = iterations.front();
136184

137185
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
138-
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
186+
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
139187
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
140188
}
141189
#endif
142190

143-
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
144-
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
145-
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
191+
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
192+
193+
// Build the request
194+
auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
195+
request.setStopWords(stopWords);
196+
197+
// Submit to the executor for batching
198+
return executor.enqueueRequest(request);
146199
}
147200

148201
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {

backends/trtllm/src/ffi.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
2323

2424

2525
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
26-
rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
27-
int32_t topK, float_t topP, float_t temperature,
28-
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
26+
rust::Slice<const uint32_t> tokens,
27+
uint32_t maxNewTokens,
28+
int32_t topK,
29+
float_t topP,
30+
float_t temperature,
31+
float_t repetition_penalty,
32+
float_t frequency_penalty,
33+
uint64_t seed) {
2934

3035
// This will copy all the items from the initial slice
3136
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());

0 commit comments

Comments
 (0)