Skip to content

Commit 2e7b892

Browse files
Update attention_is_all_you_need sample readme
* Update attention_is_all_you_need sample readme
1 parent f721019 commit 2e7b892

File tree

5 files changed

+18
-39
lines changed

5 files changed

+18
-39
lines changed

PyTorch/1.13/README.md

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,22 @@ Follow the steps below to get set up with PyTorch on DirectML.
1414

1515
2. Clone this repo.
1616

17-
3. Install prerequisites
18-
```
19-
pip install torchvision==0.14.0
20-
pip install torch==1.13
21-
pip install torch-directml
22-
```
17+
3. Install **torch-directml**
2318

24-
4. _(optional)_ Run `pip list`. The following packages should be installed:
25-
```
26-
torch 1.13.0
27-
torch-directml 0.1.13.*
28-
torchvision 0.14.0
19+
>⚠️ Since torch-directml 0.1.13.1.*, **torch** and **torchvision** will be installed as dependencies
20+
21+
```ps
22+
pip install torch-directml
2923
```
3024

31-
5. Create a DML Device and Test
25+
4. Create a DML Device and Test
26+
3227
```
3328
import torch
3429
import torch_directml
3530
dml = torch_directml.device()
3631
```
37-
>⚠️ Note that device creation has changed in torch-directml 1.13 from previous versions. The torch-directml backend is currently mapped to “PrivateUse1." The new `torch_directml.device()` API is a convenient wrapper for creating your tenors on the correct device.
32+
>⚠️ Note that device creation has changed in torch-directml 0.1.13 from previous versions. The torch-directml backend is currently mapped to “PrivateUse1." The new `torch_directml.device()` API is a convenient wrapper for creating your tenors on the correct device.
3833
3934
## Samples
4035

PyTorch/1.13/attention_is_all_you_need/README.md

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,22 @@ This is a PyTorch Directml implementation of the Transformer model in "[Attentio
55

66
This sample is extracted from [pytorch benchmark](https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/attention_is_all_you_need_pytorch), and has been slightly changed to apply torch-directml.
77

8-
9-
# Requirement
10-
- python 3.8
11-
- torch 1.13
12-
- torch_directml 0.1.13
13-
- torchtext 0.14.0
14-
- spacy
15-
- tqdm
16-
- dill
17-
- numpy
18-
19-
208
# Usage
219

22-
## 1) Run install.py to download and preprocess data
10+
## 1) Install the prerequisites and prepare data
11+
2312
From inside the `attention_is_all_you_need` directory, run the following script:
2413
```ps
2514
python install.py
2615
```
2716

2817
## 2) Train the model
2918
```bash
30-
python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 128 -warmup 128000 -epoch 400 -use_dml
19+
python train.py -data_pkl .data/m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 128 -warmup 128000 -epoch 400 -use_dml
3120
```
3221

3322
## 3) Test the model
3423
```bash
35-
python translate.py -data_pkl m30k_deen_shr.pkl -model trained.chkpt -output prediction.txt -use_dml
24+
python translate.py -data_pkl .data/m30k_deen_shr.pkl -model trained.chkpt -output prediction.txt -use_dml
3625
```
3726

PyTorch/1.13/attention_is_all_you_need/install.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ def spacy_download(language):
1212

1313
def preprocess():
1414
current_dir = Path(os.path.dirname(os.path.realpath(__file__)))
15-
multi30k_data_dir = os.path.join(current_dir.parent.parent, ".data", "multi30k")
16-
root = os.path.join(str(Path(__file__).parent), ".data")
17-
os.makedirs(root, exist_ok=True)
15+
data_root = os.path.join(current_dir, ".data")
16+
os.makedirs(data_root, exist_ok=True)
1817
subprocess.check_call([sys.executable, os.path.join(current_dir, 'preprocess.py'), '-lang_src', 'de', '-lang_trg', 'en', '-share_vocab',
19-
'-save_data', os.path.join(root, 'm30k_deen_shr.pkl')])
18+
'-save_data', os.path.join(data_root, 'm30k_deen_shr.pkl')])
2019

2120
if __name__ == '__main__':
2221
pip_install_requirements()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
torchtext
12
dill==0.3.4
23
tqdm
34
iopath
45
numpy
56
spacy==2.3.5
6-
torchtext==0.14.0

PyTorch/README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,11 @@ PyTorch on DirectML is supported on both the latest versions of Windows 10 and t
1111
## Pytorch with DirectML Versions
1212
| torch-directml | pytorch |
1313
|-----------------------|-------|
14-
| [0.1.13.\*](https://pypi.org/project/torch-directml/) | 1.13 |
14+
| [0.1.13+](https://pypi.org/project/torch-directml/) | 1.13+ |
1515
| [1.8.0a0.\*](https://pypi.org/project/pytorch-directml/) | 1.8 |
1616

17-
## Setup
18-
* For users of Pytorch-DirectML forked from Pytorch __1.13__, see the setup instructions in the [1.13](./1.13/) folder.
19-
* For users of Pytorch-DirectML forked from Pytorch __1.8__, see the setup instructions in the [1.8](./1.8/) folder.
20-
2117
## Samples
22-
For users of Pytorch-DirectML forked from Pytorch 1.13, the samples can be found below or in the [1.13](./1.13/) folder:
18+
For users of Pytorch-DirectML forked from Pytorch 1.13 or higher, the samples can be found below or in the [1.13](./1.13/) folder:
2319
* [attenion is all you need- the original transformer model](./1.13/attention_is_all_you_need/)
2420
* [yolov3- a real-time object detection model](./1.13/yolov3/)
2521
* [squeezenet - a small image classification model](./1.13/squeezenet)

0 commit comments

Comments
 (0)