diff --git a/README.md b/README.md
index d82934c..5208ae8 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,39 @@
-# OccRWKV
+
+# 🤖 OccRWKV: Rethinking Efficient 3D Semantic Occupancy Prediction with Linear Complexity
+*We will open source the complete code after the paper is accepted*
+
-## OccRWKV: Rethinking Sparse Latent Representation for 3D Semantic Occupancy Prediction
+## 📢 News
-## Preperation
+- [2024/09]: OccRWKV's logs are available for download:
+
+
+| OccRWKV Results | Experiment Log |
+|:------------------------------------------------------------------:|:----------:|
+|OccRWKV on the SemanticKITTI hidden official test dataset | [link](https://connecthkuhk-my.sharepoint.com/:t:/g/personal/u3009632_connect_hku_hk/EYqFDMD6xexCqXwfZ_nPxEUB0akfqePg4TwuGiuf4fQK0Q?e=PFM1ma) |
+|OccRWKV train log | [link](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/u3009632_connect_hku_hk/EcKG5MgDCTJJuu8DJ7VoS9sB0euzAEaMkpLjlY9LvRJ0GA?e=lwddX3) |
+
+
+
+- [2024/08]: The pre-trained model can be downloaded at [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/u3009632_connect_hku_hk/ETCUIJ7rPnFJniQYMsDsPyIBHkzirRP4c3n-eU9fcBZTaA?e=P8AkQ2).
+- [2024/07]: 🔥 We released the code of OccRWKV. *The First Receptance Weighted Key Value (RWKV)-based 3D Semantic Occupancy Network*
+
+
+
+```
+@article{wang2024omega,
+title={OccRWKV: Rethinking Efficient 3D Semantic Occupancy Prediction with Linear Complexity},
+author={Wang, Junming and Yin, Wei and Long, Xiaoxiao and Zhang, Xinyu and Xing, Zebing and Guo, Xiaoyang and Qian, Zhang},
+year={2024}
+ }
+```
+
+Please kindly star ⭐️ this project if it helps you. We take great efforts to develop and maintain it 😁.
+
+
+## 🛠️ Installation
-### Prerequisites
```
conda create -n occ_rwkv python=3.10 -y
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y
@@ -13,9 +41,12 @@ pip install spconv-cu120
pip install tensorboardX
pip install dropblock
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
+pip install -U openmim
+mim install mmcv-full
+pip install mmcls==0.25.0
```
-### Dataset
+## 💽 Dataset
Please download the Semantic Scene Completion dataset (v1.1) from the [SemanticKITTI website](http://www.semantic-kitti.org/dataset.html) and extract it.
@@ -35,7 +66,7 @@ SemanticKITTI
│ │ ├── ... ...
```
-## Getting Start
+## 🤗 Getting Start
Clone the repository:
```
https://github.com/jmwang0117/OccRWKV.git
@@ -49,7 +80,7 @@ We provide training routine examples in the `cfgs` folder. Make sure to change t
```
$ cd
-$ python train.py --cfg cfgs/DSC-Base.yaml --dset_root
+$ bash scripts/run_train.sh
```
### Validation
@@ -57,7 +88,7 @@ Validation passes are done during training routine. Additional pass in the valid
```
$ cd
-$ python validate.py --weights --dset_root
+$ bash scripts/run_val.sh
```
### Test
@@ -69,21 +100,15 @@ You can provide which checkpoints you want to use for testing. We used the ones
```
$ cd
-$ python test.py --weights --dset_root --out_path
+$ bash scripts/run_test.sh
```
-### Pretrained Model
-
-You can download the models with the scores below from this [Google drive link](https://drive.google.com/file/d/1-b3O7QS6hBQIGFTO-7qSG7Zb9kbQuxdO/view?usp=sharing),
-
-| Model | Segmentation | Completion |
-|--|--|--|
-| SSC-RS | 24.2 | 59.7 |
-* Results reported to SemanticKITTI: Semantic Scene Completion leaderboard ([link](https://codalab.lisn.upsaclay.fr/competitions/7170\#results)).
-## Acknowledgement
-This project is not possible without multiple great opensourced codebases.
-* [spconv](https://github.com/traveller59/spconv)
-* [LMSCNet](https://github.com/cv-rits/LMSCNet)
-* [SSA-SC](https://github.com/jokester-zzz/SSA-SC)
-* [GASN](https://github.com/ItIsFriday/PcdSeg)
+## 🏆 Acknowledgement
+Many thanks to these excellent open source projects:
+- [AGRNav](https://github.com/jmwang0117/AGRNav)
+- [Prometheus](https://github.com/amov-lab/Prometheus)
+- [SSC-RS](https://github.com/Jieqianyu/SSC-RS)
+- [semantic-kitti-api](https://github.com/PRBonn/semantic-kitti-api)
+- [Terrestrial-Aerial-Navigation](https://github.com/ZJU-FAST-Lab/Terrestrial-Aerial-Navigation)
+- [EGO-Planner](https://github.com/ZJU-FAST-Lab/ego-planner-swarm)
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..c91f139
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,251 @@
+name: occ_rwkv
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_kmp_llvm
+ - aom=3.9.0=hac33072_0
+ - blas=2.116=mkl
+ - blas-devel=3.9.0=16_linux64_mkl
+ - brotli-python=1.1.0=py310hc6cd4ac_1
+ - bzip2=1.0.8=hd590300_5
+ - ca-certificates=2024.6.2=hbcca054_0
+ - cairo=1.18.0=h3faef2a_0
+ - certifi=2024.6.2=pyhd8ed1ab_0
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
+ - cuda-cudart=12.1.105=0
+ - cuda-cupti=12.1.105=0
+ - cuda-libraries=12.1.0=0
+ - cuda-nvrtc=12.1.105=0
+ - cuda-nvtx=12.1.105=0
+ - cuda-opencl=12.5.39=0
+ - cuda-runtime=12.1.0=0
+ - cuda-version=12.5=3
+ - dav1d=1.2.1=hd590300_0
+ - expat=2.6.2=h59595ed_0
+ - ffmpeg=7.0.1=gpl_hb399a10_100
+ - filelock=3.14.0=pyhd8ed1ab_0
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
+ - font-ttf-inconsolata=3.000=h77eed37_0
+ - font-ttf-source-code-pro=2.038=h77eed37_0
+ - font-ttf-ubuntu=0.83=h77eed37_2
+ - fontconfig=2.14.2=h14ed4e7_0
+ - fonts-conda-ecosystem=1=0
+ - fonts-conda-forge=1=0
+ - freetype=2.12.1=h267a509_2
+ - fribidi=1.0.10=h36c2ea0_0
+ - gettext=0.22.5=h59595ed_2
+ - gettext-tools=0.22.5=h59595ed_2
+ - gmp=6.3.0=h59595ed_1
+ - gmpy2=2.1.5=py310hc7909c9_1
+ - gnutls=3.7.9=hb077bed_0
+ - graphite2=1.3.13=h59595ed_1003
+ - harfbuzz=8.5.0=hfac3d4d_0
+ - icu=73.2=h59595ed_0
+ - idna=3.7=pyhd8ed1ab_0
+ - jinja2=3.1.4=pyhd8ed1ab_0
+ - lame=3.100=h166bdaf_1003
+ - lcms2=2.16=hb7c19ff_0
+ - ld_impl_linux-64=2.40=hf3520f5_3
+ - lerc=4.0.0=h27087fc_0
+ - libabseil=20240116.2=cxx17_h59595ed_0
+ - libasprintf=0.22.5=h661eb56_2
+ - libasprintf-devel=0.22.5=h661eb56_2
+ - libass=0.17.1=h8fe9dca_1
+ - libblas=3.9.0=16_linux64_mkl
+ - libcblas=3.9.0=16_linux64_mkl
+ - libcublas=12.1.0.26=0
+ - libcufft=11.0.2.4=0
+ - libcufile=1.10.0.4=0
+ - libcurand=10.3.6.39=0
+ - libcusolver=11.4.4.55=0
+ - libcusparse=12.0.2.55=0
+ - libdeflate=1.20=hd590300_0
+ - libdrm=2.4.120=hd590300_0
+ - libexpat=2.6.2=h59595ed_0
+ - libffi=3.4.2=h7f98852_5
+ - libgcc-ng=13.2.0=h77fa898_7
+ - libgettextpo=0.22.5=h59595ed_2
+ - libgettextpo-devel=0.22.5=h59595ed_2
+ - libgfortran-ng=13.2.0=h69a702a_7
+ - libgfortran5=13.2.0=hca663fb_7
+ - libglib=2.80.2=hf974151_0
+ - libgomp=13.2.0=h77fa898_7
+ - libhwloc=2.10.0=default_h5622ce7_1001
+ - libiconv=1.17=hd590300_2
+ - libidn2=2.3.7=hd590300_0
+ - libjpeg-turbo=3.0.0=hd590300_1
+ - liblapack=3.9.0=16_linux64_mkl
+ - liblapacke=3.9.0=16_linux64_mkl
+ - libnpp=12.0.2.50=0
+ - libnsl=2.0.1=hd590300_0
+ - libnvjitlink=12.1.105=0
+ - libnvjpeg=12.1.1.14=0
+ - libopenvino=2024.1.0=h2da1b83_7
+ - libopenvino-auto-batch-plugin=2024.1.0=hb045406_7
+ - libopenvino-auto-plugin=2024.1.0=hb045406_7
+ - libopenvino-hetero-plugin=2024.1.0=h5c03a75_7
+ - libopenvino-intel-cpu-plugin=2024.1.0=h2da1b83_7
+ - libopenvino-intel-gpu-plugin=2024.1.0=h2da1b83_7
+ - libopenvino-intel-npu-plugin=2024.1.0=he02047a_7
+ - libopenvino-ir-frontend=2024.1.0=h5c03a75_7
+ - libopenvino-onnx-frontend=2024.1.0=h07e8aee_7
+ - libopenvino-paddle-frontend=2024.1.0=h07e8aee_7
+ - libopenvino-pytorch-frontend=2024.1.0=he02047a_7
+ - libopenvino-tensorflow-frontend=2024.1.0=h39126c6_7
+ - libopenvino-tensorflow-lite-frontend=2024.1.0=he02047a_7
+ - libopus=1.3.1=h7f98852_1
+ - libpciaccess=0.18=hd590300_0
+ - libpng=1.6.43=h2797004_0
+ - libprotobuf=4.25.3=h08a7969_0
+ - libsqlite=3.46.0=hde9e2c9_0
+ - libstdcxx-ng=13.2.0=hc0a3c3a_7
+ - libtasn1=4.19.0=h166bdaf_0
+ - libtiff=4.6.0=h1dd3fc0_3
+ - libunistring=0.9.10=h7f98852_0
+ - libuuid=2.38.1=h0b41bf4_0
+ - libva=2.21.0=h4ab18f5_2
+ - libvpx=1.14.1=hac33072_0
+ - libwebp-base=1.4.0=hd590300_0
+ - libxcb=1.15=h0b41bf4_0
+ - libxcrypt=4.4.36=hd590300_1
+ - libxml2=2.12.7=hc051c1a_1
+ - libzlib=1.3.1=h4ab18f5_1
+ - llvm-openmp=15.0.7=h0cdce71_0
+ - markupsafe=2.1.5=py310h2372a71_0
+ - mkl=2022.1.0=h84fe81f_915
+ - mkl-devel=2022.1.0=ha770c72_916
+ - mkl-include=2022.1.0=h84fe81f_915
+ - mpc=1.3.1=hfe3b2da_0
+ - mpfr=4.2.1=h9458935_1
+ - mpmath=1.3.0=pyhd8ed1ab_0
+ - ncurses=6.5=h59595ed_0
+ - nettle=3.9.1=h7ab15ed_0
+ - networkx=3.3=pyhd8ed1ab_1
+ - numpy=1.26.4=py310hb13e2d6_0
+ - ocl-icd=2.3.2=hd590300_1
+ - openh264=2.4.1=h59595ed_0
+ - openjpeg=2.5.2=h488ebb8_0
+ - openssl=3.3.1=h4ab18f5_0
+ - p11-kit=0.24.1=hc5aa10d_0
+ - pcre2=10.43=hcad00b1_0
+ - pillow=10.3.0=py310hf73ecf8_0
+ - pip=24.0=pyhd8ed1ab_0
+ - pixman=0.43.2=h59595ed_0
+ - pthread-stubs=0.4=h36c2ea0_1001
+ - pugixml=1.14=h59595ed_0
+ - pysocks=1.7.1=pyha2e5f31_6
+ - python=3.10.14=hd12c33a_0_cpython
+ - python_abi=3.10=4_cp310
+ - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0
+ - pytorch-cuda=12.1=ha16c6d3_5
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.1=py310h2372a71_1
+ - readline=8.2=h8228510_1
+ - snappy=1.2.0=hdb0a2a9_1
+ - svt-av1=2.1.0=hac33072_0
+ - sympy=1.12=pypyh9d50eac_103
+ - tbb=2021.12.0=h297d8ca_1
+ - tk=8.6.13=noxft_h4845f30_101
+ - torchaudio=2.1.0=py310_cu121
+ - torchtriton=2.1.0=py310
+ - torchvision=0.16.0=py310_cu121
+ - typing_extensions=4.12.2=pyha770c72_0
+ - wheel=0.43.0=pyhd8ed1ab_1
+ - x264=1!164.3095=h166bdaf_2
+ - x265=3.5=h924138e_3
+ - xorg-fixesproto=5.0=h7f98852_1002
+ - xorg-kbproto=1.0.7=h7f98852_1002
+ - xorg-libice=1.1.1=hd590300_0
+ - xorg-libsm=1.2.4=h7391055_0
+ - xorg-libx11=1.8.9=h8ee46fc_0
+ - xorg-libxau=1.0.11=hd590300_0
+ - xorg-libxdmcp=1.1.3=h7f98852_0
+ - xorg-libxext=1.3.4=h0b41bf4_2
+ - xorg-libxfixes=5.0.3=h7f98852_1004
+ - xorg-libxrender=0.9.11=hd590300_0
+ - xorg-renderproto=0.11.1=h7f98852_1002
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
+ - xorg-xproto=7.0.31=h7f98852_1007
+ - xz=5.2.6=h166bdaf_0
+ - yaml=0.2.5=h7f98852_2
+ - zlib=1.3.1=h4ab18f5_1
+ - zstd=1.5.6=ha6fb4c9_0
+ - pip:
+ - absl-py==2.1.0
+ - addict==2.4.0
+ - aliyun-python-sdk-core==2.15.1
+ - aliyun-python-sdk-kms==2.16.3
+ - ccimport==0.4.2
+ - cffi==1.16.0
+ - click==8.1.7
+ - colorama==0.4.6
+ - contourpy==1.2.1
+ - crcmod==1.7
+ - cryptography==42.0.8
+ - cumm-cu120==0.4.11
+ - cycler==0.12.1
+ - dropblock==0.3.0
+ - fire==0.6.0
+ - fonttools==4.53.0
+ - fsspec==2024.6.0
+ - grpcio==1.64.1
+ - importlib-metadata==7.1.0
+ - jmespath==0.10.0
+ - kiwisolver==1.4.5
+ - lark==1.1.9
+ - markdown==3.6
+ - markdown-it-py==3.0.0
+ - matplotlib==3.9.0
+ - mdurl==0.1.2
+ - mmcls==0.25.0
+ - mmcv-full==1.7.2
+ - model-index==0.1.11
+ - ninja==1.11.1.1
+ - opencv-python==4.10.0.82
+ - opendatalab==0.0.10
+ - openmim==0.3.9
+ - openxlab==0.1.0
+ - ordered-set==4.1.0
+ - oss2==2.17.0
+ - packaging==24.1
+ - pandas==2.2.2
+ - pccm==0.4.11
+ - platformdirs==4.2.2
+ - portalocker==2.8.2
+ - protobuf==4.25.3
+ - pybind11==2.12.0
+ - pycparser==2.22
+ - pycryptodome==3.20.0
+ - pyg-lib==0.4.0+pt21cu121
+ - pygments==2.18.0
+ - pyparsing==3.1.2
+ - python-dateutil==2.9.0.post0
+ - pytz==2023.4
+ - requests==2.28.2
+ - rich==13.4.2
+ - scipy==1.13.1
+ - setuptools==60.2.0
+ - six==1.16.0
+ - spconv-cu120==2.3.6
+ - tabulate==0.9.0
+ - tensorboard==2.17.0
+ - tensorboard-data-server==0.7.2
+ - tensorboardx==2.6.2.2
+ - termcolor==2.4.0
+ - thop==0.1.1-2209072238
+ - tomli==2.0.1
+ - torch-cluster==1.6.3+pt21cu121
+ - torch-scatter==2.1.2+pt21cu121
+ - torch-sparse==0.6.18+pt21cu121
+ - torch-spline-conv==1.2.2+pt21cu121
+ - torchprofile==0.0.4
+ - tqdm==4.65.2
+ - tzdata==2024.1
+ - urllib3==1.26.18
+ - werkzeug==3.0.3
+ - yapf==0.40.2
+ - zipp==3.19.2
+prefix: /home/jmwang/miniforge3/envs/occ_rwkv
diff --git a/networks/bev_net.py b/networks/bev_net.py
index 8b43c4a..5e6a62c 100644
--- a/networks/bev_net.py
+++ b/networks/bev_net.py
@@ -3,67 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
from dropblock import DropBlock2D
-
-
-class BEVFusion(nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, bev_features, sem_features, com_features):
- return torch.cat([bev_features, sem_features, com_features], dim=1)
-
- @staticmethod
- def channel_reduction(x, out_channels):
- """
- Args:
- x: (B, C1, H, W)
- out_channels: C2
-
- Returns:
-
- """
- B, in_channels, H, W = x.shape
- assert (in_channels % out_channels == 0) and (in_channels >= out_channels)
-
- x = x.view(B, out_channels, -1, H, W)
- # x = torch.max(x, dim=2)[0]
- x = x.sum(dim=2)
- return x
-
-
-class BEVUNet(nn.Module):
- def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock):
- super().__init__()
- self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding)
- self.down1 = down(64, 128, dilation, group_conv, circular_padding)
- self.down2 = down(256, 256, dilation, group_conv, circular_padding)
- self.down3 = down(512, 512, dilation, group_conv, circular_padding)
- self.down4 = down(1024, 512, dilation, group_conv, circular_padding)
- self.up1 = up(1536, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up2 = up(1024, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up3 = up(512, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.dropout = nn.Dropout(p=0. if dropblock else dropout)
- self.outc = outconv(128, n_class)
-
- self.bev_fusions = nn.ModuleList([BEVFusion() for _ in range(3)])
-
- def forward(self, x, sem_fea_list, com_fea_list):
- x1 = self.inc(x) # [B, 64, 256, 256]
- x2 = self.down1(x1) # [B, 128, 128, 128]
- x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 256
- x3 = self.down2(x2_f) # [B, 256, 64, 64]
- x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 512
- x4 = self.down3(x3_f) # [B, 512, 32, 32]
- x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 1024
- x5 = self.down4(x4_f) # [B, 512, 16, 16]
- x = self.up1(x5, x4_f)
- x = self.up2(x, x3_f)
- x = self.up3(x, x2_f)
- x = self.up4(x, x1)
- x = self.outc(self.dropout(x))
- return x
-
+from networks.vrwkv import Block as RWKVBlock
class BEVFusionv1(nn.Module):
def __init__(self, channel):
@@ -106,30 +46,62 @@ def forward(self, bev_features, sem_features, com_features):
class BEVUNetv1(nn.Module):
def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock):
super().__init__()
+ # Encoder
self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding)
self.down1 = down(64, 128, dilation, group_conv, circular_padding)
self.down2 = down(128, 256, dilation, group_conv, circular_padding)
self.down3 = down(256, 512, dilation, group_conv, circular_padding)
self.down4 = down(512, 512, dilation, group_conv, circular_padding)
- self.up1 = up(1024, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up2 = up(768, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up3 = up(384, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout)
- self.dropout = nn.Dropout(p=0. if dropblock else dropout)
- self.outc = outconv(128, n_class)
-
+
+ # RWKV blocks for the encoder
+
+ self.rwkv_block1 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0)
+ self.rwkv_block2 = RWKVBlock(n_embd=256, n_layer=18, layer_id=0)
+ self.rwkv_block3 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0)
+ self.rwkv_block4 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0)
+ self.rwkv_block_up1 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0)
+ self.rwkv_block_up2 = RWKVBlock(n_embd=256, n_layer=18, layer_id=0)
+ self.rwkv_block_up3 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0)
+ self.rwkv_block_up4 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0)
+
+ # BEV fusion modules
channels = [128, 256, 512]
self.bev_fusions = nn.ModuleList([BEVFusionv1(channels[i]) for i in range(3)])
+ # Decoder
+ self.up1 = up(1024, 512, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout)
+ self.up2 = up(768, 256, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout)
+ self.up3 = up(384, 128, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout)
+ self.up4 = up(192, 128, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout)
+ self.outc = outconv(128, n_class)
+
+ self.dropout = nn.Dropout(p=0. if dropblock else dropout)
+
def forward(self, x, sem_fea_list, com_fea_list):
- x1 = self.inc(x) # [B, 64, 256, 256]
- x2 = self.down1(x1) # [B, 128, 128, 128]
- x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 128
- x3 = self.down2(x2_f) # [B, 256, 64, 64]
- x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 256
- x4 = self.down3(x3_f) # [B, 512, 32, 32]
- x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 512
- x5 = self.down4(x4_f) # [B, 512, 16, 16]
+ # Encoder with RWKV blocks
+ x1 = self.inc(x)
+
+ x2 = self.down1(x1)
+ B, C, H, W = x2.shape
+ x2 = self.rwkv_block1(x2.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2)
+ x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0])
+
+ x3 = self.down2(x2_f)
+ B, C, H, W = x3.shape
+ x3 = self.rwkv_block2(x3.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2)
+ x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1])
+
+ x4 = self.down3(x3_f)
+ B, C, H, W = x4.shape
+ x4 = self.rwkv_block3(x4.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2)
+ x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2])
+
+ x5 = self.down4(x4_f)
+ B, C, H, W = x5.shape
+ x5 = self.rwkv_block4(x5.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2)
+
+
+ # return x
x = self.up1(x5, x4_f) # 512, 512
x = self.up2(x, x3_f) # 512, 256
x = self.up3(x, x2_f) # 256, 128
diff --git a/networks/completion.py b/networks/completion.py
index a4c1adf..7dc8798 100644
--- a/networks/completion.py
+++ b/networks/completion.py
@@ -1,9 +1,9 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
-
+from networks.vrwkv import Block as RWKVBlock
from utils.lovasz_losses import lovasz_softmax
-
+from utils.ssc_loss import geo_scal_loss
class ResBlock(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, padding, stride, dilation=1):
@@ -36,6 +36,11 @@ def __init__(self, init_size=32, nbr_class=20, phase='trainval'):
self.block_1 = make_layers(16, 16, kernel_size=3, padding=1, stride=1, dilation=1, blocks=1) # 1/2, 16
self.block_2 = make_layers(16, 32, kernel_size=3, padding=1, stride=1, dilation=1, downsample=True, blocks=1) # 1/4, 32
self.block_3 = make_layers(32, 64, kernel_size=3, padding=2, stride=1, dilation=2, downsample=True, blocks=1) # 1/8, 64
+
+ # RWKV blocks
+ self.rwkv_block1 = RWKVBlock(n_embd=64, n_layer=18, layer_id=0)
+ self.rwkv_block2 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0)
+ self.rwkv_block3 = RWKVBlock(n_embd=256,n_layer=18, layer_id=0)
self.reduction_1 = nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
@@ -72,7 +77,19 @@ def forward_once(self, inputs):
bev_1 = self.reduction_1(res1.flatten(1, 2)) # B, 64, 128, 128
bev_2 = self.reduction_2(res2.flatten(1, 2)) # B, 128, 64, 64
bev_3 = res3.flatten(1, 2) # B, 256, 32, 32
-
+
+ B, C, H, W = bev_1.shape
+ patch_resolution = (H, W)
+ bev_1 = self.rwkv_block1(bev_1.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
+
+ B, C, H, W = bev_2.shape
+ patch_resolution = (H, W)
+ bev_2 = self.rwkv_block2(bev_2.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
+
+ B, C, H, W = bev_3.shape
+ patch_resolution = (H, W)
+ bev_3 = self.rwkv_block3(bev_3.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
+
if self.phase == 'trainval':
logits_2 = self.out2(res1)
logits_4 = self.out4(res2)
@@ -123,4 +140,4 @@ def forward(self, data_dict, example):
else:
out_dict = self.forward_once(data_dict['vw_dense'])
return out_dict
-
+
diff --git a/networks/dsc.py b/networks/occrwkv.py
similarity index 98%
rename from networks/dsc.py
rename to networks/occrwkv.py
index 013f64b..56f699e 100644
--- a/networks/dsc.py
+++ b/networks/occrwkv.py
@@ -7,12 +7,13 @@
import torch.nn.functional as F
from .preprocess import PcPreprocessor
-from .bev_net import BEVUNet, BEVUNetv1
+from .bev_net import BEVUNetv1
from .completion import CompletionBranch
from .semantic_segmentation import SemanticBranch
from utils.lovasz_losses import lovasz_softmax
-class DSC(nn.Module):
+
+class OccRWKV(nn.Module):
def __init__(self, cfg, phase='trainval'):
super().__init__()
self.phase = phase
@@ -83,7 +84,6 @@ def compute_loss(self, scores, labels, ss_loss_dict, sc_loss_dict):
:param: prediction: the predicted tensor, must be [BS, C, H, W, D]
'''
class_weights = self.get_class_weights().to(device=scores.device, dtype=scores.dtype)
-
loss_1_1 = F.cross_entropy(scores, labels.long(), weight=class_weights, ignore_index=255)
loss_1_1 += lovasz_softmax(torch.nn.functional.softmax(scores, dim=1), labels.long(), ignore=255)
loss_1_1 *= 3
diff --git a/networks/semantic_segmentation.py b/networks/semantic_segmentation.py
index 50856b9..42dc189 100644
--- a/networks/semantic_segmentation.py
+++ b/networks/semantic_segmentation.py
@@ -3,13 +3,14 @@
import torch
import torch.nn.functional as F
import torch_scatter
-
import spconv.pytorch as spconv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils.lovasz_losses import lovasz_softmax
+from networks.vrwkv import Block as RWKVBlock
+from utils.ssc_loss import sem_scal_loss
class BasicBlock(spconv.SparseModule):
def __init__(self, C_in, C_out, indice_key):
@@ -174,6 +175,11 @@ def __init__(self, sizes=[256, 256, 32], nbr_class=19, init_size=32, class_frequ
self.proj3_block = SGFE(input_channels=128, output_channels=256,\
reduce_channels=128, name="proj3")
+ # RWKV blocks
+ self.rwkv_block1 = RWKVBlock(n_embd=64, n_layer=18, layer_id=0)
+ self.rwkv_block2 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0)
+ self.rwkv_block3 = RWKVBlock(n_embd=256,n_layer=18, layer_id=0)
+
self.phase = phase
if phase == 'trainval':
num_class = self.nbr_class # SemanticKITTI: 19
@@ -215,39 +221,40 @@ def forward_once(self, vw_features, coord_ind, full_coord, pw_label, info):
vw_features, coord.int(), np.array(self.sizes, np.int32)[::-1], batch_size
)
conv1_output = self.conv1_block(input_tensor)
- proj1_vw, vw1_coord, pw1_coord = self.proj1_block(info, conv1_output.features, output_scale=2, input_coords=coord.int(),
- input_coords_inv=full_coord)
+
+ proj1_vw, vw1_coord, pw1_coord = self.proj1_block(info, conv1_output.features, output_scale=2, input_coords=coord.int(), input_coords_inv=full_coord)
proj1_bev = self.bev_projection(proj1_vw, vw1_coord, (np.array(self.sizes, np.int32) // 2)[::-1], batch_size)
-
- conv2_input_tensor = spconv.SparseConvTensor(
- proj1_vw, vw1_coord.int(), (np.array(self.sizes, np.int32) // 2)[::-1], batch_size
- )
+ B, C, H, W = proj1_bev.shape
+ patch_resolution = (H, W)
+ proj1_bev_rwkv = self.rwkv_block1(proj1_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
+
+
+ conv2_input_tensor = spconv.SparseConvTensor(proj1_vw, vw1_coord.int(), (np.array(self.sizes, np.int32) // 2)[::-1], batch_size)
conv2_output = self.conv2_block(conv2_input_tensor)
- proj2_vw, vw2_coord, pw2_coord = self.proj2_block(info, conv2_output.features, output_scale=4, input_coords=vw1_coord.int(),
- input_coords_inv=pw1_coord)
+ proj2_vw, vw2_coord, pw2_coord = self.proj2_block(info, conv2_output.features, output_scale=4, input_coords=vw1_coord.int(), input_coords_inv=pw1_coord)
proj2_bev = self.bev_projection(proj2_vw, vw2_coord, (np.array(self.sizes, np.int32) // 4)[::-1], batch_size)
+ B, C, H, W = proj2_bev.shape
+ patch_resolution = (H, W)
+ proj2_bev_rwkv = self.rwkv_block2(proj2_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
- conv3_input_tensor = spconv.SparseConvTensor(
- proj2_vw, vw2_coord.int(), (np.array(self.sizes, np.int32) // 4)[::-1], batch_size
- )
+ conv3_input_tensor = spconv.SparseConvTensor(proj2_vw, vw2_coord.int(), (np.array(self.sizes, np.int32) // 4)[::-1], batch_size)
conv3_output = self.conv3_block(conv3_input_tensor)
- proj3_vw, vw3_coord, _ = self.proj3_block(info, conv3_output.features, output_scale=8, input_coords=vw2_coord.int(),
- input_coords_inv=pw2_coord)
+ proj3_vw, vw3_coord, _ = self.proj3_block(info, conv3_output.features, output_scale=8, input_coords=vw2_coord.int(), input_coords_inv=pw2_coord)
proj3_bev = self.bev_projection(proj3_vw, vw3_coord, (np.array(self.sizes, np.int32) // 8)[::-1], batch_size)
+ B, C, H, W = proj3_bev.shape
+ patch_resolution = (H, W)
+ proj3_bev_rwkv = self.rwkv_block3(proj3_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2)
if self.phase == 'trainval':
- index_02 = torch.cat([info[2]['bxyz_indx'][:, 0].unsqueeze(-1),
- torch.flip(info[2]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
- index_04 = torch.cat([info[4]['bxyz_indx'][:, 0].unsqueeze(-1),
- torch.flip(info[4]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
- index_08 = torch.cat([info[8]['bxyz_indx'][:, 0].unsqueeze(-1),
- torch.flip(info[8]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
+ index_02 = torch.cat([info[2]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[2]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
+ index_04 = torch.cat([info[4]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[4]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
+ index_08 = torch.cat([info[8]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[8]['bxyz_indx'], dims=[1])[:, :3]], dim=1)
vw_label_02 = voxel_sem_target(index_02.int(), pw_label.int())[0]
vw_label_04 = voxel_sem_target(index_04.int(), pw_label.int())[0]
vw_label_08 = voxel_sem_target(index_08.int(), pw_label.int())[0]
return dict(
- mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev],
+ mss_bev_dense = [proj1_bev_rwkv, proj2_bev_rwkv, proj3_bev_rwkv],
mss_logits_list = [
[vw_label_02.clone(), self.out2(proj1_vw)],
[vw_label_04.clone(), self.out4(proj2_vw)],
@@ -255,7 +262,7 @@ def forward_once(self, vw_features, coord_ind, full_coord, pw_label, info):
)
return dict(
- mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev]
+ mss_bev_dense = [proj1_bev_rwkv, proj2_bev_rwkv, proj3_bev_rwkv]
)
def forward(self, data_dict, example):
@@ -284,6 +291,49 @@ def forward(self, data_dict, example):
out_dict = self.forward_once(data_dict['vw_features'],
data_dict['coord_ind'], data_dict['full_coord'], None, data_dict['info'])
return out_dict
+
+
+
+ # def forward(self, data_dict, example):
+ # if self.phase == 'trainval':
+ # out_dict = self.forward_once(data_dict['vw_features'],
+ # data_dict['coord_ind'], data_dict['full_coord'],
+ # example['points_label'], data_dict['info'])
+ # all_teach_pair = out_dict['mss_logits_list']
+
+ # # Define weights for each scale. Adjust these as needed.
+ # # For example, giving the first scale (0 index) a larger weight.
+ # scale_weights = [3.0, 1.0, 1.0]
+ # loss_dict = {}
+ # for i, teach_pair in enumerate(all_teach_pair):
+ # voxel_labels_copy = teach_pair[0].long().clone()
+ # voxel_labels_copy[voxel_labels_copy == 0] = 256
+ # voxel_labels_copy = voxel_labels_copy - 1
+
+ # #res04_loss = lovasz_softmax(F.softmax(teach_pair[1], dim=1), voxel_labels_copy, ignore=255)
+ # sem_scal_loss_val = sem_scal_loss(teach_pair[1], voxel_labels_copy) # Calculate semantic scale loss
+
+ # # Apply the weight to the scale-specific losses
+ # #weighted_res04_loss = res04_loss * scale_weights[i]
+ # weighted_sem_scal_loss_val = sem_scal_loss_val * scale_weights[i]
+
+ # # Update the loss dictionary with weighted losses
+ # #loss_dict["vw_" + str(i) + "_lovasz_loss"] = weighted_res04_loss
+ # loss_dict["vw_" + str(i) + "_sem_scal_loss"] = weighted_sem_scal_loss_val
+
+ # # Sum up all computed losses to form the total loss
+ # total_loss = sum(loss_dict.values())
+ # loss_dict["total_loss"] = total_loss
+
+ # return dict(
+ # mss_bev_dense=out_dict['mss_bev_dense'],
+ # loss=loss_dict
+ # )
+ # else:
+ # out_dict = self.forward_once(data_dict['vw_features'],
+ # data_dict['coord_ind'], data_dict['full_coord'],
+ # None, data_dict['info'])
+ # return out_dict
def get_class_weights(self):
'''
diff --git a/networks/vrwkv.py b/networks/vrwkv.py
new file mode 100644
index 0000000..81f070b
--- /dev/null
+++ b/networks/vrwkv.py
@@ -0,0 +1,439 @@
+# Copyright (c) Shanghai AI Lab. All rights reserved.
+from typing import Sequence
+import math, os
+import logging
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import torch.utils.checkpoint as cp
+
+
+
+
+
+def resize_pos_embed(pos_embed,
+ src_shape,
+ dst_shape,
+ mode='bicubic',
+ num_extra_tokens=1):
+ """Resize pos_embed weights.
+
+ Args:
+ pos_embed (torch.Tensor): Position embedding weights with shape
+ [1, L, C].
+ src_shape (tuple): The resolution of downsampled origin training
+ image, in format (H, W).
+ dst_shape (tuple): The resolution of downsampled new training
+ image, in format (H, W).
+ mode (str): Algorithm used for upsampling. Choose one from 'nearest',
+ 'linear', 'bilinear', 'bicubic' and 'trilinear'.
+ Defaults to 'bicubic'.
+ num_extra_tokens (int): The number of extra tokens, such as cls_token.
+ Defaults to 1.
+
+ Returns:
+ torch.Tensor: The resized pos_embed of shape [1, L_new, C]
+ """
+ if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
+ return pos_embed
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
+ _, L, C = pos_embed.shape
+ src_h, src_w = src_shape
+ assert L == src_h * src_w + num_extra_tokens, \
+ f"The length of `pos_embed` ({L}) doesn't match the expected " \
+ f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
+ '`img_size` argument.'
+ extra_tokens = pos_embed[:, :num_extra_tokens]
+
+ src_weight = pos_embed[:, num_extra_tokens:]
+ src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
+
+ dst_weight = F.interpolate(
+ src_weight, size=dst_shape, align_corners=False, mode=mode)
+ dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
+
+ return torch.cat((extra_tokens, dst_weight), dim=1)
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, patch_size, in_chans=3, embed_dim=768):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+
+logger = logging.getLogger(__name__)
+
+T_MAX = 66000
+
+
+from torch.utils.cpp_extension import load
+wkv_cuda = load(name="wkv", sources=["/home/jmwang/OccRWKV/networks/cuda/wkv_op.cpp", "/home/jmwang/OccRWKV/networks/cuda/wkv_cuda.cu"],
+ verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
+
+
+class WKV(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, B, T, C, w, u, k, v):
+ ctx.B = B
+ ctx.T = T
+ ctx.C = C
+ # print(T)
+ assert T <= T_MAX
+ assert B * C % min(C, 1024) == 0
+
+ half_mode = (w.dtype == torch.half)
+ bf_mode = (w.dtype == torch.bfloat16)
+ ctx.save_for_backward(w, u, k, v)
+ w = w.float().contiguous()
+ u = u.float().contiguous()
+ k = k.float().contiguous()
+ v = v.float().contiguous()
+ y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
+ if half_mode:
+ y = y.half()
+ elif bf_mode:
+ y = y.bfloat16()
+ return y
+
+ @staticmethod
+ def backward(ctx, gy):
+ B = ctx.B
+ T = ctx.T
+ C = ctx.C
+ assert T <= T_MAX
+ assert B * C % min(C, 1024) == 0
+ w, u, k, v = ctx.saved_tensors
+ gw = torch.zeros((B, C), device='cuda').contiguous()
+ gu = torch.zeros((B, C), device='cuda').contiguous()
+ gk = torch.zeros((B, T, C), device='cuda').contiguous()
+ gv = torch.zeros((B, T, C), device='cuda').contiguous()
+ half_mode = (w.dtype == torch.half)
+ bf_mode = (w.dtype == torch.bfloat16)
+ wkv_cuda.backward(B, T, C,
+ w.float().contiguous(),
+ u.float().contiguous(),
+ k.float().contiguous(),
+ v.float().contiguous(),
+ gy.float().contiguous(),
+ gw, gu, gk, gv)
+ if half_mode:
+ gw = torch.sum(gw.half(), dim=0)
+ gu = torch.sum(gu.half(), dim=0)
+ return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
+ elif bf_mode:
+ gw = torch.sum(gw.bfloat16(), dim=0)
+ gu = torch.sum(gu.bfloat16(), dim=0)
+ return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
+ else:
+ gw = torch.sum(gw, dim=0)
+ gu = torch.sum(gu, dim=0)
+ return (None, None, None, gw, gu, gk, gv)
+
+
+def RUN_CUDA(B, T, C, w, u, k, v):
+ return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
+
+
+def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):
+ assert gamma <= 1/4
+ B, N, C = input.shape
+ input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1])
+ B, C, H, W = input.shape
+ output = torch.zeros_like(input)
+ output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel]
+ output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W]
+ output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :]
+ output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :]
+ output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...]
+ return output.flatten(2).transpose(1, 2)
+
+
+class VRWKV_SpatialMix(nn.Module):
+ def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift',
+ channel_gamma=1/4, shift_pixel=1, init_mode='fancy',
+ key_norm=False, with_cp=False):
+ super().__init__()
+ self.layer_id = layer_id
+ self.n_layer = n_layer
+ self.n_embd = n_embd
+ self.device = None
+ attn_sz = n_embd
+ self._init_weights(init_mode)
+ self.shift_pixel = shift_pixel
+ self.shift_mode = shift_mode
+ if shift_pixel > 0:
+ self.shift_func = eval(shift_mode)
+ self.channel_gamma = channel_gamma
+ else:
+ self.spatial_mix_k = None
+ self.spatial_mix_v = None
+ self.spatial_mix_r = None
+
+ self.key = nn.Linear(n_embd, attn_sz, bias=False)
+ self.value = nn.Linear(n_embd, attn_sz, bias=False)
+ self.receptance = nn.Linear(n_embd, attn_sz, bias=False)
+ if key_norm:
+ self.key_norm = nn.LayerNorm(n_embd)
+ else:
+ self.key_norm = None
+ self.output = nn.Linear(attn_sz, n_embd, bias=False)
+
+ self.key.scale_init = 0
+ self.receptance.scale_init = 0
+ self.output.scale_init = 0
+
+ self.with_cp = with_cp
+
+ def _init_weights(self, init_mode):
+ if init_mode=='fancy':
+ with torch.no_grad(): # fancy init
+ ratio_0_to_1 = (self.layer_id / (self.n_layer - 1)) # 0 to 1
+ ratio_1_to_almost0 = (1.0 - (self.layer_id / self.n_layer)) # 1 to ~0
+
+ # fancy time_decay
+ decay_speed = torch.ones(self.n_embd)
+ for h in range(self.n_embd):
+ decay_speed[h] = -5 + 8 * (h / (self.n_embd-1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ self.spatial_decay = nn.Parameter(decay_speed)
+
+ # fancy time_first
+ zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(self.n_embd)]) * 0.5)
+ self.spatial_first = nn.Parameter(torch.ones(self.n_embd) * math.log(0.3) + zigzag)
+
+ # fancy time_mix
+ x = torch.ones(1, 1, self.n_embd)
+ for i in range(self.n_embd):
+ x[0, 0, i] = i / self.n_embd
+ self.spatial_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.spatial_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
+ self.spatial_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
+ elif init_mode=='local':
+ self.spatial_decay = nn.Parameter(torch.ones(self.n_embd))
+ self.spatial_first = nn.Parameter(torch.ones(self.n_embd))
+ self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]))
+ self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]))
+ self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]))
+ elif init_mode=='global':
+ self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd))
+ self.spatial_first = nn.Parameter(torch.zeros(self.n_embd))
+ self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
+ self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
+ self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
+ else:
+ raise NotImplementedError
+
+ def jit_func(self, x, patch_resolution):
+ # Mix x with the previous timestep to produce xk, xv, xr
+ B, T, C = x.size()
+ if self.shift_pixel > 0:
+ xx = self.shift_func(x, self.shift_pixel, self.channel_gamma, patch_resolution)
+ xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)
+ xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v)
+ xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)
+ else:
+ xk = x
+ xv = x
+ xr = x
+
+ # Use xk, xv, xr to produce k, v, r
+ k = self.key(xk)
+ v = self.value(xv)
+ r = self.receptance(xr)
+ sr = torch.sigmoid(r)
+
+ return sr, k, v
+
+ def forward(self, x, patch_resolution=None):
+
+ def _inner_forward(x):
+
+ B, T, C = x.size()
+ self.device = x.device
+
+ sr, k, v = self.jit_func(x, patch_resolution)
+ x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v)
+ if self.key_norm is not None:
+ x = self.key_norm(x)
+ x = sr * x
+ x = self.output(x)
+ return x
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+ return x
+
+
+class VRWKV_ChannelMix(nn.Module):
+ def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift',
+ channel_gamma=1/4, shift_pixel=1, hidden_rate=4, init_mode='fancy',
+ key_norm=False, with_cp=False):
+ super().__init__()
+ self.layer_id = layer_id
+ self.n_layer = n_layer
+ self.n_embd = n_embd
+ self.with_cp = with_cp
+ self._init_weights(init_mode)
+ self.shift_pixel = shift_pixel
+ self.shift_mode = shift_mode
+ if shift_pixel > 0:
+ self.shift_func = eval(shift_mode)
+ self.channel_gamma = channel_gamma
+ else:
+ self.spatial_mix_k = None
+ self.spatial_mix_r = None
+
+ hidden_sz = hidden_rate * n_embd
+ self.key = nn.Linear(n_embd, hidden_sz, bias=False)
+ if key_norm:
+ self.key_norm = nn.LayerNorm(hidden_sz)
+ else:
+ self.key_norm = None
+ self.receptance = nn.Linear(n_embd, n_embd, bias=False)
+ self.value = nn.Linear(hidden_sz, n_embd, bias=False)
+
+ self.value.scale_init = 0
+ self.receptance.scale_init = 0
+
+ def _init_weights(self, init_mode):
+ if init_mode == 'fancy':
+ with torch.no_grad(): # fancy init of time_mix
+ ratio_1_to_almost0 = (1.0 - (self.layer_id / self.n_layer)) # 1 to ~0
+ x = torch.ones(1, 1, self.n_embd)
+ for i in range(self.n_embd):
+ x[0, 0, i] = i / self.n_embd
+ self.spatial_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ self.spatial_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
+ elif init_mode == 'local':
+ self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]))
+ self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]))
+ elif init_mode == 'global':
+ self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
+ self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5)
+ else:
+ raise NotImplementedError
+
+ def forward(self, x, patch_resolution=None):
+ def _inner_forward(x):
+ if self.shift_pixel > 0:
+ xx = self.shift_func(x, self.shift_pixel, self.channel_gamma, patch_resolution)
+ xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k)
+ xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r)
+ else:
+ xk = x
+ xr = x
+
+ k = self.key(xk)
+ k = torch.square(torch.relu(k))
+ if self.key_norm is not None:
+ k = self.key_norm(k)
+ kv = self.value(k)
+ x = torch.sigmoid(self.receptance(xr)) * kv
+ return x
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift',
+ channel_gamma=1/4, shift_pixel=1, drop_path=0., hidden_rate=4,
+ init_mode='fancy', init_values=None, post_norm=False, key_norm=False,
+ with_cp=False):
+ super().__init__()
+ self.layer_id = layer_id
+ self.ln1 = nn.LayerNorm(n_embd)
+ self.ln2 = nn.LayerNorm(n_embd)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ if self.layer_id == 0:
+ self.ln0 = nn.LayerNorm(n_embd)
+
+ self.att = VRWKV_SpatialMix(n_embd, n_layer, layer_id, shift_mode,
+ channel_gamma, shift_pixel, init_mode,
+ key_norm=key_norm)
+
+ self.ffn = VRWKV_ChannelMix(n_embd, n_layer, layer_id, shift_mode,
+ channel_gamma, shift_pixel, hidden_rate,
+ init_mode, key_norm=key_norm)
+ self.layer_scale = (init_values is not None)
+ self.post_norm = post_norm
+ if self.layer_scale:
+ self.gamma1 = nn.Parameter(init_values * torch.ones((n_embd)), requires_grad=True)
+ self.gamma2 = nn.Parameter(init_values * torch.ones((n_embd)), requires_grad=True)
+ self.with_cp = with_cp
+
+ def forward(self, x, patch_resolution=None):
+ def _inner_forward(x):
+ if self.layer_id == 0:
+ x = self.ln0(x)
+ if self.post_norm:
+ if self.layer_scale:
+ x = x + self.drop_path(self.gamma1 * self.ln1(self.att(x, patch_resolution)))
+ x = x + self.drop_path(self.gamma2 * self.ln2(self.ffn(x, patch_resolution)))
+ else:
+ x = x + self.drop_path(self.ln1(self.att(x, patch_resolution)))
+ x = x + self.drop_path(self.ln2(self.ffn(x, patch_resolution)))
+ else:
+ if self.layer_scale:
+ x = x + self.drop_path(self.gamma1 * self.att(self.ln1(x), patch_resolution))
+ x = x + self.drop_path(self.gamma2 * self.ffn(self.ln2(x), patch_resolution))
+ else:
+ x = x + self.drop_path(self.att(self.ln1(x), patch_resolution))
+ x = x + self.drop_path(self.ffn(self.ln2(x), patch_resolution))
+ return x
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+ return x
+
+
diff --git a/test.py b/test.py
index 6c9d916..f1b495b 100644
--- a/test.py
+++ b/test.py
@@ -5,7 +5,7 @@
import sys
import numpy as np
import time
-
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# Append root directory to system path for imports
repo_path, _ = os.path.split(os.path.realpath(__file__))
repo_path, _ = os.path.split(repo_path)
@@ -21,7 +21,7 @@
def parse_args():
- parser = argparse.ArgumentParser(description='DSC validating')
+ parser = argparse.ArgumentParser(description='OCCRWKV validating')
parser.add_argument(
'--weights',
dest='weights_file',
diff --git a/train.py b/train.py
index b012155..530720d 100644
--- a/train.py
+++ b/train.py
@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
import os
-os.environ["CUDA_VISIBLE_DEVICES"] = "2"
+os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import torch
import torch.nn as nn
@@ -24,11 +24,11 @@
import utils.checkpoint as checkpoint
def parse_args():
- parser = argparse.ArgumentParser(description='DSC training')
+ parser = argparse.ArgumentParser(description='OccRWKV training')
parser.add_argument(
'--cfg',
dest='config_file',
- default='cfgs/SSC-RS.yaml',
+ default='cfgs/OccRWKV.yaml',
metavar='FILE',
help='path to config file',
type=str,
@@ -147,7 +147,6 @@ def train(model, optimizer, scheduler, dataset, _cfg, start_epoch, logger, tbwri
for scale in metrics.evaluator.keys():
tbwriter.add_scalar('train_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch - 1)
tbwriter.add_scalar('train_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch - 1)
- tbwriter.add_scalar('train_performance/{}/Seg_mIoU'.format(scale), seg_miou, epoch - 1)
tbwriter.add_scalar('train_performance/{}/Precision'.format(scale), metrics.get_occupancy_Precision(scale).item(), epoch-1)
tbwriter.add_scalar('train_performance/{}/Recall'.format(scale), metrics.get_occupancy_Recall(scale).item(), epoch-1)
tbwriter.add_scalar('train_performance/{}/F1'.format(scale), metrics.get_occupancy_F1(scale).item(), epoch-1)
@@ -256,10 +255,10 @@ def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics):
checkpoint_info = {}
- if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']:
- logger.info('=> Best loss on validation set encountered: ({} < {})'.format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS']))
- _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item()
- checkpoint_info['best-loss'] = 'BEST_LOSS'
+ # if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']:
+ # logger.info('=> Best loss on validation set encountered: ({} < {})'.format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS']))
+ # _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item()
+ # checkpoint_info['best-loss'] = 'BEST_LOSS'
mIoU_1_1 = metrics.get_semantics_mIoU('1_1')
IoU_1_1 = metrics.get_occupancy_IoU('1_1')
@@ -276,9 +275,9 @@ def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics):
def main():
# https://github.com/pytorch/pytorch/issues/27588
- torch.backends.cudnn.enabled = False
+ torch.backends.cudnn.enabled = True
- seed_all(10)
+ seed_all(42)
args = parse_args()
@@ -304,9 +303,7 @@ def main():
logger.info('=> Loading network architecture...')
model = get_model(_cfg._dict, phase='trainval')
- if torch.cuda.device_count() > 1:
- model = nn.DataParallel(model)
- model = model.module
+
logger.info(f'=> Model Parameters: {sum(p.numel() for p in model.parameters())/1000000.0} M')
logger.info('=> Loading optimizer...')
diff --git a/validate.py b/validate.py
index 88f52c1..d7de92d 100644
--- a/validate.py
+++ b/validate.py
@@ -3,12 +3,12 @@
import torch
import torch.nn as nn
import sys
-
+from thop import profile
# Append root directory to system path for imports
repo_path, _ = os.path.split(os.path.realpath(__file__))
repo_path, _ = os.path.split(repo_path)
sys.path.append(repo_path)
-
+os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from utils.seed import seed_all
from utils.config import CFG
from utils.dataset import get_dataset
@@ -44,6 +44,7 @@ def parse_args():
def validate(model, dset, _cfg, logger, metrics):
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float32 # Tensor type to be used
@@ -54,7 +55,8 @@ def validate(model, dset, _cfg, logger, metrics):
logger.info('=> Passing the network on the validation set...')
time_list = []
model.eval()
-
+ flops_list = []
+ total_flops = 0
with torch.no_grad():
for t, (data, indices) in enumerate(tqdm(dset, ncols=100)):
@@ -62,6 +64,15 @@ def validate(model, dset, _cfg, logger, metrics):
start_time = time.time()
scores, loss = model(data)
time_list.append(time.time() - start_time)
+ input_data = next(iter(data.values()))
+
+ # 计算FLOPs
+ flops, params = profile(model, inputs=(data,), verbose=False)
+ total_flops += flops
+ current_gflops = flops / 1e9
+ logger.info(f'Current batch GFLOPs: {current_gflops}')
+
+
# Updating batch losses to then get mean for epoch loss
metrics.losses_track.update_validaiton_losses(loss)
@@ -93,6 +104,9 @@ def validate(model, dset, _cfg, logger, metrics):
class_score = metrics.evaluator['1_1'].getIoU()[1][i]
logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score))
+ # 计算总FLOPs并转换为GFLOPs
+ total_gflops = (total_flops / len(dset.dataset)) / 1e9 # 转换为GFLOPs
+ logger.info(f'Average GFLOPs per input for validation set: {total_gflops}')
return time_list