Skip to content

Add Gemini 1.5 Pro and Flash Support #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
except ImportError:
openai_new_api = False # old openai api version

from google.generativeai.types import GenerateContentResponse

@dataclass(frozen=True)
class ChatAgentResponse:
Expand Down Expand Up @@ -237,7 +238,25 @@ def step(

if num_tokens < self.model_token_limit:
response = self.model_backend.run(messages=openai_messages)
if openai_new_api:
if isinstance(response, GenerateContentResponse):
candidate = response.candidates[0]

output_messages = [
ChatMessage(role_name=self.role_name, role_type=self.role_type,
meta_dict=dict(), role=candidate.content.role, content=part.text)
for part in candidate.content.parts
]
info = self.get_info(
candidate.index,
{
"total_tokens": response.usage_metadata.total_token_count,
"prompt_tokens": response.usage_metadata.prompt_token_count,
"completion_tokens": response.usage_metadata.candidates_token_count,
},
candidate.finish_reason,
num_tokens,
)
elif openai_new_api:
if not isinstance(response, ChatCompletion):
raise RuntimeError("OpenAI returned unexpected struct")
output_messages = [
Expand Down
2 changes: 1 addition & 1 deletion camel/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def to_openai_message(self, role: Optional[str] = None) -> OpenAIMessage:
OpenAIMessage: The converted :obj:`OpenAIMessage` object.
"""
role = role or self.role
if role not in {"system", "user", "assistant"}:
if role not in {"system", "user", "assistant", "model"}:
raise ValueError(f"Unrecognized role: {role}")
return {"role": role, "content": self.content}

Expand Down
77 changes: 76 additions & 1 deletion camel/model_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from camel.typing import ModelType
from chatdev.statistics import prompt_cost
from chatdev.utils import log_visualize
import google.generativeai as genai

try:
from openai.types.chat import ChatCompletion
Expand All @@ -30,12 +31,21 @@

import os

OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
if 'OPENAI_API_KEY' in os.environ:
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
else:
OPENAI_API_KEY = None

if 'BASE_URL' in os.environ:
BASE_URL = os.environ['BASE_URL']
else:
BASE_URL = None

if 'GEMINI_API_KEY' in os.environ:
GEMINI_API_KEY = os.environ['GEMINI_API_KEY']
else:
GEMINI_API_KEY = None


class ModelBackend(ABC):
r"""Base class for different model backends.
Expand Down Expand Up @@ -148,6 +158,66 @@ def run(self, *args, **kwargs):
raise RuntimeError("Unexpected return from OpenAI API")
return response

class GeminiModel(ModelBackend):
r"""Gemini API in a unified ModelBackend interface."""

def __init__(self, model_type: ModelType, model_config_dict: Dict) -> None:
super().__init__()
self.model_type = model_type
self.model_config_dict = model_config_dict
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

def run(self, *args, **kwargs):
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
string = "\n".join([message["content"] for message in kwargs["messages"]])
model = genai.GenerativeModel(self.model_type.value)
chat = model.start_chat()
num_prompt_tokens = model.count_tokens(string).total_tokens
num_max_token_map = {
"gemini-1.5-pro": 1000000,
"gemini-1.5-flash": 1000000
}
num_max_token = num_max_token_map[self.model_type.value]
num_max_completion_tokens = num_max_token - num_prompt_tokens

# print("Positional arguments (*args):", args)

# print("Keyword arguments (**kwargs):", kwargs)
# # print("Values:", type(kwargs['messages'][0]))

# for openai_message in kwargs['messages']:
# # print("{}\t{}".format(openai_message.role, openai_message.content))
# print("{}\t{}\t{}".format(openai_message["role"], hash(openai_message["content"]), openai_message["content"][:60].replace("\n", "")))

messages=[]

for openai_message in kwargs['messages']:
if openai_message["role"]=='user' or openai_message["role"]=='system':
role='user'
else:
role='model'

if len(messages)==0:
messages.append({'role':role,'parts':[{'text': openai_message["content"].replace("\n", "")}]})
elif role==messages[-1]['role']:
messages[-1]['parts'][0]['text']=messages[-1]['parts'][0]['text'] + openai_message["content"].replace("\n", "")
else:
messages.append({'role':role,'parts':[{'text': openai_message["content"].replace("\n", "")}]})

response = model.generate_content(*args, contents=messages)

cost = prompt_cost(
self.model_type.value,
num_prompt_tokens=response.usage_metadata.prompt_token_count,
num_completion_tokens=response.usage_metadata.candidates_token_count
)

log_visualize(
"**[Gemini_Usage_Info Receive]**\nprompt_tokens: {}\ncompletion_tokens: {}\ntotal_tokens: {}\ncost: ${:.6f}\n".format(
response.usage_metadata.prompt_token_count,response.usage_metadata.candidates_token_count,
response.usage_metadata.total_token_count, cost))

return response

class StubModel(ModelBackend):
r"""A dummy model used for unit tests."""
Expand Down Expand Up @@ -191,6 +261,11 @@ def create(model_type: ModelType, model_config_dict: Dict) -> ModelBackend:
None
}:
model_class = OpenAIModel
elif model_type in {
ModelType.GEMINI_1_5_FLASH,
ModelType.GEMINI_1_5_PRO
}:
model_class = GeminiModel
elif model_type == ModelType.STUB:
model_class = StubModel
else:
Expand Down
2 changes: 2 additions & 0 deletions camel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ModelType(Enum):
GPT_4_TURBO_V = "gpt-4-turbo"
GPT_4O = "gpt-4o"
GPT_4O_MINI = "gpt-4o-mini"
GEMINI_1_5_PRO="gemini-1.5-pro"
GEMINI_1_5_FLASH = "gemini-1.5-flash"

STUB = "stub"

Expand Down
10 changes: 9 additions & 1 deletion camel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def num_tokens_from_messages(
ModelType.GPT_4_TURBO_V,
ModelType.GPT_4O,
ModelType.GPT_4O_MINI,
ModelType.STUB
ModelType.STUB,
ModelType.GEMINI_1_5_FLASH,
ModelType.GEMINI_1_5_PRO
}:
return count_tokens_openai_chat_models(messages, encoding)
else:
Expand Down Expand Up @@ -130,6 +132,10 @@ def get_model_token_limit(model: ModelType) -> int:
return 128000
elif model == ModelType.GPT_4O_MINI:
return 128000
elif model == ModelType.GEMINI_1_5_FLASH:
return 1000000
elif model == ModelType.GEMINI_1_5_PRO:
return 1000000
else:
raise ValueError("Unknown model type")

Expand Down Expand Up @@ -158,6 +164,8 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
elif 'OPENAI_API_KEY' in os.environ:
return func(self, *args, **kwargs)
elif 'GEMINI_API_KEY' in os.environ:
return func(self, *args, **kwargs);
else:
raise ValueError('OpenAI API key not found.')

Expand Down
16 changes: 14 additions & 2 deletions camel/web_spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,20 @@
import os
import time

self_api_key = os.environ.get('OPENAI_API_KEY')
BASE_URL = os.environ.get('BASE_URL')
if 'OPENAI_API_KEY' in os.environ:
self_api_key = os.environ['OPENAI_API_KEY']
else:
self_api_key = None

if 'BASE_URL' in os.environ:
BASE_URL = os.environ['BASE_URL']
else:
BASE_URL = None

if 'GEMINI_API_KEY' in os.environ:
self_api_key = os.environ['GEMINI_API_KEY']
else:
self_api_key = None

if BASE_URL:
client = openai.OpenAI(
Expand Down
11 changes: 10 additions & 1 deletion ecl/embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import os
import openai
from openai import OpenAI
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
if 'OPENAI_API_KEY' in os.environ:
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
else:
OPENAI_API_KEY = None

if 'BASE_URL' in os.environ:
BASE_URL = os.environ['BASE_URL']
else:
BASE_URL = None

if 'GEMINI_API_KEY' in os.environ:
GEMINI_API_KEY = os.environ['GEMINI_API_KEY']
else:
GEMINI_API_KEY = None
import sys
import time
from tenacity import (
Expand Down
11 changes: 10 additions & 1 deletion ecl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@
stop_after_attempt,
wait_exponential
)
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
if 'OPENAI_API_KEY' in os.environ:
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
else:
OPENAI_API_KEY = None

if 'BASE_URL' in os.environ:
BASE_URL = os.environ['BASE_URL']
else:
BASE_URL = None

if 'GEMINI_API_KEY' in os.environ:
GEMINI_API_KEY = os.environ['GEMINI_API_KEY']
else:
GEMINI_API_KEY = None

def getFilesFromType(sourceDir, filetype):
files = []
for root, directories, filenames in os.walk(sourceDir):
Expand Down
2 changes: 2 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def get_config(company):
# 'GPT_4_TURBO_V': ModelType.GPT_4_TURBO_V
'GPT_4O': ModelType.GPT_4O,
'GPT_4O_MINI': ModelType.GPT_4O_MINI,
'GEMINI_1_5_FLASH': ModelType.GEMINI_1_5_FLASH,
'GEMINI_1_5_PRO': ModelType.GEMINI_1_5_PRO
}
if openai_new_api:
args2type['GPT_3_5_TURBO'] = ModelType.GPT_3_5_TURBO_NEW
Expand Down