|
| 1 | +# The MIT License (MIT) |
| 2 | +# Copyright © 2024 Yuma Rao |
| 3 | + |
| 4 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 5 | +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation |
| 6 | +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, |
| 7 | +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: |
| 8 | + |
| 9 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of |
| 10 | +# the Software. |
| 11 | + |
| 12 | +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO |
| 13 | +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL |
| 14 | +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION |
| 15 | +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
| 16 | +# DEALINGS IN THE SOFTWARE. |
| 17 | + |
| 18 | +import time |
| 19 | +import os |
| 20 | +import bittensor as bt |
| 21 | +import argparse |
| 22 | +from starlette.types import Send |
| 23 | +from functools import partial |
| 24 | +from typing import Dict, Awaitable |
| 25 | + |
| 26 | +# Bittensor Miner Template: |
| 27 | +from prompting.base.prompting_miner import BaseStreamPromptingMiner |
| 28 | +from prompting.protocol import StreamPromptingSynapse |
| 29 | + |
| 30 | +# import base miner class which takes care of most of the boilerplate |
| 31 | + |
| 32 | +from prompting.miners.utils import OpenAIUtils |
| 33 | + |
| 34 | +from langchain.prompts import ChatPromptTemplate |
| 35 | +from langchain_core.output_parsers import StrOutputParser |
| 36 | +from langchain.chat_models import ChatOpenAI |
| 37 | +from dotenv import load_dotenv, find_dotenv |
| 38 | +from langchain_core.runnables.base import RunnableSequence |
| 39 | +from deprecated import deprecated |
| 40 | + |
| 41 | +@deprecated(version="2.4.1+", reason="Class is deprecated, use openai miner for reference on example miner.") |
| 42 | +class LangchainMiner(BaseStreamPromptingMiner, OpenAIUtils): |
| 43 | + """Langchain-based miner which uses OpenAI's API as the LLM. |
| 44 | + This miner does not use any tools or external APIs when processing requests - it relies entirely on the models' own representation and world model. In some cases, this can produce lower quality results. |
| 45 | + You should also install the dependencies for this miner, which can be found in the requirements.txt file in this directory. |
| 46 | + """ |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def add_args(cls, parser: argparse.ArgumentParser): |
| 50 | + """ |
| 51 | + Adds OpenAI-specific arguments to the command line parser. |
| 52 | + """ |
| 53 | + super().add_args(parser) |
| 54 | + |
| 55 | + def __init__(self, config=None): |
| 56 | + super().__init__(config=config) |
| 57 | + |
| 58 | + bt.logging.info(f"Initializing with model {self.config.neuron.model_id}...") |
| 59 | + |
| 60 | + if self.config.wandb.on: |
| 61 | + self.identity_tags = ("openai_miner",) + (self.config.neuron.model_id,) |
| 62 | + |
| 63 | + _ = load_dotenv(find_dotenv()) |
| 64 | + api_key = os.environ.get("OPENAI_API_KEY") |
| 65 | + |
| 66 | + # Set openai key and other args |
| 67 | + self.model = ChatOpenAI( |
| 68 | + api_key=api_key, |
| 69 | + model_name=self.config.neuron.model_id, |
| 70 | + max_tokens=self.config.neuron.max_tokens, |
| 71 | + temperature=self.config.neuron.temperature, |
| 72 | + ) |
| 73 | + |
| 74 | + self.system_prompt = self.config.neuron.system_prompt |
| 75 | + self.accumulated_total_tokens = 0 |
| 76 | + self.accumulated_prompt_tokens = 0 |
| 77 | + self.accumulated_completion_tokens = 0 |
| 78 | + self.accumulated_total_cost = 0 |
| 79 | + |
| 80 | + def forward(self, synapse: StreamPromptingSynapse) -> Awaitable: |
| 81 | + async def _forward( |
| 82 | + self, |
| 83 | + message: str, |
| 84 | + init_time: float, |
| 85 | + timeout_threshold: float, |
| 86 | + chain: RunnableSequence, |
| 87 | + chain_formatter: Dict[str, str], |
| 88 | + send: Send, |
| 89 | + ): |
| 90 | + buffer = [] |
| 91 | + temp_completion = "" # for wandb logging |
| 92 | + timeout_reached = False |
| 93 | + |
| 94 | + try: |
| 95 | + # Langchain built in streaming. 'astream' also available for async |
| 96 | + for token in chain.stream(chain_formatter): |
| 97 | + buffer.append(token) |
| 98 | + |
| 99 | + if time.time() - init_time > timeout_threshold: |
| 100 | + bt.logging.debug(f"⏰ Timeout reached, stopping streaming") |
| 101 | + timeout_reached = True |
| 102 | + break |
| 103 | + |
| 104 | + if len(buffer) == self.config.neuron.streaming_batch_size: |
| 105 | + joined_buffer = "".join(buffer) |
| 106 | + temp_completion += joined_buffer |
| 107 | + bt.logging.debug(f"Streamed tokens: {joined_buffer}") |
| 108 | + |
| 109 | + await send( |
| 110 | + { |
| 111 | + "type": "http.response.body", |
| 112 | + "body": joined_buffer.encode("utf-8"), |
| 113 | + "more_body": True, |
| 114 | + } |
| 115 | + ) |
| 116 | + buffer = [] |
| 117 | + |
| 118 | + if ( |
| 119 | + buffer and not timeout_reached |
| 120 | + ): # Don't send the last buffer of data if timeout. |
| 121 | + joined_buffer = "".join(buffer) |
| 122 | + await send( |
| 123 | + { |
| 124 | + "type": "http.response.body", |
| 125 | + "body": joined_buffer.encode("utf-8"), |
| 126 | + "more_body": False, |
| 127 | + } |
| 128 | + ) |
| 129 | + |
| 130 | + except Exception as e: |
| 131 | + bt.logging.error(f"Error in forward: {e}") |
| 132 | + if self.config.neuron.stop_on_forward_exception: |
| 133 | + self.should_exit = True |
| 134 | + |
| 135 | + finally: |
| 136 | + synapse_latency = time.time() - init_time |
| 137 | + if self.config.wandb.on: |
| 138 | + self.log_event( |
| 139 | + timing=synapse_latency, |
| 140 | + prompt=message, |
| 141 | + completion=temp_completion, |
| 142 | + system_prompt=self.system_prompt, |
| 143 | + ) |
| 144 | + |
| 145 | + bt.logging.debug(f"📧 Message received, forwarding synapse: {synapse}") |
| 146 | + |
| 147 | + prompt = ChatPromptTemplate.from_messages( |
| 148 | + [("system", self.system_prompt), ("user", "{input}")] |
| 149 | + ) |
| 150 | + chain = prompt | self.model | StrOutputParser() |
| 151 | + |
| 152 | + role = synapse.roles[-1] |
| 153 | + message = synapse.messages[-1] |
| 154 | + |
| 155 | + chain_formatter = {"role": role, "input": message} |
| 156 | + |
| 157 | + init_time = time.time() |
| 158 | + timeout_threshold = synapse.timeout |
| 159 | + |
| 160 | + token_streamer = partial( |
| 161 | + _forward, |
| 162 | + self, |
| 163 | + message, |
| 164 | + init_time, |
| 165 | + timeout_threshold, |
| 166 | + chain, |
| 167 | + chain_formatter, |
| 168 | + ) |
| 169 | + return synapse.create_streaming_response(token_streamer) |
0 commit comments