Skip to content

Commit e25ad2d

Browse files
authored
Update web_demo.py
1 parent 30a9c60 commit e25ad2d

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

web_demo.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# isort: skip_file
23
import copy
34
import warnings
@@ -16,6 +17,10 @@
1617
logger = logging.get_logger(__name__)
1718

1819

20+
import argparse
21+
22+
23+
1924
@dataclass
2025
class GenerationConfig:
2126
# this config is used for chat to provide more diversity
@@ -163,9 +168,13 @@ def on_btn_click():
163168

164169

165170
@st.cache_resource
166-
def load_model():
167-
model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct').cuda()
168-
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct', trust_remote_code=True)
171+
def load_model(arg1):
172+
# model = AutoModelForCausalLM.from_pretrained(args.m).cuda()
173+
# tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True)
174+
model = AutoModelForCausalLM.from_pretrained(arg1, torch_dtype=torch.float16).cuda()
175+
tokenizer = AutoTokenizer.from_pretrained(arg1, trust_remote_code=True)
176+
177+
169178
return model, tokenizer
170179

171180

@@ -207,10 +216,10 @@ def combine_history(prompt):
207216
return total_prompt
208217

209218

210-
def main():
219+
def main(arg1):
211220
# torch.cuda.empty_cache()
212221
print('load model begin.')
213-
model, tokenizer = load_model()
222+
model, tokenizer = load_model(arg1)
214223
print('load model end.')
215224

216225
st.title('Llama3-Instruct')
@@ -259,4 +268,7 @@ def main():
259268

260269

261270
if __name__ == '__main__':
262-
main()
271+
272+
import sys
273+
arg1 = sys.argv[1]
274+
main(arg1)

0 commit comments

Comments
 (0)