diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 4065fc5561..054ce193ce 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -157,7 +157,6 @@ async def test_restful_api(setup): "model_uid": "test_restful_api2", "model_name": "orca", "quantization": "q4_0", - "embedding": "True", } response = requests.post(url, json=payload) @@ -177,28 +176,3 @@ async def test_restful_api(setup): url = f"{endpoint}/v1/models/test_restful_api2" response = requests.delete(url) - - # test for model that does not specify embedding - url = f"{endpoint}/v1/models" - - payload = { - "model_uid": "test_restful_api3", - "model_name": "orca", - "quantization": "q4_0", - } - - response = requests.post(url, json=payload) - response_data = response.json() - model_uid_res = response_data["model_uid"] - assert model_uid_res == "test_restful_api3" - - url = f"{endpoint}/v1/embeddings" - payload = { - "model": "test_restful_api3", - "input": "The food was delicious and the waiter...", - } - response = requests.post(url, json=payload) - assert response.status_code == 400 - - url = f"{endpoint}/v1/models/test_restful_api3" - response = requests.delete(url) diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 09ee276629..0872233945 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -145,6 +145,7 @@ def _sanitize_model_config( else: llamacpp_model_config.setdefault("n_ctx", 2048) + llamacpp_model_config.setdefault("embedding", True) llamacpp_model_config.setdefault("use_mmap", False) llamacpp_model_config.setdefault("use_mlock", True) diff --git a/xinference/tests/test_client.py b/xinference/tests/test_client.py index ea2beebde9..78ab50da2e 100644 --- a/xinference/tests/test_client.py +++ b/xinference/tests/test_client.py @@ -31,9 +31,6 @@ async def test_sync_client(setup): model = client.get_model(model_uid=model_uid) assert isinstance(model, ChatModelHandle) - with pytest.raises(RuntimeError): - model.create_embedding("The food was delicious and the waiter...") - completion = model.chat("write a poem.") assert "content" in completion["choices"][0]["message"] @@ -44,7 +41,6 @@ async def test_sync_client(setup): model_name="orca", model_size_in_billions=3, quantization="q4_0", - embedding="True", ) model = client.get_model(model_uid=model_uid) @@ -106,9 +102,6 @@ async def test_RESTful_client(setup): for chunk in streaming_response: assert "content" or "role" in chunk["choices"][0]["delta"] - with pytest.raises(RuntimeError): - model.create_embedding("The food was delicious and the waiter...") - client.terminate_model(model_uid=model_uid) assert len(client.list_models()) == 0 @@ -119,7 +112,6 @@ async def test_RESTful_client(setup): model_name="orca", model_size_in_billions=3, quantization="q4_0", - embedding="True", ) model2 = client.get_model(model_uid=model_uid2)