Skip to content

Changed calculation of log_sum_exp(x1, x2) #1772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 19, 2020

Conversation

pgree
Copy link
Contributor

@pgree pgree commented Mar 10, 2020

Summary

The partial derivatives of log_sum_exp(x, y) were losing accuracy in the regime where x and y were equal and greater than around 1e14. In the original log_sum_exp(x1, x2) function, the derivative of log_sum_exp with respect to x1 was evaluated by this formula:

exp(x1 - (y + log(exp(x1 - y) + exp(x2 - y))))

where y is the maximum of x1 and x2. This will result in a loss of accuracy when y is close to x1 and log(exp(x1-y) + exp(x2 - y)) is small. For example, when x1 = 1e20 and x2 = 1e20

x1 - (y + log(exp(x1 - y) + exp(x2 - y)))

evaluates to 0.

I changed the calculation of the partial derivative of log_sum_exp with respect to x1 to be

1 / (1 + exp(x2 - x1))

and the partial derivative of log_sum_exp with respect to x2 to be

1 / (1 + exp(x1 - x2)).

Tests

I wrote a couple of new tests to check that the partials are accurate with large and equal argument values as in @martinmodrak 's issue raised in #1679.

Side Effects

None

Checklist

@pgree pgree changed the title Bugfix/1679 log sum exp Changed calculation of log_sum_exp(x1, x2) Mar 10, 2020
@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.86 4.81 1.01 1.03% faster
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 1.02 1.72% faster
eight_schools/eight_schools.stan 0.09 0.09 0.98 -2.23% slower
gp_regr/gp_regr.stan 0.22 0.22 1.0 -0.3% slower
irt_2pl/irt_2pl.stan 6.46 6.44 1.0 0.28% faster
performance.compilation 87.39 86.67 1.01 0.82% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 7.6 7.54 1.01 0.71% faster
pkpd/one_comp_mm_elim_abs.stan 19.98 21.4 0.93 -7.13% slower
sir/sir.stan 96.15 91.63 1.05 4.7% faster
gp_regr/gen_gp_data.stan 0.05 0.05 0.96 -3.83% slower
low_dim_gauss_mix/low_dim_gauss_mix.stan 2.96 2.95 1.0 0.3% faster
pkpd/sim_one_comp_mm_elim_abs.stan 0.31 0.33 0.96 -4.57% slower
arK/arK.stan 1.74 1.74 1.0 -0.45% slower
arma/arma.stan 0.67 0.66 1.01 0.57% faster
garch/garch.stan 0.51 0.51 1.0 -0.24% slower
Mean result: 0.995026828704

Jenkins Console Log
Blue Ocean
Commit hash: b513358


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.85 4.82 1.01 0.56% faster
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 1.0 0.05% faster
eight_schools/eight_schools.stan 0.09 0.09 1.03 2.61% faster
gp_regr/gp_regr.stan 0.22 0.22 1.01 0.94% faster
irt_2pl/irt_2pl.stan 6.43 6.48 0.99 -0.79% slower
performance.compilation 88.55 86.36 1.03 2.48% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 7.54 7.63 0.99 -1.29% slower
pkpd/one_comp_mm_elim_abs.stan 21.13 20.13 1.05 4.76% faster
sir/sir.stan 93.48 94.66 0.99 -1.26% slower
gp_regr/gen_gp_data.stan 0.05 0.05 1.02 1.56% faster
low_dim_gauss_mix/low_dim_gauss_mix.stan 2.96 2.96 1.0 -0.03% slower
pkpd/sim_one_comp_mm_elim_abs.stan 0.32 0.31 1.02 1.59% faster
arK/arK.stan 1.74 1.74 1.0 0.09% faster
arma/arma.stan 0.66 0.66 1.0 -0.1% slower
garch/garch.stan 0.51 0.51 1.01 0.78% faster
Mean result: 1.00830238348

Jenkins Console Log
Blue Ocean
Commit hash: b513358


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Copy link
Member

@bbbales2 bbbales2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here yah go!

@@ -22,8 +22,8 @@ class log_sum_exp_vv_vari : public op_vv_vari {
log_sum_exp_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to get rid of calculate_chain everywhere? It seems like a strange function.

Also used here (which should be fixed as part of this pull):

avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);

In stan/math/rev/fun/log_diff_exp.hpp and in stan/math/rev/fun/log1p_exp.hpp.

EXPECT_FLOAT_EQ(a.adj(), 1.0);

var a1 = 1e50;
var a2 = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you remove the other calculate_chain thing we'll need some tests where a1 is a var and a2 is a double.

@bob-carpenter
Copy link
Member

