From ec412fe4361f61fea09126a3e96205482b05a9c7 Mon Sep 17 00:00:00 2001 From: "doron.haviv12@gmail.com" Date: Fri, 15 Nov 2024 15:23:26 -0500 Subject: [PATCH 1/2] add pot --- .../riemannian_wasserstein/utils_Geom.py | 273 ++++++++++++++++++ .../wasserstein/DefaultConfig.py | 2 +- .../wasserstein/WassersteinFlowMatching.py | 10 +- .../wasserstein/utils_OT.py | 24 +- tutorials/tutorial_point_cloud_wfm.ipynb | 2 +- 5 files changed, 296 insertions(+), 15 deletions(-) diff --git a/src/wassersteinflowmatching/riemannian_wasserstein/utils_Geom.py b/src/wassersteinflowmatching/riemannian_wasserstein/utils_Geom.py index 3d36989..9a087e8 100644 --- a/src/wassersteinflowmatching/riemannian_wasserstein/utils_Geom.py +++ b/src/wassersteinflowmatching/riemannian_wasserstein/utils_Geom.py @@ -213,3 +213,276 @@ def general_step(): return new_p +class hyperbolic: + def project_to_geometry(self, P): + # Project points to ensure they lie within the Poincaré ball + # Normalize points that lie outside the unit ball + norm = jnp.linalg.norm(P, axis=-1, keepdims=True) + return jnp.where(norm >= 1.0, P / (norm + 1e-5), P) + + def mobius_addition(self, x, y): + """ + Möbius addition in the Poincaré ball model. + Formula: (1 + 2 + |y|²)x + (1 - |x|²)y / (1 + 2 + |x|²|y|²) + """ + x_norm_sq = jnp.sum(x**2) + y_norm_sq = jnp.sum(y**2) + dot_prod = jnp.dot(x, y) + numerator = (1 + 2*dot_prod + y_norm_sq)*x + (1 - x_norm_sq)*y + denominator = 1 + 2*dot_prod + x_norm_sq*y_norm_sq + return numerator / denominator + + def mobius_addition_batch(self, x, y): + """ + Vectorized Möbius addition for batches of points. + x: shape (n, d) or (d,) + y: shape (m, d) or (d,) + Returns: shape (n, m, d) or (n, d) depending on input shapes + """ + # Add batch dimensions if needed + if x.ndim == 1: + x = x[None, :] + if y.ndim == 1: + y = y[None, :] + + # Reshape for broadcasting + x = x[:, None, :] # (n, 1, d) + y = y[None, :, :] # (1, m, d) + + # Compute norms and dot products + x_norm_sq = jnp.sum(x**2, axis=-1, keepdims=True) # (n, 1, 1) + y_norm_sq = jnp.sum(y**2, axis=-1, keepdims=True) # (1, m, 1) + dot_prod = jnp.sum(x * y, axis=-1, keepdims=True) # (n, m, 1) + + # Compute Möbius addition + numerator = (1 + 2*dot_prod + y_norm_sq)*x + (1 - x_norm_sq)*y + denominator = 1 + 2*dot_prod + x_norm_sq*y_norm_sq + + return numerator / denominator + + def distance(self, P0, P1): + """ + Compute the hyperbolic distance between two points in the Poincaré ball. + Formula: d(x,y) = 2 * arctanh(|(-x) ⊕ y|) + """ + # Project points to ensure they're in the unit ball + P0 = self.project_to_geometry(P0) + P1 = self.project_to_geometry(P1) + + # Compute the Möbius addition of -P0 and P1 + minus_p0 = -P0 + mobius_sum = self.mobius_addition(minus_p0, P1) + + # Compute the norm of the result + norm = jnp.linalg.norm(mobius_sum) + + # Clip to avoid numerical issues + norm = jnp.clip(norm, 0.0, 1.0 - 1e-5) + + # Return the hyperbolic distance + return 2 * jnp.arctanh(norm) + + def distance_matrix(self, P0, P1): + """ + Compute pairwise hyperbolic distances between two sets of points. + P0: shape (n, d) + P1: shape (m, d) + Returns: shape (n, m) + """ + # Project points to ensure they're in the unit ball + P0 = self.project_to_geometry(P0) + P1 = self.project_to_geometry(P1) + + # Compute the Möbius addition of -P0 and P1 for all pairs + minus_P0 = -P0 + mobius_sums = self.mobius_addition_batch(minus_P0, P1) # shape (n, m, d) + + # Compute the norms of the results + norms = jnp.linalg.norm(mobius_sums, axis=-1) # shape (n, m) + + # Clip to avoid numerical issues + norms = jnp.clip(norms, 0.0, 1.0 - 1e-5) + + # Return the hyperbolic distances + return 2 * jnp.arctanh(norms) + + def interpolant(self, P0, P1, t): + """ + Compute geodesic interpolation in the Poincaré ball. + This is the geodesic from P0 to P1 at time t. + """ + # Project points to ensure they're in the unit ball + P0 = self.project_to_geometry(P0) + P1 = self.project_to_geometry(P1) + + # If points are very close, return linear interpolation + if jnp.allclose(P0, P1): + return (1 - t) * P0 + t * P1 + + + # Compute the geodesic + P0_norm = jnp.linalg.norm(P0) + P1_norm = jnp.linalg.norm(P1) + + # Handle special cases + if P0_norm < 1e-6: # P0 is near origin + return t * P1 + if P1_norm < 1e-6: # P1 is near origin + return (1 - t) * P0 + + # General case: compute the geodesic using the exponential map + initial_velocity = self.log_map(P0, P1) + return self.exponential_map(P0, initial_velocity, t) + + def velocity(self, P0, P1, t): + """ + Compute the velocity vector at time t along the geodesic from P0 to P1. + + Args: + P0: Starting point in the Poincaré ball + P1: Ending point in the Poincaré ball + t: Time parameter in [0,1] + + Returns: + Velocity vector at the point gamma(t) where gamma is the geodesic from P0 to P1 + """ + # Project points to ensure they're in the unit ball + P0 = self.project_to_geometry(P0) + P1 = self.project_to_geometry(P1) + + # If points are very close, return zero velocity + if jnp.allclose(P0, P1): + return jnp.zeros_like(P0) + + # Compute the initial velocity using the log map + initial_velocity = self.log_map(P0, P1) + + # Get the point at time t along the geodesic + Pt = self.interpolant(P0, P1, t) + + # Compute the conformal factors + lambda_P0 = 2 / (1 - jnp.sum(P0**2)) + lambda_Pt = 2 / (1 - jnp.sum(Pt**2)) + + # Compute the parallel transport from P0 to Pt + # First, get the squared norms + P0_norm_sq = jnp.sum(P0**2) + Pt_norm_sq = jnp.sum(Pt**2) + + # Compute the inner product + inner_prod = jnp.sum(P0 * Pt) + + # Compute the parallel transport scaling factor + # This accounts for the change in the metric tensor along the geodesic + scaling = lambda_P0 / lambda_Pt * ( + (1 - P0_norm_sq) / (1 - Pt_norm_sq) * + (1 + 2 * inner_prod + Pt_norm_sq) / + (1 + 2 * inner_prod + P0_norm_sq) + ) + + # For numerical stability, clip the scaling factor + scaling = jnp.clip(scaling, 1e-6, 1e6) + + # Parallel transport the initial velocity to Pt + transported_velocity = scaling * initial_velocity + + # Project the transported velocity onto the tangent space at Pt + # This ensures the velocity remains tangent to the manifold + Pt_component = jnp.sum(transported_velocity * Pt) * Pt + tangent_velocity = transported_velocity - Pt_component + + return tangent_velocity + + def tangent_norm(self, v, w, p): + """ + Compute the norm of the difference between two tangent vectors v and w at point x + in the Poincaré ball model. + + Args: + v: First tangent vector + w: Second tangent vector + x: Base point in the Poincaré ball where these vectors are tangent + + Returns: + The squared norm of the difference between the tangent vectors + under the hyperbolic metric + """ + # Project base point to ensure it's in the unit ball + p = self.project_to_geometry(p) + + # Ensure vectors are tangent by projecting out radial components + p_norm_sq = jnp.sum(p**2) + + # Project v onto tangent space at x + v_dot_p = jnp.sum(v * p) + v_tangent = v - (v_dot_p * p) + + # Project w onto tangent space at x + w_dot_p = jnp.sum(w * p) + w_tangent = w - (w_dot_p * p) + + # Compute the difference between the tangent vectors + diff = v_tangent - w_tangent + + # Compute the conformal factor (hyperbolic metric tensor) + # In the Poincaré ball, the metric tensor is scaled by lambda_x^2 + lambda_p = 2 / (1 - p_norm_sq) + + # Compute the squared norm under the hyperbolic metric + # ||v||^2 = _x = λ_x^2 _euclidean + return lambda_p**2 * jnp.sum(diff**2) + + def exponential_map(self, p, v, delta_t=1.0): + """ + Compute the exponential map in the Poincaré ball. + Maps a tangent vector v at point p to a point in the manifold. + """ + # Project p to ensure it's in the unit ball + p = self.project_to_geometry(p) + + # Compute the norm of v + v_norm = jnp.linalg.norm(v) + + # If v is very small, return p + if v_norm < 1e-6: + return p + + # Compute the conformal factor + lambda_p = 2 / (1 - jnp.sum(p**2)) + + # Scale the vector by delta_t + v = v * delta_t + + # Compute the exponential map + v_norm = jnp.linalg.norm(v) + coef = jnp.tanh(v_norm / (2 * lambda_p)) / v_norm + + # Return the result + result = self.mobius_addition(p, coef * v) + return self.project_to_geometry(result) + + def log_map(self, p, q): + """ + Compute the logarithmic map in the Poincaré ball. + Maps a point q to a tangent vector at point p. + """ + # Project points to ensure they're in the unit ball + p = self.project_to_geometry(p) + q = self.project_to_geometry(q) + + # If points are very close, return zero vector + if jnp.allclose(p, q): + return jnp.zeros_like(p) + + # Compute the Möbius addition of -p and q + minus_p = -p + mobius_diff = self.mobius_addition(minus_p, q) + + # Compute the norm of the difference + diff_norm = jnp.linalg.norm(mobius_diff) + + # Compute the conformal factor + lambda_p = 2 / (1 - jnp.sum(p**2)) + + # Compute the logarithmic map + return 2 * lambda_p * jnp.arctanh(diff_norm) * mobius_diff / diff_norm diff --git a/src/wassersteinflowmatching/wasserstein/DefaultConfig.py b/src/wassersteinflowmatching/wasserstein/DefaultConfig.py index 9be4243..9b08e66 100644 --- a/src/wassersteinflowmatching/wasserstein/DefaultConfig.py +++ b/src/wassersteinflowmatching/wasserstein/DefaultConfig.py @@ -37,7 +37,7 @@ class DefaultConfig: minibatch_ot_lse: bool = True noise_type: str = 'chol_normal' scaling: str = 'None' - noise_df_scale: float = 2 + noise_df_scale: float = 2.0 factor: float = 1.0 embedding_dim: int = 512 num_layers: int = 6 diff --git a/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py b/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py index 906cbc3..e0ff9fa 100644 --- a/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py +++ b/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py @@ -234,7 +234,7 @@ def create_train_state(self, model, peak_lr, end_lr, training_steps, warmup_step # learning_rate, decay_steps, 0.97, staircase = False, # ) - tx = optax.adamw(lr_sched) # + tx = optax.adam(lr_sched) # return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) @@ -264,7 +264,6 @@ def minibatch_ot(self, point_clouds, point_cloud_weights, noise, noise_weights, - @partial(jit, static_argnums=(0,)) def train_step(self, state, point_clouds_batch, weights_batch, labels_batch=None, noise_samples=None, noise_weights=None, key=random.key(0)): """ @@ -387,6 +386,7 @@ def train( print(f'Sampling {shape_sample} points from each point cloud') sample_points = jax.vmap(self.sample_single_batch, in_axes=(0, 0, 0, None)) + tq = trange(training_steps - self.state.step, leave=True, desc="") self.losses = [] for training_step in tq: @@ -500,12 +500,6 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener init_noise = init_noise[None, :, :] noise = [init_noise] else: - - # noise = self.noise_func(size =[num_samples, size, self.space_dim], - # minval = self.min_val, - # maxval = self.max_val, key = subkey) - - noise = self.noise_func(size = [num_samples, size, self.space_dim], noise_config = self.noise_config, key = subkey) diff --git a/src/wassersteinflowmatching/wasserstein/utils_OT.py b/src/wassersteinflowmatching/wasserstein/utils_OT.py index 441896c..68b6350 100644 --- a/src/wassersteinflowmatching/wasserstein/utils_OT.py +++ b/src/wassersteinflowmatching/wasserstein/utils_OT.py @@ -4,6 +4,8 @@ import jax # type: ignore from jax import lax # type: ignore from jax import random # type: ignore +import ot # type: ignore +import numpy as np # type: ignore def argmax_row_iter(M): """ @@ -134,8 +136,8 @@ def entropic_ot_distance(pc_x, pc_y, eps = 0.1, lse_mode = False): def euclidean_distance(pc_x, pc_y): - pc_x, w_x = pc_x[0], pc_x[1] - pc_y, w_y = pc_y[0], pc_y[1] + pc_x, _ = pc_x[0], pc_x[1] + pc_y, _ = pc_y[0], pc_y[1] dist = jnp.mean(jnp.sum((pc_x - pc_y)**2, axis = 1)) return(dist) @@ -187,7 +189,6 @@ def chamfer_distance(pc_x, pc_y): chamfer_dist = jnp.sum(pairwise_dist.min(axis = 0) * w_y) + jnp.sum(pairwise_dist.min(axis = 1) * w_x) return chamfer_dist - def ot_mat_from_distance(distance_matrix, eps = 0.002, lse_mode = True): ot_solve = linear.solve( ott.geometry.geometry.Geometry(cost_matrix = distance_matrix, epsilon = eps, scale_cost = 'max_cost'), @@ -263,6 +264,7 @@ def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iterati delta = pc_y[map_ind]-pc_x return(delta, ot_solve) + def transport_plan_sample(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200): pc_x, w_x = pc_x[0], pc_x[1] pc_y, w_y = pc_y[0], pc_y[1] @@ -277,9 +279,21 @@ def transport_plan_sample(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteratio return(ot_solve.matrix, ot_solve) -def transport_plan_euclidean(pc_x, pc_y): +def transport_plan_unreg(pc_x, pc_y): pc_x, w_x = pc_x[0], pc_x[1] pc_y, w_y = pc_y[0], pc_y[1] + T = ot.emd(w_x, w_y, ot.dist(pc_x, pc_y)) + + map_ind = np.argmax(T, axis=1) + delta = pc_y[map_ind] - pc_x + + return(delta, T) + +def transport_plan_euclidean(pc_x, pc_y): + pc_x, _ = pc_x[0], pc_x[1] + pc_y, _ = pc_y[0], pc_y[1] + delta = pc_y - pc_x - return(delta, 0) \ No newline at end of file + return(delta, 0) + diff --git a/tutorials/tutorial_point_cloud_wfm.ipynb b/tutorials/tutorial_point_cloud_wfm.ipynb index 501a8ca..0a01934 100644 --- a/tutorials/tutorial_point_cloud_wfm.ipynb +++ b/tutorials/tutorial_point_cloud_wfm.ipynb @@ -185,7 +185,7 @@ "FlowMatchingModel.train(batch_size = 64, \n", " shape_sample = 1000, \n", " training_steps = 500000, \n", - " warmup_steps = 5000)" + " warmup_steps = 50000)" ] }, { From 612f670f9ca3b3ccb8ac179674278920f4b587ba Mon Sep 17 00:00:00 2001 From: "doron.haviv12@gmail.com" Date: Fri, 15 Nov 2024 15:25:58 -0500 Subject: [PATCH 2/2] fix toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d71d04b..f28732d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ ott-jax = "^0.4.6" clu = "^0.0.12" tqdm = "^4.66.2" tensorflow_probability = "^0.24.0" +pot = "^0.9.4" [build-system] requires = ["poetry-core"]