Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
43a28d5
Create 结项报告.md
gsy19971111 Sep 30, 2025
d5354bc
Add files via upload
gsy19971111 Sep 30, 2025
8dd06e6
Add files via upload
gsy19971111 Sep 30, 2025
f9bdce3
Update 结项报告.md
gsy19971111 Sep 30, 2025
b9d8b4f
Add files via upload
gsy19971111 Sep 30, 2025
c0a6f5a
Delete jointContribution/AI_Climate_disease/ERA5_land_download.ipynb
gsy19971111 Oct 24, 2025
248f748
Delete jointContribution/AI_Climate_disease/XGBoost_imputer.ipynb
gsy19971111 Oct 24, 2025
ed535f4
Delete jointContribution/AI_Climate_disease/dataprocessing.ipynb
gsy19971111 Oct 24, 2025
66dc828
Delete jointContribution/AI_Climate_disease/demo_200000_impute_report…
gsy19971111 Oct 24, 2025
92c5292
Delete jointContribution/AI_Climate_disease/paddle_TabM.ipynb
gsy19971111 Oct 24, 2025
ba246fe
Delete jointContribution/AI_Climate_disease/paddle_model_samples.ipynb
gsy19971111 Oct 24, 2025
0c76205
Delete jointContribution/AI_Climate_disease/结项报告.md
gsy19971111 Oct 24, 2025
1da524a
Create model.py
gsy19971111 Oct 24, 2025
c48657b
add
gsy19971111 Oct 24, 2025
73a8f69
add
gsy19971111 Oct 24, 2025
3653d4c
Update __init__.py
gsy19971111 Oct 24, 2025
8b9d77e
add
gsy19971111 Oct 24, 2025
3454fdd
Update __init__.py
gsy19971111 Oct 24, 2025
f2263f1
Create era5_ukb.yaml
gsy19971111 Oct 24, 2025
451aef3
Delete jointContribution/AI_Disease_Climate/TabM.py
gsy19971111 Oct 24, 2025
9e6e6f0
Delete jointContribution/AI_Disease_Climate/main.py
gsy19971111 Oct 24, 2025
4c90118
Delete jointContribution/AI_Disease_Climate/model.py
gsy19971111 Oct 24, 2025
61ea405
Add files via upload
gsy19971111 Oct 24, 2025
b25501a
Update bce.py
gsy19971111 Oct 24, 2025
3e05444
Update era5_land_dataset.py
gsy19971111 Oct 24, 2025
cc0a7d6
fix
gsy19971111 Oct 24, 2025
72d974e
Merge branch 'develop' into develop
gsy19971111 Oct 24, 2025
13d08ae
fix
gsy19971111 Oct 27, 2025
cd2ad79
fix
gsy19971111 Oct 27, 2025
c36af46
fix
gsy19971111 Oct 27, 2025
f9b1ca7
Merge branch 'develop' into develop
gsy19971111 Oct 27, 2025
f353ecb
Update bce.py
gsy19971111 Oct 28, 2025
a56e33f
Merge branch 'develop' into develop
gsy19971111 Oct 28, 2025
b99aefc
Merge branch 'develop' into develop
gsy19971111 Oct 29, 2025
322db54
fix
gsy19971111 Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions jointContribution/AI_Disease_Climate/TabM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Literal
from typing import Optional

import paddle
import paddle.nn as nn


def init_rsqrt_uniform_(w: paddle.Tensor) -> paddle.Tensor:
bound = 1.0 / math.sqrt(w.shape[-1])
noise = paddle.uniform(w.shape, min=-bound, max=bound, dtype=w.dtype)
w.set_value(noise)
return w


def init_random_signs_(w: paddle.Tensor) -> paddle.Tensor:
with paddle.no_grad():
p = paddle.full(w.shape, 0.5, dtype="float32")
s = paddle.bernoulli(p) * 2.0 - 1.0
s = paddle.cast(s, w.dtype)
w.set_value(s)
return w


class NLinear(nn.Layer):
def __init__(self, k: int, in_f: int, out_f: int, bias: bool = True):
super().__init__()
self.k, self.in_f, self.out_f = k, in_f, out_f
self.weight = self.create_parameter(shape=[k, in_f, out_f])
self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None
self.reset_parameters()

def reset_parameters(self):
init_rsqrt_uniform_(self.weight)

if self.bias_e is not None:
init_rsqrt_uniform_(self.bias_e)

def forward(self, x): # x: (B,K,I)
xk = paddle.transpose(x, [1, 0, 2]) # (K,B,I)
yk = paddle.bmm(xk, self.weight) # (K,B,O)
y = paddle.transpose(yk, [1, 0, 2]) # (B,K,O)

if self.bias_e is not None:
y = y + self.bias_e
return y


class ScaleEnsemble(nn.Layer):
def __init__(self, k: int, d: int, init="ones"):
super().__init__()
self.k, self.d = k, d
self.weight = self.create_parameter(shape=[k, d])
self.init = init
self.reset_parameters()

def reset_parameters(self):

if self.init == "ones":
self.weight.set_value(paddle.ones_like(self.weight))
else:
init_random_signs_(self.weight)

def forward(self, x): # (B,K,D)
return x * self.weight


class LinearBE(nn.Layer):
def __init__(
self, in_f: int, out_f: int, k: int, scale_init="ones", bias: bool = True
):
super().__init__()
self.k, self.in_f, self.out_f = k, in_f, out_f
self.weight = self.create_parameter(shape=[in_f, out_f])
self.r = self.create_parameter(shape=[k, in_f])
self.s = self.create_parameter(shape=[k, out_f])
self.use_bias = bias
self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None
self.scale_init = scale_init
self.reset_parameters()

