-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
476 additions
and
365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,21 @@ | ||
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.12 | ||
FROM python:3.12-slim | ||
|
||
# 设置工作目录 | ||
WORKDIR /app | ||
RUN apt-get update && apt-get install -y git ffmpeg && \ | ||
apt-get clean && rm -rf /var/lib/apt/lists/* | ||
RUN useradd -m -u 1000 user | ||
USER user | ||
ENV HOME=/home/user \ | ||
PATH=/home/user/.local/bin:$PATH | ||
|
||
# 将当前目录内容复制到容器的 /app 中 | ||
COPY . /app | ||
# 设置工作目录 | ||
WORKDIR $HOME/app | ||
|
||
# 安装项目依赖 | ||
COPY --chown=user requirements.txt . | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
# 安装 FFmpeg | ||
RUN apt-get update && apt-get install -y ffmpeg | ||
|
||
# 设置环境变量 | ||
ENV FIREWORKS_API_KEY=${FIREWORKS_API_KEY} | ||
COPY --chown=user . . | ||
|
||
# 暴露端口 8000 供应用使用 | ||
EXPOSE 8000 | ||
EXPOSE 7860 | ||
|
||
# 运行应用 | ||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] | ||
|
||
|
||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
--- | ||
title: PodcastLM Backend | ||
emoji: 👀 | ||
colorFrom: gray | ||
colorTo: green | ||
sdk: docker | ||
pinned: false | ||
license: mit | ||
--- | ||
|
||
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,104 @@ | ||
import os | ||
import random | ||
|
||
from fastapi import APIRouter, Query | ||
from constants import GRADIO_CACHE_DIR, MELO_TTS_LANGUAGE_MAPPING, SUNO_LANGUAGE_MAPPING | ||
from utils import generate_podcast_audio, generate_script | ||
from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, QUESTION_MODIFIER, SYSTEM_PROMPT, TONE_MODIFIER | ||
from schema import ShortDialogue | ||
from loguru import logger | ||
from pydub import AudioSegment | ||
|
||
from tempfile import NamedTemporaryFile | ||
import uuid | ||
from fastapi import APIRouter, BackgroundTasks, Form, HTTPException, UploadFile, File | ||
from fastapi.responses import StreamingResponse, JSONResponse | ||
import json | ||
from typing import Dict, Optional | ||
from utils import combine_audio, generate_dialogue, generate_podcast_info, generate_podcast_summary, get_pdf_text | ||
|
||
router = APIRouter() | ||
|
||
@router.get("/") | ||
def generate(input: str = Query(..., description="Input string")): | ||
random_voice_number = random.randint(1, 9) | ||
|
||
modified_system_prompt = SYSTEM_PROMPT | ||
question = "introduce chatgpt" | ||
tone = "funny" | ||
language = "English" | ||
length = "Short (1-2 min)" | ||
|
||
if question: | ||
modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {question}" | ||
if tone: | ||
modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}." | ||
if length: | ||
modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}" | ||
if language: | ||
modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}." | ||
@router.post("/generate_transcript") | ||
async def generate_transcript( | ||
pdfFile: Optional[UploadFile] = File(None), | ||
textInput: str = Form(...), | ||
tone: str = Form(...), | ||
duration: str = Form(...), | ||
language: str = Form(...), | ||
|
||
llm_output = generate_script(modified_system_prompt, "introduce chatgpt", ShortDialogue) | ||
|
||
logger.info(f"Generated dialogue: {llm_output}") | ||
): | ||
pdfContent = await get_pdf_text(pdfFile) | ||
new_text = pdfContent | ||
return StreamingResponse(generate_dialogue(new_text,textInput, tone, duration, language), media_type="application/json") | ||
|
||
|
||
@router.get("/test") | ||
def test(): | ||
return {"message": "Hello World"} | ||
|
||
@router.post("/summarize") | ||
async def get_summary( | ||
textInput: str = Form(...), | ||
tone: str = Form(...), | ||
duration: str = Form(...), | ||
language: str = Form(...), | ||
pdfFile: Optional[UploadFile] = File(None) | ||
): | ||
pdfContent = await get_pdf_text(pdfFile) | ||
new_text = pdfContent | ||
return StreamingResponse( | ||
generate_podcast_summary( | ||
new_text, | ||
textInput, | ||
tone, | ||
duration, | ||
language, | ||
), | ||
media_type="application/json" | ||
) | ||
|
||
audio_segments = [] | ||
transcript = "" | ||
total_characters = 0 | ||
@router.post("/pod_info") | ||
async def get_pod_info( | ||
textInput: str = Form(...), | ||
tone: str = Form(...), | ||
duration: str = Form(...), | ||
language: str = Form(...), | ||
pdfFile: Optional[UploadFile] = File(None) | ||
): | ||
pdfContent = await get_pdf_text(pdfFile) | ||
new_text = pdfContent[:100] | ||
|
||
return StreamingResponse(generate_podcast_info(new_text, textInput, tone, duration, language), media_type="application/json") | ||
|
||
for line in llm_output.dialogue: | ||
print(f"Generating audio for {line.speaker}: {line.text}") | ||
logger.info(f"Generating audio for {line.speaker}: {line.text}") | ||
if line.speaker == "Host (Jane)": | ||
speaker = f"**Host**: {line.text}" | ||
else: | ||
speaker = f"**{llm_output.name_of_guest}**: {line.text}" | ||
transcript += speaker + "\n\n" | ||
total_characters += len(line.text) | ||
|
||
language_for_tts = SUNO_LANGUAGE_MAPPING[language] | ||
task_status: Dict[str, Dict] = {} | ||
|
||
# Get audio file path | ||
audio_file_path = generate_podcast_audio( | ||
line.text, line.speaker, language_for_tts, random_voice_number | ||
) | ||
# Read the audio file into an AudioSegment | ||
audio_segment = AudioSegment.from_file(audio_file_path) | ||
audio_segments.append(audio_segment) | ||
|
||
# Concatenate all audio segments | ||
combined_audio = sum(audio_segments) | ||
@router.post("/generate_audio") | ||
async def audio( | ||
background_tasks: BackgroundTasks, | ||
text: str = Form(...), | ||
language: str = Form(...) | ||
): | ||
task_id = str(uuid.uuid4()) | ||
task_status[task_id] = {"status": "processing"} | ||
|
||
background_tasks.add_task(combine_audio, task_status, task_id, text, language) | ||
|
||
# Export the combined audio to a temporary file | ||
temporary_directory = GRADIO_CACHE_DIR | ||
os.makedirs(temporary_directory, exist_ok=True) | ||
return JSONResponse(content={"task_id": task_id, "status": "processing"}) | ||
|
||
temporary_file = NamedTemporaryFile( | ||
dir=temporary_directory, | ||
delete=False, | ||
suffix=".mp3", | ||
) | ||
combined_audio.export(temporary_file.name, format="mp3") | ||
logger.info(f"Generated {total_characters} characters of audio") | ||
|
||
# Delete any files in the temp directory that end with .mp3 and are over a day old | ||
# for file in glob.glob(f"{temporary_directory}*.mp3"): | ||
# if ( | ||
# os.path.isfile(file) | ||
# and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN | ||
# ): | ||
# os.remove(file) | ||
@router.get("/audio_status/{task_id}") | ||
async def get_audio_status(task_id: str): | ||
if task_id not in task_status: | ||
raise HTTPException(status_code=404, detail="Task not found") | ||
|
||
status = task_status[task_id] | ||
|
||
if status["status"] == "completed": | ||
return JSONResponse(content={ | ||
"status": "completed", | ||
"audio_url": status["audio_url"] | ||
}) | ||
elif status["status"] == "failed": | ||
return JSONResponse(content={ | ||
"status": "failed", | ||
"error": status["error"] | ||
}) | ||
else: | ||
return JSONResponse(content={ | ||
"status": "processing" | ||
}) | ||
|
||
|
||
|
||
print(temporary_file.name) | ||
print(transcript) | ||
return {"message": f"Hello World, input: {temporary_file}"} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.