diff --git a/example.py b/example.py index b32c8c8..547fa37 100644 --- a/example.py +++ b/example.py @@ -1,11 +1,14 @@ -import torch -from transformers import AutoTokenizer +# python>=3.10 -from smoe.models.llama_moe import LlamaMoEForCausalLM +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer -model_dir = "/mnt/petrelfs/share_data/quxiaoye/runs/llama2_random_split_112gpus_16_2/outputs/cpt-llama2_random_split_112gpus_16_2_scale_factor_8-2342244/checkpoint-13600/" -tokenizer = AutoTokenizer.from_pretrained(model_dir) -model = LlamaMoEForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) +model_dir = "llama-moe/LLaMA-MoE-v1-3_5B-2_8" +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True +) +model.eval() model.to("cuda:0") input_text = "Suzhou is famous of"