-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] OT barycenters for generic transport costs #715
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #715 +/- ##
========================================
Coverage 97.13% 97.14%
========================================
Files 100 100
Lines 20369 20545 +176
========================================
+ Hits 19786 19959 +173
- Misses 583 586 +3 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @eloitanguy this is very nice work as usual. I have a couple comments below
ot/lp/_barycenter_solvers.py
Outdated
List of K arrays of measure weights, each of shape (m_k). | ||
X_init : array-like | ||
Array of shape (n, d) representing initial barycenter points. | ||
cost_list : list of callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you accept a callable (if same cost for eveyone) and a list?
ot/lp/_barycenter_solvers.py
Outdated
@@ -426,3 +426,172 @@ def generalized_free_support_barycenter( | |||
return Y, log_dict | |||
else: | |||
return Y | |||
|
|||
|
|||
class StoppingCriterionReached(Exception): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a fan of teh exception, because of the try there nuerical error in teh loop might not be detected... could you do a classical keak with variable stopingcriterion please?
|
||
|
||
# ground barycenter function | ||
def B(y, its=150, lr=1, stop_threshold=stop_threshold): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cold you please do a function fr that (torch only) in pot that can be used if B is not passed (deafult=None). This will simplify the call for torch cost and keep the possibility to compute closed form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that it would be nice to have B accept a warm start for X0 also
axis = [-4, 4, -2, 6] | ||
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) | ||
for k in range(K): | ||
draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we have different colors for te siferent soure distribs ()and maybe black/gray for the barycenter?
After discussion with @rflamary we decided to expand the PR with additional features, namely two barycenter solvers:
|
Types of changes
free_support
that accepts any cost function (implements this paper)ot.gmm
for fast computation of GMM barycentersREADME.md
ground_bary=None
and to accept different notions of barycentric projections w.r.t. costs different to L2.Motivation and context / Related issue
How has this been tested (if it applies)
test/test_ot.py
andtest/test_gmm.py
PR checklist