-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f165d3f
commit 379948f
Showing
1 changed file
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import json | ||
import subprocess | ||
import requests | ||
|
||
|
||
class ServerWrapper: | ||
def __init__(self, config=None): | ||
self.api_url = config.get("url") if config else "http://localhost:8080" | ||
|
||
|
||
def check_server(self): | ||
"""Check the status of the sever.""" | ||
response = requests.get(f"{self.api_url}/health") | ||
if response.status_code == 200: | ||
data = response.json() | ||
return { | ||
"status": data.get("status"), | ||
"slots_idle": data.get("slots_idle"), | ||
"slots_processing": data.get("slots_processing") | ||
} | ||
else: | ||
return { | ||
"status": "error", | ||
"code": response.status_code, | ||
"message": response.text | ||
} | ||
|
||
def tokenize(self, content): | ||
data = {"content": content} | ||
response = requests.post( | ||
f"{self.api_url}/tokenize", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps(data) | ||
) | ||
|
||
if response.status_code != 200: | ||
raise ValueError(f"Error from server while tokenizing: {response.text}") | ||
|
||
return response.json() | ||
|
||
def detokenize(self, tokens): | ||
data = {"tokens": tokens} | ||
response = requests.post( | ||
f"{self.api_url}/detokenize", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps(data) | ||
) | ||
|
||
if response.status_code != 200: | ||
raise ValueError(f"Error from server while detokenizing: {response.text}") | ||
|
||
return response.json() | ||
|
||
def generate_embedding(self, content, image_data=None): | ||
data = { | ||
"content": content, | ||
"image_data": image_data or [] | ||
} | ||
|
||
response = requests.post( | ||
f"{self.api_url}/embedding", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps(data) | ||
) | ||
|
||
if response.status_code != 200: | ||
raise ValueError(f"Error from server: {response.text}") | ||
|
||
return response.json() | ||
|
||
def generate_completion(self, prompt, **kwargs): | ||
prompt = f"### Human: {prompt}\n### Assistant: " | ||
data = { | ||
"prompt": prompt, | ||
"temperature": kwargs.get("temperature", 0.8), | ||
"top_k": 40, | ||
"top_p": 0.9, | ||
# "n_keep": self.n_keep, | ||
"n_predict": 16, | ||
"cache_prompt": True, | ||
"stop": ["\n### Human:"], | ||
"stream": True, | ||
# ... add other parameters here, following the same pattern ... | ||
} | ||
|
||
# Add optional parameters only if they are explicitly provided | ||
for param in ["dynatemp_range", "dynatemp_exponent", "top_k", "top_p", "min_p", "n_predict", | ||
"n_keep", "stream", "stop", "tfs_z", "typical_p", "repeat_penalty", "repeat_last_n", | ||
"penalize_nl", "presence_penalty", "frequency_penalty", "penalty_prompt", "mirostat", | ||
"mirostat_tau", "mirostat_eta", "grammar", "seed", "ignore_eos", "logit_bias", "n_probs", | ||
"min_keep", "image_data", "slot_id", "cache_prompt", "system_prompt", "samplers"]: | ||
if kwargs.get(param) is not None: | ||
data[param] = kwargs[param] | ||
|
||
# Send the POST request | ||
print("Sending request: ", json.dumps(data)) | ||
response = requests.post( | ||
f"{self.api_url}/completion", | ||
headers={"Content-Type": "application/json"}, | ||
data=json.dumps(data), | ||
stream=True, | ||
) | ||
|
||
# Handle response | ||
if response.status_code != 200: | ||
raise ValueError(f"Error in chat completion: {response.text}") | ||
|
||
answer = "" | ||
for line in response.iter_lines(): | ||
if line: | ||
decoded_line = line.decode("utf-8") | ||
if decoded_line.startswith("data: "): | ||
json_content = decoded_line[len("data: ") :] | ||
try: | ||
data_segment = json.loads(json_content) | ||
answer += data_segment.get("content", "") | ||
except json.JSONDecodeError: | ||
print(f"Error decoding JSON segment: {json_content}") | ||
|
||
return answer.strip() | ||
|
||
def get_metrics(self): | ||
response = requests.get(f"{self.api_url}/metrics") | ||
|
||
if response.status_code != 200: | ||
raise ValueError(f"Error from server while fetching metrics: {response.text}") | ||
|
||
# Metrics are typically returned in a text-based format that Prometheus can parse | ||
# Rather than JSON. Here, we simply return the raw text for further processing. | ||
return response.text | ||
|
||
|
||
# Example usage | ||
if __name__ == "__main__": | ||
server_wrapper = ServerWrapper({"url": "http://localhost:8080"}) | ||
|
||
# Example usage for check server status | ||
print("Check server status: ", server_wrapper.check_server()) | ||
|
||
# Example usage for completion | ||
completion_options = { | ||
"temperature": 0.7, | ||
"top_k": 50, | ||
# ... include other options as needed ... | ||
} | ||
prompt = "Building a website can be done in 10 simple steps:" | ||
completion = server_wrapper.generate_completion(prompt, **completion_options) | ||
print("Generate Completion: ", completion) | ||
|
||
# Example usage for embedding | ||
# content = "This is a sample text for which we want to generate an embedding." | ||
# image_data = [ | ||
# { | ||
# "data": "base64_encoded_string_of_your_image", | ||
# "id": 21 | ||
# } | ||
# ] | ||
# embedding = server_wrapper.generate_embedding(content, None) | ||
# print("Generate embedding: ", embedding) | ||
|
||
# Example usage for tokenization | ||
content_to_tokenize = "Tokenize this text." | ||
tokens = server_wrapper.tokenize(content_to_tokenize) | ||
print("Generate Tokens: ", tokens) | ||
|
||
# Example usage for detokenization | ||
tokens_to_detokenize = tokens['tokens'] # assuming the token list is under 'tokens' key | ||
text = server_wrapper.detokenize(tokens_to_detokenize) | ||
print("Generate Detokenized text:", text) | ||
|
||
# # Examle usage for metrics | ||
# metrics = server_wrapper.get_metrics() | ||
# print(metrics) | ||
|
||
|