|
37 | 37 | import test_config
|
38 | 38 |
|
39 | 39 | from muagent.schemas.db import *
|
| 40 | +from muagent.schemas.apis.ekg_api_schema import LLMFCRequest |
40 | 41 | from muagent.db_handler import *
|
41 | 42 | from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
42 | 43 | from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
|
|
46 | 47 |
|
47 | 48 | from pydantic import BaseModel
|
48 | 49 |
|
49 |
| - |
| 50 | +from muagent.schemas.models import ModelConfig |
| 51 | +from muagent.models import get_model |
50 | 52 |
|
51 | 53 |
|
52 | 54 | cur_dir = os.path.dirname(__file__)
|
@@ -92,56 +94,75 @@ def update_params(self, **kwargs):
|
92 | 94 |
|
93 | 95 | def _llm_type(self, *args):
|
94 | 96 | return ""
|
95 |
| - |
96 |
| - def predict(self, prompt: str, stop = None) -> str: |
97 |
| - return self._call(prompt, stop) |
98 |
| - |
99 |
| - def _call(self, prompt: str, |
100 |
| - stop = None) -> str: |
| 97 | + |
| 98 | + def _get_model(self): |
101 | 99 | """_call
|
102 | 100 | """
|
103 |
| - return_str = "" |
104 |
| - stop = stop or self.stop |
105 |
| - |
106 |
| - if self.model_type == "ollama": |
107 |
| - stream = ollama.chat( |
108 |
| - model=self.model_name, |
109 |
| - messages=[{'role': 'user', 'content': prompt}], |
110 |
| - stream=True, |
111 |
| - ) |
112 |
| - answer = "" |
113 |
| - for chunk in stream: |
114 |
| - answer += chunk['message']['content'] |
115 |
| - |
116 |
| - return answer |
117 |
| - elif self.model_type == "openai": |
| 101 | + if self.model_type in [ |
| 102 | + "ollama", "qwen", "openai", "lingyiwanwu", |
| 103 | + "kimi", "moonshot", |
| 104 | + ]: |
118 | 105 | from muagent.llm_models.openai_model import getChatModelFromConfig
|
119 | 106 | llm_config = LLMConfig(
|
120 | 107 | model_name=self.model_name,
|
121 |
| - model_engine="openai", |
| 108 | + model_engine=self.model_type, |
122 | 109 | api_key=self.api_key,
|
123 | 110 | api_base_url=self.url,
|
124 | 111 | temperature=self.temperature,
|
125 | 112 | stop=self.stop
|
126 | 113 | )
|
127 | 114 | model = getChatModelFromConfig(llm_config)
|
128 |
| - return model.predict(prompt, stop=self.stop) |
129 |
| - elif self.model_type in ["lingyiwanwu", "kimi", "moonshot", "qwen"]: |
130 |
| - from muagent.llm_models.openai_model import getChatModelFromConfig |
131 |
| - llm_config = LLMConfig( |
| 115 | + else: |
| 116 | + model_config = ModelConfig( |
| 117 | + model_type=self.model_type, |
132 | 118 | model_name=self.model_name,
|
133 |
| - model_engine=self.model_type, |
134 | 119 | api_key=self.api_key,
|
135 |
| - api_base_url=self.url, |
| 120 | + api_url=self.url, |
136 | 121 | temperature=self.temperature,
|
137 |
| - stop=self.stop |
138 | 122 | )
|
139 |
| - model = getChatModelFromConfig(llm_config) |
140 |
| - return model.predict(prompt, stop=self.stop) |
141 |
| - else: |
142 |
| - pass |
| 123 | + model = get_model(model_config) |
| 124 | + return model |
| 125 | + |
| 126 | + def predict(self, prompt: str, stop = None) -> str: |
| 127 | + return self._call(prompt, stop) |
143 | 128 |
|
144 |
| - return return_str |
| 129 | + def fc(self, request: LLMFCRequest) -> str: |
| 130 | + """_function_call |
| 131 | + """ |
| 132 | + if self.model_type not in [ |
| 133 | + "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen" |
| 134 | + ]: |
| 135 | + return f"{self.model_type} not in valid model range" |
| 136 | + |
| 137 | + model = self._get_model() |
| 138 | + return model.fc( |
| 139 | + messages=request.messages, |
| 140 | + tools=request.tools, |
| 141 | + tool_choice=request.tool_choice, |
| 142 | + parallel_tool_calls=request.parallel_tool_calls, |
| 143 | + ) |
| 144 | + |
| 145 | + def _call(self, prompt: str, |
| 146 | + stop = None) -> str: |
| 147 | + """_call |
| 148 | + """ |
| 149 | + return_str = "" |
| 150 | + stop = stop or self.stop |
| 151 | + if self.model_type not in [ |
| 152 | + "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen" |
| 153 | + ]: |
| 154 | + pass |
| 155 | + elif self.model_type not in [ |
| 156 | + "dashscope_chat", "moonshot_chat", "ollama_chat", |
| 157 | + "openai_chat", "qwen_chat", "yi_chat", |
| 158 | + "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding" |
| 159 | + ]: |
| 160 | + pass |
| 161 | + else: |
| 162 | + return f"{self.model_type} not in valid model range" |
| 163 | + |
| 164 | + model = self._get_model() |
| 165 | + return model.predict(prompt, stop=self.stop) |
145 | 166 |
|
146 | 167 |
|
147 | 168 | class CustomEmbeddings(Embeddings):
|
@@ -185,6 +206,17 @@ def _get_sentence_emb(self, sentence: str) -> dict:
|
185 | 206 | )
|
186 | 207 | text2vector_dict = get_embedding("openai", [sentence], embed_config=embed_config)
|
187 | 208 | return text2vector_dict[sentence]
|
| 209 | + elif self.embedding_type in [ |
| 210 | + "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding" |
| 211 | + ]: |
| 212 | + model_config = ModelConfig( |
| 213 | + model_type=self.embedding_type, |
| 214 | + model_name=self.model_name, |
| 215 | + api_key=self.api_key, |
| 216 | + api_url=self.url, |
| 217 | + ) |
| 218 | + model = get_model(model_config) |
| 219 | + return model.embed_query(sentence) |
188 | 220 | else:
|
189 | 221 | pass
|
190 | 222 |
|
@@ -280,6 +312,7 @@ def embed_query(self, text: str) -> List[float]:
|
280 | 312 | llm_config=llm_config,
|
281 | 313 | tb_config=tb_config,
|
282 | 314 | gb_config=gb_config,
|
| 315 | + initialize_space=True, |
283 | 316 | clear_history_data=clear_history_data
|
284 | 317 | )
|
285 | 318 |
|
|
0 commit comments