-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
45 lines (31 loc) · 1.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from dotenv import load_dotenv
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_openai import ChatOpenAI
from pydantic.v1 import BaseModel, Field
class Country(BaseModel):
capital: str = Field(description="The capital of the country")
name: str = Field(description="The name of the country")
load_dotenv()
OPENAI_MODEL = "gpt-4"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PROMPT_COUNTRY_INFO = """
Provide information about the country of {country}. If the country doesn't exist, make up something.
{format_instructions}
"""
def main():
parser = PydanticOutputParser(pydantic_object=Country)
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name=OPENAI_MODEL)
# get user input
country = input("Enter a country: ")
message = HumanMessagePromptTemplate.from_template(template=PROMPT_COUNTRY_INFO)
chat_prompt = ChatPromptTemplate.from_messages(messages=[message])
chat_prompt_with_values = chat_prompt.format_prompt(
country=country, format_instructions=parser.get_format_instructions()
)
response = llm(chat_prompt_with_values.to_messages())
data = parser.parse(response.content)
print(f"The capital of {data.name} is {data.capital}")
if __name__ == "__main__":
main()