This repository contains the official PyTorch implementation of Gradient Estimation with Stochastic Softmax Tricks.
This repository contains code to reproduce the NRI for Graph Layout experiments, and includes:
- Spanning tree SST with exponential family entropy relaxation
- Spanning tree SMT (batched pytorch and C++ implemenation of Kruskal's algorithm)
- Top k SST with exponential family entropy relaxation
- Directed spanning tree SMT (C++ implementation of Chu–Liu/Edmonds' algorithm)
- Pytorch 1.4
- Python 3.7
- CVXPY (
pip install cvxpy
) - cvxpylayers (
pip install cvxpylayers
) - NetworkX (
pip install networkx
)
To import core/kruskals/kruskals.py
the C++ Kruskal's algorithm implementation must be compiled and installed:
cd core/kruskals
python setup_kruskals.py install
The graph layout data can be generated using the script data_scipts/generate_graph_layout_dataset.py
. To generate the exact data used in the experiments of the paper, run:
python -m data_scripts.generate_graph_layout_dataset --num_timesteps 10 --num_vertices 10
python -m data_scripts.generate_graph_layout_dataset --num_timesteps 20 --num_vertices 10
To reproduce the spanning tree SST results for T=10, run:
python train.py --suffix _novar_1skip_10t_1r_graph10 --timesteps 10 \
--prediction_steps 9 --sst tree --relaxation exp_family_entropy \
--max_range 15 --symmeterize_logits True
and tune --lr
, --temp
by doing random uniform search in the log scale over a hypercube-shaped search space.
To reproduce the topk SST results for T=10, run:
python train.py --suffix _novar_1skip_10t_1r_graph10 --timesteps 10 \
--prediction_steps 9 --sst topk --relaxation exp_family_entropy \
--symmeterize_logits True
and tune --lr
, --temp
by doing random uniform search in the log scale over a hypercube-shaped search space, and random search for --eps_for_finitediff
over [1e0, 1e-1, 1e-2, 1e-3, 1e-4].
To reproduce the independent edges results for T=10, run:
python train.py --suffix _novar_1skip_10t_1r_graph10 --timesteps 10 \
--prediction_steps 9 --sst indep --symmeterize_logits False
and tune --lr
, --temp
by doing random uniform search in the log scale over a hypercube-shaped search space.
To reproduce the spanning tree SST results for T=20, replace the above commands with --suffix _novar_1skip_20t_1r_graph10 --timesteps 20 --prediction_steps 10
.
If you use this code for your research, please cite our paper:
@article{paulus2020gradient,
title={Gradient Estimation with Stochastic Softmax Tricks},
author={Paulus, Max B and Choi, Dami and Tarlow, Daniel and Krause, Andreas and Maddison, Chris J},
journal={Advances in Neural Information Processing Systems},
year={2020}
}