Skip to content

Commit

Permalink
[lmi] support auto configuration of mistral models provided as mistra… (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Feb 5, 2025
1 parent 509b7bd commit ba3950c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 8 deletions.
7 changes: 2 additions & 5 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,11 @@
"option.max_model_len": 8192,
},
"pixtral-12b": {
"option.model_id": "s3://djl-llm/pixtral-12b/",
"option.model_id": "s3://djl-llm/pixtral-12b-2409/",
"option.max_model_len": 8192,
"option.max_rolling_batch_size": 16,
"option.tokenizer_mode": "mistral",
"option.limit_mm_per_prompt": "image=4",
"option.entryPoint": "djl_python.huggingface"
},
"llama32-11b-multimodal": {
"option.model_id": "s3://djl-llm/llama-3-2-11b-vision-instruct/",
Expand Down Expand Up @@ -1061,13 +1060,11 @@
"option.max_model_len": 8192,
},
"pixtral-12b": {
"option.model_id": "s3://djl-llm/pixtral-12b/",
"option.model_id": "s3://djl-llm/pixtral-12b-2409/",
"option.max_model_len": 8192,
"option.max_rolling_batch_size": 16,
"option.tokenizer_mode": "mistral",
"option.limit_mm_per_prompt": "image=4",
"option.entryPoint": "djl_python.huggingface",
"option.tensor_parallel_degree": "max"
},
"llama32-11b-multimodal": {
"option.model_id": "s3://djl-llm/llama-3-2-11b-vision-instruct/",
Expand Down
52 changes: 50 additions & 2 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ static void convertRustModel(ModelInfo<?, ?> info) throws IOException {
* @return the Huggingface config.json file URI
*/
public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String modelId) {
String[] possibleConfigFiles = {"config.json", "adapter_config.json", "model_index.json"};
String[] possibleConfigFiles = {
"config.json", "adapter_config.json", "model_index.json", "params.json"
};
URI configUri;
for (String configFile : possibleConfigFiles) {
configUri = findHuggingFaceConfigUriForConfigFile(modelInfo, modelId, configFile);
Expand Down Expand Up @@ -369,8 +371,12 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo<?, ?>
if (hubToken != null) {
headers.put("Authorization", "Bearer " + hubToken);
}

try (InputStream is = Utils.openUrl(modelConfigUri.toURL(), headers)) {
if (modelConfigUri.toString().endsWith("params.json")) {
MistralModelConfig mistralConfig =
JsonUtils.GSON.fromJson(Utils.toString(is), MistralModelConfig.class);
return new HuggingFaceModelConfig(mistralConfig);
}
return JsonUtils.GSON.fromJson(Utils.toString(is), HuggingFaceModelConfig.class);
} catch (IOException | JsonSyntaxException e) {
throw new ModelNotFoundException("Invalid huggingface model id: " + modelId, e);
Expand Down Expand Up @@ -518,6 +524,17 @@ public static final class HuggingFaceModelConfig {

private Set<String> allArchitectures;

HuggingFaceModelConfig(MistralModelConfig mistralModelConfig) {
this.modelType = "mistral";
this.configArchitectures = List.of("MistralForCausalLM");
this.hiddenSize = mistralModelConfig.dim;
this.intermediateSize = mistralModelConfig.hiddenDim;
this.numAttentionHeads = mistralModelConfig.nHeads;
this.numHiddenLayers = mistralModelConfig.nLayers;
this.numKeyValueHeads = mistralModelConfig.nKvHeads;
this.vocabSize = mistralModelConfig.vocabSize;
}

/**
* Returns the model type of this HuggingFace model.
*
Expand Down Expand Up @@ -659,4 +676,35 @@ private void determineAllArchitectures() {
}
}
}

/**
* This represents a Mistral Model Config. Mistral artifacts are different from HuggingFace
* artifacts. Some Mistral vended models only come in Mistral artifact form.
*/
static final class MistralModelConfig {

@SerializedName("dim")
private int dim;

@SerializedName("n_layers")
private int nLayers;

@SerializedName("head_dim")
private int headDim;

@SerializedName("hidden_dim")
private int hiddenDim;

@SerializedName("n_heads")
private int nHeads;

@SerializedName("n_kv_heads")
private int nKvHeads;

@SerializedName("vocab_size")
private int vocabSize;

@SerializedName("vision_encoder")
private Map<String, String> visionEncoder;
}
}
3 changes: 2 additions & 1 deletion wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ public void testInferLmiEngine() throws IOException, ModelException {
"NousResearch/Hermes-2-Pro-Mistral-7B", "lmi-dist",
"src/test/resources/local-hf-model", "lmi-dist",
"HuggingFaceH4/tiny-random-LlamaForSequenceClassification", "disable",
"THUDM/chatglm3-6b", "lmi-dist");
"THUDM/chatglm3-6b", "lmi-dist",
"src/test/resources/local-mistral-model", "lmi-dist");
Path modelStore = Paths.get("build/models");
Path modelDir = modelStore.resolve("lmi_test_model");
Path prop = modelDir.resolve("serving.properties");
Expand Down
11 changes: 11 additions & 0 deletions wlm/src/test/resources/local-mistral-model/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"dim": 4096,
"n_layers": 32,
"head_dim": 128,
"hidden_dim": 14336,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"vocab_size": 32768,
"rope_theta": 1000000.0
}

0 comments on commit ba3950c

Please sign in to comment.