Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
agrimgupta92 committed Jun 29, 2018
0 parents commit a83b538
Show file tree
Hide file tree
Showing 20 changed files with 1,998 additions and 0 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
The MIT License (MIT)

Copyright (c) 2018 Agrim Gupta, Justin Johnson

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
25 changes: 25 additions & 0 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Social GAN Model Zoo

We refer our method as SGAN-kVP-N where kV signifies if the model was trained using variety loss (k = 1 essentially means no variety loss) and P signifies usage of our proposed pooling module. At test time we sample multiple times from the model and chose the best prediction in L2 sense for quantitative evaluation. N refers to the number of time we sample from our model during test time. We report two error metrics Average Displacement Error (ADE) and Final Displacement Error (FDE) for t<sub>pred</sub> = 8 and 12 in meters.

These results are better from what were reported in the paper. You can use print_args to get hyper-parameters used for training. For SGAN-20VP-20 we used 'global' as opposed to 'local' as done in the paper.

**SGAN-20V-20**

| Model | ADE<sub>8</sub> | ADE<sub>12</sub> | FDE<sub>8</sub> | FDE<sub>12</sub> |
|-----|-----|--- |--- |--- |
| `ETH`| 0.58 |0.71 |1.13 |1.29 |
| `Hotel`| 0.36 |0.48 |0.71 |1.02|
| `Univ`| 0.33 |0.56 |0.70 |1.18 |
| `Zara1`| 0.21 |0.34 |0.42 |0.69|
| `Zara2`| 0.21 |0.31|0.42 |0.64|

**SGAN-20VP-20**

| Model | ADE<sub>8</sub> | ADE<sub>12</sub> | FDE<sub>8</sub> | FDE<sub>12</sub> |
|-----|-----|--- |--- |--- |
| `ETH`| 0.57 |0.77|1.14 |1.39|
| `Hotel`| 0.38 |0.43|0.73 |0.88|
| `Univ`| 0.42 |0.75|0.79 |1.50|
| `Zara1`| 0.22 |0.34|0.43 |0.68|
| `Zara2`| 0.24 |0.36|0.48 |0.73|
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Social GAN

This is the code for the paper

