Skip to content

Commit ef4fef5

Browse files
authored
Fix hf regression due to prompt truncation (patched-codes#994)
1 parent bbae0ec commit ef4fef5

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def get_models(self) -> set[str]:
2929
def is_model_supported(self, model: str) -> bool:
3030
return any(client.is_model_supported(model) for client in self.__clients)
3131

32-
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> bool:
32+
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
3333
for client in self.__clients:
3434
if client.is_model_supported(model):
3535
return client.is_prompt_supported(messages, model)
36-
return False
36+
return -1
3737

3838
def truncate_messages(
3939
self, messages: Iterable[ChatCompletionMessageParam], model: str

patchwork/common/client/llm/openai_.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def __get_model_limits(self, model: str) -> int:
6262
return self.__MODEL_LIMITS.get(model, 128_000)
6363

6464
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
65+
# might not implement model endpoint
66+
if self.__is_not_openai_url():
67+
return 1
68+
6569
model_limit = self.__get_model_limits(model)
6670
token_count = 0
6771
encoding = tiktoken.encoding_for_model(model)

patchwork/steps/JoinList/JoinList.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ def run(self):
2424
if isinstance(item, str):
2525
items.append(item)
2626
elif isinstance(item, dict):
27+
is_added = False
2728
for possible_key in self.possible_keys:
2829
if possible_key in item.keys():
2930
items.append(item.get(possible_key))
30-
else:
31-
items.append(json.dumps(item))
31+
is_added = True
32+
break
33+
if not is_added:
34+
items.append(json.dumps(item))
3235
else:
3336
items.append(str(item))
3437

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "patchwork-cli"
3-
version = "0.0.75"
3+
version = "0.0.76"
44
description = ""
55
authors = ["patched.codes"]
66
license = "AGPL"

0 commit comments

Comments
 (0)