Skip to content

Commit c284bec

Browse files
authored
9.30 Project (#1223)
* Create 结项报告.md * Add files via upload * Add files via upload * Update 结项报告.md * Add files via upload * Delete jointContribution/AI_Climate_disease/ERA5_land_download.ipynb * Delete jointContribution/AI_Climate_disease/XGBoost_imputer.ipynb * Delete jointContribution/AI_Climate_disease/dataprocessing.ipynb * Delete jointContribution/AI_Climate_disease/demo_200000_impute_report.csv * Delete jointContribution/AI_Climate_disease/paddle_TabM.ipynb * Delete jointContribution/AI_Climate_disease/paddle_model_samples.ipynb * Delete jointContribution/AI_Climate_disease/结项报告.md * Create model.py * add * add * Update __init__.py * add * Update __init__.py * Create era5_ukb.yaml * Delete jointContribution/AI_Disease_Climate/TabM.py * Delete jointContribution/AI_Disease_Climate/main.py * Delete jointContribution/AI_Disease_Climate/model.py * Add files via upload * Update bce.py * Update era5_land_dataset.py * fix * fix * fix * fix * Update bce.py * fix
1 parent 5350fc9 commit c284bec

File tree

7 files changed

+1441
-0
lines changed

7 files changed

+1441
-0
lines changed
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
from typing import Literal
16+
from typing import Optional
17+
18+
import paddle
19+
import paddle.nn as nn
20+
21+
22+
def init_rsqrt_uniform_(w: paddle.Tensor) -> paddle.Tensor:
23+
bound = 1.0 / math.sqrt(w.shape[-1])
24+
noise = paddle.uniform(w.shape, min=-bound, max=bound, dtype=w.dtype)
25+
w.set_value(noise)
26+
return w
27+
28+
29+
def init_random_signs_(w: paddle.Tensor) -> paddle.Tensor:
30+
with paddle.no_grad():
31+
p = paddle.full(w.shape, 0.5, dtype="float32")
32+
s = paddle.bernoulli(p) * 2.0 - 1.0
33+
s = paddle.cast(s, w.dtype)
34+
w.set_value(s)
35+
return w
36+
37+
38+
class NLinear(nn.Layer):
39+
def __init__(self, k: int, in_f: int, out_f: int, bias: bool = True):
40+
super().__init__()
41+
self.k, self.in_f, self.out_f = k, in_f, out_f
42+
self.weight = self.create_parameter(shape=[k, in_f, out_f])
43+
self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None
44+
self.reset_parameters()
45+
46+
def reset_parameters(self):
47+
init_rsqrt_uniform_(self.weight)
48+
49+
if self.bias_e is not None:
50+
init_rsqrt_uniform_(self.bias_e)
51+
52+
def forward(self, x): # x: (B,K,I)
53+
xk = paddle.transpose(x, [1, 0, 2]) # (K,B,I)
54+
yk = paddle.bmm(xk, self.weight) # (K,B,O)
55+
y = paddle.transpose(yk, [1, 0, 2]) # (B,K,O)
56+
57+
if self.bias_e is not None:
58+
y = y + self.bias_e
59+
return y
60+
61+
62+
class ScaleEnsemble(nn.Layer):
63+
def __init__(self, k: int, d: int, init="ones"):
64+
super().__init__()
65+
self.k, self.d = k, d
66+
self.weight = self.create_parameter(shape=[k, d])
67+
self.init = init
68+
self.reset_parameters()
69+
70+
def reset_parameters(self):
71+
72+
if self.init == "ones":
73+
self.weight.set_value(paddle.ones_like(self.weight))
74+
else:
75+
init_random_signs_(self.weight)
76+
77+
def forward(self, x): # (B,K,D)
78+
return x * self.weight
79+
80+
81+
class LinearBE(nn.Layer):
82+
def __init__(
83+
self, in_f: int, out_f: int, k: int, scale_init="ones", bias: bool = True
84+
):
85+
super().__init__()
86+
self.k, self.in_f, self.out_f = k, in_f, out_f
87+
self.weight = self.create_parameter(shape=[in_f, out_f])
88+
self.r = self.create_parameter(shape=[k, in_f])
89+
self.s = self.create_parameter(shape=[k, out_f])
90+
self.use_bias = bias
91+
self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None
92+
self.scale_init = scale_init
93+
self.reset_parameters()
94+
95+
def reset_parameters(self):
96+
init_rsqrt_uniform_(self.weight)
97+
98+
if self.scale_init == "ones":
99+
self.r.set_value(paddle.ones_like(self.r))
100+
self.s.set_value(paddle.ones_like(self.s))
101+
else:
102+
init_random_signs_(self.r)
103+
init_random_signs_(self.s)
104+
105+
if self.use_bias:
106+
init_rsqrt_uniform_(self.bias_e)
107+
108+
def forward(self, x): # (B,K,I)
109+
xr = x * self.r # (B,K,I)
110+
y = paddle.matmul(xr, self.weight) # (B,K,O)
111+
y = y * self.s
112+
113+
if self.use_bias:
114+
y = y + self.bias_e
115+
return y
116+
117+
118+
class MLPBlock(nn.Layer):
119+
def __init__(self, d_in, d_hid, dropout, act="ReLU"):
120+
super().__init__()
121+
Act = getattr(nn, act)
122+
self.net = nn.Sequential(
123+
nn.Linear(d_in, d_hid),
124+
Act(),
125+
nn.Dropout(dropout),
126+
)
127+
128+
def forward(self, x):
129+
return self.net(x)
130+
131+
132+
class BackboneMLP(nn.Layer):
133+
def __init__(self, n_blocks: int, d_in: int, d_hidden: int, dropout: float):
134+
super().__init__()
135+
blocks = []
136+
for i in range(n_blocks):
137+
blocks.append(MLPBlock(d_in if i == 0 else d_hidden, d_hidden, dropout))
138+
self.blocks = nn.LayerList(blocks)
139+
140+
def forward(self, x):
141+
for blk in self.blocks:
142+
x = blk(x)
143+
return x
144+
145+
146+
def _get_parent_by_path(root: nn.Layer, path_list):
147+
cur = root
148+
for p in path_list:
149+
150+
if hasattr(cur, p):
151+
cur = getattr(cur, p)
152+
else:
153+
sub_layers = getattr(cur, "_sub_layers", None)
154+
155+
if sub_layers is None or p not in sub_layers:
156+
raise AttributeError(
157+
f"Cannot locate sublayer '{p}' under '{type(cur).__name__}'"
158+
)
159+
160+
cur = sub_layers[p]
161+
return cur
162+
163+
164+
def _replace_linear(module: nn.Layer, k: int, mode: Literal["be", "packed"]):
165+
to_replace = []
166+
for full_name, layer in module.named_sublayers(include_self=False):
167+
168+
if isinstance(layer, nn.Linear):
169+
parts = full_name.split(".")
170+
parent_path, child_name = parts[:-1], parts[-1]
171+
parent = _get_parent_by_path(module, parent_path) if parent_path else module
172+
in_f = layer.weight.shape[0]
173+
out_f = layer.weight.shape[1]
174+
175+
if mode == "be":
176+
new_layer = LinearBE(in_f, out_f, k)
177+
with paddle.no_grad():
178+
new_layer.weight.set_value(layer.weight.clone())
179+
if layer.bias is not None and new_layer.bias_e is not None:
180+
b = layer.bias.reshape([1, -1]).tile([k, 1])
181+
new_layer.bias_e.set_value(b)
182+
else: # packed
183+
new_layer = NLinear(k, in_f, out_f, bias=layer.bias is not None)
184+
with paddle.no_grad():
185+
w = layer.weight.unsqueeze(0).tile([k, 1, 1])
186+
new_layer.weight.set_value(w)
187+
188+
if layer.bias is not None and new_layer.bias_e is not None:
189+
b = layer.bias.unsqueeze(0).tile([k, 1])
190+
new_layer.bias_e.set_value(b)
191+
192+
to_replace.append((parent, child_name, new_layer))
193+
for parent, child_name, new_layer in to_replace:
194+
195+
if hasattr(parent, child_name):
196+
setattr(parent, child_name, new_layer)
197+
else:
198+
parent._sub_layers[child_name] = new_layer
199+
200+
201+
class TabMFeatureExtractor(nn.Layer):
202+
"""arch_type: 'plain' | 'tabm' | 'tabm-mini' | 'tabm-packed'"""
203+
204+
def __init__(
205+
self,
206+
num_features: int,
207+
arch_type: Literal["plain", "tabm", "tabm-mini", "tabm-packed"] = "tabm",
208+
k: int = 32,
209+
backbone_cfg: Optional[dict] = None,
210+
reduce: bool = True,
211+
):
212+
super().__init__()
213+
214+
if arch_type == "plain":
215+
k = 1
216+
217+
self.k = k
218+
self.reduce = reduce
219+
cfg = backbone_cfg or dict(n_blocks=3, d_hidden=512, dropout=0.1)
220+
self.d_hidden = cfg["d_hidden"]
221+
self.backbone = BackboneMLP(**cfg, d_in=num_features)
222+
223+
if arch_type == "tabm":
224+
_replace_linear(self.backbone, k, mode="be")
225+
self.min_adapter = None
226+
elif arch_type == "tabm-mini":
227+
self.min_adapter = ScaleEnsemble(k, num_features, init="random-signs")
228+
elif arch_type == "tabm-packed":
229+
_replace_linear(self.backbone, k, mode="packed")
230+
self.min_adapter = None
231+
else:
232+
self.min_adapter = None
233+
234+
def forward(self, x_num: paddle.Tensor): # x_num: (B, D)
235+
236+
if self.k > 1:
237+
x = x_num.unsqueeze(1).tile([1, self.k, 1]) # (B,K,D)
238+
else:
239+
x = x_num.unsqueeze(1) # (B,1,D)
240+
241+
if self.min_adapter is not None:
242+
x = self.min_adapter(x)
243+
244+
feats = self.backbone(x) # (B,K,H)
245+
return feats.mean(axis=1) if self.reduce else feats # (B,H) or (B,K,H)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
# dynamic output directory according to running time and override name
14+
dir: outputs_era5_land_ukb/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
15+
job:
16+
name: ${mode} # name of logfile
17+
chdir: false # keep current working directory unchanged
18+
callbacks:
19+
init_callback:
20+
_target_: ppsci.utils.callbacks.InitCallback
21+
sweep:
22+
# output directory for multirun
23+
dir: ${hydra.run.dir}
24+
subdir: ./
25+
26+
# general settings
27+
mode: train # running mode: train/eval
28+
seed: 42
29+
output_dir: ${hydra:run.dir}
30+
log_freq: 1
31+
32+
DATA_DIR: ./data/
33+
34+
# model settings
35+
MODEL:
36+
input_keys: ["video", "vec"]
37+
output_keys: ["y"]
38+
T: 12
39+
H: 10
40+
W: 10
41+
C: 10
42+
N: 24
43+
44+
# training settings
45+
TRAIN:
46+
epochs: 2
47+
iters_per_epoch: 95
48+
save_freq: 1
49+
eval_during_train: false
50+
eval_freq: 5
51+
weight_decay: 1.0e-5
52+
batch_size: 32
53+
pretrained_model_path: null
54+
checkpoint_path: null
55+
56+
# evaluation settings
57+
EVAL:
58+
pretrained_model_path: null
59+
60+
# inference settings
61+
INFER:
62+
max_batch_size: 128
63+
batch_size: 32

0 commit comments

Comments
 (0)