Skip to content

Commit

Permalink
deepseek-r1
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Feb 17, 2025
1 parent 7faa8a3 commit 245d06b
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/smart_minion/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ async def smart_brain():
llm=llm,
#llms={"route": [ "llama3.2","llama3.1"]}
)
obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 1 3 4 6", route="python")
# obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 1 3 4 6", route="python")
# print(obs)

obs, score, *_ = await brain.step(query="what's the solution for game of 24 for 2 3 5 12", route="python")
print(obs)

current_file_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down
55 changes: 55 additions & 0 deletions minion/main/pre_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Any
from minion.main.minion import Minion, register_pre_processing_minion
from minion.main.prompt import PROBLEM_REFLECT_PROMPT, EXAMPLE_REASONING_PROMPT
from minion.actions.lmp_action_node import LmpActionNode
from jinja2 import Template
from minion.logs import logger

class PreProcessingMinion(Minion):
"""Base class for all pre-processing minions"""
pass

@register_pre_processing_minion
class ProblemReflectMinion(PreProcessingMinion):
"""Minion that performs problem reflection before solving"""

async def execute(self):
"""Execute the problem reflection process"""
prompt = Template(PROBLEM_REFLECT_PROMPT)
prompt = prompt.render(input=self.input)

node = LmpActionNode(self.brain.llm)
reflection = await node.execute(prompt)

# Store reflection in input metadata for later use
self.input.info["problem_reflection"] = reflection

logger.info(f"Problem reflection completed: {reflection}")
return reflection

@register_pre_processing_minion
class ExampleReasoningMinion(PreProcessingMinion):
"""Minion that analyzes and reasons about examples in the query"""

async def execute(self):
"""Execute the example reasoning process"""
# Check if the input contains examples
if not self._has_examples():
logger.info("No examples found in the input, skipping example reasoning")
return None

prompt = Template(EXAMPLE_REASONING_PROMPT)
prompt = prompt.render(input=self.input)

node = LmpActionNode(self.brain.llm)
reasoning = await node.execute(prompt)

self.input.info["example_reasoning"] = reasoning

logger.info(f"Example reasoning completed: {reasoning}")
return reasoning

def _has_examples(self) -> bool:
"""Check if the input contains examples"""
# This can be implemented with more complex example detection logic as needed
return 'example' in self.input.query.lower() or 'examples' in self.input.query.lower()
181 changes: 181 additions & 0 deletions minion/main/result_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import Dict, Any, List
from collections import Counter
from minion.main.minion import Minion, register_result_strategy

class ResultStrategy(Minion):
"""Base class for result processing strategies"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.workers = kwargs.get('workers', []) # List of actual worker instances

async def execute(self) -> str:
"""Process the results according to the strategy
Returns:
The processed final result
"""
raise NotImplementedError

@register_result_strategy
class MajorityVotingStrategy(ResultStrategy):
"""Strategy that selects result with highest vote count"""

async def execute(self) -> str:
if not self.workers:
return ""

# Count answers from workers
results = Counter(worker.answer for worker in self.workers)
total_count = len(self.workers)
majority_count = total_count // 2 + 1

# Check for majority
for result, count in results.items():
if count >= majority_count:
return result

# No majority reached, return most common
return results.most_common(1)[0][0]

@register_result_strategy
class BestOfNStrategy(ResultStrategy):
"""Strategy that selects the best result based on a scoring function"""

async def execute(self) -> str:
if not self.workers:
return ""

# Score each worker's answer
scores = {}
for worker in self.workers:
result = worker.answer
# Score could be based on worker configuration and state
check_count = worker.worker_config.get("check", 0) if worker.worker_config else 0
# Add more sophisticated scoring logic here
if result in scores:
scores[result] = max(scores[result], check_count)
else:
scores[result] = check_count

if not scores:
return self.workers[0].answer if self.workers else ""

return max(scores.items(), key=lambda x: x[1])[0]

@register_result_strategy
class UscStrategy(ResultStrategy):
"""Strategy that combines self-consistent results"""

async def execute(self) -> str:
if not self.workers:
return ""

# Count answers for now, but could implement more sophisticated
# consistency checking in the future
results = Counter(worker.answer for worker in self.workers)
return results.most_common(1)[0][0]


@register_result_strategy
class UsccStrategy(ResultStrategy):
"""Strategy that combines self-consistent results"""

async def execute(self) -> str:
if not self.workers:
return ""

# Count answers for now, but could implement more sophisticated
# consistency checking in the future
results = Counter(worker.answer for worker in self.workers)
return results.most_common(1)[0][0]

@register_result_strategy
class CodiumStrategy(ResultStrategy):
"""Strategy that implements Codium's solution ranking and improvement process"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.solid_test_cases = {
'public': [], # todo: bootstrap a List of verified public test cases
'ai': [] # todo: bootstrap a List of verified AI-generated test cases
}
self.max_improvement_attempts = 3

def _process_test_cases(self, test_cases, entry_point="main"):
"""Process test cases from metadata format to internal format"""
if not test_cases or not isinstance(test_cases, dict):
return []

inputs = test_cases.get('input', [])
outputs = test_cases.get('output', [])

# Ensure we have matching input/output pairs
return list(zip(inputs, outputs))

