Skip to content

Commit

Permalink
feat: backend
Browse files Browse the repository at this point in the history
  • Loading branch information
YOYZHANG committed Oct 22, 2024
1 parent 785bba8 commit 25cbc35
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 365 deletions.
29 changes: 13 additions & 16 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
FROM tiangolo/uvicorn-gunicorn-fastapi:python3.12
FROM python:3.12-slim

# 设置工作目录
WORKDIR /app
RUN apt-get update && apt-get install -y git ffmpeg && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN useradd -m -u 1000 user
USER user
ENV HOME=/home/user \
PATH=/home/user/.local/bin:$PATH

# 将当前目录内容复制到容器的 /app 中
COPY . /app
# 设置工作目录
WORKDIR $HOME/app

# 安装项目依赖
COPY --chown=user requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 安装 FFmpeg
RUN apt-get update && apt-get install -y ffmpeg

# 设置环境变量
ENV FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
COPY --chown=user . .

# 暴露端口 8000 供应用使用
EXPOSE 8000
EXPOSE 7860

# 运行应用
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]


CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
11 changes: 11 additions & 0 deletions backend/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
title: PodcastLM Backend
emoji: 👀
colorFrom: gray
colorTo: green
sdk: docker
pinned: false
license: mit
---

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
165 changes: 90 additions & 75 deletions backend/api/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,104 @@
import os
import random

from fastapi import APIRouter, Query
from constants import GRADIO_CACHE_DIR, MELO_TTS_LANGUAGE_MAPPING, SUNO_LANGUAGE_MAPPING
from utils import generate_podcast_audio, generate_script
from prompts import LANGUAGE_MODIFIER, LENGTH_MODIFIERS, QUESTION_MODIFIER, SYSTEM_PROMPT, TONE_MODIFIER
from schema import ShortDialogue
from loguru import logger
from pydub import AudioSegment

from tempfile import NamedTemporaryFile
import uuid
from fastapi import APIRouter, BackgroundTasks, Form, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse, JSONResponse
import json
from typing import Dict, Optional
from utils import combine_audio, generate_dialogue, generate_podcast_info, generate_podcast_summary, get_pdf_text

router = APIRouter()

@router.get("/")
def generate(input: str = Query(..., description="Input string")):
random_voice_number = random.randint(1, 9)

modified_system_prompt = SYSTEM_PROMPT
question = "introduce chatgpt"
tone = "funny"
language = "English"
length = "Short (1-2 min)"

if question:
modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {question}"
if tone:
modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
if length:
modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
if language:
modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
@router.post("/generate_transcript")
async def generate_transcript(
pdfFile: Optional[UploadFile] = File(None),
textInput: str = Form(...),
tone: str = Form(...),
duration: str = Form(...),
language: str = Form(...),

llm_output = generate_script(modified_system_prompt, "introduce chatgpt", ShortDialogue)

logger.info(f"Generated dialogue: {llm_output}")
):
pdfContent = await get_pdf_text(pdfFile)
new_text = pdfContent
return StreamingResponse(generate_dialogue(new_text,textInput, tone, duration, language), media_type="application/json")


@router.get("/test")
def test():
return {"message": "Hello World"}

@router.post("/summarize")
async def get_summary(
textInput: str = Form(...),
tone: str = Form(...),
duration: str = Form(...),
language: str = Form(...),
pdfFile: Optional[UploadFile] = File(None)
):
pdfContent = await get_pdf_text(pdfFile)
new_text = pdfContent
return StreamingResponse(
generate_podcast_summary(
new_text,
textInput,
tone,
duration,
language,
),
media_type="application/json"
)

audio_segments = []
transcript = ""
total_characters = 0
@router.post("/pod_info")
async def get_pod_info(
textInput: str = Form(...),
tone: str = Form(...),
duration: str = Form(...),
language: str = Form(...),
pdfFile: Optional[UploadFile] = File(None)
):
pdfContent = await get_pdf_text(pdfFile)
new_text = pdfContent[:100]

return StreamingResponse(generate_podcast_info(new_text, textInput, tone, duration, language), media_type="application/json")