The first rule of naming is that names should be short conventional placeholders where possible and mean something everywhere else. The second rule of naming is that names should be as short as possible, because big ones obfuscate the structure of your code and we want to do everything in our power to make the actual code readable. The third rule is to use your judgement in balancing the first and second rules.

With compute_chain, the compute part isn't really adding anything because every function does computation. That goes for other generic naming conventions like _helper and do_it.

The chain part is misleading because there's no object named "chain" being computed. The method chain() is so named because it's responsible for applying the chain rule.

What is being computed? A partial derivative. So we can call the function something like log_sum_exp_partial, which describes what it returns and will read well in code.

Finally, always let your functions pull things apart if it doesn't involve extra overhead. So you can pass the avi_ pointer and let the function pull out avi_->val_. That helps get rid of code duplication when it's called twice.

P.S. This was just to give you an idea of how to think about naming. We're not going to micromanage every name in every program!

@pgree
Copy link
Contributor Author

pgree commented Mar 11, 2020

Thanks @bob-carpenter and @bbbales2

Is it better to change the function calculate_chain(x, y) to compute the new formula for the derivative (and change the name of the function)? Or should I replace the call to the function everywhere it appears with something like

1 / (1 + exp(x2 - x1))?

@bob-carpenter
Copy link
Member

Keep the function (it removes code duplication) and rename it (so it self documents).

@bbbales2
Copy link
Member

calculate_chain(x, y)

I think it makes sense to get rid of this one.

@bob-carpenter
Copy link
Member

Ben's the official reviewer, so his opinion trumps mine in this one!

It's really not as big a deal as we're making it here. It's just good to set out conventions when you start coding so you don't have to think later.

@bbbales2
Copy link
Member

Mm, I'm not convinced we should keep it. Really not much code duplication.

stan/math/rev/fun/log1p_exp.hpp:  void chain() { avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); }

stan/math/rev/fun/log_sum_exp.hpp:    avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
stan/math/rev/fun/log_sum_exp.hpp:    bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
stan/math/rev/fun/log_sum_exp.hpp:      avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
stan/math/rev/fun/log_sum_exp.hpp:      bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
stan/math/rev/fun/log_sum_exp.hpp:      vis_[i]->adj_ += adj_ * calculate_chain(vis_[i]->val_, val_);

stan/math/rev/fun/log_diff_exp.hpp:    avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
stan/math/rev/fun/log_diff_exp.hpp:      avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);

It also lives in rev, which I think is the wrong place and the comment on the implementation doesn't seem right:

inline double calculate_chain(double x, double val) {
  return std::exp(x - val);  // works out to inv_logit(x)
}

@bbbales2
Copy link
Member

And honestly is that just wrong? It just seems wrong to me. What is happening here.

exp(x - val) is not 1 / (1 + exp(x - val)) (or 1 / (1 + exp(val - x)) or whatever)

@pgree
Copy link
Contributor Author

pgree commented Mar 11, 2020

In log_sum_exp.hpp the function calculate_chain(x, val) was being called with

val = log_sum_exp(x, y).

So it works out that

calculate_chain(x, log_sum_exp(x, y)) = 1 / (1 + exp(y - x)).

I need to check how it's being used in the other functions to make sure that this switch would work.

@bbbales2
Copy link
Member

Aaaah, thanks for that clarification.

@andrjohns
Copy link
Collaborator

Quick suggestion (not a blocking review comment), some of these could use inv_logit, which would be more resistant to over-/under-flow.

For log_sum_exp:

  void chain() {
    avi_->adj_ += adj_ * inv_logit(avi_->val_ - bvi_->val_);
    bvi_->adj_ += adj_ * inv_logit(bvi_->val_ - avi_->val_);
  }

For log1p_exp:

  void chain() { avi_->adj_ += adj_ * inv_logit(avi_->val_); }

@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.87 4.87 1.0 -0.0% slower
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 0.98 -2.17% slower
eight_schools/eight_schools.stan 0.09 0.09 0.99 -0.75% slower
gp_regr/gp_regr.stan 0.22 0.22 0.98 -1.77% slower
irt_2pl/irt_2pl.stan 6.49 6.43 1.01 0.88% faster
performance.compilation 87.39 86.31 1.01 1.23% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 7.52 7.52 1.0 -0.07% slower
pkpd/one_comp_mm_elim_abs.stan 21.38 21.24 1.01 0.68% faster
sir/sir.stan 95.73 93.72 1.02 2.1% faster
gp_regr/gen_gp_data.stan 0.05 0.05 1.01 0.91% faster
low_dim_gauss_mix/low_dim_gauss_mix.stan 2.96 2.95 1.0 0.13% faster
pkpd/sim_one_comp_mm_elim_abs.stan 0.31 0.32 0.96 -3.84% slower
arK/arK.stan 1.75 1.92 0.92 -9.22% slower
arma/arma.stan 0.66 0.68 0.98 -2.5% slower
garch/garch.stan 0.51 0.52 0.99 -0.78% slower
Mean result: 0.990644009267

