Skip to content

Commit 402b84a

Browse files
milancurcicjvdp1
andauthored
Generic conv & maxpool (#220)
* Generic conv constructor for specific conv1d and conv2d layers * Generic maxpool constructor for maxpool1d_layer and maxpool2d_layer * Fix arguments in 2d CNN * Update src/nf/nf_layer_constructors.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_layer_constructors.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_layer_constructors.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_layer_constructors.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Add generic locally_connected wrapper around locally_connected1d --------- Co-authored-by: Jeremie Vandenplas <[email protected]>
1 parent 2ed7b6a commit 402b84a

17 files changed

+231
-220
lines changed

README.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3333
| Embedding | `embedding` | n/a | 2 |||
3434
| Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 |||
3535
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
36-
| Locally connected (1-d) | `locally_connected1d` | `input2d`, `locally_connected1d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
37-
| Convolutional (1-d) | `conv1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
38-
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
39-
| Max-pooling (1-d) | `maxpool1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 |||
40-
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
36+
| Locally connected (1-d) | `locally_connected` | `input`, `locally_connected`, `conv`, `maxpool`, `reshape` | 2 |||
37+
| Convolutional (1-d and 2-d) | `conv` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
38+
| Max-pooling (1-d and 2-d) | `maxpool` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 |||
4139
| Linear (2-d) | `linear2d` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 |||
4240
| Self-attention | `self_attention` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 |||
4341
| Layer Normalization | `layernorm` | `linear2d`, `self_attention` | 2 |||

example/cnn_mnist.f90

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist
22

33
use nf, only: network, sgd, &
4-
input, conv2d, maxpool2d, flatten, dense, reshape, &
4+
input, conv, maxpool, flatten, dense, reshape, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -21,10 +21,10 @@ program cnn_mnist
2121
net = network([ &
2222
input(784), &
2323
reshape(1, 28, 28), &
24-
conv2d(filters=8, kernel_size=3, activation=relu()), &
25-
maxpool2d(pool_size=2), &
26-
conv2d(filters=16, kernel_size=3, activation=relu()), &
27-
maxpool2d(pool_size=2), &
24+
conv(filters=8, kernel_width=3, kernel_height=3, activation=relu()), &
25+
maxpool(pool_width=2, pool_height=2, stride=2), &
26+
conv(filters=16, kernel_width=3, kernel_height=3, activation=relu()), &
27+
maxpool(pool_width=2, pool_height=2, stride=2), &
2828
dense(10, activation=softmax()) &
2929
])
3030

example/cnn_mnist_1d.f90

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist_1d
22

33
use nf, only: network, sgd, &
4-
input, conv1d, maxpool1d, flatten, dense, reshape, locally_connected1d, &
4+
input, maxpool, flatten, dense, reshape, locally_connected, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -21,10 +21,10 @@ program cnn_mnist_1d
2121
net = network([ &
2222
input(784), &
2323
reshape(28, 28), &
24-
locally_connected1d(filters=8, kernel_size=3, activation=relu()), &
25-
maxpool1d(pool_size=2), &
26-
locally_connected1d(filters=16, kernel_size=3, activation=relu()), &
27-
maxpool1d(pool_size=2), &
24+
locally_connected(filters=8, kernel_size=3, activation=relu()), &
25+
maxpool(pool_width=2, stride=2), &
26+
locally_connected(filters=16, kernel_size=3, activation=relu()), &
27+
maxpool(pool_width=2, stride=2), &
2828
dense(10, activation=softmax()) &
2929
])
3030

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.21.0"
2+
version = "0.22.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf.f90

+3-5
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv1d, &
7-
conv2d, &
6+
conv, &
87
dense, &
98
dropout, &
109
embedding, &
1110
flatten, &
1211
input, &
1312
layernorm, &
1413
linear2d, &
15-
locally_connected1d, &
16-
maxpool1d, &
17-
maxpool2d, &
14+
locally_connected, &
15+
maxpool, &
1816
reshape, &
1917
self_attention
2018
use nf_loss, only: mse, quadratic

0 commit comments

Comments
 (0)