Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] OT barycenters for generic transport costs #715

Open
wants to merge 28 commits into
base: master
Choose a base branch
from

Conversation

eloitanguy
Copy link
Collaborator

@eloitanguy eloitanguy commented Jan 20, 2025

Types of changes

  • contribute a generalisation of free_support that accepts any cost function (implements this paper)
  • contribute a specific version of this function in ot.gmm for fast computation of GMM barycenters
  • example for new barycenter function
  • example of GMM barycenters
  • add summary of features and example in README.md
  • allow the barycentric projection solver to accept ground_bary=None and to accept different notions of barycentric projections w.r.t. costs different to L2.
  • simple BCD solver with GD on the positions X

Motivation and context / Related issue

How has this been tested (if it applies)

  • full coverage in test/test_ot.py and test/test_gmm.py

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@github-actions github-actions bot added the ot.lp label Jan 20, 2025
Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 98.31461% with 3 lines in your changes missing coverage. Please review.

Project coverage is 97.14%. Comparing base (cbdf979) to head (371e3e7).

Additional details and impacted files
@@           Coverage Diff            @@
##           master     #715    +/-   ##
========================================
  Coverage   97.13%   97.14%            
========================================
  Files         100      100            
  Lines       20369    20545   +176     
========================================
+ Hits        19786    19959   +173     
- Misses        583      586     +3     
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions github-actions bot added the Tests label Jan 21, 2025
@github-actions github-actions bot added the CI label Jan 21, 2025
@eloitanguy eloitanguy changed the title [WIP] OT barycenters for generic transport costs [MRG] OT barycenters for generic transport costs Jan 21, 2025
@github-actions github-actions bot added the ot.gmm label Mar 3, 2025
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @eloitanguy this is very nice work as usual. I have a couple comments below

List of K arrays of measure weights, each of shape (m_k).
X_init : array-like
Array of shape (n, d) representing initial barycenter points.
cost_list : list of callable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you accept a callable (if same cost for eveyone) and a list?

@@ -426,3 +426,172 @@ def generalized_free_support_barycenter(
return Y, log_dict
else:
return Y


class StoppingCriterionReached(Exception):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of teh exception, because of the try there nuerical error in teh loop might not be detected... could you do a classical keak with variable stopingcriterion please?



# ground barycenter function
def B(y, its=150, lr=1, stop_threshold=stop_threshold):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cold you please do a function fr that (torch only) in pot that can be used if B is not passed (deafult=None). This will simplify the call for torch cost and keep the possibility to compute closed form.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it would be nice to have B accept a warm start for X0 also

axis = [-4, 4, -2, 6]
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
for k in range(K):
draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have different colors for te siferent soure distribs ()and maybe black/gray for the barycenter?

@eloitanguy
Copy link
Collaborator Author

After discussion with @rflamary we decided to expand the PR with additional features, namely two barycenter solvers:

  • a solver for the iterates of H of 76 with an optional ground_bary function computing barycenters between the points on different spaces, and with a user-defined notion of cost on each space $\mathcal{Y}_k$ with respect to which the barycentric projection is performed (defaults to L2, which is fast but often not accurate/meaningful depending on the spaces and costs)
  • another BCD solver which alternates between the plans and the the positions X (optimisation w.r.t. the positions with a pytorch optimiser)

@eloitanguy eloitanguy changed the title [MRG] OT barycenters for generic transport costs [WIP] OT barycenters for generic transport costs Mar 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants