Skip to content

Commit

Permalink
FEAT: Support F5 TTS (xorbitsai#2626)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Dec 9, 2024
1 parent 45cedde commit 53cddf3
Show file tree
Hide file tree
Showing 60 changed files with 11,902 additions and 1 deletion.
10 changes: 10 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip install -U silero-vad
${{ env.SELF_HOST_PYTHON }} -m pip install -U pydantic
${{ env.SELF_HOST_PYTHON }} -m pip install -U diffusers
${{ env.SELF_HOST_PYTHON }} -m pip install -U torchdiffeq
${{ env.SELF_HOST_PYTHON }} -m pip install -U "x_transformers>=1.31.14"
${{ env.SELF_HOST_PYTHON }} -m pip install -U pypinyin
${{ env.SELF_HOST_PYTHON }} -m pip install -U tomli
${{ env.SELF_HOST_PYTHON }} -m pip install -U vocos
${{ env.SELF_HOST_PYTHON }} -m pip install -U jieba
${{ env.SELF_HOST_PYTHON }} -m pip install -U soundfile
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
--disable-warnings \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/core/tests/test_continuous_batching.py && \
Expand All @@ -207,6 +214,9 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_cosyvoice.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_f5tts.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_fish_speech.py
Expand Down
16 changes: 16 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ all =
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
torchdiffeq # For F5-TTS
x_transformers>=1.31.14 # For F5-TTS
pypinyin # For F5-TTS
tomli # For F5-TTS
vocos # For F5-TTS
librosa # For F5-TTS
jieba # For F5-TTS
soundfile # For F5-TTS
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down Expand Up @@ -216,6 +224,14 @@ audio =
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
torchdiffeq # For F5-TTS
x_transformers>=1.31.14 # For F5-TTS
pypinyin # For F5-TTS
tomli # For F5-TTS
vocos # For F5-TTS
librosa # For F5-TTS
jieba # For F5-TTS
soundfile # For F5-TTS
doc =
ipython>=6.5.0
sphinx>=3.0.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import os
import platform
import sys
import shutil
import subprocess
import sys
import warnings
from sysconfig import get_config_vars

Expand Down
8 changes: 8 additions & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
torchdiffeq # For F5-TTS
x_transformers>=1.31.14 # For F5-TTS
pypinyin # For F5-TTS
tomli # For F5-TTS
vocos # For F5-TTS
librosa # For F5-TTS
jieba # For F5-TTS
soundfile # For F5-TTS
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down
8 changes: 8 additions & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ loralib # For Fish Speech
ormsgpack # For Fish Speech
cachetools # For Fish Speech
silero-vad # For Fish Speech
torchdiffeq # For F5-TTS
x_transformers>=1.31.14 # For F5-TTS
pypinyin # For F5-TTS
tomli # For F5-TTS
vocos # For F5-TTS
librosa # For F5-TTS
jieba # For F5-TTS
soundfile # For F5-TTS
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down
5 changes: 5 additions & 0 deletions xinference/model/audio/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..utils import valid_model_revision
from .chattts import ChatTTSModel
from .cosyvoice import CosyVoiceModel
from .f5tts import F5TTSModel
from .fish_speech import FishSpeechModel
from .funasr import FunASRModel
from .whisper import WhisperModel
Expand Down Expand Up @@ -169,6 +170,7 @@ def create_audio_model_instance(
ChatTTSModel,
CosyVoiceModel,
FishSpeechModel,
F5TTSModel,
],
AudioModelDescription,
]:
Expand All @@ -182,6 +184,7 @@ def create_audio_model_instance(
ChatTTSModel,
CosyVoiceModel,
FishSpeechModel,
F5TTSModel,
]
if model_spec.model_family == "whisper":
if not model_spec.engine:
Expand All @@ -196,6 +199,8 @@ def create_audio_model_instance(
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
elif model_spec.model_family == "FishAudio":
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
elif model_spec.model_family == "F5-TTS":
model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
else:
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
model_description = AudioModelDescription(
Expand Down
195 changes: 195 additions & 0 deletions xinference/model/audio/f5tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import re
from io import BytesIO
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from .core import AudioModelFamilyV1

logger = logging.getLogger(__name__)


class F5TTSModel:
def __init__(
self,
model_uid: str,
model_path: str,
model_spec: "AudioModelFamilyV1",
device: Optional[str] = None,
**kwargs,
):
self._model_uid = model_uid
self._model_path = model_path
self._model_spec = model_spec
self._device = device
self._model = None
self._vocoder = None
self._kwargs = kwargs

@property
def model_ability(self):
return self._model_spec.model_ability

def load(self):
import os
import sys

# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))

from f5_tts.infer.utils_infer import load_model, load_vocoder
from f5_tts.model import DiT

vocoder_name = self._kwargs.get("vocoder_name", "vocos")
vocoder_path = self._kwargs.get("vocoder_path")

if vocoder_name not in ["vocos", "bigvgan"]:
raise Exception(f"Unsupported vocoder name: {vocoder_name}")

if vocoder_path is not None:
self._vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=True, local_path=vocoder_path
)
else:
self._vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=False)

model_cls = DiT
model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
if vocoder_name == "vocos":
exp_name = "F5TTS_Base"
ckpt_step = 1200000
elif vocoder_name == "bigvgan":
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
else:
assert False
ckpt_file = os.path.join(
self._model_path, exp_name, f"model_{ckpt_step}.safetensors"
)
logger.info(f"Loading %s...", ckpt_file)
self._model = load_model(
model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name
)

def _infer(self, ref_audio, ref_text, text_gen, model_obj, mel_spec_type, speed):
import numpy as np
from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text

config = {}
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
(
voices[voice]["ref_audio"],
voices[voice]["ref_text"],
) = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])

final_sample_rate = None
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
continue
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
self._vocoder,
mel_spec_type=mel_spec_type,
speed=speed,
)
generated_audio_segments.append(audio)

if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
return final_sample_rate, final_wave
return None, None

def speech(
self,
input: str,
voice: str,
response_format: str = "mp3",
speed: float = 1.0,
stream: bool = False,
**kwargs,
):
import f5_tts
import soundfile
import tomli

if stream:
raise Exception("F5-TTS does not support stream generation.")

prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)

if prompt_speech is None:
base = os.path.dirname(f5_tts.__file__)
config = os.path.join(base, "infer/examples/basic/basic.toml")
with open(config, "rb") as f:
config_dict = tomli.load(f)
prompt_speech = os.path.join(base, config_dict["ref_audio"])
prompt_text = config_dict["ref_text"]

assert self._model is not None
vocoder_name = self._kwargs.get("vocoder_name", "vocos")
sample_rate, wav = self._infer(
ref_audio=prompt_speech,
ref_text=prompt_text,
text_gen=input,
model_obj=self._model,
mel_spec_type=vocoder_name,
speed=speed,
)

# Save the generated audio
with BytesIO() as out:
with soundfile.SoundFile(
out, "w", sample_rate, 1, format=response_format.upper()
) as f:
f.write(wav)
return out.getvalue()
8 changes: 8 additions & 0 deletions xinference/model/audio/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,13 @@
"model_revision": "069c573759936b35191d3380deb89183c0656f59",
"model_ability": "text-to-audio",
"multilingual": true
},
{
"model_name": "F5-TTS",
"model_family": "F5-TTS",
"model_id": "SWivid/F5-TTS",
"model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
"model_ability": "text-to-audio",
"multilingual": true
}
]
9 changes: 9 additions & 0 deletions xinference/model/audio/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,14 @@
"model_revision": "master",
"model_ability": "text-to-audio",
"multilingual": true
},
{
"model_name": "F5-TTS",
"model_family": "F5-TTS",
"model_hub": "modelscope",
"model_id": "SWivid/F5-TTS_Emilia-ZH-EN",
"model_revision": "master",
"model_ability": "text-to-audio",
"multilingual": true
}
]
Loading

0 comments on commit 53cddf3

Please sign in to comment.