Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.1.2 Updates #22

Merged
merged 6 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ An intelligent system that automatically generates engaging podcast conversation

Listen to sample podcasts generated using Podcast-LLM:

### Structured JSON Output from LLMs (Google multispeaker voices)

[![Play Podcast Sample](https://img.shields.io/badge/Play%20Podcast-brightgreen?style=for-the-badge&logo=soundcloud)](https://soundcloud.com/evan-dempsey-153309617/llm-structured-output)

### UFO Crash Retrieval (Elevenlabs voices)

[![Play Podcast Sample](https://img.shields.io/badge/Play%20Podcast-brightgreen?style=for-the-badge&logo=soundcloud)](https://soundcloud.com/evan-dempsey-153309617/ufo-crash-retrieval-elevenlabs-voices)
Expand Down
77 changes: 60 additions & 17 deletions podcast_llm/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pydantic
from typing import Any, Optional, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
Expand Down Expand Up @@ -74,7 +75,7 @@
self.provider = provider
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens,
self.max_tokens = max_tokens
self.rate_limiter = rate_limiter
self.parser = StrOutputParser()
self.schema = None
Expand All @@ -89,7 +90,48 @@
raise ValueError(f"The LLM provider value '{self.provider}' is not supported.")

model_class = provider_to_model[self.provider]
self.llm = model_class(model=self.model, rate_limiter=self.rate_limiter)
self.llm = model_class(model=self.model, rate_limiter=self.rate_limiter, max_tokens=self.max_tokens)

def coerce_to_schema(self, llm_output: str):
"""
Coerce raw LLM output into a structured schema object.

Takes unstructured text output from the LLM and attempts to parse it into
a structured Pydantic object based on the defined schema. Currently supports
Question and Answer schema types.

Args:
llm_output (str): Raw text output from the LLM to be coerced

Returns:
BaseModel: Pydantic object matching the defined schema type

Raises:
ValueError: If no schema is defined
OutputParserException: If output cannot be coerced to the schema

The coercion maps the raw text to the appropriate schema field:
- Question schema -> 'question' field
- Answer schema -> 'answer' field
"""
if not self.schema:
raise ValueError('Schema is not defined.')

schema_class_name = self.schema.__name__

if schema_class_name == 'Question':
schema_field_name = 'question'
elif schema_class_name == 'Answer':
schema_field_name = 'answer'
else:
raise OutputParserException(
f"Unable to coerce output to schema: {schema_class_name}",
llm_output=llm_output
)

schema_values = {schema_field_name: llm_output}
pydantic_object = self.schema(**schema_values)
return pydantic_object

def invoke(
self,
Expand Down Expand Up @@ -117,22 +159,23 @@
- Google: Custom handling for structured output via parser and format instructions
"""
logger.debug(f"Invoking LLM with prompt:\n{input.to_string()}")
prompt = input

Check warning on line 162 in podcast_llm/utils/llm.py

View check run for this annotation

Codecov / codecov/patch

podcast_llm/utils/llm.py#L162

Added line #L162 was not covered by tests

if self.provider == 'google' and self.schema is not None:
format_instructions = self.parser.get_format_instructions()

Check warning on line 165 in podcast_llm/utils/llm.py

View check run for this annotation

Codecov / codecov/patch

podcast_llm/utils/llm.py#L164-L165

Added lines #L164 - L165 were not covered by tests

logger.debug(f"LLM provider is {self.provider} and schema is provided. Adding format instructions to prompt:\n{format_instructions}")
messages = input.to_messages()
messages[0] = SystemMessage(content=f"{messages[0].content}\n{format_instructions}")
prompt = ChatPromptValue(messages=messages)
logger.debug(f"Modified prompt:\n{prompt.to_string()}")

Check warning on line 171 in podcast_llm/utils/llm.py

View check run for this annotation

Codecov / codecov/patch

podcast_llm/utils/llm.py#L167-L171

Added lines #L167 - L171 were not covered by tests

try:
return self.llm.invoke(input=prompt, config=config)
except OutputParserException as ex:
logger.debug(f"Error parsing LLM output. Coercing to fit schema.\n{ex.llm_output}")
return self.coerce_to_schema(ex.llm_output)

Check warning on line 177 in podcast_llm/utils/llm.py

View check run for this annotation

Codecov / codecov/patch

podcast_llm/utils/llm.py#L173-L177

Added lines #L173 - L177 were not covered by tests

if self.provider in ('openai', 'anthropic',):
return self.llm.invoke(input=input, config=config)
elif self.provider == 'google':
if self.schema is not None:
format_instructions = self.parser.get_format_instructions()

logger.debug(f"LLM provider is {self.provider} and schema is provided. Adding format instructions to prompt:\n{format_instructions}")
messages = input.to_messages()
messages[0] = SystemMessage(content=f"{messages[0].content}\n{format_instructions}")
prompt = ChatPromptValue(messages=messages)
logger.debug(f"Modified prompt:\n{prompt.to_string()}")

return self.llm.invoke(input=prompt, config=config)
else:
return self.llm.invoke(input=input, config=config)

def with_structured_output(
self,
Expand Down
3 changes: 2 additions & 1 deletion podcast_llm/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@

return interviewee_chain.invoke({
'topic': topic,
'outline': outline,
'outline': outline.as_str,
'section': section.title,
'subsection': subsection.title,
'word_count': 100,
Expand Down Expand Up @@ -388,6 +388,7 @@

# Process script in batches of bath_size
for i in range(0, len(draft_script), batch_size):
logger.info(f"Rewriting lines {i+1} to {i+batch_size} of {len(draft_script)}")

Check warning on line 391 in podcast_llm/writer.py

View check run for this annotation

Codecov / codecov/patch

podcast_llm/writer.py#L391

Added line #L391 was not covered by tests
batch = draft_script[i:i + batch_size]
final_script.extend(rewrite_script_section(batch, rewriter_chain))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "podcast-llm"
version = "0.1.1"
version = "0.1.2"
description = "An intelligent system that automatically generates engaging podcast conversations using LLMs and text-to-speech technology."
authors = ["Evan Dempsey <[email protected]>"]
license = "Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)"
Expand Down
132 changes: 132 additions & 0 deletions tests/test_utils_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pydantic
import pytest
from langchain_core.exceptions import OutputParserException
from podcast_llm.utils.llm import LLMWrapper, get_fast_llm, get_long_context_llm
from podcast_llm.config import PodcastConfig


def test_llm_wrapper_initialization_with_supported_providers():
"""Test that LLMWrapper initializes correctly with supported providers."""
supported_providers = ['openai', 'google', 'anthropic']
model_name = 'test-model-name'
for provider_name in supported_providers:
llm_wrapper_instance = LLMWrapper(
provider=provider_name,
model=model_name
)
assert llm_wrapper_instance.provider == provider_name
assert llm_wrapper_instance.model == model_name
assert llm_wrapper_instance.temperature == 1.0
assert llm_wrapper_instance.max_tokens == 8192
assert llm_wrapper_instance.rate_limiter is None
assert llm_wrapper_instance.llm is not None

def test_llm_wrapper_initialization_with_unsupported_provider():
"""Test that LLMWrapper raises ValueError when initialized with an unsupported provider."""
with pytest.raises(ValueError) as exception_info:
LLMWrapper(provider='unsupported_provider', model='test-model-name')
assert "The LLM provider value 'unsupported_provider' is not supported." in str(exception_info.value)

def test_llm_wrapper_with_structured_output_method():
"""Test that LLMWrapper configures structured output correctly."""
class MockSchema(pydantic.BaseModel):
example_field: str

llm_wrapper_instance = LLMWrapper(
provider='openai',
model='test-model-name'
)
llm_wrapper_instance = llm_wrapper_instance.with_structured_output(MockSchema)
assert llm_wrapper_instance.llm is not None
assert llm_wrapper_instance.schema is None # For OpenAI provider, schema should not be set
# Assuming with_structured_output returns self

def test_llm_wrapper_coerce_to_schema():
"""Test that LLMWrapper.coerce_to_schema correctly converts output to schema objects."""
class Question(pydantic.BaseModel):
question: str

class Answer(pydantic.BaseModel):
answer: str

class OtherSchema(pydantic.BaseModel):
other: str

llm_wrapper = LLMWrapper(provider='openai', model='test-model')

# Test with Question schema
llm_wrapper.schema = Question
question_output = llm_wrapper.coerce_to_schema("What is the meaning of life?")
assert isinstance(question_output, Question)
assert question_output.question == "What is the meaning of life?"

# Test with Answer schema
llm_wrapper.schema = Answer
answer_output = llm_wrapper.coerce_to_schema("42")
assert isinstance(answer_output, Answer)
assert answer_output.answer == "42"

# Test with no schema defined
llm_wrapper.schema = None
with pytest.raises(ValueError) as exc_info:
llm_wrapper.coerce_to_schema("test output")
assert "Schema is not defined" in str(exc_info.value)

# Test with unsupported schema type
llm_wrapper.schema = OtherSchema
with pytest.raises(OutputParserException) as exc_info:
llm_wrapper.coerce_to_schema("test output")
assert "Unable to coerce output to schema: OtherSchema" in str(exc_info.value)


def test_get_fast_llm_with_supported_provider():
"""Test that get_fast_llm returns an LLMWrapper with the correct fast model."""
config_instance = PodcastConfig.load()
config_instance.fast_llm_provider='openai'
rate_limiter_instance = None
fast_llm_instance = get_fast_llm(
config=config_instance,
rate_limiter=rate_limiter_instance
)
assert isinstance(fast_llm_instance, LLMWrapper)
assert fast_llm_instance.provider == 'openai'
assert fast_llm_instance.model == 'gpt-4o-mini'

def test_get_fast_llm_with_unsupported_provider():
"""Test that get_fast_llm raises ValueError when given an unsupported provider."""
config_instance = PodcastConfig.load()
config_instance.fast_llm_provider='unsupported_provider'
rate_limiter_instance = None
with pytest.raises(ValueError) as exception_info:
get_fast_llm(
config=config_instance,
rate_limiter=rate_limiter_instance
)
assert "The fast_llm_provider value 'unsupported_provider' is not supported." in str(exception_info.value)

def test_get_long_context_llm_with_supported_provider():
"""Test that get_long_context_llm returns an LLMWrapper with the correct long context model."""
config_instance = PodcastConfig.load()
config_instance.long_context_llm_provider='anthropic'
rate_limiter_instance = None
long_context_llm_instance = get_long_context_llm(
config=config_instance,
rate_limiter=rate_limiter_instance
)
assert isinstance(long_context_llm_instance, LLMWrapper)
assert long_context_llm_instance.provider == 'anthropic'
assert long_context_llm_instance.model == 'claude-3-5-sonnet-20241022'

def test_get_long_context_llm_with_unsupported_provider():
"""Test that get_long_context_llm raises ValueError when given an unsupported provider."""
config_instance = PodcastConfig.load()
config_instance.long_context_llm_provider='unsupported_provider'
rate_limiter_instance = None
with pytest.raises(ValueError) as exception_info:
get_long_context_llm(
config=config_instance,
rate_limiter=rate_limiter_instance
)
assert "The long_context_llm_provider value 'unsupported_provider' is not supported." in str(exception_info.value)


Loading