Skip to content

Commit 93f2c9c

Browse files
authored
fix get_toolcall & fix ci (#3999)
1 parent 60718fa commit 93f2c9c

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

swift/llm/infer/infer_engine/infer_engine.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,10 @@ def infer(self,
176176

177177
@staticmethod
178178
def _get_toolcall(response: str, template: Template) -> Optional[List[ChatCompletionMessageToolCall]]:
179-
functions = template.agent_template.get_toolcall(response)
179+
try:
180+
functions = template.agent_template.get_toolcall(response)
181+
except Exception:
182+
functions = None
180183
if functions:
181184
return [ChatCompletionMessageToolCall(function=function) for function in functions]
182185

swift/plugin/agent_template/hermes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def get_toolcall(self, response: str) -> List['Function']:
1919
functions = []
2020
for res in res_list:
2121
res = self._parse_json(res)
22-
if res is not None:
22+
if isinstance(res, dict) and 'name' in res and 'arguments' in res:
2323
functions.append(Function(name=res['name'], arguments=res['arguments']))
2424
if len(functions) == 0:
2525
# compat react_en

swift/plugin/agent_template/llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_toolcall(self, response: str) -> List['Function']:
2525
res_list = re.findall(r'{[^{]*?"name":.*?"parameters":\s*?{.*?}\s*?}', response, re.DOTALL)
2626
for res in res_list:
2727
res = self._parse_json(res)
28-
if res is not None:
28+
if isinstance(res, dict) and 'name' in res and 'parameters' in res:
2929
functions.append(Function(name=res['name'], arguments=res['parameters']))
3030
if len(functions) == 0:
3131
# compat react_en

tests/llm/test_template.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import os
12
import unittest
23

3-
if __name__ == '__main__':
4-
import os
5-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
6-
os.environ['SWIFT_DEBUG'] = '1'
7-
from swift.llm import PtEngine, RequestConfig, get_model_tokenizer, get_template
8-
from swift.utils import get_logger, seed_everything
4+
from swift.llm import PtEngine, RequestConfig, get_model_tokenizer, get_template
5+
from swift.utils import get_logger, seed_everything
6+
7+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8+
os.environ['SWIFT_DEBUG'] = '1'
99

1010
logger = get_logger()
1111

0 commit comments

Comments
 (0)