**<a href="https://arxiv.org/abs/1803.10892">Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks</a>**
<br>
<a href="http://web.stanford.edu/~agrim/">Agrim Gupta</a>,
<a href="http://cs.stanford.edu/people/jcjohns/">Justin Johnson</a>,
<a href="http://vision.stanford.edu/feifeili/">Fei-Fei Li</a>,
<a href="http://cvgl.stanford.edu/silvio/">Silvio Savarese</a>,
<a href="http://web.stanford.edu/~alahi/">Alexandre Alahi</a>
<br>
Presented at [CVPR 2018](http://cvpr2018.thecvf.com/)

Human motion is interpersonal, multimodal and follows social conventions. In this paper, we tackle this problem by combining tools from sequence prediction and generative adversarial networks: a recurrent sequence-to-sequence model observes motion histories and predicts future behavior, using a novel pooling mechanism to aggregate information across
people.

Below we show an examples of socially acceptable predictions made by our model in complex scenarios. Each person is denoted by a different color. We denote observed trajectory by dots and predicted trajectory by stars.
<div align='center'>
<img src="images/2.gif"></img>
<img src="images/3.gif"></img>
</div>

If you find this code useful in your research then please cite
```
@inproceedings{gupta2018social,
title={Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks},
author={Gupta, Agrim and Johnson, Justin and Fei-Fei, Li and Savarese, Silvio and Alahi, Alexandre},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
number={CONF},
year={2018}
}
```

## Model
Our model consists of three key components: Generator (G), Pooling Module (PM) and Discriminator (D). G is based on encoder-decoder framework where we link the hidden states of encoder and decoder via PM. G takes as input trajectories of all people involved in a scene and outputs corresponding predicted trajectories. D inputs the entire sequence comprising both input trajectory and future prediction and classifies them as “real/fake”.

<div align='center'>
<img src='images/model.png' width='1000px'>
</div>

## Setup
All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4.

You can setup a virtual environment to run the code like this:

```bash
python3 -m venv env # Create a virtual environment
source env/bin/activate # Activate virtual environment
pip install -r requirements.txt # Install dependencies
echo $PWD > env/lib/python3.5/site-packages/sgan.pth # Add current directory to python path
# Work for a while ...
deactivate # Exit virtual environment
```

## Pretrained Models
You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models:

- `sgan-models/<dataset_name>_<pred_len>.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20V-20 in Table 1.
- `sgan-p-models/<dataset_name>_<pred_len>.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20VP-20 in Table 1.

Please refer to [Model Zoo](MODEL_ZOO.md) for results.

## Running Models
You can use the script `scripts/evaluate_model.py` to easily run any of the pretrained models on any of the datsets. For example you can replicate the Table 1 results for all datasets for SGAN-20V-20 like this:

```bash
python scripts/evaluate_model.py \
--model_path models/sgan-models
```

## Training new models
Instructions for training new models can be [found here](TRAINING.md).
81 changes: 81 additions & 0 deletions TRAINING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
You can train your own model by following these instructions:

## Step 1: Preparing Data
Run the following script to download the dataset:

```bash
bash scripts/download_data.sh
```

This will create the directory `datasets/<dataset_name>` with train/ val/ and test/ splits. All the datasets are pre-processed to be in world coordinates i.e. in meters. We support five datasets ETH, ZARA1, ZARA2, HOTEL and UNIV. We use leave-one-out approach, train on 4 sets and test on the remaining set. We observe the trajectory for 8 times steps (3.2 seconds) and show prediction results for 8 (3.2 seconds) and 12 (4.8 seconds) time steps.

## Step 2: Train a model

Now you can train a new model by running the script:

```bash
python scripts/train.py
```

By default this will train a model on Zara1, periodically saving checkpoint files `checkpoint_with_model.pt` and `checkpoint_no_model.pt` to the current working directory. The training script has a number of command-line flags that you can use to configure the model architecture, hyperparameters, and input / output settings:

### Optimization

- `--batch_size`: How many sequences to use in each minibatch during training. Default is 64.
- `--num_iterations`: Number of training iterations. Default is 10,000.
- `--num_epochs`: Number of training iterations. Default is 200.

### Dataset options

- `--dataset_name`: The dataset to use for training; must be either of the five supported datasets. Default is `zara1`.
- `--delim`: Delimiter used in the files of the dataset. Default is ' '.
- `--obs_len`: Number of time-steps in input trajectories. Default is 8.
- `--pred_len`: Number of time-steps in output trajectories. Default is 8.
- `--loader_num_workers`: The number of background threads to use for data loading. Default is 4.
- `--skip`: Number of frames to skip while making the dataset. For e.g. if Sequence<sub>1</sub> in the dataset is from Frame<sub>1</sub> - Frame<sub>N</sub> and skip = 2. Then Sequence<sub>2</sub> will be from Frame<sub>3</sub> - Frame<sub>N+2</sub>. Default is 1.

### Model options
Our model consists of three components 1) Generator 2) Pooling Module 3) Discriminator. These flags control the architecture hyperparameters for both generator and discriminator.

- `--embedding_dim`: Integer giving the dimension for the embedding layer for input (x, y) coordinates. Default is 64.
- `--num_layers`: Number of layers in LSTM. We only support num_layers = 1.
- `--dropout`: Float value specifying the amount of dropout. Default is 0 (no dropout).
- `--batch_norm`: Boolean flag indicating if MLP has batch norm. Default is False.
- `--mlp_dim`: Integer giving the dimensions of the MLP. Default is 1024.
We use the same mlp options across all three components of the model.

**Generator Options**: The generator takes as input all the trajectories for a given sequence and jointly predicts socially acceptable trajectories. These flags control architecture hyperparameters specific to the generator:
- `--encoder_h_dim_g`: Integer giving the dimensions of the hidden layer in the encoder. Default is 64.
- `--decoder_h_dim_g`: Integer giving the dimensions of the hidden layer in the decoder. Default is 64.
- `--noise_dim`: Integer tuple giving the dimensions of the noise added to the input of the decoder. Default is None.
- `--noise_type`: Type of noise to be added. We support two options "uniform" and "gaussian" noise. Default is "gaussian".
- `--noise_mix_type`: The added noise can either be the same across all pedestrians or we can have a different per person noise. We support two options "global" and "ped". Default value is "ped".
- `--clipping_threshold_g`: Float value indicating the threshold at which the gradients should be clipped. Default is 0.
- `--g_learning_rate`: Learning rate for the generator. Default is 5e-4.
- `--g_steps`: An iteration consists of g_steps forward backward pass on the generator. Default is 1.

