Skip to content

Commit

Permalink
BUG: Set default embedding to be True (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayini1119 authored Jul 24, 2023
1 parent 18cb8d0 commit b753f98
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 34 deletions.
26 changes: 0 additions & 26 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ async def test_restful_api(setup):
"model_uid": "test_restful_api2",
"model_name": "orca",
"quantization": "q4_0",
"embedding": "True",
}

response = requests.post(url, json=payload)
Expand All @@ -177,28 +176,3 @@ async def test_restful_api(setup):

url = f"{endpoint}/v1/models/test_restful_api2"
response = requests.delete(url)

# test for model that does not specify embedding
url = f"{endpoint}/v1/models"

payload = {
"model_uid": "test_restful_api3",
"model_name": "orca",
"quantization": "q4_0",
}

response = requests.post(url, json=payload)
response_data = response.json()
model_uid_res = response_data["model_uid"]
assert model_uid_res == "test_restful_api3"

url = f"{endpoint}/v1/embeddings"
payload = {
"model": "test_restful_api3",
"input": "The food was delicious and the waiter...",
}
response = requests.post(url, json=payload)
assert response.status_code == 400

url = f"{endpoint}/v1/models/test_restful_api3"
response = requests.delete(url)
1 change: 1 addition & 0 deletions xinference/model/llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _sanitize_model_config(
else:
llamacpp_model_config.setdefault("n_ctx", 2048)

llamacpp_model_config.setdefault("embedding", True)
llamacpp_model_config.setdefault("use_mmap", False)
llamacpp_model_config.setdefault("use_mlock", True)

Expand Down
8 changes: 0 additions & 8 deletions xinference/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ async def test_sync_client(setup):
model = client.get_model(model_uid=model_uid)
assert isinstance(model, ChatModelHandle)

with pytest.raises(RuntimeError):
model.create_embedding("The food was delicious and the waiter...")

completion = model.chat("write a poem.")
assert "content" in completion["choices"][0]["message"]

Expand All @@ -44,7 +41,6 @@ async def test_sync_client(setup):
model_name="orca",
model_size_in_billions=3,
quantization="q4_0",
embedding="True",
)

model = client.get_model(model_uid=model_uid)
Expand Down Expand Up @@ -106,9 +102,6 @@ async def test_RESTful_client(setup):
for chunk in streaming_response:
assert "content" or "role" in chunk["choices"][0]["delta"]

with pytest.raises(RuntimeError):
model.create_embedding("The food was delicious and the waiter...")

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0

Expand All @@ -119,7 +112,6 @@ async def test_RESTful_client(setup):
model_name="orca",
model_size_in_billions=3,
quantization="q4_0",
embedding="True",
)

model2 = client.get_model(model_uid=model_uid2)
Expand Down

0 comments on commit b753f98

Please sign in to comment.