Skip to content

Commit e20aaa8

Browse files
committed
transformers-generate-from_pretrained.py
1 parent 5e59929 commit e20aaa8

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python3
2+
3+
print("Loading...")
4+
import torch
5+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
6+
7+
# Load model and tokenizer from directory (e.g., for text classification)
8+
model_dir = 'results/checkpoint-4'
9+
#tokenizer = AutoTokenizer.from_pretrained(model_dir)
10+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
11+
# model = AutoModelForSequenceClassification.from_pretrained(model_dir)
12+
model = AutoModelForCausalLM.from_pretrained(model_dir)
13+
14+
text = "0 1"
15+
16+
inputs = tokenizer(text, return_tensors="pt")
17+
print(inputs)
18+
19+
# Perform inference
20+
with torch.no_grad():
21+
outputs = model(**inputs)
22+
23+
# Get predicted class
24+
#predictions = torch.argmax(outputs.logits, dim=-1)
25+
#print(f"Predicted class: {predictions.item()}")
26+
27+
# For text generation (if using a language model like GPT)
28+
generated = model.generate(inputs['input_ids'], max_length=4)
29+
generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
30+
print(generated_text)
31+

0 commit comments

Comments
 (0)