Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
agent_template: str = 'glm4_0414'


@dataclass
class GLM4_5TemplateMeta(GLM4_0414TemplateMeta):
agent_template: str = 'glm4_5'


class GLM4_1VTemplateMeta(GLM4_0414TemplateMeta):
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>{{SYSTEM}}'])

Expand Down Expand Up @@ -234,8 +239,18 @@ class GLM4_5Template(ThinkingTemplate):
no_think_prefix = '<think></think>\n'
history_think_prefix = '<think></think>\n'

def _swift_encode(self, inputs: StdTemplateInputs):
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
# When it's a tool_call, avoid generating <|observation|><|user|>
penultimate_content = res_context_list[-2] if len(res_context_list) >= 2 else None
if isinstance(penultimate_content,
str) and penultimate_content.endswith('<|observation|>') and res_context_list[-1] == '<|user|>':
res_context_list = res_context_list[:-1]
answer_len -= 1
return res_context_list, loss_scale_list, answer_len


register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))
register_template(GLM4_5TemplateMeta(LLMTemplateType.glm4_5, template_cls=GLM4_5Template))

register_template(GLM4_1VTemplateMeta(MLLMTemplateType.glm4_1v, template_cls=GLM4_1VTemplate))

Expand Down
3 changes: 2 additions & 1 deletion swift/plugin/agent_template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .base import BaseAgentTemplate
from .extra import ReactGRPOAgentTemplate
from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate
from .glm4 import GLM4_5AgentTemplate, GLM4_0414AgentTemplate, GLM4AgentTemplate
from .hermes import HermesAgentTemplate, HunyuanHermesAgentTemplate
from .llama import Llama3AgentTemplate, Llama4AgentTemplate
from .mistral import MistralAgentTemplate
Expand All @@ -23,6 +23,7 @@
'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench
'glm4': GLM4AgentTemplate,
'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414
'glm4_5': GLM4_5AgentTemplate,
'llama3': Llama3AgentTemplate,
'llama4': Llama4AgentTemplate,
# extra
Expand Down
70 changes: 69 additions & 1 deletion swift/plugin/agent_template/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _find_function_call(single_content: str) -> Optional['Function']:
matches = pattern.findall(single_content)
if not matches:
return

name, arguments = matches[0]
return Function(name=name, arguments=arguments)

Expand Down Expand Up @@ -77,3 +76,72 @@ def _format_tool_calls(self, tool_call_messages) -> str:

class GLM4_0414AgentTemplate(GLM4AgentTemplate):
is_glm4_0414 = True


class GLM4_5AgentTemplate(BaseAgentTemplate):

@staticmethod
def _find_function_call(single_content: str) -> Optional['Function']:
from swift.llm.infer import Function
single_content = single_content.strip()
func_name_match = re.match(r'^([^\n<]+)', single_content)
if not func_name_match:
return None
func_name = func_name_match.group(1).strip()
keys = re.findall(r'<arg_key>(.*?)</arg_key>', single_content, re.DOTALL)
values = re.findall(r'<arg_value>(.*?)</arg_value>', single_content, re.DOTALL)
if len(keys) != len(values):
return None
args = {k.strip(): v.strip() for k, v in zip(keys, values)}
return Function(name=func_name, arguments=json.dumps(args, ensure_ascii=False))

def get_toolcall(self, response: str) -> List['Function']:
toolcall_list = re.findall(r'<tool_call>(.*?)</tool_call>', response, re.DOTALL)
functions = []
for toolcall in toolcall_list:
function = self._find_function_call(toolcall)
if function:
functions.append(function)
Comment on lines +100 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop for collecting functions can be made more concise and Pythonic by using a list comprehension with a walrus operator (PEP 572). This can improve readability and reduce boilerplate code.

Suggested change
functions = []
for toolcall in toolcall_list:
function = self._find_function_call(toolcall)
if function:
functions.append(function)
functions = [func for toolcall in toolcall_list if (func := self._find_function_call(toolcall))]

if len(functions) == 0:
# compat react_en
return super().get_toolcall(response)
return functions

def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
tool_descs = [
'# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>'
]
for tool in tools:
tool_descs.append(f'{json.dumps(tool, ensure_ascii=False)}')
tool_descs.append(
'</tools>\n\nFor each function call, output the function name and arguments within the following XML format:\n<tool_call>{function-name}\n<arg_key>{arg-key-1}</arg_key>\n<arg_value>{arg-value-1}</arg_value>\n<arg_key>{arg-key-2}</arg_key>\n<arg_value>{arg-value-2}</arg_value>\n...\n</tool_call>'
)
tool_descs = '\n'.join(tool_descs)
if system.strip():
tool_descs += '<|system|>\n' + system.strip()
return tool_descs

def _format_tool_responses(
self,
assistant_content: str,
tool_messages,
) -> Tuple[str, 'Prompt']:
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = []
for _, tool_message in enumerate(tool_messages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of enumerate here is unnecessary as the index _ is not used. You can directly iterate over tool_messages for cleaner and more idiomatic code.

Suggested change
for _, tool_message in enumerate(tool_messages):
for tool_message in tool_messages:

tool_content = tool_message['content']
res.append(f'\n<tool_response>\n{tool_content}\n</tool_response>')
res.append('<|assistant|>\n')
return assistant_content, res

def _format_tool_calls(self, tool_call_messages) -> str:
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
tool_calls.append(f"<tool_call>{tool_call['name']}")
for arg_key, arg_value in tool_call['arguments'].items():
tool_calls.append(f'<arg_key>{arg_key}</arg_key>\n<arg_value>{arg_value}</arg_value>')
Comment on lines +144 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code assumes that tool_call['arguments'] is always a dictionary, which may not be guaranteed by _parse_tool_call. If arguments is not a dictionary, tool_call['arguments'].items() will raise a runtime error. It's safer to check if arguments is a dictionary before iterating over it to make the function more robust against malformed inputs.

Suggested change
for arg_key, arg_value in tool_call['arguments'].items():
tool_calls.append(f'<arg_key>{arg_key}</arg_key>\n<arg_value>{arg_value}</arg_value>')
arguments = tool_call.get('arguments')
if isinstance(arguments, dict):
for arg_key, arg_value in arguments.items():
tool_calls.append(f'<arg_key>{arg_key}</arg_key>\n<arg_value>{arg_value}</arg_value>')

tool_calls.append('</tool_call>')
return '\n'.join(tool_calls) + '<|observation|>'
Loading