-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import List, Optional | ||
|
||
from minion.providers.openai_provider import OpenAIProvider | ||
from minion.providers.llm_provider_registry import llm_registry | ||
from minion.message_types import Message | ||
|
||
|
||
@llm_registry.register("azure") | ||
class AzureProvider(OpenAIProvider): | ||
def _setup(self) -> None: | ||
import openai | ||
|
||
# Azure OpenAI 需要特定的配置 | ||
client_kwargs = { | ||
"api_key": self.config.api_key, | ||
"azure_endpoint": self.config.base_url, # Azure 使用 endpoint 而不是 base_url | ||
"api_version": self.config.api_version or "2024-05-01-preview", # Azure 需要 api_version | ||
} | ||
|
||
# Azure OpenAI 使用 azure_deployment_name 而不是 model | ||
if hasattr(self.config, "deployment_name"): | ||
client_kwargs["azure_deployment"] = self.config.deployment_name | ||
else: | ||
client_kwargs["azure_deployment"] = self.config.model | ||
|
||
self.client_ell = openai.AzureOpenAI(**client_kwargs) | ||
self.client = openai.AsyncAzureOpenAI(**client_kwargs) | ||
|
||
async def generate(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: | ||
# 移除 model 参数,因为 Azure 使用 deployment_name | ||
kwargs.pop("model", None) | ||
return await super().generate(messages, temperature, **kwargs) | ||
|
||
async def generate_stream(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: | ||
# 移除 model 参数,因为 Azure 使用 deployment_name | ||
kwargs.pop("model", None) | ||
return await super().generate_stream(messages, temperature, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import sys | ||
import subprocess | ||
import tempfile | ||
import os | ||
import re | ||
from typing import NamedTuple | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class ProcessResult: | ||
"""Result of running code in a separate process.""" | ||
stdout: str | ||
stderr: str | ||
return_code: int | ||
|
||
@property | ||
def success(self) -> bool: | ||
"""Whether the process executed successfully.""" | ||
return self.return_code == 0 | ||
|
||
|
||
def _has_main_structure(code: str) -> tuple[bool, bool]: | ||
"""Check if code already has main function and/or main guard. | ||
Returns: | ||
Tuple of (has_main_func, has_main_guard) | ||
""" | ||
# Simple pattern matching for main function and guard | ||
has_main_func = bool(re.search(r'def\s+main\s*\(', code)) | ||
has_main_guard = 'if __name__ == "__main__":' in code or "if __name__ == '__main__':" in code | ||
return has_main_func, has_main_guard | ||
|
||
|
||
def run_code_in_separate_process( | ||
code: str, | ||
input_data: str = "", | ||
timeout: int = 120, | ||
indent: str = " " # Default indentation for the code | ||
) -> ProcessResult: | ||
"""Run Python code in a separate process with the given input data. | ||
Args: | ||
code: The Python code to execute | ||
input_data: Input data to feed to the process's stdin | ||
timeout: Maximum execution time in seconds | ||
indent: Indentation to use for the code in the main function | ||
Returns: | ||
ProcessResult containing stdout, stderr, and return code | ||
Raises: | ||
TimeoutError: If the code execution exceeds the timeout | ||
subprocess.SubprocessError: If there's an error running the subprocess | ||
""" | ||
has_main_func, has_main_guard = _has_main_structure(code) | ||
|
||
# If code already has both main structures, use it as is | ||
if has_main_func and has_main_guard: | ||
wrapper_code = code | ||
# If code has main function but no guard, add the guard | ||
elif has_main_func: | ||
wrapper_code = f'''{code} | ||
if __name__ == "__main__": | ||
main()''' | ||
# If code has neither, wrap it in both | ||
else: | ||
# Indent the code | ||
indented_code = "\n".join(indent + line if line.strip() else line | ||
for line in code.splitlines()) | ||
|
||
wrapper_code = f''' | ||
import sys | ||
def main(): | ||
{indented_code} | ||
if __name__ == "__main__": | ||
main() | ||
''' | ||
|
||
# Create a temporary Python file with the code | ||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: | ||
f.write(wrapper_code) | ||
temp_file = f.name | ||
|
||
try: | ||
# Run the code in a separate process | ||
process = subprocess.Popen( | ||
[sys.executable, temp_file], | ||
stdin=subprocess.PIPE, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
text=True | ||
) | ||
|
||
# Send input and get output with timeout | ||
try: | ||
stdout, stderr = process.communicate(input=input_data, timeout=timeout) | ||
return ProcessResult( | ||
stdout=stdout, | ||
stderr=stderr, | ||
return_code=process.returncode | ||
) | ||
except subprocess.TimeoutExpired: | ||
process.kill() | ||
raise TimeoutError("Code execution timed out") | ||
|
||
finally: | ||
# Clean up the temporary file | ||
os.unlink(temp_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from jinja2 import Environment, BaseLoader | ||
|
||
def render_template_with_variables(template_str: str, **kwargs) -> str: | ||
"""使用 Jinja2 渲染模板字符串 | ||
Args: | ||
template_str (str): 模板字符串 | ||
**kwargs: 传递给模板的变量字典 | ||
Returns: | ||
str: 渲染后的字符串 | ||
""" | ||
env = Environment(loader=BaseLoader()) | ||
template = env.from_string(template_str) | ||
return template.render(**kwargs) |