From 552cf942d044a371da1e298f4e77cfc90a8f44f6 Mon Sep 17 00:00:00 2001 From: jmwang0117 <1021347250@qq.com> Date: Fri, 27 Sep 2024 22:00:45 +0000 Subject: [PATCH] Update --- README.md | 71 +++-- environment.yaml | 251 +++++++++++++++++ networks/bev_net.py | 124 ++++----- networks/completion.py | 25 +- networks/{dsc.py => occrwkv.py} | 6 +- networks/semantic_segmentation.py | 94 +++++-- networks/vrwkv.py | 439 ++++++++++++++++++++++++++++++ test.py | 4 +- train.py | 23 +- validate.py | 20 +- 10 files changed, 911 insertions(+), 146 deletions(-) create mode 100644 environment.yaml rename networks/{dsc.py => occrwkv.py} (98%) create mode 100644 networks/vrwkv.py diff --git a/README.md b/README.md index d82934c..5208ae8 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,39 @@ -# OccRWKV +
+# 🤖 OccRWKV: Rethinking Efficient 3D Semantic Occupancy Prediction with Linear Complexity +*We will open source the complete code after the paper is accepted* +
-## OccRWKV: Rethinking Sparse Latent Representation for 3D Semantic Occupancy Prediction +## 📢 News -## Preperation +- [2024/09]: OccRWKV's logs are available for download: +
+ +| OccRWKV Results | Experiment Log | +|:------------------------------------------------------------------:|:----------:| +|OccRWKV on the SemanticKITTI hidden official test dataset | [link](https://connecthkuhk-my.sharepoint.com/:t:/g/personal/u3009632_connect_hku_hk/EYqFDMD6xexCqXwfZ_nPxEUB0akfqePg4TwuGiuf4fQK0Q?e=PFM1ma) | +|OccRWKV train log | [link](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/u3009632_connect_hku_hk/EcKG5MgDCTJJuu8DJ7VoS9sB0euzAEaMkpLjlY9LvRJ0GA?e=lwddX3) | + +
+ +- [2024/08]: The pre-trained model can be downloaded at [OneDrive](https://connecthkuhk-my.sharepoint.com/:u:/g/personal/u3009632_connect_hku_hk/ETCUIJ7rPnFJniQYMsDsPyIBHkzirRP4c3n-eU9fcBZTaA?e=P8AkQ2). +- [2024/07]: 🔥 We released the code of OccRWKV. *The First Receptance Weighted Key Value (RWKV)-based 3D Semantic Occupancy Network* + +
+ +``` +@article{wang2024omega, +title={OccRWKV: Rethinking Efficient 3D Semantic Occupancy Prediction with Linear Complexity}, +author={Wang, Junming and Yin, Wei and Long, Xiaoxiao and Zhang, Xinyu and Xing, Zebing and Guo, Xiaoyang and Qian, Zhang}, +year={2024} + } +``` + +Please kindly star ⭐️ this project if it helps you. We take great efforts to develop and maintain it 😁. + + +## 🛠️ Installation -### Prerequisites ``` conda create -n occ_rwkv python=3.10 -y conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y @@ -13,9 +41,12 @@ pip install spconv-cu120 pip install tensorboardX pip install dropblock pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html +pip install -U openmim +mim install mmcv-full +pip install mmcls==0.25.0 ``` -### Dataset +## 💽 Dataset Please download the Semantic Scene Completion dataset (v1.1) from the [SemanticKITTI website](http://www.semantic-kitti.org/dataset.html) and extract it. @@ -35,7 +66,7 @@ SemanticKITTI │ │ ├── ... ... ``` -## Getting Start +## 🤗 Getting Start Clone the repository: ``` https://github.com/jmwang0117/OccRWKV.git @@ -49,7 +80,7 @@ We provide training routine examples in the `cfgs` folder. Make sure to change t ``` $ cd -$ python train.py --cfg cfgs/DSC-Base.yaml --dset_root +$ bash scripts/run_train.sh ``` ### Validation @@ -57,7 +88,7 @@ Validation passes are done during training routine. Additional pass in the valid ``` $ cd -$ python validate.py --weights --dset_root +$ bash scripts/run_val.sh ``` ### Test @@ -69,21 +100,15 @@ You can provide which checkpoints you want to use for testing. We used the ones ``` $ cd -$ python test.py --weights --dset_root --out_path +$ bash scripts/run_test.sh ``` -### Pretrained Model - -You can download the models with the scores below from this [Google drive link](https://drive.google.com/file/d/1-b3O7QS6hBQIGFTO-7qSG7Zb9kbQuxdO/view?usp=sharing), - -| Model | Segmentation | Completion | -|--|--|--| -| SSC-RS | 24.2 | 59.7 | -* Results reported to SemanticKITTI: Semantic Scene Completion leaderboard ([link](https://codalab.lisn.upsaclay.fr/competitions/7170\#results)). -## Acknowledgement -This project is not possible without multiple great opensourced codebases. -* [spconv](https://github.com/traveller59/spconv) -* [LMSCNet](https://github.com/cv-rits/LMSCNet) -* [SSA-SC](https://github.com/jokester-zzz/SSA-SC) -* [GASN](https://github.com/ItIsFriday/PcdSeg) +## 🏆 Acknowledgement +Many thanks to these excellent open source projects: +- [AGRNav](https://github.com/jmwang0117/AGRNav) +- [Prometheus](https://github.com/amov-lab/Prometheus) +- [SSC-RS](https://github.com/Jieqianyu/SSC-RS) +- [semantic-kitti-api](https://github.com/PRBonn/semantic-kitti-api) +- [Terrestrial-Aerial-Navigation](https://github.com/ZJU-FAST-Lab/Terrestrial-Aerial-Navigation) +- [EGO-Planner](https://github.com/ZJU-FAST-Lab/ego-planner-swarm) diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..c91f139 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,251 @@ +name: occ_rwkv +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_kmp_llvm + - aom=3.9.0=hac33072_0 + - blas=2.116=mkl + - blas-devel=3.9.0=16_linux64_mkl + - brotli-python=1.1.0=py310hc6cd4ac_1 + - bzip2=1.0.8=hd590300_5 + - ca-certificates=2024.6.2=hbcca054_0 + - cairo=1.18.0=h3faef2a_0 + - certifi=2024.6.2=pyhd8ed1ab_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.5.39=0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.5=3 + - dav1d=1.2.1=hd590300_0 + - expat=2.6.2=h59595ed_0 + - ffmpeg=7.0.1=gpl_hb399a10_100 + - filelock=3.14.0=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_2 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - freetype=2.12.1=h267a509_2 + - fribidi=1.0.10=h36c2ea0_0 + - gettext=0.22.5=h59595ed_2 + - gettext-tools=0.22.5=h59595ed_2 + - gmp=6.3.0=h59595ed_1 + - gmpy2=2.1.5=py310hc7909c9_1 + - gnutls=3.7.9=hb077bed_0 + - graphite2=1.3.13=h59595ed_1003 + - harfbuzz=8.5.0=hfac3d4d_0 + - icu=73.2=h59595ed_0 + - idna=3.7=pyhd8ed1ab_0 + - jinja2=3.1.4=pyhd8ed1ab_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.16=hb7c19ff_0 + - ld_impl_linux-64=2.40=hf3520f5_3 + - lerc=4.0.0=h27087fc_0 + - libabseil=20240116.2=cxx17_h59595ed_0 + - libasprintf=0.22.5=h661eb56_2 + - libasprintf-devel=0.22.5=h661eb56_2 + - libass=0.17.1=h8fe9dca_1 + - libblas=3.9.0=16_linux64_mkl + - libcblas=3.9.0=16_linux64_mkl + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.10.0.4=0 + - libcurand=10.3.6.39=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.20=hd590300_0 + - libdrm=2.4.120=hd590300_0 + - libexpat=2.6.2=h59595ed_0 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=13.2.0=h77fa898_7 + - libgettextpo=0.22.5=h59595ed_2 + - libgettextpo-devel=0.22.5=h59595ed_2 + - libgfortran-ng=13.2.0=h69a702a_7 + - libgfortran5=13.2.0=hca663fb_7 + - libglib=2.80.2=hf974151_0 + - libgomp=13.2.0=h77fa898_7 + - libhwloc=2.10.0=default_h5622ce7_1001 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.7=hd590300_0 + - libjpeg-turbo=3.0.0=hd590300_1 + - liblapack=3.9.0=16_linux64_mkl + - liblapacke=3.9.0=16_linux64_mkl + - libnpp=12.0.2.50=0 + - libnsl=2.0.1=hd590300_0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libopenvino=2024.1.0=h2da1b83_7 + - libopenvino-auto-batch-plugin=2024.1.0=hb045406_7 + - libopenvino-auto-plugin=2024.1.0=hb045406_7 + - libopenvino-hetero-plugin=2024.1.0=h5c03a75_7 + - libopenvino-intel-cpu-plugin=2024.1.0=h2da1b83_7 + - libopenvino-intel-gpu-plugin=2024.1.0=h2da1b83_7 + - libopenvino-intel-npu-plugin=2024.1.0=he02047a_7 + - libopenvino-ir-frontend=2024.1.0=h5c03a75_7 + - libopenvino-onnx-frontend=2024.1.0=h07e8aee_7 + - libopenvino-paddle-frontend=2024.1.0=h07e8aee_7 + - libopenvino-pytorch-frontend=2024.1.0=he02047a_7 + - libopenvino-tensorflow-frontend=2024.1.0=h39126c6_7 + - libopenvino-tensorflow-lite-frontend=2024.1.0=he02047a_7 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.43=h2797004_0 + - libprotobuf=4.25.3=h08a7969_0 + - libsqlite=3.46.0=hde9e2c9_0 + - libstdcxx-ng=13.2.0=hc0a3c3a_7 + - libtasn1=4.19.0=h166bdaf_0 + - libtiff=4.6.0=h1dd3fc0_3 + - libunistring=0.9.10=h7f98852_0 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.21.0=h4ab18f5_2 + - libvpx=1.14.1=hac33072_0 + - libwebp-base=1.4.0=hd590300_0 + - libxcb=1.15=h0b41bf4_0 + - libxcrypt=4.4.36=hd590300_1 + - libxml2=2.12.7=hc051c1a_1 + - libzlib=1.3.1=h4ab18f5_1 + - llvm-openmp=15.0.7=h0cdce71_0 + - markupsafe=2.1.5=py310h2372a71_0 + - mkl=2022.1.0=h84fe81f_915 + - mkl-devel=2022.1.0=ha770c72_916 + - mkl-include=2022.1.0=h84fe81f_915 + - mpc=1.3.1=hfe3b2da_0 + - mpfr=4.2.1=h9458935_1 + - mpmath=1.3.0=pyhd8ed1ab_0 + - ncurses=6.5=h59595ed_0 + - nettle=3.9.1=h7ab15ed_0 + - networkx=3.3=pyhd8ed1ab_1 + - numpy=1.26.4=py310hb13e2d6_0 + - ocl-icd=2.3.2=hd590300_1 + - openh264=2.4.1=h59595ed_0 + - openjpeg=2.5.2=h488ebb8_0 + - openssl=3.3.1=h4ab18f5_0 + - p11-kit=0.24.1=hc5aa10d_0 + - pcre2=10.43=hcad00b1_0 + - pillow=10.3.0=py310hf73ecf8_0 + - pip=24.0=pyhd8ed1ab_0 + - pixman=0.43.2=h59595ed_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - pugixml=1.14=h59595ed_0 + - pysocks=1.7.1=pyha2e5f31_6 + - python=3.10.14=hd12c33a_0_cpython + - python_abi=3.10=4_cp310 + - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py310h2372a71_1 + - readline=8.2=h8228510_1 + - snappy=1.2.0=hdb0a2a9_1 + - svt-av1=2.1.0=hac33072_0 + - sympy=1.12=pypyh9d50eac_103 + - tbb=2021.12.0=h297d8ca_1 + - tk=8.6.13=noxft_h4845f30_101 + - torchaudio=2.1.0=py310_cu121 + - torchtriton=2.1.0=py310 + - torchvision=0.16.0=py310_cu121 + - typing_extensions=4.12.2=pyha770c72_0 + - wheel=0.43.0=pyhd8ed1ab_1 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.9=h8ee46fc_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxrender=0.9.11=hd590300_0 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - yaml=0.2.5=h7f98852_2 + - zlib=1.3.1=h4ab18f5_1 + - zstd=1.5.6=ha6fb4c9_0 + - pip: + - absl-py==2.1.0 + - addict==2.4.0 + - aliyun-python-sdk-core==2.15.1 + - aliyun-python-sdk-kms==2.16.3 + - ccimport==0.4.2 + - cffi==1.16.0 + - click==8.1.7 + - colorama==0.4.6 + - contourpy==1.2.1 + - crcmod==1.7 + - cryptography==42.0.8 + - cumm-cu120==0.4.11 + - cycler==0.12.1 + - dropblock==0.3.0 + - fire==0.6.0 + - fonttools==4.53.0 + - fsspec==2024.6.0 + - grpcio==1.64.1 + - importlib-metadata==7.1.0 + - jmespath==0.10.0 + - kiwisolver==1.4.5 + - lark==1.1.9 + - markdown==3.6 + - markdown-it-py==3.0.0 + - matplotlib==3.9.0 + - mdurl==0.1.2 + - mmcls==0.25.0 + - mmcv-full==1.7.2 + - model-index==0.1.11 + - ninja==1.11.1.1 + - opencv-python==4.10.0.82 + - opendatalab==0.0.10 + - openmim==0.3.9 + - openxlab==0.1.0 + - ordered-set==4.1.0 + - oss2==2.17.0 + - packaging==24.1 + - pandas==2.2.2 + - pccm==0.4.11 + - platformdirs==4.2.2 + - portalocker==2.8.2 + - protobuf==4.25.3 + - pybind11==2.12.0 + - pycparser==2.22 + - pycryptodome==3.20.0 + - pyg-lib==0.4.0+pt21cu121 + - pygments==2.18.0 + - pyparsing==3.1.2 + - python-dateutil==2.9.0.post0 + - pytz==2023.4 + - requests==2.28.2 + - rich==13.4.2 + - scipy==1.13.1 + - setuptools==60.2.0 + - six==1.16.0 + - spconv-cu120==2.3.6 + - tabulate==0.9.0 + - tensorboard==2.17.0 + - tensorboard-data-server==0.7.2 + - tensorboardx==2.6.2.2 + - termcolor==2.4.0 + - thop==0.1.1-2209072238 + - tomli==2.0.1 + - torch-cluster==1.6.3+pt21cu121 + - torch-scatter==2.1.2+pt21cu121 + - torch-sparse==0.6.18+pt21cu121 + - torch-spline-conv==1.2.2+pt21cu121 + - torchprofile==0.0.4 + - tqdm==4.65.2 + - tzdata==2024.1 + - urllib3==1.26.18 + - werkzeug==3.0.3 + - yapf==0.40.2 + - zipp==3.19.2 +prefix: /home/jmwang/miniforge3/envs/occ_rwkv diff --git a/networks/bev_net.py b/networks/bev_net.py index 8b43c4a..5e6a62c 100644 --- a/networks/bev_net.py +++ b/networks/bev_net.py @@ -3,67 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from dropblock import DropBlock2D - - -class BEVFusion(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, bev_features, sem_features, com_features): - return torch.cat([bev_features, sem_features, com_features], dim=1) - - @staticmethod - def channel_reduction(x, out_channels): - """ - Args: - x: (B, C1, H, W) - out_channels: C2 - - Returns: - - """ - B, in_channels, H, W = x.shape - assert (in_channels % out_channels == 0) and (in_channels >= out_channels) - - x = x.view(B, out_channels, -1, H, W) - # x = torch.max(x, dim=2)[0] - x = x.sum(dim=2) - return x - - -class BEVUNet(nn.Module): - def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock): - super().__init__() - self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding) - self.down1 = down(64, 128, dilation, group_conv, circular_padding) - self.down2 = down(256, 256, dilation, group_conv, circular_padding) - self.down3 = down(512, 512, dilation, group_conv, circular_padding) - self.down4 = down(1024, 512, dilation, group_conv, circular_padding) - self.up1 = up(1536, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up2 = up(1024, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up3 = up(512, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.dropout = nn.Dropout(p=0. if dropblock else dropout) - self.outc = outconv(128, n_class) - - self.bev_fusions = nn.ModuleList([BEVFusion() for _ in range(3)]) - - def forward(self, x, sem_fea_list, com_fea_list): - x1 = self.inc(x) # [B, 64, 256, 256] - x2 = self.down1(x1) # [B, 128, 128, 128] - x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 256 - x3 = self.down2(x2_f) # [B, 256, 64, 64] - x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 512 - x4 = self.down3(x3_f) # [B, 512, 32, 32] - x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 1024 - x5 = self.down4(x4_f) # [B, 512, 16, 16] - x = self.up1(x5, x4_f) - x = self.up2(x, x3_f) - x = self.up3(x, x2_f) - x = self.up4(x, x1) - x = self.outc(self.dropout(x)) - return x - +from networks.vrwkv import Block as RWKVBlock class BEVFusionv1(nn.Module): def __init__(self, channel): @@ -106,30 +46,62 @@ def forward(self, bev_features, sem_features, com_features): class BEVUNetv1(nn.Module): def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock): super().__init__() + # Encoder self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding) self.down1 = down(64, 128, dilation, group_conv, circular_padding) self.down2 = down(128, 256, dilation, group_conv, circular_padding) self.down3 = down(256, 512, dilation, group_conv, circular_padding) self.down4 = down(512, 512, dilation, group_conv, circular_padding) - self.up1 = up(1024, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up2 = up(768, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up3 = up(384, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) - self.dropout = nn.Dropout(p=0. if dropblock else dropout) - self.outc = outconv(128, n_class) - + + # RWKV blocks for the encoder + + self.rwkv_block1 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0) + self.rwkv_block2 = RWKVBlock(n_embd=256, n_layer=18, layer_id=0) + self.rwkv_block3 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0) + self.rwkv_block4 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0) + self.rwkv_block_up1 = RWKVBlock(n_embd=512, n_layer=18, layer_id=0) + self.rwkv_block_up2 = RWKVBlock(n_embd=256, n_layer=18, layer_id=0) + self.rwkv_block_up3 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0) + self.rwkv_block_up4 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0) + + # BEV fusion modules channels = [128, 256, 512] self.bev_fusions = nn.ModuleList([BEVFusionv1(channels[i]) for i in range(3)]) + # Decoder + self.up1 = up(1024, 512, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up2 = up(768, 256, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up3 = up(384, 128, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up4 = up(192, 128, circular_padding, bilinear=bilinear, group_conv=group_conv, use_dropblock=dropblock, drop_p=dropout) + self.outc = outconv(128, n_class) + + self.dropout = nn.Dropout(p=0. if dropblock else dropout) + def forward(self, x, sem_fea_list, com_fea_list): - x1 = self.inc(x) # [B, 64, 256, 256] - x2 = self.down1(x1) # [B, 128, 128, 128] - x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 128 - x3 = self.down2(x2_f) # [B, 256, 64, 64] - x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 256 - x4 = self.down3(x3_f) # [B, 512, 32, 32] - x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 512 - x5 = self.down4(x4_f) # [B, 512, 16, 16] + # Encoder with RWKV blocks + x1 = self.inc(x) + + x2 = self.down1(x1) + B, C, H, W = x2.shape + x2 = self.rwkv_block1(x2.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2) + x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) + + x3 = self.down2(x2_f) + B, C, H, W = x3.shape + x3 = self.rwkv_block2(x3.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2) + x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) + + x4 = self.down3(x3_f) + B, C, H, W = x4.shape + x4 = self.rwkv_block3(x4.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2) + x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) + + x5 = self.down4(x4_f) + B, C, H, W = x5.shape + x5 = self.rwkv_block4(x5.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=(H, W)).view(B, H, W, C).permute(0, 3, 1, 2) + + + # return x x = self.up1(x5, x4_f) # 512, 512 x = self.up2(x, x3_f) # 512, 256 x = self.up3(x, x2_f) # 256, 128 diff --git a/networks/completion.py b/networks/completion.py index a4c1adf..7dc8798 100644 --- a/networks/completion.py +++ b/networks/completion.py @@ -1,9 +1,9 @@ import torch.nn as nn import torch.nn.functional as F import torch - +from networks.vrwkv import Block as RWKVBlock from utils.lovasz_losses import lovasz_softmax - +from utils.ssc_loss import geo_scal_loss class ResBlock(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, padding, stride, dilation=1): @@ -36,6 +36,11 @@ def __init__(self, init_size=32, nbr_class=20, phase='trainval'): self.block_1 = make_layers(16, 16, kernel_size=3, padding=1, stride=1, dilation=1, blocks=1) # 1/2, 16 self.block_2 = make_layers(16, 32, kernel_size=3, padding=1, stride=1, dilation=1, downsample=True, blocks=1) # 1/4, 32 self.block_3 = make_layers(32, 64, kernel_size=3, padding=2, stride=1, dilation=2, downsample=True, blocks=1) # 1/8, 64 + + # RWKV blocks + self.rwkv_block1 = RWKVBlock(n_embd=64, n_layer=18, layer_id=0) + self.rwkv_block2 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0) + self.rwkv_block3 = RWKVBlock(n_embd=256,n_layer=18, layer_id=0) self.reduction_1 = nn.Sequential( nn.Conv2d(256, 128, kernel_size=1), @@ -72,7 +77,19 @@ def forward_once(self, inputs): bev_1 = self.reduction_1(res1.flatten(1, 2)) # B, 64, 128, 128 bev_2 = self.reduction_2(res2.flatten(1, 2)) # B, 128, 64, 64 bev_3 = res3.flatten(1, 2) # B, 256, 32, 32 - + + B, C, H, W = bev_1.shape + patch_resolution = (H, W) + bev_1 = self.rwkv_block1(bev_1.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) + + B, C, H, W = bev_2.shape + patch_resolution = (H, W) + bev_2 = self.rwkv_block2(bev_2.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) + + B, C, H, W = bev_3.shape + patch_resolution = (H, W) + bev_3 = self.rwkv_block3(bev_3.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) + if self.phase == 'trainval': logits_2 = self.out2(res1) logits_4 = self.out4(res2) @@ -123,4 +140,4 @@ def forward(self, data_dict, example): else: out_dict = self.forward_once(data_dict['vw_dense']) return out_dict - + diff --git a/networks/dsc.py b/networks/occrwkv.py similarity index 98% rename from networks/dsc.py rename to networks/occrwkv.py index 013f64b..56f699e 100644 --- a/networks/dsc.py +++ b/networks/occrwkv.py @@ -7,12 +7,13 @@ import torch.nn.functional as F from .preprocess import PcPreprocessor -from .bev_net import BEVUNet, BEVUNetv1 +from .bev_net import BEVUNetv1 from .completion import CompletionBranch from .semantic_segmentation import SemanticBranch from utils.lovasz_losses import lovasz_softmax -class DSC(nn.Module): + +class OccRWKV(nn.Module): def __init__(self, cfg, phase='trainval'): super().__init__() self.phase = phase @@ -83,7 +84,6 @@ def compute_loss(self, scores, labels, ss_loss_dict, sc_loss_dict): :param: prediction: the predicted tensor, must be [BS, C, H, W, D] ''' class_weights = self.get_class_weights().to(device=scores.device, dtype=scores.dtype) - loss_1_1 = F.cross_entropy(scores, labels.long(), weight=class_weights, ignore_index=255) loss_1_1 += lovasz_softmax(torch.nn.functional.softmax(scores, dim=1), labels.long(), ignore=255) loss_1_1 *= 3 diff --git a/networks/semantic_segmentation.py b/networks/semantic_segmentation.py index 50856b9..42dc189 100644 --- a/networks/semantic_segmentation.py +++ b/networks/semantic_segmentation.py @@ -3,13 +3,14 @@ import torch import torch.nn.functional as F import torch_scatter - import spconv.pytorch as spconv import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from utils.lovasz_losses import lovasz_softmax +from networks.vrwkv import Block as RWKVBlock +from utils.ssc_loss import sem_scal_loss class BasicBlock(spconv.SparseModule): def __init__(self, C_in, C_out, indice_key): @@ -174,6 +175,11 @@ def __init__(self, sizes=[256, 256, 32], nbr_class=19, init_size=32, class_frequ self.proj3_block = SGFE(input_channels=128, output_channels=256,\ reduce_channels=128, name="proj3") + # RWKV blocks + self.rwkv_block1 = RWKVBlock(n_embd=64, n_layer=18, layer_id=0) + self.rwkv_block2 = RWKVBlock(n_embd=128, n_layer=18, layer_id=0) + self.rwkv_block3 = RWKVBlock(n_embd=256,n_layer=18, layer_id=0) + self.phase = phase if phase == 'trainval': num_class = self.nbr_class # SemanticKITTI: 19 @@ -215,39 +221,40 @@ def forward_once(self, vw_features, coord_ind, full_coord, pw_label, info): vw_features, coord.int(), np.array(self.sizes, np.int32)[::-1], batch_size ) conv1_output = self.conv1_block(input_tensor) - proj1_vw, vw1_coord, pw1_coord = self.proj1_block(info, conv1_output.features, output_scale=2, input_coords=coord.int(), - input_coords_inv=full_coord) + + proj1_vw, vw1_coord, pw1_coord = self.proj1_block(info, conv1_output.features, output_scale=2, input_coords=coord.int(), input_coords_inv=full_coord) proj1_bev = self.bev_projection(proj1_vw, vw1_coord, (np.array(self.sizes, np.int32) // 2)[::-1], batch_size) - - conv2_input_tensor = spconv.SparseConvTensor( - proj1_vw, vw1_coord.int(), (np.array(self.sizes, np.int32) // 2)[::-1], batch_size - ) + B, C, H, W = proj1_bev.shape + patch_resolution = (H, W) + proj1_bev_rwkv = self.rwkv_block1(proj1_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) + + + conv2_input_tensor = spconv.SparseConvTensor(proj1_vw, vw1_coord.int(), (np.array(self.sizes, np.int32) // 2)[::-1], batch_size) conv2_output = self.conv2_block(conv2_input_tensor) - proj2_vw, vw2_coord, pw2_coord = self.proj2_block(info, conv2_output.features, output_scale=4, input_coords=vw1_coord.int(), - input_coords_inv=pw1_coord) + proj2_vw, vw2_coord, pw2_coord = self.proj2_block(info, conv2_output.features, output_scale=4, input_coords=vw1_coord.int(), input_coords_inv=pw1_coord) proj2_bev = self.bev_projection(proj2_vw, vw2_coord, (np.array(self.sizes, np.int32) // 4)[::-1], batch_size) + B, C, H, W = proj2_bev.shape + patch_resolution = (H, W) + proj2_bev_rwkv = self.rwkv_block2(proj2_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) - conv3_input_tensor = spconv.SparseConvTensor( - proj2_vw, vw2_coord.int(), (np.array(self.sizes, np.int32) // 4)[::-1], batch_size - ) + conv3_input_tensor = spconv.SparseConvTensor(proj2_vw, vw2_coord.int(), (np.array(self.sizes, np.int32) // 4)[::-1], batch_size) conv3_output = self.conv3_block(conv3_input_tensor) - proj3_vw, vw3_coord, _ = self.proj3_block(info, conv3_output.features, output_scale=8, input_coords=vw2_coord.int(), - input_coords_inv=pw2_coord) + proj3_vw, vw3_coord, _ = self.proj3_block(info, conv3_output.features, output_scale=8, input_coords=vw2_coord.int(), input_coords_inv=pw2_coord) proj3_bev = self.bev_projection(proj3_vw, vw3_coord, (np.array(self.sizes, np.int32) // 8)[::-1], batch_size) + B, C, H, W = proj3_bev.shape + patch_resolution = (H, W) + proj3_bev_rwkv = self.rwkv_block3(proj3_bev.permute(0, 2, 3, 1).reshape(B, H * W, C), patch_resolution=patch_resolution).view(B, H, W, C).permute(0, 3, 1, 2) if self.phase == 'trainval': - index_02 = torch.cat([info[2]['bxyz_indx'][:, 0].unsqueeze(-1), - torch.flip(info[2]['bxyz_indx'], dims=[1])[:, :3]], dim=1) - index_04 = torch.cat([info[4]['bxyz_indx'][:, 0].unsqueeze(-1), - torch.flip(info[4]['bxyz_indx'], dims=[1])[:, :3]], dim=1) - index_08 = torch.cat([info[8]['bxyz_indx'][:, 0].unsqueeze(-1), - torch.flip(info[8]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + index_02 = torch.cat([info[2]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[2]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + index_04 = torch.cat([info[4]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[4]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + index_08 = torch.cat([info[8]['bxyz_indx'][:, 0].unsqueeze(-1), torch.flip(info[8]['bxyz_indx'], dims=[1])[:, :3]], dim=1) vw_label_02 = voxel_sem_target(index_02.int(), pw_label.int())[0] vw_label_04 = voxel_sem_target(index_04.int(), pw_label.int())[0] vw_label_08 = voxel_sem_target(index_08.int(), pw_label.int())[0] return dict( - mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev], + mss_bev_dense = [proj1_bev_rwkv, proj2_bev_rwkv, proj3_bev_rwkv], mss_logits_list = [ [vw_label_02.clone(), self.out2(proj1_vw)], [vw_label_04.clone(), self.out4(proj2_vw)], @@ -255,7 +262,7 @@ def forward_once(self, vw_features, coord_ind, full_coord, pw_label, info): ) return dict( - mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev] + mss_bev_dense = [proj1_bev_rwkv, proj2_bev_rwkv, proj3_bev_rwkv] ) def forward(self, data_dict, example): @@ -284,6 +291,49 @@ def forward(self, data_dict, example): out_dict = self.forward_once(data_dict['vw_features'], data_dict['coord_ind'], data_dict['full_coord'], None, data_dict['info']) return out_dict + + + + # def forward(self, data_dict, example): + # if self.phase == 'trainval': + # out_dict = self.forward_once(data_dict['vw_features'], + # data_dict['coord_ind'], data_dict['full_coord'], + # example['points_label'], data_dict['info']) + # all_teach_pair = out_dict['mss_logits_list'] + + # # Define weights for each scale. Adjust these as needed. + # # For example, giving the first scale (0 index) a larger weight. + # scale_weights = [3.0, 1.0, 1.0] + # loss_dict = {} + # for i, teach_pair in enumerate(all_teach_pair): + # voxel_labels_copy = teach_pair[0].long().clone() + # voxel_labels_copy[voxel_labels_copy == 0] = 256 + # voxel_labels_copy = voxel_labels_copy - 1 + + # #res04_loss = lovasz_softmax(F.softmax(teach_pair[1], dim=1), voxel_labels_copy, ignore=255) + # sem_scal_loss_val = sem_scal_loss(teach_pair[1], voxel_labels_copy) # Calculate semantic scale loss + + # # Apply the weight to the scale-specific losses + # #weighted_res04_loss = res04_loss * scale_weights[i] + # weighted_sem_scal_loss_val = sem_scal_loss_val * scale_weights[i] + + # # Update the loss dictionary with weighted losses + # #loss_dict["vw_" + str(i) + "_lovasz_loss"] = weighted_res04_loss + # loss_dict["vw_" + str(i) + "_sem_scal_loss"] = weighted_sem_scal_loss_val + + # # Sum up all computed losses to form the total loss + # total_loss = sum(loss_dict.values()) + # loss_dict["total_loss"] = total_loss + + # return dict( + # mss_bev_dense=out_dict['mss_bev_dense'], + # loss=loss_dict + # ) + # else: + # out_dict = self.forward_once(data_dict['vw_features'], + # data_dict['coord_ind'], data_dict['full_coord'], + # None, data_dict['info']) + # return out_dict def get_class_weights(self): ''' diff --git a/networks/vrwkv.py b/networks/vrwkv.py new file mode 100644 index 0000000..81f070b --- /dev/null +++ b/networks/vrwkv.py @@ -0,0 +1,439 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +from typing import Sequence +import math, os +import logging +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch.utils.checkpoint as cp + + + + + +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + dst_weight = F.interpolate( + src_weight, size=dst_shape, align_corners=False, mode=mode) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + + return torch.cat((extra_tokens, dst_weight), dim=1) + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, patch_size, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + x = self.proj(x).flatten(2).transpose(1, 2) + return x + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + + +logger = logging.getLogger(__name__) + +T_MAX = 66000 + + +from torch.utils.cpp_extension import load +wkv_cuda = load(name="wkv", sources=["/home/jmwang/OccRWKV/networks/cuda/wkv_op.cpp", "/home/jmwang/OccRWKV/networks/cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + + +class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + # print(T) + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + + half_mode = (w.dtype == torch.half) + bf_mode = (w.dtype == torch.bfloat16) + ctx.save_for_backward(w, u, k, v) + w = w.float().contiguous() + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + if half_mode: + y = y.half() + elif bf_mode: + y = y.bfloat16() + return y + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + half_mode = (w.dtype == torch.half) + bf_mode = (w.dtype == torch.bfloat16) + wkv_cuda.backward(B, T, C, + w.float().contiguous(), + u.float().contiguous(), + k.float().contiguous(), + v.float().contiguous(), + gy.float().contiguous(), + gw, gu, gk, gv) + if half_mode: + gw = torch.sum(gw.half(), dim=0) + gu = torch.sum(gu.half(), dim=0) + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif bf_mode: + gw = torch.sum(gw.bfloat16(), dim=0) + gu = torch.sum(gu.bfloat16(), dim=0) + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + else: + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw, gu, gk, gv) + + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + + +def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None): + assert gamma <= 1/4 + B, N, C = input.shape + input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1]) + B, C, H, W = input.shape + output = torch.zeros_like(input) + output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel] + output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W] + output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :] + output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :] + output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...] + return output.flatten(2).transpose(1, 2) + + +class VRWKV_SpatialMix(nn.Module): + def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift', + channel_gamma=1/4, shift_pixel=1, init_mode='fancy', + key_norm=False, with_cp=False): + super().__init__() + self.layer_id = layer_id + self.n_layer = n_layer + self.n_embd = n_embd + self.device = None + attn_sz = n_embd + self._init_weights(init_mode) + self.shift_pixel = shift_pixel + self.shift_mode = shift_mode + if shift_pixel > 0: + self.shift_func = eval(shift_mode) + self.channel_gamma = channel_gamma + else: + self.spatial_mix_k = None + self.spatial_mix_v = None + self.spatial_mix_r = None + + self.key = nn.Linear(n_embd, attn_sz, bias=False) + self.value = nn.Linear(n_embd, attn_sz, bias=False) + self.receptance = nn.Linear(n_embd, attn_sz, bias=False) + if key_norm: + self.key_norm = nn.LayerNorm(n_embd) + else: + self.key_norm = None + self.output = nn.Linear(attn_sz, n_embd, bias=False) + + self.key.scale_init = 0 + self.receptance.scale_init = 0 + self.output.scale_init = 0 + + self.with_cp = with_cp + + def _init_weights(self, init_mode): + if init_mode=='fancy': + with torch.no_grad(): # fancy init + ratio_0_to_1 = (self.layer_id / (self.n_layer - 1)) # 0 to 1 + ratio_1_to_almost0 = (1.0 - (self.layer_id / self.n_layer)) # 1 to ~0 + + # fancy time_decay + decay_speed = torch.ones(self.n_embd) + for h in range(self.n_embd): + decay_speed[h] = -5 + 8 * (h / (self.n_embd-1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.spatial_decay = nn.Parameter(decay_speed) + + # fancy time_first + zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(self.n_embd)]) * 0.5) + self.spatial_first = nn.Parameter(torch.ones(self.n_embd) * math.log(0.3) + zigzag) + + # fancy time_mix + x = torch.ones(1, 1, self.n_embd) + for i in range(self.n_embd): + x[0, 0, i] = i / self.n_embd + self.spatial_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.spatial_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.spatial_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) + elif init_mode=='local': + self.spatial_decay = nn.Parameter(torch.ones(self.n_embd)) + self.spatial_first = nn.Parameter(torch.ones(self.n_embd)) + self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd])) + self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd])) + self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd])) + elif init_mode=='global': + self.spatial_decay = nn.Parameter(torch.zeros(self.n_embd)) + self.spatial_first = nn.Parameter(torch.zeros(self.n_embd)) + self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5) + self.spatial_mix_v = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5) + self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5) + else: + raise NotImplementedError + + def jit_func(self, x, patch_resolution): + # Mix x with the previous timestep to produce xk, xv, xr + B, T, C = x.size() + if self.shift_pixel > 0: + xx = self.shift_func(x, self.shift_pixel, self.channel_gamma, patch_resolution) + xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k) + xv = x * self.spatial_mix_v + xx * (1 - self.spatial_mix_v) + xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r) + else: + xk = x + xv = x + xr = x + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + + return sr, k, v + + def forward(self, x, patch_resolution=None): + + def _inner_forward(x): + + B, T, C = x.size() + self.device = x.device + + sr, k, v = self.jit_func(x, patch_resolution) + x = RUN_CUDA(B, T, C, self.spatial_decay / T, self.spatial_first / T, k, v) + if self.key_norm is not None: + x = self.key_norm(x) + x = sr * x + x = self.output(x) + return x + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class VRWKV_ChannelMix(nn.Module): + def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift', + channel_gamma=1/4, shift_pixel=1, hidden_rate=4, init_mode='fancy', + key_norm=False, with_cp=False): + super().__init__() + self.layer_id = layer_id + self.n_layer = n_layer + self.n_embd = n_embd + self.with_cp = with_cp + self._init_weights(init_mode) + self.shift_pixel = shift_pixel + self.shift_mode = shift_mode + if shift_pixel > 0: + self.shift_func = eval(shift_mode) + self.channel_gamma = channel_gamma + else: + self.spatial_mix_k = None + self.spatial_mix_r = None + + hidden_sz = hidden_rate * n_embd + self.key = nn.Linear(n_embd, hidden_sz, bias=False) + if key_norm: + self.key_norm = nn.LayerNorm(hidden_sz) + else: + self.key_norm = None + self.receptance = nn.Linear(n_embd, n_embd, bias=False) + self.value = nn.Linear(hidden_sz, n_embd, bias=False) + + self.value.scale_init = 0 + self.receptance.scale_init = 0 + + def _init_weights(self, init_mode): + if init_mode == 'fancy': + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = (1.0 - (self.layer_id / self.n_layer)) # 1 to ~0 + x = torch.ones(1, 1, self.n_embd) + for i in range(self.n_embd): + x[0, 0, i] = i / self.n_embd + self.spatial_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.spatial_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + elif init_mode == 'local': + self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd])) + self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd])) + elif init_mode == 'global': + self.spatial_mix_k = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5) + self.spatial_mix_r = nn.Parameter(torch.ones([1, 1, self.n_embd]) * 0.5) + else: + raise NotImplementedError + + def forward(self, x, patch_resolution=None): + def _inner_forward(x): + if self.shift_pixel > 0: + xx = self.shift_func(x, self.shift_pixel, self.channel_gamma, patch_resolution) + xk = x * self.spatial_mix_k + xx * (1 - self.spatial_mix_k) + xr = x * self.spatial_mix_r + xx * (1 - self.spatial_mix_r) + else: + xk = x + xr = x + + k = self.key(xk) + k = torch.square(torch.relu(k)) + if self.key_norm is not None: + k = self.key_norm(k) + kv = self.value(k) + x = torch.sigmoid(self.receptance(xr)) * kv + return x + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class Block(nn.Module): + def __init__(self, n_embd, n_layer, layer_id, shift_mode='q_shift', + channel_gamma=1/4, shift_pixel=1, drop_path=0., hidden_rate=4, + init_mode='fancy', init_values=None, post_norm=False, key_norm=False, + with_cp=False): + super().__init__() + self.layer_id = layer_id + self.ln1 = nn.LayerNorm(n_embd) + self.ln2 = nn.LayerNorm(n_embd) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(n_embd) + + self.att = VRWKV_SpatialMix(n_embd, n_layer, layer_id, shift_mode, + channel_gamma, shift_pixel, init_mode, + key_norm=key_norm) + + self.ffn = VRWKV_ChannelMix(n_embd, n_layer, layer_id, shift_mode, + channel_gamma, shift_pixel, hidden_rate, + init_mode, key_norm=key_norm) + self.layer_scale = (init_values is not None) + self.post_norm = post_norm + if self.layer_scale: + self.gamma1 = nn.Parameter(init_values * torch.ones((n_embd)), requires_grad=True) + self.gamma2 = nn.Parameter(init_values * torch.ones((n_embd)), requires_grad=True) + self.with_cp = with_cp + + def forward(self, x, patch_resolution=None): + def _inner_forward(x): + if self.layer_id == 0: + x = self.ln0(x) + if self.post_norm: + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.ln1(self.att(x, patch_resolution))) + x = x + self.drop_path(self.gamma2 * self.ln2(self.ffn(x, patch_resolution))) + else: + x = x + self.drop_path(self.ln1(self.att(x, patch_resolution))) + x = x + self.drop_path(self.ln2(self.ffn(x, patch_resolution))) + else: + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.att(self.ln1(x), patch_resolution)) + x = x + self.drop_path(self.gamma2 * self.ffn(self.ln2(x), patch_resolution)) + else: + x = x + self.drop_path(self.att(self.ln1(x), patch_resolution)) + x = x + self.drop_path(self.ffn(self.ln2(x), patch_resolution)) + return x + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + diff --git a/test.py b/test.py index 6c9d916..f1b495b 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ import sys import numpy as np import time - +os.environ["CUDA_VISIBLE_DEVICES"] = "3" # Append root directory to system path for imports repo_path, _ = os.path.split(os.path.realpath(__file__)) repo_path, _ = os.path.split(repo_path) @@ -21,7 +21,7 @@ def parse_args(): - parser = argparse.ArgumentParser(description='DSC validating') + parser = argparse.ArgumentParser(description='OCCRWKV validating') parser.add_argument( '--weights', dest='weights_file', diff --git a/train.py b/train.py index b012155..530720d 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,6 @@ # -*- coding:utf-8 -*- import os -os.environ["CUDA_VISIBLE_DEVICES"] = "2" +os.environ["CUDA_VISIBLE_DEVICES"] = "1" import argparse import torch import torch.nn as nn @@ -24,11 +24,11 @@ import utils.checkpoint as checkpoint def parse_args(): - parser = argparse.ArgumentParser(description='DSC training') + parser = argparse.ArgumentParser(description='OccRWKV training') parser.add_argument( '--cfg', dest='config_file', - default='cfgs/SSC-RS.yaml', + default='cfgs/OccRWKV.yaml', metavar='FILE', help='path to config file', type=str, @@ -147,7 +147,6 @@ def train(model, optimizer, scheduler, dataset, _cfg, start_epoch, logger, tbwri for scale in metrics.evaluator.keys(): tbwriter.add_scalar('train_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch - 1) tbwriter.add_scalar('train_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch - 1) - tbwriter.add_scalar('train_performance/{}/Seg_mIoU'.format(scale), seg_miou, epoch - 1) tbwriter.add_scalar('train_performance/{}/Precision'.format(scale), metrics.get_occupancy_Precision(scale).item(), epoch-1) tbwriter.add_scalar('train_performance/{}/Recall'.format(scale), metrics.get_occupancy_Recall(scale).item(), epoch-1) tbwriter.add_scalar('train_performance/{}/F1'.format(scale), metrics.get_occupancy_F1(scale).item(), epoch-1) @@ -256,10 +255,10 @@ def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics): checkpoint_info = {} - if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']: - logger.info('=> Best loss on validation set encountered: ({} < {})'.format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS'])) - _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item() - checkpoint_info['best-loss'] = 'BEST_LOSS' + # if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']: + # logger.info('=> Best loss on validation set encountered: ({} < {})'.format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS'])) + # _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item() + # checkpoint_info['best-loss'] = 'BEST_LOSS' mIoU_1_1 = metrics.get_semantics_mIoU('1_1') IoU_1_1 = metrics.get_occupancy_IoU('1_1') @@ -276,9 +275,9 @@ def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics): def main(): # https://github.com/pytorch/pytorch/issues/27588 - torch.backends.cudnn.enabled = False + torch.backends.cudnn.enabled = True - seed_all(10) + seed_all(42) args = parse_args() @@ -304,9 +303,7 @@ def main(): logger.info('=> Loading network architecture...') model = get_model(_cfg._dict, phase='trainval') - if torch.cuda.device_count() > 1: - model = nn.DataParallel(model) - model = model.module + logger.info(f'=> Model Parameters: {sum(p.numel() for p in model.parameters())/1000000.0} M') logger.info('=> Loading optimizer...') diff --git a/validate.py b/validate.py index 88f52c1..d7de92d 100644 --- a/validate.py +++ b/validate.py @@ -3,12 +3,12 @@ import torch import torch.nn as nn import sys - +from thop import profile # Append root directory to system path for imports repo_path, _ = os.path.split(os.path.realpath(__file__)) repo_path, _ = os.path.split(repo_path) sys.path.append(repo_path) - +os.environ["CUDA_VISIBLE_DEVICES"] = "3" from utils.seed import seed_all from utils.config import CFG from utils.dataset import get_dataset @@ -44,6 +44,7 @@ def parse_args(): def validate(model, dset, _cfg, logger, metrics): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') dtype = torch.float32 # Tensor type to be used @@ -54,7 +55,8 @@ def validate(model, dset, _cfg, logger, metrics): logger.info('=> Passing the network on the validation set...') time_list = [] model.eval() - + flops_list = [] + total_flops = 0 with torch.no_grad(): for t, (data, indices) in enumerate(tqdm(dset, ncols=100)): @@ -62,6 +64,15 @@ def validate(model, dset, _cfg, logger, metrics): start_time = time.time() scores, loss = model(data) time_list.append(time.time() - start_time) + input_data = next(iter(data.values())) + + # 计算FLOPs + flops, params = profile(model, inputs=(data,), verbose=False) + total_flops += flops + current_gflops = flops / 1e9 + logger.info(f'Current batch GFLOPs: {current_gflops}') + + # Updating batch losses to then get mean for epoch loss metrics.losses_track.update_validaiton_losses(loss) @@ -93,6 +104,9 @@ def validate(model, dset, _cfg, logger, metrics): class_score = metrics.evaluator['1_1'].getIoU()[1][i] logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score)) + # 计算总FLOPs并转换为GFLOPs + total_gflops = (total_flops / len(dset.dataset)) / 1e9 # 转换为GFLOPs + logger.info(f'Average GFLOPs per input for validation set: {total_gflops}') return time_list