Skip to content

Commit

Permalink
v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ZikangZhou committed May 14, 2022
1 parent fe49ab8 commit 68ac163
Show file tree
Hide file tree
Showing 30 changed files with 2,030 additions and 1 deletion.
113 changes: 113 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.github

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# IDEs
.idea
.vscode

# seed project
lightning_logs/
.DS_Store
90 changes: 89 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction
Official Implementation of "HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction", which will be published in CVPR 2022.
This repository is the official implementation of "HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction" published in CVPR 2022.

![](https://github.com/ZikangZhou/HiVT/assets/overview.png)

```
@inproceedings{zhou2022hivt,
Expand All @@ -9,3 +11,89 @@ Official Implementation of "HiVT: Hierarchical Vector Transformer for Multi-Agen
year={2022}
}
```

## Gettting Started

1\. Clone this repository:
```
git clone https://github.com/ZikangZhou/HiVT.git
cd HiVT
```

2\. Create a conda environment and install the dependencies:
```
conda create -n HiVT python=3.8
conda activate HiVT
conda install pytorch==1.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge
conda install pytorch-geometric==1.7.2 -c rusty1s -c conda-forge
```

3\. Download [Argoverse Motion Forecasting Dataset v1.1](https://www.argoverse.org/av1.html). After downloading and extracting the tar.gz files, the dataset directory should be organized as follows:
```
/path/to/dataset_root/
├── train/
| └── data/
| ├── 1.csv
| ├── 2.csv
| ├── ...
└── val/
└── data/
├── 1.csv
├── 2.csv
├── ...
```

4\. Install [Argoverse 1 API](https://github.com/argoai/argoverse-api).

## Training

To train HiVT-64, run:
```
python train.py --root /path/to/dataset_root/ --embed_dim 64
```

To train HiVT-128, run:
```
python train.py --root /path/to/dataset_root/ --embed_dim 128
```

**Note**: When running the training command for the first time, it will take several hours to preprocess the data (~3.5 hours on my machine). Training on an RTX 2080 Ti GPU takes 35-40 minutes per epoch.

During training, the checkpoints will be saved in `lightning_logs/` automatically. To monitor the training process, run:
```
tensorboard --log_dir lightning_logs/
```

## Evaluation

To evaluate the prediction performance, run:
```
python eval.py --root /path/to/dataset_root/ --batch_size 32 --ckpt_path /path/to/your_checkpoint.ckpt
```

## Pretrained Models

We provide the pretrained HiVT-64 and HiVT-128 in [checkpoints/](https://github.com/ZikangZhou/HiVT/checkpoints). You can evaluate the pretrained models using the aforementioned evaluation command, or have a look at the training process via TensorBoard:
```
tensorboard --log_dir checkpoints/
```

## Results

### Quantitative Results

For this repository, the expected performance on Argoverse 1 validation set is:

| Models | minADE | minFDE | MR |
| :--- | :---: | :---: | :---: |
| HiVT-64 | 0.69 | 1.03 | 0.10 |
| HiVT-128 | 0.66 | 0.97 | 0.09 |

### Qualitative Results

![](https://github.com/ZikangZhou/HiVT/assets/visualization.png)

## License

This repository is licensed under [Apache 2.0](https://github.com/ZikangZhou/HiVT/LICENSE).

Binary file added assets/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/visualization.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
27 changes: 27 additions & 0 deletions checkpoints/HiVT-128/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
historical_steps: 20
future_steps: 30
num_modes: 6
rotate: true
node_dim: 2
edge_dim: 2
embed_dim: 128
num_heads: 8
dropout: 0.1
num_temporal_layers: 4
num_global_layers: 3
local_radius: 50
parallel: false
lr: 0.0005
weight_decay: 0.0001
T_max: 64
root: /
train_batch_size: 32
val_batch_size: 32
shuffle: true
num_workers: 8
pin_memory: true
persistent_workers: true
gpus: 1
max_epochs: 64
monitor: val_minFDE
save_top_k: 5
Binary file not shown.
Binary file not shown.
27 changes: 27 additions & 0 deletions checkpoints/HiVT-64/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
historical_steps: 20
future_steps: 30
num_modes: 6
rotate: true
node_dim: 2
edge_dim: 2
embed_dim: 64
num_heads: 8
dropout: 0.1
num_temporal_layers: 4
num_global_layers: 3
local_radius: 50
parallel: false
lr: 0.0005
weight_decay: 0.0001
T_max: 64
root: /
train_batch_size: 32
val_batch_size: 32
shuffle: true
num_workers: 8
pin_memory: true
persistent_workers: true
gpus: 1
max_epochs: 64
monitor: val_minFDE
save_top_k: 5
14 changes: 14 additions & 0 deletions datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datamodules.argoverse_v1_datamodule import ArgoverseV1DataModule
62 changes: 62 additions & 0 deletions datamodules/argoverse_v1_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional

from pytorch_lightning import LightningDataModule
from torch_geometric.data import DataLoader

from datasets import ArgoverseV1Dataset


class ArgoverseV1DataModule(LightningDataModule):

def __init__(self,
root: str,
train_batch_size: int,
val_batch_size: int,
shuffle: bool = True,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = True,
train_transform: Optional[Callable] = None,
val_transform: Optional[Callable] = None,
local_radius: float = 50) -> None:
super(ArgoverseV1DataModule, self).__init__()
self.root = root
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.num_workers = num_workers
self.train_transform = train_transform
self.val_transform = val_transform
self.local_radius = local_radius

def prepare_data(self) -> None:
ArgoverseV1Dataset(self.root, 'train', self.train_transform, self.local_radius)
ArgoverseV1Dataset(self.root, 'val', self.val_transform, self.local_radius)

def setup(self, stage: Optional[str] = None) -> None:
self.train_dataset = ArgoverseV1Dataset(self.root, 'train', self.train_transform, self.local_radius)
self.val_dataset = ArgoverseV1Dataset(self.root, 'val', self.val_transform, self.local_radius)

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers,
pin_memory=self.pin_memory, persistent_workers=self.persistent_workers)
14 changes: 14 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022, Zikang Zhou. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets.argoverse_v1_dataset import ArgoverseV1Dataset
Loading

0 comments on commit 68ac163

Please sign in to comment.