From f63627fd7cc56b24b7c36725c53591c7b7fb2595 Mon Sep 17 00:00:00 2001 From: jmwang0117 <1021347250@qq.com> Date: Tue, 11 Jun 2024 00:08:51 +0000 Subject: [PATCH] Initial --- README.md | 90 ++++- cfgs/2024.6.11.yaml | 93 +++++ cfgs/DSC-Base.yaml | 88 +++++ data/semantic-kitti.yaml | 211 +++++++++++ datasets/__init__.py | 15 + datasets/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 821 bytes .../semantic_kitti.cpython-310.pyc | Bin 0 -> 6713 bytes datasets/io_data.py | 233 ++++++++++++ datasets/label_downsample.py | 110 ++++++ datasets/semantic_kitti.py | 217 ++++++++++++ networks/__init__.py | 0 networks/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 134 bytes networks/__pycache__/bev_net.cpython-310.pyc | Bin 0 -> 8523 bytes .../__pycache__/completion.cpython-310.pyc | Bin 0 -> 4131 bytes networks/__pycache__/dsc.cpython-310.pyc | Bin 0 -> 5197 bytes .../__pycache__/preprocess.cpython-310.pyc | Bin 0 -> 4872 bytes .../semantic_segmentation.cpython-310.pyc | Bin 0 -> 10031 bytes networks/bev_net.py | 294 ++++++++++++++++ networks/completion.py | 126 +++++++ networks/dsc.py | 138 ++++++++ networks/preprocess.py | 139 ++++++++ networks/semantic_segmentation.py | 295 ++++++++++++++++ scripts/run_train.sh | 1 + test.py | 139 ++++++++ train.py | 332 ++++++++++++++++++ utils/__init__.py | 0 utils/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 131 bytes utils/__pycache__/checkpoint.cpython-310.pyc | Bin 0 -> 2377 bytes utils/__pycache__/config.cpython-310.pyc | Bin 0 -> 2614 bytes utils/__pycache__/dataset.cpython-310.pyc | Bin 0 -> 967 bytes utils/__pycache__/io_tools.cpython-310.pyc | Bin 0 -> 1508 bytes utils/__pycache__/logger.cpython-310.pyc | Bin 0 -> 747 bytes .../__pycache__/lovasz_losses.cpython-310.pyc | Bin 0 -> 10398 bytes utils/__pycache__/metrics.cpython-310.pyc | Bin 0 -> 6844 bytes utils/__pycache__/model.cpython-310.pyc | Bin 0 -> 313 bytes utils/__pycache__/optimizer.cpython-310.pyc | Bin 0 -> 1189 bytes utils/__pycache__/seed.cpython-310.pyc | Bin 0 -> 503 bytes utils/__pycache__/time.cpython-310.pyc | Bin 0 -> 477 bytes utils/checkpoint.py | 98 ++++++ utils/config.py | 95 +++++ utils/dataset.py | 24 ++ utils/io_tools.py | 44 +++ utils/logger.py | 33 ++ utils/lovasz_losses.py | 321 +++++++++++++++++ utils/metrics.py | 201 +++++++++++ utils/model.py | 5 + utils/optimizer.py | 36 ++ utils/seed.py | 15 + utils/time.py | 12 + validate.py | 151 ++++++++ 50 files changed, 3555 insertions(+), 1 deletion(-) create mode 100644 cfgs/2024.6.11.yaml create mode 100644 cfgs/DSC-Base.yaml create mode 100644 data/semantic-kitti.yaml create mode 100644 datasets/__init__.py create mode 100644 datasets/__pycache__/__init__.cpython-310.pyc create mode 100644 datasets/__pycache__/semantic_kitti.cpython-310.pyc create mode 100644 datasets/io_data.py create mode 100644 datasets/label_downsample.py create mode 100644 datasets/semantic_kitti.py create mode 100644 networks/__init__.py create mode 100644 networks/__pycache__/__init__.cpython-310.pyc create mode 100644 networks/__pycache__/bev_net.cpython-310.pyc create mode 100644 networks/__pycache__/completion.cpython-310.pyc create mode 100644 networks/__pycache__/dsc.cpython-310.pyc create mode 100644 networks/__pycache__/preprocess.cpython-310.pyc create mode 100644 networks/__pycache__/semantic_segmentation.cpython-310.pyc create mode 100644 networks/bev_net.py create mode 100644 networks/completion.py create mode 100644 networks/dsc.py create mode 100644 networks/preprocess.py create mode 100644 networks/semantic_segmentation.py create mode 100644 scripts/run_train.sh create mode 100644 test.py create mode 100644 train.py create mode 100644 utils/__init__.py create mode 100644 utils/__pycache__/__init__.cpython-310.pyc create mode 100644 utils/__pycache__/checkpoint.cpython-310.pyc create mode 100644 utils/__pycache__/config.cpython-310.pyc create mode 100644 utils/__pycache__/dataset.cpython-310.pyc create mode 100644 utils/__pycache__/io_tools.cpython-310.pyc create mode 100644 utils/__pycache__/logger.cpython-310.pyc create mode 100644 utils/__pycache__/lovasz_losses.cpython-310.pyc create mode 100644 utils/__pycache__/metrics.cpython-310.pyc create mode 100644 utils/__pycache__/model.cpython-310.pyc create mode 100644 utils/__pycache__/optimizer.cpython-310.pyc create mode 100644 utils/__pycache__/seed.cpython-310.pyc create mode 100644 utils/__pycache__/time.cpython-310.pyc create mode 100644 utils/checkpoint.py create mode 100644 utils/config.py create mode 100644 utils/dataset.py create mode 100644 utils/io_tools.py create mode 100644 utils/logger.py create mode 100644 utils/lovasz_losses.py create mode 100644 utils/metrics.py create mode 100644 utils/model.py create mode 100644 utils/optimizer.py create mode 100644 utils/seed.py create mode 100644 utils/time.py create mode 100644 validate.py diff --git a/README.md b/README.md index 9030b40..d82934c 100644 --- a/README.md +++ b/README.md @@ -1 +1,89 @@ -# OccRWKV \ No newline at end of file +# OccRWKV + +## OccRWKV: Rethinking Sparse Latent Representation for 3D Semantic Occupancy Prediction + + +## Preperation + +### Prerequisites +``` +conda create -n occ_rwkv python=3.10 -y +conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y +pip install spconv-cu120 +pip install tensorboardX +pip install dropblock +pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html +``` + +### Dataset + +Please download the Semantic Scene Completion dataset (v1.1) from the [SemanticKITTI website](http://www.semantic-kitti.org/dataset.html) and extract it. + +Or you can use [voxelizer](https://github.com/jbehley/voxelizer) to generate ground truths of semantic scene completion. + +The dataset folder should be organized as follows. +```angular2 +SemanticKITTI +├── dataset +│ ├── sequences +│ │ ├── 00 +│ │ │ ├── labels +│ │ │ ├── velodyne +│ │ │ ├── voxels +│ │ │ ├── [OTHER FILES OR FOLDERS] +│ │ ├── 01 +│ │ ├── ... ... +``` + +## Getting Start +Clone the repository: +``` +https://github.com/jmwang0117/OccRWKV.git +``` + +We provide training routine examples in the `cfgs` folder. Make sure to change the dataset path to your extracted dataset location in such files if you want to use them for training. Additionally, you can change the folder where the performance and states will be stored. +* `config_dict['DATASET']['DATA_ROOT']` should be changed to the root directory of the SemanticKITTI dataset (`/.../SemanticKITTI/dataset/sequences`) +* `config_dict['OUTPUT']['OUT_ROOT'] ` should be changed to desired output folder. + +### Train SSC-RS Net + +``` +$ cd +$ python train.py --cfg cfgs/DSC-Base.yaml --dset_root +``` +### Validation + +Validation passes are done during training routine. Additional pass in the validation set with saved model can be done by using the `validate.py` file. You need to provide the path to the saved model and the dataset root directory. + +``` +$ cd +$ python validate.py --weights --dset_root +``` +### Test + +Since SemantiKITTI contains a hidden test set, we provide test routine to save predicted output in same format of SemantiKITTI, which can be compressed and uploaded to the [SemanticKITTI Semantic Scene Completion Benchmark](http://www.semantic-kitti.org/tasks.html#ssc). + +We recommend to pass compressed data through official checking script provided in the [SemanticKITTI Development Kit](http://www.semantic-kitti.org/resources.html#devkit) to avoid any issue. + +You can provide which checkpoints you want to use for testing. We used the ones that performed best on the validation set during training. For testing, you can use the following command. + +``` +$ cd +$ python test.py --weights --dset_root --out_path +``` +### Pretrained Model + +You can download the models with the scores below from this [Google drive link](https://drive.google.com/file/d/1-b3O7QS6hBQIGFTO-7qSG7Zb9kbQuxdO/view?usp=sharing), + +| Model | Segmentation | Completion | +|--|--|--| +| SSC-RS | 24.2 | 59.7 | + +* Results reported to SemanticKITTI: Semantic Scene Completion leaderboard ([link](https://codalab.lisn.upsaclay.fr/competitions/7170\#results)). + +## Acknowledgement +This project is not possible without multiple great opensourced codebases. +* [spconv](https://github.com/traveller59/spconv) +* [LMSCNet](https://github.com/cv-rits/LMSCNet) +* [SSA-SC](https://github.com/jokester-zzz/SSA-SC) +* [GASN](https://github.com/ItIsFriday/PcdSeg) diff --git a/cfgs/2024.6.11.yaml b/cfgs/2024.6.11.yaml new file mode 100644 index 0000000..f49de0f --- /dev/null +++ b/cfgs/2024.6.11.yaml @@ -0,0 +1,93 @@ +DATALOADER: + NUM_WORKERS: 4 +DATASET: + CONFIG_FILE: data/semantic-kitti.yaml + DATA_ROOT: data/SemanticKITTI/dataset/sequences + GRID_METERS: + - 0.2 + - 0.2 + - 0.2 + LIMS: + - - 0 + - 51.2 + - - -25.6 + - 25.6 + - - -2 + - 4.4 + NCLASS: 20 + SC_CLASS_FREQ: + - 7632350044 + - 15783539 + - 125136 + - 118809 + - 646799 + - 821951 + - 262978 + - 283696 + - 204750 + - 61688703 + - 4502961 + - 44883650 + - 2269923 + - 56840218 + - 15719652 + - 158442623 + - 2061623 + - 36970522 + - 1151988 + - 334146 + SIZES: + - 256 + - 256 + - 32 + SS_CLASS_FREQ: + - 55437630 + - 320797 + - 541736 + - 2578735 + - 3274484 + - 552662 + - 184064 + - 78858 + - 240942562 + - 17294618 + - 170599734 + - 6369672 + - 230413074 + - 101130274 + - 476491114 + - 9833174 + - 129609852 + - 4506626 + - 1168181 + TYPE: SemanticKITTI +MODEL: + TYPE: DSC-AFC +OPTIMIZER: + BASE_LR: 0.001 + BETA1: 0.9 + BETA2: 0.999 + MOMENTUM: NA + TYPE: Adam + WEIGHT_DECAY: NA +OUTPUT: + BEST_LOSS: 999999999999 + BEST_METRIC: -999999999999 + OUTPUT_PATH: ./outputs/DSC-AFC_SemanticKITTI_0611_000540 + OUT_ROOT: ./outputs/ +SCHEDULER: + FREQUENCY: epoch + LR_POWER: 0.98 + TYPE: power_iteration +STATUS: + CONFIG: /home/jmwang/OccRWKV/cfgs/2024.6.11.yaml + LAST: '' + RESUME: false +TRAIN: + BATCH_SIZE: 2 + CHECKPOINT_PERIOD: 15 + EPOCHS: 80 + SUMMARY_PERIOD: 50 +VAL: + BATCH_SIZE: 2 + SUMMARY_PERIOD: 20 diff --git a/cfgs/DSC-Base.yaml b/cfgs/DSC-Base.yaml new file mode 100644 index 0000000..2c3fef2 --- /dev/null +++ b/cfgs/DSC-Base.yaml @@ -0,0 +1,88 @@ +DATALOADER: + NUM_WORKERS: 4 +DATASET: + CONFIG_FILE: data/semantic-kitti.yaml + DATA_ROOT: data/SemanticKITTI/dataset/sequences + GRID_METERS: + - 0.2 + - 0.2 + - 0.2 + LIMS: + - - 0 + - 51.2 + - - -25.6 + - 25.6 + - - -2 + - 4.4 + NCLASS: 20 + SC_CLASS_FREQ: + - 7632350044 + - 15783539 + - 125136 + - 118809 + - 646799 + - 821951 + - 262978 + - 283696 + - 204750 + - 61688703 + - 4502961 + - 44883650 + - 2269923 + - 56840218 + - 15719652 + - 158442623 + - 2061623 + - 36970522 + - 1151988 + - 334146 + SIZES: + - 256 + - 256 + - 32 + SS_CLASS_FREQ: + - 55437630 + - 320797 + - 541736 + - 2578735 + - 3274484 + - 552662 + - 184064 + - 78858 + - 240942562 + - 17294618 + - 170599734 + - 6369672 + - 230413074 + - 101130274 + - 476491114 + - 9833174 + - 129609852 + - 4506626 + - 1168181 + TYPE: SemanticKITTI +MODEL: + TYPE: DSC-AFC +OPTIMIZER: + BASE_LR: 0.001 + BETA1: 0.9 + BETA2: 0.999 + MOMENTUM: NA + TYPE: Adam + WEIGHT_DECAY: NA +OUTPUT: + OUT_ROOT: ./outputs/ +SCHEDULER: + FREQUENCY: epoch + LR_POWER: 0.98 + TYPE: power_iteration +STATUS: + RESUME: false +TRAIN: + BATCH_SIZE: 2 + CHECKPOINT_PERIOD: 15 + EPOCHS: 80 + SUMMARY_PERIOD: 50 +VAL: + BATCH_SIZE: 2 + SUMMARY_PERIOD: 20 diff --git a/data/semantic-kitti.yaml b/data/semantic-kitti.yaml new file mode 100644 index 0000000..6281065 --- /dev/null +++ b/data/semantic-kitti.yaml @@ -0,0 +1,211 @@ +# This file is covered by the LICENSE file in the root of this project. +labels: + 0 : "unlabeled" + 1 : "outlier" + 10: "car" + 11: "bicycle" + 13: "bus" + 15: "motorcycle" + 16: "on-rails" + 18: "truck" + 20: "other-vehicle" + 30: "person" + 31: "bicyclist" + 32: "motorcyclist" + 40: "road" + 44: "parking" + 48: "sidewalk" + 49: "other-ground" + 50: "building" + 51: "fence" + 52: "other-structure" + 60: "lane-marking" + 70: "vegetation" + 71: "trunk" + 72: "terrain" + 80: "pole" + 81: "traffic-sign" + 99: "other-object" + 252: "moving-car" + 253: "moving-bicyclist" + 254: "moving-person" + 255: "moving-motorcyclist" + 256: "moving-on-rails" + 257: "moving-bus" + 258: "moving-truck" + 259: "moving-other-vehicle" +color_map: # bgr + 0 : [0, 0, 0] + 1 : [0, 0, 255] + 10: [245, 150, 100] + 11: [245, 230, 100] + 13: [250, 80, 100] + 15: [150, 60, 30] + 16: [255, 0, 0] + 18: [180, 30, 80] + 20: [255, 0, 0] + 30: [30, 30, 255] + 31: [200, 40, 255] + 32: [90, 30, 150] + 40: [255, 0, 255] + 44: [255, 150, 255] + 48: [75, 0, 75] + 49: [75, 0, 175] + 50: [0, 200, 255] + 51: [50, 120, 255] + 52: [0, 150, 255] + 60: [170, 255, 150] + 70: [0, 175, 0] + 71: [0, 60, 135] + 72: [80, 240, 150] + 80: [150, 240, 255] + 81: [0, 0, 255] + 99: [255, 255, 50] + 252: [245, 150, 100] + 256: [255, 0, 0] + 253: [200, 40, 255] + 254: [30, 30, 255] + 255: [90, 30, 150] + 257: [250, 80, 100] + 258: [180, 30, 80] + 259: [255, 0, 0] +content: # as a ratio with the total number of points + 0: 0.018889854628292943 + 1: 0.0002937197336781505 + 10: 0.040818519255974316 + 11: 0.00016609538710764618 + 13: 2.7879693665067774e-05 + 15: 0.00039838616015114444 + 16: 0.0 + 18: 0.0020633612104619787 + 20: 0.0016218197275284021 + 30: 0.00017698551338515307 + 31: 1.1065903904919655e-08 + 32: 5.532951952459828e-09 + 40: 0.1987493871255525 + 44: 0.014717169549888214 + 48: 0.14392298360372 + 49: 0.0039048553037472045 + 50: 0.1326861944777486 + 51: 0.0723592229456223 + 52: 0.002395131480328884 + 60: 4.7084144280367186e-05 + 70: 0.26681502148037506 + 71: 0.006035012012626033 + 72: 0.07814222006271769 + 80: 0.002855498193863172 + 81: 0.0006155958086189918 + 99: 0.009923127583046915 + 252: 0.001789309418528068 + 253: 0.00012709999297008662 + 254: 0.00016059776092534436 + 255: 3.745553104802113e-05 + 256: 0.0 + 257: 0.00011351574470342043 + 258: 0.00010157861367183268 + 259: 4.3840131989471124e-05 +# classes that are indistinguishable from single scan or inconsistent in +# ground truth are mapped to their closest equivalent +learning_map: + 0 : 0 # "unlabeled" + 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped + 10: 1 # "car" + 11: 2 # "bicycle" + 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped + 15: 3 # "motorcycle" + 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped + 18: 4 # "truck" + 20: 5 # "other-vehicle" + 30: 6 # "person" + 31: 7 # "bicyclist" + 32: 8 # "motorcyclist" + 40: 9 # "road" + 44: 10 # "parking" + 48: 11 # "sidewalk" + 49: 12 # "other-ground" + 50: 13 # "building" + 51: 14 # "fence" + 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped + 60: 9 # "lane-marking" to "road" ---------------------------------mapped + 70: 15 # "vegetation" + 71: 16 # "trunk" + 72: 17 # "terrain" + 80: 18 # "pole" + 81: 19 # "traffic-sign" + 99: 0 # "other-object" to "unlabeled" ----------------------------mapped + 252: 1 # "moving-car" to "car" ------------------------------------mapped + 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped + 254: 6 # "moving-person" to "person" ------------------------------mapped + 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped + 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped + 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped + 258: 4 # "moving-truck" to "truck" --------------------------------mapped + 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped +learning_map_inv: # inverse of previous map + 0: 0 # "unlabeled", and others ignored + 1: 10 # "car" + 2: 11 # "bicycle" + 3: 15 # "motorcycle" + 4: 18 # "truck" + 5: 20 # "other-vehicle" + 6: 30 # "person" + 7: 31 # "bicyclist" + 8: 32 # "motorcyclist" + 9: 40 # "road" + 10: 44 # "parking" + 11: 48 # "sidewalk" + 12: 49 # "other-ground" + 13: 50 # "building" + 14: 51 # "fence" + 15: 70 # "vegetation" + 16: 71 # "trunk" + 17: 72 # "terrain" + 18: 80 # "pole" + 19: 81 # "traffic-sign" +learning_ignore: # Ignore classes + 0: True # "unlabeled", and others ignored + 1: False # "car" + 2: False # "bicycle" + 3: False # "motorcycle" + 4: False # "truck" + 5: False # "other-vehicle" + 6: False # "person" + 7: False # "bicyclist" + 8: False # "motorcyclist" + 9: False # "road" + 10: False # "parking" + 11: False # "sidewalk" + 12: False # "other-ground" + 13: False # "building" + 14: False # "fence" + 15: False # "vegetation" + 16: False # "trunk" + 17: False # "terrain" + 18: False # "pole" + 19: False # "traffic-sign" +split: # sequence numbers + train: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 9 + - 10 + valid: + - 8 + test: + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..d605eef --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,15 @@ + +import torch +from .semantic_kitti import SemanticKitti + + +def collate_fn(data): + keys = data[0][0].keys() + out_dict = {} + for key in keys: + if key in ['points', 'points_label']: + out_dict[key] = [d[0][key] for d in data] + else: + out_dict[key] = torch.stack([d[0][key] for d in data]) + idx = [d[1] for d in data] + return out_dict, idx \ No newline at end of file diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b72338fe34b2076941b92ec1821ab4da61ef49b GIT binary patch literal 821 zcmZ`%!D`z;5S>}cj&N-^p()rGldCRvKfstCatowTLklWFs9M=kEm_X)CT)rgfd-l$ z^9RY%zqHq$a_+6C&d7C>ly;Yyo!yyv`(~J;)oK8(ch}R&1p<8M%flt;@{&aFk&+-W zg8-zEPKLP%1cG#B?Gge<;u!q;9Yw+69U*~_wlz-jXl~N@AT^}Ld<{UXUsY zyIaW(+`3XCjE})}g9s3R_D=g`xjw@zdY@#E*m4(`@iUSi9a2-mJMxxTqJ%E- z>%z4-A=DLCpetAjEBY@i$`x8?mYl0C%KGUKj*C1Qe8@jW^XcGK9KU{h@Ma(*6Y0e0 zK^Ug<)P!Mwac&#KEY&71^2O*0|H>3P@VOTkSr(ZjoXq>ftcasbkI2}B{+G3?QNH4i zS4@=076Y%ZEwaTC=U)S!VVv-N=x#qv{`47!uyWl(4rOs$u zGlhzeZB3gfp82kIxglzsf&#BNHK9ynW1UofT4PnLnWt^sH-jtFoU)BBdbB&(6Ky)x tb9(e$wL?eH*1rcbobf~|p1vwzjr^Fe^@e#NPqSpibE5eN51ZI@egRqTo=X4# literal 0 HcmV?d00001 diff --git a/datasets/__pycache__/semantic_kitti.cpython-310.pyc b/datasets/__pycache__/semantic_kitti.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77ad94c98ceaae8433f291a5ce5f47dfcef2205b GIT binary patch literal 6713 zcmai2-H#hr6`wm_9*;lvBb#J*6N+i7g41TxCZP=>G-=bOgjNeQ{Qx60n)r^r;~md9 zcgE?iV@pVBA0xpdASCY_5fc9a5`O^V2?=Q)%L7{a(g%>BLJGfg$M)LZU14j^+;i`5 z&b{~Cdp_=sN+nyt?J~CrZua1~>7}@f^4Cws@J_yzmLD<=<9#k(WMEc*$ceyH#ivA8EA`SMMs- zg?;p zoVFjzm%EedZ8zQ)Bu0`iFJ9b=+TO)h`(rn3Ui^N&{@zFLe0Y&QjJ+hj7<+9uO#HgD z?I(%9(b-LNB*2L}7L9=9EHa%5(qml(8UIWa3_Nn6LlhDW73xtwRXJ;@YZx|XYfvd3 z_4Jt2h+9=2X#;f%uZ{53dVnYH0IyBq^$}j5!4r3Y*ZXQ;yQ%bbD9z{_P`j}y7{=MW zaO>>npS`kvW1qNvno~8KU)4<2PCeBFza7_1sfL}=dw2Bi2~x{}dPN>=YBjyYA$$>B zOsaZ`(Z5KVTL4?9uEsy=aiTdJEs4@%eWeAxVFv07@fz12s{_<>o$H@Mo7}jea1--k zRKwLzwd*G!y- zeyW;S#Vgetf!|5>deqrP6{#@IQ#7+Dtbc7r7sO`{a z&q6@_JctDYt}wm4{7=1T{Y%Fh>iOmW=sD|4st2c?eF_sl2WTD7%#RlHCLS4EJcVKj z;s~p0c{Q){RT)88O&>)NV-!J5p69lVCVGgxT>#9EG)f_K6=Swc_L>;&c(Yyq4>@V%ra!H7oYH_puvn_llXI zrii*BuUYrHe!zi9jRs5_T|e;PGo|8d)OEOB%O#!=E&`d}i2|=zOoZEL`1K31-wbcp z$&$xEDo@PhPhk3Rl#Y_Hp_*uY==hPkM>U9=Ex1aAB7_C-MqdLpWo=fIlVt=B)LXJ_ z$Z}qmO4Sq=ZX*tM40KMi_@Ir zhY!+xJqjX0jlH$+yjJ;vD`J2vldW*){rB%w(tK2}cROxa-|baCkbsAS3Ry;6X^5y@ ziM{(>TAsM_aL}{v#1$CGwMzA5nx}lh5mA(+`mWm!(tPYTJSV{7r+U=!LhMFD)T6NB zH_$u4l zd~!#?8>b}-xd*jc;DxCjc45Cj#!hNBC_E%pE7fB!XowATRa=}Iqy&n~<6(#^RQe{7 zSBOwV$Ve999gM_J9OpXf1c$%6$Skz1;NLqnSpb<`jfKpl z1LP+t2;CydI+6qI2&6&VnpouOo61jf1J8WgH;6Bc zS~1v5f_8VWG(pI^H|J1S#9N0@SH$f()K&5JA=FjzU8=LsZ;a>CGtO_^^TTQ}Gx1l^ zEG`it15P!#5OI}CW*mtGQ-|FtsfwIWCpVMot;i3h(?CA)9r?+&x0`Cv!klXm-$Lt0 zc;at>*s2A8pgy%Vo0Zivv)Os|{OqH>sOq2h&dhp-8E-a|6Vma4@>b~{Dj23uhd}8f zxcZn45S)ezGChQ+76Ra&0*BBVm|W|d50v-}*Xb>uLtr$ANH+H104OAdjOPZq3?*(R zMvKBC>r=?XDFi$m0!|=SA4gl{eK!yU5y}QuxR6*!p!0-&#DKRlm%{Uqlv^i9QsB9% zvYnM18XOg&i;z~5Wu;XXNwUzl;KtB%vN&AoTL|PBIUL^-jgJ?y@eP;haJ{Dv@^UOh zOLDw+-yYh$40+_(PVfb|(LWsea&b?E1f)@rP>>Rcu1E>Rei0H(NI;37!0EAw&|k#d zi;#0tq(Gh;o=%M6D$n=r)*8)p53Y7l>f8Mia0fl7{i3X)FW1Ev$r&I}!tBZ@m1nSD z082H5EsHOq2b)qEzVcWb*isiZ>4Kk>&RNdv=S-AkOHsC*LJKYjs;4*KN2V;#DuG9- zZE8p>MIM}akCZw4Pw|r}wphaElIeApR!bjn2X!^;ffq!4H}tBi_<%sCh}&%5(IyXi3cutS3QC&)$ygzIrwb+5+im#|gTLAfrrm*i|gPM^edvvV7j6 zbc7;#YSP*3$YkRa3izp+<^O2`lH1)ZHHCE4Ha)RUBcp6Zp606yBbVUw9q~2P))q#I zMG$y((s->H*PY2k!Nmt$s$sy$dg?9`>^x;(4t-rKjFd(GM|?<9m!{R{$a-=#$jTE3 zUDpcG5Rztj^DyU?vMULa+r=00dmuU-lf}yHdC+w{7E^!rjJknZ6K-meokHSbGaDyW zdQ-cp5{}9y5}a|V498Z499^AHZ^(&Jmow=N-F(m^{TCg(vA>b#3JpcvyL=n3Crs*d zehN9qL00EDzXc>F|7^0}$yv8KLLBOl(6U3ii=m677BZW0r#JDw?U#;!#WEuC5sg@4PVhtoU zU{aSa(LGm;69y8Bt`}GJYz=-$0Nj1;P`Jxaw(y*e#E$^`J)W41qO7lIXPK_k9>K0S zBD789%&Zh`fr*yh05hW{CbfhMkSQ&xGg=yRV=d9$s@8Mt%IG)gwUR^?$y7#*O4mg3 z=$KrQH=h-p7b-Y%BFhh(8yg!FA=iYLbVYb=QgPrltCol`>TD;?>R5EB;XV-oGD;#u zOce={xjnZ;wgdNk;3obERBVBiO{rbA_tn|Gf9Q*u9Y_{HeMl=fi$bCdae{y)s0ey& z!fAWZT4s#8cn&z)Hl2er8hk&cKY~2S{GhEQ!if*(&Of73UKJzdIph__phLr><(I=- zM1i~lnJHx-YZL#VX2(W4F^&h)KzV*3Yov=E+~f#c0A+^j#0k!c&A$LGt#^_2zJ6GL zM}9LiP#WY!j})M2pu$dAOC4$m&(J1PDcCg9Fdg~@-^W!N?ZoW5Sw_vOk;NrxgLOzj zB*JwQWuZQ}@%E5&qku9$!Ghypihkm?9p|rTm&#s+qwcNEPHGCbSygY}e!^~7%ch^^A4d(rnFIw#t@S)Fe_o1gct=S-5cM z`*(?41W64Ux-<7h?p%i8EFU;0+q;P$#Mz}^?oEUr+&E-Hj~G{KUC-wK8pR}eH(SIl QpM|ttQr)ub^7@JY0XV)Yp8x;= literal 0 HcmV?d00001 diff --git a/datasets/io_data.py b/datasets/io_data.py new file mode 100644 index 0000000..8995645 --- /dev/null +++ b/datasets/io_data.py @@ -0,0 +1,233 @@ +import numpy as np +import yaml +import imageio + + +def unpack(compressed): + ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' + uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) + uncompressed[::8] = compressed[:] >> 7 & 1 + uncompressed[1::8] = compressed[:] >> 6 & 1 + uncompressed[2::8] = compressed[:] >> 5 & 1 + uncompressed[3::8] = compressed[:] >> 4 & 1 + uncompressed[4::8] = compressed[:] >> 3 & 1 + uncompressed[5::8] = compressed[:] >> 2 & 1 + uncompressed[6::8] = compressed[:] >> 1 & 1 + uncompressed[7::8] = compressed[:] & 1 + + return uncompressed + + +def img_normalize(img, mean, std): + img = img.astype(np.float32) / 255.0 + img = img - mean + img = img / std + + return img + + +def pack(array): + """ convert a boolean array into a bitwise array. """ + array = array.reshape((-1)) + + #compressing bit flags. + # yapf: disable + compressed = array[::8] << 7 | array[1::8] << 6 | array[2::8] << 5 | array[3::8] << 4 | array[4::8] << 3 | array[5::8] << 2 | array[6::8] << 1 | array[7::8] + # yapf: enable + + return np.array(compressed, dtype=np.uint8) + + +def get_grid_coords(dims, resolution): + ''' + :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) + :return coords_grid: is the center coords of voxels in the grid + ''' + + # The sensor in centered in X (we go to dims/2 + 1 for the histogramdd) + g_xx = np.arange(-dims[0] / 2, dims[0] / 2 + 1) + # The sensor is in Y=0 (we go to dims + 1 for the histogramdd) + g_yy = np.arange(0, dims[1] + 1) + # The sensor is in Z=1.73. I observed that the ground was to voxel levels above the grid bottom, so Z pose is at 10 + # if bottom voxel is 0. If we want the sensor to be at (0, 0, 0), then the bottom in z is -10, top is 22 + # (we go to 22 + 1 for the histogramdd) + # ATTENTION.. Is 11 for old grids.. 10 for new grids (v1.1) (https://github.com/PRBonn/semantic-kitti-api/issues/49) + sensor_pose = 10 + g_zz = np.arange(0 - sensor_pose, dims[2] - sensor_pose + 1) + + # Obtaining the grid with coords... + xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1]) + coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T + coords_grid = coords_grid.astype(np.float) + + coords_grid = (coords_grid * resolution) + resolution / 2 + + temp = np.copy(coords_grid) + temp[:, 0] = coords_grid[:, 1] + temp[:, 1] = coords_grid[:, 0] + coords_grid = np.copy(temp) + + return coords_grid, g_xx, g_yy, g_zz + + +def _get_remap_lut(config_path): + ''' + remap_lut to remap classes of semantic kitti for training... + :return: + ''' + + dataset_config = yaml.safe_load(open(config_path, 'r')) + # make lookup table for mapping + maxkey = max(dataset_config['learning_map'].keys()) + + # +100 hack making lut bigger just in case there are unknown labels + remap_lut = np.zeros((maxkey + 100), dtype=np.int32) + remap_lut[list(dataset_config['learning_map'].keys())] = list(dataset_config['learning_map'].values()) + + # in completion we have to distinguish empty and invalid voxels. + # Important: For voxels 0 corresponds to "empty" and not "unlabeled". + remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' + remap_lut[0] = 0 # only 'empty' stays 'empty'. + + return remap_lut + + +def _read_SemKITTI(path, dtype, do_unpack): + bin = np.fromfile(path, dtype=dtype) # Flattened array + if do_unpack: + bin = unpack(bin) + return bin + + +def _read_label_SemKITTI(path): + label = _read_SemKITTI(path, dtype=np.uint16, do_unpack=False).astype(np.float32) + return label + + +def _read_label2_SemKITTI(path): + point_label = _read_SemKITTI(path, dtype=np.uint32, do_unpack=False).reshape((-1,1)) + point_label = point_label & 0xFFFF + return point_label + + +def _read_invalid_SemKITTI(path): + invalid = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True) + return invalid + + +def _read_occluded_SemKITTI(path): + occluded = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True) + return occluded + + +def _read_occupancy_SemKITTI(path): + occupancy = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True).astype(np.float32) + return occupancy + + +def _read_point_SemKITTI(path): + point = _read_SemKITTI(path, dtype=np.float32, do_unpack=False).reshape(-1, 4) + return point + + +def _read_rgb_SemKITTI(path): + rgb = np.asarray(imageio.imread(path)) + return rgb + + +def _read_pointcloud_SemKITTI(path): + 'Return pointcloud semantic kitti with remissions (x, y, z, intensity)' + pointcloud = _read_SemKITTI(path, dtype=np.float32, do_unpack=False) + pointcloud = pointcloud.reshape((-1, 4)) + return pointcloud + + +def _read_calib_SemKITTI(calib_path): + """ + :param calib_path: Path to a calibration text file. + :return: dict with calibration matrices. + """ + calib_all = {} + with open(calib_path, 'r') as f: + for line in f.readlines(): + if line == '\n': + break + key, value = line.split(':', 1) + calib_all[key] = np.array([float(x) for x in value.split()]) + + # reshape matrices + calib_out = {} + calib_out['P2'] = calib_all['P2'].reshape(3, 4) # 3x4 projection matrix for left camera + calib_out['Tr'] = np.identity(4) # 4x4 matrix + calib_out['Tr'][:3, :4] = calib_all['Tr'].reshape(3, 4) + return calib_out + + +def get_remap_lut(path): + ''' + remap_lut to remap classes of semantic kitti for training... + :return: + ''' + + dataset_config = yaml.safe_load(open(path, 'r')) + + # make lookup table for mapping + maxkey = max(dataset_config['learning_map'].keys()) + + # +100 hack making lut bigger just in case there are unknown labels + remap_lut = np.zeros((maxkey + 100), dtype=np.int32) + remap_lut[list(dataset_config['learning_map'].keys())] = list(dataset_config['learning_map'].values()) + + # in completion we have to distinguish empty and invalid voxels. + # Important: For voxels 0 corresponds to "empty" and not "unlabeled". + remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' + remap_lut[0] = 0 # only 'empty' stays 'empty'. + + return remap_lut + + +def data_augmentation_3Dflips(flip, data): + # The .copy() is done to avoid negative strides of the numpy array caused by the way numpy manages the data + # into memory. This gives errors when trying to pass the array to torch sensors.. Solution seen in: + # https://discuss.pytorch.org/t/torch-from-numpy-not-support-negative-strides/3663 + # Dims -> {XZY} + # Flipping around the X axis... + if np.isclose(flip, 1): + data = np.flip(data, axis=0).copy() + + # Flipping around the Y axis... + if np.isclose(flip, 2): + data = np.flip(data, 2).copy() + + # Flipping around the X and the Y axis... + if np.isclose(flip, 3): + data = np.flip(np.flip(data, axis=0), axis=2).copy() + + return data + + +def get_cmap_semanticKITTI20(): + colors = np.array([ + # [0 , 0 , 0, 255], + [100, 150, 245, 255], + [100, 230, 245, 255], + [30, 60, 150, 255], + [80, 30, 180, 255], + [100, 80, 250, 255], + [255, 30, 30, 255], + [255, 40, 200, 255], + [150, 30, 90, 255], + [255, 0, 255, 255], + [255, 150, 255, 255], + [75, 0, 75, 255], + [175, 0, 75, 255], + [255, 200, 0, 255], + [255, 120, 50, 255], + [0, 175, 0, 255], + [135, 60, 0, 255], + [150, 240, 80, 255], + [255, 240, 150, 255], + [255, 0, 0, 255] + ]).astype(np.uint8) + + return colors diff --git a/datasets/label_downsample.py b/datasets/label_downsample.py new file mode 100644 index 0000000..5b097b1 --- /dev/null +++ b/datasets/label_downsample.py @@ -0,0 +1,110 @@ +from glob import glob +import os +import numpy as np +import time +import argparse +import sys + +# Append root directory to system path for imports +repo_path, _ = os.path.split(os.path.realpath(__file__)) +repo_path, _ = os.path.split(repo_path) +repo_path, _ = os.path.split(repo_path) +sys.path.append(repo_path) + +import datasets.io_data as SemanticKittiIO + + +def parse_args(): + parser = argparse.ArgumentParser(description='LMSCNet labels lower scales creation') + parser.add_argument( + '--dset_root', + dest='dataset_root', + default='', + metavar='DATASET', + help='path to dataset root folder', + type=str, + ) + args = parser.parse_args() + return args + + +def majority_pooling(grid, k_size=2): + result = np.zeros((grid.shape[0] // k_size, grid.shape[1] // k_size, grid.shape[2] // k_size)) + for xx in range(0, int(np.floor(grid.shape[0]/k_size))): + for yy in range(0, int(np.floor(grid.shape[1]/k_size))): + for zz in range(0, int(np.floor(grid.shape[2]/k_size))): + + sub_m = grid[(xx*k_size):(xx*k_size)+k_size, (yy*k_size):(yy*k_size)+k_size, (zz*k_size):(zz*k_size)+k_size] + unique, counts = np.unique(sub_m, return_counts=True) + if True in ((unique != 0) & (unique != 255)): + # Remove counts with 0 and 255 + counts = counts[((unique != 0) & (unique != 255))] + unique = unique[((unique != 0) & (unique != 255))] + else: + if True in (unique == 0): + counts = counts[(unique != 255)] + unique = unique[(unique != 255)] + value = unique[np.argmax(counts)] + result[xx, yy, zz] = value + return result + + +def downscale_data(LABEL, downscaling): + # Majority pooling labels downscaled in 3D + LABEL = majority_pooling(LABEL, k_size=downscaling) + # Reshape to 1D + LABEL = LABEL.reshape(-1) + # Invalid file downscaled + INVALID = np.zeros_like(LABEL) + INVALID[np.isclose(LABEL, 255)] = 1 + return LABEL, INVALID + + +def main(): + + args = parse_args() + + dset_root = args.dataset_root + remap_lut = SemanticKittiIO.get_remap_lut(os.path.join('data', 'semantic-kitti.yaml')) + sequences = sorted(glob(os.path.join(dset_root, 'dataset', 'sequences', '*'))) + # Selecting training/validation set sequences only (labels unavailable for test set) + sequences = sequences[:11] + grid_dimensions = [256, 256, 32] # [W, H, D] + + assert len(sequences) > 0, 'Error, no sequences on selected dataset root path' + + for sequence in sequences: + label_paths = sorted(glob(os.path.join(sequence, 'voxels', '*.label'))) + invalid_paths = sorted(glob(os.path.join(sequence, 'voxels', '*.invalid'))) + out_dir = os.path.join(sequence, 'voxels') + downscaling = {'1_2': 2, '1_4': 4, '1_8': 8} + + for i in range(len(label_paths)): + filename, _ = os.path.splitext(os.path.basename(label_paths[i])) + + LABEL = SemanticKittiIO._read_label_SemKITTI(label_paths[i]) + INVALID = SemanticKittiIO._read_invalid_SemKITTI(invalid_paths[i]) + LABEL = remap_lut[LABEL.astype(np.uint16)].astype(np.float32) # Remap 20 classes semanticKITTI SSC + LABEL[np.isclose(INVALID, 1)] = 255 # Setting to unknown all voxels marked on invalid mask... + LABEL = LABEL.reshape(grid_dimensions) + + for scale in downscaling: + label_filename = os.path.join(out_dir, filename + '.label_' + scale) + invalid_filename = os.path.join(out_dir, filename + '.invalid_' + scale) + # If files have not been created... + if not (os.path.isfile(label_filename) & os.path.isfile(invalid_filename)): + LABEL_ds, INVALID_ds = downscale_data(LABEL, downscaling[scale]) + SemanticKittiIO.pack(INVALID_ds.astype(dtype=np.uint8)).tofile(invalid_filename) + print(time.strftime('%x %X') + ' -- => File {} - Sequence {} saved...'.format(filename + '.label_' + scale, os.path.basename(sequence))) + LABEL_ds.astype(np.uint16).tofile(label_filename) + print(time.strftime('%x %X') + ' -- => File {} - Sequence {} saved...'.format(filename + '.invalid_' + scale, os.path.basename(sequence))) + + print(time.strftime('%x %X') + ' -- => All files saved for Sequence {}'.format(os.path.basename(sequence))) + + print(time.strftime('%x %X') + ' -- => All files saved') + + exit() + +if __name__ == '__main__': + main() + diff --git a/datasets/semantic_kitti.py b/datasets/semantic_kitti.py new file mode 100644 index 0000000..07be2e3 --- /dev/null +++ b/datasets/semantic_kitti.py @@ -0,0 +1,217 @@ +from glob import glob +import torch + +import os +import yaml +import numpy as np + +def mask_op(data, x_min, x_max): + mask = (data > x_min) & (data < x_max) + return mask + +def get_mask(pc, lims): + mask_x = mask_op(pc[:, 0], lims[0][0] + 0.0001, lims[0][1] - 0.0001) + mask_y = mask_op(pc[:, 1], lims[1][0] + 0.0001, lims[1][1] - 0.0001) + mask_z = mask_op(pc[:, 2], lims[2][0] + 0.0001, lims[2][1] - 0.0001) + mask = (mask_x) & (mask_y) & mask_z + return mask + + +def unpack(compressed): + ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' + uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) + uncompressed[::8] = compressed[:] >> 7 & 1 + uncompressed[1::8] = compressed[:] >> 6 & 1 + uncompressed[2::8] = compressed[:] >> 5 & 1 + uncompressed[3::8] = compressed[:] >> 4 & 1 + uncompressed[4::8] = compressed[:] >> 3 & 1 + uncompressed[5::8] = compressed[:] >> 2 & 1 + uncompressed[6::8] = compressed[:] >> 1 & 1 + uncompressed[7::8] = compressed[:] & 1 + + return uncompressed + +def augmentation_random_flip(data, flip_type, is_scan=False): + if flip_type==1: + if is_scan: + data[:, 0] = 51.2 - data[:, 0] + else: + data = np.flip(data, axis=0).copy() + elif flip_type==2: + if is_scan: + data[:, 1] = -data[:, 1] + else: + data = np.flip(data, axis=1).copy() + elif flip_type==3: + if is_scan: + data[:, 0] = 51.2 - data[:, 0] + data[:, 1] = -data[:, 1] + else: + data = np.flip(np.flip(data, axis=0), axis=1).copy() + return data + +class SemanticKitti(torch.utils.data.Dataset): + CLASSES = ('unlabeled', + 'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', + 'person', 'bicyclist', 'motorcyclist', 'road', + 'parking', 'sidewalk', 'other-ground', 'building', 'fence', + 'vegetation', 'trunk', 'terrain', 'pole', 'traffic-sign') + + def __init__(self, data_root, data_config_file, setname, + lims, + sizes, + augmentation=False, + shuffle_index=False): + self.data_root = data_root + self.data_config = yaml.safe_load(open(data_config_file, 'r')) + self.sequences = self.data_config["split"][setname] + self.setname = setname + self.labels = self.data_config['labels'] + self.learning_map = self.data_config["learning_map"] + + self.learning_map_inv = self.data_config["learning_map_inv"] + self.color_map = self.data_config['color_map'] + + self.lims = lims + self.sizes = sizes + self.augmentation = augmentation + self.shuffle_index = shuffle_index + + self.filepaths = {} + print(f"=> Parsing SemanticKITTI {self.setname}") + self.get_filepaths() + self.num_files_ = len(self.filepaths['occupancy']) + print("Using {} scans from sequences {}".format(self.num_files_, self.sequences)) + print(f"Is aug: {self.augmentation}") + + def get_filepaths(self,): + # fill in with names, checking that all sequences are complete + if self.setname != 'test': + for key in ['label_1_1', 'invalid_1_1', 'label_1_2', 'invalid_1_2', 'label_1_4', 'invalid_1_4', 'label_1_8', 'invalid_1_8', 'occluded', 'occupancy']: + self.filepaths[key] = [] + else: + self.filepaths['occupancy'] = [] + for seq in self.sequences: + # to string + seq = '{0:02d}'.format(int(seq)) + print("parsing seq {}".format(seq)) + if self.setname != 'test': + # Scale 1_1 + self.filepaths['label_1_1'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.label'))) + self.filepaths['invalid_1_1'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.invalid'))) + # Scale 1_2 + self.filepaths['label_1_2'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.label_1_2'))) + self.filepaths['invalid_1_2'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.invalid_1_2'))) + # Scale 1_4 + self.filepaths['label_1_4'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.label_1_4'))) + self.filepaths['invalid_1_4'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.invalid_1_4'))) + # Scale 1_4 + self.filepaths['label_1_8'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.label_1_8'))) + self.filepaths['invalid_1_8'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.invalid_1_8'))) + + # occluded + self.filepaths['occluded'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.occluded'))) + + self.filepaths['occupancy'] += sorted(glob(os.path.join(self.data_root, seq, 'voxels', '*.bin'))) + + def get_data(self, idx, flip_type): + data_collection = {} + sc_remap_lut = self.get_remap_lut(completion=True) + ss_remap_lut = self.get_remap_lut() + for typ in self.filepaths.keys(): + scale = int(typ.split('_')[-1]) if 'label' in typ or 'invalid' in typ else 1 + if 'label' in typ: + scan_data = np.fromfile(self.filepaths[typ][idx], dtype=np.uint16) + if scale == 1: + scan_data = sc_remap_lut[scan_data] + else: + scan_data = unpack(np.fromfile(self.filepaths[typ][idx], dtype=np.uint8)) + scan_data = scan_data.reshape((self.sizes[0]//scale, self.sizes[1]//scale, self.sizes[2]//scale)) + scan_data = scan_data.astype(np.float32) + if self.augmentation: + scan_data = augmentation_random_flip(scan_data, flip_type) + data_collection[typ] = torch.from_numpy(scan_data) + + points_path = self.filepaths['occupancy'][idx].replace('voxels', 'velodyne') + points = np.fromfile(points_path, dtype=np.float32) + points = points.reshape((-1, 4)) + + if self.setname != 'test': + points_label_path = self.filepaths['occupancy'][idx].replace('voxels', 'labels').replace('.bin', '.label') + points_label = np.fromfile(points_label_path, dtype=np.uint32) + points_label = points_label.reshape((-1)) + points_label = points_label & 0xFFFF # semantic label in lower half + points_label = ss_remap_lut[points_label] + + if self.shuffle_index: + pt_idx = np.random.permutation(np.arange(0, points.shape[0])) + points = points[pt_idx] + if self.setname != 'test': + points_label = points_label[pt_idx] + + if self.lims: + filter_mask = get_mask(points, self.lims) + points = points[filter_mask] + if self.setname != 'test': + points_label = points_label[filter_mask] + + if self.augmentation: + points = augmentation_random_flip(points, flip_type, is_scan=True) + + data_collection['points'] = torch.from_numpy(points) + if self.setname != 'test': + data_collection['points_label'] = torch.from_numpy(points_label) + + return data_collection + + + def __len__(self): + return self.num_files_ + + def get_n_classes(self): + return len(self.learning_map_inv) + + def get_remap_lut(self, completion=False): + # put label from original values to xentropy + # or vice-versa, depending on dictionary values + # make learning map a lookup table + maxkey = max(self.learning_map.keys()) + + # +100 hack making lut bigger just in case there are unknown labels + remap_lut = np.zeros((maxkey + 100), dtype=np.int32) + remap_lut[list(self.learning_map.keys())] = list(self.learning_map.values()) + + # in completion we have to distinguish empty and invalid voxels. + # Important: For voxels 0 corresponds to "empty" and not "unlabeled". + if completion: + remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' + remap_lut[0] = 0 # only 'empty' stays 'empty'. + + return remap_lut + + def get_inv_remap_lut(self): + ''' + remap_lut to remap classes of semantic kitti for training... + :return: + ''' + # make lookup table for mapping + maxkey = max(self.learning_map_inv.keys()) + + # +100 hack making lut bigger just in case there are unknown labels + remap_lut = np.zeros((maxkey + 1), dtype=np.int32) + remap_lut[list(self.learning_map_inv.keys())] = list(self.learning_map_inv.values()) + + return remap_lut + + def to_color(self, label): + # put label in original values + label = SemanticKitti.map(label, self.learning_map_inv) + # put label in color + return SemanticKitti.map(label, self.color_map) + + def get_xentropy_class_string(self, idx): + return self.labels[self.learning_map_inv[idx]] + + def __getitem__(self, idx): + flip_type = np.random.randint(0, 4) if self.augmentation else 0 + return self.get_data(idx, flip_type), idx diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/__pycache__/__init__.cpython-310.pyc b/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e40345c2e5f3d47c566e777e3a56d2a149762253 GIT binary patch literal 134 zcmd1j<>g`kg6N6qX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2B#KO;XkRX;1Y zJTWg_-#PO2Tqh+-xn!NLFl D)l(bO literal 0 HcmV?d00001 diff --git a/networks/__pycache__/bev_net.cpython-310.pyc b/networks/__pycache__/bev_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82126f7337484988ae856f000f6fcb0cbc48a971 GIT binary patch literal 8523 zcmbVRTXP)8b)KHv&Rzg477z?kvV+jFtW_A&1eKdY(-I-mvMs}jNQw43$!NJV02bVv zdS*##vxpT#iS1Mnmt2)Rlc_uydCV_xmGhQ7=4oD%DyPbY$Ap!nN>#G*eP?EGARsE4 ztv+-5bWcxD_jfLRwmUahFmQe6{q3#)Id2$0XJz`bQF#d^`cII+2vp0cDY=_9Q|?yV zuGuPWbK16PD5LGxaw0KXUOQjQE8{ChUl1e72K(Z7pL$aJn zmLPc|Sb=0ElPm}7x>5bq5x(2ltg58=ithHVwz`em>#sB@oW88A5`6(oC`(i`zi(`* zKn3P|Ud>8Bzh;L{`q}%c=6>I>jam*b$tC%#Uw-piKWcV64K+0`c@u@hNAn=Dv2Wa0 z53t68aospnn^iM$qJA&biRb&xPBZp>%+-#<*4Cj)^1k0_)uYJw4~-9=yRg%3hZlC+ zck7+)3*T%sUcdSEH!pO;_-%P3|}?0ht30MkI$Lh-vdP+U+N5k z!diB{thb{V$2AJ?y}0(=)pKhb7tgJ|dT#Azb*8b~kNw6@z0(O>(Ti&v>xFUa>me4U zgR#aE&E-7$2-gw6OV5E+E#P6N-U}1^PBXllSW&-SbrW?jDNPTZsH;h_*%`koQLiTI z=J7>c%x2>2Fz7d8pmPDQh`99gg;iBCOQx*`PfeEjk^WVWxa`#1q3~d@qPYyjjj+}7{i!*6Xw}?(P5od^rsf$X65!VlCxy!yr#J;eZiL;FNa8CSb1s8+YjK zJ43^_Vw;KpiIYhjNMIJGBrYWQO-XW)I5QG2&Ic|e9>F)v4|758Acye_kUJRdxV#9t zjrSgxmmv2}kk7^Qy9>j--U;#@h0zNK0K8w07tvSko8T+){QafW3dhUA+`f9n*jEqO z>j~PV>ip)>OIVPjcS&HVK!DC88i=L=ojbG>?^d(b?1Xil6t;D@-}4*Y&YfhgABBDZ zBYTTRHgU=Ny`!IEoL`{G1>>Q5GbWPA=LTFyi7Fu5(7``+_7qL!`{vuQJOY_&iUyb+ z^ZXYQHH{-NAcf1VW)#E3^sZdwP#m{n44%n0er)*8U&id}t>_Av(E~@Ez~PE^a1$&i zZnGobWZ|CJLHBMaaY!$wxApXPDZRarnEhU2^?Mgd)=4grK;G#jIr<&wZj$4I2m>Oc z+hmf$H$QOv0LljyJ@rnQ!Xw$m3RUdP`Fo7+2a;swnXw5{9GaQ%k9 z0FsoOogQq^t$N(p@jG4J)?XEIqp2JHR$cqOdJr@_+ev=BRaz8L46+}l9VM$*me}!a|3_b#HW*X(=i{ua1P;4TUG0=O51 z`;95vDjP|QdPot+h7iU~Nm6+r6f#@e7llUn%?^`Ia z$`4d*?mE~t3DyoMbsrcHELgPMC+g*WqF#CQ!dn9C3{Z@>a{R&xvZ7c~;DivVBJR8$ zYD5Em)EtC@m}+TaX!KXmDJfhJ-|mN?5&2&&UIBpFD3T6yVs25xl;xY>y7WO`z6*w9S zOIl$lpz3UOV+U$1Dk_|b7|#5ZAv_HRc$f@C9auhqqH2G1k5BV3>$Tg1T2QN7VeNVA{X5 z?ZMLYdS#lYeNNyc^?C~pv?CMsTG5m}R=s#QVekXx;<%%W{5bq6|9`?CnA)jehv5&R z=qU-qpV@$j;SU3%NjT$;;R{>+9sshr_YekS}$<&*HIAmLB((eP&*htI!N zKQtHy$qA3gPYw+o z&yII}tcoYaKKgg?jDC^inEDX2`*v)JLqkXun6b;#>OOSsz(E^=AUHg0 z*Mnv`2jmbsF?fPDoYLn3f+;^ww{(2$z&*gbK5opDdrU^G<~CCV@v9EaBsU5#T{#&< zH-Oe0Lef^f5gsbtN3Uuj^?w?C8dlH=>RTiM2yAk*gP4F3>?Q(Oc)t8=Vg2?yT-U_r zS|2*8>vH4>oVbXiaTJySSpPo95flrc#wh*~f|>7gJ7 zcSGN0{T+~MF$MY$D3JY`1-hoW)_+VwLzg&&Hcptlt?NIb`~)2TJ>=0jkZBwdp2-@)SsL(P*=ZECb0}lckk0xj#SQm=tAt@ua4-fiiPhSuH!c=-)-X!L1!#%$Z334Jiu<=N-QPtQCtk|K$F z4c6!aCMJfmnol9h?>G)o0VY>LY?#=Is+&g2-=*|j60x?gle8mYb2$OxW-0O=$9`Dh>PsGKS_ z{o80eIiX|QndHAwzn2~;1=@+*?$OYE+`bA-PTSXi$2;WDf^=VJcXcqwY}x;1H<90a zpR+z~KmUV`vY#ax`mwvoHTXH_NR^aq#GICmLUraMmUu{qE%=9q(1alY()h5Qfd%}4pPe;IoI*x6g*=I(%Og!ku=Ne zn{grVkSNX{(C|p|LJ}B-3|7o6-oU>Sr3hKm(BG;z@UM^lOGuGuXf?NymFyzFp(A8w z-6+=0Ak=rziv9fS^hOjNaXHYVAh}`5h14Vg>GyQ^)+i&wHJ!jV@!se~^>(inCJP%t z>J8*Zqh2=(*8@TD#4Kt4Yo{5I=^K1ccAlSDDZwzmB(;3f{UWO;*sK4BTr!`)wY?1A z@{D@gw9Seto3801VzfUR6#Kmiwnx?g&oMtP=6{2O{0+!dP-x!AAAgeUR3mQHK(-^5 zGqIV;1t zRn)z#F8_#RyDNu3an5s?EX)VD0DnI+SH)0K(}hW3!k;wANLTB#a+mzDG)tsQY zwe{`v_N`On-6Dd3EHAQZt-}15mHSa8S3q`2oOuDuo9TDezH#lE{xM&0s@>+C{|hD4 zbct#yYoN?rqn#Rg@2{b?F%q&Td4@i`M`uB16_++U^}`Z0I-)3Mmg#d-Ps%YT7?uN# ztQt}R)B!yKSRnP8nxss6?OigulbyrV!)mJrn zyB#oGzx&JH?mu5=>|fNFzic#a;VC4DWRj1|gKSWSp%z)D-4F8Y z(G95^Dv_g@G_{jPPZNbtv5O?#)$Sk*Dj7-QECZ9^E=vBNjmB|s%zR$w1sWd7CyLdzFH-Pg7>nTTqY%NH&>4%C25+1_95V1 z-r+Vs>duxvQR#6-u7nf7)@RW$!PW<1ILzfJCB7Ag4@Xg2_n6&KZ<2}jWJlv9O~Y_* z0hiFLXcdY)iHw`A^I3;?9U~i4hp~ERO#2p-()6N;q{nRDeK33M*somGs{B({@>3zP zac77@)3z#X?9>6Dxg;-bDJt)o_>ldfbKIC{@kOnNS4+O{RQ{jIRjnQMHY`my!i^3O zlT4}>WV-Q5^yC+Lp5BmV?b=arnnVZpl82)tD+f{93)C`=ULo=d_tSyQA7@2GUZvgp z6n%wujT;x5&u^VMyXG)DNW!|Cja+X2n-?+ z5)dM|(j!EWzHH#_7-Vp<15MriVCMF>@rm(+@ij~{ehozXrHTgGqbQB1HXsG4)F(x! zS^N=rat;_c5f+Nl+P9@aBSIo@u7M4wbb(L;GCsMz6A${Hk^3`w1HHkCj~*XAKJc=h z#Q}y+dm5#68o&;)6xhmT=fr}fb;gXWjWKhwGaGZYziz}Ct+@k;W8qR|$)y9K{yNcx zN8z5E$$XhG-vo#BiQ16hsf%Q$P?YqgxztYVaxMRdzV^cea527N`HlS|P&@Dr^ z0gwg@8zxnl+|tPIs~X}JXCuwiqw3&IN=aw2KDL{syK*20#`Nln*8u!DNGDIMJzHStbv# zQ>U`QIgq)Pb?U*P#2IqNG4k%IU)ggKuX3g>043XRNDHp}Df9v}TLWCv%oc!aRSsrp zmhH-|yvpCRszw!5%{>$gC*mGEmrLh_3KF+=y4uT(CkWjhB9Y!t9)&W=a6C)YO!K`# ziPxYg;WxC)y^gwqms%zA10p{pvQ6YiL~ep0Q=tMVlT7`XMCRyzLfkDPKLydY9K>a> zRi8N0{o$xA)Hh5IGT(K|js?5}^hV#IRb8)NoAv8P-=C~)bH;7Wxb3;bWze@ZNsD2g z#mQ}0DZUNT1?bYfB35{d2cm8X;d7t6Ukdn(`_<9ampo#Au9#qQdQrZh2sU2vcLW^H z4do@y?H@QJBr7S_s3NccMAjPny>)CKJ2-nA>5o7VFfRQw2Ni>}AVGDuAaN@p8*A*; zI%CJ4v@7p`;pExAeZX$P&x|yHM9AhElg@#-$x65M_Tgo))H*>jm4R%ZSjRr34OD>C z2JBICtr~Lab8#GCba!F{FkLlaqj|wbQ^L!jwY*@%nb<(TvC%R%oQm3Hqg@5?<5ty{ zE1$#vk@khWG_llyr4B50E?DZoQU_YEEm(3Vme6l3Eg4I0MQyUwHI|mDu6(_=R5r2u zm)BSo$SV{ZizkEVSlv++6WV|DxDE#OGc*+rv@Ryblq1z!M5w^fjWoKCYWRA%IsS^; zO_Nmzk|e14EflbLxr4ndR|&XQ&7mZg$&;SwFF_BXW_}FPyeRrj62vGM&^VucYkuOb zdfe9HxUHARZO@O}uE%XJj@zD&tD{_)m$qpIh*A^5AJeyaDCdWxGSN<)=2@b?PeQ0= zdrE*c;Gn(9?~QK93KN8=s8oLVRJ)@?8I_6p2%5cR6Z`mkipmHzeq3rVc~X~CeqHv* zWxr7-QT!lL;pREJb%tIh*ULk;7ltgtkSVeIuwD>I2FOJnzGCi{xNHhRLZoxPdx00iV&(0>mdG{F37P@(U?-Cb2A zMaf~j-^9U;+%eJ6YoV~rRs3K*%j#PD(zLn0JIc)OUX<$KoUeVxFN vKZra_yPue!7gMjBg2XIkZsoJOK>IiAAD#CpWKE^rVk-dkmRPZZ_T}Jz)koRN literal 0 HcmV?d00001 diff --git a/networks/__pycache__/dsc.cpython-310.pyc b/networks/__pycache__/dsc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03db69bbe4c8cad5342b1ee90541bdbdb2cafcc0 GIT binary patch literal 5197 zcmZ`-OOG4J5uPVGd@lE`m6e1Mj93hmT}g@ISTG`6E6EmODYhs}$k@(ccDl(OikzW) zhP&FuFt3&1TsDt;4qV^#cXI4gfSmRi_>>^9^HmS`r6@y8cUN~!cU5)QSJjMbwX%li z)4y~(VoB5fLB#3D0`V^1^p;! zTT*SSUmjH4%Ao326>axxgSuN+?LvQM&~O`QYpmFx9n86Ny7s=t99DX!u@bj>^LwVd z!1T7(Djz};?f#*TvZc%VN5Mye4@DgCB#A}Kkj0y~AANG4r?NC{Y+VIkV?7=W`#cTf z=%(_QTr`+K1ob5eAI@@J<&cp1o57nf1DjJPk{a#zHop%=u_mfHxU;VAVs{506~qF4-Mk%@R1k5ac0gd!OA zec=s##=@wJ!;GPgejGdzbP#2Ec@s5%lp@HUkHiGu0~;^PiF!?Lu&x4c2jc)w*! zE8+c)G=ol8=_OXc;TDp_gKb{uvgGA<^Sl?@zh7V5j0b$JH`w;0?%FSc;KAdcKU#}; zx*dxr$r?+7)!~jfhZ)Eg$tsGnUeK#}>nQ6+UDXB9C%^N?1$|Dxpf~jLx4Tui-|CT} z1KRYXhh}x~uH#M4V#F{g(Fm%6RHu8N6*= z$gGK#6%KTGtx2tk&5Q#*Ga0;eS)=|h7E&F4V*zfgePdjF49v9f+cR}7a;|F0jcstz z@0aBJ+N78jGxNas%o;o6(zraXsP9WaTJ5>Pj1CZC;xLoh2apv0Q(|^kXN7%2cj5pd zKUZ8`gV}0L@~iTsk~vv5txlAOrPbcdf%<}(&k0eC18ZEve2vL0VNj@z>)@|5=N&C8 zp)R2=qb{Sapst|SCg{oZ17kb`{st^qg=}-sw8(0Z&l=BWvy*wW^sK?^>WhU8eE4Ec zX*HYGj-{N(e2dw9Hp6Bh&m5Hp=6E4n=q(WnLWgC%=enlqX^^ff+E@B`FBCPbRNr=$TxA1Uj3?oHJ}et%VZOF0;jKDb%y&j?vcg zHqdA3zAOymFiI0aAxzAmz*#m~I8Z=>xMZQ79FhmQm91^B!~Jw500X5Ujs=6KGgXbfr1+&8cN$2!r#eZ(1uBVd`()xi21U@2g7v73;bY{ z%OVhNFiN>BjiLmY$alGDVoLFSTFX`|DQFPqsl(a|`F4(jRiF#di+iM=EO&Y8rM^Jb zIwvj?Z;?OsDJ_uY9N?55fM^y%msq%^Vc>;oNa~?$5haab;FeFeMf`|FAvh6hG^MRJ z#Z@?#+rJ9)Ge7k`%$m9tP(aU|s3yV5khp^HF&-R`%Ys7RMCwlwCiu3ebaiX=eG>S6 zPJ85*NINy3_yLKbiZ@7>VkZ{czF6o zmYrG%{d#6h3I|#Scw^?i1>h^B#nkDQdS!%eopPvsBj*7}Vc&duqd&0C)zPlb%L!1z`eaA~u2XUGq~ zh3?e5ok!Se*hSpr1FD!Dh)bG31(++rwN_f|8Bms{&{QfAH zZPn#09qT-2$X86HfwTJ%lZsDJNd1ni1%l29k5UBa9dh@mD=1=#w}``tB0`i%S;Ftj z;xzx2W-=Pc0+P)UPg<4SX9`IGbrMGq05QO-#33;lk}k>xN6qa+2+a^YhZSy)G!bWNdf zn99ZzKO7(x@0%BAKCk1jmXlwb(IY?;I=kfFnDJ6ny5vwVB4vQNo)z|4mNVq3##WxTAPfZh!o-EGZPBI3^xpby-n& zTilnZhjv^11O>p-4xm5Xfyf|8kq}UU7z9CpLs;)6u(%7ZExgJcRjUE0*sUL#BbW2c ziJ_&)XOtZ%F=9HBb({}qYwW&#EbvrY6@8Ot#5tp25jWMX{1pl+@s)5#ZV&uK3&;eyUrd6jpaBlp4Xgs>#y zUk>3<_)hYiJ|olZUO2MXtJwap9OAb$!SAS0d-AHq=V@_evXgcHKwqeJ6~MiMqou2q zwKYTc-3ZX(d9v)$ZE=KD$dgsi`(otx^AUB`6K$IG11bpML_)=g3UU`gheZ5_ir-VQ zheFn-m&$-6!S+t=5;=#Uv`UmvP}-j4_rALBoOj-EW*yrx@hdrwQ*4-utFsl5ern zBKZ_6?HM~^>ul-v{rg$t!sV6N)HS1owlpp?qHNC z?Wluss$Dgbon7s8CL+<{?NQ#lljg-pwtel0L?uyekPYKumd}-BUpqy1m}>uirgmeQ z7TV6lepG`PV6;)l{E*2yHvQihZ|;tU>COJ|ev&0h|Ty3 zoGo~1Z^9?BRXW8=X$k8Jn_BP=i*4of-IABAZ`0bxHf$c*gnQ0b6AXLscB0jDoA?M{ z{M5o?cOG)>e32!^sfQLk;ue!}DzzWSS)Qpl)=r)b;ll}9g|_lsd$+SZO{De))w6j& z&3d~ku9|SZQ>25gu4j4N*-i319Td7Tny6(P|U9=!VH-v|L?(ubNiyv;B`!zQfo89TquUK5jH6Z?~>0dJ~;GOyzA6SA5OW4Qur`FV;2FmMGJOr@8 zE3k!AzYNhv82{DkT9?_hj&VNv)Vmgf*EzPPjk009bVqV*^P#0%GF20q!v}k*EK+TY zZ1|AN@8Rt!b!i9k3LctoYe%MoiF^tzNg&l8Jauo9>ZXahxacHG!J+lJ5)TnxHjM42 zeA@KS^6><&4ZhevIK+Yi(Kk$!Jo1x{%JveKYO7ahz90Gfaph{wW9{xGgGn0IF*B)?VatNwnxY%!~(ft91llgG9bMX$B6qu^~OL;zJgJ@IY(Y6 z9Sj!eGdp}1Zz~)U{8n)nMCU%U9E-R9x7Lzm|8AZjHeK)pA*9EI{Wb5}hAkZ7J|MPe z8mb0cfH%Si;0sX~4Lm)=9yMU1rCT3#K3EdiNS;22iT{lo*FWEAd-CT{xk#PtR3-9c4Q-ZeDrTFdnHfydX3$CMb-Nk1)@~=A~0OIZY*Kxth2|gm&jT= zk6AVhKh}557d#~o15an4lwvIk)o2_Kh`{6O6@A`2k}zSud3#f?lTL&+k3nS;oFs+? zqaxK5+5zPjx-}k+K$Ja-%K@zQp3Fo%OcA1m_D0=qkt*F#iR`7S>La&klJGr)5buJpt9<1yxbc|b{>ad;sUw%ZhBA{kZQY75-BuS@-ofUN_ssEqaa6K?CY><8$E#D zofO$92gMLS$Tvy6L4qKy1Cti$CrRQ}`4$;n^J}7DxfyA~XBq0|(@?w>FF=IW3V3|+ z7h0a>asM07X(0(gjhwfQ$=b_P*u4M+J^zeFq?P@kF0K;Cc*T&aJ)xsnd-fz95J% zbidMj_PKEZc2?~I$+_CaEI^O2Uq7LQ-ayLt0o!1cO~Y^5QHUMY(Z7lGUqceHK_V_C z5o05ZHMcZNTmG$Vluc8dQHlw6suUxi$0s`_B9K=h+Lm^OsntA@?q!|yj8NJw18I4k z`nDmW#xj!tK=O666EY0gB^59gWX$B`6H2Mu|_pihcL}! zMc$x=3{aP7)tiYB@wh|7<0Kc$x!yzHg2Hi~U*%hT!+Mp!gy*_N>4|E`!)s6A^*bi& z(8pO^gkGX$3a8>f5E+}=JvJr8nL-d0`U*;Li!drcnKSEzu!_7E*E5pLT zasi|wYQb^_tU|s|^U&{+A45b9kmFvK-k*aRJR7U=qH=03i?3Mqly*^urTYf|D8`cg z?vn_;obD@$`Yi6HIQb}fu-wLD4o{%bHufee5+wY`nc3kN(T?q%X$P^YQT^c=>F6upAGZFUOO< z7!Up2_{OwpINU~ThGnDYh|v7{maztilV;DJ+b)!?@`9*;g(FN^LkaXYVzEVhaAJLG zPgka^)3tJ?TrJm1uj?ST8aoIgiXS|FheyTm@LiPwwEIP(I=d8l#6?vt`3M)vA;Pe( z(dhtZ?)HVM#BoM>TF5`pD@%>ED{=J3Ic1RM@{eTnb3na284RkRG6lq2G;@>0o7BZE zRMNRCKZA}aHpPYsFq9Y~!--)eBOD^hISH}^*S>0l8lI_Qi>QqYMPuBK(fWdEB@-rB}sO97ouIS=#0{}_RTr)MxIxNdnXs?AE&;K8q-D6?U}wu=@hy}rzlHrO1|oAk0f6U^4p~&)mid& z-*_ZBWq-xb`Q{_pS@td8MoGob`vu%r(5L8^(5IC3siI`bFQcR!sN3g2WfdjMeg!3! ztmM2e-;?Uq6UBx#bPgSMX7x4p34xZUYDb{e$8lOH`7a1qV$V5B4Vl$Pj?HD6)^S*?U z8=EIGMs>HYBzC*!w;O@G6Ff|``)w~gp+W!WkAFnJfBtg4km%v=AP|Cz6Vr9uy>{ff zi4hJO{oY;}E4?R6ac-tPUk@1SnOPn2FS$=~*(#@1$EbZ`2J)$tw% zA_SGhdMogD9=;pA_1lS&mL|FgI=kz2Ez!cD(@eBG)K0(S_ID$p=_H zjigu}%R_ABNaiOsQbuZ|F)K;2601$+E7FmCgv|#VsM`jv+*li88@J6^b4}bW+$~VB zV{J^VtZp`F(gV%n6=n?UXbol(-+?#!|?+JhH7P#H(gs zl01?LR}dGx7xcouNYrq*o8(ev3!P#(z-D+IH{9|DL1N!07IDM&Vc_I5h6GneiPjQr zU(oF8x>zB~@jsVls}Jhb-o;eWgOA!=zfmwtmDkhM1`J*%254wSCHr=S~u`| z;#oW%mUeqPz5af$)(xVqzJKxrzq*kWqrM~+eaCKk_>%ju;QWbtLxEa)1NOS*U& zO+3&t(2~{V!}skPey>6ud)?Z>tF^uULD1Q6he2&3{__nmxL4D>2%sB!uhs^AOyO*M z9(7&eIf-AzD^E|uXnE0AAYMZi>kZQouj6@P+!hL^I`R*YnQZ$AY~U+nrA2r3SQ#q3 zxnKZ-1;^?)l%YD*xJ(5+$bF~~Ncd1`Di;AE(8l66`p}595do;o?+fnOy_k&KzOZ97_6{d?uE6wZW z@+!NHeqV&wHos82fAH|Io~`8;`jTI&kJAJC;Pi_I6`%RSyrH)z^>xAy=R zt$Xcl?#uMkS3RXKy^Y#)E9#LAaIV+s0g5kSe^_&f4LrTaq65oOOc@1p-R(sobM)p3 z$@OC;DP?u--X5%bz?G0Douw&JAx%jR_J@_dvD<3{EQnvG*SvzHZaK?m7Spk&i{w~5 zM4GcxoK9wzrG`8x$C}nUIeLUCaf-aqB*2{Ia`ZUO#pUTc+{hEHAbJxc3rRm-va0fP zkV=wG@l_Nwa4r035$G2_k7NQJeS^Uxpxq-2;1RqEC_4j>xkSBp_YGd=1}VnLkDlpL zaWOoJq$8jRhK&kTp#mOPk2F9zmDJp0c@z4E#49yQq<9lpSHfnQik?_b7*@Bj<@I`C z=cE9mIn6mXD|5DcPKwan-fma04P&#)hKzLl@6xbnzBbn5uz)E%9Td^<>AFMh$@(oB6)%}2ori0Mr!#oWlGxi`!<>DG8Ge9si z#=n7z-=XfQF)qnvO;A|_*TQeQr)wyl>}d;*)8EKVd0}UB?7+x#dut}zy*KZ^QP)mL zR|_)dC-lA((%21pn*7M*)3VZ1g}0!RX1k5OmR#C#>@;>=a$(kyrfvevW$Grx1oJJp z1NIn8?B=IWg{T-y+t!f`sPF=`s4+B<&yPxIv4jiy5gJmyB#jm1%46IKV4zK5z74gQ zv>OA{%F=@uClsRP^bN6E#kx?t87+*ZlD69>O+2N-u8tZR+hsgy$EmZkg8Zx%IYTh- zZgtY1T%;hWXyhc!8oymw}a$zqB zlM=JsR1T7*6iUttGy;602x5uxR=XE?A}M8@>59PLZ3Ibi(C>HJy%u4H4xsGY3>n`7 zHcoQQhD#thQJal0u>#0cAUxoTFA_m}<})Q4yFqFv=s8PF2imO@G{|2`lxBl5u8VsS zZ^9tRA%d*bb1XUHeUKAhCz?y?pr#WL#OqFOkPc4VpoUM8YU2Lz77|-FvVTc20OV{y zIvb!4rJp{hX^?qER-RDql7_ZVRuuw$-0~-fMPmA?I@nxbbb2WhDa3|7!`C0;>YHLk`wMpxxF5D1JUkL%33puS3oIjdJ&vR^n> zvUZ50q-!;O!%thID%Pl!t0;EVtkP*Ncf0dLf zd2t~*9PFdh_vf>7v%S4A*^Ret0Bz3Nn>U6e7&!FP_xEOO*rNbI?-)DHK ziAC6lKma)f8k~Yq5K9We-}Q5ed^gc3{vaGGbSe2gO5UL4E|LWBA_Q^?4g-MuUL<~< zY6znXvPh{^>tWh|Hv`qVI>BCXkJ51^N5J)At@L!ZJT&Ax+> z420YDy~txbO_@}=6f%9qjg#_x4qf-0O#FQ}(@{x2Q)sM(ce;)Wr$_uEji;0UU;8gO7_5%s&eJ3pfcjBg&#pK`h|o@7R^1G9fEI7o#or` za}4NA-y%Tg+kPH*-7ok>+zm+067D&U5}J^ha#FgN#i4G)^KNX_)f2=3B{9UCPHL18 z3kdRxHWM=v9(K~}END++Q^I(xJ=Ydt0)c`Bm=GbLL?-u$45I;`7!Hw-_yKY`V*3-vpqZT2#?QgVc(02)7Hu;i8Nr(IB+zoW8{AD-$U3h?;yzO)L)d(ul7F-8*> zpb3m90XhIY03?`C7E%@AR)~8E;!UAA;}gwz+Oz$M_CJT`7sY1^Fc#b!3PQ~P@LX?L?F}vghIeG7wgmx zlEeLh9<;GW*n>h06ka%OB~Ly@cqwzIX^5%Kk=az%t<=P@H6c4A-lU{Q2^$xRsEVsd z5*>lhFhB>=7)gF_f1>E2hLK?-B6tgbj`}>LA-uj#`pnuc`exb_duko)-%yA45NHuCOR*j2Pth|k z5yg3$MSeS4KIQFaTbDq0DTY=A(gae(2$nE{tT&KXI)d_WnXw|_#N|^Xu+NNOg-3vv zPZ`14)>VvPIj&*^=P-f_MsN-z$fv#IN;ZPk;rZbOUW3*6{HYP-&y3(AD6RN~9p(G7 z_($Gd-3PMvi&LDM?rYW8U=L_lXZ4@AyA%tS(cbI!07wM!{b?rX!>BW}@@aN($OHAdYj$y-itCn(Q?r3?t|{q$Hw* z;uhi_CEucC3kih*yEuHCYLR4#vYmdb9pMIF7ZI30qlTAfxH;8l4~WK8m9cum@!%lQ zJt4e@thET5MjPwmJJj@hlw6|Zhm?Fo$sbbkT}u9l5;94mOUYM}5TwVEk>2R^dx1En z;*Tl$6D0Lo>M8$#+F7h8$q`8G<}S`{IPaK){WK_>XavIrsZdZOtes@QI}Hyzh16iD zz!&`pG`M^Fh*R$0%!pZoxks+Yk~tOXxQ;gyq@O3Yo_X{svGt7D`i$7@5yW8fC&V_M zLTqD!*hWTdBQp;-FfGT*zyio18UlB~y|F%hdTILfibFOg4WOdIT;`?u%qypDNg)I0 zZok(E{uM6_iM7bqm*m?(13!;LI9shCm{54~qM`xde^$Pr*5ru`FmncGPpir*+y({I z%Wxygvhow-Q^f?rcyf68l)aiCR|7U_`YX`*aIv+z280b=hM*P@?zbU_@KNbd0V&xK z+1HT5?(nm5Q5x$*Lfmx30BlQV{LpRO17~sR9hweSPH)5hMMjj1H9Gu7XzmRhRA`gK z3E#TN5sZ-qB!M_WY{iyu?I^b--18Su7THlA7RzMcA74itE2zP@5kMwXQ%q~nPe>BqKFxnRXXu!7o2hyr z(Gq_}$zLN$3LMdgPM~8BF{1K6M*;<6_`x2&H~^T59^%6Staa)O&5j&So={6UG1{$O zUj)==KWMkMBFG_^kn9Erjtb^<^4WJCoqiZPIef`L52+4P_?Bo?QZJ|K;%`AZSw?(6 z`-~tZ93~2!*}tQq{5>TU^cCbdQs-Vv{3YcTN_MC)mwiN%6dMAxUAV-eKfrl(7zyzg zXi+bu5X~Yd-etX%3l6;Q01-lRn7PO)c~Cx4;Nie=;I##)@s|ygAz3tTZYi#Sren{V zI?L&C1+Q)h$sRm(Y#gxMc!{U)CnUp~a|Wi&48DAjVp|c^#(aTC-8FUE7MHLHO$R>z+7h@2&Mk>-xZNUSGd) zJ6~ol!;$wJ7#7sO=y&O;1K))FhzeAY7eB@m z`#h7A2`dhMCB)RM%{q$pHH)?mnwK=H_LKYpTkbWwZ z2FUqa1pZDSb#h+EFxiZ!KCzLtdWVW>CUcYLWniDhyp?{S@fuB%Pv7YKiu41X0zUF6 WAgIKtVqURLvtU}4<%+%f+W!OS{g#pd literal 0 HcmV?d00001 diff --git a/networks/bev_net.py b/networks/bev_net.py new file mode 100644 index 0000000..8b43c4a --- /dev/null +++ b/networks/bev_net.py @@ -0,0 +1,294 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from dropblock import DropBlock2D + + +class BEVFusion(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, bev_features, sem_features, com_features): + return torch.cat([bev_features, sem_features, com_features], dim=1) + + @staticmethod + def channel_reduction(x, out_channels): + """ + Args: + x: (B, C1, H, W) + out_channels: C2 + + Returns: + + """ + B, in_channels, H, W = x.shape + assert (in_channels % out_channels == 0) and (in_channels >= out_channels) + + x = x.view(B, out_channels, -1, H, W) + # x = torch.max(x, dim=2)[0] + x = x.sum(dim=2) + return x + + +class BEVUNet(nn.Module): + def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock): + super().__init__() + self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding) + self.down1 = down(64, 128, dilation, group_conv, circular_padding) + self.down2 = down(256, 256, dilation, group_conv, circular_padding) + self.down3 = down(512, 512, dilation, group_conv, circular_padding) + self.down4 = down(1024, 512, dilation, group_conv, circular_padding) + self.up1 = up(1536, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up2 = up(1024, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up3 = up(512, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.dropout = nn.Dropout(p=0. if dropblock else dropout) + self.outc = outconv(128, n_class) + + self.bev_fusions = nn.ModuleList([BEVFusion() for _ in range(3)]) + + def forward(self, x, sem_fea_list, com_fea_list): + x1 = self.inc(x) # [B, 64, 256, 256] + x2 = self.down1(x1) # [B, 128, 128, 128] + x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 256 + x3 = self.down2(x2_f) # [B, 256, 64, 64] + x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 512 + x4 = self.down3(x3_f) # [B, 512, 32, 32] + x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 1024 + x5 = self.down4(x4_f) # [B, 512, 16, 16] + x = self.up1(x5, x4_f) + x = self.up2(x, x3_f) + x = self.up3(x, x2_f) + x = self.up4(x, x1) + x = self.outc(self.dropout(x)) + return x + + +class BEVFusionv1(nn.Module): + def __init__(self, channel): + super().__init__() + + self.attention_bev = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channel, channel, kernel_size=1), + nn.Sigmoid() + ) + self.attention_sem = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channel, channel, kernel_size=1), + nn.Sigmoid() + ) + self.attention_com = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channel, channel, kernel_size=1), + nn.Sigmoid() + ) + + self.adapter_sem = nn.Conv2d(channel//2, channel, 1) + self.adapter_com = nn.Conv2d(channel//2, channel, 1) + + def forward(self, bev_features, sem_features, com_features): + sem_features = self.adapter_sem(sem_features) + com_features = self.adapter_com(com_features + ) + attn_bev = self.attention_bev(bev_features) + attn_sem = self.attention_sem(sem_features) + attn_com = self.attention_com(com_features) + + fusion_features = torch.mul(bev_features, attn_bev) \ + + torch.mul(sem_features, attn_sem) \ + + torch.mul(com_features, attn_com) + + return fusion_features + + +class BEVUNetv1(nn.Module): + def __init__(self, n_class, n_height, dilation, bilinear, group_conv, input_batch_norm, dropout, circular_padding, dropblock): + super().__init__() + self.inc = inconv(64, 64, dilation, input_batch_norm, circular_padding) + self.down1 = down(64, 128, dilation, group_conv, circular_padding) + self.down2 = down(128, 256, dilation, group_conv, circular_padding) + self.down3 = down(256, 512, dilation, group_conv, circular_padding) + self.down4 = down(512, 512, dilation, group_conv, circular_padding) + self.up1 = up(1024, 512, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up2 = up(768, 256, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up3 = up(384, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.up4 = up(192, 128, circular_padding, bilinear = bilinear, group_conv = group_conv, use_dropblock=dropblock, drop_p=dropout) + self.dropout = nn.Dropout(p=0. if dropblock else dropout) + self.outc = outconv(128, n_class) + + channels = [128, 256, 512] + self.bev_fusions = nn.ModuleList([BEVFusionv1(channels[i]) for i in range(3)]) + + def forward(self, x, sem_fea_list, com_fea_list): + x1 = self.inc(x) # [B, 64, 256, 256] + x2 = self.down1(x1) # [B, 128, 128, 128] + x2_f = self.bev_fusions[0](x2, sem_fea_list[0], com_fea_list[0]) # 128, 64, 64 -> 128 + x3 = self.down2(x2_f) # [B, 256, 64, 64] + x3_f = self.bev_fusions[1](x3, sem_fea_list[1], com_fea_list[1]) # 256, 128, 128 -> 256 + x4 = self.down3(x3_f) # [B, 512, 32, 32] + x4_f = self.bev_fusions[2](x4, sem_fea_list[2], com_fea_list[2]) # 512, 256, 256 -> 512 + x5 = self.down4(x4_f) # [B, 512, 16, 16] + x = self.up1(x5, x4_f) # 512, 512 + x = self.up2(x, x3_f) # 512, 256 + x = self.up3(x, x2_f) # 256, 128 + x = self.up4(x, x1) # 128, 64 + x = self.outc(self.dropout(x)) + return x + + +class double_conv(nn.Module): + '''(conv => BN => ReLU) * 2''' + def __init__(self, in_ch, out_ch,group_conv,dilation=1): + super(double_conv, self).__init__() + if group_conv: + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1,groups = min(out_ch,in_ch)), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, 3, padding=1,groups = out_ch), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + else: + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + + def forward(self, x): + x = self.conv(x) + return x + +class double_conv_circular(nn.Module): + '''(conv => BN => ReLU) * 2''' + def __init__(self, in_ch, out_ch,group_conv,dilation=1): + super(double_conv_circular, self).__init__() + if group_conv: + self.conv1 = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=(1,0),groups = min(out_ch,in_ch)), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_ch, out_ch, 3, padding=(1,0),groups = out_ch), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=(1,0)), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_ch, out_ch, 3, padding=(1,0)), + nn.BatchNorm2d(out_ch), + nn.LeakyReLU(inplace=True) + ) + + def forward(self, x): + #add circular padding + x = F.pad(x,(1,1,0,0),mode = 'circular') + x = self.conv1(x) + x = F.pad(x,(1,1,0,0),mode = 'circular') + x = self.conv2(x) + return x + +class inconv(nn.Module): + def __init__(self, in_ch, out_ch, dilation, input_batch_norm, circular_padding): + super(inconv, self).__init__() + if input_batch_norm: + if circular_padding: + self.conv = nn.Sequential( + nn.BatchNorm2d(in_ch), + double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation) + ) + else: + self.conv = nn.Sequential( + nn.BatchNorm2d(in_ch), + double_conv(in_ch, out_ch,group_conv = False,dilation = dilation) + ) + else: + if circular_padding: + self.conv = double_conv_circular(in_ch, out_ch,group_conv = False,dilation = dilation) + else: + self.conv = double_conv(in_ch, out_ch,group_conv = False,dilation = dilation) + + def forward(self, x): + x = self.conv(x) + return x + +class down(nn.Module): + def __init__(self, in_ch, out_ch, dilation, group_conv, circular_padding): + super(down, self).__init__() + if circular_padding: + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + double_conv_circular(in_ch, out_ch,group_conv = group_conv,dilation = dilation) + ) + else: + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + double_conv(in_ch, out_ch, group_conv=group_conv, dilation=dilation) + ) + + def forward(self, x): + x = self.mpconv(x) + return x + +class up(nn.Module): + def __init__(self, in_ch, out_ch, circular_padding, bilinear=True, group_conv=False, use_dropblock=False, drop_p=0.5): + super(up, self).__init__() + + # would be a nice idea if the upsampling could be learned too, + # but my machine do not have enough memory to handle all those weights + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + elif group_conv: + self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2, groups = in_ch//2) + else: + self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) + + if circular_padding: + self.conv = double_conv_circular(in_ch, out_ch,group_conv = group_conv) + else: + self.conv = double_conv(in_ch, out_ch, group_conv = group_conv) + + self.use_dropblock = use_dropblock + if self.use_dropblock: + self.dropblock = DropBlock2D(block_size=7, drop_prob=drop_p) + + def forward(self, x1, x2): + x1 = self.up(x1) + + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, + diffY // 2, diffY - diffY//2)) + + # for padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + if self.use_dropblock: + x = self.dropblock(x) + return x + +class outconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/networks/completion.py b/networks/completion.py new file mode 100644 index 0000000..a4c1adf --- /dev/null +++ b/networks/completion.py @@ -0,0 +1,126 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + +from utils.lovasz_losses import lovasz_softmax + + +class ResBlock(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, padding, stride, dilation=1): + super().__init__() + self.reduction = nn.Conv3d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation) + self.layer = nn.Conv3d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation) + + def forward(self, x): + add = self.reduction(x) + out = self.layer(F.relu(add)) + out_res = F.relu(add + out) + return out_res + + +def make_layers(in_dim, out_dim, kernel_size=3, padding=1, stride=1, dilation=1,downsample=False, blocks=2): + layers = [] + if downsample: + layers.append(nn.MaxPool3d(2)) + layers.append(ResBlock(in_dim, out_dim, kernel_size, padding, stride, dilation)) + for _ in range(1, blocks): + layers.append(ResBlock(out_dim, out_dim, kernel_size, padding, stride, dilation)) + return nn.Sequential(*layers) + + +class CompletionBranch(nn.Module): + def __init__(self, init_size=32, nbr_class=20, phase='trainval'): + super().__init__() + self.nclass = nbr_class + self.in_layer = nn.Conv3d(1, 16, kernel_size=7, padding=3, stride=2, dilation=1) # 1/2, 16 + self.block_1 = make_layers(16, 16, kernel_size=3, padding=1, stride=1, dilation=1, blocks=1) # 1/2, 16 + self.block_2 = make_layers(16, 32, kernel_size=3, padding=1, stride=1, dilation=1, downsample=True, blocks=1) # 1/4, 32 + self.block_3 = make_layers(32, 64, kernel_size=3, padding=2, stride=1, dilation=2, downsample=True, blocks=1) # 1/8, 64 + + self.reduction_1 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=1), + nn.ReLU() + ) + self.reduction_2 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(), + ) + + self.phase = phase + if phase == 'trainval': + self.out2 = nn.Sequential( + nn.Conv3d(16, 16, kernel_size=1), + nn.ReLU(), + nn.Conv3d(16, 2, kernel_size=1)) + self.out4 = nn.Sequential( + nn.Conv3d(32, 32, kernel_size=1), + nn.ReLU(), + nn.Conv3d(32, 2, kernel_size=1)) + self.out8 = nn.Sequential( + nn.Conv3d(64, 32, kernel_size=1), + nn.ReLU(), + nn.Conv3d(32, 2, kernel_size=1)) + + def forward_once(self, inputs): + out = F.relu(self.in_layer(inputs)) + res1 = self.block_1(out) # B, 16, 16, 128, 128 + res2 = self.block_2(res1) # B, 32, 8, 64, 64 + res3 = self.block_3(res2) # B, 64, 4, 32, 32 + + bev_1 = self.reduction_1(res1.flatten(1, 2)) # B, 64, 128, 128 + bev_2 = self.reduction_2(res2.flatten(1, 2)) # B, 128, 64, 64 + bev_3 = res3.flatten(1, 2) # B, 256, 32, 32 + + if self.phase == 'trainval': + logits_2 = self.out2(res1) + logits_4 = self.out4(res2) + logits_8 = self.out8(res3) + + return dict( + mss_bev_dense = [bev_1, bev_2, bev_3], + mss_logits_list = [logits_2, logits_4, logits_8] + ) + + return dict( + mss_bev_dense = [bev_1, bev_2, bev_3] + ) + + def forward(self, data_dict, example): + if self.phase == 'trainval': + out_dict = self.forward_once(data_dict['vw_dense']) + teacher_2, teacher_4, teacher_8 = out_dict['mss_logits_list'] + teacher_2 = teacher_2.permute(0, 1, 4, 3, 2) + teacher_4 = teacher_4.permute(0, 1, 4, 3, 2) + teacher_8 = teacher_8.permute(0, 1, 4, 3, 2) + + sc_label_1_2_copy = example['label_1_2'].clone() + sc_label_1_2_copy = ((0 < sc_label_1_2_copy) & (sc_label_1_2_copy < self.nclass)).long() + sc_label_1_2_copy[example['invalid_1_2'] == 1] = 255 + scale_loss_1_2 = lovasz_softmax(F.softmax(teacher_2, dim=1), sc_label_1_2_copy, ignore=255) + focal_loss_1_2 = F.cross_entropy(teacher_2, sc_label_1_2_copy, ignore_index=255) + loss = {"1_2_lovasz_loss": scale_loss_1_2,"1_2_ce_loss": focal_loss_1_2} + + sc_label_1_4_copy = example['label_1_4'].clone() + sc_label_1_4_copy = ((0 < sc_label_1_4_copy) & (sc_label_1_4_copy < self.nclass)).long() + sc_label_1_4_copy[example['invalid_1_4'] == 1] = 255 + scale_loss_1_4 = lovasz_softmax(F.softmax(teacher_4, dim=1), sc_label_1_4_copy, ignore=255) + focal_loss_1_4 = F.cross_entropy(teacher_4, sc_label_1_4_copy, ignore_index=255) + loss.update({"1_4_lovasz_loss": scale_loss_1_4,"1_4_ce_loss": focal_loss_1_4}) + + sc_label_1_8_copy = example['label_1_8'].clone() + sc_label_1_8_copy = ((0 < sc_label_1_8_copy) & (sc_label_1_8_copy < self.nclass)).long() + sc_label_1_8_copy[example['invalid_1_8'] == 1] = 255 + scale_loss_1_8 = lovasz_softmax(F.softmax(teacher_8, dim=1), sc_label_1_8_copy, ignore=255) + focal_loss_1_8 = F.cross_entropy(teacher_8, sc_label_1_8_copy, ignore_index=255) + loss.update({"1_8_lovasz_loss": scale_loss_1_8,"1_8_ce_loss": focal_loss_1_8}) + + return dict( + mss_bev_dense=out_dict['mss_bev_dense'], + loss=loss + ) + else: + out_dict = self.forward_once(data_dict['vw_dense']) + return out_dict + diff --git a/networks/dsc.py b/networks/dsc.py new file mode 100644 index 0000000..013f64b --- /dev/null +++ b/networks/dsc.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .preprocess import PcPreprocessor +from .bev_net import BEVUNet, BEVUNetv1 +from .completion import CompletionBranch +from .semantic_segmentation import SemanticBranch +from utils.lovasz_losses import lovasz_softmax + +class DSC(nn.Module): + def __init__(self, cfg, phase='trainval'): + super().__init__() + self.phase = phase + nbr_classes = cfg['DATASET']['NCLASS'] + self.nbr_classes = nbr_classes + self.class_frequencies = cfg['DATASET']['SC_CLASS_FREQ'] + ss_req = cfg['DATASET']['SS_CLASS_FREQ'] + + self.lims = cfg['DATASET']['LIMS'] # [[0, 51.2], [-25.6, 25.6], [-2, 4.4]] + self.sizes = cfg['DATASET']['SIZES'] # [256, 256, 32] # W, H, D (x, y, z) + self.grid_meters = cfg['DATASET']['GRID_METERS'] # [0.2, 0.2, 0.2] + self.n_height = self.sizes[-1] # 32 + self.dilation = 1 + self.bilinear = True + self.group_conv = False + self.input_batch_norm = True + self.dropout = 0.5 + self.circular_padding = False + self.dropblock = False + + self.preprocess = PcPreprocessor(lims=self.lims, sizes=self.sizes, grid_meters=self.grid_meters, init_size=self.n_height) + self.sem_branch = SemanticBranch(sizes=self.sizes, nbr_class=nbr_classes-1, init_size=self.n_height, class_frequencies=ss_req, phase=phase) + self.com_branch = CompletionBranch(init_size=self.n_height, nbr_class=nbr_classes, phase=phase) + self.bev_model = BEVUNetv1(self.nbr_classes*self.n_height, self.n_height, self.dilation, self.bilinear, self.group_conv, + self.input_batch_norm, self.dropout, self.circular_padding, self.dropblock) + + def forward(self, example): + batch_size = len(example['points']) + with torch.no_grad(): + indicator = [0] + pc_ibatch = [] + for i in range(batch_size): + pc_i = example['points'][i] + pc_ibatch.append(pc_i) + indicator.append(pc_i.size(0) + indicator[-1]) + pc = torch.cat(pc_ibatch, dim=0) + vw_feature, coord_ind, full_coord, info = self.preprocess(pc, indicator) # N, C; B, C, W, H, D + coord = torch.cat([coord_ind[:, 0].reshape(-1, 1), torch.flip(coord_ind, dims=[1])[:, :3]], dim=1) + bev_dense = self.sem_branch.bev_projection(vw_feature, coord, np.array(self.sizes, np.int32)[::-1], batch_size) # B, C, H, W + torch.cuda.empty_cache() + + ss_data_dict = {} + ss_data_dict['vw_features'] = vw_feature + ss_data_dict['coord_ind'] = coord_ind + ss_data_dict['full_coord'] = full_coord + ss_data_dict['info'] = info + ss_out_dict = self.sem_branch(ss_data_dict, example) # B, C, D, H, W + + sc_data_dict = {} + occupancy = example['occupancy'].permute(0, 3, 2, 1) # B, D, H, W + sc_data_dict['vw_dense'] = occupancy.unsqueeze(1) + sc_out_dict = self.com_branch(sc_data_dict, example) + + inputs = torch.cat([occupancy, bev_dense], dim=1) # B, C, H, W + x = self.bev_model(inputs, ss_out_dict['mss_bev_dense'], sc_out_dict['mss_bev_dense']) + new_shape = [x.shape[0], self.nbr_classes, self.n_height, *x.shape[-2:]] # [B, 20, 32, 256, 256] + x = x.view(new_shape) + out_scale_1_1 = x.permute(0,1,4,3,2) # [B,20,256,256,32] + + if self.phase == 'trainval': + loss_dict = self.compute_loss(out_scale_1_1, self.get_target(example)['1_1'], ss_out_dict['loss'], sc_out_dict['loss']) + return {'pred_semantic_1_1': out_scale_1_1}, loss_dict + + return {'pred_semantic_1_1': out_scale_1_1} + + def compute_loss(self, scores, labels, ss_loss_dict, sc_loss_dict): + ''' + :param: prediction: the predicted tensor, must be [BS, C, H, W, D] + ''' + class_weights = self.get_class_weights().to(device=scores.device, dtype=scores.dtype) + + loss_1_1 = F.cross_entropy(scores, labels.long(), weight=class_weights, ignore_index=255) + loss_1_1 += lovasz_softmax(torch.nn.functional.softmax(scores, dim=1), labels.long(), ignore=255) + loss_1_1 *= 3 + + loss_seg = sum(ss_loss_dict.values()) + loss_com = sum(sc_loss_dict.values()) + loss_total = loss_1_1 + loss_seg + loss_com + loss = {'total': loss_total, 'semantic_1_1': loss_1_1, 'semantic_seg': loss_seg, 'scene_completion': loss_com} + + return loss + + def weights_initializer(self, m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def weights_init(self): + self.apply(self.weights_initializer) + + def get_parameters(self): + return self.parameters() + + def get_class_weights(self): + ''' + Class weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf) + ''' + epsilon_w = 0.001 # eps to avoid zero division + weights = torch.from_numpy(1 / np.log(np.array(self.class_frequencies) + epsilon_w)) + + return weights + + def get_target(self, data): + ''' + Return the target to use for evaluation of the model + ''' + label_copy = data['label_1_1'].clone() + label_copy[data['invalid_1_1'] == 1] = 255 + return {'1_1': label_copy} + + def get_scales(self): + ''' + Return scales needed to train the model + ''' + scales = ['1_1'] + return scales + + def get_validation_loss_keys(self): + return ['total', 'semantic_1_1', 'semantic_seg', 'scene_completion'] + + def get_train_loss_keys(self): + return ['total', 'semantic_1_1', 'semantic_seg', 'scene_completion'] + diff --git a/networks/preprocess.py b/networks/preprocess.py new file mode 100644 index 0000000..ac67c67 --- /dev/null +++ b/networks/preprocess.py @@ -0,0 +1,139 @@ +import torch +import torch_scatter +import torch.nn as nn +import torch.nn.functional as F + + +def quantitize(data, lim_min, lim_max, size, with_res=False): + idx = (data - lim_min) / (lim_max - lim_min) * size.float() + idxlong = idx.type(torch.cuda.LongTensor) + if with_res: + idx_res = idx - idxlong.float() + return idxlong, idx_res + else: + return idxlong + + +class VFELayerMinus(nn.Module): + def __init__(self, + in_channels, + out_channels, + name='', + last_vfe = False): + super().__init__() + self.name = 'VFELayerMinusSlim' + name + if not last_vfe: + out_channels = out_channels // 2 + self.units = out_channels + + self.linear = nn.Linear(in_channels, self.units, bias=True) + self.weight_linear = nn.Linear(6, self.units, bias=True) + + def forward(self, inputs, bxyz_indx, mean=None, activate=False, gs=None): + x = self.linear(inputs) + if activate: + x = F.relu(x) + if gs is not None: + x = x * gs + if mean is not None: + x_weight = self.weight_linear(mean) + if activate: + x_weight = F.relu(x_weight) + x = x * x_weight + _, value = torch.unique(bxyz_indx, return_inverse=True, dim=0) + max_feature, _ = torch_scatter.scatter_max(x, value, dim=0) + gather_max_feature = max_feature[value, :] + x_concated = torch.cat((x, gather_max_feature), dim=1) + return x_concated + + +class PcPreprocessor(nn.Module): + def __init__(self, lims, sizes, grid_meters, init_size=32, offset=0.5, pooling_scales=[0.5, 1, 2, 4, 6, 8]): + # todo move to cfg + super().__init__() + self.sizes = torch.tensor(sizes).float() + self.lims = lims + self.pooling_scales = pooling_scales + self.grid_meters = grid_meters + self.offset = offset + self.target_scale = 1 + + self.multi_scale_top_layers = nn.ModuleDict() + self.feature_list = { + 0.5: [10, init_size], + 1: [10, init_size], + } + self.target_scale = 1 + for scale in self.feature_list.keys(): + top_layer = VFELayerMinus(self.feature_list[scale][0], + self.feature_list[scale][1], + "top_layer_" + str(int(10*scale) if scale == 0.5 else scale)) + self.multi_scale_top_layers[str(int(10*scale) if scale == 0.5 else scale)] = top_layer + + self.aggtopmeanproj = nn.Linear(6, init_size, bias=True) + self.aggtopproj = nn.Linear(2*init_size, init_size, bias=True) + self.aggfusion = nn.Linear(init_size, init_size, bias=True) + + def add_pcmean_and_gridmean(self, pc, bxyz_indx, return_mean=False): + _, value = torch.unique(bxyz_indx, return_inverse=True, dim=0) + pc_mean = torch_scatter.scatter_mean(pc[:, :3], value, dim=0)[value] + pc_mean_minus = pc[:, :3] - pc_mean + + m_pergird = torch.tensor(self.grid_meters, dtype=torch.float, device=pc.device) + xmin_ymin_zmin = torch.tensor([self.lims[0][0], self.lims[1][0], self.lims[2][0]], dtype=torch.float, device=pc.device) + pc_gridmean = (bxyz_indx[:, 1:].type(torch.cuda.FloatTensor) + self.offset) * m_pergird + xmin_ymin_zmin + pc_gridmean_minus = pc[:, :3] - pc_gridmean + + pc_feature = torch.cat((pc, pc_mean_minus, pc_gridmean_minus), dim=1) # same input n, 10 + mean = torch.cat((pc_mean_minus, pc_gridmean_minus), dim=1) # different input_mean n, 6 + if return_mean: + return pc_feature, mean + else: + return pc_feature + + def extract_geometry_features(self, pc, info): + ms_mean_features = {} + ms_pc_features = [] + for scale in self.feature_list.keys(): + bxyz_indx = info[scale]['bxyz_indx'].long() + pc_feature, topview_mean = self.add_pcmean_and_gridmean(pc, bxyz_indx, return_mean=True) + pc_feature = self.multi_scale_top_layers[str(int(10*scale) if scale == 0.5 else scale)]( + pc_feature, bxyz_indx, mean=topview_mean) + ms_mean_features[scale] = topview_mean + ms_pc_features.append(pc_feature) + + agg_tpfeature = F.relu(self.aggtopmeanproj(ms_mean_features[self.target_scale])) \ + * F.relu(self.aggtopproj(torch.cat(ms_pc_features, dim=1))) + agg_tpfeature = self.aggfusion(agg_tpfeature) + + bxyz_indx_tgt = info[self.target_scale]['bxyz_indx'].long() + index, value = torch.unique(bxyz_indx_tgt, return_inverse=True, dim=0) + maxf = torch_scatter.scatter_max(agg_tpfeature, value, dim=0)[0] + + return maxf, index, value + + def forward(self, pc, indicator): + indicator_t = [] + tensor = torch.ones((1,), dtype=torch.long).to(pc) + for i in range(len(indicator) - 1): + indicator_t.append(tensor.new_full((indicator[i+1] - indicator[i],), i)) + indicator_t = torch.cat(indicator_t, dim=0) + info = {'batch': len(indicator)-1} + self.sizes = self.sizes.to(pc) + + for scale in self.pooling_scales: + xidx, xres = quantitize(pc[:, 0], self.lims[0][0], + self.lims[0][1], self.sizes[0] // scale, with_res=True) + yidx, yres = quantitize(pc[:, 1], self.lims[1][0], + self.lims[1][1], self.sizes[1] // scale, with_res=True) + zidx, zres = quantitize(pc[:, 2], self.lims[2][0], + self.lims[2][1], self.sizes[2] // scale, with_res=True) + bxyz_indx = torch.stack([indicator_t, xidx, yidx, zidx], dim=-1) + xyz_res = torch.stack([xres, yres, zres], dim=-1) + info[scale] = {'bxyz_indx': bxyz_indx, 'xyz_res': xyz_res} + + voxel_feature, coord_ind, full_coord = self.extract_geometry_features(pc, info) + + return voxel_feature, coord_ind, full_coord, info + + diff --git a/networks/semantic_segmentation.py b/networks/semantic_segmentation.py new file mode 100644 index 0000000..50856b9 --- /dev/null +++ b/networks/semantic_segmentation.py @@ -0,0 +1,295 @@ +import numpy as np +import torch.nn as nn +import torch +import torch.nn.functional as F +import torch_scatter + +import spconv.pytorch as spconv +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from utils.lovasz_losses import lovasz_softmax + +class BasicBlock(spconv.SparseModule): + def __init__(self, C_in, C_out, indice_key): + super(BasicBlock, self).__init__() + self.layers_in = spconv.SparseSequential( + spconv.SubMConv3d(C_in, C_out, 1, indice_key=indice_key, bias=False), + nn.BatchNorm1d(C_out), + ) + self.layers = spconv.SparseSequential( + spconv.SubMConv3d(C_in, C_out, 3, indice_key=indice_key, bias=False), + nn.BatchNorm1d(C_out), + nn.LeakyReLU(0.1), + spconv.SubMConv3d(C_out, C_out, 3, indice_key=indice_key, bias=False), + nn.BatchNorm1d(C_out) + ) + self.relu2 = spconv.SparseSequential( + nn.LeakyReLU(0.1) + ) + + def forward(self, x): + identity = self.layers_in(x) + out = self.layers(x) + output = spconv.SparseConvTensor(sum([i.features for i in [identity, out]]), + out.indices, out.spatial_shape, out.batch_size) + output.indice_dict = out.indice_dict + output.grid = out.grid + return self.relu2(output) + + +def make_layers_sp(C_in, C_out, blocks, indice_key): + layers = [] + layers.append(BasicBlock(C_in, C_out, indice_key)) + for _ in range(1, blocks): + layers.append(BasicBlock(C_out, C_out, indice_key)) + return spconv.SparseSequential(*layers) + + +def scatter(x, idx, method, dim=0): + if method == "max": + return torch_scatter.scatter_max(x, idx, dim=dim)[0] + elif method == "mean": + return torch_scatter.scatter_mean(x, idx, dim=dim) + elif method == "sum": + return torch_scatter.scatter_add(x, idx, dim=dim) + else: + print("unknown method") + exit(-1) + + +def gather(x, idx): + """ + :param x: voxelwise features + :param idx: + :return: pointwise features + """ + return x[idx] + + +def voxel_sem_target(point_voxel_coors, sem_label): + """make sparse voxel tensor of semantic labels + Args: + point_voxel_coors(N, bxyz): point-wise voxel coors + sem_label(N, ): point-wise semantic label + Return: + unq_sem(M, ): voxel-wise semantic label + unq_voxel(M, bxyz): voxel-wise voxel coors + """ + voxel_sem = torch.cat([point_voxel_coors, sem_label.reshape(-1, 1)], dim=-1) + unq_voxel_sem, unq_sem_count = torch.unique(voxel_sem, return_counts=True, dim=0) + unq_voxel, unq_ind = torch.unique(unq_voxel_sem[:, :4], return_inverse=True, dim=0) + label_max_ind = torch_scatter.scatter_max(unq_sem_count, unq_ind)[1] + unq_sem = unq_voxel_sem[:, -1][label_max_ind] + return unq_sem, unq_voxel + + +class SFE(spconv.SparseModule): + def __init__(self, in_channels, out_channels, layer_name, layer_num=1): + super().__init__() + self.spconv_layers = make_layers_sp(in_channels, out_channels, layer_num, layer_name) + + def forward(self, inputs): + conv_features = self.spconv_layers(inputs) + return conv_features + + +class SGFE(nn.Module): + def __init__(self, input_channels, output_channels, reduce_channels, name, p_scale=[2, 4, 6, 8]): + super().__init__() + self.inplanes = input_channels + self.input_channels = input_channels + self.output_channels = output_channels + self.name = name + + self.feature_reduce = nn.Linear(input_channels, reduce_channels) + self.pooling_scale = p_scale + self.fc_list = nn.ModuleList() + self.fcs = nn.ModuleList() + for _, _ in enumerate(self.pooling_scale): + self.fc_list.append(nn.Sequential( + nn.Linear(reduce_channels, reduce_channels//2), + nn.ReLU(), + )) + self.fcs.append(nn.Sequential(nn.Linear(reduce_channels//2, reduce_channels//2))) + self.scale_selection = nn.Sequential( + nn.Linear(len(self.pooling_scale) * reduce_channels//2, + reduce_channels),nn.ReLU(), + ) + self.fc = nn.Sequential(nn.Linear(reduce_channels//2, reduce_channels//2, bias=False), + nn.ReLU(inplace=False)) + self.out_fc = nn.Linear(reduce_channels//2, reduce_channels, bias=False) + self.linear_output = nn.Sequential( + nn.Linear(2 * reduce_channels, reduce_channels, bias=False), + nn.ReLU(), + nn.Linear(reduce_channels, output_channels), + ) + + def forward(self, coords_info, input_data, output_scale, input_coords=None, input_coords_inv=None): + + reduced_feature = F.relu(self.feature_reduce(input_data)) + output_list = [reduced_feature] + for j, ps in enumerate(self.pooling_scale): + index = torch.cat([input_coords[:, 0].unsqueeze(-1), + (input_coords[:, 1:] // ps).int()], dim=1) + unq, unq_inv = torch.unique(index, return_inverse=True, dim=0) + fkm = scatter(reduced_feature, unq_inv, method="mean", dim=0) + att = self.fc_list[j](fkm)[unq_inv] + out = ( att) + output_list.append(out) + scale_features = torch.stack(output_list[1:], dim=1) + feat_S = scale_features.sum(1) + feat_Z = self.fc(feat_S) + attention_vectors = [fc(feat_Z) for fc in self.fcs] + attention_vectors = torch.sigmoid(torch.stack(attention_vectors, dim=1)) + scale_features = self.out_fc(torch.sum(scale_features * attention_vectors, dim=1)) + + output_f = torch.cat([reduced_feature, scale_features], dim=1) + proj = self.linear_output(output_f) + proj = proj[input_coords_inv] + + index = torch.cat([coords_info[output_scale]['bxyz_indx'][:, 0].unsqueeze(-1), + torch.flip(coords_info[output_scale]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + + unq, unq_inv = torch.unique(index, return_inverse=True, dim=0) + tv_fmap = scatter(proj, unq_inv, method="max", dim=0) + return tv_fmap, unq, unq_inv + + +class SemanticBranch(nn.Module): + def __init__(self, sizes=[256, 256, 32], nbr_class=19, init_size=32, class_frequencies=None, phase='trainval'): + super().__init__() + self.class_frequencies = class_frequencies + self.sizes = sizes + self.nbr_class = nbr_class + self.conv1_block = SFE(init_size, init_size, "svpfe_0") + self.conv2_block = SFE(64, 64, "svpfe_1") + self.conv3_block = SFE(128, 128, "svpfe_2") + + self.proj1_block = SGFE(input_channels=init_size, output_channels=64,\ + reduce_channels=init_size, name="proj1") + self.proj2_block = SGFE(input_channels=64, output_channels=128,\ + reduce_channels=64, name="proj2") + self.proj3_block = SGFE(input_channels=128, output_channels=256,\ + reduce_channels=128, name="proj3") + + self.phase = phase + if phase == 'trainval': + num_class = self.nbr_class # SemanticKITTI: 19 + self.out2 = nn.Sequential( + nn.Linear(64, 64, bias=False), + nn.BatchNorm1d(64, ), + nn.LeakyReLU(0.1), + nn.Linear(64, num_class) + ) + self.out4 = nn.Sequential( + nn.Linear(128, 64, bias=False), + nn.BatchNorm1d(64, ), + nn.LeakyReLU(0.1), + nn.Linear(64, num_class) + ) + self.out8 = nn.Sequential( + nn.Linear(256, 64, bias=False), + nn.BatchNorm1d(64, ), + nn.LeakyReLU(0.1), + nn.Linear(64, num_class) + ) + + + def bev_projection(self, vw_features, vw_coord, sizes, batch_size): + unq, unq_inv = torch.unique( + torch.cat([vw_coord[:, 0].reshape(-1, 1), vw_coord[:, -2:]], dim=-1).int(), return_inverse=True, dim=0) + bev_fea = scatter(vw_features, unq_inv, method='max') + bev_dense = spconv.SparseConvTensor(bev_fea, unq.int(), sizes[-2:], batch_size).dense() # B, C, H, W + + return bev_dense + + def forward_once(self, vw_features, coord_ind, full_coord, pw_label, info): + batch_size = info['batch'] + if pw_label is not None: + pw_label = torch.cat(pw_label, dim=0) + + coord = torch.cat([coord_ind[:, 0].reshape(-1, 1), torch.flip(coord_ind, dims=[1])[:, :3]], dim=1) + input_tensor = spconv.SparseConvTensor( + vw_features, coord.int(), np.array(self.sizes, np.int32)[::-1], batch_size + ) + conv1_output = self.conv1_block(input_tensor) + proj1_vw, vw1_coord, pw1_coord = self.proj1_block(info, conv1_output.features, output_scale=2, input_coords=coord.int(), + input_coords_inv=full_coord) + proj1_bev = self.bev_projection(proj1_vw, vw1_coord, (np.array(self.sizes, np.int32) // 2)[::-1], batch_size) + + conv2_input_tensor = spconv.SparseConvTensor( + proj1_vw, vw1_coord.int(), (np.array(self.sizes, np.int32) // 2)[::-1], batch_size + ) + conv2_output = self.conv2_block(conv2_input_tensor) + proj2_vw, vw2_coord, pw2_coord = self.proj2_block(info, conv2_output.features, output_scale=4, input_coords=vw1_coord.int(), + input_coords_inv=pw1_coord) + proj2_bev = self.bev_projection(proj2_vw, vw2_coord, (np.array(self.sizes, np.int32) // 4)[::-1], batch_size) + + conv3_input_tensor = spconv.SparseConvTensor( + proj2_vw, vw2_coord.int(), (np.array(self.sizes, np.int32) // 4)[::-1], batch_size + ) + conv3_output = self.conv3_block(conv3_input_tensor) + proj3_vw, vw3_coord, _ = self.proj3_block(info, conv3_output.features, output_scale=8, input_coords=vw2_coord.int(), + input_coords_inv=pw2_coord) + proj3_bev = self.bev_projection(proj3_vw, vw3_coord, (np.array(self.sizes, np.int32) // 8)[::-1], batch_size) + + + if self.phase == 'trainval': + index_02 = torch.cat([info[2]['bxyz_indx'][:, 0].unsqueeze(-1), + torch.flip(info[2]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + index_04 = torch.cat([info[4]['bxyz_indx'][:, 0].unsqueeze(-1), + torch.flip(info[4]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + index_08 = torch.cat([info[8]['bxyz_indx'][:, 0].unsqueeze(-1), + torch.flip(info[8]['bxyz_indx'], dims=[1])[:, :3]], dim=1) + vw_label_02 = voxel_sem_target(index_02.int(), pw_label.int())[0] + vw_label_04 = voxel_sem_target(index_04.int(), pw_label.int())[0] + vw_label_08 = voxel_sem_target(index_08.int(), pw_label.int())[0] + return dict( + mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev], + mss_logits_list = [ + [vw_label_02.clone(), self.out2(proj1_vw)], + [vw_label_04.clone(), self.out4(proj2_vw)], + [vw_label_08.clone(), self.out8(proj3_vw)]] + ) + + return dict( + mss_bev_dense = [proj1_bev, proj2_bev, proj3_bev] + ) + + def forward(self, data_dict, example): + if self.phase == 'trainval': + out_dict = self.forward_once(data_dict['vw_features'], + data_dict['coord_ind'], data_dict['full_coord'], example['points_label'], data_dict['info']) + all_teach_pair = out_dict['mss_logits_list'] + + class_weights = self.get_class_weights().to(device=data_dict['vw_features'].device, dtype=data_dict['vw_features'].dtype) + loss_dict = {} + for i in range(len(all_teach_pair)): + teach_pair = all_teach_pair[i] + voxel_labels_copy = teach_pair[0].long().clone() + voxel_labels_copy[voxel_labels_copy == 0] = 256 + voxel_labels_copy = voxel_labels_copy - 1 + + res04_loss = lovasz_softmax(F.softmax(teach_pair[1], dim=1), voxel_labels_copy, ignore=255) + res04_loss2 = F.cross_entropy(teach_pair[1], voxel_labels_copy, weight=class_weights, ignore_index=255) + loss_dict["vw_" + str(i) + "lovasz_loss"] = res04_loss + loss_dict["vw_" + str(i) + "ce_loss"] = res04_loss2 + return dict( + mss_bev_dense=out_dict['mss_bev_dense'], + loss=loss_dict + ) + else: + out_dict = self.forward_once(data_dict['vw_features'], + data_dict['coord_ind'], data_dict['full_coord'], None, data_dict['info']) + return out_dict + + def get_class_weights(self): + ''' + Class weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf) + ''' + epsilon_w = 0.001 # eps to avoid zero division + weights = torch.from_numpy(1 / np.log(np.array(self.class_frequencies) + epsilon_w)) + + return weights diff --git a/scripts/run_train.sh b/scripts/run_train.sh new file mode 100644 index 0000000..4f6426c --- /dev/null +++ b/scripts/run_train.sh @@ -0,0 +1 @@ + python train.py --cfg /home/jmwang/OccRWKV/cfgs/2024.6.11.yaml --dset_root /home/jmwang/datasets/semantic_kitti/dataset/sequences diff --git a/test.py b/test.py new file mode 100644 index 0000000..6c9d916 --- /dev/null +++ b/test.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +import torch.nn as nn +import sys +import numpy as np +import time + +# Append root directory to system path for imports +repo_path, _ = os.path.split(os.path.realpath(__file__)) +repo_path, _ = os.path.split(repo_path) +sys.path.append(repo_path) + +from utils.seed import seed_all +from utils.config import CFG +from utils.dataset import get_dataset +from utils.model import get_model +from utils.logger import get_logger +from utils.io_tools import dict_to, _create_directory +import utils.checkpoint as checkpoint + + +def parse_args(): + parser = argparse.ArgumentParser(description='DSC validating') + parser.add_argument( + '--weights', + dest='weights_file', + default='', + metavar='FILE', + help='path to folder where model.pth file is', + type=str, + ) + parser.add_argument( + '--dset_root', + dest='dataset_root', + default=None, + metavar='DATASET', + help='path to dataset root folder', + type=str, + ) + parser.add_argument( + '--out_path', + dest='output_path', + default='', + metavar='OUT_PATH', + help='path to folder where predictions will be saved', + type=str, + ) + args = parser.parse_args() + return args + + +def test(model, dset, _cfg, logger, out_path_root): + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + # Moving optimizer and model to used device + model = model.to(device=device) + logger.info('=> Passing the network on the test set...') + model.eval() + inv_remap_lut = dset.dataset.get_inv_remap_lut() + time_list = [] + + with torch.no_grad(): + + for t, (data, indices) in enumerate(dset): + + data = dict_to(data, device) + # torch.cuda.synchronize() + start_time = time.time() + scores = model(data) # [b,20,256,256,32] + # torch.cuda.synchronize() + time_list.append(time.time() - start_time) + for key in scores: + scores[key] = torch.argmax(scores[key], dim=1).data.cpu().numpy() + + curr_index = 0 + for score in scores['pred_semantic_1_1']: + score = score.reshape(-1).astype(np.uint16) + score = inv_remap_lut[score].astype(np.uint16) + input_filename = dset.dataset.filepaths['occupancy'][indices[curr_index]] + filename, extension = os.path.splitext(os.path.basename(input_filename)) + sequence = os.path.dirname(input_filename).split('/')[-2] + out_filename = os.path.join(out_path_root, 'sequences', sequence, 'predictions', filename + '.label') + _create_directory(os.path.dirname(out_filename)) + score.tofile(out_filename) + logger.info('=> Sequence {} - File {} saved'.format(sequence, os.path.basename(out_filename))) + curr_index += 1 + + return time_list + + +def main(): + + # https://github.com/pytorch/pytorch/issues/27588 + torch.backends.cudnn.enabled = False + + seed_all(0) + + args = parse_args() + + weights_f = args.weights_file + dataset_f = args.dataset_root + out_path_root = args.output_path + + assert os.path.isfile(weights_f), '=> No file found at {}' + + checkpoint_path = torch.load(weights_f) + config_dict = checkpoint_path.pop('config_dict') + config_dict['DATASET']['DATA_ROOT'] = dataset_f + + # Read train configuration file + _cfg = CFG() + _cfg.from_dict(config_dict) + # Setting the logger to print statements and also save them into logs file + logger = get_logger(out_path_root, 'logs_val.log') + + logger.info('============ Test weights: "%s" ============\n' % weights_f) + dataset = get_dataset(_cfg._dict)['test'] + + logger.info('=> Loading network architecture...') + model = get_model(_cfg._dict, phase='test') + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + model = model.module + + logger.info('=> Loading network weights...') + model = checkpoint.load_model(model, weights_f, logger) + + time_list = test(model, dataset, _cfg, logger, out_path_root) + + logger.info('=> ============ Network Test Done ============') + + logger.info('Inference time per frame is %.4f seconds\n' % (np.sum(time_list) / len(dataset.dataset))) + + exit() + + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..b012155 --- /dev/null +++ b/train.py @@ -0,0 +1,332 @@ +# -*- coding:utf-8 -*- +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +import argparse +import torch +import torch.nn as nn +from tensorboardX import SummaryWriter +import sys +import numpy as np + +# Append root directory to system path for imports +repo_path, _ = os.path.split(os.path.realpath(__file__)) +repo_path, _ = os.path.split(repo_path) +sys.path.append(repo_path) + +from utils.seed import seed_all +from utils.config import CFG +from utils.dataset import get_dataset +from utils.model import get_model +from utils.logger import get_logger +from utils.optimizer import build_optimizer, build_scheduler +from utils.io_tools import dict_to +from utils.metrics import Metrics +import utils.checkpoint as checkpoint + +def parse_args(): + parser = argparse.ArgumentParser(description='DSC training') + parser.add_argument( + '--cfg', + dest='config_file', + default='cfgs/SSC-RS.yaml', + metavar='FILE', + help='path to config file', + type=str, + ) + parser.add_argument( + '--dset_root', + dest='dataset_root', + default=None, + metavar='DATASET', + help='path to dataset root folder', + type=str, + ) + args = parser.parse_args() + return args + +def fast_hist_crop(output, target, unique_label): + hist = fast_hist(output.flatten(), target.flatten(), + np.max(unique_label) + 1) # 19*19 + hist = hist[unique_label, :] + hist = hist[:, unique_label] + return hist + +def fast_hist(pred, label, n): # n==19 + k = (label >= 0) & (label < n) + bin_count = np.bincount( + n * label[k].astype(int) + pred[k], + minlength=n**2) + return bin_count[:n**2].reshape(n, n) + +def per_class_iu(hist): + return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + +def train(model, optimizer, scheduler, dataset, _cfg, start_epoch, logger, tbwriter): + """ + Train a model using the PyTorch Module API. + Inputs: + - model: A PyTorch Module giving the model to train. + - optimizer: An Optimizer object we will use to train the model + - scheduler: Scheduler for learning rate decay if used + - dataset: The dataset to load files + - _cfg: The configuration dictionary read from config file + - start_epoch: The epoch at which start the training (checkpoint) + - logger: The logger to save info + - tbwriter: The tensorboard writer to save plots + Returns: Nothing, but prints model accuracies during training. + """ + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + # Moving optimizer and model to used device + model = model.to(device=device) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + + dset = dataset['train'] + + nbr_epochs = _cfg._dict['TRAIN']['EPOCHS'] + nbr_iterations = len(dset) # number of iterations depends on batchs size + + # Defining metrics class and initializing them.. + metrics = Metrics(_cfg._dict['DATASET']['NCLASS'], nbr_iterations, model.get_scales()) + metrics.reset_evaluator() + metrics.losses_track.set_validation_losses(model.get_validation_loss_keys()) + metrics.losses_track.set_train_losses(model.get_train_loss_keys()) + + for epoch in range(start_epoch, nbr_epochs + 1): + + logger.info('=> =============== Epoch [{}/{}] ==============='.format(epoch, nbr_epochs)) + logger.info('=> Reminder - Output of routine on {}'.format(_cfg._dict['OUTPUT']['OUTPUT_PATH'])) + + # Print learning rate + # for param_group in optimizer.param_groups: + logger.info('=> Learning rate: {}'.format(scheduler.get_last_lr()[0])) + + model.train() # put model to training mode + + # for t, (data, indices) in enumerate(dataset['train']): + for t, (data, indices) in enumerate(dset): + + data = dict_to(data, device) + scores, loss = model(data) # [b,20,256,256,32] + + # Zero out the gradients. + optimizer.zero_grad() + # Backward pass: gradient of loss wr. each model parameter. + loss['total'].backward() + # update parameters of model by gradients. + optimizer.step() + + if _cfg._dict['SCHEDULER']['FREQUENCY'] == 'iteration': + scheduler.step() + + for l_key in loss: + tbwriter.add_scalar('train_loss_batch/{}'.format(l_key), loss[l_key].item(), len(dset) * (epoch - 1) + t) + # Updating batch losses to then get mean for epoch loss + metrics.losses_track.update_train_losses(loss) + + if (t + 1) % _cfg._dict['TRAIN']['SUMMARY_PERIOD'] == 0: + loss_print = '=> Epoch [{}/{}], Iteration [{}/{}], Learn Rate: {}, Train Losses: '\ + .format(epoch, nbr_epochs, t+1, len(dset), scheduler.get_lr()[0]) + for key in loss.keys(): + loss_print += '{} = {:.6f}, '.format(key, loss[key]) + + logger.info(loss_print[:-3]) + + metrics.add_batch(prediction=scores, target=model.get_target(data)) + + for l_key in metrics.losses_track.train_losses: + tbwriter.add_scalar('train_loss_epoch/{}'.format(l_key), metrics.losses_track.train_losses[l_key].item() / metrics.losses_track.train_iteration_counts, epoch - 1) + tbwriter.add_scalar('lr/lr', scheduler.get_lr()[0], epoch - 1) + + epoch_loss = metrics.losses_track.train_losses['total'] / metrics.losses_track.train_iteration_counts + + for scale in metrics.evaluator.keys(): + tbwriter.add_scalar('train_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch - 1) + tbwriter.add_scalar('train_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch - 1) + tbwriter.add_scalar('train_performance/{}/Seg_mIoU'.format(scale), seg_miou, epoch - 1) + tbwriter.add_scalar('train_performance/{}/Precision'.format(scale), metrics.get_occupancy_Precision(scale).item(), epoch-1) + tbwriter.add_scalar('train_performance/{}/Recall'.format(scale), metrics.get_occupancy_Recall(scale).item(), epoch-1) + tbwriter.add_scalar('train_performance/{}/F1'.format(scale), metrics.get_occupancy_F1(scale).item(), epoch-1) + + logger.info('=> [Epoch {} - Total Train Loss = {}]'.format(epoch, epoch_loss)) + for scale in metrics.evaluator.keys(): + loss_scale = metrics.losses_track.train_losses['semantic_{}'.format(scale)].item() / metrics.losses_track.train_iteration_counts + logger.info('=> [Epoch {} - Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} - Seg_mIoU = {:.6f}' + ' - P = {:.6f} - R = {:.6f} - F1 = {:.6f}]'.format(epoch, scale, loss_scale, + metrics.get_semantics_mIoU(scale).item(), + metrics.get_occupancy_IoU(scale).item(), + 0, + metrics.get_occupancy_Precision(scale).item(), + metrics.get_occupancy_Recall(scale).item(), + metrics.get_occupancy_F1(scale).item(), + )) + + logger.info('=> Epoch {} - Training set class-wise IoU:'.format(epoch)) + for i in range(1, metrics.nbr_classes): + class_name = dset.dataset.get_xentropy_class_string(i) + class_score = metrics.evaluator['1_1'].getIoU()[1][i] + logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score)) + + # Reset evaluator for validation... + metrics.reset_evaluator() + + checkpoint_info = validate(model, dataset['val'], _cfg, epoch, logger, tbwriter, metrics) + # Save checkpoints + for k in checkpoint_info.keys(): + checkpoint_path = os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'chkpt', k) + _cfg._dict['STATUS'][checkpoint_info[k]] = checkpoint_path + checkpoint.save(checkpoint_path, model, optimizer, scheduler, epoch, _cfg._dict) + + # Save checkpoint if current epoch matches checkpoint period + if epoch % _cfg._dict['TRAIN']['CHECKPOINT_PERIOD'] == 0: + checkpoint_path = os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'chkpt', str(epoch).zfill(2)) + checkpoint.save(checkpoint_path, model, optimizer, scheduler, epoch, _cfg._dict) + + # Reset evaluator and losses for next epoch... + metrics.reset_evaluator() + metrics.losses_track.restart_train_losses() + metrics.losses_track.restart_validation_losses() + + if _cfg._dict['SCHEDULER']['FREQUENCY'] == 'epoch': + scheduler.step() + + # Update config file + _cfg.update_config(resume=True) + + return metrics.best_metric_record + +def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + nbr_epochs = _cfg._dict['TRAIN']['EPOCHS'] + logger.info('=> Passing the network on the validation set...') + + model.eval() + + with torch.no_grad(): + + for t, (data, indices) in enumerate(dset): + + data = dict_to(data, device) + scores, loss = model(data) + + for l_key in loss: + tbwriter.add_scalar('validation_loss_batch/{}'.format(l_key), loss[l_key].item(), len(dset) * (epoch - 1) + t) + # Updating batch losses to then get mean for epoch loss + metrics.losses_track.update_validaiton_losses(loss) + + if (t + 1) % _cfg._dict['VAL']['SUMMARY_PERIOD'] == 0: + loss_print = '=> Epoch [{}/{}], Iteration [{}/{}], Validation Losses: '.format(epoch, nbr_epochs, t + 1, len(dset)) + for key in loss.keys(): + loss_print += '{} = {:.6f}, '.format(key, loss[key]) + logger.info(loss_print[:-3]) + + metrics.add_batch(prediction=scores, target=model.get_target(data)) + + for l_key in metrics.losses_track.validation_losses: + tbwriter.add_scalar('validation_loss_epoch/{}'.format(l_key), metrics.losses_track.validation_losses[l_key].item() / metrics.losses_track.validation_iteration_counts, epoch - 1) + + epoch_loss = metrics.losses_track.validation_losses['total'] / metrics.losses_track.validation_iteration_counts + + for scale in metrics.evaluator.keys(): + tbwriter.add_scalar('validation_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch - 1) + tbwriter.add_scalar('validation_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch - 1) + + logger.info('=> [Epoch {} - Total Validation Loss = {}]'.format(epoch, epoch_loss)) + for scale in metrics.evaluator.keys(): + loss_scale = metrics.losses_track.validation_losses['semantic_{}'.format(scale)].item() / metrics.losses_track.train_iteration_counts + logger.info('=> [Epoch {} - Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} - Seg_mIoU = {:.6f}' + ' - P = {:.6f} - R = {:.6f} - F1 = {:.6f}]'.format(epoch, scale, loss_scale, + metrics.get_semantics_mIoU(scale).item(), + metrics.get_occupancy_IoU(scale).item(), + 0, + metrics.get_occupancy_Precision(scale).item(), + metrics.get_occupancy_Recall(scale).item(), + metrics.get_occupancy_F1(scale).item(), + )) + + logger.info('=> Epoch {} - Validation set class-wise IoU:'.format(epoch)) + for i in range(1, metrics.nbr_classes): + class_name = dset.dataset.get_xentropy_class_string(i) + class_score = metrics.evaluator['1_1'].getIoU()[1][i] + logger.info(' => {}: {:.6f}'.format(class_name, class_score)) + + checkpoint_info = {} + + if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']: + logger.info('=> Best loss on validation set encountered: ({} < {})'.format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS'])) + _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item() + checkpoint_info['best-loss'] = 'BEST_LOSS' + + mIoU_1_1 = metrics.get_semantics_mIoU('1_1') + IoU_1_1 = metrics.get_occupancy_IoU('1_1') + if mIoU_1_1 > _cfg._dict['OUTPUT']['BEST_METRIC']: + logger.info('=> Best metric on validation set encountered: ({} > {})'.format(mIoU_1_1, _cfg._dict['OUTPUT']['BEST_METRIC'])) + _cfg._dict['OUTPUT']['BEST_METRIC'] = mIoU_1_1.item() + checkpoint_info['best-metric'] = 'BEST_METRIC' + metrics.update_best_metric_record(mIoU_1_1, IoU_1_1, epoch_loss.item(), epoch) + + checkpoint_info['last'] = 'LAST' + + return checkpoint_info + +def main(): + + # https://github.com/pytorch/pytorch/issues/27588 + torch.backends.cudnn.enabled = False + + seed_all(10) + + args = parse_args() + + train_f = args.config_file + dataset_f = args.dataset_root + + # Read train configuration file + _cfg = CFG() + _cfg.from_config_yaml(train_f) + + # Replace dataset path in config file by the one passed by argument + if dataset_f is not None: + _cfg._dict['DATASET']['DATA_ROOT'] = dataset_f + + # Create writer for Tensorboard + tbwriter = SummaryWriter(logdir=os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'metrics')) + + # Setting the logger to print statements and also save them into logs file + logger = get_logger(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'logs_train.log') + + logger.info('============ Training routine: "%s" ============\n' % train_f) + dataset = get_dataset(_cfg._dict) + + logger.info('=> Loading network architecture...') + model = get_model(_cfg._dict, phase='trainval') + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + model = model.module + logger.info(f'=> Model Parameters: {sum(p.numel() for p in model.parameters())/1000000.0} M') + + logger.info('=> Loading optimizer...') + optimizer = build_optimizer(_cfg, model) + scheduler = build_scheduler(_cfg, optimizer) + + model, optimizer, scheduler, epoch = checkpoint.load(model, optimizer, scheduler, _cfg._dict['STATUS']['RESUME'], _cfg._dict['STATUS']['LAST'], logger) + + best_record = train(model, optimizer, scheduler, dataset, _cfg, epoch, logger, tbwriter) + + logger.info('=> ============ Network trained - all epochs passed... ============') + + logger.info('=> [Best performance: Epoch {} - mIoU = {} - IoU {}]'.format(best_record['epoch'], best_record['mIoU'], best_record['IoU'])) + + logger.info('=> Writing config file in output folder - deleting from config files folder') + _cfg.finish_config() + logger.info('=> Training routine completed...') + + exit() + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6035f7b6658d3c485df65c0b9ea6ddcf90eba98f GIT binary patch literal 131 zcmd1j<>g`kg6N6qX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2BvKO;XkRX;1Y zJTWg_-#TZlX-=vgNPjUCkYHf|0CG(l Aga7~l literal 0 HcmV?d00001 diff --git a/utils/__pycache__/checkpoint.cpython-310.pyc b/utils/__pycache__/checkpoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c3b6bf16ccca5aee306e91db5c757163472c2c6 GIT binary patch literal 2377 zcmZ`*&2!sC6xXg+vMf7^)21IOEr=OF5JdOv`~5Ql_0LjG^OEwX!W&A7-^l zC+_+};17T^2l>dqv{z2~2RMWa@2%X#=?wDdY2WJY>V5ovZ(Xdd`4(J1{yrK`+LrY< zDytU-Sm38BlG$m=q_*k zYW17k<4t(_+~+NLwm5xib=ya$U^z!R93g*b>@{LJ!#Xe|fs`*4>N*htys zdF50P4=wIg z4aFAZko-cVS2b~d?kTVGxeFs#wv;8?a&1TtsrXV^ll5w2VdXY_6Q1*?@_1w0syYkn zf%W{epXl75x0H{W`;Bf}+??dBN!HYIS)b^yq;^nl%bGcO^ax(9kbo; z5n@BTpzNjbQ3I6T@%3a>hrw197k*&Xopev!*0-^1{v2$LKKqS%AExEI_%%~1+D zgN;oX-7U?Md|2pq?O-{8by6@4gchY{(@>5zn?Q24o0OPd?G6i(g;G1f2MC_we%hWE zQ|%{Zl7r)U6bod%uu&jOOu7t@B3PKnI4iYdLetJvK;`rXZZm+zORH{MU|SI#6q~q( zE8=o5i^T<_q{V0iIdEcAfhKx3FwXj?aL}aGHsnp)`(gT*C2qmW*{d)g z?L3a6C*MDMwgYKO%bnAuz3G9t0_IJ)%TGZt*C!rfq)jfeE5wIiK(|N-^viS|-gHQZ zuEWTuFf#LIeS|4^b0H*7`o%w>NNYkWssP=k51g;rZdVLv=C%Z`NLMxt$KL@KK_6Ih z9G-KgY!2M5{1c8l&}STOb7na10Dtd*?sD63+)>bRfa4DpQx0&vrQAsiY#u7Np5YDP z_zh!4#539EUbUv2CC6`o-aMnPdY2mS*0EzPIlde#T5WJ2)%5Z@_H20%E5OKx_ek z{vR@oh>yYO%<+P7!9;UTY*XtL1WU%TiT(52`gra)v+kKO9U5%7eo|}S%(Yik%s*W_K@Z)3U!0( z`a!6C5vb75RMbD{;^g`7^s0v-4GA1iI!wSNFG$A{mf+CB6~czLAd13)))v=ASyW!o zPEk}v4J3}Ji!0EU#8oi|ZCT8VYtUB20;u$Xwl=o!c2T3*hkXWR14@QqL}W~Su+$?j z+3T;%E_@qtkY#Q+?Pp33yDF7NdV3JaAVG02I`9X9I=pS<&6iXcd=YgOba)oVy_ZB* z;N;i0Ru9u8Ts=yj2K|H8R=2zN=*NevLlwo@YJSYh;F)%OKk7%y_e-$n9}bmWY!8<0Xs+9aI~fpd};5bfe59pwD#K%_S^o0_4Yk&=L2nf?e+G) zr|peabNl|?W}~F}vml8zk71YQ=^*TP%G<*@&QFm*Jfiup;j>I%%YY>5Pz{F4??rLg z50X%q4IG&YRHkdgf#FaNSA$7(S*LQ&2?pxCs^Y;{dooS@=>pW_8c4G`R38@SnL`~~ zhr-!t0Rg_e_${C>aAgc|r^t=;LI)DhRTd;r`MzW0!Q&| zEXKh{$8uq*3_4CAZa!V5@SE70ZTK@@jA_q z(x?w}=wR)j(^UjyC?t5bE`k8>&d)>sUBXx{!9L7d8M46Ub#U`C6o{lnMEV^>+kVU0 z$an)f9yF%XhKL{-5sFR_a7!V4*31dQI%X#f;)Mf@h0%>1#^6afO%g6r+6X0*P%c2F z=~F;Jmr)-XX2T?u*HO&HFc}z3z7N?d3gT!Kx=IERto;Mm zM60_Tt%NIXI7qvPZW0DSAaLOq?k{1c+&Im$X*UU#jJjFFKog@%Z?62%_|r0s(WN_E zUfbVkd0zbQ8~DtWt{Ca=R(tRMh9(A2&3D$lcGEII5;s702Z}+dZYTjL)e)t$ z6!U(FT-a(rn8sXIHwhKq3XD1g%(#8on6Yp2l$*QDz%k5RjOhsw{eLc+ZZY+dUCrrF_in+MA>e5B5Zd13zjp?4cLDKbH7!ry^B zzoFKcduNh0qVjCVQ~{p7Sr|_3I*m2?Ih@t7jR_he3C+q4z>t#-tj+~nZsny&`w|vC zf?_zd!0K#)!P{{KIdl%CW`k}ws`!-S3+DT}?E8?ohB5Z5zW>`Wi1Qg)!&^Vb>Lym0 zwoEunWUEBi%dfD)i_a0-sD-IbBC7cCWSCAI;xMP|%sJ(9#VNthsdBRdGN~~V-KROm bVV-qy_q&eO&RVYUx0oi3wQx4AvU%%2=g=$@ literal 0 HcmV?d00001 diff --git a/utils/__pycache__/dataset.cpython-310.pyc b/utils/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65103f0e7f2face1831543bfe416a6ece0a00ec0 GIT binary patch literal 967 zcmY*YOK;Oa5MIBZDdAO7PF#@irTGI9!n7$^O%oKGRxF2=@j8jAAIRQNK<YWYh==!Y<_FByzyAk7m{6j3Qhg@h$Z zU^$ly1uKP$Rj?JRennWLT8KRT61I4zJpkme_-7V7RmJrLzouKnItchI@iV-2v}V;!*4V*{|tP%E`EHCv@s zsRi4^GYEhf{u z7g}plbBuu1TN_ySKYOoZ4SN3uI~b$Zg!ygHxUug?emIG^8ohZn!B?XTqA1{&xFw7E zJi^Z2e0C7LAP2#8g4Jno6k;t5@Fe6KtK+PUJQ=(3W>usmi&<8cc+jjj%VnODtfc8D zu6>O2j2>{6r48d+wD5x&wx9da-b*49apl;bVp~+4&i(z#0z0$UN91I_IGij(Pv!Qz zm?dk{WM5M5l$(N_RrP6FH{9q8)$p^81y#>X; zqxu0HzaNj*RgsR~7iV#~8qJeradLP(+ORBdMt#QJ51+ZSN*U>G9paitpoWs$?`=tJ z^`9komRvssrrBzQ9C`L7tttU?{(Fq8O`@t4zi+vC0Xp66v-E>Xpp; z%Gfw?Q7_BYC>g3wL>s%s*|}rAcc8a1*S`*0gJP&!=fjuDsNeb_O`ksd{>RqDX1Qr) zMb{QZZjQz?SM4j?#dsAu$uz1+IePdS`5d>d*qV*@5B6|%u%RLgKImaFwCqP8Yk>*jGU3J zt6uGJf($i5919nAX<6@5sNjW~xp11~ldA;o;i}IT}`+xr{ z*eV+Pg8dE$7S7oUF<&$Nna7t(7};1eLYSaI*6K@+tHRrcgz%bEDPQc$08tk617D+0 z)v%bVrp&ZTZJ}pP)Dvy8smf>G;8rFC7S+coifLfy$q*EAg9eXw%rzZQ(@ihor}8aN zPh2)pJopWmA9l4*1JzCNyNIC6#0#WhoY(Vsf0Uh#T>P7J*Tby1^}9O5vpYRcV5aUrsw4*av=E6|_Sn{7D1q0mL`VIe2EIwl0l Q_XQ$^DGH*ks9xXx7oguwga7~l literal 0 HcmV?d00001 diff --git a/utils/__pycache__/logger.cpython-310.pyc b/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d665ba0fc8f5940ec93f36aef34a37a8d1b28a23 GIT binary patch literal 747 zcmY*X%We}f6!l{s&U8|!#HNKs7cDeMZGHfuN+=N(D3J)F5Ef`Ob|y2OM`gRDjV1^o zet{yrKOpu$92J zvFsAosKyJjz%!?IZuy&#dTH zkE=qovtkmJNxL7%uTEdQX7)TcJAu`=^6frqAAIet@f`pmPCoR*Xw8nuPj+wYb`6vQnurh5~)eeZj(x-&5m82Ekio0a94PZ`EP)5+-1$H{pd=@DTV(x@3yNV6r{ zX3cC{HA~RB)v{|g{@X35?bh5uo7eVDLm2H+Ef}Z|d6bL3Y3lpfNW^VJdeXmZ)b`0!vMeiiMQujT%Skzf+J1RjPRo6$&B__MAETa< zXXLDW3N3R#G31;)aMzFrVtdtJJ+OLk%c>nh|3P^O{ZGpUc~~An?XY}RJ|mw+?Ff30 z$`G|@PHyb!B)l!ULI zS$H$NwD|UoQ`Z--gztS2UX8mOaS|SX|H_TIn=O*POlz>_+9L@*SlGqhAS$P%{a-z&T^RDj6=P^VSFdUV&GlF z+w*EZ3)7CuVj0FtbyUhTEXHy8c57umT%hG%kJe+w?TTLbAr)Wa!0($MKYWB=)h?`7 zoD^2tZ5PgRs}p5~+vv7&T6K%kN>q?z8B2MSmH;eM1h2D7vSEth` ze8(3vVop?qFG67p^%NQvS>=s}gL3#QqeAB_yDw46?X;8|%OV%c!ZiCvU#vhHB+*=K zneP~%&3|h4-M*K5QmmP|OXYbjXMeOU_i}5^I&WOieWkpVF63rENc_yq{VjpAv?UH3 z{UCR?tvuKf!nie^*|`JF2_W5I)_`QK-W6W|E1eI*<&L5wUQwMcETqwjQZ2NlW)i8* zvyjzZie*~3 zt!OE3rG<0p{p&YAF5Koyf(Ix{P>XuA9j(NL-RyMJ!lvPc!=u$cnw2C`SmSC7K z0<_ij6F-OZ^dJh`nS_jH#H=`i|E`!eNp8MSFQD^sgA8o=n?Qx$g9Y~YsN~KQ1m=zi z>}?3l>-!L6DHFN5B_PBFEPIgKYnFO1_wxXP^r(a&VW$1^XdCJn$AYB3N?u9N=Ou{D zVv#*fb%?DDk=l8ML^^19F{Y%&3d1$;5-hl18t)J^xYEj;Z8vv!1dG}3y~ZMXTd8Q1 zjCLsRIBMJsk0-ELz4Wp=BKI*AGO-qfLW?+3TIm(&>}PIY;I?8*{KmXxE^70t3Uvtm zz(RBWjCvl`Z^`IftC?mEpvhZ5q3LMuS)+FVqMJY0>NJ3eZ=o_AqvmKjPaVL&!cnl< zSgRJvR_m9UtQSs1Z&#gAPgC<@8Z0iVBXsf%70*&}l!}lF(hh#QLNxbu>aj=V0YAXV zl&-ODtw%Q7z1F6Rm^00vSD;;!f=4%Km4-jEeL5^0>6<8WQ_W>y9*dUZG{8|vt!Hvsh=bu!ep z2A9C7k*RKWGi*ocS{fcd69Q5ehEwbs_1Ta_4@bhjgMBahLN<7~g*1yJNlHqMOHl@Y zOwK19sOZ^nqthI~Q`nT;H>Rny)?~u3(VCw@@e9(_{V-F}H!}k(F1AexGs;c5^Yz^uJ zO?eUpAlKD`XU*5jPHJD+w7^=eX)*KzR7P)d4bAB!3fpH~d}Q0xkKof@M^l4rbNDNv zvg^%qoM_u8AEhdqjYyLK2TTKNyaS__HrR$K?c9J)0~}|GeE`r_og8dMI?}ywK{7TN ziL>ewFUg%lM&Hf6+=V@OnZ502rEP%ej=&7wme?`-e(pm`Uhf}S`2SOLxc*&+0RPft z&hY)L0dszr9ccIVxFKk$JsWawm~{q<%`d~6_w{hWR|b{ynpyx-E!k`rfs9koP6_T> zEh~x?_FSzZ_$i!>cx+)4J1T4$Wf3zJT;9nFak;Qn-0D_K>Q%ayq0&k5Sn$Z2tQ8ab z|BznUu*J3_9Vu)gLW+#*ajYa|{39If9x7hNoEFpYxi0(KIbumPqCxgK{E_RW^NTpr zcThl!5WDP%KJaf)gPlsV3zturxC}1?9WmD|uv{azper`c>A#bM-B3SR&1bxrm$jv) zA<+jdUZ(j(Ss7R$Y>lOjh&Vif%+KO>X9E!d<6Riv>P9UR(J?KwV_HaL2*l^@jb?ls zMAX$FpiPgh&e@1sObsJDS*>^EG$lu#x% zY69_}x#v-8mb9deNXY&I4!g$uy_i_cDD1d!@zPbG!5Cb^tFQ=Y+LUM&L!_|Zd4viGzLojOfjc zJ;?lUnlYsyh#@UzrFa7sQiUC(Oo5G9Wyg6RAf5wi4ieN@^rQz6{!mHoa@*b^a%*2> zYr>;h?Pj9LClE7@mOwc1q-}%@gu(M@*Nnwc@1b5e%>>@5aA7nnaaQw^c%_buo}}}I z2TFV!f#h9Wn36UJ;zh`=_tf}%4n{#$uyRS%j_dUzsMp&a*=Py=^kD*^#MJ~ z#9&PuajP|4S9T*;(aTWI^7zH@?!Hk0yXCgAC8g!uMegG`(F7KVOz4X!BokY(Jwy$H zmJwXU#9~wYEPDUMIzK?!i;6;Jgnp{CzR7Vj_x|>>jv?57n>4_Yvh|TdGioxGMC&dn zYPbCas4FTOH|lrDx&GO3^{2gALb)2w3lEr+Vx7jqK~fMB1KZkpRB}u0Cp!2AfbEHd z5N>gV5Q;4r_dTO2z7*szNynpahP3egZYyh^ViO#2!`+MrJ>@`7>#DO94cOtuZxcj) zIS*QDwwhTJ(f%%5+(pqnCwYd+;zD(PFg-yi{jN|LjwPV@KT3xr@i?4J8|{$ zWg=`NJxvKNneWbh2j3+%LUciDG+@2#H5n}nZ}3(G`TGYnG1ZwfO@Jqd1X!z*LVgl1 zS94wj>YXKZ8Z8e@^(p>6#(dQR+Wr#92Mw_cNNE=mz=U!J`n9 z(YNJbkHfw_^ZjG$W&S-me!s8Kp$Tv>Nsbo;c#a70657XLl(GZl``I)78!ALsc0gD_ zfP>u7MQXPyK-{w0D(q+-3#fgQ_8-1Ys$AGYMKZpn}~+7h{s zWMB#RC}qKFIWGZgd|6`Z9Uvc3Uag??@-pVAz<*9;!S-YxWPUb9IR_%$oM||kP3QI& zg-AiPFvJ})_wR{4*H4BQN7*>59Lx>S^u17L=zE62=?ESF?L-DPF+~XlrXvqP=z%}# z6KbHMS6$QyIRwacAa@hYE$k$%O%0a|8h5VUK%(V`xTdi4KPO9dxYend6SX>uxOLPgZDXI@pDQou<` z(xSrRt7n}$WobpDW-hN3fo=qZO7&y9_&gPFqo|dZR|av#<7}90`mb=LL{Tc9ZwBzE z2Vs&G#1SU{r@ZSlA`2Z?96_6+tH#Wf-k1=P&GZ{oM$B}M%yY;16!8N`ok+>Sz)XrA zcI>{993G4%lUo>!bubn=*}-`5Bf2JcVKBIA)}ZSKSJ}8~ad4H*S5ahZY}*v+l7ru| ziC-b_fcO@XEKyt30Tv=bUYPnu+W-hY06{($`M}6?pCDq}-7#RYC4h*J|3O~L1CD@7 zy%TI3|7`@sqaKnAVTF81?BpS_Mb^}Z6#Xk5?RbiOi(00F78VljBw2H%O9@(!;zML~ z3J3kNb`Ae7v!$Uu-O42 z6yJ=IVz`Md!J^bj;dc4_RHDvzU!o3=Zh>}JhhTGQzW zwMq9NRH+5XPRI3y`pr&Oo54M(QpYYAthr5QOFK(Nnf8T|6+-Z{`NWXFGT2SzJ;Hy* zCA`NpU7odUpil+-0Pt7lB>odVsn0Ru>%aJlet-MpvArfjA#zg$P3f;dr^bw$HUV8Q zNlJ$C?h?5vFv&yAB(cTEOwyvACNRmber?<{au5A~lY94!H4}W!&C5syS)ZHC1S?3T zI-DkgGrEgiDcUK*+!J|4kGf~@xXH|2^@eTv1hoMk-TQD;0F|lCMMBnHE#;H>6pY=L zK6*+747r2s%k*ylR6a@lWkUw{;a(X!_T}X%W7`L}_24M~M0)+1WS*XL3VEE}C*7a# zf5Hd^=Ghq766u}@i%k8$oCOk9mKOD)&EQe9t`NGC;2<+Aw`w`WxX*piU=dY>+TLFm^zau zTuJo7!dEDq(O0zff*nhSL1tDfbibFWI7S8WU`^0}O*N+Se?T>6=|t*v=5qEi7~~+h z=J3<0EgH6qf=s(qliF5M)GCd7id-ntkYF=4uc2vJtpp-hZ%~_R(XbgjbDi7hyMQ|J z`og?%y)Z9d!*RVV!#Wq13=FK$j1((|E*xIEF)jHm# z>fzf8 z>EJ7yLTo&?(yXATX`Mq%<=eIh5X#XOy8Z~m7v_J=*hJrL99o?%whlzNgUaQqdE*)j z6fG1dU&ihWSE|-EX3N(o+fO~}S5(ZQsG65G&3K{k>UG&^=$!joiuTpm-28PEt6oo*`;L46iLr6 zZHZX*5T|G^0h$6m6etP_J-L^n=U#f~U(n;83iJ^8Z1mEL`}=*%rMMJ*2+}3?o0*;G z_nYrE!|wci)xh)q=bIb9J8c;MW?}rXP`HFQz6=r>k!l+aC8^mkCAAvXi$-Kd)+a_} z-7*_CN_OO+gDk!;;x?PNWEwwdI0H6REv%wt;(EBS*6wOzi=aHuZNtcc$l~GGL2sb zN!6!@ZzO8VOifTLQlFXlV)M(djSVHM@e=1#W6QXup2s`Z%Kg;XuN)W(NP$m#{kF=S zDB0~bLxrJ>$3%v6F5!&MC^aJioPlGw5ude?1J}dheBY&-as7J3hbNOmr)Z&-He* z>RP)MCt-KHcd^EKxro!o;L#x}=X}4_X(hgYZmr!7gLZronW978%Vs)$*bd^j8E0mv zmpOsf!EWZXI?4CGqg{-ZIU4gf0dps8_oF7t&A6G=ElmN_^CWC1i`FyD&4OfJF)^FP z9*Ao$T4&XqHKG<@*!%G#^rAq;U6gx4}Fh@#8{?_{vDP@vfC zs{>;X2&FV_#;=BZx@IKSzMT>`DCCwSsHoIDWdKK(6;R5{c4`W8mZ-hJq94TIn(7S( z$!BJV-QI4l1R@i5)Cx8;EADSYN+|wp(b>q%q?ef+d^<;2-)tsVk|2qnL~~)iYA&j( zRZ}%pRfnt>%}T6OxhRt&CNCqC8vDk9@qx7m2FcPMm`%--5Q~$r^e2ryC$&@OjIrk; zg}AR6sgtK?jC~W3Rnki8ZqQDd>tCFB_lrON#ro$ORQ?k z0q2c?z={w`AVrW8xk#O;5_w47XeO#6twgiY98xcuj~0;5L=Ql#Ygz5pF8$xz8W!w@ zUN^iNa+cy@p0D8l2r?8U-)LbUd?nO`eKCA^iNZ%fX~a_0oos<_rv+^Ys0(p+41r23 zl7iosX|uUR2dz#~I+2$~U9d0xP7-Hl#yU<`LIfxn`Xb2$%G{w>@QTQ}rmOhF-mgqD zJqi>`+*7p*t9qBlE=6ecVP#KvrY~AEt|MH8EdB_UaL^R*+j(q|+1Hx8ab|%|tSUQ> z2*MxZD2L(-`rf6_^VA&G;_Q(%PVIKz1#9kEv9L?bQ)(vHQsbcGd+;+h`YFiZ;i(H4 zZ+m2JE}zrh{T7ITH4i5Pl}$~@fSrXMUNN4A$s%wmu-$@tdKwbu!$pa`75{=Qy@Ubu zvmhfliIS!QPqMJo??I-TB`!WhySwxZG@>-8gGY~WGo{;o8PaqtnxcfYahLUTB(IW` zWiC5495>(KR?;n`J7nY6K?e2vt>tJR8o=VAB96rIQ4|VeC_+g`Pox#Fg%gK5j)|s~ z1Ns6w<){voW*VN$2yv1?m&Kxmo&x9sq|+s!j~!uRTECJtA7?)nL!7(XnU|>1?zVEhC+a zj>Ep6$SSYpr=^gRQ#@QfU+Suyfyd2yYC)c6_v}P8chk7VRHZiL%1aPx-^xurDcDUzHz&6 zl7B|6^bMX`hHt1T+Z@z}pdGpe3yq40$Q70@{s#(T=@$1aIXtPoUfG!%`VIjI*GM!R2ed8={vE7*$mD?{zp7BeT_D2joRZ_Hp2QQ$du37M0I0Mwp!JF09;d97^ENZnCuXa1_-PNSK+UjiL zHehUGm|urmeN5@QT7M}nH5kUNO{Ah2g95YH5`49|u+-GF=MdAd}CdjR(<9IzZ5C?K>rK)GoJ2P`l3;MDD0u=3;MQ&`1qJKuf@j;VdR4pN`Vk?{H9 z2@u4FJ--`<{a(-sceA;Ta;bAzrx|YOZ-X>=pDHG&*-50AcF^|^O`p+FYke3BREV$?ZfLj!I-xt(J{t zgK4g6g_6~sNhP~Oi>LKr-|4g-Hx+WehK0_$wNgG={e{rzkt8 zrsH|}LJ3Q1BfY6uTY7+TEI%;i&IZRx8YLJi-pAOf4JDA^+sZA69I!8-ZxIDhAu{DR z=H)uq-@-jYprghD1CWsuJ96c=NZtw;7yE4dHw+~Au8WKW)Zp|uPvxdZLfa8nbbsUu zzMS{p5|!sh7v9#A1LJWn-psjKLS9dcc^^Q`JM^6rAN?H-WpNUw+#8%M7B|r@B*w*4 z!kgDrjPSj%twAy@CAkCtO0qc{C?5FB-R!Ya_8(qdAK zR)(jdqk$X)C^?^35qtbbI1q3 z+=n9uc|V_p)Z231Y7uMy5#I@8#&8)Y4+l6P;g{fb@t9*sx2w__8z*xSl7(JYHf3=H zJ7D^|QLC;oYSZk?wo*@8KfgY_sGVQePBk5wdb0fpqssV2ui`Zq%`&tjCO%nIbC}4KlcrC-MP2`~Uy| literal 0 HcmV?d00001 diff --git a/utils/__pycache__/optimizer.cpython-310.pyc b/utils/__pycache__/optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66721224a351080b21d2fe9aad75f530fc75b93c GIT binary patch literal 1189 zcmZuw%}*0S6rYdXb{EQrC89($c+$m$#;YM9LIXrvNDCq>$!6JE+Tw1vncYet+nb3e zV!U|PBmahH_3m6fk&tjQCK`=zc5MUbH1G9&{N~Ls@4XEM2Ko@3XK!k|rwT$}rE~re za2|q$6QCGk*h4-xF-D9q`VyIhDNKcb%5W7KnG-;vvLkPGPZo>}$UI2+7F3M(pw1Xn z7(*S>^C)`>6^;(|dr zG#ClZ4}7;7wSAG>cN?{OWHHyV4-8eP7Hc?>P=i*~@Hv&>!I~RcEt}iE8@W8>vXMqs zD3-HZ6RICD*Ao=*Nog)y%{6&XPiiy{U4D%CFhqXbDb(TJ!k)ixH*1Ae$Ej>Seo|;h z4KFP8z*DUQF|^Zec+Bd>BK8r$gk}_Gpad7>`JYUO;96ImtA+X8%D_SUbHe^y>UlabaVnSmAQ{ zgysazFtVGGgTd_{X-lbsgpymbwWNF^gGrx`u~6Ms;M5Z@RaYP>v}w=wcbGkML*A8d zDCqEZ+N;p1yR7ZGd}`VY9NP$5ykaBJN^@neFx6fu%Mx&6uM37&H*aS$zO+$vlb%~Onx5xp79!a6CB-~aifXPYB z5GXU7a=b6}9`e!?-PP#jufe@=5LBLOco?Tzq5LL<48`n^0`AnOl1)l@Bs-F1&sCP& X$^W((78=?(Bc~C{RnH-f=0^Si1s4r^ literal 0 HcmV?d00001 diff --git a/utils/__pycache__/seed.cpython-310.pyc b/utils/__pycache__/seed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4871a66ef0414a429bc59026ce7154d8f6101326 GIT binary patch literal 503 zcmYjO%}T>S5Z+BTZH#RN!F%u)DCPx9QB)`htr$^^hZ3`Cn~>~=%|@y8rZ3>tryp(iZu zY>VOnq6W+`AN!8$28nd5B%zf%^Z6>B1U)Eqg996$m5SGMagM(bDrXWxhyDPHPjNQ@ literal 0 HcmV?d00001 diff --git a/utils/checkpoint.py b/utils/checkpoint.py new file mode 100644 index 0000000..8d7bf0c --- /dev/null +++ b/utils/checkpoint.py @@ -0,0 +1,98 @@ +from torch.nn.parallel import DataParallel, DistributedDataParallel +import torch +import os +from glob import glob + +from .io_tools import _remove_recursively, _create_directory + + +def load(model, optimizer, scheduler, resume, path, logger): + ''' + Load checkpoint file + ''' + + # If not resume, initialize model and return everything as it is + if not resume: + logger.info('=> No checkpoint. Initializing model from scratch') + model.weights_init() + epoch = 1 + return model, optimizer, scheduler, epoch + + # If resume, check that path exists and load everything to return + else: + file_path = glob(os.path.join(path, '*.pth'))[0] + assert os.path.isfile(file_path), '=> No checkpoint found at {}'.format(path) + checkpoint = torch.load(file_path) + epoch = checkpoint.pop('startEpoch') + + s = model.module.state_dict() if isinstance(model, (DataParallel, DistributedDataParallel)) else model.state_dict() + for key, val in checkpoint['model'].items(): + if key[:6] == 'module': + key = key[7:] + + if key in s and s[key].shape == val.shape: + s[key][...] = val + elif key not in s: + print('igonre weight from not found key {}'.format(key)) + else: + print('ignore weight of mistached shape in key {}'.format(key)) + + if isinstance(model, (DataParallel, DistributedDataParallel)): + model.module.load_state_dict(s) + else: + model.load_state_dict(s) + optimizer.load_state_dict(checkpoint.pop('optimizer')) + scheduler.load_state_dict(checkpoint.pop('scheduler')) + logger.info('=> Continuing training routine. Checkpoint loaded at {}'.format(file_path)) + return model, optimizer, scheduler, epoch + + +def load_model(model, filepath, logger): + ''' + Load checkpoint file + ''' + + # check that path exists and load everything to return + assert os.path.isfile(filepath), '=> No file found at {}' + checkpoint = torch.load(filepath) + + s = model.module.state_dict() if isinstance(model, (DataParallel, DistributedDataParallel)) else model.state_dict() + for key, val in checkpoint['model'].items(): + if key[:6] == 'module': + key = key[7:] + + if key in s and s[key].shape == val.shape: + s[key][...] = val + elif key not in s: + print('igonre weight from not found key {}'.format(key)) + else: + print('ignore weight of mistached shape in key {}'.format(key)) + + if isinstance(model, (DataParallel, DistributedDataParallel)): + model.module.load_state_dict(s) + else: + model.load_state_dict(s) + logger.info('=> Model loaded at {}'.format(filepath)) + return model + + +def save(path, model, optimizer, scheduler, epoch, config): + ''' + Save checkpoint file + ''' + + # Remove recursively if epoch_last folder exists and create new one + # _remove_recursively(path) + _create_directory(path) + + weights_fpath = os.path.join(path, 'weights_epoch_{}.pth'.format(str(epoch).zfill(3))) + + torch.save({ + 'startEpoch': epoch+1, # To start on next epoch when loading the dict... + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + 'config_dict': config + }, weights_fpath) + + return weights_fpath \ No newline at end of file diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..7df4f47 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,95 @@ +import yaml +import os + +from .time import get_date_sting + + +class CFG: + + def __init__(self): + ''' + Class constructor + :param config_path: + ''' + + # Initializing dict... + self._dict = {} + return + + def from_config_yaml(self, config_path): + ''' + Class constructor + :param config_path: + ''' + + # Reading config file + self._dict = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader) + + if not 'OUTPUT_PATH' in self._dict['OUTPUT'].keys(): + self.set_output_filename() + self.init_stats() + self._dict['STATUS']['CONFIG'] = config_path + self.update_config() + + return + + def from_dict(self, config_dict): + ''' + Class constructor + :param config_path: + ''' + + # Reading config file + self._dict = config_dict + return + + def set_output_filename(self): + ''' + Set output path in the form Model_Dataset_DDYY_HHMMSS + ''' + datetime = get_date_sting() + model = self._dict['MODEL']['TYPE'] + dataset = self._dict['DATASET']['TYPE'] + OUT_PATH = os.path.join(self._dict['OUTPUT']['OUT_ROOT'], model + '_' + dataset + '_' + datetime) + self._dict['OUTPUT']['OUTPUT_PATH'] = OUT_PATH + return + + def update_config(self, resume=False): + ''' + Save config file + ''' + if resume: + self.set_resume() + yaml.dump(self._dict, open(self._dict['STATUS']['CONFIG'], 'w')) + return + + def init_stats(self): + ''' + Initialize training stats (i.e. epoch mean time, best loss, best metrics) + ''' + self._dict['OUTPUT']['BEST_LOSS'] = 999999999999 + self._dict['OUTPUT']['BEST_METRIC'] = -999999999999 + self._dict['STATUS']['LAST'] = '' + return + + def set_resume(self): + ''' + Update resume status dict file + ''' + if not self._dict['STATUS']['RESUME']: + self._dict['STATUS']['RESUME'] = True + return + + def finish_config(self): + self.move_config(os.path.join(self._dict['OUTPUT']['OUTPUT_PATH'], 'config.yaml')) + return + + def move_config(self, path): + # Remove from original path + os.remove(self._dict['STATUS']['CONFIG']) + # Change ['STATUS']['CONFIG'] to new path + self._dict['STATUS']['CONFIG'] = path + # Save to routine output folder + yaml.dump(self._dict, open(path, 'w')) + + return diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..31269cb --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,24 @@ +from torch.utils.data import DataLoader + +from datasets import SemanticKitti, collate_fn + + +def get_dataset(_cfg): + if _cfg['DATASET']['TYPE'] == 'SemanticKITTI': + data_root = _cfg['DATASET']['DATA_ROOT'] + config_file = _cfg['DATASET']['CONFIG_FILE'] + lims = _cfg['DATASET']['LIMS'] + sizes = _cfg['DATASET']['SIZES'] + ds_train = SemanticKitti(data_root, config_file, 'train', lims, sizes, augmentation=True, shuffle_index=True) + ds_val = SemanticKitti(data_root, config_file, 'valid', lims, sizes, augmentation=False, shuffle_index=False) + ds_test = SemanticKitti(data_root, config_file, 'test', lims, sizes, augmentation=False, shuffle_index=False) + dataset = {} + train_batch_size = _cfg['TRAIN']['BATCH_SIZE'] + val_batch_size = _cfg['VAL']['BATCH_SIZE'] + num_workers = _cfg['DATALOADER']['NUM_WORKERS'] + + dataset['train'] = DataLoader(ds_train, batch_size=train_batch_size, num_workers=num_workers, shuffle=True, collate_fn=collate_fn) + dataset['val'] = DataLoader(ds_val, batch_size=val_batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + dataset['test'] = DataLoader(ds_test, batch_size=2, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + + return dataset \ No newline at end of file diff --git a/utils/io_tools.py b/utils/io_tools.py new file mode 100644 index 0000000..4ea3bd4 --- /dev/null +++ b/utils/io_tools.py @@ -0,0 +1,44 @@ +import hashlib +import os +import torch + + +def get_md5(filename): + ''' + + ''' + hash_obj = hashlib.md5() + with open(filename, 'rb') as f: + hash_obj.update(f.read()) + return hash_obj.hexdigest() + +def dict_to(_dict, device): + for key, value in _dict.items(): + if type(_dict[key]) is dict: + _dict[key] = dict_to(_dict[key], device) + if type(_dict[key]) is list: + _dict[key] = [v.to(device) for v in _dict[key]] + else: + _dict[key] = _dict[key].to(device) + + return _dict + + +def _remove_recursively(folder_path): + ''' + Remove directory recursively + ''' + if os.path.isdir(folder_path): + filelist = [f for f in os.listdir(folder_path)] + for f in filelist: + os.remove(os.path.join(folder_path, f)) + return + + +def _create_directory(directory): + ''' + Create directory if doesn't exists + ''' + if not os.path.exists(directory): + os.makedirs(directory) + return diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..f777cb6 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,33 @@ +import errno +import os +import logging + + +def get_logger(path, filename): + + # Create the folder where the training information is to be saved if it doesn't exist + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise + + # Create the logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # In order to store logs of level INFO and above + # create file handler which logs even debug messages into logs file + fh = logging.FileHandler(os.path.join(path, filename)) + fh.setLevel(logging.INFO) + # create console handler + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + # create formatter and add it to the handlers + formatter = logging.Formatter('%(asctime)s -- %(message)s') + fh.setFormatter(formatter) + ch.setFormatter(formatter) + # add the handlers to the logger + logger.addHandler(fh) + logger.addHandler(ch) + + return logger \ No newline at end of file diff --git a/utils/lovasz_losses.py b/utils/lovasz_losses.py new file mode 100644 index 0000000..4809554 --- /dev/null +++ b/utils/lovasz_losses.py @@ -0,0 +1,321 @@ +""" +Lovasz-Softmax and Jaccard hinge loss in PyTorch +Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) +""" + +from __future__ import print_function, division + +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np + +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse as ifilterfalse + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / float(union) + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / float(union)) + ious.append(iou) + ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + + +# --------------------------- MULTICLASS LOSSES --------------------------- + + +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: # <- + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + elif probas.dim() == 5: + #3D segmentation + B, C, L, H, W = probas.size() + probas = probas.contiguous().view(B, C, L, H*W) + if probas.dim() == 4: + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + else: + probas = probas.contiguous() + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + +def xloss(logits, labels, ignore=None): + """ + Cross entropy loss + """ + return F.cross_entropy(logits, Variable(labels), ignore_index=255) + +def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None): + """ + Something wrong with this loss + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + vprobas, vlabels = flatten_probas(probas, labels, ignore) + + + true_1_hot = torch.eye(vprobas.shape[1])[vlabels] + + if bk_class: + one_hot_assignment = torch.ones_like(vlabels) + one_hot_assignment[vlabels == bk_class] = 0 + one_hot_assignment = one_hot_assignment.float().unsqueeze(1) + true_1_hot = true_1_hot*one_hot_assignment + + true_1_hot = true_1_hot.to(vprobas.device) + intersection = torch.sum(vprobas * true_1_hot) + cardinality = torch.sum(vprobas + true_1_hot) + loss = (intersection + smooth / (cardinality - intersection + smooth)).mean() + return (1-loss)*smooth + +def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100): + """ + Multi-class Hinge Jaccard loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + ignore: void class labels + """ + vprobas, vlabels = flatten_probas(probas, labels, ignore) + C = vprobas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + if c in vlabels: + c_sample_ind = vlabels == c + cprobas = vprobas[c_sample_ind,:] + non_c_ind =np.array([a for a in class_to_sum if a != c]) + class_pred = cprobas[:,c] + max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0] + TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth + FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge) + + if (~c_sample_ind).sum() == 0: + FP = 0 + else: + nonc_probas = vprobas[~c_sample_ind,:] + class_pred = nonc_probas[:,c] + max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0] + FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + + losses.append(1 - TP/(TP+FP+FN)) + + if len(losses) == 0: return 0 + return mean(losses) + +# --------------------------- HELPER FUNCTIONS --------------------------- +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..f550840 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,201 @@ +# Some sections of this code reused code from SemanticKITTI development kit +# https://github.com/PRBonn/semantic-kitti-api + +import numpy as np +import torch +import copy + + +class iouEval: + def __init__(self, n_classes, ignore=None): + # classes + self.n_classes = n_classes + + # What to include and ignore from the means + self.ignore = np.array(ignore, dtype=np.int64) + self.include = np.array([n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64) + + # reset the class counters + self.reset() + + def num_classes(self): + return self.n_classes + + def reset(self): + self.conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64) + + def addBatch(self, x, y): # x=preds, y=targets + + assert x.shape == y.shape + + # sizes should be matching + x_row = x.reshape(-1) # de-batchify + y_row = y.reshape(-1) # de-batchify + + # check + assert (x_row.shape == x_row.shape) + + # create indexes + idxs = tuple(np.stack((x_row, y_row), axis=0)) + + # make confusion matrix (cols = gt, rows = pred) + np.add.at(self.conf_matrix, idxs, 1) + + def getStats(self): + # remove fp from confusion on the ignore classes cols + conf = self.conf_matrix.copy() + conf[:, self.ignore] = 0 + + # get the clean stats + tp = np.diag(conf) + fp = conf.sum(axis=1) - tp + fn = conf.sum(axis=0) - tp + return tp, fp, fn + + def getIoU(self): + tp, fp, fn = self.getStats() + intersection = tp + union = tp + fp + fn + 1e-15 + iou = intersection / union + iou_mean = (intersection[self.include] / union[self.include]).mean() + return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES + + def getacc(self): + tp, fp, fn = self.getStats() + total_tp = tp.sum() + total = tp[self.include].sum() + fp[self.include].sum() + 1e-15 + acc_mean = total_tp / total + return acc_mean # returns "acc mean" + + def get_confusion(self): + return self.conf_matrix.copy() + + +class LossesTrackEpoch: + def __init__(self, num_iterations): + # classes + self.num_iterations = num_iterations + self.validation_losses = {} + self.train_losses = {} + self.train_iteration_counts = 0 + self.validation_iteration_counts = 0 + + def set_validation_losses(self, keys): + for key in keys: + self.validation_losses[key] = 0 + return + + def set_train_losses(self, keys): + for key in keys: + self.train_losses[key] = 0 + return + + def update_train_losses(self, loss): + for key in loss: + self.train_losses[key] += loss[key] + self.train_iteration_counts += 1 + return + + def update_validaiton_losses(self, loss): + for key in loss: + self.validation_losses[key] += loss[key] + self.validation_iteration_counts += 1 + return + + def restart_train_losses(self): + for key in self.train_losses.keys(): + self.train_losses[key] = 0 + self.train_iteration_counts = 0 + return + + def restart_validation_losses(self): + for key in self.validation_losses.keys(): + self.validation_losses[key] = 0 + self.validation_iteration_counts = 0 + return + + +class Metrics: + def __init__(self, nbr_classes, num_iterations_epoch, scales): + + self.nbr_classes = nbr_classes + self.evaluator = {} + for scale in scales: + self.evaluator[scale] = iouEval(self.nbr_classes, []) + self.losses_track = LossesTrackEpoch(num_iterations_epoch) + self.best_metric_record = {'mIoU': 0, 'IoU': 0, 'epoch': 0, 'loss': 99999999} + + return + + def add_batch(self, prediction, target): + + # passing to cpu + for key in prediction: + prediction[key] = torch.argmax(prediction[key], dim=1).data.cpu().numpy() + for key in target: + target[key] = target[key].data.cpu().numpy() + + for key in target: + prediction['pred_semantic_' + key] = prediction['pred_semantic_' + key].reshape(-1).astype('int64') + target[key] = target[key].reshape(-1).astype('int64') + lidar_mask = self.get_eval_mask_Lidar(target[key]) + self.evaluator[key].addBatch(prediction['pred_semantic_' + key][lidar_mask], target[key][lidar_mask]) + + return + + def get_eval_mask_Lidar(self, target): + ''' + eval_mask_lidar is only to ingore unknown voxels in groundtruth + ''' + mask = (target != 255) + return mask + + def get_occupancy_IoU(self, scale): + conf = self.evaluator[scale].get_confusion() + tp_occupancy = np.sum(conf[1:, 1:]) + fp_occupancy = np.sum(conf[1:, 0]) + fn_occupancy = np.sum(conf[0, 1:]) + intersection = tp_occupancy + union = tp_occupancy + fp_occupancy + fn_occupancy + 1e-15 + iou_occupancy = intersection / union + return iou_occupancy # returns iou occupancy + + def get_occupancy_Precision(self, scale): + conf = self.evaluator[scale].get_confusion() + tp_occupancy = np.sum(conf[1:, 1:]) + fp_occupancy = np.sum(conf[1:, 0]) + precision = tp_occupancy / (tp_occupancy + fp_occupancy + 1e-15) + return precision # returns precision occupancy + + def get_occupancy_Recall(self, scale): + conf = self.evaluator[scale].get_confusion() + tp_occupancy = np.sum(conf[1:, 1:]) + fn_occupancy = np.sum(conf[0, 1:]) + recall = tp_occupancy / (tp_occupancy + fn_occupancy + 1e-15) + return recall # returns recall occupancy + + def get_occupancy_F1(self, scale): + conf = self.evaluator[scale].get_confusion() + tp_occupancy = np.sum(conf[1:, 1:]) + fn_occupancy = np.sum(conf[0, 1:]) + fp_occupancy = np.sum(conf[1:, 0]) + precision = tp_occupancy / (tp_occupancy + fp_occupancy + 1e-15) + recall = tp_occupancy / (tp_occupancy + fn_occupancy + 1e-15) + F1 = 2 * (precision * recall) / (precision + recall + 1e-15) + return F1 # returns recall occupancy + + def get_semantics_mIoU(self, scale): + _, class_jaccard = self.evaluator[scale].getIoU() + mIoU_semantics = class_jaccard[1:].mean() # Ignore on free voxels (0 excluded) + return mIoU_semantics # returns mIoU semantics + + def reset_evaluator(self): + for key in self.evaluator: + self.evaluator[key].reset() + + def update_best_metric_record(self, mIoU, IoU, loss, epoch): + self.best_metric_record['mIoU'] = mIoU + self.best_metric_record['IoU'] = IoU + self.best_metric_record['loss'] = loss + self.best_metric_record['epoch'] = epoch + return diff --git a/utils/model.py b/utils/model.py new file mode 100644 index 0000000..e63d541 --- /dev/null +++ b/utils/model.py @@ -0,0 +1,5 @@ +from networks.dsc import DSC + + +def get_model(_cfg, phase='train'): + return DSC(_cfg, phase=phase) diff --git a/utils/optimizer.py b/utils/optimizer.py new file mode 100644 index 0000000..c3244cb --- /dev/null +++ b/utils/optimizer.py @@ -0,0 +1,36 @@ +import torch.optim as optim + + +def build_optimizer(_cfg, model): + + opt = _cfg._dict['OPTIMIZER']['TYPE'] + lr = _cfg._dict['OPTIMIZER']['BASE_LR'] + if 'MOMENTUM' in _cfg._dict['OPTIMIZER']: momentum = _cfg._dict['OPTIMIZER']['MOMENTUM'] + if 'WEIGHT_DECAY' in _cfg._dict['OPTIMIZER']: weight_decay = _cfg._dict['OPTIMIZER']['WEIGHT_DECAY'] + + if opt == 'Adam': optimizer = optim.Adam(model.get_parameters(), + lr=lr, + betas=(0.9, 0.999)) + + elif opt == 'SGD': optimizer = optim.SGD(model.get_parameters(), + lr=lr, + momentum=momentum, + weight_decay=weight_decay) + + return optimizer + + +def build_scheduler(_cfg, optimizer): + + # Constant learning rate + if _cfg._dict['SCHEDULER']['TYPE'] == 'constant': + lambda1 = lambda epoch: 1 + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) + + # Learning rate scaled by 0.98^(epoch) + if _cfg._dict['SCHEDULER']['TYPE'] == 'power_iteration': + lambda1 = lambda epoch: (0.98) ** (epoch) + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) + + + return scheduler \ No newline at end of file diff --git a/utils/seed.py b/utils/seed.py new file mode 100644 index 0000000..21d1ea3 --- /dev/null +++ b/utils/seed.py @@ -0,0 +1,15 @@ +import random +import torch +import numpy as np +import os + + +def seed_all(seed): + ''' + Set seeds for training reproducibility + ''' + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) \ No newline at end of file diff --git a/utils/time.py b/utils/time.py new file mode 100644 index 0000000..991359d --- /dev/null +++ b/utils/time.py @@ -0,0 +1,12 @@ +import datetime + + +def get_date_sting(): + ''' + To retrieve time in nice format for string printing and naming + :return: + ''' + _now = datetime.datetime.now() + _date = ('%.2i' % _now.month) + ('%.2i' % _now.day) # ('%.4i' % _now.year) + + _time = ('%.2i' % _now.hour) + ('%.2i' % _now.minute) + ('%.2i' % _now.second) + return (_date + '_' + _time) \ No newline at end of file diff --git a/validate.py b/validate.py new file mode 100644 index 0000000..88f52c1 --- /dev/null +++ b/validate.py @@ -0,0 +1,151 @@ +import os +import argparse +import torch +import torch.nn as nn +import sys + +# Append root directory to system path for imports +repo_path, _ = os.path.split(os.path.realpath(__file__)) +repo_path, _ = os.path.split(repo_path) +sys.path.append(repo_path) + +from utils.seed import seed_all +from utils.config import CFG +from utils.dataset import get_dataset +from utils.model import get_model +from utils.logger import get_logger +from utils.io_tools import dict_to +from utils.metrics import Metrics +import utils.checkpoint as checkpoint +from tqdm import tqdm +import time +import numpy as np + +def parse_args(): + parser = argparse.ArgumentParser(description='DSC validating') + parser.add_argument( + '--weights', + dest='weights_file', + default='', + metavar='FILE', + help='path to folder where model.pth file is', + type=str, + ) + parser.add_argument( + '--dset_root', + dest='dataset_root', + default=None, + metavar='DATASET', + help='path to dataset root folder', + type=str, + ) + args = parser.parse_args() + return args + + +def validate(model, dset, _cfg, logger, metrics): + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float32 # Tensor type to be used + + # Moving optimizer and model to used device + model = model.to(device=device) + + logger.info('=> Passing the network on the validation set...') + time_list = [] + model.eval() + + with torch.no_grad(): + for t, (data, indices) in enumerate(tqdm(dset, ncols=100)): + + data = dict_to(data, device) + start_time = time.time() + scores, loss = model(data) + time_list.append(time.time() - start_time) + + # Updating batch losses to then get mean for epoch loss + metrics.losses_track.update_validaiton_losses(loss) + + if (t + 1) % _cfg._dict['VAL']['SUMMARY_PERIOD'] == 0: + loss_print = '=> Iteration [{}/{}], Train Losses: '.format(t + 1, len(dset)) + for key in loss.keys(): + loss_print += '{} = {:.6f}, '.format(key, loss[key]) + logger.info(loss_print[:-3]) + + metrics.add_batch(prediction=scores, target=model.get_target(data)) + + epoch_loss = metrics.losses_track.validation_losses['total'] / metrics.losses_track.validation_iteration_counts + + logger.info('=> [Total Validation Loss = {}]'.format(epoch_loss)) + for scale in metrics.evaluator.keys(): + loss_scale = metrics.losses_track.validation_losses['semantic_{}'.format(scale)].item() / metrics.losses_track.validation_iteration_counts + logger.info('=> [Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} ' + '- P = {:.6f} - R = {:.6f} - F1 = {:.6f}]'.format(scale, loss_scale, + metrics.get_semantics_mIoU(scale).item(), + metrics.get_occupancy_IoU(scale).item(), + metrics.get_occupancy_Precision(scale).item(), + metrics.get_occupancy_Recall(scale).item(), + metrics.get_occupancy_F1(scale).item())) + + logger.info('=> Training set class-wise IoU:') + for i in range(1, metrics.nbr_classes): + class_name = dset.dataset.get_xentropy_class_string(i) + class_score = metrics.evaluator['1_1'].getIoU()[1][i] + logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score)) + + return time_list + + +def main(): + + # https://github.com/pytorch/pytorch/issues/27588 + torch.backends.cudnn.enabled = False + + seed_all(0) + + args = parse_args() + + weights_f = args.weights_file + dataset_f = args.dataset_root + + assert os.path.isfile(weights_f), '=> No file found at {}' + + checkpoint_path = torch.load(weights_f) + config_dict = checkpoint_path.pop('config_dict') + config_dict['DATASET']['DATA_ROOT'] = dataset_f + + # Read train configuration file + _cfg = CFG() + _cfg.from_dict(config_dict) + # Setting the logger to print statements and also save them into logs file + logger = get_logger(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'logs_val.log') + + logger.info('============ Validation weights: "%s" ============\n' % weights_f) + dataset = get_dataset(_cfg._dict) + + logger.info('=> Loading network architecture...') + model = get_model(_cfg._dict, phase='trainval') + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + model = model.module + logger.info(f'=> Model Parameters: {sum(p.numel() for p in model.parameters())/1000000.0} M') + + logger.info('=> Loading network weights...') + model = checkpoint.load_model(model, weights_f, logger) + + nbr_iterations = len(dataset['val']) + metrics = Metrics(_cfg._dict['DATASET']['NCLASS'], nbr_iterations, model.get_scales()) + metrics.reset_evaluator() + metrics.losses_track.set_validation_losses(model.get_validation_loss_keys()) + metrics.losses_track.set_train_losses(model.get_train_loss_keys()) + + time_list = validate(model, dataset['val'], _cfg, logger, metrics) + + logger.info('=> ============ Network Validation Done ============') + logger.info('Inference time per frame is %.4f seconds\n' % (np.sum(time_list) / len(dataset['val'].dataset))) + + exit() + + +if __name__ == '__main__': + main()