Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,19 @@ cc_library(
)
""",
)

new_git_repository(
name = "dr_libs",
remote = "https://github.com/mackron/dr_libs",
commit = "24d738be2349fd4b6fe50eeaa81f5bd586267fd0",
build_file_content = """
cc_library(
name = "dr",
hdrs = ["dr_flac.h", "dr_mp3.h", "dr_wav.h"],
visibility = ["//visibility:public"],
local_defines = [
],
)
""",
)

28 changes: 28 additions & 0 deletions demos/audio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Audio endpoints


## Audio synthesis

python export_model.py speech --source_model microsoft/speecht5_tts --vocoder microsoft/speecht5_hifigan --weight-format fp16

docker run -p 8000:8000 -d -v $(pwd)/models/:/models openvino/model_server --model_name speecht5_tts --model_path /models/microsoft/speecht5_tts --rest_port 8000

curl http://localhost/v3/audio/speech -H "Content-Type: application/json" -d "{\"model\": \"speecht5_tts\", \"input\": \"The quick brown fox jumped over the lazy dog.\"}" -o audio.wav





## Audio transcription

python export_model.py transcription --source_model openai/whisper-large-v2 --weight-format fp16 --target_device GPU


docker run -p 8000:8000 -it --device /dev/dri -u 0 -v $(pwd)/models/:/models openvino/model_server --model_name whisper --model_path /models/openai/whisper-large-v2 --rest_port 8000


curl http://localhost/v3/audio/transcriptions -H "Content-Type: multipart/form-data" -F file="@audio.wav" -F model="whisper"




42 changes: 42 additions & 0 deletions demos/audio/openai_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pathlib import Path
from openai import OpenAI

prompt = "Intel Corporation is an American multinational technology company headquartered in Santa Clara, California.[3] Intel designs, manufactures, and sells computer components such as central processing units (CPUs) and related products for business and consumer markets. It was the world's third-largest semiconductor chip manufacturer by revenue in 2024[4] and has been included in the Fortune 500 list of the largest United States corporations by revenue since 2007. It was one of the first companies listed on Nasdaq. Since 2025, it is partially owned by the United States government."
filename = "speech.wav"
url="http://localhost:80/v3"


speech_file_path = Path(__file__).parent / "speech.wav"
client = OpenAI(base_url=url, api_key="not_used")

# with client.audio.speech.with_streaming_response.create(
# model="whisper",
# voice="alloy",
# input=prompt
# ) as response:
# response.stream_to_file(speech_file_path)

audio_file = open("speech.wav", "rb")
transcript = client.audio.transcriptions.create(
model="whisper",
file=audio_file
)


print(transcript)
2 changes: 1 addition & 1 deletion demos/common/export_models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ options:

#### Text Generation CPU Deployment
```console
python export_model.py text_generation --source_model meta-llama/Meta-Llama-3-8B-Instruct --weight-format fp16 --kv_cache_precision u8 --config_file_path models/config_all.json --model_repository_path models
python demos\common\export_models\export_model.py text_generation --source_model meta-llama/Llama-3.2-1B-Instruct --weight-format int4 --kv_cache_precision u8 --config_file_path config.json --model_repository_path audio
```

#### GPU Deployment (Low Concurrency, Limited Memory)
Expand Down
86 changes: 83 additions & 3 deletions demos/common/export_models/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,58 @@ def add_common_arguments(parser):
parser_image_generation.add_argument('--max_num_images_per_prompt', type=int, default=0, help='Max allowed number of images client is allowed to request for a given prompt', dest='max_num_images_per_prompt')
parser_image_generation.add_argument('--default_num_inference_steps', type=int, default=0, help='Default number of inference steps when not specified by client', dest='default_num_inference_steps')
parser_image_generation.add_argument('--max_num_inference_steps', type=int, default=0, help='Max allowed number of inference steps client is allowed to request for a given prompt', dest='max_num_inference_steps')

parser_speech_generation = subparsers.add_parser('speech', help='export model for speech synthesis endpoint')
add_common_arguments(parser_speech_generation)
parser_speech_generation.add_argument('--num_streams', default=0, type=int, help='The number of parallel execution streams to use for the models in the pipeline.', dest='num_streams')
parser_speech_generation.add_argument('--vocoder', type=str, help='The vocoder model to use for speech synthesis. For example microsoft/speecht5_hifigan', dest='vocoder')

parser_transcription_generation = subparsers.add_parser('transcription', help='export model for speech transcription endpoint')
add_common_arguments(parser_transcription_generation)
parser_transcription_generation.add_argument('--num_streams', default=0, type=int, help='The number of parallel execution streams to use for the models in the pipeline.', dest='num_streams')
args = vars(parser.parse_args())

speech_graph_template = """
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node {
name: "SpeechExecutor"
input_side_packet: "SPEECH_NODE_RESOURCES:speech_servable"
calculator: "SpeechCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.SpeechCalculatorOptions]: {
models_path: "{{model_path}}",
mode: TEXT_TO_SPEECH,
plugin_config: '{ "NUM_STREAMS": "{{num_streams|default(1, true)}}" }',
device: "{{target_device|default("CPU", true)}}"
}
}
}
"""

transcription_graph_template = """
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node {
name: "SpeechExecutor"
input_side_packet: "SPEECH_NODE_RESOURCES:speech_servable"
calculator: "SpeechCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.SpeechCalculatorOptions]: {
models_path: "{{model_path}}",
mode: SPEECH_TO_TEXT,
plugin_config: '{ "NUM_STREAMS": "{{num_streams|default(1, true)}}" }',
device: "{{target_device|default("CPU", true)}}"
}
}
}
"""


embedding_graph_template = """input_stream: "REQUEST_PAYLOAD:input"
output_stream: "RESPONSE_PAYLOAD:output"
node {
Expand Down Expand Up @@ -526,7 +576,7 @@ def export_embeddings_model(model_repository_path, source_model, model_name, pre
print("Created subconfig {}".format(os.path.join(model_repository_path, model_name, 'subconfig.json')))
add_servable_to_config(config_file_path, model_name, os.path.relpath(os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))

def export_embeddings_model_ov(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, truncate=True):
def export_embeddings_model_ov(model_repository_path, source_model, model_name, precision, task_parameters):
set_max_context_length = ""
destination_path = os.path.join(model_repository_path, model_name)
print("Exporting embeddings model to ",destination_path)
Expand All @@ -543,7 +593,32 @@ def export_embeddings_model_ov(model_repository_path, source_model, model_name,
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
add_servable_to_config(config_file_path, model_name, os.path.relpath(os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))

def export_speech_model(model_repository_path, source_model, model_name, precision, task_parameters):
destination_path = os.path.join(model_repository_path, model_name)
print("Exporting speech model to ",destination_path)
if not os.path.isdir(destination_path) or args['overwrite_models']:
optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code --model-kwargs \"{{\\\"vocoder\\\": \\\"{}\\\"}}\" {}".format(source_model, precision, task_parameters['vocoder'], destination_path)
if os.system(optimum_command):
raise ValueError("Failed to export speech model", source_model)
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(speech_graph_template)
graph_content = gtemplate.render(model_path="./", **task_parameters)
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))

def export_transcription_model(model_repository_path, source_model, model_name, precision, task_parameters):
destination_path = os.path.join(model_repository_path, model_name)
print("Exporting transcription model to ",destination_path)
if not os.path.isdir(destination_path) or args['overwrite_models']:
optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code {}".format(source_model, precision, destination_path)
if os.system(optimum_command):
raise ValueError("Failed to export transcription model", source_model)
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(transcription_graph_template)
graph_content = gtemplate.render(model_path="./", **task_parameters)
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))

