From 63263350f5dbc149e64101fb0b4749db5d08772e Mon Sep 17 00:00:00 2001 From: mrorigo Date: Thu, 23 Jan 2025 19:38:56 +0100 Subject: [PATCH] feat: Support ollama instead of llamafile - Removed llamafile dependency from requirements.txt - Use Ollama instead of llamafile - Introduces LLMBackendType enum --- podcastfy/client.py | 52 ++-- podcastfy/content_generator.py | 515 ++++++++++++++++++--------------- podcastfy/text_to_speech.py | 36 ++- requirements.txt | 1 + tests/test_genai_podcast.py | 37 ++- tests/test_generate_podcast.py | 26 +- 6 files changed, 377 insertions(+), 290 deletions(-) diff --git a/podcastfy/client.py b/podcastfy/client.py index e13f54a..ba469a3 100644 --- a/podcastfy/client.py +++ b/podcastfy/client.py @@ -11,7 +11,7 @@ import typer import yaml from podcastfy.content_parser.content_extractor import ContentExtractor -from podcastfy.content_generator import ContentGenerator +from podcastfy.content_generator import ContentGenerator, LLMBackendType from podcastfy.text_to_speech import TextToSpeech from podcastfy.utils.config import Config, load_config from podcastfy.utils.config_conversation import load_conversation_config @@ -47,12 +47,12 @@ def process_content( config: Optional[Dict[str, Any]] = None, conversation_config: Optional[Dict[str, Any]] = None, image_paths: Optional[List[str]] = None, - is_local: bool = False, + llm_type: LLMBackendType = LLMBackendType.LITELLM, text: Optional[str] = None, model_name: Optional[str] = None, api_key_label: Optional[str] = None, topic: Optional[str] = None, - longform: bool = False + longform: bool = False, ): """ Process URLs, a transcript file, image paths, or raw text to generate a podcast or transcript. @@ -82,14 +82,14 @@ def process_content( content_extractor = ContentExtractor() content_generator = ContentGenerator( - is_local=is_local, + llm_type, model_name=model_name, api_key_label=api_key_label, - conversation_config=conv_config.to_dict() + conversation_config=conv_config.to_dict(), ) combined_content = "" - + if urls: logger.info(f"Processing {len(urls)} links") contents = [content_extractor.extract_content(link) for link in urls] @@ -97,7 +97,9 @@ def process_content( if text: if longform and len(text.strip()) < 100: - logger.info("Text too short for direct long-form generation. Extracting context...") + logger.info( + "Text too short for direct long-form generation. Extracting context..." + ) expanded_content = content_extractor.generate_topic_content(text) combined_content += f"\n\n{expanded_content}" else: @@ -117,13 +119,15 @@ def process_content( combined_content, image_file_paths=image_paths or [], output_filepath=transcript_filepath, - longform=longform + longform=longform, ) if generate_audio: api_key = None if tts_model != "edge": - api_key = getattr(config, f"{tts_model.upper().replace('MULTI', '')}_API_KEY") + api_key = getattr( + config, f"{tts_model.upper().replace('MULTI', '')}_API_KEY" + ) text_to_speech = TextToSpeech( model=tts_model, @@ -183,6 +187,12 @@ def main( text: str = typer.Option( None, "--text", "-txt", help="Raw text input to be processed" ), + llm_type: str = typer.Option( + None, + "--llm-type", + "-lt", + help="LLM type for content generation (litellm(default), ollama, google) ", + ), llm_model_name: str = typer.Option( None, "--llm-model-name", "-m", help="LLM model name for transcript generation" ), @@ -193,10 +203,10 @@ def main( None, "--topic", "-tp", help="Topic to generate podcast about" ), longform: bool = typer.Option( - False, - "--longform", - "-lf", - help="Generate long-form content (only available for text input without images)" + False, + "--longform", + "-lf", + help="Generate long-form content (only available for text input without images)", ), ): """ @@ -226,12 +236,12 @@ def main( generate_audio=not transcript_only, conversation_config=conversation_config, config=config, - is_local=is_local, + llm_type=llm_type, text=text, model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) else: urls_list = urls or [] @@ -250,12 +260,12 @@ def main( config=config, conversation_config=conversation_config, image_paths=image_paths, - is_local=is_local, + llm_type=llm_type, text=text, model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) if transcript_only: @@ -283,7 +293,7 @@ def generate_podcast( config: Optional[Dict[str, Any]] = None, conversation_config: Optional[Dict[str, Any]] = None, image_paths: Optional[List[str]] = None, - is_local: bool = False, + llm_type: LLMBackendType = LLMBackendType.LITELLM, text: Optional[str] = None, llm_model_name: Optional[str] = None, api_key_label: Optional[str] = None, @@ -302,7 +312,7 @@ def generate_podcast( config (Optional[Dict[str, Any]]): User-provided configuration dictionary. conversation_config (Optional[Dict[str, Any]]): User-provided conversation configuration dictionary. image_paths (Optional[List[str]]): List of image file paths to process. - is_local (bool): Whether to use a local LLM. Defaults to False. + llm_type (LLMBackendType): LLM backend type for content generation. text (Optional[str]): Raw text input to be processed. llm_model_name (Optional[str]): LLM model name for content generation. api_key_label (Optional[str]): Environment variable name for LLM API key. @@ -355,7 +365,7 @@ def generate_podcast( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) else: urls_list = urls or [] @@ -381,7 +391,7 @@ def generate_podcast( model_name=llm_model_name, api_key_label=api_key_label, topic=topic, - longform=longform + longform=longform, ) except Exception as e: diff --git a/podcastfy/content_generator.py b/podcastfy/content_generator.py index f3cfd91..fcd97f9 100644 --- a/podcastfy/content_generator.py +++ b/podcastfy/content_generator.py @@ -8,12 +8,12 @@ import os from typing import Optional, Dict, Any, List +from enum import Enum import re - from langchain_community.chat_models import ChatLiteLLM from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_community.llms.llamafile import Llamafile +from langchain_ollama.llms import OllamaLLM from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain import hub @@ -26,67 +26,101 @@ logger = logging.getLogger(__name__) -class LLMBackend: +class LLMBackendType(Enum): + LITELLM = "litellm" + GOOGLE = "google" + OLLAMA = "ollama" + # OPENAI = "openai" + + +class LLMBackendConfig: + type: LLMBackendType + temperature: float + max_output_tokens: int + model_name: str + api_key_label: str + is_local: bool = False + def __init__( self, - is_local: bool, + type: LLMBackendType, temperature: float, max_output_tokens: int, model_name: str, api_key_label: str = "GEMINI_API_KEY", ): """ - Initialize the LLMBackend. - Args: - is_local (bool): Whether to use a local LLM or not. - temperature (float): The temperature for text generation. - max_output_tokens (int): The maximum number of output tokens. - model_name (str): The name of the model to use. + is_local (bool): Whether to use a local LLM or not. + temperature (float): The temperature for text generation. + max_output_tokens (int): The maximum number of output tokens. + model_name (str): The name of the model to use. """ - self.is_local = is_local + self.type = type + self.is_local = self.type == LLMBackendType.OLLAMA self.temperature = temperature self.max_output_tokens = max_output_tokens self.model_name = model_name - self.is_multimodal = not is_local # Does not assume local LLM is multimodal + self.api_key_label = api_key_label + logger.debug(f"LLMBackendConfig: {self.__dict__}") + + +class LLMBackend: + def __init__(self, llm_config: LLMBackendConfig): + """ + Initialize the LLMBackend with the specified configuration. + + Args: + + """ + self.config = llm_config common_params = { - "temperature": temperature, + "temperature": self.config.temperature, "presence_penalty": 0.75, # Encourage diverse content "frequency_penalty": 0.75, # Avoid repetition } - if is_local: - self.llm = Llamafile() # replace with ollama - elif ( - "gemini" in self.model_name.lower() - ): # keeping original gemini as a special case while we build confidence on LiteLLM + logger.debug(f"Initializing LLM backend: {self.config.type}") + if self.config.type == "ollama": + logger.debug(f"Initializing OllamaLLM with model {self.config.model_name}") + self.llm = OllamaLLM( + model=self.config.model_name, + temperature=self.config.temperature, + top_p=0.9, # default 0.9 + top_k=40, # default 40 + num_predict=self.config.max_output_tokens, + ) + elif self.config.type == "google": + logger.debug(f"Initializing ChatGoogleGenerativeAI with model {self.config.model_name}") self.llm = ChatGoogleGenerativeAI( api_key=os.environ["GEMINI_API_KEY"], - model=model_name, - max_output_tokens=max_output_tokens, + model=self.config.model_name, + max_output_tokens=self.config.max_output_tokens, **common_params, ) else: # user should set api_key_label from input + logger.debug(f"Initializing ChatLiteLLM for type {self.config.type} with model {self.config.model_name}") self.llm = ChatLiteLLM( - model=self.model_name, - temperature=temperature, - api_key=os.environ[api_key_label], + model=self.config.model_name, + temperature=self.config.temperature, + api_key=os.environ[self.config.api_key_label], ) class LongFormContentGenerator: """ Handles generation of long-form podcast conversations by breaking content into manageable chunks. - + Uses a "Content Chunking with Contextual Linking" strategy to maintain context between segments while generating longer conversations. - + Attributes: LONGFORM_INSTRUCTIONS (str): Constant containing instructions for long-form generation llm_chain: The LangChain chain used for content generation """ + # Add constant for long-form instructions LONGFORM_INSTRUCTIONS = """ Additional Instructions: @@ -98,27 +132,36 @@ class LongFormContentGenerator: 6. Maintain consistent voice throughout the extended discussion 7. Generate a long conversation - output max_output_tokens tokens """ - - def __init__(self, chain, llm, config_conversation: Dict[str, Any], ): + + def __init__( + self, + chain, + llm, + config_conversation: Dict[str, Any], + ): """ Initialize ConversationGenerator. - + Args: llm_chain: The LangChain chain to use for generation config_conversation: Conversation configuration dictionary """ self.llm_chain = chain self.llm = llm - self.max_num_chunks = config_conversation.get("max_num_chunks", 10) # Default if not in config - self.min_chunk_size = config_conversation.get("min_chunk_size", 200) # Default if not in config + self.max_num_chunks = config_conversation.get( + "max_num_chunks", 10 + ) # Default if not in config + self.min_chunk_size = config_conversation.get( + "min_chunk_size", 200 + ) # Default if not in config def __calculate_chunk_size(self, input_content: str) -> int: """ Calculate chunk size based on input content length. - + Args: input_content: Input text content - + Returns: Calculated chunk size that ensures: - Returns 1 if content length <= min_chunk_size @@ -128,114 +171,117 @@ def __calculate_chunk_size(self, input_content: str) -> int: input_length = len(input_content) if input_length <= self.min_chunk_size: return input_length - + maximum_chunk_size = input_length // self.max_num_chunks if maximum_chunk_size >= self.min_chunk_size: return maximum_chunk_size - + # Calculate chunk size that maximizes size while maintaining minimum chunks return input_length // (input_length // self.min_chunk_size) def chunk_content(self, input_content: str, chunk_size: int) -> List[str]: """ Split input content into manageable chunks while preserving context. - + Args: input_content (str): The input text to chunk chunk_size (int): Maximum size of each chunk - + Returns: List[str]: List of content chunks """ - sentences = input_content.split('. ') + sentences = input_content.split(". ") chunks = [] current_chunk = [] current_length = 0 - + for sentence in sentences: sentence_length = len(sentence) if current_length + sentence_length > chunk_size and current_chunk: - chunks.append('. '.join(current_chunk) + '.') + chunks.append(". ".join(current_chunk) + ".") current_chunk = [] current_length = 0 current_chunk.append(sentence) current_length += sentence_length - + if current_chunk: - chunks.append('. '.join(current_chunk) + '.') + chunks.append(". ".join(current_chunk) + ".") return chunks - def enhance_prompt_params(self, prompt_params: Dict, - part_idx: int, - total_parts: int, - chat_context: str) -> Dict: + def enhance_prompt_params( + self, prompt_params: Dict, part_idx: int, total_parts: int, chat_context: str + ) -> Dict: """ Enhance prompt parameters for long-form content generation. - + Args: prompt_params (Dict): Original prompt parameters part_idx (int): Index of current conversation part total_parts (int): Total number of conversation parts chat_context (str): Chat context from previous parts - + Returns: Dict: Enhanced prompt parameters with part-specific instructions """ enhanced_params = prompt_params.copy() - # Initialize part_instructions with chat context + # Initialize part_instructions with chat context enhanced_params["context"] = chat_context - + COMMON_INSTRUCTIONS = """ Podcast conversation so far is given in CONTEXT. Continue the natural flow of conversation. Follow-up on the very previous point/question without repeating topics or points already discussed! Hence, the transition should be smooth and natural. Avoid abrupt transitions. - Make sure the first to speak is different from the previous speaker. Look at the last tag in CONTEXT to determine the previous speaker. + Make sure the first to speak is different from the previous speaker. Look at the last tag in CONTEXT to determine the previous speaker. If last tag in CONTEXT is , then the first to speak now should be . If last tag in CONTEXT is , then the first to speak now should be . This is a live conversation without any breaks. Hence, avoid statemeents such as "we'll discuss after a short break. Stay tuned" or "Okay, so, picking up where we left off". - """ + """ # Add part-specific instructions if part_idx == 0: - enhanced_params["instruction"] = f""" + enhanced_params[ + "instruction" + ] = f""" ALWAYS START THE CONVERSATION GREETING THE AUDIENCE: Welcome to {enhanced_params["podcast_name"]} - {enhanced_params["podcast_tagline"]}. You are generating the Introduction part of a long podcast conversation. Don't cover any topics yet, just introduce yourself and the topic. Leave the rest for later parts, following these guidelines: """ elif part_idx == total_parts - 1: - enhanced_params["instruction"] = f""" - You are generating the last part of a long podcast conversation. + enhanced_params[ + "instruction" + ] = f""" + You are generating the last part of a long podcast conversation. {COMMON_INSTRUCTIONS} For this part, discuss the below INPUT and then make concluding remarks in a podcast conversation format and END THE CONVERSATION GREETING THE AUDIENCE WITH PERSON1 ALSO SAYING A GOOD BYE MESSAGE, following these guidelines: """ else: - enhanced_params["instruction"] = f""" + enhanced_params[ + "instruction" + ] = f""" You are generating part {part_idx+1} of {total_parts} parts of a long podcast conversation. {COMMON_INSTRUCTIONS} For this part, discuss the below INPUT in a podcast conversation format, following these guidelines: """ - + return enhanced_params - def generate_long_form( - self, - input_content: str, - prompt_params: Dict - ) -> str: + def generate_long_form(self, input_content: str, prompt_params: Dict) -> str: """ Generate a complete long-form conversation using chunked content. - + Args: input_content (str): Input text for conversation prompt_params (Dict): Base prompt parameters - + Returns: str: Generated long-form conversation """ # Add long-form instructions once at the beginning - prompt_params["user_instructions"] = prompt_params.get("user_instructions", "") + self.LONGFORM_INSTRUCTIONS - + prompt_params["user_instructions"] = ( + prompt_params.get("user_instructions", "") + self.LONGFORM_INSTRUCTIONS + ) + # Get chunk size chunk_size = self.__calculate_chunk_size(input_content) @@ -244,13 +290,13 @@ def generate_long_form( chat_context = input_content num_parts = len(chunks) print(f"Generating {num_parts} parts") - + for i, chunk in enumerate(chunks): enhanced_params = self.enhance_prompt_params( prompt_params, part_idx=i, total_parts=num_parts, - chat_context=chat_context + chat_context=chat_context, ) enhanced_params["input_text"] = chunk response = self.llm_chain.invoke(enhanced_params) @@ -259,20 +305,20 @@ def generate_long_form( else: chat_context = chat_context + response print(f"Generated part {i+1}/{num_parts}: Size {len(chunk)} characters.") - #print(f"[LLM-START] Step: {i+1} ##############################") - #print(response) - #print(f"[LLM-END] Step: {i+1} ##############################") + # print(f"[LLM-START] Step: {i+1} ##############################") + # print(response) + # print(f"[LLM-END] Step: {i+1} ##############################") conversation_parts.append(response) return self.stitch_conversations(conversation_parts) - + def stitch_conversations(self, parts: List[str]) -> str: """ Combine conversation parts with smooth transitions. - + Args: parts (List[str]): List of conversation parts - + Returns: str: Combined conversation """ @@ -284,12 +330,12 @@ def stitch_conversations(self, parts: List[str]) -> str: class ContentCleanerMixin: """ Mixin class containing common transcript cleaning operations. - + Provides reusable cleaning methods that can be used by different content generation strategies. Methods use protected naming convention (_method_name) as they are intended for internal use by the strategies. """ - + @staticmethod def _clean_scratchpad(text: str) -> str: """ @@ -297,12 +343,13 @@ def _clean_scratchpad(text: str) -> str: """ try: import re - pattern = r'```scratchpad\n.*?```\n?|```plaintext\n.*?```\n?|```\n?|\[.*?\]' - cleaned_text = re.sub(pattern, '', text, flags=re.DOTALL) + + pattern = r"```scratchpad\n.*?```\n?|```plaintext\n.*?```\n?|```\n?|\[.*?\]" + cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL) # Remove "xml" if followed by or cleaned_text = re.sub(r"xml(?=\s*)", "", cleaned_text) # Remove underscores around words - cleaned_text = re.sub(r'_(.*?)_', r'\1', cleaned_text) + cleaned_text = re.sub(r"_(.*?)_", r"\1", cleaned_text) return cleaned_text.strip() except Exception as e: logger.error(f"Error cleaning scratchpad content: {str(e)}") @@ -310,8 +357,7 @@ def _clean_scratchpad(text: str) -> str: @staticmethod def _clean_tss_markup( - input_text: str, - additional_tags: List[str] = ["Person1", "Person2"] + input_text: str, additional_tags: List[str] = ["Person1", "Person2"] ) -> str: """ Remove unsupported TSS markup tags while preserving supported ones. @@ -333,11 +379,9 @@ def _clean_tss_markup( cleaned_text, flags=re.DOTALL, ) - - return cleaned_text.strip() - + except Exception as e: logger.error(f"Error cleaning TSS markup: {str(e)}") return input_text @@ -346,38 +390,36 @@ def _clean_tss_markup( class ContentGenerationStrategy(ABC): """ Abstract base class defining the interface for content generation strategies. - + Defines the contract that all concrete strategies must implement, including validation, generation, and cleaning operations. """ - + @abstractmethod def validate(self, input_texts: str, image_file_paths: List[str]) -> None: """Validate inputs for this strategy.""" pass - + @abstractmethod - def generate(self, - chain, - input_texts: str, - prompt_params: Dict[str, Any], - **kwargs) -> str: + def generate( + self, chain, input_texts: str, prompt_params: Dict[str, Any], **kwargs + ) -> str: """Generate content using this strategy.""" pass - + @abstractmethod - def clean(self, - response: str, - config: Dict[str, Any]) -> str: + def clean(self, response: str, config: Dict[str, Any]) -> str: """Clean the generated response according to strategy.""" pass @abstractmethod - def compose_prompt_params(self, - config_conversation: Dict[str, Any], - image_file_paths: List[str] = [], - image_path_keys: List[str] = [], - input_texts: str = "") -> Dict[str, Any]: + def compose_prompt_params( + self, + config_conversation: Dict[str, Any], + image_file_paths: List[str] = [], + image_path_keys: List[str] = [], + input_texts: str = "", + ) -> Dict[str, Any]: """Compose prompt parameters according to strategy.""" pass @@ -385,15 +427,20 @@ def compose_prompt_params(self, class StandardContentStrategy(ContentGenerationStrategy, ContentCleanerMixin): """ Strategy for generating standard-length content. - + Implements basic content generation without chunking or special handling. Uses common cleaning operations from ContentCleanerMixin. """ - - def __init__(self, llm, content_generator_config: Dict[str, Any], config_conversation: Dict[str, Any]): + + def __init__( + self, + llm, + content_generator_config: Dict[str, Any], + config_conversation: Dict[str, Any], + ): """ Initialize StandardContentStrategy. - + Args: content_generator_config (Dict[str, Any]): Configuration for content generation config_conversation (Dict[str, Any]): Conversation configuration @@ -401,30 +448,28 @@ def __init__(self, llm, content_generator_config: Dict[str, Any], config_convers self.llm = llm self.content_generator_config = content_generator_config self.config_conversation = config_conversation - + def validate(self, input_texts: str, image_file_paths: List[str]) -> None: """No specific validation needed for standard content.""" pass - - def generate(self, - chain, - input_texts: str, - prompt_params: Dict[str, Any], - **kwargs) -> str: + + def generate( + self, chain, input_texts: str, prompt_params: Dict[str, Any], **kwargs + ) -> str: """Generate standard-length content.""" return chain.invoke(prompt_params) - - def clean(self, - response: str, - config: Dict[str, Any]) -> str: + + def clean(self, response: str, config: Dict[str, Any]) -> str: """Apply basic TSS markup cleaning.""" return self._clean_tss_markup(response) - def compose_prompt_params(self, - config_conversation: Dict[str, Any], - image_file_paths: List[str] = [], - image_path_keys: List[str] = [], - input_texts: str = "") -> Dict[str, Any]: + def compose_prompt_params( + self, + config_conversation: Dict[str, Any], + image_file_paths: List[str] = [], + image_path_keys: List[str] = [], + input_texts: str = "", + ) -> Dict[str, Any]: """Compose prompt parameters for standard content generation.""" prompt_params = { "input_text": input_texts, @@ -454,19 +499,24 @@ def compose_prompt_params(self, class LongFormContentStrategy(ContentGenerationStrategy, ContentCleanerMixin): """ Strategy for generating long-form content. - + Implements advanced content generation using chunking and context maintenance. Includes additional cleaning operations specific to long-form content. - + Note: - Only works with text input (no images) - Requires non-empty input text """ - - def __init__(self, llm, content_generator_config: Dict[str, Any], config_conversation: Dict[str, Any]): + + def __init__( + self, + llm, + content_generator_config: Dict[str, Any], + config_conversation: Dict[str, Any], + ): """ Initialize LongFormContentStrategy. - + Args: content_generator_config (Dict[str, Any]): Configuration for content generation config_conversation (Dict[str, Any]): Conversation configuration @@ -474,75 +524,71 @@ def __init__(self, llm, content_generator_config: Dict[str, Any], config_convers self.llm = llm self.content_generator_config = content_generator_config self.config_conversation = config_conversation - + def validate(self, input_texts: str, image_file_paths: List[str]) -> None: """Validate inputs for long-form generation.""" if not input_texts.strip(): raise ValueError("Long-form generation requires non-empty input text") if image_file_paths: raise ValueError("Long-form generation is not available with image inputs") - - def generate(self, - chain, - input_texts: str, - prompt_params: Dict[str, Any], - **kwargs) -> str: + + def generate( + self, chain, input_texts: str, prompt_params: Dict[str, Any], **kwargs + ) -> str: """Generate long-form content.""" generator = LongFormContentGenerator(chain, self.llm, self.config_conversation) - return generator.generate_long_form( - input_texts, - prompt_params - ) - - def clean(self, - response: str, - config: Dict[str, Any]) -> str: + return generator.generate_long_form(input_texts, prompt_params) + + def clean(self, response: str, config: Dict[str, Any]) -> str: """Apply enhanced cleaning for long-form content.""" # First apply standard cleaning using common method standard_clean = self._clean_tss_markup(response) # Then apply additional long-form specific cleaning return self._clean_transcript_response(standard_clean, config) - - def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> str: + + def _clean_transcript_response( + self, transcript: str, config: Dict[str, Any] + ) -> str: """ Clean transcript using a two-step process with LLM-based cleaning. - + First cleans the markup using a specialized prompt template, then rewrites for better flow and consistency using a second prompt template. - + Args: transcript (str): Raw transcript text that may contain scratchpad blocks config (Dict[str, Any]): Configuration dictionary containing LLM and prompt settings - + Returns: str: Cleaned and rewritten transcript with proper tags and improved flow - + Note: Falls back to original or partially cleaned transcript if any cleaning step fails """ logger.debug("Starting transcript cleaning process") final_transcript = self._fix_alternating_tags(transcript) - + logger.debug("Completed transcript cleaning process") - + return final_transcript - - def _clean_transcript_response_DEPRECATED(self, transcript: str, config: Dict[str, Any]) -> str: + def _clean_transcript_response_DEPRECATED( + self, transcript: str, config: Dict[str, Any] + ) -> str: """ Clean transcript using a two-step process with LLM-based cleaning. - + First cleans the markup using a specialized prompt template, then rewrites for better flow and consistency using a second prompt template. - + Args: transcript (str): Raw transcript text that may contain scratchpad blocks config (Dict[str, Any]): Configuration dictionary containing LLM and prompt settings - + Returns: str: Cleaned and rewritten transcript with proper tags and improved flow - + Note: Falls back to original or partially cleaned transcript if any cleaning step fails """ @@ -550,30 +596,34 @@ def _clean_transcript_response_DEPRECATED(self, transcript: str, config: Dict[st try: logger.debug("Initializing LLM model for cleaning") # Initialize model with config values for consistent cleaning - #llm = ChatGoogleGenerativeAI( + # llm = ChatGoogleGenerativeAI( # model=self.content_generator_config["meta_llm_model"], # temperature=0, # presence_penalty=0.75, # Encourage diverse content # frequency_penalty=0.75 # Avoid repetition - #) + # ) llm = self.llm logger.debug("LLM model initialized successfully") # Get prompt templates from hub logger.debug("Pulling prompt templates from hub") try: - clean_transcript_prompt = hub.pull(f"{self.content_generator_config['cleaner_prompt_template']}:{self.content_generator_config['cleaner_prompt_commit']}") - rewrite_prompt = hub.pull(f"{self.content_generator_config['rewriter_prompt_template']}:{self.content_generator_config['rewriter_prompt_commit']}") + clean_transcript_prompt = hub.pull( + f"{self.content_generator_config['cleaner_prompt_template']}:{self.content_generator_config['cleaner_prompt_commit']}" + ) + rewrite_prompt = hub.pull( + f"{self.content_generator_config['rewriter_prompt_template']}:{self.content_generator_config['rewriter_prompt_commit']}" + ) logger.debug("Successfully pulled prompt templates") except Exception as e: logger.error(f"Error pulling prompt templates: {str(e)}") return transcript - + logger.debug("Creating cleaning and rewriting chains") # Create chains clean_chain = clean_transcript_prompt | llm | StrOutputParser() rewrite_chain = rewrite_prompt | llm | StrOutputParser() - + # Run cleaning chain logger.debug("Executing cleaning chain") try: @@ -585,11 +635,13 @@ def _clean_transcript_response_DEPRECATED(self, transcript: str, config: Dict[st except Exception as e: logger.error(f"Error in cleaning chain: {str(e)}") return transcript - + # Run rewriting chain logger.debug("Executing rewriting chain") try: - rewritten_response = rewrite_chain.invoke({"transcript": cleaned_response}) + rewritten_response = rewrite_chain.invoke( + {"transcript": cleaned_response} + ) if not rewritten_response: logger.warning("Rewriting chain returned empty response") return cleaned_response # Fall back to cleaned version @@ -597,14 +649,14 @@ def _clean_transcript_response_DEPRECATED(self, transcript: str, config: Dict[st except Exception as e: logger.error(f"Error in rewriting chain: {str(e)}") return cleaned_response # Fall back to cleaned version - + # Fix alternating tags in the final response logger.debug("Fixing alternating tags") final_transcript = self._fix_alternating_tags(rewritten_response) logger.debug("Completed transcript cleaning process") - + return final_transcript - + except Exception as e: logger.error(f"Error in transcript cleaning process: {str(e)}") return transcript # Return original if cleaning fails @@ -612,16 +664,16 @@ def _clean_transcript_response_DEPRECATED(self, transcript: str, config: Dict[st def _fix_alternating_tags(self, transcript: str) -> str: """ Ensures transcript has properly alternating Person1 and Person2 tags. - + Merges consecutive same-person tags and ensures proper tag alternation throughout the transcript. - + Args: transcript (str): Input transcript text that may have consecutive same-person tags - + Returns: str: Transcript with properly alternating tags and merged content - + Example: Input: Hello @@ -630,31 +682,31 @@ def _fix_alternating_tags(self, transcript: str) -> str: Output: Hello World Hi - + Note: Returns original transcript if cleaning fails """ try: # Split into individual tag blocks while preserving tags - pattern = r'(.*?)' + pattern = r"(.*?)" blocks = re.split(pattern, transcript, flags=re.DOTALL) - + # Filter out empty/whitespace blocks blocks = [b.strip() for b in blocks if b.strip()] - + merged_blocks = [] current_content = [] current_person = None - + for block in blocks: # Extract person number and content - match = re.match(r'(.*?)', block, re.DOTALL) + match = re.match(r"(.*?)", block, re.DOTALL) if not match: continue - + person_num, content = match.groups() content = content.strip() - + if current_person == person_num: # Same person - append content current_content.append(content) @@ -662,27 +714,33 @@ def _fix_alternating_tags(self, transcript: str) -> str: # Different person - flush current content if any if current_content: merged_text = " ".join(current_content) - merged_blocks.append(f"{merged_text}") + merged_blocks.append( + f"{merged_text}" + ) # Start new person current_person = person_num current_content = [content] - + # Flush final content if current_content: merged_text = " ".join(current_content) - merged_blocks.append(f"{merged_text}") - + merged_blocks.append( + f"{merged_text}" + ) + return "\n".join(merged_blocks) - + except Exception as e: logger.error(f"Error fixing alternating tags: {str(e)}") return transcript # Return original if fixing fails - def compose_prompt_params(self, - config_conversation: Dict[str, Any], - image_file_paths: List[str] = [], - image_path_keys: List[str] = [], - input_texts: str = "") -> Dict[str, Any]: + def compose_prompt_params( + self, + config_conversation: Dict[str, Any], + image_file_paths: List[str] = [], + image_path_keys: List[str] = [], + input_texts: str = "", + ) -> Dict[str, Any]: """Compose prompt parameters for long-form content generation.""" return { "conversation_style": ", ".join( @@ -704,11 +762,11 @@ def compose_prompt_params(self, class ContentGenerator: def __init__( - self, - is_local: bool=False, - model_name: str="gemini-1.5-pro-latest", - api_key_label: str="GEMINI_API_KEY", - conversation_config: Optional[Dict[str, Any]] = None + self, + backend_type: LLMBackendType = LLMBackendType.GOOGLE, + model_name: str|None = "gemini-1.5-pro-latest", + api_key_label: str|None = "GEMINI_API_KEY", + conversation_config: Optional[Dict[str, Any]] = None, ): """ Initialize the ContentGenerator. @@ -717,7 +775,6 @@ def __init__( api_key (str): API key for Google's Generative AI. conversation_config (Optional[Dict[str, Any]]): Custom conversation configuration. """ - #os.environ["GOOGLE_API_KEY"] = api_key self.config = load_config() self.content_generator_config = self.config.get("content_generator", {}) @@ -732,53 +789,45 @@ def __init__( if transcripts_dir and not os.path.exists(transcripts_dir): os.makedirs(transcripts_dir) - - self.is_local = is_local - # Initialize LLM backend + # Initialize LLM backend if not model_name: model_name = self.content_generator_config.get("llm_model") - if is_local: - model_name = "User provided local model" + # if is_local: + # model_name = "User provided local model" llm_backend = LLMBackend( - is_local=is_local, - temperature=self.config_conversation.get("creativity", 1), - max_output_tokens=self.content_generator_config.get( - "max_output_tokens", 8192 - ), - model_name=model_name, - api_key_label=api_key_label, + LLMBackendConfig( + backend_type, + self.config_conversation.get("creativity", 1), + self.content_generator_config.get("max_output_tokens", 8192), + model_name or "gemini-1.5-pro-latest", + api_key_label or "GEMINI_API_KEY", + ) ) self.llm = llm_backend.llm - - # Initialize strategies with configs self.strategies = { True: LongFormContentStrategy( - self.llm, - self.content_generator_config, - self.config_conversation + self.llm, self.content_generator_config, self.config_conversation ), False: StandardContentStrategy( - self.llm, - self.content_generator_config, - self.config_conversation - ) + self.llm, self.content_generator_config, self.config_conversation + ), } - def __compose_prompt(self, num_images: int, longform: bool=False): + def __compose_prompt(self, num_images: int, longform: bool = False): """ Compose the prompt for the LLM based on the content list. """ content_generator_config = self.config.get("content_generator", {}) - + # Get base template and commit values base_template = content_generator_config.get("prompt_template") base_commit = content_generator_config.get("prompt_commit") - + # Modify template and commit for longform if configured if longform: template = content_generator_config.get("longform_prompt_template") @@ -839,7 +888,7 @@ def generate_qa_content( input_texts: str = "", image_file_paths: List[str] = [], output_filepath: Optional[str] = None, - longform: bool = False + longform: bool = False, ) -> str: """ Generate Q&A content based on input texts. @@ -848,7 +897,6 @@ def generate_qa_content( input_texts (str): Input texts to generate content from. image_file_paths (List[str]): List of image file paths. output_filepath (Optional[str]): Filepath to save the response content. - is_local (bool): Whether to use a local LLM or not. model_name (str): Model name to use for generation. api_key_label (str): Environment variable name for API key. longform (bool): Whether to generate long-form content. Defaults to False. @@ -863,38 +911,29 @@ def generate_qa_content( try: # Get appropriate strategy strategy = self.strategies[longform] - + # Validate inputs for chosen strategy strategy.validate(input_texts, image_file_paths) # Setup chain - num_images = 0 if self.is_local else len(image_file_paths) - self.prompt_template, image_path_keys = self.__compose_prompt(num_images, longform) + num_images = len(image_file_paths) + self.prompt_template, image_path_keys = self.__compose_prompt( + num_images, longform + ) self.parser = StrOutputParser() self.chain = self.prompt_template | self.llm | self.parser - # Prepare parameters using strategy prompt_params = strategy.compose_prompt_params( - self.config_conversation, - image_file_paths, - image_path_keys, - input_texts + self.config_conversation, image_file_paths, image_path_keys, input_texts ) # Generate content using selected strategy - self.response = strategy.generate( - self.chain, - input_texts, - prompt_params - ) + self.response = strategy.generate(self.chain, input_texts, prompt_params) # Clean response using the same strategy - self.response = strategy.clean( - self.response, - self.content_generator_config - ) - + self.response = strategy.clean(self.response, self.content_generator_config) + logger.info(f"Content generated successfully") # Save output if requested @@ -905,7 +944,7 @@ def generate_qa_content( print(f"Transcript saved to {output_filepath}") return self.response - + except Exception as e: logger.error(f"Error generating content: {str(e)}") raise diff --git a/podcastfy/text_to_speech.py b/podcastfy/text_to_speech.py index 5a2beed..b95275c 100644 --- a/podcastfy/text_to_speech.py +++ b/podcastfy/text_to_speech.py @@ -43,7 +43,9 @@ def __init__( # Get API key from config if not provided if not api_key: - api_key = getattr(self.config, f"{model.upper().replace('MULTI', '')}_API_KEY", None) + api_key = getattr( + self.config, f"{model.upper().replace('MULTI', '')}_API_KEY", None + ) # Initialize provider using factory self.provider = TTSProviderFactory.create( @@ -114,29 +116,33 @@ def convert_to_speech(self, text: str, output_file: str) -> None: if not audio_data_list: raise ValueError("No audio data chunks provided") - logger.info(f"Starting audio processing with {len(audio_data_list)} chunks") + logger.info( + f"Starting audio processing with {len(audio_data_list)} chunks" + ) combined = AudioSegment.empty() - + for i, chunk in enumerate(audio_data_list): # Save chunk to temporary file - #temp_file = "./tmp.mp3" - #with open(temp_file, "wb") as f: + # temp_file = "./tmp.mp3" + # with open(temp_file, "wb") as f: # f.write(chunk) - + segment = AudioSegment.from_file(io.BytesIO(chunk)) - logger.info(f"################### Loaded chunk {i}, duration: {len(segment)}ms") - + logger.info( + f"################### Loaded chunk {i}, duration: {len(segment)}ms" + ) + combined += segment - + # Export with high quality settings os.makedirs(os.path.dirname(output_file), exist_ok=True) combined.export( - output_file, + output_file, format=self.audio_format, codec="libmp3lame", - bitrate="320k" + bitrate="320k", ) - + except Exception as e: logger.error(f"Error during audio processing: {str(e)}") raise @@ -223,7 +229,11 @@ def get_sort_key(file_path: str) -> Tuple[int, int]: def _setup_directories(self) -> None: """Setup required directories for audio processing.""" self.output_directories = self.tts_config.get("output_directories", {}) - temp_dir = self.tts_config.get("temp_audio_dir", "data/audio/tmp/").rstrip("/").split("/") + temp_dir = ( + self.tts_config.get("temp_audio_dir", "data/audio/tmp/") + .rstrip("/") + .split("/") + ) self.temp_audio_dir = os.path.join(*temp_dir) base_dir = os.path.abspath(os.path.dirname(__file__)) self.temp_audio_dir = os.path.join(base_dir, self.temp_audio_dir) diff --git a/requirements.txt b/requirements.txt index 1ecbc63..02ab73d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -75,6 +75,7 @@ langchain-google-genai==2.0.4 ; python_version >= "3.11" and python_version < "4 langchain-google-vertexai==2.0.7 ; python_version >= "3.11" and python_version < "4.0" langchain-text-splitters==0.3.2 ; python_version >= "3.11" and python_version < "4.0" langchain==0.3.7 ; python_version >= "3.11" and python_version < "4.0" +langchain-ollama==0.2.2 ; python_version >= "3.11" and python_version < "4.0" langsmith==0.1.141 ; python_version >= "3.11" and python_version < "4.0" levenshtein==0.26.1 ; python_version >= "3.11" and python_version < "4.0" litellm==1.52.2 ; python_version >= "3.11" and python_version < "4.0" diff --git a/tests/test_genai_podcast.py b/tests/test_genai_podcast.py index 691d78c..a134b7c 100644 --- a/tests/test_genai_podcast.py +++ b/tests/test_genai_podcast.py @@ -3,7 +3,7 @@ from unittest.mock import patch, MagicMock import tempfile import os -from podcastfy.content_generator import ContentGenerator +from podcastfy.content_generator import ContentGenerator, LLMBackendType from podcastfy.utils.config import Config from podcastfy.utils.config_conversation import ConversationConfig from podcastfy.content_parser.pdf_extractor import PDFExtractor @@ -15,6 +15,11 @@ "https://raw.githubusercontent.com/souzatharsis/podcastfy/refs/heads/main/data/images/connection.jpg", ] +# BACKEND_TYPE = LLMBackendType.OLLAMA +# MODEL_NAME = "llama3.2:3b-instruct-fp16" # "gemini-1.5-pro-latest" +# API_KEY_LABEL = "OLLAMA_API_KEY" + +BACKEND_TYPE = LLMBackendType.GOOGLE MODEL_NAME = "gemini-1.5-pro-latest" API_KEY_LABEL = "GEMINI_API_KEY" @@ -44,7 +49,12 @@ def test_generate_qa_content(self): """ Test the generate_qa_content method of ContentGenerator. """ - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL) + print( + f"Creating backend with model_name={MODEL_NAME} and api_key_label={API_KEY_LABEL}" + ) + content_generator = ContentGenerator( + BACKEND_TYPE, model_name=MODEL_NAME, api_key_label=API_KEY_LABEL + ) input_text = "United States of America" result = content_generator.generate_qa_content(input_text) self.assertIsNotNone(result) @@ -56,7 +66,12 @@ def test_custom_conversation_config(self): Test the generation of content using a custom conversation configuration file. """ conversation_config = sample_conversation_config() - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL, conversation_config=conversation_config) + content_generator = ContentGenerator( + BACKEND_TYPE, + model_name=MODEL_NAME, + api_key_label=API_KEY_LABEL, + conversation_config=conversation_config, + ) input_text = "United States of America" result = content_generator.generate_qa_content(input_text) @@ -73,7 +88,9 @@ def test_generate_qa_content_from_images(self): """Test generating Q&A content from two input images.""" image_paths = MOCK_IMAGE_PATHS - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL) + content_generator = ContentGenerator( + BACKEND_TYPE, model_name=MODEL_NAME, api_key_label=API_KEY_LABEL + ) with tempfile.NamedTemporaryFile( mode="w+", suffix=".txt", delete=False @@ -100,7 +117,9 @@ def test_generate_qa_content_from_images(self): def test_generate_qa_content_from_pdf(self): """Test generating Q&A content from a PDF file.""" pdf_file = "tests/data/pdf/file.pdf" - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL) + content_generator = ContentGenerator( + BACKEND_TYPE, model_name=MODEL_NAME, api_key_label=API_KEY_LABEL + ) pdf_extractor = PDFExtractor() # Extract content from the PDF file @@ -116,7 +135,9 @@ def test_generate_qa_content_from_pdf(self): def test_generate_qa_content_from_raw_text(self): """Test generating Q&A content from raw input text.""" raw_text = "The wonderful world of LLMs." - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL) + content_generator = ContentGenerator( + BACKEND_TYPE, model_name=MODEL_NAME, api_key_label=API_KEY_LABEL + ) result = content_generator.generate_qa_content(input_texts=raw_text) @@ -128,7 +149,9 @@ def test_generate_qa_content_from_raw_text(self): def test_generate_qa_content_from_topic(self): """Test generating Q&A content from a specific topic.""" topic = "Latest news about OpenAI" - content_generator = ContentGenerator(model_name=MODEL_NAME, api_key_label=API_KEY_LABEL) + content_generator = ContentGenerator( + BACKEND_TYPE, model_name=MODEL_NAME, api_key_label=API_KEY_LABEL + ) extractor = ContentExtractor() topic = "Latest news about OpenAI" diff --git a/tests/test_generate_podcast.py b/tests/test_generate_podcast.py index 0ad8c7f..9a867a6 100644 --- a/tests/test_generate_podcast.py +++ b/tests/test_generate_podcast.py @@ -223,6 +223,7 @@ def test_generate_from_local_pdf(sample_config): assert audio_file.endswith(".mp3") assert os.path.getsize(audio_file) > 1024 # Check if larger than 1KB + @pytest.mark.skip(reason="Testing edge only on Github Action as it's free") def test_generate_from_local_pdf_multispeaker(sample_config): """Test generating a podcast from a local PDF file.""" @@ -235,6 +236,7 @@ def test_generate_from_local_pdf_multispeaker(sample_config): assert audio_file.endswith(".mp3") assert os.path.getsize(audio_file) > 1024 # Check if larger than 1KB + @pytest.mark.skip(reason="Testing edge only on Github Action as it's free") def test_generate_from_local_pdf_multispeaker_longform(sample_config): """Test generating a podcast from a local PDF file.""" @@ -247,6 +249,7 @@ def test_generate_from_local_pdf_multispeaker_longform(sample_config): assert audio_file.endswith(".mp3") assert os.path.getsize(audio_file) > 1024 # Check if larger than 1KB + def test_generate_podcast_no_urls_or_transcript(): """Test that an error is raised when no URLs or transcript file is provided.""" with pytest.raises(ValueError): @@ -413,33 +416,34 @@ def test_generate_transcript_only_with_custom_llm( def test_generate_longform_transcript(sample_config, default_conversation_config): """Test generating a longform podcast transcript from a PDF file.""" pdf_file = "tests/data/pdf/file.pdf" - + # Generate transcript with longform=True result = generate_podcast( - urls=[pdf_file], - config=sample_config, - transcript_only=True, - longform=True + urls=[pdf_file], config=sample_config, transcript_only=True, longform=True ) assert result is not None assert os.path.exists(result) assert result.endswith(".txt") - + # Read and verify the content with open(result, "r") as f: content = f.read() - + # Verify the content follows the Person1/Person2 format assert "" in content assert "" in content - + # Verify it's a long-form transcript (>1000 characters) - assert len(content) > 1000, f"Content length ({len(content)}) is less than minimum expected for longform (1000)" - + assert ( + len(content) > 1000 + ), f"Content length ({len(content)}) is less than minimum expected for longform (1000)" + # Verify multiple discussion rounds exist (characteristic of longform) person1_segments = content.count("") - assert person1_segments > 3, f"Expected more than 3 discussion rounds, got {person1_segments}" + assert ( + person1_segments > 3 + ), f"Expected more than 3 discussion rounds, got {person1_segments}" if __name__ == "__main__":