async def rank_solutions(self):
"""Rank solutions based on initial quality metrics"""
ranked_solutions = []
for worker in self.workers:
score = 0
# Add scoring logic here - could be based on:
# - Code complexity
# - Test case coverage
# - Worker's confidence score
# - Previous success rate
score += worker.worker_config.get("check", 0) if worker.worker_config else 0
ranked_solutions.append((worker, score))

return sorted(ranked_solutions, key=lambda x: x[1], reverse=True)

async def verify_test_case(self, test_case, solution):
"""Verify if a test case is solid by checking if solution passes it"""
try:
# Implementation would depend on your test execution framework
# This is a placeholder for the actual test execution logic
result = await self.execute_test(test_case, solution)
return result.success
except Exception as e:
return False

async def improve_solution(self, worker, iteration=0):
"""Improve a solution using public tests and AI-generated tests"""
# First, ensure solution passes all solid test cases
for test_type, test_cases in self.solid_test_cases.items():
for test_case in test_cases:
if not await self.verify_test_case(test_case, worker.answer):
return False

# Try to improve the solution
if not await worker.improve():
return False

# Verify and add passing public tests to solid test cases
for test in worker.input.metadata.get("test_cases", []):
if await self.verify_test_case(test, worker.answer):
if test not in self.solid_test_cases['public']:
self.solid_test_cases['public'].append(test)

# Verify and add passing AI tests to solid test cases
for test in worker.input.metadata.get("ai_test_cases", []):
if await self.verify_test_case(test, worker.answer):
if test not in self.solid_test_cases['ai']:
self.solid_test_cases['ai'].append(test)

return True

async def execute(self) -> str:
if not self.workers:
return ""

# Rank initial solutions
ranked_solutions = await self.rank_solutions()

# Try solutions in ranked order
for worker, score in ranked_solutions:
for attempt in range(self.max_improvement_attempts):
if await self.improve_solution(worker, attempt):
return worker.answer

# If no solution passes all tests, return the highest ranked solution
return ranked_solutions[0][0].answer
109 changes: 109 additions & 0 deletions minion/providers/azure_inference_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List, Optional, Dict, Any
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage, AssistantMessage
from azure.core.credentials import AzureKeyCredential

from minion.logs import log_llm_stream
from minion.providers.base_provider import BaseProvider
from minion.providers.llm_provider_registry import llm_registry
from minion.message_types import Message, MessageContent, ContentType
from minion.providers.openai_provider import OpenAIProvider


@llm_registry.register("azure_inference")
class AzureInferenceProvider(OpenAIProvider):
def _setup(self) -> None:
"""Setup Azure Inference SDK client"""
endpoint = self.config.base_url
key = self.config.api_key
self.model = self.config.model or self.config.deployment_name

self.client = ChatCompletionsClient(
endpoint=endpoint,
credential=AzureKeyCredential(key)
)

# def _prepare_messages(self, messages: List[Message] | Message | str) -> List[Any]:
# """Convert minion Message objects to Azure Inference SDK message objects"""
# # Convert single message or string to list
# if isinstance(messages, (str, Message)):
# messages = [messages if isinstance(messages, Message) else Message(role="user", content=messages)]
#
# azure_messages = []
# for msg in messages:
# content = msg.content
# if isinstance(content, str):
# text_content = content
# elif isinstance(content, MessageContent):
# if content.type == ContentType.TEXT:
# text_content = content.text
# else:
# # For now, we only handle text content for Azure Inference
# # TODO: Add support for image content if Azure Inference SDK supports it
# text_content = content.text if content.text else ""
# else:
# text_content = str(content)
#
# if msg.role == "system":
# azure_messages.append(SystemMessage(content=text_content))
# elif msg.role == "user":
# azure_messages.append(UserMessage(content=text_content))
# elif msg.role == "assistant":
# azure_messages.append(AssistantMessage(content=text_content))
# return azure_messages

async def generate(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str:
"""Generate completion using Azure Inference SDK"""
azure_messages = self._prepare_messages(messages)

# Prepare parameters
params = {
"messages": azure_messages,
"model": self.model,
"temperature": temperature if temperature is not None else 0.6,
"max_tokens": kwargs.get("max_tokens", 1000)
}

# Add optional parameters if provided
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
if "frequency_penalty" in kwargs:
params["frequency_penalty"] = kwargs["frequency_penalty"]
if "presence_penalty" in kwargs:
params["presence_penalty"] = kwargs["presence_penalty"]

response = self.client.complete(**params)
return response.messages[0].content

async def generate_stream(self, messages: List[Message], temperature: Optional[float] = None, **kwargs) -> str:
"""Generate streaming completion using Azure Inference SDK"""
azure_messages = self._prepare_messages(messages)

# Prepare parameters
params = {
"messages": azure_messages,
"model": self.model,
"temperature": temperature if temperature is not None else 0.6,
"max_tokens": kwargs.get("max_tokens", 1000),
"stream": True
}

# Add optional parameters if provided
if "top_p" in kwargs:
params["top_p"] = kwargs["top_p"]
if "frequency_penalty" in kwargs:
params["frequency_penalty"] = kwargs["frequency_penalty"]
if "presence_penalty" in kwargs:
params["presence_penalty"] = kwargs["presence_penalty"]

response = self.client.complete(**params)
full_content = ""
completion_tokens = 0
for chunk in response:
if chunk.choices:
completion_tokens += 1
chunk_message = chunk.choices[0].delta.content
full_content += chunk_message
log_llm_stream(chunk_message)
#yield chunk_message
return full_content

0 comments on commit 245d06b

Please sign in to comment.