Skip to content

Commit 1e8979a

Browse files
authored
Merge pull request #166 from DiffEqML/neuralsde
Merging neuralsde branch to the master branch for a new feature.
2 parents c708c4c + 0bbdd3e commit 1e8979a

File tree

9 files changed

+1085
-156
lines changed

9 files changed

+1085
-156
lines changed

test/test_sdeint.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import pytest
2+
from torch import nn
3+
import torch
4+
import torchsde
5+
import numpy as np
6+
from torchdyn.numerics import sdeint
7+
from numpy.testing import assert_almost_equal
8+
9+
10+
@pytest.mark.parametrize("solver", ["euler", "milstein_ito"])
11+
def test_geo_brownian_ito(solver):
12+
torch.manual_seed(0)
13+
np.random.seed(0)
14+
15+
t0, t1 = 0, 1
16+
size = (1, 1)
17+
device = "cpu"
18+
19+
alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
20+
beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
21+
x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device)
22+
t_size = 1000
23+
ts = torch.linspace(t0, t1, t_size).to(device)
24+
25+
bm = torchsde.BrownianInterval(
26+
t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time"
27+
)
28+
29+
def get_bm_queries(bm, ts):
30+
bm_increments = torch.stack(
31+
[bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0
32+
)
33+
bm_queries = torch.cat(
34+
(torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0))
35+
)
36+
return bm_queries
37+
38+
class SDE(nn.Module):
39+
def __init__(self, alpha, beta):
40+
super().__init__()
41+
self.alpha = nn.Parameter(alpha, requires_grad=True)
42+
self.beta = nn.Parameter(beta, requires_grad=True)
43+
self.noise_type = "diagonal"
44+
self.sde_type = "ito"
45+
46+
def f(self, t, x):
47+
return self.alpha * x
48+
49+
def g(self, t, x):
50+
return self.beta * x
51+
52+
sde = SDE(alpha, beta).to(device)
53+
54+
with torch.no_grad():
55+
_, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm)
56+
57+
bm_queries = get_bm_queries(bm, ts)
58+
xs_true = x0.cpu() * np.exp(
59+
(alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu()
60+
+ beta.cpu() * bm_queries[:, 0, 0].cpu()
61+
)
62+
63+
assert_almost_equal(xs_true[0][-1], xs_torchdyn[-1], decimal=2)
64+
65+
66+
@pytest.mark.parametrize("solver", ["eulerHeun", "milstein_stratonovich"])
67+
def test_geo_brownian_stratonovich(solver):
68+
torch.manual_seed(0)
69+
np.random.seed(0)
70+
71+
t0, t1 = 0, 1
72+
size = (1, 1)
73+
device = "cpu"
74+
75+
alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
76+
beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device)
77+
x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device)
78+
t_size = 1000
79+
ts = torch.linspace(t0, t1, t_size).to(device)
80+
81+
bm = torchsde.BrownianInterval(
82+
t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time"
83+
)
84+
85+
def get_bm_queries(bm, ts):
86+
bm_increments = torch.stack(
87+
[bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0
88+
)
89+
bm_queries = torch.cat(
90+
(torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0))
91+
)
92+
return bm_queries
93+
94+
class SDE(nn.Module):
95+
def __init__(self, alpha, beta):
96+
super().__init__()
97+
self.alpha = nn.Parameter(alpha, requires_grad=True)
98+
self.beta = nn.Parameter(beta, requires_grad=True)
99+
self.noise_type = "diagonal"
100+
self.sde_type = "stratonovich"
101+
102+
def f(self, t, x):
103+
return self.alpha * x
104+
105+
def g(self, t, x):
106+
return self.beta * x
107+
108+
sde = SDE(alpha, beta).to(device)
109+
110+
with torch.no_grad():
111+
_, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm)
112+
113+
bm_queries = get_bm_queries(bm, ts)
114+
xs_true = x0.cpu() * np.exp(
115+
(alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu()
116+
+ beta.cpu() * bm_queries[:, 0, 0].cpu()
117+
)
118+
119+
assert_almost_equal(xs_true[0][-1] - xs_torchdyn[-1], 1, decimal=0)
120+

torchdyn/core/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
from torchdyn.core.defunc import DEFunc
13+
from torchdyn.core.defunc import DEFunc, SDEFunc
1414
from torchdyn.core.neuralde import NeuralODE, NeuralSDE, MultipleShootingLayer
1515
from torchdyn.core.problems import ODEProblem, SDEProblem, MultipleShootingProblem
1616

1717
# backward-compatibility (pre v0.2.0)
1818
NeuralDE = NeuralODE
1919

20-
__all__ = ['DEFunc', 'NeuralODE', 'NeuralDE', 'NeuralSDE', 'ODEProblem', 'SDEProblem',
20+
__all__ = ['DEFunc', 'SDEFunc', 'NeuralODE', 'NeuralDE', 'NeuralSDE', 'ODEProblem', 'SDEProblem',
2121
'MultipleShootingProblem', 'MultipleShootingLayer']

torchdyn/core/defunc.py

+56-31
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,34 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
12+
from inspect import getfullargspec
1313
from typing import Callable, Dict
1414
import torch
1515
from torch import Tensor, cat
1616
import torch.nn as nn
1717

1818

1919
class DEFuncBase(nn.Module):
20-
def __init__(self, vector_field:Callable, has_time_arg:bool=True):
20+
def __init__(self, vector_field: Callable, has_time_arg: bool = True):
2121
"""Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
2222
Args:
2323
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
2424
has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
2525
or `forward` method. Defaults to True.
2626
"""
2727
super().__init__()
28-
self.nfe, self.vf, self.has_time_arg = 0., vector_field, has_time_arg
28+
self.nfe, self.vf, self.has_time_arg = 0.0, vector_field, has_time_arg
2929

30-
def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
30+
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
3131
self.nfe += 1
32-
if self.has_time_arg: return self.vf(t, x, args=args)
33-
else: return self.vf(x)
32+
if self.has_time_arg:
33+
return self.vf(t, x, args=args)
34+
else:
35+
return self.vf(x)
3436

3537

3638
class DEFunc(nn.Module):
37-
def __init__(self, vector_field:Callable, order:int=1):
39+
def __init__(self, vector_field: Callable, order: int = 1):
3840
"""Special vector field wrapper for Neural ODEs.
3941
4042
Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.
@@ -51,43 +53,50 @@ def __init__(self, vector_field:Callable, order:int=1):
5153
(3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
5254
"""
5355
super().__init__()
54-
self.vf, self.nfe, = vector_field, 0.
56+
self.vf, self.nfe, = vector_field, 0.0
5557
self.order, self.integral_loss, self.sensitivity = order, None, None
5658
# identify whether vector field already has time arg
5759

58-
def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
60+
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
5961
self.nfe += 1
6062
# set `t` depth-variable to DepthCat modules
6163
for _, module in self.vf.named_modules():
62-
if hasattr(module, 't'):
64+
if hasattr(module, "t"):
6365
module.t = t
6466

6567
# if-else to handle autograd training with integral loss propagated in x[:, 0]
66-
if (self.integral_loss is not None) and self.sensitivity == 'autograd':
68+
if (self.integral_loss is not None) and self.sensitivity == "autograd":
6769
x_dyn = x[:, 1:]
6870
dlds = self.integral_loss(t, x_dyn)
69-
if len(dlds.shape) == 1: dlds = dlds[:, None]
70-
if self.order > 1: x_dyn = self.horder_forward(t, x_dyn, args)
71-
else: x_dyn = self.vf(t, x_dyn)
71+
if len(dlds.shape) == 1:
72+
dlds = dlds[:, None]
73+
if self.order > 1:
74+
x_dyn = self.horder_forward(t, x_dyn, args)
75+
else:
76+
x_dyn = self.vf(t, x_dyn)
7277
return cat([dlds, x_dyn], 1).to(x_dyn)
7378

7479
# regular forward
7580
else:
76-
if self.order > 1: x = self.higher_order_forward(t, x)
77-
else: x = self.vf(t, x, args=args)
81+
if self.order > 1:
82+
x = self.higher_order_forward(t, x)
83+
else:
84+
x = self.vf(t, x, args=args)
7885
return x
7986

80-
def higher_order_forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
87+
def higher_order_forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
8188
x_new = []
8289
size_order = x.size(1) // self.order
8390
for i in range(1, self.order):
84-
x_new.append(x[:, size_order*i : size_order*(i+1)])
91+
x_new.append(x[:, size_order * i : size_order * (i + 1)])
8592
x_new.append(self.vf(t, x))
8693
return cat(x_new, dim=1).to(x)
8794

8895

8996
class SDEFunc(nn.Module):
90-
def __init__(self, f:Callable, g:Callable, order:int=1):
97+
def __init__(
98+
self, f: Callable, g: Callable, order: int = 1, noise_type=None, sde_type=None
99+
):
91100
""""Special vector field wrapper for Neural SDEs.
92101
93102
Args:
@@ -99,19 +108,35 @@ def __init__(self, f:Callable, g:Callable, order:int=1):
99108
self.order, self.intloss, self.sensitivity = order, None, None
100109
self.f_func, self.g_func = f, g
101110
self.nfe = 0
111+
self.noise_type = noise_type
112+
self.sde_type = sde_type
102113

103-
def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
104-
pass
114+
def forward(self, t: Tensor, x: Tensor) -> Tensor:
115+
raise NotImplementedError("Hopefully soon...")
105116

106-
def f(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
117+
def f(self, t: Tensor, x: Tensor) -> Tensor:
107118
self.nfe += 1
108-
for _, module in self.f_func.named_modules():
109-
if hasattr(module, 't'):
110-
module.t = t
111-
return self.f_func(x, args)
119+
if issubclass(type(self.f_func), nn.Module):
120+
if "t" not in getfullargspec(self.f_func.forward).args:
121+
return self.f_func(x)
122+
else:
123+
return self.f_func(t, x)
124+
else:
125+
if "t" not in getfullargspec(self.f_func).args:
126+
return self.f_func(x)
127+
else:
128+
return self.f_func(t, x)
112129

113-
def g(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
114-
for _, module in self.g_func.named_modules():
115-
if hasattr(module, 't'):
116-
module.t = t
117-
return self.g_func(x, args)
130+
def g(self, t: Tensor, x: Tensor) -> Tensor:
131+
self.nfe += 1
132+
if issubclass(type(self.g_func), nn.Module):
133+
134+
if "t" not in getfullargspec(self.g_func.forward).args:
135+
return self.g_func(x)
136+
else:
137+
return self.g_func(t, x)
138+
else:
139+
if "t" not in getfullargspec(self.g_func).args:
140+
return self.g_func(x)
141+
else:
142+
return self.g_func(t, x)

0 commit comments

Comments
 (0)