Skip to content

Commit 2a1ba84

Browse files
authored
Merge pull request #57 Resolve #41, #16
Resolve #41 Resolve #16
2 parents 72779d9 + 5c662d4 commit 2a1ba84

File tree

4 files changed

+6216
-14
lines changed

4 files changed

+6216
-14
lines changed

.env.example

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
###############################################
2+
############## LLM API SELECTION ##############
3+
###############################################
4+
5+
# LLM_PROVIDER=openai
6+
# OPEN_AI_LLM_KEY=
7+
# OPEN_AI_LLM_MODEL=gpt-4o
8+
9+
# LLM_PROVIDER=gemini
10+
# GEMINI_API_KEY=
11+
# GEMINI_LLM_MODEL=gemini-2.0-flash-lite
12+
13+
# LLM_PROVIDER=azure
14+
# AZURE_OPENAI_LLM_ENDPOINT=
15+
# AZURE_OPENAI_LLM_KEY=
16+
# AZURE_OPENAI_LLM_MODEL=
17+
# AZURE_OPENAI_LLM_API_VERSION=
18+
19+
# LLM_PROVIDER=ollama
20+
# OLLAMA_LLM_BASE_URL=
21+
# OLLAMA_LLM_MODEL=
22+
23+
# LLM_PROVIDER=huggingface
24+
# HUGGING_FACE_LLM_REPO_ID=
25+
# HUGGING_FACE_LLM_ENDPOINT=
26+
# HUGGING_FACE_LLM_API_TOKEN=
27+
28+
# LLM_PROVIDER=bedrock
29+
# AWS_BEDROCK_LLM_ACCESS_KEY_ID=
30+
# AWS_BEDROCK_LLM_SECRET_ACCESS_KEY=
31+
# AWS_BEDROCK_LLM_REGION=us-west-2
32+
# AWS_BEDROCK_LLM_ENDPOINT_URL=https://bedrock.us-west-2.amazonaws.com
33+
# AWS_BEDROCK_LLM_MODEL=anthropic.claude-3-5-sonnet-20241022-v2:0\
34+
35+
###############################################
36+
########### Embedding API SElECTION ###########
37+
###############################################
38+
# Only used if you are using an LLM that does not natively support embedding (openai or Azure)
39+
# EMBEDDING_ENGINE='openai'
40+
# OPEN_AI_KEY=sk-xxxx
41+
# EMBEDDING_MODEL_PREF='text-embedding-ada-002'
42+
43+
# EMBEDDING_ENGINE='azure'
44+
# AZURE_OPENAI_ENDPOINT=
45+
# AZURE_OPENAI_KEY=
46+
# EMBEDDING_MODEL_PREF='my-embedder-model' # This is the "deployment" on Azure you want to use for embeddings. Not the base model. Valid base model is text-embedding-ada-002
47+
48+
# EMBEDDING_ENGINE='ollama'
49+
# EMBEDDING_BASE_PATH='http://host.docker.internal:11434'
50+
# EMBEDDING_MODEL_PREF='nomic-embed-text:latest'
51+
# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
52+
53+
# EMBEDDING_ENGINE='bedrock'
54+
# AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID=
55+
# AWS_BEDROCK_EMBEDDING_ACCESS_KEY=
56+
# AWS_BEDROCK_EMBEDDING_REGION=us-west-2
57+
# AWS_BEDROCK_EMBEDDING_MODEL_PREF=amazon.embedding-embedding-ada-002:0
58+
59+
# EMBEDDING_ENGINE='gemini'
60+
# GEMINI_EMBEDDING_API_KEY=
61+
# EMBEDDING_MODEL_PREF='text-embedding-004'
62+
63+
# EMBEDDING_ENGINE='huggingface'
64+
# HUGGING_FACE_EMBEDDING_REPO_ID=
65+
# HUGGING_FACE_EMBEDDING_MODEL=
66+
# HUGGING_FACE_EMBEDDING_API_TOKEN=

llm_utils/llm_factory.py