def reset_parameters(self):
init_rsqrt_uniform_(self.weight)

if self.scale_init == "ones":
self.r.set_value(paddle.ones_like(self.r))
self.s.set_value(paddle.ones_like(self.s))
else:
init_random_signs_(self.r)
init_random_signs_(self.s)

if self.use_bias:
init_rsqrt_uniform_(self.bias_e)

def forward(self, x): # (B,K,I)
xr = x * self.r # (B,K,I)
y = paddle.matmul(xr, self.weight) # (B,K,O)
y = y * self.s

if self.use_bias:
y = y + self.bias_e
return y


class MLPBlock(nn.Layer):
def __init__(self, d_in, d_hid, dropout, act="ReLU"):
super().__init__()
Act = getattr(nn, act)
self.net = nn.Sequential(
nn.Linear(d_in, d_hid),
Act(),
nn.Dropout(dropout),
)

def forward(self, x):
return self.net(x)


class BackboneMLP(nn.Layer):
def __init__(self, n_blocks: int, d_in: int, d_hidden: int, dropout: float):
super().__init__()
blocks = []
for i in range(n_blocks):
blocks.append(MLPBlock(d_in if i == 0 else d_hidden, d_hidden, dropout))
self.blocks = nn.LayerList(blocks)

def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x


def _get_parent_by_path(root: nn.Layer, path_list):
cur = root
for p in path_list:

if hasattr(cur, p):
cur = getattr(cur, p)
else:
sub_layers = getattr(cur, "_sub_layers", None)

if sub_layers is None or p not in sub_layers:
raise AttributeError(
f"Cannot locate sublayer '{p}' under '{type(cur).__name__}'"
)

cur = sub_layers[p]
return cur


def _replace_linear(module: nn.Layer, k: int, mode: Literal["be", "packed"]):
to_replace = []
for full_name, layer in module.named_sublayers(include_self=False):

if isinstance(layer, nn.Linear):
parts = full_name.split(".")
parent_path, child_name = parts[:-1], parts[-1]
parent = _get_parent_by_path(module, parent_path) if parent_path else module
in_f = layer.weight.shape[0]
out_f = layer.weight.shape[1]

if mode == "be":
new_layer = LinearBE(in_f, out_f, k)
with paddle.no_grad():
new_layer.weight.set_value(layer.weight.clone())
if layer.bias is not None and new_layer.bias_e is not None:
b = layer.bias.reshape([1, -1]).tile([k, 1])
new_layer.bias_e.set_value(b)
else: # packed
new_layer = NLinear(k, in_f, out_f, bias=layer.bias is not None)
with paddle.no_grad():
w = layer.weight.unsqueeze(0).tile([k, 1, 1])
new_layer.weight.set_value(w)

if layer.bias is not None and new_layer.bias_e is not None:
b = layer.bias.unsqueeze(0).tile([k, 1])
new_layer.bias_e.set_value(b)

to_replace.append((parent, child_name, new_layer))
for parent, child_name, new_layer in to_replace:

if hasattr(parent, child_name):
setattr(parent, child_name, new_layer)
else:
parent._sub_layers[child_name] = new_layer


class TabMFeatureExtractor(nn.Layer):
"""arch_type: 'plain' | 'tabm' | 'tabm-mini' | 'tabm-packed'"""

def __init__(
self,
num_features: int,
arch_type: Literal["plain", "tabm", "tabm-mini", "tabm-packed"] = "tabm",
k: int = 32,
backbone_cfg: Optional[dict] = None,
reduce: bool = True,
):
super().__init__()

if arch_type == "plain":
k = 1

self.k = k
self.reduce = reduce
cfg = backbone_cfg or dict(n_blocks=3, d_hidden=512, dropout=0.1)
self.d_hidden = cfg["d_hidden"]
self.backbone = BackboneMLP(**cfg, d_in=num_features)

if arch_type == "tabm":
_replace_linear(self.backbone, k, mode="be")
self.min_adapter = None
elif arch_type == "tabm-mini":
self.min_adapter = ScaleEnsemble(k, num_features, init="random-signs")
elif arch_type == "tabm-packed":
_replace_linear(self.backbone, k, mode="packed")
self.min_adapter = None
else:
self.min_adapter = None

def forward(self, x_num: paddle.Tensor): # x_num: (B, D)

if self.k > 1:
x = x_num.unsqueeze(1).tile([1, self.k, 1]) # (B,K,D)
else:
x = x_num.unsqueeze(1) # (B,1,D)

if self.min_adapter is not None:
x = self.min_adapter(x)

feats = self.backbone(x) # (B,K,H)
return feats.mean(axis=1) if self.reduce else feats # (B,H) or (B,K,H)
63 changes: 63 additions & 0 deletions jointContribution/AI_Disease_Climate/config/era5_ukb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_era5_land_ukb/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 1

DATA_DIR: ./data/

# model settings
MODEL:
input_keys: ["video", "vec"]
output_keys: ["y"]
T: 12
H: 10
W: 10
C: 10
N: 24

# training settings
TRAIN:
epochs: 2
iters_per_epoch: 95
save_freq: 1
eval_during_train: false
eval_freq: 5
weight_decay: 1.0e-5
batch_size: 32
pretrained_model_path: null
checkpoint_path: null

# evaluation settings
EVAL:
pretrained_model_path: null

# inference settings
INFER:
max_batch_size: 128
batch_size: 32
Loading