**Pooling Options**: Our design is general to support any pooling type. We support two pooling modules: 1) Social Pooling 2) Pool Net. These flags control architecture hyperparameters specific to pooling modules:
- `--pooling_type`: Type of pooling module to use. We support two options "pool_net" and "spool". Default is "pool_net".
- `--pool_every_timestep`: We can pool the hidden states at every time step or only once after obs_len. Default is false.
- `--bottleneck_dim`: Output dimensions of the pooled vector for Pool Net. Default is 1024.
- `--neighborhood_size`: Neighborhood size to consider in social pooling. Please refer to S-LSTM paper for details. Default is 2.
- `--grid_size`: The neighborhood is divided into grid_size x grid_size grids. Default is 8.

**Discriminator Options**: These flags control architecture hyperparameters specific to the discriminator:
- `--d_type`: The discriminator can either treat each trajectory independently as described in the paper (option "local") or it can follow something very similar to the generator and pool the information across trajectories to determine if they are real/fake (option "global"). Default is "local".
- `--encoder_h_dim_d`: Integer giving the dimensions of the hidden layer in the encoder. Default is 64.
- `--d_learning_rate`: Learning rate for the discriminator. Default is 5e-4.
- `--d_steps`: An iteration consists of d_steps forward backward pass on the generator. Default is 2.
- `--clipping_threshold_d`: Float value indicating the threshold at which the gradients should be clipped. Default is 0.

### Output Options
These flags control outputs from the training script:

- `--output_dir`: Directory to which checkpoints will be saved. Default is current directory.
- `--print_every`: Training losses are printed and recorded every `--print_every` iterations. Default is 10.
- `--timing`: If this flag is set to 1 then measure and print the time that each model component takes to execute.
- `--checkpoint_every`: Checkpoints are saved to disk every `--checkpoint_every` iterations. Default is 100. Each checkpoint contains a history of training losses, error metrics like ADE, FDE etc, the current state of the generator, discriminators, and optimizers, as well as all other state information needed to resume training in case it is interrupted. We actually save two checkpoints: one with all information, and one without model parameters; the latter is much smaller, and is convenient for exploring the results of a large hyperparameter sweep without actually loading model parameters.
- `--checkpoint_name`: Base filename for saved checkpoints; default is 'checkpoint', so the filename for the checkpoint with model parameters will be 'checkpoint_with_model.pt' and the filename for the checkpoint without model parameters will be 'checkpoint_no_model.pt'.
- `--restore_from_checkpoint`: Default behavior is to start training from scratch, and overwrite the output checkpoint path if it already exists. If this flag is set to 1 then instead resume training from the output checkpoint file if it already exists. This is useful when running in an environment where jobs can be preempted.
- `--checkpoint_start_from`: Default behavior is to start training from scratch; if this flag is given then instead resume training from the specified checkpoint. This takes precedence over `--restore_from_checkpoint` if both are given.
- `--num_samples_check`: When calculating metrics on training dataset limit the number of samples you want to evaluate on to ensure checkpointing is fast for big datasets.
Binary file added images/2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/3.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
attrdict==2.0.0
numpy==1.14.5
Pillow==5.1.0
pkg-resources==0.0.0
six==1.11.0
torch==0.4.0
torchvision==0.2.1
3 changes: 3 additions & 0 deletions scripts/download_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wget -O datasets.zip 'https://www.dropbox.com/s/8n02xqv3l9q18r1/datasets.zip?dl=0'
unzip -q datasets.zip
rm -rf datasets.zip
3 changes: 3 additions & 0 deletions scripts/download_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wget -O models.zip 'https://www.dropbox.com/s/h8q5z4axfgzx9eb/models.zip?dl=0'
unzip -q models.zip
rm -rf models.zip
117 changes: 117 additions & 0 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import argparse
import os
import torch

from attrdict import AttrDict

from sgan.data.loader import data_loader
from sgan.models import TrajectoryGenerator
from sgan.losses import displacement_error, final_displacement_error
from sgan.utils import relative_to_abs, get_dset_path

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str)
parser.add_argument('--num_samples', default=20, type=int)
parser.add_argument('--dset_type', default='test', type=str)