def export_rerank_model_ov(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, max_doc_length):
destination_path = os.path.join(model_repository_path, model_name)
Expand Down Expand Up @@ -674,14 +749,19 @@ def export_image_generation_model(model_repository_path, source_model, model_nam
export_embeddings_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['truncate'])

elif args['task'] == 'embeddings_ov':
export_embeddings_model_ov(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['truncate'])
export_embeddings_model_ov(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters)

elif args['task'] == 'rerank':
export_rerank_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['max_doc_length'])

elif args['task'] == 'rerank_ov':
export_rerank_model_ov(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, args['config_file_path'], args['max_doc_length'])

elif args['task'] == 'speech':
export_speech_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters)

elif args['task'] == 'transcription':
export_transcription_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters)
elif args['task'] == 'image_generation':
template_parameters = {k: v for k, v in args.items() if k in [
'ov_cache_dir',
Expand Down
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ ovms_cc_library(
"//conditions:default": [],
"//:not_disable_mediapipe" : [
"//src/image_gen:image_gen_calculator",
"//src/speech:speech_calculator",
"//src/image_gen:imagegen_init",
"//src/llm:openai_completions_api_handler",
"//src/embeddings:embeddingscalculator",
Expand Down
2 changes: 2 additions & 0 deletions src/http_rest_api_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ static Status createV3HttpPayload(
} else {
SPDLOG_DEBUG("Model name from deduced from MultiPart field: {}", modelName);
}
auto stream = multiPartParser->getFieldByName("stream");
SPDLOG_ERROR("{}", stream);
ensureJsonParserInErrorState(parsedJson);
} else if (isApplicationJson) {
{
Expand Down
23 changes: 23 additions & 0 deletions src/mediapipe_internal/mediapipegraphdefinition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const std::string MediapipeGraphDefinition::SCHEDULER_CLASS_NAME{"Mediapipe"};
const std::string MediapipeGraphDefinition::PYTHON_NODE_CALCULATOR_NAME{"PythonExecutorCalculator"};
const std::string MediapipeGraphDefinition::LLM_NODE_CALCULATOR_NAME{"LLMCalculator"};
const std::string MediapipeGraphDefinition::IMAGE_GEN_CALCULATOR_NAME{"ImageGenCalculator"};
const std::string MediapipeGraphDefinition::SPEECH_NODE_CALCULATOR_NAME{"SpeechCalculator"};
const std::string MediapipeGraphDefinition::EMBEDDINGS_NODE_CALCULATOR_NAME{"EmbeddingsCalculatorOV"};
const std::string MediapipeGraphDefinition::RERANK_NODE_CALCULATOR_NAME{"RerankCalculatorOV"};

Expand Down Expand Up @@ -554,6 +555,28 @@ Status MediapipeGraphDefinition::initializeNodes() {
rerankServableMap.insert(std::pair<std::string, std::shared_ptr<RerankServable>>(nodeName, std::move(servable)));
rerankServablesCleaningGuard.disableCleaning();
}
if (endsWith(config.node(i).calculator(), SPEECH_NODE_CALCULATOR_NAME)) {
auto& speechServableMap = this->sidePacketMaps.speechServableMap;
ResourcesCleaningGuard<SpeechServableMap> speechServablesCleaningGuard(speechServableMap);
if (!config.node(i).node_options().size()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Speech node missing options in graph: {}. ", this->name);
return StatusCode::LLM_NODE_MISSING_OPTIONS;
}
if (config.node(i).name().empty()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Speech node name is missing in graph: {}. ", this->name);
return StatusCode::LLM_NODE_MISSING_NAME;
}
std::string nodeName = config.node(i).name();
if (speechServableMap.find(nodeName) != speechServableMap.end()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Speech node name: {} already used in graph: {}. ", nodeName, this->name);
return StatusCode::LLM_NODE_NAME_ALREADY_EXISTS;
}
mediapipe::SpeechCalculatorOptions nodeOptions;
config.node(i).node_options(0).UnpackTo(&nodeOptions);
std::shared_ptr<SpeechServable> servable = std::make_shared<SpeechServable>(nodeOptions.models_path(), nodeOptions.device(), mgconfig.getBasePath(), nodeOptions.mode());
speechServableMap.insert(std::pair<std::string, std::shared_ptr<SpeechServable>>(nodeName, std::move(servable)));
speechServablesCleaningGuard.disableCleaning();
}
}
return StatusCode::OK;
}
Expand Down
8 changes: 7 additions & 1 deletion src/mediapipe_internal/mediapipegraphdefinition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "../sidepacket_servable.hpp"
#include "../embeddings/embeddings_servable.hpp"
#include "../rerank/rerank_servable.hpp"
#include "../speech/speech_servable.hpp"

namespace ovms {
class MediapipeGraphDefinitionUnloadGuard;
Expand All @@ -62,6 +63,7 @@ struct ImageGenerationPipelines;
using PythonNodeResourcesMap = std::unordered_map<std::string, std::shared_ptr<PythonNodeResources>>;
using GenAiServableMap = std::unordered_map<std::string, std::shared_ptr<GenAiServable>>;
using RerankServableMap = std::unordered_map<std::string, std::shared_ptr<RerankServable>>;
using SpeechServableMap = std::unordered_map<std::string, std::shared_ptr<SpeechServable>>;
using EmbeddingsServableMap = std::unordered_map<std::string, std::shared_ptr<EmbeddingsServable>>;
using ImageGenerationPipelinesMap = std::unordered_map<std::string, std::shared_ptr<ImageGenerationPipelines>>;

Expand All @@ -71,19 +73,22 @@ struct GraphSidePackets {
ImageGenerationPipelinesMap imageGenPipelinesMap;
EmbeddingsServableMap embeddingsServableMap;
RerankServableMap rerankServableMap;
SpeechServableMap speechServableMap;
void clear() {
pythonNodeResourcesMap.clear();
genAiServableMap.clear();
imageGenPipelinesMap.clear();
embeddingsServableMap.clear();
rerankServableMap.clear();
speechServableMap.clear();
}
bool empty() {
return (pythonNodeResourcesMap.empty() &&
genAiServableMap.empty() &&
imageGenPipelinesMap.empty() &&
embeddingsServableMap.empty() &&
rerankServableMap.empty());
rerankServableMap.empty() &&
speechServableMap.empty());
}
};

Expand Down Expand Up @@ -124,6 +129,7 @@ class MediapipeGraphDefinition {
static const std::string IMAGE_GEN_CALCULATOR_NAME;
static const std::string EMBEDDINGS_NODE_CALCULATOR_NAME;
static const std::string RERANK_NODE_CALCULATOR_NAME;
static const std::string SPEECH_NODE_CALCULATOR_NAME;
Status waitForLoaded(std::unique_ptr<MediapipeGraphDefinitionUnloadGuard>& unloadGuard, const uint32_t waitForLoadedTimeoutMicroseconds = WAIT_FOR_LOADED_DEFAULT_TIMEOUT_MICROSECONDS);

// Pipelines are not versioned and any available definition has constant version equal 1.
Expand Down
4 changes: 3 additions & 1 deletion src/mediapipe_internal/mediapipegraphexecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ MediapipeGraphExecutor::MediapipeGraphExecutor(
const GenAiServableMap& llmNodeResourcesMap,
const EmbeddingsServableMap& embeddingsServableMap,
const RerankServableMap& rerankServableMap,
const SpeechServableMap& speechServableMap,
PythonBackend* pythonBackend,
MediapipeServableMetricReporter* mediapipeServableMetricReporter) :
name(name),
Expand All @@ -56,7 +57,7 @@ MediapipeGraphExecutor::MediapipeGraphExecutor(
outputTypes(std::move(outputTypes)),
inputNames(std::move(inputNames)),
outputNames(std::move(outputNames)),
sidePacketMaps({pythonNodeResourcesMap, llmNodeResourcesMap, {}, embeddingsServableMap, rerankServableMap}),
sidePacketMaps({pythonNodeResourcesMap, llmNodeResourcesMap, {}, embeddingsServableMap, rerankServableMap, speechServableMap}),
pythonBackend(pythonBackend),
currentStreamTimestamp(STARTING_TIMESTAMP),
mediapipeServableMetricReporter(mediapipeServableMetricReporter) {}
Expand Down Expand Up @@ -88,6 +89,7 @@ const std::string MediapipeGraphExecutor::LLM_SESSION_SIDE_PACKET_TAG = "llm";
const std::string MediapipeGraphExecutor::IMAGE_GEN_SESSION_SIDE_PACKET_TAG = "pipes";
const std::string MediapipeGraphExecutor::EMBEDDINGS_SESSION_SIDE_PACKET_TAG = "embeddings_servable";
const std::string MediapipeGraphExecutor::RERANK_SESSION_SIDE_PACKET_TAG = "rerank_servable";
const std::string MediapipeGraphExecutor::SPEECH_SESSION_SIDE_PACKET_TAG = "speech_servable";
const ::mediapipe::Timestamp MediapipeGraphExecutor::STARTING_TIMESTAMP = ::mediapipe::Timestamp(0);

} // namespace ovms
Loading