9
9
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
-
12
+ from inspect import getfullargspec
13
13
from typing import Callable , Dict
14
14
import torch
15
15
from torch import Tensor , cat
16
16
import torch .nn as nn
17
17
18
18
19
19
class DEFuncBase (nn .Module ):
20
- def __init__ (self , vector_field :Callable , has_time_arg :bool = True ):
20
+ def __init__ (self , vector_field : Callable , has_time_arg : bool = True ):
21
21
"""Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
22
22
Args:
23
23
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
24
24
has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
25
25
or `forward` method. Defaults to True.
26
26
"""
27
27
super ().__init__ ()
28
- self .nfe , self .vf , self .has_time_arg = 0. , vector_field , has_time_arg
28
+ self .nfe , self .vf , self .has_time_arg = 0.0 , vector_field , has_time_arg
29
29
30
- def forward (self , t :Tensor , x :Tensor , args :Dict = {}) -> Tensor :
30
+ def forward (self , t : Tensor , x : Tensor , args : Dict = {}) -> Tensor :
31
31
self .nfe += 1
32
- if self .has_time_arg : return self .vf (t , x , args = args )
33
- else : return self .vf (x )
32
+ if self .has_time_arg :
33
+ return self .vf (t , x , args = args )
34
+ else :
35
+ return self .vf (x )
34
36
35
37
36
38
class DEFunc (nn .Module ):
37
- def __init__ (self , vector_field :Callable , order :int = 1 ):
39
+ def __init__ (self , vector_field : Callable , order : int = 1 ):
38
40
"""Special vector field wrapper for Neural ODEs.
39
41
40
42
Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.
@@ -51,43 +53,50 @@ def __init__(self, vector_field:Callable, order:int=1):
51
53
(3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
52
54
"""
53
55
super ().__init__ ()
54
- self .vf , self .nfe , = vector_field , 0.
56
+ self .vf , self .nfe , = vector_field , 0.0
55
57
self .order , self .integral_loss , self .sensitivity = order , None , None
56
58
# identify whether vector field already has time arg
57
59
58
- def forward (self , t :Tensor , x :Tensor , args :Dict = {}) -> Tensor :
60
+ def forward (self , t : Tensor , x : Tensor , args : Dict = {}) -> Tensor :
59
61
self .nfe += 1
60
62
# set `t` depth-variable to DepthCat modules
61
63
for _ , module in self .vf .named_modules ():
62
- if hasattr (module , 't' ):
64
+ if hasattr (module , "t" ):
63
65
module .t = t
64
66
65
67
# if-else to handle autograd training with integral loss propagated in x[:, 0]
66
- if (self .integral_loss is not None ) and self .sensitivity == ' autograd' :
68
+ if (self .integral_loss is not None ) and self .sensitivity == " autograd" :
67
69
x_dyn = x [:, 1 :]
68
70
dlds = self .integral_loss (t , x_dyn )
69
- if len (dlds .shape ) == 1 : dlds = dlds [:, None ]
70
- if self .order > 1 : x_dyn = self .horder_forward (t , x_dyn , args )
71
- else : x_dyn = self .vf (t , x_dyn )
71
+ if len (dlds .shape ) == 1 :
72
+ dlds = dlds [:, None ]
73
+ if self .order > 1 :
74
+ x_dyn = self .horder_forward (t , x_dyn , args )
75
+ else :
76
+ x_dyn = self .vf (t , x_dyn )
72
77
return cat ([dlds , x_dyn ], 1 ).to (x_dyn )
73
78
74
79
# regular forward
75
80
else :
76
- if self .order > 1 : x = self .higher_order_forward (t , x )
77
- else : x = self .vf (t , x , args = args )
81
+ if self .order > 1 :
82
+ x = self .higher_order_forward (t , x )
83
+ else :
84
+ x = self .vf (t , x , args = args )
78
85
return x
79
86
80
- def higher_order_forward (self , t :Tensor , x :Tensor , args :Dict = {}) -> Tensor :
87
+ def higher_order_forward (self , t : Tensor , x : Tensor , args : Dict = {}) -> Tensor :
81
88
x_new = []
82
89
size_order = x .size (1 ) // self .order
83
90
for i in range (1 , self .order ):
84
- x_new .append (x [:, size_order * i : size_order * ( i + 1 )])
91
+ x_new .append (x [:, size_order * i : size_order * ( i + 1 )])
85
92
x_new .append (self .vf (t , x ))
86
93
return cat (x_new , dim = 1 ).to (x )
87
94
88
95
89
96
class SDEFunc (nn .Module ):
90
- def __init__ (self , f :Callable , g :Callable , order :int = 1 ):
97
+ def __init__ (
98
+ self , f : Callable , g : Callable , order : int = 1 , noise_type = None , sde_type = None
99
+ ):
91
100
""""Special vector field wrapper for Neural SDEs.
92
101
93
102
Args:
@@ -99,19 +108,35 @@ def __init__(self, f:Callable, g:Callable, order:int=1):
99
108
self .order , self .intloss , self .sensitivity = order , None , None
100
109
self .f_func , self .g_func = f , g
101
110
self .nfe = 0
111
+ self .noise_type = noise_type
112
+ self .sde_type = sde_type
102
113
103
- def forward (self , t :Tensor , x :Tensor , args : Dict = {} ) -> Tensor :
104
- pass
114
+ def forward (self , t : Tensor , x : Tensor ) -> Tensor :
115
+ raise NotImplementedError ( "Hopefully soon..." )
105
116
106
- def f (self , t :Tensor , x :Tensor , args : Dict = {} ) -> Tensor :
117
+ def f (self , t : Tensor , x : Tensor ) -> Tensor :
107
118
self .nfe += 1
108
- for _ , module in self .f_func .named_modules ():
109
- if hasattr (module , 't' ):
110
- module .t = t
111
- return self .f_func (x , args )
119
+ if issubclass (type (self .f_func ), nn .Module ):
120
+ if "t" not in getfullargspec (self .f_func .forward ).args :
121
+ return self .f_func (x )
122
+ else :
123
+ return self .f_func (t , x )
124
+ else :
125
+ if "t" not in getfullargspec (self .f_func ).args :
126
+ return self .f_func (x )
127
+ else :
128
+ return self .f_func (t , x )
112
129
113
- def g (self , t :Tensor , x :Tensor , args :Dict = {}) -> Tensor :
114
- for _ , module in self .g_func .named_modules ():
115
- if hasattr (module , 't' ):
116
- module .t = t
117
- return self .g_func (x , args )
130
+ def g (self , t : Tensor , x : Tensor ) -> Tensor :
131
+ self .nfe += 1
132
+ if issubclass (type (self .g_func ), nn .Module ):
133
+
134
+ if "t" not in getfullargspec (self .g_func .forward ).args :
135
+ return self .g_func (x )
136
+ else :
137
+ return self .g_func (t , x )
138
+ else :
139
+ if "t" not in getfullargspec (self .g_func ).args :
140
+ return self .g_func (x )
141
+ else :
142
+ return self .g_func (t , x )
0 commit comments