Lines changed: 168 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,179 @@
11
# llm_factory.py
2+
import os
23
from typing import Optional
34

5+
from dotenv import load_dotenv
46
from langchain.llms.base import BaseLanguageModel
5-
from langchain_openai import ChatOpenAI
7+
from langchain_aws import ChatBedrockConverse, BedrockEmbeddings
8+
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
9+
from langchain_huggingface import (
10+
ChatHuggingFace,
11+
HuggingFaceEndpoint,
12+
HuggingFaceEndpointEmbeddings,
13+
)
14+
from langchain_ollama import ChatOllama, OllamaEmbeddings
15+
from langchain_openai import (
16+
AzureOpenAIEmbeddings,
17+
ChatOpenAI,
18+
AzureChatOpenAI,
19+
OpenAIEmbeddings,
20+
)
21+
from langchain_community.llms.bedrock import Bedrock
622

23+
# .env 파일 로딩
24+
load_dotenv()
725

8-
def get_llm(
9-
model_type: str,
10-
model_name: Optional[str] = None,
11-
openai_api_key: Optional[str] = None,
12-
**kwargs,
13-
) -> BaseLanguageModel:
26+
27+
def get_llm() -> BaseLanguageModel:
1428
"""
15-
주어진 model_type과 model_name 등에 따라 적절한 LLM 객체를 생성/반환한다.
29+
return chat model interface
1630
"""
17-
if model_type == "openai":
18-
return ChatOpenAI(
19-
model=model_name,
20-
api_key=openai_api_key,
21-
**kwargs,
31+
provider = os.getenv("LLM_PROVIDER")
32+
33+
if provider is None:
34+
raise ValueError("LLM_PROVIDER environment variable is not set.")
35+
36+
if provider == "openai":
37+
return get_llm_openai()
38+
39+
elif provider == "azure":
40+
return get_llm_azure()
41+
42+
elif provider == "bedrock":
43+
return get_llm_bedrock()
44+
45+
elif provider == "gemini":
46+
return get_llm_gemini()
47+
48+
elif provider == "ollama":
49+
return get_llm_ollama()
50+
51+
elif provider == "huggingface":
52+
return get_llm_huggingface()
53+
54+
else:
55+
raise ValueError(f"Invalid LLM API Provider: {provider}")
56+
57+
58+
def get_llm_openai() -> BaseLanguageModel:
59+
return ChatOpenAI(
60+
model=os.getenv("OPEN_MODEL_PREF", "gpt-4o"),
61+
api_key=os.getenv("OPEN_AI_KEY"),
62+
)
63+
64+
65+
def get_llm_azure() -> BaseLanguageModel:
66+
return AzureChatOpenAI(
67+
api_key=os.getenv("AZURE_OPENAI_LLM_KEY"),
68+
azure_endpoint=os.getenv("AZURE_OPENAI_LLM_ENDPOINT"),
69+
azure_deployment=os.getenv("AZURE_OPENAI_LLM_MODEL"), # Deployment name
70+
api_version=os.getenv("AZURE_OPENAI_LLM_API_VERSION", "2023-07-01-preview"),
71+
)
72+
73+
74+
def get_llm_bedrock() -> BaseLanguageModel:
75+
return ChatBedrockConverse(
76+
model=os.getenv("AWS_BEDROCK_LLM_MODEL"),
77+
aws_access_key_id=os.getenv("AWS_BEDROCK_LLM_ACCESS_KEY_ID"),
78+
aws_secret_access_key=os.getenv("AWS_BEDROCK_LLM_SECRET_ACCESS_KEY"),
79+
region_name=os.getenv("AWS_BEDROCK_LLM_REGION", "us-east-1"),
80+
)
81+
82+
83+
def get_llm_gemini() -> BaseLanguageModel:
84+
return ChatGoogleGenerativeAI(model=os.getenv("GEMINI_LLM_MODEL"))
85+
86+
87+
def get_llm_ollama() -> BaseLanguageModel:
88+
base_url = os.getenv("OLLAMA_LLM_BASE_URL")
89+
if base_url:
90+
return ChatOllama(base_url=base_url, model=os.getenv("OLLAMA_LLM_MODEL"))
91+
else:
92+
return ChatOllama(model=os.getenv("OLLAMA_LLM_MODEL"))
93+
94+
95+
def get_llm_huggingface() -> BaseLanguageModel:
96+
return ChatHuggingFace(
97+
llm=HuggingFaceEndpoint(
98+
model=os.getenv("HUGGING_FACE_LLM_MODEL"),
99+
repo_id=os.getenv("HUGGING_FACE_LLM_REPO_ID"),
100+
task="text-generation",
101+
endpoint_url=os.getenv("HUGGING_FACE_LLM_ENDPOINT"),
102+
huggingfacehub_api_token=os.getenv("HUGGING_FACE_LLM_API_TOKEN"),
22103
)
104+
)
105+
106+
107+
def get_embeddings() -> Optional[BaseLanguageModel]:
108+
"""
109+
return embedding model interface
110+
"""
111+
provider = os.getenv("EMBEDDING_PROVIDER")
112+
113+
if provider is None:
114+
raise ValueError("EMBEDDING_PROVIDER environment variable is not set.")
115+
116+
if provider == "openai":
117+
return get_embeddings_openai()
118+
119+
elif provider == "bedrock":
120+
return get_embeddings_bedrock()
121+
122+
elif provider == "azure":
123+
return get_embeddings_azure()
124+
125+
elif provider == "gemini":
126+
return get_embeddings_gemini()
127+
128+
elif provider == "ollama":
129+
return get_embeddings_ollama()
23130

24131
else:
25-
raise ValueError(f"지원하지 않는 model_type: {model_type}")
132+
raise ValueError(f"Invalid Embedding API Provider: {provider}")
133+
134+
135+
def get_embeddings_openai() -> BaseLanguageModel:
136+
return OpenAIEmbeddings(
137+
model=os.getenv("OPEN_AI_EMBEDDING_MODEL"),
138+
openai_api_key=os.getenv("OPEN_AI_EMBEDDING_KEY"),
139+
)
140+
141+
142+
def get_embeddings_azure() -> BaseLanguageModel:
143+
return AzureOpenAIEmbeddings(
144+
api_key=os.getenv("AZURE_OPENAI_EMBEDDING_KEY"),
145+
azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT"),
146+
azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"),
147+
api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
148+
)
149+
150+
151+
def get_embeddings_bedrock() -> BaseLanguageModel:
152+
return BedrockEmbeddings(
153+
model_id=os.getenv("AWS_BEDROCK_EMBEDDING_MODEL"),
154+
aws_access_key_id=os.getenv("AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID"),
155+
aws_secret_access_key=os.getenv("AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY"),
156+
region_name=os.getenv("AWS_BEDROCK_EMBEDDING_REGION", "us-east-1"),
157+
)
158+
159+
160+
def get_embeddings_gemini() -> BaseLanguageModel:
161+
return GoogleGenerativeAIEmbeddings(
162+
model=os.getenv("GEMINI_EMBEDDING_MODEL"),
163+
api_key=os.getenv("GEMINI_EMBEDDING_KEY"),
164+
)
165+
166+
167+
def get_embeddings_ollama() -> BaseLanguageModel:
168+
return OllamaEmbeddings(
169+
model=os.getenv("OLLAMA_EMBEDDING_MODEL"),
170+
base_url=os.getenv("OLLAMA_EMBEDDING_BASE_URL"),
171+
)
172+
173+
174+
def get_embeddings_huggingface() -> BaseLanguageModel:
175+
return HuggingFaceEndpointEmbeddings(
176+
model=os.getenv("HUGGING_FACE_EMBEDDING_MODEL"),
177+
repo_id=os.getenv("HUGGING_FACE_EMBEDDING_REPO_ID"),
178+
huggingfacehub_api_token=os.getenv("HUGGING_FACE_EMBEDDING_API_TOKEN"),
179+
)

0 commit comments

Comments
 (0)