Skip to content

Commit 452097b

Browse files
committed
multi dim bugfixes
1 parent c3014a7 commit 452097b

File tree

9 files changed

+40
-56
lines changed

9 files changed

+40
-56
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '1.0.1'
5+
__version__ = '1.0.2'
66
url = 'https://github.com/rusty1s/pytorch_scatter'
77

88
install_requires = ['cffi']

test/test_backward.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,30 @@
1414
@pytest.mark.parametrize('func,device', product(funcs, devices))
1515
def test_backward(func, device):
1616
index = torch.tensor(indices, dtype=torch.long, device=device)
17-
src = torch.rand(index.size(), dtype=torch.double, device=device)
17+
src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
1818
src.requires_grad_()
1919

2020
op = getattr(torch_scatter, 'scatter_{}'.format(func))
21-
data = (src, index)
21+
data = (src, index, 0)
2222
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
2323

2424

2525
tests = [{
2626
'name': 'max',
27-
'src': [1, 2, 3, 4, 5],
27+
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
2828
'index': [2, 0, 1, 1, 0],
2929
'dim': 0,
3030
'fill_value': 0,
31-
'grad': [4, 8, 6],
32-
'expected': [6, 0, 0, 8, 4]
31+
'grad': [[4, 4], [8, 8], [6, 6]],
32+
'expected': [[6, 6], [0, 0], [0, 0], [8, 8], [4, 4]],
3333
}, {
3434
'name': 'min',
35-
'src': [1, 2, 3, 4, 5],
35+
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
3636
'index': [2, 0, 1, 1, 0],
3737
'dim': 0,
3838
'fill_value': 3,
39-
'grad': [4, 8, 6],
40-
'expected': [6, 4, 8, 0, 0]
39+
'grad': [[4, 4], [8, 8], [6, 6]],
40+
'expected': [[6, 6], [4, 4], [8, 8], [0, 0], [0, 0]],
4141
}]
4242

4343

test/test_forward.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,102 +12,102 @@
1212
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
1313
'dim': -1,
1414
'fill_value': 0,
15-
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
15+
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
1616
}, {
1717
'name': 'add',
1818
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
19-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
19+
'index': [0, 1, 1, 0],
2020
'dim': 0,
2121
'fill_value': 0,
22-
'expected': [[6, 5], [6, 8]]
22+
'expected': [[6, 5], [6, 8]],
2323
}, {
2424
'name': 'sub',
2525
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
2626
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
2727
'dim': -1,
2828
'fill_value': 9,
29-
'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]]
29+
'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
3030
}, {
3131
'name': 'sub',
3232
'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
33-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
33+
'index': [0, 1, 1, 0],
3434
'dim': 0,
3535
'fill_value': 9,
36-
'expected': [[3, 4], [3, 5]]
36+
'expected': [[3, 4], [3, 5]],
3737
}, {
3838
'name': 'mul',
3939
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
4040
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
4141
'dim': -1,
4242
'fill_value': 1,
43-
'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]]
43+
'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]],
4444
}, {
4545
'name': 'mul',
4646
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
47-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
47+
'index': [0, 1, 1, 0],
4848
'dim': 0,
4949
'fill_value': 1,
50-
'expected': [[5, 6], [8, 15]]
50+
'expected': [[5, 6], [8, 15]],
5151
}, {
5252
'name': 'div',
5353
'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
5454
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
5555
'dim': -1,
5656
'fill_value': 1,
57-
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]]
57+
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
5858
}, {
5959
'name': 'div',
6060
'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
61-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
61+
'index': [0, 1, 1, 0],
6262
'dim': 0,
6363
'fill_value': 1,
64-
'expected': [[0.25, 0.25], [0.125, 0.5]]
64+
'expected': [[0.25, 0.25], [0.125, 0.5]],
6565
}, {
6666
'name': 'mean',
6767
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
6868
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
6969
'dim': -1,
7070
'fill_value': 0,
71-
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
71+
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
7272
}, {
7373
'name': 'mean',
7474
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
75-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
75+
'index': [0, 1, 1, 0],
7676
'dim': 0,
7777
'fill_value': 0,
78-
'expected': [[3, 2.5], [3, 4]]
78+
'expected': [[3, 2.5], [3, 4]],
7979
}, {
8080
'name': 'max',
8181
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
8282
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
8383
'dim': -1,
8484
'fill_value': 0,
8585
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
86-
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
86+
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
8787
}, {
8888
'name': 'max',
8989
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
90-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
90+
'index': [0, 1, 1, 0],
9191
'dim': 0,
9292
'fill_value': 0,
9393
'expected': [[5, 3], [4, 5]],
94-
'expected_arg': [[0, 3], [2, 1]]
94+
'expected_arg': [[0, 3], [2, 1]],
9595
}, {
9696
'name': 'min',
9797
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
9898
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
9999
'dim': -1,
100100
'fill_value': 9,
101101
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
102-
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
102+
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
103103
}, {
104104
'name': 'min',
105105
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
106-
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
106+
'index': [0, 1, 1, 0],
107107
'dim': 0,
108108
'fill_value': 9,
109109
'expected': [[1, 2], [2, 3]],
110-
'expected_arg': [[3, 0], [1, 2]]
110+
'expected_arg': [[3, 0], [1, 2]],
111111
}]
112112

