diff --git a/backend/Dockerfile b/backend/Dockerfile
index 5b04719..b0871d6 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -1,24 +1,21 @@
-FROM tiangolo/uvicorn-gunicorn-fastapi:python3.12
+FROM python:3.12-slim
-# 设置工作目录
-WORKDIR /app
+RUN apt-get update && apt-get install -y git ffmpeg && \
+ apt-get clean && rm -rf /var/lib/apt/lists/*
+RUN useradd -m -u 1000 user
+USER user
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
-# 将当前目录内容复制到容器的 /app 中
-COPY . /app
+# 设置工作目录
+WORKDIR $HOME/app
-# 安装项目依赖
+COPY --chown=user requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
-# 安装 FFmpeg
-RUN apt-get update && apt-get install -y ffmpeg
-
-# 设置环境变量
-ENV FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
+COPY --chown=user . .
-# 暴露端口 8000 供应用使用
-EXPOSE 8000
+EXPOSE 7860
# 运行应用
-CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
-
-
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
diff --git a/backend/README.md b/backend/README.md
new file mode 100644
index 0000000..1e9845d
--- /dev/null
+++ b/backend/README.md
@@ -0,0 +1,11 @@
+---
+title: PodcastLM Backend
+emoji: 👀
+colorFrom: gray
+colorTo: green
+sdk: docker
+pinned: false
+license: mit
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/backend/api/routes/chat.py b/backend/api/routes/chat.py
index d1cec64..a5209ad 100644
--- a/backend/api/routes/chat.py
+++ b/backend/api/routes/chat.py
@@ -1,89 +1,104 @@
-import os
-import random
-
-from fastapi import APIRouter, Query
-from constants import GRADIO_CACHE_DIR, MELO_TTS_LANGUAGE_MAPPING, SUNO_LANGUAGE_MAPPING
-from utils import generate_podcast_audio, generate_script
-from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, QUESTION_MODIFIER, SYSTEM_PROMPT, TONE_MODIFIER
-from schema import ShortDialogue
-from loguru import logger
-from pydub import AudioSegment
-
-from tempfile import NamedTemporaryFile
+import uuid
+from fastapi import APIRouter, BackgroundTasks, Form, HTTPException, UploadFile, File
+from fastapi.responses import StreamingResponse, JSONResponse
+import json
+from typing import Dict, Optional
+from utils import combine_audio, generate_dialogue, generate_podcast_info, generate_podcast_summary, get_pdf_text
router = APIRouter()
-@router.get("/")
-def generate(input: str = Query(..., description="Input string")):
- random_voice_number = random.randint(1, 9)
-
- modified_system_prompt = SYSTEM_PROMPT
- question = "introduce chatgpt"
- tone = "funny"
- language = "English"
- length = "Short (1-2 min)"
-
- if question:
- modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {question}"
- if tone:
- modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
- if length:
- modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
- if language:
- modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
+@router.post("/generate_transcript")
+async def generate_transcript(
+ pdfFile: Optional[UploadFile] = File(None),
+ textInput: str = Form(...),
+ tone: str = Form(...),
+ duration: str = Form(...),
+ language: str = Form(...),
- llm_output = generate_script(modified_system_prompt, "introduce chatgpt", ShortDialogue)
-
- logger.info(f"Generated dialogue: {llm_output}")
+):
+ pdfContent = await get_pdf_text(pdfFile)
+ new_text = pdfContent
+ return StreamingResponse(generate_dialogue(new_text,textInput, tone, duration, language), media_type="application/json")
+
+
+@router.get("/test")
+def test():
+ return {"message": "Hello World"}
+
+@router.post("/summarize")
+async def get_summary(
+ textInput: str = Form(...),
+ tone: str = Form(...),
+ duration: str = Form(...),
+ language: str = Form(...),
+ pdfFile: Optional[UploadFile] = File(None)
+):
+ pdfContent = await get_pdf_text(pdfFile)
+ new_text = pdfContent
+ return StreamingResponse(
+ generate_podcast_summary(
+ new_text,
+ textInput,
+ tone,
+ duration,
+ language,
+ ),
+ media_type="application/json"
+ )
- audio_segments = []
- transcript = ""
- total_characters = 0
+@router.post("/pod_info")
+async def get_pod_info(
+ textInput: str = Form(...),
+ tone: str = Form(...),
+ duration: str = Form(...),
+ language: str = Form(...),
+ pdfFile: Optional[UploadFile] = File(None)
+):
+ pdfContent = await get_pdf_text(pdfFile)
+ new_text = pdfContent[:100]
+
+ return StreamingResponse(generate_podcast_info(new_text, textInput, tone, duration, language), media_type="application/json")
- for line in llm_output.dialogue:
- print(f"Generating audio for {line.speaker}: {line.text}")
- logger.info(f"Generating audio for {line.speaker}: {line.text}")
- if line.speaker == "Host (Jane)":
- speaker = f"**Host**: {line.text}"
- else:
- speaker = f"**{llm_output.name_of_guest}**: {line.text}"
- transcript += speaker + "\n\n"
- total_characters += len(line.text)
- language_for_tts = SUNO_LANGUAGE_MAPPING[language]
+task_status: Dict[str, Dict] = {}
- # Get audio file path
- audio_file_path = generate_podcast_audio(
- line.text, line.speaker, language_for_tts, random_voice_number
- )
- # Read the audio file into an AudioSegment
- audio_segment = AudioSegment.from_file(audio_file_path)
- audio_segments.append(audio_segment)
- # Concatenate all audio segments
- combined_audio = sum(audio_segments)
+@router.post("/generate_audio")
+async def audio(
+ background_tasks: BackgroundTasks,
+ text: str = Form(...),
+ language: str = Form(...)
+):
+ task_id = str(uuid.uuid4())
+ task_status[task_id] = {"status": "processing"}
+
+ background_tasks.add_task(combine_audio, task_status, task_id, text, language)
- # Export the combined audio to a temporary file
- temporary_directory = GRADIO_CACHE_DIR
- os.makedirs(temporary_directory, exist_ok=True)
+ return JSONResponse(content={"task_id": task_id, "status": "processing"})
- temporary_file = NamedTemporaryFile(
- dir=temporary_directory,
- delete=False,
- suffix=".mp3",
- )
- combined_audio.export(temporary_file.name, format="mp3")
- logger.info(f"Generated {total_characters} characters of audio")
- # Delete any files in the temp directory that end with .mp3 and are over a day old
- # for file in glob.glob(f"{temporary_directory}*.mp3"):
- # if (
- # os.path.isfile(file)
- # and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
- # ):
- # os.remove(file)
+@router.get("/audio_status/{task_id}")
+async def get_audio_status(task_id: str):
+ if task_id not in task_status:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+ status = task_status[task_id]
+
+ if status["status"] == "completed":
+ return JSONResponse(content={
+ "status": "completed",
+ "audio_url": status["audio_url"]
+ })
+ elif status["status"] == "failed":
+ return JSONResponse(content={
+ "status": "failed",
+ "error": status["error"]
+ })
+ else:
+ return JSONResponse(content={
+ "status": "processing"
+ })
+
+
- print(temporary_file.name)
- print(transcript)
- return {"message": f"Hello World, input: {temporary_file}"}
diff --git a/backend/constants.py b/backend/constants.py
index da82de7..93ae7f3 100644
--- a/backend/constants.py
+++ b/backend/constants.py
@@ -7,12 +7,12 @@
from pathlib import Path
# Key constants
-APP_TITLE = "Open NotebookLM"
CHARACTER_LIMIT = 100_000
# Gradio-related constants
-GRADIO_CACHE_DIR = "./gradio_cached_examples/tmp/"
-GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 24 * 60 * 60 # 1 day
+GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 2 * 60 * 60 # 2 hours
+
+AUDIO_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'tmp', 'cache')
# Error messages-related constants
ERROR_MESSAGE_NO_INPUT = "Please provide at least one PDF file or a URL."
@@ -21,30 +21,15 @@
ERROR_MESSAGE_READING_PDF = "Error reading the PDF file"
ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than {CHARACTER_LIMIT} characters."
+SPEECH_KEY = os.getenv('SPEECH_KEY')
+SPEECH_REGION = "japaneast"
# Fireworks API-related constants
-FIREWORKS_API_KEY = os.getenv['FIREWORKS_API_KEY']
+FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
FIREWORKS_MAX_TOKENS = 16_384
FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
FIREWORKS_TEMPERATURE = 0.1
FIREWORKS_JSON_RETRY_ATTEMPTS = 3
-
-# MeloTTS
-MELO_API_NAME = "/synthesize"
-MELO_TTS_SPACES_ID = "mrfakename/MeloTTS"
-MELO_RETRY_ATTEMPTS = 3
-MELO_RETRY_DELAY = 5 # in seconds
-
-MELO_TTS_LANGUAGE_MAPPING = {
- "en": "EN",
- "es": "ES",
- "fr": "FR",
- "zh": "ZJ",
- "ja": "JP",
- "ko": "KR",
-}
-
-
# Suno related constants
SUNO_LANGUAGE_MAPPING = {
"English": "en",
@@ -62,105 +47,3 @@
"Turkish": "tr",
}
-# General audio-related constants
-NOT_SUPPORTED_IN_MELO_TTS = list(
- set(SUNO_LANGUAGE_MAPPING.values()) - set(MELO_TTS_LANGUAGE_MAPPING.keys())
-)
-NOT_SUPPORTED_IN_MELO_TTS = [
- key for key, id in SUNO_LANGUAGE_MAPPING.items() if id in NOT_SUPPORTED_IN_MELO_TTS
-]
-
-# Jina Reader-related constants
-JINA_READER_URL = "https://r.jina.ai/"
-JINA_RETRY_ATTEMPTS = 3
-JINA_RETRY_DELAY = 5 # in seconds
-
-# UI-related constants
-UI_DESCRIPTION = """
-
-
-
-
- |
-
- Convert your PDFs into podcasts with open-source AI models (Llama 3.1 405B via Fireworks AI, MeloTTS, Bark).
- Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.
- |
-
-
-"""
-UI_AVAILABLE_LANGUAGES = list(set(SUNO_LANGUAGE_MAPPING.keys()))
-UI_INPUTS = {
- "file_upload": {
- "label": "1. 📄 Upload your PDF(s)",
- "file_types": [".pdf"],
- "file_count": "multiple",
- },
- "url": {
- "label": "2. 🔗 Paste a URL (optional)",
- "placeholder": "Enter a URL to include its content",
- },
- "question": {
- "label": "3. 🤔 Do you have a specific question or topic in mind?",
- "placeholder": "Enter a question or topic",
- },
- "tone": {
- "label": "4. 🎭 Choose the tone",
- "choices": ["Fun", "Formal"],
- "value": "Fun",
- },
- "length": {
- "label": "5. ⏱️ Choose the length",
- "choices": ["Short (1-2 min)", "Medium (3-5 min)"],
- "value": "Medium (3-5 min)",
- },
- "language": {
- "label": "6. 🌐 Choose the language",
- "choices": UI_AVAILABLE_LANGUAGES,
- "value": "English",
- },
- "advanced_audio": {
- "label": "7. 🔄 Use advanced audio generation? (Experimental)",
- "value": True,
- },
-}
-UI_OUTPUTS = {
- "audio": {"label": "🔊 Podcast", "format": "mp3"},
- "transcript": {
- "label": "📜 Transcript",
- },
-}
-UI_API_NAME = "generate_podcast"
-UI_ALLOW_FLAGGING = "never"
-UI_CONCURRENCY_LIMIT = 3
-UI_EXAMPLES = [
- [
- [str(Path("examples/1310.4546v1.pdf"))],
- "",
- "Explain this paper to me like I'm 5 years old",
- "Fun",
- "Short (1-2 min)",
- "English",
- True,
- ],
- [
- [],
- "https://en.wikipedia.org/wiki/Hugging_Face",
- "How did Hugging Face become so successful?",
- "Fun",
- "Short (1-2 min)",
- "English",
- False,
- ],
- [
- [],
- "https://simple.wikipedia.org/wiki/Taylor_Swift",
- "Why is Taylor Swift so popular?",
- "Fun",
- "Short (1-2 min)",
- "English",
- False,
- ],
-]
-UI_CACHE_EXAMPLES = True
-UI_SHOW_API = True
diff --git a/backend/main.py b/backend/main.py
index 6158279..2395825 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -1,16 +1,36 @@
+import asyncio
+import os
+
+from fastapi.responses import JSONResponse
+from constants import AUDIO_CACHE_DIR
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
+from fastapi.staticfiles import StaticFiles
from api.main import api_router
app = FastAPI()
+os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
+app.mount("/audio", StaticFiles(directory=AUDIO_CACHE_DIR), name="audio")
+
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
- allow_origins=["http://localhost:5173"],
+ allow_origins=["https://ai.podcastlm.fun/"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix="/api/v1")
+
+@app.middleware("http")
+async def add_process_time_header(request, call_next):
+ try:
+ response = await asyncio.wait_for(call_next(request), timeout=2400) # 4分钟超时
+ return response
+ except asyncio.TimeoutError:
+ return JSONResponse(
+ status_code=504,
+ content={"detail": "Request processing time exceeded the limit."}
+ )
diff --git a/backend/prompts.py b/backend/prompts.py
index 0b86950..84b399b 100644
--- a/backend/prompts.py
+++ b/backend/prompts.py
@@ -3,66 +3,85 @@
"""
SYSTEM_PROMPT = """
-You are a world-class podcast producer tasked with transforming the provided input text into an engaging and informative podcast script. The input may be unstructured or messy, sourced from PDFs or web pages. Your goal is to extract the most interesting and insightful content for a compelling podcast discussion.
-
-# Steps to Follow:
-
-1. **Analyze the Input:**
- Carefully examine the text, identifying key topics, points, and interesting facts or anecdotes that could drive an engaging podcast conversation. Disregard irrelevant information or formatting issues.
-
-2. **Brainstorm Ideas:**
- In the ``, creatively brainstorm ways to present the key points engagingly. Consider:
- - Analogies, storytelling techniques, or hypothetical scenarios to make content relatable
- - Ways to make complex topics accessible to a general audience
- - Thought-provoking questions to explore during the podcast
- - Creative approaches to fill any gaps in the information
-
-3. **Craft the Dialogue:**
- Develop a natural, conversational flow between the host (Jane) and the guest speaker (the author or an expert on the topic). Incorporate:
- - The best ideas from your brainstorming session
- - Clear explanations of complex topics
- - An engaging and lively tone to captivate listeners
- - A balance of information and entertainment
-
- Rules for the dialogue:
- - The host (Jane) always initiates the conversation and interviews the guest
- - Include thoughtful questions from the host to guide the discussion
- - Incorporate natural speech patterns, including occasional verbal fillers (e.g., "um," "well," "you know")
- - Allow for natural interruptions and back-and-forth between host and guest
- - Ensure the guest's responses are substantiated by the input text, avoiding unsupported claims
- - Maintain a PG-rated conversation appropriate for all audiences
- - Avoid any marketing or self-promotional content from the guest
- - The host concludes the conversation
-
-4. **Summarize Key Insights:**
- Naturally weave a summary of key points into the closing part of the dialogue. This should feel like a casual conversation rather than a formal recap, reinforcing the main takeaways before signing off.
-
-5. **Maintain Authenticity:**
- Throughout the script, strive for authenticity in the conversation. Include:
- - Moments of genuine curiosity or surprise from the host
- - Instances where the guest might briefly struggle to articulate a complex idea
- - Light-hearted moments or humor when appropriate
- - Brief personal anecdotes or examples that relate to the topic (within the bounds of the input text)
-
-6. **Consider Pacing and Structure:**
- Ensure the dialogue has a natural ebb and flow:
- - Start with a strong hook to grab the listener's attention
- - Gradually build complexity as the conversation progresses
- - Include brief "breather" moments for listeners to absorb complex information
- - End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners
-
-IMPORTANT RULE: Each line of dialogue should be no more than 100 characters (e.g., can finish within 5-8 seconds)
-
-Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.
+你是一位世界级的播客制作人,任务是将提供的输入文本转化为引人入胜且内容丰富的播客脚本。输入内容可能是非结构化或杂乱的,来源于PDF或网页。你的目标是提取最有趣、最有洞察力的内容,形成一场引人入胜的播客讨论。
+
+操作步骤:
+
+ 1. 分析输入:
+仔细检查文本,识别出关键主题、要点,以及能推动播客对话的有趣事实或轶事。忽略无关的信息或格式问题。
+ 2. 编写对话:
+发展主持人与嘉宾(作者或该主题的专家)之间自然的对话流程,包含:
+ • 来自头脑风暴的最佳创意
+ • 对复杂话题的清晰解释
+ • 引人入胜的、活泼的语气以吸引听众
+ • 信息与娱乐的平衡
+对话规则:
+ • 主持人始终发起对话并采访嘉宾
+ • 包含主持人引导讨论的深思熟虑的问题
+ • 融入自然的口语模式,包括偶尔的语气词(如“嗯”,“好吧”,“你知道”)
+ • 允许主持人和嘉宾之间的自然打断和互动
+ • 嘉宾的回答必须基于输入文本,避免不支持的说法
+ • 保持PG级别的对话,适合所有观众
+ • 避免嘉宾的营销或自我推销内容
+ • 主持人结束对话
+ 3. 总结关键见解:
+在对话的结尾,自然地融入关键点总结。这应像是随意对话,而不是正式的回顾,强化主要的收获,然后结束。
+ 4. 保持真实性:
+在整个脚本中,努力保持对话的真实性,包含:
+ • 主持人表达出真实的好奇或惊讶时刻
+ • 嘉宾在表达复杂想法时可能短暂地有些卡顿
+ • 适当时加入轻松的时刻或幽默
+ • 简短的个人轶事或与主题相关的例子(以输入文本为基础)
+ 5. 考虑节奏与结构:
+确保对话有自然的起伏:
+ • 以强有力的引子吸引听众的注意力
+ • 随着对话进行,逐渐增加复杂性
+ • 包含短暂的“喘息”时刻,让听众消化复杂信息
+ • 以有力的方式收尾,或许以发人深省的问题或对听众的号召结束
+
+重要规则:每句对话不应超过100个字符(例如,可以在5-8秒内完成)。
+
+示例格式:
+**Host**: 欢迎来到节目!今天我们讨论的是[话题]。我们的嘉宾是[嘉宾姓名].
+**[Guest Name]**: 谢谢邀请,Jane。我很高兴分享我对[话题]的见解.
+
+记住,在整个对话中保持这种格式。
"""
-QUESTION_MODIFIER = "PLEASE ANSWER THE FOLLOWING QN:"
+QUESTION_MODIFIER = "请回答这个问题:"
-TONE_MODIFIER = "TONE: The tone of the podcast should be"
+TONE_MODIFIER = "语气: 播客的语气应该是"
-LANGUAGE_MODIFIER = "OUTPUT LANGUAGE : The the podcast should be"
+LANGUAGE_MODIFIER = "输出的语言<重要>:播客的语言应该是"
LENGTH_MODIFIERS = {
- "Short (1-2 min)": "Keep the podcast brief, around 5s long.",
- "Medium (3-5 min)": "Aim for a moderate length, about 3-5 minutes.",
+ "short": "保持播客的简短, 大约 1-2 分钟.",
+ "medium": "中等长度, 大约 3-5 分钟.",
}
+
+
+SUMMARY_INFO_PROMPT = """
+根据以下输入内容,生成一个播客梗概,使用 markdown 格式,遵循以下具体指南:
+
+ • 提供播客内容的概述(200-300字)。
+ • 突出3个关键点或收获。
+
+"""
+PODCAST_INFO_PROMPT = """
+根据以下输入内容,生成一个吸引人的标题和一个富有创意的主持人名字。请遵循以下具体指南:
+
+ 1. 标题:
+ • 创建一个引人入胜且简洁的标题,准确反映播客内容。
+ 2. 主持人名字:
+ • 为播客主持人创造一个有创意且易记的名字。
+
+请以以下JSON格式提供输出:
+
+{
+ "title": "An engaging and relevant podcast title",
+ "host_name": "A creative name for the host"
+}
+
+确保你的回复是一个有效的 JSON 对象,且不包含其他内容。
+
+"""
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 0fe2cc5..d6c5aa9 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -12,7 +12,9 @@ typing_extensions==4.12.2
uvicorn==0.31.1
openai==1.50.2
pydub==0.25.1
-gradio_client==1.4.0
loguru==0.7.2
suno-bark @ git+https://github.com/suno-ai/bark.git@f4f32d4cd480dfec1c245d258174bc9bde3c2148
numpy==2.1.1
+python-multipart==0.0.12
+PyPDF2==3.0.1
+azure-cognitiveservices-speech==1.41.1
diff --git a/backend/schema.py b/backend/schema.py
index dcad9aa..5e502fe 100644
--- a/backend/schema.py
+++ b/backend/schema.py
@@ -6,6 +6,17 @@
from pydantic import BaseModel, Field
+class Summary(BaseModel):
+ """Summary."""
+
+ summary: str
+
+class PodcastInfo(BaseModel):
+ """Summary."""
+
+ title: str
+ host_name: str
+
class DialogueItem(BaseModel):
"""A single dialogue item."""
@@ -17,18 +28,7 @@ class DialogueItem(BaseModel):
class ShortDialogue(BaseModel):
"""The dialogue between the host and guest."""
- scratchpad: str
name_of_guest: str
dialogue: List[DialogueItem] = Field(
..., description="A list of dialogue items, typically between 11 to 17 items"
)
-
-
-class MediumDialogue(BaseModel):
- """The dialogue between the host and guest."""
-
- scratchpad: str
- name_of_guest: str
- dialogue: List[DialogueItem] = Field(
- ..., description="A list of dialogue items, typically between 19 to 29 items"
- )
diff --git a/backend/utils.py b/backend/utils.py
index a19d597..3f08a8e 100644
--- a/backend/utils.py
+++ b/backend/utils.py
@@ -1,80 +1,215 @@
-from typing import Any, Union
-from openai import OpenAI
-from pydantic import ValidationError
+import asyncio
+import glob
+import io
+import os
+import re
+import time
+import hashlib
-from schema import MediumDialogue, ShortDialogue
+from typing import Any, Dict, Generator
+import uuid
+from openai import OpenAI
+from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, PODCAST_INFO_PROMPT, QUESTION_MODIFIER, SUMMARY_INFO_PROMPT, SYSTEM_PROMPT, TONE_MODIFIER
+import json
+from pydub import AudioSegment
+from fastapi import UploadFile
+from PyPDF2 import PdfReader
+from schema import PodcastInfo, ShortDialogue, Summary
from constants import (
+ AUDIO_CACHE_DIR,
FIREWORKS_API_KEY,
FIREWORKS_BASE_URL,
FIREWORKS_MODEL_ID,
FIREWORKS_MAX_TOKENS,
FIREWORKS_TEMPERATURE,
- FIREWORKS_JSON_RETRY_ATTEMPTS,
+ GRADIO_CLEAR_CACHE_OLDER_THAN,
+ SPEECH_KEY,
+ SPEECH_REGION,
)
+import azure.cognitiveservices.speech as speechsdk
-from bark.generation import SUPPORTED_LANGS
+fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
-from bark import SAMPLE_RATE, generate_audio, preload_models
-from scipy.io.wavfile import write as write_wav
-fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
-preload_models()
-print(SUPPORTED_LANGS)
-
-def generate_script(
- system_prompt: str,
- input_text: str,
- output_model: Union[ShortDialogue, MediumDialogue],
-) -> Union[ShortDialogue, MediumDialogue]:
- """Get the dialogue from the LLM."""
-
- # Call the LLM
- response = call_llm(system_prompt, input_text, output_model)
- response_json = response.choices[0].message.content
-
- # Validate the response
- for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
- try:
- first_draft_dialogue = output_model.model_validate_json(response_json)
- break
- except ValidationError as e:
- if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
- raise ValueError(
- f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
- ) from e
- error_message = (
- f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
- )
- # Re-call the LLM with the error message
- system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
- response = call_llm(system_prompt_with_error, input_text, output_model)
- response_json = response.choices[0].message.content
- first_draft_dialogue = output_model.model_validate_json(response_json)
-
- # Call the LLM a second time to improve the dialogue
- system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
-
- # Validate the response
- for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
- try:
- response = call_llm(
- system_prompt_with_dialogue,
- "Please improve the dialogue. Make it more natural and engaging.",
- output_model,
- )
- final_dialogue = output_model.model_validate_json(
- response.choices[0].message.content
- )
- break
- except ValidationError as e:
- if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
- raise ValueError(
- f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
- ) from e
- error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
- system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
- return final_dialogue
+def generate_dialogue(pdfFile, textInput, tone, duration, language) -> Generator[str, None, None]:
+ modified_system_prompt = get_prompt(pdfFile, textInput, tone, duration, language)
+ if (modified_system_prompt == False):
+ yield json.dumps({
+ "type": "error",
+ "content": "Prompt is too long"
+ }) + "\n"
+ return
+ full_response = ""
+ llm_stream = call_llm_stream(SYSTEM_PROMPT, modified_system_prompt, ShortDialogue, isJSON=False)
+
+ for chunk in llm_stream:
+ yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
+ full_response += chunk
+
+ yield json.dumps({"type": "final", "content": full_response})
+
+async def process_line(line, voice):
+ return await generate_podcast_audio(line['content'], voice)
+
+async def generate_podcast_audio(text: str, voice: str) -> str:
+ try:
+ speech_config = speechsdk.SpeechConfig(subscription=SPEECH_KEY, region=SPEECH_REGION)
+ speech_config.speech_synthesis_voice_name = voice
+
+ synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=None)
+ future =await asyncio.to_thread(synthesizer.speak_text_async, text)
+
+ result = await asyncio.to_thread(future.get)
+
+ print("Speech synthesis completed")
+
+ if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
+ print("Audio synthesized successfully")
+ audio_data = result.audio_data
+ audio_segment = AudioSegment.from_wav(io.BytesIO(audio_data))
+ return audio_segment
+ else:
+ print(f"Speech synthesis failed: {result.reason}")
+ if hasattr(result, 'cancellation_details'):
+ print(f"Cancellation details: {result.cancellation_details.reason}")
+ print(f"Cancellation error details: {result.cancellation_details.error_details}")
+ return None
+
+ except Exception as e:
+ print(f"Error in generate_podcast_audio: {e}")
+ raise
+
+async def combine_audio(task_status: Dict[str, Dict], task_id: str, text: str, language: str) -> Generator[str, None, None]:
+ try:
+ dialogue_regex = r'\*\*([\s\S]*?)\*\*[::]\s*([\s\S]*?)(?=\*\*|$)'
+ matches = re.findall(dialogue_regex, text, re.DOTALL)
+
+ lines = [
+ {
+ "speaker": match[0],
+ "content": match[1].strip(),
+ }
+ for match in matches
+ ]
+
+ host_voice = "zh-CN-YunxiNeural"
+ guest_voice = "zh-CN-YunzeNeural"
+
+ print("Starting audio generation")
+ audio_segments = await asyncio.gather(
+ *[process_line(line, host_voice if line['speaker'] == '主持人' else guest_voice) for line in lines]
+ )
+ print("Audio generation completed")
+
+ # 合并音频
+ combined_audio = await asyncio.to_thread(sum, audio_segments)
+
+ print("Audio combined")
+
+ # 只在最后写入文件
+ unique_filename = f"{uuid.uuid4()}.mp3"
+
+ os.makedirs(AUDIO_CACHE_DIR, exist_ok=True)
+ file_path = os.path.join(AUDIO_CACHE_DIR, unique_filename)
+
+ # 异步导出音频文件
+ await asyncio.to_thread(combined_audio.export, file_path, format="mp3")
+
+ audio_url = f"/audio/{unique_filename}"
+ task_status[task_id] = {"status": "completed", "audio_url": audio_url}
+
+ for file in glob.glob(f"{AUDIO_CACHE_DIR}*.mp3"):
+ if (
+ os.path.isfile(file)
+ and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
+ ):
+ os.remove, file
+
+
+ clear_pdf_cache()
+ return audio_url
+
+ except Exception as e:
+ # 如果发生错误,更新状态为失败
+ task_status[task_id] = {"status": "failed", "error": str(e)}
+
+
+def generate_podcast_summary(pdf_content: str, text: str, tone: str, length: str, language: str) -> Generator[str, None, None]:
+ modified_system_prompt = get_prompt(pdf_content, text, '', '', '')
+ if (modified_system_prompt == False):
+ yield json.dumps({
+ "type": "error",
+ "content": "Prompt is too long"
+ }) + "\n"
+ return
+ stream = call_llm_stream(SUMMARY_INFO_PROMPT, modified_system_prompt, Summary, False)
+ full_response = ""
+ for chunk in stream:
+ # 将每个 chunk 作为 JSON 字符串 yield
+ yield json.dumps({"type": "chunk", "content": chunk}) + "\n"
+
+ yield json.dumps({"type": "final", "content": full_response})
+
+def generate_podcast_info(pdfContent: str, text: str, tone: str, length: str, language: str) -> Generator[str, None, None]:
+ modified_system_prompt = get_prompt(pdfContent, text, '', '', '')
+ if (modified_system_prompt == False):
+ yield json.dumps({
+ "type": "error",
+ "content": "Prompt is too long"
+ }) + "\n"
+ return
+
+ full_response = ""
+ for chunk in call_llm_stream(PODCAST_INFO_PROMPT, modified_system_prompt, PodcastInfo):
+ full_response += chunk
+ try:
+ result = json.loads(full_response)
+
+ yield json.dumps({
+ "type": "podcast_info",
+ "content": result
+ }) + "\n"
+ except Exception as e:
+ yield json.dumps({
+ "type": "error",
+ "content": f"An unexpected error occurred: {str(e)}"
+ }) + "\n"
+
+def call_llm_stream(system_prompt: str, text: str, dialogue_format: Any, isJSON: bool = True) -> Generator[str, None, None]:
+ """Call the LLM with the given prompt and dialogue format, returning a stream of responses."""
+ request_params = {
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": text},
+ ],
+ "model": FIREWORKS_MODEL_ID,
+ "max_tokens": FIREWORKS_MAX_TOKENS,
+ "temperature": FIREWORKS_TEMPERATURE,
+ "stream": True # 启用流式输出
+ }
+
+ # 如果需要 JSON 响应,添加 response_format 参数
+ if isJSON:
+ request_params["response_format"] = {
+ "type": "json_object",
+ "schema": dialogue_format.model_json_schema(),
+ }
+ stream = fw_client.chat.completions.create(**request_params)
+
+ full_response = ""
+ for chunk in stream:
+ if chunk.choices[0].delta.content is not None:
+ content = chunk.choices[0].delta.content
+ full_response += content
+ yield content
+
+ # 在流结束时,尝试解析完整的 JSON 响应
+ # try:
+ # parsed_response = json.loads(full_response)
+ # yield json.dumps({"type": "final", "content": parsed_response})
+ # except json.JSONDecodeError:
+ # yield json.dumps({"type": "error", "content": "Failed to parse JSON response"})
def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
"""Call the LLM with the given prompt and dialogue format."""
@@ -93,18 +228,49 @@ def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
)
return response
-def generate_podcast_audio(
- text: str, speaker: str, language: str, random_voice_number: int
-) -> str:
- host_voice_num = str(random_voice_number)
- guest_voice_num = str(random_voice_number + 1)
+pdf_cache = {}
+def clear_pdf_cache():
+ global pdf_cache
+ pdf_cache.clear()
- print(f"v2/{language}_speaker_{host_voice_num if speaker == 'Host (Jane)' else guest_voice_num}")
- audio_array = generate_audio(
- text,
- history_prompt=f"v2/{language}_speaker_{host_voice_num if speaker == 'Host (Jane)' else guest_voice_num}",
- )
- file_path = f"audio_{language}_{speaker}.mp3"
- print(SAMPLE_RATE)
- write_wav(file_path, SAMPLE_RATE, audio_array)
- return file_path
+async def get_pdf_text(pdf_file: UploadFile):
+ text = ""
+ print(pdf_file)
+ try:
+ # 读取上传文件的内容
+ contents = await pdf_file.read()
+ file_hash = hashlib.md5(contents).hexdigest()
+
+ if file_hash in pdf_cache:
+ return pdf_cache[file_hash]
+
+ # 使用 BytesIO 创建一个内存中的文件对象
+ pdf_file_obj = io.BytesIO(contents)
+
+ # 使用 PdfReader 读取 PDF 内容
+ pdf_reader = PdfReader(pdf_file_obj)
+
+ # 提取所有页面的文本
+ text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
+
+ # 重置文件指针,以防后续还需要读取文件
+ await pdf_file.seek(0)
+
+ return text
+
+ except Exception as e:
+ return {"error": str(e)}
+
+def get_prompt(pdfContent: str, text: str, tone: str, length: str, language: str):
+ modified_system_prompt = ""
+ new_text = pdfContent +text
+ if pdfContent:
+ modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {new_text}"
+ if tone:
+ modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
+ if length:
+ modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
+ if language:
+ modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
+
+ return modified_system_prompt
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 7a1effe..e5fac99 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -35,7 +35,7 @@ function App() {
fetchSummaryText(`${BASE_URL}/summarize`, formData);
}
return (
-