for line in llm_output.dialogue:
print(f"Generating audio for {line.speaker}: {line.text}")
logger.info(f"Generating audio for {line.speaker}: {line.text}")
if line.speaker == "Host (Jane)":
speaker = f"**Host**: {line.text}"
else:
speaker = f"**{llm_output.name_of_guest}**: {line.text}"
transcript += speaker + "\n\n"
total_characters += len(line.text)

language_for_tts = SUNO_LANGUAGE_MAPPING[language]
task_status: Dict[str, Dict] = {}

# Get audio file path
audio_file_path = generate_podcast_audio(
line.text, line.speaker, language_for_tts, random_voice_number
)
# Read the audio file into an AudioSegment
audio_segment = AudioSegment.from_file(audio_file_path)
audio_segments.append(audio_segment)

# Concatenate all audio segments
combined_audio = sum(audio_segments)
@router.post("/generate_audio")
async def audio(
background_tasks: BackgroundTasks,
text: str = Form(...),
language: str = Form(...)
):
task_id = str(uuid.uuid4())
task_status[task_id] = {"status": "processing"}

background_tasks.add_task(combine_audio, task_status, task_id, text, language)

# Export the combined audio to a temporary file
temporary_directory = GRADIO_CACHE_DIR
os.makedirs(temporary_directory, exist_ok=True)
return JSONResponse(content={"task_id": task_id, "status": "processing"})

temporary_file = NamedTemporaryFile(
dir=temporary_directory,
delete=False,
suffix=".mp3",
)
combined_audio.export(temporary_file.name, format="mp3")
logger.info(f"Generated {total_characters} characters of audio")

# Delete any files in the temp directory that end with .mp3 and are over a day old
# for file in glob.glob(f"{temporary_directory}*.mp3"):
# if (
# os.path.isfile(file)
# and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
# ):
# os.remove(file)
@router.get("/audio_status/{task_id}")
async def get_audio_status(task_id: str):
if task_id not in task_status:
raise HTTPException(status_code=404, detail="Task not found")

status = task_status[task_id]

if status["status"] == "completed":
return JSONResponse(content={
"status": "completed",
"audio_url": status["audio_url"]
})
elif status["status"] == "failed":
return JSONResponse(content={
"status": "failed",
"error": status["error"]
})
else:
return JSONResponse(content={
"status": "processing"
})



print(temporary_file.name)
print(transcript)
return {"message": f"Hello World, input: {temporary_file}"}

129 changes: 6 additions & 123 deletions backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pathlib import Path

# Key constants
APP_TITLE = "Open NotebookLM"
CHARACTER_LIMIT = 100_000

# Gradio-related constants
GRADIO_CACHE_DIR = "./gradio_cached_examples/tmp/"
GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 24 * 60 * 60 # 1 day
GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 2 * 60 * 60 # 2 hours

AUDIO_CACHE_DIR = os.path.join(os.path.dirname(__file__), 'tmp', 'cache')

# Error messages-related constants
ERROR_MESSAGE_NO_INPUT = "Please provide at least one PDF file or a URL."
Expand All @@ -21,30 +21,15 @@
ERROR_MESSAGE_READING_PDF = "Error reading the PDF file"
ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than {CHARACTER_LIMIT} characters."

SPEECH_KEY = os.getenv('SPEECH_KEY')
SPEECH_REGION = "japaneast"
# Fireworks API-related constants
FIREWORKS_API_KEY = os.getenv['FIREWORKS_API_KEY']
FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
FIREWORKS_MAX_TOKENS = 16_384
FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
FIREWORKS_TEMPERATURE = 0.1
FIREWORKS_JSON_RETRY_ATTEMPTS = 3

# MeloTTS
MELO_API_NAME = "/synthesize"
MELO_TTS_SPACES_ID = "mrfakename/MeloTTS"
MELO_RETRY_ATTEMPTS = 3
MELO_RETRY_DELAY = 5 # in seconds

