Skip to content

Enhance Functionality: Allow Full LLM Max Tokens #1631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions patchwork/common/client/llm/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def system(self) -> str:

def is_model_supported(self, model: str) -> bool:
return any(client.is_model_supported(model) for client in self.__clients)

def get_model_limit(self, model: str) -> int:
"""
Get the model's context length limit from the appropriate client.

Args:
model: The model name

Returns:
The maximum context length in tokens, or a default value if not found
"""
for client in self.__clients:
if client.is_model_supported(model) and hasattr(client, 'get_model_limit'):
return client.get_model_limit(model)

# Default value if no client found or client doesn't have the method
return 128_000

def is_prompt_supported(
self,
Expand All @@ -119,6 +136,13 @@ def is_prompt_supported(
top_p: Optional[float] | NotGiven = NOT_GIVEN,
file: Path | NotGiven = NOT_GIVEN,
) -> int:
"""
Check if the prompt is supported by the model and return available tokens.

Returns:
int: If > 0, represents available tokens remaining after prompt.
If <= 0, indicates that prompt is too large.
"""
for client in self.__clients:
if client.is_model_supported(model):
inputs = dict(
Expand Down
52 changes: 14 additions & 38 deletions patchwork/common/client/llm/openai_.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,46 +110,22 @@ def is_model_supported(self, model: str) -> bool:
return model in _cached_list_models_from_openai(self.__api_key)

def __get_model_limits(self, model: str) -> int:
"""Return the token limit for a given model."""
return self.__MODEL_LIMITS.get(model, 128_000)

def get_model_limit(self, model: str) -> int:
"""
Public method to get the model's context length limit.

Args:
model: The model name

Returns:
The maximum context length in tokens
"""
return self.__get_model_limits(model)


def is_prompt_supported(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
file: Path | NotGiven = NOT_GIVEN,
) -> int:
# might not implement model endpoint
if self.__is_not_openai_url():
return 1

model_limit = self.__get_model_limits(model)
token_count = 0
encoding = None
try:
encoding = tiktoken.encoding_for_model(model)
except Exception as e:
logger.error(f"Error getting encoding for model {model}: {e}, using gpt-4o as fallback")
encoding = tiktoken.encoding_for_model("gpt-4o")
for message in messages:
message_token_count = len(encoding.encode(message.get("content")))
token_count = token_count + message_token_count
if token_count > model_limit:
return -1

return model_limit - token_count

def truncate_messages(
self, messages: Iterable[ChatCompletionMessageParam], model: str
Expand Down
2 changes: 1 addition & 1 deletion patchwork/patchflows/GenerateREADME/defaults.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# model: codellama/CodeLlama-70b-Instruct-hf
# model_temperature: 0.2
# model_top_p: 0.95
# model_max_tokens: 2000
# model_max_tokens: 2000 # Use -1 to automatically use the maximum tokens available for the model

# CommitChanges Inputs
disable_branch: false
Expand Down
61 changes: 23 additions & 38 deletions patchwork/steps/CallLLM/CallLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,7 @@ class _InnerCallLLMResponse:


class CallLLM(Step, input_class=CallLLMInputs, output_class=CallLLMOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
# Set 'openai_key' from inputs or environment if not already set
inputs.setdefault("openai_api_key", os.environ.get("OPENAI_API_KEY"))

prompt_file = inputs.get("prompt_file")
if prompt_file is not None:
prompt_file_path = Path(prompt_file)
if not prompt_file_path.is_file():
raise ValueError(f'Unable to find Prompt file: "{prompt_file}"')
try:
with open(prompt_file_path, "r") as fp:
self.prompts = json.load(fp)
except json.JSONDecodeError as e:
raise ValueError(f'Invalid Json Prompt file "{prompt_file}": {e}')
elif "prompts" in inputs.keys():
self.prompts = inputs["prompts"]
else:
raise ValueError('Missing required data: "prompt_file" or "prompts"')

self.call_limit = int(inputs.get("max_llm_calls", -1))
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
self.save_responses_to_file = inputs.get("save_responses_to_file", None)
self.model = inputs.get("model", "gpt-4o-mini")
self.allow_truncated = inputs.get("allow_truncated", False)
self.file = inputs.get("file", None)
self.client = AioLlmClient.create_aio_client(inputs)
if self.client is None:
raise ValueError(
f"Model API key not found.\n"
f'Please login at: "{TOKEN_URL}",\n'
"Please go to the Integration's tab and generate an API key.\n"
"Please copy the access token that is generated, "
"and add `--patched_api_key=<token>` to the command line.\n"
"\n"
"If you are using an OpenAI API Key, please set `--openai_api_key=<token>`.\n"
)


def __persist_to_file(self, contents):
# Convert relative path to absolute path
Expand Down Expand Up @@ -121,10 +85,22 @@ def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
kwargs["file"] = Path(self.file)

for prompt in prompts:
is_input_accepted = self.client.is_prompt_supported(model=self.model, messages=prompt, **kwargs) > 0
available_tokens = self.client.is_prompt_supported(model=self.model, messages=prompt, **kwargs)
is_input_accepted = available_tokens > 0

if not is_input_accepted:
self.set_status(StepStatus.WARNING, "Input token limit exceeded.")
prompt = self.client.truncate_messages(prompt, self.model)

# Handle the case where model_max_tokens was set to -1
# Calculate max_tokens based on available tokens from the model after prompt
if hasattr(self, '_use_max_tokens') and self._use_max_tokens:
if available_tokens > 0:
kwargs['max_tokens'] = available_tokens
logger.info(f"Setting max_tokens to {available_tokens} based on available model context")
else:
# If we can't determine available tokens, set a reasonable default
logger.warning("Could not determine available tokens. Using model default.")

logger.trace(f"Message sent: \n{escape(indent(pformat(prompt), ' '))}")
try:
Expand Down Expand Up @@ -184,4 +160,13 @@ def __parse_model_args(self) -> dict:
else:
new_model_args[key] = arg

# Handle special case for max_tokens = -1 (use maximum available tokens)
if 'max_tokens' in new_model_args and new_model_args['max_tokens'] == -1:
# Will be handled during the chat completion call
logger.info("Using maximum available tokens for the model")
del new_model_args['max_tokens'] # Remove it for now, we'll calculate it later
self._use_max_tokens = True
else:
self._use_max_tokens = False

return new_model_args
9 changes: 8 additions & 1 deletion patchwork/steps/CallLLM/typed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import Annotated, Dict, List, TypedDict
from typing_extensions import Annotated, Dict, List, TypedDict, Literal

from patchwork.common.constants import TOKEN_URL
from patchwork.common.utils.step_typing import StepTypeConfig
Expand All @@ -12,6 +12,13 @@ class CallLLMInputs(TypedDict, total=False):
allow_truncated: Annotated[bool, StepTypeConfig(is_config=True)]
model_args: Annotated[str, StepTypeConfig(is_config=True)]
client_args: Annotated[str, StepTypeConfig(is_config=True)]
model_max_tokens: Annotated[
int | Literal[-1],
StepTypeConfig(
is_config=True,
description="Maximum number of tokens to generate. Use -1 to use maximum available tokens for the model."
)
]
openai_api_key: Annotated[
str,
StepTypeConfig(
Expand Down
Loading