Skip to content

Commit 705bef6

Browse files
committedNov 15, 2024
Polish Ollama support
1 parent 3d90631 commit 705bef6

6 files changed

+45
-50
lines changed
 

‎README.md

+13-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ The experiment results will be stored in the directory named `results_low_level`
101101
## Support for New Models
102102
We rely on LangChain to provide a common interface to access different model APIs.
103103
You can add new supported models in the `netconfeval/common/model_configs.py` file.
104-
We currently support OpenAI models (`'type': 'openai'`) and HuggingFace models (`'type': 'HF'`) through a custom LangChain-compatible class (`netconfeval/foundation/langchain/hf.py`).
104+
We currently support OpenAI models (`'type': 'openai'`), Ollama models (`'type': 'Ollama'`), and HuggingFace models (`'type': 'HF'`) through a custom LangChain-compatible class (`netconfeval/foundation/langchain/hf.py`).
105105

106106
To add a model, just add a new Dict element to the `model_configurations` Dict, by providing a unique key for it.
107107
The new model key is then automatically visible using the `--model` command line parameter of the `.py` tests of the benchmarks.
@@ -131,6 +131,16 @@ model_configurations = {
131131
}
132132
```
133133

134+
### Ollama Models
135+
The Ollama model Dict contains the following keys:
136+
```python
137+
{
138+
'type': 'Ollama', # The type of the model, in this case 'Ollama'
139+
'model_name': 'llama3:8b-instruct-fp16', # The model name taken from Ollama library
140+
'num_predict': 4096, # Max output length
141+
}
142+
```
143+
134144
### HuggingFace Models
135145
The HuggingFace model Dict contains the following keys:
136146
```python
@@ -158,7 +168,7 @@ model_configurations = {
158168
```
159169

160170
### Adding new model types
161-
Aside from adding OpenAI and HuggingFace models, it is also possible to add new model types (for example Gemini by Google).
171+
Aside from adding OpenAI, Ollama, and HuggingFace models, it is also possible to add new model types (for example Gemini by Google).
162172

163173
We will continuously improve support for different APIs, but if you want to contribute:
164174
- Define a new `type`, coherent with the model types (e.g., `google` for Google models);
@@ -187,4 +197,4 @@ If you use NetConfEval, please cite our paper:
187197
```
188198

189199
## Help
190-
If you have any questions regarding our code or the paper, you can contact [Changjie Wang](https://www.kth.se/profile/changjie) (changjie at kth.se) and/or [Mariano Scazzariello](https://www.kth.se/profile/marianos) (mariano.scazzariello at ri.se).
200+
If you have any questions regarding our code or the paper, you can contact [Changjie Wang](https://www.kth.se/profile/changjie) (changjie at kth.se) and/or [Mariano Scazzariello](https://www.ri.se/en/person/mariano-scazzariello) (mariano.scazzariello at ri.se).

‎netconfeval/common/model_configs.py

+28-43
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,10 @@ def _build_mistral_lite_prompt(messages):
106106

107107
return prompt
108108

109+
109110
def _build_qwen2_prompt(messages):
110-
111111
start_turn = "<|im_start|>"
112112
end_turn = "<|im_end|>\n"
113-
114113

115114
conversation = []
116115
for index, message in enumerate(messages):
@@ -122,14 +121,14 @@ def _build_qwen2_prompt(messages):
122121
content = content.strip()
123122

124123
if role.lower() in ['user', 'system', 'assistant']:
125-
conversation.append(start_turn + role.lower() + '\n' + content + end_turn)
124+
conversation.append(start_turn + role.lower() + '\n' + content + end_turn)
126125
else:
127126
raise ValueError(f"Unexpected role: {role}")
128127

129-
# Assemble the prompt with the start and end tokens and start the turn of assistant to prime the generation process
130-
prompt = ' '.join(conversation) + start_turn + "assistant\n"
128+
# Assemble the prompt with the start and end tokens and start the turn of assistant to prime the generation process
129+
prompt = ' '.join(conversation) + start_turn + "assistant\n"
131130

132-
return prompt
131+
return prompt
133132

134133

135134
def _build_llama3_prompt(messages):
@@ -148,18 +147,17 @@ def _build_llama3_prompt(messages):
148147
content = content.strip()
149148

150149
if role.lower() in ['user', 'system', 'assistant']:
151-
conversation.append(start_role + role.lower() + end_role + content + end_turn ) # Example: <|begin_of_text|><|start_header_id|>system<|end_header_id|>
150+
# Example: <|begin_of_text|><|start_header_id|>system<|end_header_id|>
151+
conversation.append(start_role + role.lower() + end_role + content + end_turn)
152152
else:
153153
raise ValueError(f"Unexpected role: {role}")
154154

155-
# Assemble the prompt with the start and end tokens and start the turn of assistant to prime the generation process
156-
prompt = start_prompt + ' '.join(conversation) + '\n' + start_role + "assistant" + end_role
157-
print(prompt)
155+
prompt = start_prompt + ' '.join(conversation) + '\n' + start_role + "assistant" + end_role
156+
158157
return prompt
159158

160159

161160
def get_model_instance(model_name: str) -> Any:
162-
163161
if model_configurations[model_name]['type'] == 'HF':
164162
from netconfeval.foundation.langchain.chat_models.hf import ChatHF
165163

@@ -169,15 +167,16 @@ def get_model_instance(model_name: str) -> Any:
169167
use_quantization=model_configurations[model_name]['use_quantization'],
170168
prompt_func=model_configurations[model_name]['prompt_builder'],
171169
)
172-
170+
173171
elif model_configurations[model_name]['type'] == 'Ollama':
174172
from langchain_community.llms import Ollama
175173

176-
return Ollama(model = model_configurations[model_name]['model_name'],
177-
num_predict = model_configurations[model_name]['num_predict'],
178-
num_gpu=-1
179-
)
180-
174+
return Ollama(
175+
model=model_configurations[model_name]['model_name'],
176+
num_predict=model_configurations[model_name]['num_predict'],
177+
num_gpu=-1
178+
)
179+
181180
elif model_configurations[model_name]['type'] == 'openai':
182181
from langchain_openai import ChatOpenAI
183182

@@ -190,6 +189,7 @@ def get_model_instance(model_name: str) -> Any:
190189

191190

192191
model_configurations = {
192+
# OpenAI Models
193193
'gpt-3.5-turbo': {
194194
'model_name': 'gpt-3.5-turbo',
195195
'type': 'openai',
@@ -226,74 +226,59 @@ def get_model_instance(model_name: str) -> Any:
226226
'seed': 5000,
227227
}
228228
},
229-
230-
#### Start Ollama models ####
231-
229+
# Ollama Models
232230
'llama3.1-ollama': {
233231
'type': 'Ollama',
234232
'model_name': 'llama3.1:8b-instruct-fp16',
235233
'num_predict': 4096
236234
},
237-
238-
'llama3-ollama': {
235+
'llama3.1-4bit-ollama': {
239236
'type': 'Ollama',
240-
'model_name': 'llama3:8b-instruct-fp16',
237+
'model_name': 'llama3.1:latest',
241238
'num_predict': 4096
242239
},
243-
244-
'neural-chat-ollama': {
240+
'llama3-ollama': {
245241
'type': 'Ollama',
246-
'model_name': 'neural-chat:7b-v3.3-fp16',
242+
'model_name': 'llama3:8b-instruct-fp16',
247243
'num_predict': 4096
248244
},
249-
250-
# 4-bit quantization version
251-
252-
'llama3.1-4bit-ollama': {
245+
'llama3-4bit-ollama': {
253246
'type': 'Ollama',
254-
'model_name': 'llama3.1:latest',
247+
'model_name': 'llama3:latest',
255248
'num_predict': 4096
256249
},
257-
258-
'llama3-4bit-ollama': {
250+
'neural-chat-ollama': {
259251
'type': 'Ollama',
260-
'model_name': 'llama3:latest',
252+
'model_name': 'neural-chat:7b-v3.3-fp16',
261253
'num_predict': 4096
262254
},
263-
264255
'neural-chat-4bit-ollama': {
265256
'type': 'Ollama',
266257
'model_name': 'neural-chat:latest',
267258
'num_predict': 4096
268259
},
269-
270-
#### End Ollama models #####
271-
272-
260+
# HuggingFace Models
273261
'qwen2.5-7b-instruct': {
274262
'type': 'HF',
275-
'model_name':'Qwen/Qwen2.5-7B-Instruct',
263+
'model_name': 'Qwen/Qwen2.5-7B-Instruct',
276264
'prompt_builder': _build_qwen2_prompt,
277265
'max_length': 4096,
278266
'use_quantization': False
279267
},
280-
281268
'llama3-8b-instruct': {
282269
'type': 'HF',
283270
'model_name': 'meta-llama/Meta-Llama-3-8B-Instruct',
284271
'prompt_builder': _build_llama3_prompt,
285272
'max_length': 4096,
286273
'use_quantization': False
287274
},
288-
289275
'llama3.1-8b-instruct': {
290276
'type': 'HF',
291277
'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',
292278
'prompt_builder': _build_llama3_prompt,
293279
'max_length': 4096,
294280
'use_quantization': False
295281
},
296-
297282
'llama2-7b-chat': {
298283
'model_name': 'meta-llama/Llama-2-7b-chat-hf',
299284
'prompt_builder': _build_llama2_prompt,

‎netconfeval/step_1_formal_spec_conflict_detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def main(args: argparse.Namespace) -> None:
143143

144144
skip_compare = False
145145
start_time = time.time()
146-
if model_configurations[args.model]['type'] in ['HF','Ollama']:
146+
if model_configurations[args.model]['type'] in ['HF', 'Ollama']:
147147
# Combine all system prompts with a new line separator
148148
combined_system_prompt = f"{SETUP_PROMPT}\n{FUNCTION_PROMPT}\n{ASK_FOR_RESULT_PROMPT}"
149149
messages = [

‎netconfeval/step_1_formal_spec_conflict_distance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def main(args: argparse.Namespace) -> None:
137137

138138
skip_compare = False
139139
start_time = time.time()
140-
if model_configurations[args.model]['type'] in ['HF','Ollama']:
140+
if model_configurations[args.model]['type'] in ['HF', 'Ollama']:
141141
# Combine all system prompts with a new line separator
142142
combined_system_prompt = f"{SETUP_PROMPT}\n{FUNCTION_PROMPT}\n{ASK_FOR_RESULT_PROMPT}"
143143
messages = [

‎netconfeval/step_1_formal_spec_translation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def main(args: argparse.Namespace) -> None:
135135

136136
skip_compare = False
137137
start_time = time.time()
138-
if model_configurations[args.model]['type'] in ['HF','Ollama']:
138+
if model_configurations[args.model]['type'] in ['HF', 'Ollama']:
139139
# Combine all system prompts with a new line separator
140140
combined_system_prompt = f"{SETUP_PROMPT}\n{FUNCTION_PROMPT}\n{ASK_FOR_RESULT_PROMPT}"
141141
messages = [

‎netconfeval/step_2_code_gen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def main(args: argparse.Namespace) -> None:
126126
w = csv.DictWriter(f, result_row.keys())
127127
w.writeheader()
128128

129-
if model_configurations[args.model]['type'] in ['HF','Ollama']:
129+
if model_configurations[args.model]['type'] in ['HF', 'Ollama']:
130130
combined_system_prompt = f"{SETUP_PROMPT}\n{ASK_FOR_CODE_PROMPT}"
131131
combined_human_prompt = f"{INPUT_OUTPUT_PROMPT}\n{INSTRUCTION_PROMPT}\n{{input}}"
132132
if with_feedback:

0 commit comments

Comments
 (0)
Please sign in to comment.