From ba3950c737344454dcb2c8f53194f43b934ab953 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Tue, 4 Feb 2025 16:48:19 -0800 Subject: [PATCH] =?UTF-8?q?[lmi]=20support=20auto=20configuration=20of=20m?= =?UTF-8?q?istral=20models=20provided=20as=20mistra=E2=80=A6=20(#2714)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/integration/llm/prepare.py | 7 +-- .../java/ai/djl/serving/wlm/LmiUtils.java | 52 ++++++++++++++++++- .../ai/djl/serving/wlm/ModelInfoTest.java | 3 +- .../resources/local-mistral-model/params.json | 11 ++++ 4 files changed, 65 insertions(+), 8 deletions(-) create mode 100644 wlm/src/test/resources/local-mistral-model/params.json diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 2a7f1f15c..a820a96c9 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -683,12 +683,11 @@ "option.max_model_len": 8192, }, "pixtral-12b": { - "option.model_id": "s3://djl-llm/pixtral-12b/", + "option.model_id": "s3://djl-llm/pixtral-12b-2409/", "option.max_model_len": 8192, "option.max_rolling_batch_size": 16, "option.tokenizer_mode": "mistral", "option.limit_mm_per_prompt": "image=4", - "option.entryPoint": "djl_python.huggingface" }, "llama32-11b-multimodal": { "option.model_id": "s3://djl-llm/llama-3-2-11b-vision-instruct/", @@ -1061,13 +1060,11 @@ "option.max_model_len": 8192, }, "pixtral-12b": { - "option.model_id": "s3://djl-llm/pixtral-12b/", + "option.model_id": "s3://djl-llm/pixtral-12b-2409/", "option.max_model_len": 8192, "option.max_rolling_batch_size": 16, "option.tokenizer_mode": "mistral", "option.limit_mm_per_prompt": "image=4", - "option.entryPoint": "djl_python.huggingface", - "option.tensor_parallel_degree": "max" }, "llama32-11b-multimodal": { "option.model_id": "s3://djl-llm/llama-3-2-11b-vision-instruct/", diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 21d2b9161..336473057 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -286,7 +286,9 @@ static void convertRustModel(ModelInfo info) throws IOException { * @return the Huggingface config.json file URI */ public static URI generateHuggingFaceConfigUri(ModelInfo modelInfo, String modelId) { - String[] possibleConfigFiles = {"config.json", "adapter_config.json", "model_index.json"}; + String[] possibleConfigFiles = { + "config.json", "adapter_config.json", "model_index.json", "params.json" + }; URI configUri; for (String configFile : possibleConfigFiles) { configUri = findHuggingFaceConfigUriForConfigFile(modelInfo, modelId, configFile); @@ -369,8 +371,12 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo if (hubToken != null) { headers.put("Authorization", "Bearer " + hubToken); } - try (InputStream is = Utils.openUrl(modelConfigUri.toURL(), headers)) { + if (modelConfigUri.toString().endsWith("params.json")) { + MistralModelConfig mistralConfig = + JsonUtils.GSON.fromJson(Utils.toString(is), MistralModelConfig.class); + return new HuggingFaceModelConfig(mistralConfig); + } return JsonUtils.GSON.fromJson(Utils.toString(is), HuggingFaceModelConfig.class); } catch (IOException | JsonSyntaxException e) { throw new ModelNotFoundException("Invalid huggingface model id: " + modelId, e); @@ -518,6 +524,17 @@ public static final class HuggingFaceModelConfig { private Set allArchitectures; + HuggingFaceModelConfig(MistralModelConfig mistralModelConfig) { + this.modelType = "mistral"; + this.configArchitectures = List.of("MistralForCausalLM"); + this.hiddenSize = mistralModelConfig.dim; + this.intermediateSize = mistralModelConfig.hiddenDim; + this.numAttentionHeads = mistralModelConfig.nHeads; + this.numHiddenLayers = mistralModelConfig.nLayers; + this.numKeyValueHeads = mistralModelConfig.nKvHeads; + this.vocabSize = mistralModelConfig.vocabSize; + } + /** * Returns the model type of this HuggingFace model. * @@ -659,4 +676,35 @@ private void determineAllArchitectures() { } } } + + /** + * This represents a Mistral Model Config. Mistral artifacts are different from HuggingFace + * artifacts. Some Mistral vended models only come in Mistral artifact form. + */ + static final class MistralModelConfig { + + @SerializedName("dim") + private int dim; + + @SerializedName("n_layers") + private int nLayers; + + @SerializedName("head_dim") + private int headDim; + + @SerializedName("hidden_dim") + private int hiddenDim; + + @SerializedName("n_heads") + private int nHeads; + + @SerializedName("n_kv_heads") + private int nKvHeads; + + @SerializedName("vocab_size") + private int vocabSize; + + @SerializedName("vision_encoder") + private Map visionEncoder; + } } diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index d41f77587..e7318287c 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -248,7 +248,8 @@ public void testInferLmiEngine() throws IOException, ModelException { "NousResearch/Hermes-2-Pro-Mistral-7B", "lmi-dist", "src/test/resources/local-hf-model", "lmi-dist", "HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable", - "THUDM/chatglm3-6b", "lmi-dist"); + "THUDM/chatglm3-6b", "lmi-dist", + "src/test/resources/local-mistral-model", "lmi-dist"); Path modelStore = Paths.get("build/models"); Path modelDir = modelStore.resolve("lmi_test_model"); Path prop = modelDir.resolve("serving.properties"); diff --git a/wlm/src/test/resources/local-mistral-model/params.json b/wlm/src/test/resources/local-mistral-model/params.json new file mode 100644 index 000000000..b2a72fc63 --- /dev/null +++ b/wlm/src/test/resources/local-mistral-model/params.json @@ -0,0 +1,11 @@ +{ + "dim": 4096, + "n_layers": 32, + "head_dim": 128, + "hidden_dim": 14336, + "n_heads": 32, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "vocab_size": 32768, + "rope_theta": 1000000.0 +}