Jenkins Console Log
Blue Ocean
Commit hash: 088ee0c


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

@martinmodrak
Copy link
Contributor

Thanks for handling this! Few thoughts:

I actually played with this issue in #1677 (as it was a conveninent example to play with). As the tools for testing identitites are not going to be ready soon (we decided with @bob-carpenter to take the longer route there), I am however not sure you can get much from it (although I do use a bit more identities for testing).

I second @andrjohns on the inv_logit suggestion.

Generally, I would suggest to make each test loop over a range of values instead of using just one value.

@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.91 4.86 1.01 0.95% faster
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 1.0 0.19% faster
eight_schools/eight_schools.stan 0.09 0.09 1.04 3.74% faster
gp_regr/gp_regr.stan 0.22 0.22 1.01 0.95% faster
irt_2pl/irt_2pl.stan 6.5 6.43 1.01 1.03% faster
performance.compilation 87.27 86.35 1.01 1.05% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 7.59 7.52 1.01 0.93% faster
pkpd/one_comp_mm_elim_abs.stan 20.55 21.52 0.96 -4.71% slower
sir/sir.stan 94.45 93.79 1.01 0.7% faster
gp_regr/gen_gp_data.stan 0.05 0.05 0.99 -1.02% slower
low_dim_gauss_mix/low_dim_gauss_mix.stan 2.95 2.95 1.0 -0.08% slower
pkpd/sim_one_comp_mm_elim_abs.stan 0.31 0.31 1.0 -0.13% slower
arK/arK.stan 1.75 1.74 1.0 0.35% faster
arma/arma.stan 0.67 0.66 1.02 1.77% faster
garch/garch.stan 0.51 0.52 1.0 -0.12% slower
Mean result: 1.0040328315

Jenkins Console Log
Blue Ocean
Commit hash: a793873


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

@stan-buildbot
Copy link
Contributor


Name Old Result New Result Ratio Performance change( 1 - new / old )
gp_pois_regr/gp_pois_regr.stan 4.92 4.84 1.02 1.57% faster
low_dim_corr_gauss/low_dim_corr_gauss.stan 0.02 0.02 0.98 -2.19% slower
eight_schools/eight_schools.stan 0.09 0.09 1.04 4.19% faster
gp_regr/gp_regr.stan 0.22 0.22 0.99 -1.3% slower
irt_2pl/irt_2pl.stan 6.43 6.45 1.0 -0.3% slower
performance.compilation 87.71 86.45 1.01 1.44% faster
low_dim_gauss_mix_collapse/low_dim_gauss_mix_collapse.stan 7.51 7.57 0.99 -0.74% slower
pkpd/one_comp_mm_elim_abs.stan 21.2 21.26 1.0 -0.27% slower
sir/sir.stan 93.88 93.97 1.0 -0.1% slower
gp_regr/gen_gp_data.stan 0.05 0.05 1.04 3.4% faster
low_dim_gauss_mix/low_dim_gauss_mix.stan 2.95 2.96 1.0 -0.17% slower
pkpd/sim_one_comp_mm_elim_abs.stan 0.31 0.31 1.01 0.98% faster
arK/arK.stan 1.76 1.75 1.01 0.59% faster
arma/arma.stan 0.66 0.67 0.99 -1.46% slower
garch/garch.stan 0.52 0.52 1.01 0.55% faster
Mean result: 1.00444677354

Jenkins Console Log
Blue Ocean
Commit hash: a793873


Machine information ProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010

CPU:
Intel(R) Xeon(R) CPU E5-1680 v2 @ 3.00GHz

G++:
Configured with: --prefix=/Applications/Xcode.app/Contents/Developer/usr --with-gxx-include-dir=/usr/include/c++/4.2.1
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Clang:
Apple LLVM version 7.0.2 (clang-700.1.81)
Target: x86_64-apple-darwin15.6.0
Thread model: posix

Copy link
Member

@bbbales2 bbbales2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question and then this is good! Thanks!

@@ -17,7 +16,7 @@ class log_diff_exp_vv_vari : public op_vv_vari {
log_diff_exp_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no way to use inv_logit here? If so, that's fine, just checking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how you could because one has the extra 1 term inside and the other outside.

1 / expm1(x) = 1 / exp(x - 1)
inv_logit(u) = 1 / (1 + exp(-u))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good point. I was too lazy to actually think through the math at all.

@bbbales2 bbbales2 merged commit c26cda1 into stan-dev:develop Mar 19, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants