Skip to content

Commit 1199c20

Browse files
committed
Cleanup
1 parent b0cdf6b commit 1199c20

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,5 @@ venv.bak/
109109
/site
110110

111111
# mypy
112-
.mypy_cache/
112+
.mypy_cache/cloze_data
113+
cloze_data/

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The ROCStories dataset can be downloaded from the associated [website](http://cs
4949
As with the [TensorFlow code](https://github.com/openai/finetune-transformer-lm), this code implements the ROCStories Cloze Test result reported in the paper which can be reproduced by running:
5050

5151
```bash
52+
python -m spacy download en
5253
python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here]
5354
```
5455

model_py.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import re
2-
import math
3-
import json
41
import copy
5-
import numpy as np
2+
import json
3+
import math
4+
import re
65

6+
import numpy as np
77
import torch
88
import torch.nn as nn
99
import torch.nn.functional as F
1010
from torch.nn.parameter import Parameter
1111

12+
1213
def gelu(x):
1314
return 0.5*x*(1+torch.tanh(math.sqrt(2/math.pi)*(x+0.044715*torch.pow(x, 3))))
1415

@@ -63,7 +64,7 @@ def forward(self, x):
6364
class Attention(nn.Module):
6465
def __init__(self, nx, n_ctx, cfg, scale=False):
6566
super(Attention, self).__init__()
66-
n_state = nx # in Attention: n_state=768 (nx=n_embd)
67+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
6768
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
6869
assert n_state % cfg.n_head==0
6970
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))

setup.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import sys
2+
3+
# py_version = (sys.version_info.major, sys.version_info.minor)
4+
# if py_version < (3, 6):
5+
# raise ValueError(
6+
# "This module is only compatible with Python 3.6+, but you are running "
7+
# "Python {}. We recommend installing conda and adding it to your PATH:"
8+
# "https://conda.io/docs/user-guide/install/index.html".format(py_version))
9+
10+
from setuptools import setup
11+
12+
13+
setup(
14+
name='lm',
15+
packages=['lm'],
16+
version='0.0.1',
17+
install_requires=[
18+
"ipdb",
19+
'ftfy',
20+
'spacy',
21+
'pytorch',
22+
],
23+
author='Tom B Brown',
24+
author_email='[email protected]',
25+
)

0 commit comments

Comments
 (0)