def get_generator(checkpoint):
args = AttrDict(checkpoint['args'])
generator = TrajectoryGenerator(
obs_len=args.obs_len,
pred_len=args.pred_len,
embedding_dim=args.embedding_dim,
encoder_h_dim=args.encoder_h_dim_g,
decoder_h_dim=args.decoder_h_dim_g,
mlp_dim=args.mlp_dim,
num_layers=args.num_layers,
noise_dim=args.noise_dim,
noise_type=args.noise_type,
noise_mix_type=args.noise_mix_type,
pooling_type=args.pooling_type,
pool_every_timestep=args.pool_every_timestep,
dropout=args.dropout,
bottleneck_dim=args.bottleneck_dim,
neighborhood_size=args.neighborhood_size,
grid_size=args.grid_size,
batch_norm=args.batch_norm)
generator.load_state_dict(checkpoint['g_state'])
generator.cuda()
generator.train()
return generator


def evaluate_helper(error, seq_start_end):
sum_ = 0
error = torch.stack(error, dim=1)

for (start, end) in seq_start_end:
start = start.item()
end = end.item()
_error = error[start:end]
_error = torch.sum(_error, dim=0)
_error = torch.min(_error)
sum_ += _error
return sum_


def evaluate(args, loader, generator, num_samples):
ade_outer, fde_outer = [], []
total_traj = 0
with torch.no_grad():
for batch in loader:
batch = [tensor.cuda() for tensor in batch]
(obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
non_linear_ped, loss_mask, seq_start_end) = batch

ade, fde = [], []
total_traj += pred_traj_gt.size(1)

for _ in range(num_samples):
pred_traj_fake_rel = generator(
obs_traj, obs_traj_rel, seq_start_end
)
pred_traj_fake = relative_to_abs(
pred_traj_fake_rel, obs_traj[-1]
)
ade.append(displacement_error(
pred_traj_fake, pred_traj_gt, mode='raw'
))
fde.append(final_displacement_error(
pred_traj_fake[-1], pred_traj_gt[-1], mode='raw'
))

ade_sum = evaluate_helper(ade, seq_start_end)
fde_sum = evaluate_helper(fde, seq_start_end)

ade_outer.append(ade_sum)
fde_outer.append(fde_sum)
ade = sum(ade_outer) / (total_traj * args.pred_len)
fde = sum(fde_outer) / (total_traj)
return ade, fde


def main(args):
if os.path.isdir(args.model_path):
filenames = os.listdir(args.model_path)
filenames.sort()
paths = [
os.path.join(args.model_path, file_) for file_ in filenames
]
else:
paths = [args.model_path]

for path in paths:
checkpoint = torch.load(path)
generator = get_generator(checkpoint)
_args = AttrDict(checkpoint['args'])
path = get_dset_path(_args.dataset_name, args.dset_type)
_, loader = data_loader(_args, path)
ade, fde = evaluate(_args, loader, generator, args.num_samples)
print('Dataset: {}, Pred Len: {}, ADE: {:.2f}, FDE: {:.2f}'.format(
_args.dataset_name, _args.pred_len, ade, fde))


if __name__ == '__main__':
args = parser.parse_args()
main(args)
20 changes: 20 additions & 0 deletions scripts/print_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import argparse
import torch

"""
Tiny utility to print the command-line args used for a checkpoint
"""

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint')


def main(args):
checkpoint = torch.load(args.checkpoint, map_location='cpu')
for k, v in checkpoint['args'].items():
print(k, v)


if __name__ == '__main__':
args = parser.parse_args()
main(args)
34 changes: 34 additions & 0 deletions scripts/run_traj.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
python train.py \
--dataset_name 'zara1' \
--delim tab \
--d_type 'local' \
--pred_len 8 \
--encoder_h_dim_g 32 \
--encoder_h_dim_d 64\
--decoder_h_dim 32 \
--embedding_dim 16 \
--bottleneck_dim 32 \
--mlp_dim 64 \
--num_layers 1 \
--noise_dim 8 \
--noise_type gaussian \
--noise_mix_type global \
--pool_every_timestep 0 \
--l2_loss_weight 1 \
--batch_norm 0 \
--dropout 0 \
--batch_size 32 \
--g_learning_rate 1e-3 \
--g_steps 1 \
--d_learning_rate 1e-3 \
--d_steps 2 \
--checkpoint_every 10 \
--print_every 50 \
--num_iterations 20000 \
--num_epochs 500 \
--pooling_type 'pool_net' \
--clipping_threshold_g 1.5 \
--best_k 10 \
--gpu_num 1 \
--checkpoint_name gan_test \
--restore_from_checkpoint 0
Loading

0 comments on commit a83b538

Please sign in to comment.