113113

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .max import scatter_max
77
from .min import scatter_min
88

9-
__version__ = '1.0.1'
9+
__version__ = '1.0.2'
1010

1111
__all__ = [
1212
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean',

torch_scatter/add.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,6 @@
1-
from torch.autograd import Function
2-
31
from .utils.gen import gen
42

53

6-
class ScatterAdd(Function):
7-
@staticmethod
8-
def forward(ctx, out, src, index, dim):
9-
ctx.mark_dirty(out)
10-
ctx.save_for_backward(index)
11-
return out.scatter_add_(dim, index, src)
12-
13-
@staticmethod
14-
def backward(ctx, grad_out):
15-
index, = ctx.saved_variables
16-
17-
grad_src = None
18-
if ctx.needs_input_grad[1]:
19-
grad_src = grad_out[index]
20-
21-
return None, grad_src, None, None
22-
23-
244
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
255
r"""
266
|
@@ -90,4 +70,4 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
9070
[ 2, 4, 4, 0, 0, 0]])
9171
"""
9272
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
93-
return ScatterAdd.apply(out, src, index, dim)
73+
return out.scatter_add_(dim, index, src)

torch_scatter/div.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def forward(ctx, out, src, index, dim):
1212

1313
ctx.mark_dirty(out)
1414
ctx.save_for_backward(out, src, index)
15+
ctx.dim = dim
1516

1617
return out
1718

@@ -21,7 +22,7 @@ def backward(ctx, grad_out):
2122

2223
grad_src = None
2324
if ctx.needs_input_grad[1]:
24-
grad_src = -(out * grad_out)[index] / src
25+
grad_src = -(out * grad_out).gather(ctx.dim, index) / src
2526

2627
return None, grad_src, None, None
2728

torch_scatter/mean.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def forward(ctx, out, src, index, dim):
1515

1616
ctx.mark_dirty(out)
1717
ctx.save_for_backward(index, count)
18+
ctx.dim = dim
1819

1920
return out
2021

@@ -24,7 +25,8 @@ def backward(ctx, grad_out):
2425

2526
grad_src = None
2627
if ctx.needs_input_grad[1]:
27-
grad_src = grad_out[index] / count[index]
28+
grad_src = grad_out.gather(ctx.dim, index)
29+
grad_src /= count.gather(ctx.dim, index)
2830

2931
return None, grad_src, None, None
3032

torch_scatter/mul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def forward(ctx, out, src, index, dim):
1212

1313
ctx.mark_dirty(out)
1414
ctx.save_for_backward(out, src, index)
15+
ctx.dim = dim
1516

1617
return out
1718

@@ -21,7 +22,7 @@ def backward(ctx, grad_out):
2122

2223
grad_src = None
2324
if ctx.needs_input_grad[1]:
24-
grad_src = (grad_out * out)[index] / src
25+
grad_src = (grad_out * out).gather(ctx.dim, index) / src
2526

2627
return None, grad_src, None, None
2728

torch_scatter/utils/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
1212

1313
# Generate output tensor if not given.
1414
if out is None:
15-
dim_size = index.max() + 1 if dim_size is None else dim_size
15+
dim_size = index.max().item() + 1 if dim_size is None else dim_size
1616
out_size = list(src.size())
1717
out_size[dim] = dim_size
1818
out = src.new_full(out_size, fill_value)

0 commit comments

Comments
 (0)