Skip to content

Commit dad9a12

Browse files
authored
Merge pull request #106 from basf/develop
Version 0.2.1
2 parents 57ca8fc + a383411 commit dad9a12

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2801
-862
lines changed

.github/workflows/draft-pdf.yml

-22
This file was deleted.

README.md

+352-263
Large diffs are not rendered by default.

mambular/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.1.7"
4+
__version__ = "0.2.1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Reshape(nn.Module):
10+
def __init__(self, j, dim, method="linear"):
11+
super(Reshape, self).__init__()
12+
self.j = j
13+
self.dim = dim
14+
self.method = method
15+
16+
if self.method == "linear":
17+
# Use nn.Linear approach
18+
self.layer = nn.Linear(dim, j * dim)
19+
elif self.method == "embedding":
20+
# Use nn.Embedding approach
21+
self.layer = nn.Embedding(dim, j * dim)
22+
elif self.method == "conv1d":
23+
# Use nn.Conv1d approach
24+
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
25+
else:
26+
raise ValueError(f"Unsupported method '{method}' for reshaping.")
27+
28+
def forward(self, x):
29+
batch_size = x.shape[0]
30+
31+
if self.method == "linear" or self.method == "embedding":
32+
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
33+
x_reshaped = x_reshaped.view(
34+
batch_size, self.j, self.dim
35+
) # shape: (batch_size, j, dim)
36+
elif self.method == "conv1d":
37+
# For Conv1d, add dummy dimension and reshape
38+
x = x.unsqueeze(-1) # Add dummy dimension for convolution
39+
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
40+
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
41+
x_reshaped = x_reshaped.view(
42+
batch_size, self.j, self.dim
43+
) # shape: (batch_size, j, dim)
44+
45+
return x_reshaped
46+
47+
48+
class AttentionNetBlock(nn.Module):
49+
def __init__(
50+
self,
51+
channels,
52+
in_channels,
53+
d_model,
54+
n_heads,
55+
n_layers,
56+
dim_feedforward,
57+
transformer_activation,
58+
output_dim,
59+
attn_dropout,
60+
layer_norm_eps,
61+
norm_first,
62+
bias,
63+
activation,
64+
embedding_activation,
65+
norm_f,
66+
method,
67+
):
68+
super(AttentionNetBlock, self).__init__()
69+
70+
self.reshape = Reshape(channels, in_channels, method)
71+
72+
encoder_layer = nn.TransformerEncoderLayer(
73+
d_model=d_model,
74+
nhead=n_heads,
75+
batch_first=True,
76+
dim_feedforward=dim_feedforward,
77+
dropout=attn_dropout,
78+
activation=transformer_activation,
79+
layer_norm_eps=layer_norm_eps,
80+
norm_first=norm_first,
81+
bias=bias,
82+
)
83+
84+
self.encoder = nn.TransformerEncoder(
85+
encoder_layer,
86+
num_layers=n_layers,
87+
norm=norm_f,
88+
)
89+
90+
self.linear = nn.Linear(d_model, output_dim)
91+
self.activation = activation
92+
self.embedding_activation = embedding_activation
93+
94+
def forward(self, x):
95+
z = self.reshape(x)
96+
x = self.embedding_activation(z)
97+
x = self.encoder(x)
98+
x = z + x
99+
x = torch.sum(x, dim=1)
100+
x = self.linear(x)
101+
x = self.activation(x)
102+
return x
+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch.nn as nn
2+
import torch
3+
from rotary_embedding_torch import RotaryEmbedding
4+
from einops import rearrange
5+
import torch.nn.functional as F
6+
import numpy as np
7+
8+
9+
class GEGLU(nn.Module):
10+
def forward(self, x):
11+
x, gates = x.chunk(2, dim=-1)
12+
return x * F.gelu(gates)
13+
14+
15+
def FeedForward(dim, mult=4, dropout=0.0):
16+
return nn.Sequential(
17+
nn.LayerNorm(dim),
18+
nn.Linear(dim, dim * mult * 2),
19+
GEGLU(),
20+
nn.Dropout(dropout),
21+
nn.Linear(dim * mult, dim),
22+
)
23+
24+
25+
class Attention(nn.Module):
26+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
27+
super().__init__()
28+
inner_dim = dim_head * heads
29+
self.heads = heads
30+
self.scale = dim_head**-0.5
31+
self.norm = nn.LayerNorm(dim)
32+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
33+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
34+
self.dropout = nn.Dropout(dropout)
35+
self.rotary = rotary
36+
dim = np.int64(dim / 2)
37+
self.rotary_embedding = RotaryEmbedding(dim=dim)
38+
39+
def forward(self, x):
40+
h = self.heads
41+
x = self.norm(x)
42+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
43+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
44+
if self.rotary:
45+
q = self.rotary_embedding.rotate_queries_or_keys(q)
46+
k = self.rotary_embedding.rotate_queries_or_keys(k)
47+
q = q * self.scale
48+
49+
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
50+
51+
attn = sim.softmax(dim=-1)
52+
dropped_attn = self.dropout(attn)
53+
54+
out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
55+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
56+
out = self.to_out(out)
57+
58+
return out, attn
59+
60+
61+
class Transformer(nn.Module):
62+
def __init__(
63+
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
64+
):
65+
super().__init__()
66+
self.layers = nn.ModuleList([])
67+
68+
for _ in range(depth):
69+
self.layers.append(
70+
nn.ModuleList(
71+
[
72+
Attention(
73+
dim,
74+
heads=heads,
75+
dim_head=dim_head,
76+
dropout=attn_dropout,
77+
rotary=rotary,
78+
),
79+
FeedForward(dim, dropout=ff_dropout),
80+
]
81+
)
82+
)
83+
84+
def forward(self, x, return_attn=False):
85+
post_softmax_attns = []
86+
87+
for attn, ff in self.layers:
88+
attn_out, post_softmax_attn = attn(x)
89+
post_softmax_attns.append(post_softmax_attn)
90+
91+
x = attn_out + x
92+
x = ff(x) + x
93+
94+
if not return_attn:
95+
return x
96+
97+
return x, torch.stack(post_softmax_attns)
+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class EmbeddingLayer(nn.Module):
6+
def __init__(
7+
self,
8+
num_feature_info,
9+
cat_feature_info,
10+
d_model,
11+
embedding_activation=nn.Identity(),
12+
layer_norm_after_embedding=False,
13+
use_cls=False,
14+
cls_position=0,
15+
cat_encoding="int",
16+
):
17+
"""
18+
Embedding layer that handles numerical and categorical embeddings.
19+
20+
Parameters
21+
----------
22+
num_feature_info : dict
23+
Dictionary where keys are numerical feature names and values are their respective input dimensions.
24+
cat_feature_info : dict
25+
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
26+
d_model : int
27+
Dimensionality of the embeddings.
28+
embedding_activation : nn.Module, optional
29+
Activation function to apply after embedding. Default is `nn.Identity()`.
30+
layer_norm_after_embedding : bool, optional
31+
If True, applies layer normalization after embeddings. Default is `False`.
32+
use_cls : bool, optional
33+
If True, includes a class token in the embeddings. Default is `False`.
34+
cls_position : int, optional
35+
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.
36+
37+
Methods
38+
-------
39+
forward(num_features=None, cat_features=None)
40+
Defines the forward pass of the model.
41+
"""
42+
super(EmbeddingLayer, self).__init__()
43+
44+
self.d_model = d_model
45+
self.embedding_activation = embedding_activation
46+
self.layer_norm_after_embedding = layer_norm_after_embedding
47+
self.use_cls = use_cls
48+
self.cls_position = cls_position
49+
50+
self.num_embeddings = nn.ModuleList(
51+
[
52+
nn.Sequential(
53+
nn.Linear(input_shape, d_model, bias=False),
54+
self.embedding_activation,
55+
)
56+
for feature_name, input_shape in num_feature_info.items()
57+
]
58+
)
59+
60+
self.cat_embeddings = nn.ModuleList()
61+
for feature_name, num_categories in cat_feature_info.items():
62+
if cat_encoding == "int":
63+
self.cat_embeddings.append(
64+
nn.Sequential(
65+
nn.Embedding(num_categories + 1, d_model),
66+
self.embedding_activation,
67+
)
68+
)
69+
elif cat_encoding == "one-hot":
70+
self.cat_embeddings.append(
71+
nn.Sequential(
72+
OneHotEncoding(num_categories),
73+
nn.Linear(num_categories, d_model, bias=False),
74+
self.embedding_activation,
75+
)
76+
)
77+
78+
if self.use_cls:
79+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
80+
if layer_norm_after_embedding:
81+
self.embedding_norm = nn.LayerNorm(d_model)
82+
83+
self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)
84+
85+
def forward(self, num_features=None, cat_features=None):
86+
"""
87+
Defines the forward pass of the model.
88+
89+
Parameters
90+
----------
91+
num_features : Tensor, optional
92+
Tensor containing the numerical features.
93+
cat_features : Tensor, optional
94+
Tensor containing the categorical features.
95+
96+
Returns
97+
-------
98+
Tensor
99+
The output embeddings of the model.
100+
101+
Raises
102+
------
103+
ValueError
104+
If no features are provided to the model.
105+
"""
106+
if self.use_cls:
107+
batch_size = (
108+
cat_features[0].size(0)
109+
if cat_features != []
110+
else num_features[0].size(0)
111+
)
112+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
113+
114+
if self.cat_embeddings and cat_features is not None:
115+
cat_embeddings = [
116+
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
117+
]
118+
cat_embeddings = torch.stack(cat_embeddings, dim=1)
119+
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
120+
if self.layer_norm_after_embedding:
121+
cat_embeddings = self.embedding_norm(cat_embeddings)
122+
else:
123+
cat_embeddings = None
124+
125+
if self.num_embeddings and num_features is not None:
126+
num_embeddings = [
127+
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
128+
]
129+
num_embeddings = torch.stack(num_embeddings, dim=1)
130+
if self.layer_norm_after_embedding:
131+
num_embeddings = self.embedding_norm(num_embeddings)
132+
else:
133+
num_embeddings = None
134+
135+
if cat_embeddings is not None and num_embeddings is not None:
136+
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
137+
elif cat_embeddings is not None:
138+
x = cat_embeddings
139+
elif num_embeddings is not None:
140+
x = num_embeddings
141+
else:
142+
raise ValueError("No features provided to the model.")
143+
144+
if self.use_cls:
145+
if self.cls_position == 0:
146+
x = torch.cat([cls_tokens, x], dim=1)
147+
elif self.cls_position == 1:
148+
x = torch.cat([x, cls_tokens], dim=1)
149+
else:
150+
raise ValueError(
151+
"Invalid cls_position value. It should be either 0 or 1."
152+
)
153+
154+
return x
155+
156+
157+
class OneHotEncoding(nn.Module):
158+
def __init__(self, num_categories):
159+
super(OneHotEncoding, self).__init__()
160+
self.num_categories = num_categories
161+
162+
def forward(self, x):
163+
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()

0 commit comments

Comments
 (0)