File tree 1 file changed +18
-6
lines changed
1 file changed +18
-6
lines changed Original file line number Diff line number Diff line change
1
+
1
2
# isort: skip_file
2
3
import copy
3
4
import warnings
16
17
logger = logging .get_logger (__name__ )
17
18
18
19
20
+ import argparse
21
+
22
+
23
+
19
24
@dataclass
20
25
class GenerationConfig :
21
26
# this config is used for chat to provide more diversity
@@ -163,9 +168,13 @@ def on_btn_click():
163
168
164
169
165
170
@st .cache_resource
166
- def load_model ():
167
- model = AutoModelForCausalLM .from_pretrained ('meta-llama/Meta-Llama-3-8B-Instruct' ).cuda ()
168
- tokenizer = AutoTokenizer .from_pretrained ('meta-llama/Meta-Llama-3-8B-Instruct' , trust_remote_code = True )
171
+ def load_model (arg1 ):
172
+ # model = AutoModelForCausalLM.from_pretrained(args.m).cuda()
173
+ # tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True)
174
+ model = AutoModelForCausalLM .from_pretrained (arg1 , torch_dtype = torch .float16 ).cuda ()
175
+ tokenizer = AutoTokenizer .from_pretrained (arg1 , trust_remote_code = True )
176
+
177
+
169
178
return model , tokenizer
170
179
171
180
@@ -207,10 +216,10 @@ def combine_history(prompt):
207
216
return total_prompt
208
217
209
218
210
- def main ():
219
+ def main (arg1 ):
211
220
# torch.cuda.empty_cache()
212
221
print ('load model begin.' )
213
- model , tokenizer = load_model ()
222
+ model , tokenizer = load_model (arg1 )
214
223
print ('load model end.' )
215
224
216
225
st .title ('Llama3-Instruct' )
@@ -259,4 +268,7 @@ def main():
259
268
260
269
261
270
if __name__ == '__main__' :
262
- main ()
271
+
272
+ import sys
273
+ arg1 = sys .argv [1 ]
274
+ main (arg1 )
You can’t perform that action at this time.
0 commit comments