MELO_TTS_LANGUAGE_MAPPING = {
"en": "EN",
"es": "ES",
"fr": "FR",
"zh": "ZJ",
"ja": "JP",
"ko": "KR",
}


# Suno related constants
SUNO_LANGUAGE_MAPPING = {
"English": "en",
Expand All @@ -62,105 +47,3 @@
"Turkish": "tr",
}

# General audio-related constants
NOT_SUPPORTED_IN_MELO_TTS = list(
set(SUNO_LANGUAGE_MAPPING.values()) - set(MELO_TTS_LANGUAGE_MAPPING.keys())
)
NOT_SUPPORTED_IN_MELO_TTS = [
key for key, id in SUNO_LANGUAGE_MAPPING.items() if id in NOT_SUPPORTED_IN_MELO_TTS
]

# Jina Reader-related constants
JINA_READER_URL = "https://r.jina.ai/"
JINA_RETRY_ATTEMPTS = 3
JINA_RETRY_DELAY = 5 # in seconds

# UI-related constants
UI_DESCRIPTION = """
<table style="border-collapse: collapse; border: none; padding: 20px;">
<tr style="border: none;">
<td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
<img src="https://raw.githubusercontent.com/gabrielchua/daily-ai-papers/main/_includes/icon.png" alt="Open NotebookLM" width="120" style="margin-bottom: 10px;">
</td>
<td style="border: none; vertical-align: top; padding: 10px;">
<p style="margin-bottom: 15px;">Convert your PDFs into podcasts with open-source AI models (<a href="https://huggingface.co/meta-llama/Llama-3.1-405B">Llama 3.1 405B</a> via <a href="https://fireworks.ai/">Fireworks AI</a>, <a href="https://huggingface.co/myshell-ai/MeloTTS-English">MeloTTS</a>, <a href="https://huggingface.co/suno/bark">Bark</a>).</p>
<p style="margin-top: 15px;">Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.</p>
</td>
</tr>
</table>
"""
UI_AVAILABLE_LANGUAGES = list(set(SUNO_LANGUAGE_MAPPING.keys()))
UI_INPUTS = {
"file_upload": {
"label": "1. 📄 Upload your PDF(s)",
"file_types": [".pdf"],
"file_count": "multiple",
},
"url": {
"label": "2. 🔗 Paste a URL (optional)",
"placeholder": "Enter a URL to include its content",
},
"question": {
"label": "3. 🤔 Do you have a specific question or topic in mind?",
"placeholder": "Enter a question or topic",
},
"tone": {
"label": "4. 🎭 Choose the tone",
"choices": ["Fun", "Formal"],
"value": "Fun",
},
"length": {
"label": "5. ⏱️ Choose the length",
"choices": ["Short (1-2 min)", "Medium (3-5 min)"],
"value": "Medium (3-5 min)",
},
"language": {
"label": "6. 🌐 Choose the language",
"choices": UI_AVAILABLE_LANGUAGES,
"value": "English",
},
"advanced_audio": {
"label": "7. 🔄 Use advanced audio generation? (Experimental)",
"value": True,
},
}
UI_OUTPUTS = {
"audio": {"label": "🔊 Podcast", "format": "mp3"},
"transcript": {
"label": "📜 Transcript",
},
}
UI_API_NAME = "generate_podcast"
UI_ALLOW_FLAGGING = "never"
UI_CONCURRENCY_LIMIT = 3
UI_EXAMPLES = [
[
[str(Path("examples/1310.4546v1.pdf"))],
"",
"Explain this paper to me like I'm 5 years old",
"Fun",
"Short (1-2 min)",
"English",
True,
],
[
[],
"https://en.wikipedia.org/wiki/Hugging_Face",
"How did Hugging Face become so successful?",
"Fun",
"Short (1-2 min)",
"English",
False,
],
[
[],
"https://simple.wikipedia.org/wiki/Taylor_Swift",
"Why is Taylor Swift so popular?",
"Fun",
"Short (1-2 min)",
"English",
False,
],
]
UI_CACHE_EXAMPLES = True
UI_SHOW_API = True
Loading

0 comments on commit 25cbc35

Please sign in to comment.