diff --git a/examples/smart_minion/brain.py b/examples/smart_minion/brain.py index e76304ab..343c43fa 100644 --- a/examples/smart_minion/brain.py +++ b/examples/smart_minion/brain.py @@ -31,7 +31,10 @@ async def smart_brain(): llm=llm, #llms={"route": [ "llama3.2","llama3.1"]} ) - obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 1 3 4 6", route="python") + # obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 1 3 4 6", route="python") + # print(obs) + + obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 2 3 5 12", route="python") print(obs) current_file_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/minion/main/pre_processing.py b/minion/main/pre_processing.py new file mode 100644 index 00000000..0aaa17e9 --- /dev/null +++ b/minion/main/pre_processing.py @@ -0,0 +1,55 @@ +from typing import Dict, Any +from minion.main.minion import Minion, register_pre_processing_minion +from minion.main.prompt import PROBLEM_REFLECT_PROMPT, EXAMPLE_REASONING_PROMPT +from minion.actions.lmp_action_node import LmpActionNode +from jinja2 import Template +from minion.logs import logger + +class PreProcessingMinion(Minion): + """Base class for all pre-processing minions""" + pass + +@register_pre_processing_minion +class ProblemReflectMinion(PreProcessingMinion): + """Minion that performs problem reflection before solving""" + + async def execute(self): + """Execute the problem reflection process""" + prompt = Template(PROBLEM_REFLECT_PROMPT) + prompt = prompt.render(input=self.input) + + node = LmpActionNode(self.brain.llm) + reflection = await node.execute(prompt) + + # Store reflection in input metadata for later use + self.input.info["problem_reflection"] = reflection + + logger.info(f"Problem reflection completed: {reflection}") + return reflection + +@register_pre_processing_minion +class ExampleReasoningMinion(PreProcessingMinion): + """Minion that analyzes and reasons about examples in the query""" + + async def execute(self): + """Execute the example reasoning process""" + # Check if the input contains examples + if not self._has_examples(): + logger.info("No examples found in the input, skipping example reasoning") + return None + + prompt = Template(EXAMPLE_REASONING_PROMPT) + prompt = prompt.render(input=self.input) + + node = LmpActionNode(self.brain.llm) + reasoning = await node.execute(prompt) + + self.input.info["example_reasoning"] = reasoning + + logger.info(f"Example reasoning completed: {reasoning}") + return reasoning + + def _has_examples(self) -> bool: + """Check if the input contains examples""" + # This can be implemented with more complex example detection logic as needed + return 'example' in self.input.query.lower() or 'examples' in self.input.query.lower() \ No newline at end of file diff --git a/minion/main/result_strategy.py b/minion/main/result_strategy.py new file mode 100644 index 00000000..6e42e33d --- /dev/null +++ b/minion/main/result_strategy.py @@ -0,0 +1,181 @@ +from typing import Dict, Any, List +from collections import Counter +from minion.main.minion import Minion, register_result_strategy + +class ResultStrategy(Minion): + """Base class for result processing strategies""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.workers = kwargs.get('workers', []) # List of actual worker instances + + async def execute(self) -> str: + """Process the results according to the strategy + + Returns: + The processed final result + """ + raise NotImplementedError + +@register_result_strategy +class MajorityVotingStrategy(ResultStrategy): + """Strategy that selects result with highest vote count""" + + async def execute(self) -> str: + if not self.workers: + return "" + + # Count answers from workers + results = Counter(worker.answer for worker in self.workers) + total_count = len(self.workers) + majority_count = total_count // 2 + 1 + + # Check for majority + for result, count in results.items(): + if count >= majority_count: + return result + + # No majority reached, return most common + return results.most_common(1)[0][0] + +@register_result_strategy +class BestOfNStrategy(ResultStrategy): + """Strategy that selects the best result based on a scoring function""" + + async def execute(self) -> str: + if not self.workers: + return "" + + # Score each worker's answer + scores = {} + for worker in self.workers: + result = worker.answer + # Score could be based on worker configuration and state + check_count = worker.worker_config.get("check", 0) if worker.worker_config else 0 + # Add more sophisticated scoring logic here + if result in scores: + scores[result] = max(scores[result], check_count) + else: + scores[result] = check_count + + if not scores: + return self.workers[0].answer if self.workers else "" + + return max(scores.items(), key=lambda x: x[1])[0] + +@register_result_strategy +class UscStrategy(ResultStrategy): + """Strategy that combines self-consistent results""" + + async def execute(self) -> str: + if not self.workers: + return "" + + # Count answers for now, but could implement more sophisticated + # consistency checking in the future + results = Counter(worker.answer for worker in self.workers) + return results.most_common(1)[0][0] + + +@register_result_strategy +class UsccStrategy(ResultStrategy): + """Strategy that combines self-consistent results""" + + async def execute(self) -> str: + if not self.workers: + return "" + + # Count answers for now, but could implement more sophisticated + # consistency checking in the future + results = Counter(worker.answer for worker in self.workers) + return results.most_common(1)[0][0] + +@register_result_strategy +class CodiumStrategy(ResultStrategy): + """Strategy that implements Codium's solution ranking and improvement process""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.solid_test_cases = { + 'public': [], # todo: bootstrap a List of verified public test cases + 'ai': [] # todo: bootstrap a List of verified AI-generated test cases + } + self.max_improvement_attempts = 3 + + def _process_test_cases(self, test_cases, entry_point="main"): + """Process test cases from metadata format to internal format""" + if not test_cases or not isinstance(test_cases, dict): + return [] + + inputs = test_cases.get('input', []) + outputs = test_cases.get('output', []) + + # Ensure we have matching input/output pairs + return list(zip(inputs, outputs)) + + async def rank_solutions(self): + """Rank solutions based on initial quality metrics""" + ranked_solutions = [] + for worker in self.workers: + score = 0 + # Add scoring logic here - could be based on: + # - Code complexity + # - Test case coverage + # - Worker's confidence score + # - Previous success rate + score += worker.worker_config.get("check", 0) if worker.worker_config else 0 + ranked_solutions.append((worker, score)) + + return sorted(ranked_solutions, key=lambda x: x[1], reverse=True) + + async def verify_test_case(self, test_case, solution): + """Verify if a test case is solid by checking if solution passes it""" + try: + # Implementation would depend on your test execution framework + # This is a placeholder for the actual test execution logic + result = await self.execute_test(test_case, solution) + return result.success + except Exception as e: + return False + + async def improve_solution(self, worker, iteration=0): + """Improve a solution using public tests and AI-generated tests""" + # First, ensure solution passes all solid test cases + for test_type, test_cases in self.solid_test_cases.items(): + for test_case in test_cases: + if not await self.verify_test_case(test_case, worker.answer): + return False + + # Try to improve the solution + if not await worker.improve(): + return False + + # Verify and add passing public tests to solid test cases + for test in worker.input.metadata.get("test_cases", []): + if await self.verify_test_case(test, worker.answer): + if test not in self.solid_test_cases['public']: + self.solid_test_cases['public'].append(test) + + # Verify and add passing AI tests to solid test cases + for test in worker.input.metadata.get("ai_test_cases", []): + if await self.verify_test_case(test, worker.answer): + if test not in self.solid_test_cases['ai']: + self.solid_test_cases['ai'].append(test) + + return True + + async def execute(self) -> str: + if not self.workers: + return "" + + # Rank initial solutions + ranked_solutions = await self.rank_solutions() + + # Try solutions in ranked order + for worker, score in ranked_solutions: + for attempt in range(self.max_improvement_attempts): + if await self.improve_solution(worker, attempt): + return worker.answer + + # If no solution passes all tests, return the highest ranked solution + return ranked_solutions[0][0].answer \ No newline at end of file diff --git a/minion/providers/azure_inference_provider.py b/minion/providers/azure_inference_provider.py new file mode 100644 index 00000000..fa090e8d --- /dev/null +++ b/minion/providers/azure_inference_provider.py @@ -0,0 +1,109 @@ +from typing import List, Optional, Dict, Any +from azure.ai.inference import ChatCompletionsClient +from azure.ai.inference.models import SystemMessage, UserMessage, AssistantMessage +from azure.core.credentials import AzureKeyCredential + +from minion.logs import log_llm_stream +from minion.providers.base_provider import BaseProvider +from minion.providers.llm_provider_registry import llm_registry +from minion.message_types import Message, MessageContent, ContentType +from minion.providers.openai_provider import OpenAIProvider + + +@llm_registry.register("azure_inference") +class AzureInferenceProvider(OpenAIProvider): + def _setup(self) -> None: + """Setup Azure Inference SDK client""" + endpoint = self.config.base_url + key = self.config.api_key + self.model = self.config.model or self.config.deployment_name + + self.client = ChatCompletionsClient( + endpoint=endpoint, + credential=AzureKeyCredential(key) + ) + + # def _prepare_messages(self, messages: List[Message] | Message | str) -> List[Any]: + # """Convert minion Message objects to Azure Inference SDK message objects""" + # # Convert single message or string to list + # if isinstance(messages, (str, Message)): + # messages = [messages if isinstance(messages, Message) else Message(role="user", content=messages)] + # + # azure_messages = [] + # for msg in messages: + # content = msg.content + # if isinstance(content, str): + # text_content = content + # elif isinstance(content, MessageContent): + # if content.type == ContentType.TEXT: + # text_content = content.text + # else: + # # For now, we only handle text content for Azure Inference + # # TODO: Add support for image content if Azure Inference SDK supports it + # text_content = content.text if content.text else "" + # else: + # text_content = str(content) + # + # if msg.role == "system": + # azure_messages.append(SystemMessage(content=text_content)) + # elif msg.role == "user": + # azure_messages.append(UserMessage(content=text_content)) + # elif msg.role == "assistant": + # azure_messages.append(AssistantMessage(content=text_content)) + # return azure_messages + + async def generate(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: + """Generate completion using Azure Inference SDK""" + azure_messages = self._prepare_messages(messages) + + # Prepare parameters + params = { + "messages": azure_messages, + "model": self.model, + "temperature": temperature if temperature is not None else 0.6, + "max_tokens": kwargs.get("max_tokens", 1000) + } + + # Add optional parameters if provided + if "top_p" in kwargs: + params["top_p"] = kwargs["top_p"] + if "frequency_penalty" in kwargs: + params["frequency_penalty"] = kwargs["frequency_penalty"] + if "presence_penalty" in kwargs: + params["presence_penalty"] = kwargs["presence_penalty"] + + response = self.client.complete(**params) + return response.messages[0].content + + async def generate_stream(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str: + """Generate streaming completion using Azure Inference SDK""" + azure_messages = self._prepare_messages(messages) + + # Prepare parameters + params = { + "messages": azure_messages, + "model": self.model, + "temperature": temperature if temperature is not None else 0.6, + "max_tokens": kwargs.get("max_tokens", 1000), + "stream": True + } + + # Add optional parameters if provided + if "top_p" in kwargs: + params["top_p"] = kwargs["top_p"] + if "frequency_penalty" in kwargs: + params["frequency_penalty"] = kwargs["frequency_penalty"] + if "presence_penalty" in kwargs: + params["presence_penalty"] = kwargs["presence_penalty"] + + response = self.client.complete(**params) + full_content = "" + completion_tokens = 0 + for chunk in response: + if chunk.choices: + completion_tokens += 1 + chunk_message = chunk.choices[0].delta.content + full_content += chunk_message + log_llm_stream(chunk_message) + #yield chunk_message + return full_content \ No newline at end of file