Skip to content

Commit 3a66fd9

Browse files
committed
Small fixup
1 parent a6c575d commit 3a66fd9

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

transformer/data/shakespeare_char/prepare.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from transformers import AutoTokenizer
66
from pathlib import Path
77

8+
BASE_DIR = Path(__file__).parent
9+
10+
PYTHON_URL = 'hf://datasets/iamtarun/python_code_instructions_18k_alpaca/data/train-00000-of-00001-8b6e212f3e1ece96.parquet'
11+
SHAKESPEARE_URL = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
12+
813
def main():
9-
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
10-
shakespeare_text = requests.get(data_url).text
14+
shakespeare_text = requests.get(SHAKESPEARE_URL).text
1115

1216
# Add in some Python code training data so the model learns both Shakespare and Python
13-
df = pd.read_parquet(
14-
'hf://datasets/iamtarun/python_code_instructions_18k_alpaca/data/train-00000-of-00001-8b6e212f3e1ece96.parquet'
15-
)
17+
df = pd.read_parquet(PYTHON_URL)
1618
python_code = '\n###\n'.join(df['output'].dropna().astype(str))
1719
python_code = python_code.encode('ascii', 'ignore').decode() # there's a few non-ascii characters but I don't want to deal with them
1820

@@ -37,8 +39,8 @@ def main():
3739
# export to bin files
3840
train_ids = np.array(train_ids, dtype=np.uint16)
3941
val_ids = np.array(val_ids, dtype=np.uint16)
40-
train_ids.tofile(Path(__file__).parent / 'train.bin')
41-
val_ids.tofile(Path(__file__).parent / 'val.bin')
42+
train_ids.tofile(BASE_DIR / 'train.bin')
43+
val_ids.tofile(BASE_DIR / 'val.bin')
4244

4345
# save the meta information as well, to help us encode/decode later
4446
meta = {
@@ -47,7 +49,7 @@ def main():
4749
'encode': new_tokenizer.encode,
4850
'decode': new_tokenizer.decode,
4951
}
50-
with open(Path(__file__).parent / 'meta.pkl', 'wb') as f:
52+
with open(BASE_DIR / 'meta.pkl', 'wb') as f:
5153
pickle.dump(meta, f)
5254

5355
if __name__ == '__main__':

0 commit comments

Comments
 (0)