-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
349 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Dict, Any | ||
from minion.main.minion import Minion, register_pre_processing_minion | ||
from minion.main.prompt import PROBLEM_REFLECT_PROMPT, EXAMPLE_REASONING_PROMPT | ||
from minion.actions.lmp_action_node import LmpActionNode | ||
from jinja2 import Template | ||
from minion.logs import logger | ||
|
||
class PreProcessingMinion(Minion): | ||
"""Base class for all pre-processing minions""" | ||
pass | ||
|
||
@register_pre_processing_minion | ||
class ProblemReflectMinion(PreProcessingMinion): | ||
"""Minion that performs problem reflection before solving""" | ||
|
||
async def execute(self): | ||
"""Execute the problem reflection process""" | ||
prompt = Template(PROBLEM_REFLECT_PROMPT) | ||
prompt = prompt.render(input=self.input) | ||
|
||
node = LmpActionNode(self.brain.llm) | ||
reflection = await node.execute(prompt) | ||
|
||
# Store reflection in input metadata for later use | ||
self.input.info["problem_reflection"] = reflection | ||
|
||
logger.info(f"Problem reflection completed: {reflection}") | ||
return reflection | ||
|
||
@register_pre_processing_minion | ||
class ExampleReasoningMinion(PreProcessingMinion): | ||
"""Minion that analyzes and reasons about examples in the query""" | ||
|
||
async def execute(self): | ||
"""Execute the example reasoning process""" | ||
# Check if the input contains examples | ||
if not self._has_examples(): | ||
logger.info("No examples found in the input, skipping example reasoning") | ||
return None | ||
|
||
prompt = Template(EXAMPLE_REASONING_PROMPT) | ||
prompt = prompt.render(input=self.input) | ||
|
||
node = LmpActionNode(self.brain.llm) | ||
reasoning = await node.execute(prompt) | ||
|
||
self.input.info["example_reasoning"] = reasoning | ||
|
||
logger.info(f"Example reasoning completed: {reasoning}") | ||
return reasoning | ||
|
||
def _has_examples(self) -> bool: | ||
"""Check if the input contains examples""" | ||
# This can be implemented with more complex example detection logic as needed | ||
return 'example' in self.input.query.lower() or 'examples' in self.input.query.lower() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from typing import Dict, Any, List | ||
from collections import Counter | ||
from minion.main.minion import Minion, register_result_strategy | ||
|
||
class ResultStrategy(Minion): | ||
"""Base class for result processing strategies""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.workers = kwargs.get('workers', []) # List of actual worker instances | ||
|
||
async def execute(self) -> str: | ||
"""Process the results according to the strategy | ||
Returns: | ||
The processed final result | ||
""" | ||
raise NotImplementedError | ||
|
||
@register_result_strategy | ||
class MajorityVotingStrategy(ResultStrategy): | ||
"""Strategy that selects result with highest vote count""" | ||
|
||
async def execute(self) -> str: | ||
if not self.workers: | ||
return "" | ||
|
||
# Count answers from workers | ||
results = Counter(worker.answer for worker in self.workers) | ||
total_count = len(self.workers) | ||
majority_count = total_count // 2 + 1 | ||
|
||
# Check for majority | ||
for result, count in results.items(): | ||
if count >= majority_count: | ||
return result | ||
|
||
# No majority reached, return most common | ||
return results.most_common(1)[0][0] | ||
|
||
@register_result_strategy | ||
class BestOfNStrategy(ResultStrategy): | ||
"""Strategy that selects the best result based on a scoring function""" | ||
|
||
async def execute(self) -> str: | ||
if not self.workers: | ||
return "" | ||
|
||
# Score each worker's answer | ||
scores = {} | ||
for worker in self.workers: | ||
result = worker.answer | ||
# Score could be based on worker configuration and state | ||
check_count = worker.worker_config.get("check", 0) if worker.worker_config else 0 | ||
# Add more sophisticated scoring logic here | ||
if result in scores: | ||
scores[result] = max(scores[result], check_count) | ||
else: | ||
scores[result] = check_count | ||
|
||
if not scores: | ||
return self.workers[0].answer if self.workers else "" | ||
|
||
return max(scores.items(), key=lambda x: x[1])[0] | ||
|
||
@register_result_strategy | ||
class UscStrategy(ResultStrategy): | ||
"""Strategy that combines self-consistent results""" | ||
|
||
async def execute(self) -> str: | ||
if not self.workers: | ||
return "" | ||
|
||
# Count answers for now, but could implement more sophisticated | ||
# consistency checking in the future | ||
results = Counter(worker.answer for worker in self.workers) | ||
return results.most_common(1)[0][0] | ||
|
||
|
||
@register_result_strategy | ||
class UsccStrategy(ResultStrategy): | ||
"""Strategy that combines self-consistent results""" | ||
|
||
async def execute(self) -> str: | ||
if not self.workers: | ||
return "" | ||
|
||
# Count answers for now, but could implement more sophisticated | ||
# consistency checking in the future | ||
results = Counter(worker.answer for worker in self.workers) | ||
return results.most_common(1)[0][0] | ||
|
||
@register_result_strategy | ||
class CodiumStrategy(ResultStrategy): | ||
"""Strategy that implements Codium's solution ranking and improvement process""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.solid_test_cases = { | ||
'public': [], # todo: bootstrap a List of verified public test cases | ||
'ai': [] # todo: bootstrap a List of verified AI-generated test cases | ||
} | ||
self.max_improvement_attempts = 3 | ||
|
||
def _process_test_cases(self, test_cases, entry_point="main"): | ||
"""Process test cases from metadata format to internal format""" | ||
if not test_cases or not isinstance(test_cases, dict): | ||
return [] | ||
|
||
inputs = test_cases.get('input', []) | ||
outputs = test_cases.get('output', []) | ||
|
||
# Ensure we have matching input/output pairs | ||
return list(zip(inputs, outputs)) | ||
|
||
async def rank_solutions(self): | ||
"""Rank solutions based on initial quality metrics""" | ||
ranked_solutions = [] | ||
for worker in self.workers: | ||
score = 0 | ||
# Add scoring logic here - could be based on: | ||
# - Code complexity | ||
# - Test case coverage | ||
# - Worker's confidence score | ||
# - Previous success rate | ||
score += worker.worker_config.get("check", 0) if worker.worker_config else 0 | ||
ranked_solutions.append((worker, score)) | ||
|
||
return sorted(ranked_solutions, key=lambda x: x[1], reverse=True) | ||
|
||
async def verify_test_case(self, test_case, solution): | ||
"""Verify if a test case is solid by checking if solution passes it""" | ||
try: | ||
# Implementation would depend on your test execution framework | ||
# This is a placeholder for the actual test execution logic | ||
result = await self.execute_test(test_case, solution) | ||
return result.success | ||
except Exception as e: | ||
return False | ||
|
||
async def improve_solution(self, worker, iteration=0): | ||
"""Improve a solution using public tests and AI-generated tests""" | ||
# First, ensure solution passes all solid test cases | ||
for test_type, test_cases in self.solid_test_cases.items(): | ||
for test_case in test_cases: | ||
if not await self.verify_test_case(test_case, worker.answer): | ||
return False | ||
|
||
# Try to improve the solution | ||
if not await worker.improve(): | ||
return False | ||
|
||
# Verify and add passing public tests to solid test cases | ||
for test in worker.input.metadata.get("test_cases", []): | ||
if await self.verify_test_case(test, worker.answer): | ||
if test not in self.solid_test_cases['public']: | ||
self.solid_test_cases['public'].append(test) | ||
|
||
# Verify and add passing AI tests to solid test cases | ||
for test in worker.input.metadata.get("ai_test_cases", []): | ||
if await self.verify_test_case(test, worker.answer): | ||
if test not in self.solid_test_cases['ai']: | ||
self.solid_test_cases['ai'].append(test) | ||
|
||
return True | ||
|
||
async def execute(self) -> str: | ||
if not self.workers: | ||
return "" | ||
|
||
# Rank initial solutions | ||
ranked_solutions = await self.rank_solutions() | ||
|
||
# Try solutions in ranked order | ||
for worker, score in ranked_solutions: | ||
for attempt in range(self.max_improvement_attempts): | ||
if await self.improve_solution(worker, attempt): | ||
return worker.answer | ||
|
||
# If no solution passes all tests, return the highest ranked solution | ||
return ranked_solutions[0][0].answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import List, Optional, Dict, Any | ||
from azure.ai.inference import ChatCompletionsClient | ||
from azure.ai.inference.models import SystemMessage, UserMessage, AssistantMessage | ||
from azure.core.credentials import AzureKeyCredential | ||
|
||
from minion.logs import log_llm_stream | ||
from minion.providers.base_provider import BaseProvider | ||
from minion.providers.llm_provider_registry import llm_registry | ||
from minion.message_types import Message, MessageContent, ContentType | ||
from minion.providers.openai_provider import OpenAIProvider | ||
|
||
|
||
@llm_registry.register("azure_inference") | ||
class AzureInferenceProvider(OpenAIProvider): | ||
def _setup(self) -> None: | ||
"""Setup Azure Inference SDK client""" | ||
endpoint = self.config.base_url | ||
key = self.config.api_key | ||
self.model = self.config.model or self.config.deployment_name | ||
|
||
self.client = ChatCompletionsClient( | ||
endpoint=endpoint, | ||
credential=AzureKeyCredential(key) | ||
) | ||
|
||
# def _prepare_messages(self, messages: List[Message] | Message | str) -> List[Any]: | ||
# """Convert minion Message objects to Azure Inference SDK message objects""" | ||
# # Convert single message or string to list | ||
# if isinstance(messages, (str, Message)): | ||
# messages = [messages if isinstance(messages, Message) else Message(role="user", content=messages)] | ||
# | ||
# azure_messages = [] | ||
# for msg in messages: | ||
# content = msg.content | ||
# if isinstance(content, str): | ||
# text_content = content | ||
# elif isinstance(content, MessageContent): | ||
# if content.type == ContentType.TEXT: | ||
# text_content = content.text | ||
# else: | ||
# # For now, we only handle text content for Azure Inference | ||
# # TODO: Add support for image content if Azure Inference SDK supports it | ||
# text_content = content.text if content.text else "" | ||
# else: | ||
# text_content = str(content) | ||
# | ||
# if msg.role == "system": | ||
# azure_messages.append(SystemMessage(content=text_content)) | ||
# elif msg.role == "user": | ||
# azure_messages.append(UserMessage(content=text_content)) | ||
# elif msg.role == "assistant": | ||
# azure_messages.append(AssistantMessage(content=text_content)) | ||
# return azure_messages | ||
|
||
async def generate(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: | ||
"""Generate completion using Azure Inference SDK""" | ||
azure_messages = self._prepare_messages(messages) | ||
|
||
# Prepare parameters | ||
params = { | ||
"messages": azure_messages, | ||
"model": self.model, | ||
"temperature": temperature if temperature is not None else 0.6, | ||
"max_tokens": kwargs.get("max_tokens", 1000) | ||
} | ||
|
||
# Add optional parameters if provided | ||
if "top_p" in kwargs: | ||
params["top_p"] = kwargs["top_p"] | ||
if "frequency_penalty" in kwargs: | ||
params["frequency_penalty"] = kwargs["frequency_penalty"] | ||
if "presence_penalty" in kwargs: | ||
params["presence_penalty"] = kwargs["presence_penalty"] | ||
|
||
response = self.client.complete(**params) | ||
return response.messages[0].content | ||
|
||
async def generate_stream(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: | ||
"""Generate streaming completion using Azure Inference SDK""" | ||
azure_messages = self._prepare_messages(messages) | ||
|
||
# Prepare parameters | ||
params = { | ||
"messages": azure_messages, | ||
"model": self.model, | ||
"temperature": temperature if temperature is not None else 0.6, | ||
"max_tokens": kwargs.get("max_tokens", 1000), | ||
"stream": True | ||
} | ||
|
||
# Add optional parameters if provided | ||
if "top_p" in kwargs: | ||
params["top_p"] = kwargs["top_p"] | ||
if "frequency_penalty" in kwargs: | ||
params["frequency_penalty"] = kwargs["frequency_penalty"] | ||
if "presence_penalty" in kwargs: | ||
params["presence_penalty"] = kwargs["presence_penalty"] | ||
|
||
response = self.client.complete(**params) | ||
full_content = "" | ||
completion_tokens = 0 | ||
for chunk in response: | ||
if chunk.choices: | ||
completion_tokens += 1 | ||
chunk_message = chunk.choices[0].delta.content | ||
full_content += chunk_message | ||
log_llm_stream(chunk_message) | ||
#yield chunk_message | ||
return full_content |