Skip to content

Commit

Permalink
Merge pull request #8 from DoronHav/remove_old_code
Browse files Browse the repository at this point in the history
add pot
  • Loading branch information
DoronHav authored Nov 15, 2024
2 parents c9a6c73 + 612f670 commit a58df40
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ott-jax = "^0.4.6"
clu = "^0.0.12"
tqdm = "^4.66.2"
tensorflow_probability = "^0.24.0"
pot = "^0.9.4"

[build-system]
requires = ["poetry-core"]
Expand Down
273 changes: 273 additions & 0 deletions src/wassersteinflowmatching/riemannian_wasserstein/utils_Geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,276 @@ def general_step():

return new_p

class hyperbolic:
def project_to_geometry(self, P):
# Project points to ensure they lie within the Poincaré ball
# Normalize points that lie outside the unit ball
norm = jnp.linalg.norm(P, axis=-1, keepdims=True)
return jnp.where(norm >= 1.0, P / (norm + 1e-5), P)

def mobius_addition(self, x, y):
"""
Möbius addition in the Poincaré ball model.
Formula: (1 + 2<x,y> + |y|²)x + (1 - |x|²)y / (1 + 2<x,y> + |x|²|y|²)
"""
x_norm_sq = jnp.sum(x**2)
y_norm_sq = jnp.sum(y**2)
dot_prod = jnp.dot(x, y)
numerator = (1 + 2*dot_prod + y_norm_sq)*x + (1 - x_norm_sq)*y
denominator = 1 + 2*dot_prod + x_norm_sq*y_norm_sq
return numerator / denominator

def mobius_addition_batch(self, x, y):
"""
Vectorized Möbius addition for batches of points.
x: shape (n, d) or (d,)
y: shape (m, d) or (d,)
Returns: shape (n, m, d) or (n, d) depending on input shapes
"""
# Add batch dimensions if needed
if x.ndim == 1:
x = x[None, :]
if y.ndim == 1:
y = y[None, :]

# Reshape for broadcasting
x = x[:, None, :] # (n, 1, d)
y = y[None, :, :] # (1, m, d)

# Compute norms and dot products
x_norm_sq = jnp.sum(x**2, axis=-1, keepdims=True) # (n, 1, 1)
y_norm_sq = jnp.sum(y**2, axis=-1, keepdims=True) # (1, m, 1)
dot_prod = jnp.sum(x * y, axis=-1, keepdims=True) # (n, m, 1)

# Compute Möbius addition
numerator = (1 + 2*dot_prod + y_norm_sq)*x + (1 - x_norm_sq)*y
denominator = 1 + 2*dot_prod + x_norm_sq*y_norm_sq

return numerator / denominator

def distance(self, P0, P1):
"""
Compute the hyperbolic distance between two points in the Poincaré ball.
Formula: d(x,y) = 2 * arctanh(|(-x) ⊕ y|)
"""
# Project points to ensure they're in the unit ball
P0 = self.project_to_geometry(P0)
P1 = self.project_to_geometry(P1)

# Compute the Möbius addition of -P0 and P1
minus_p0 = -P0
mobius_sum = self.mobius_addition(minus_p0, P1)

# Compute the norm of the result
norm = jnp.linalg.norm(mobius_sum)

# Clip to avoid numerical issues
norm = jnp.clip(norm, 0.0, 1.0 - 1e-5)

# Return the hyperbolic distance
return 2 * jnp.arctanh(norm)

def distance_matrix(self, P0, P1):
"""
Compute pairwise hyperbolic distances between two sets of points.
P0: shape (n, d)
P1: shape (m, d)
Returns: shape (n, m)
"""
# Project points to ensure they're in the unit ball
P0 = self.project_to_geometry(P0)
P1 = self.project_to_geometry(P1)

# Compute the Möbius addition of -P0 and P1 for all pairs
minus_P0 = -P0
mobius_sums = self.mobius_addition_batch(minus_P0, P1) # shape (n, m, d)

# Compute the norms of the results
norms = jnp.linalg.norm(mobius_sums, axis=-1) # shape (n, m)

# Clip to avoid numerical issues
norms = jnp.clip(norms, 0.0, 1.0 - 1e-5)

# Return the hyperbolic distances
return 2 * jnp.arctanh(norms)

def interpolant(self, P0, P1, t):
"""
Compute geodesic interpolation in the Poincaré ball.
This is the geodesic from P0 to P1 at time t.
"""
# Project points to ensure they're in the unit ball
P0 = self.project_to_geometry(P0)
P1 = self.project_to_geometry(P1)

# If points are very close, return linear interpolation
if jnp.allclose(P0, P1):
return (1 - t) * P0 + t * P1


# Compute the geodesic
P0_norm = jnp.linalg.norm(P0)
P1_norm = jnp.linalg.norm(P1)

# Handle special cases
if P0_norm < 1e-6: # P0 is near origin
return t * P1
if P1_norm < 1e-6: # P1 is near origin
return (1 - t) * P0

# General case: compute the geodesic using the exponential map
initial_velocity = self.log_map(P0, P1)
return self.exponential_map(P0, initial_velocity, t)

def velocity(self, P0, P1, t):
"""
Compute the velocity vector at time t along the geodesic from P0 to P1.
Args:
P0: Starting point in the Poincaré ball
P1: Ending point in the Poincaré ball
t: Time parameter in [0,1]
Returns:
Velocity vector at the point gamma(t) where gamma is the geodesic from P0 to P1
"""
# Project points to ensure they're in the unit ball
P0 = self.project_to_geometry(P0)
P1 = self.project_to_geometry(P1)

# If points are very close, return zero velocity
if jnp.allclose(P0, P1):
return jnp.zeros_like(P0)

# Compute the initial velocity using the log map
initial_velocity = self.log_map(P0, P1)

# Get the point at time t along the geodesic
Pt = self.interpolant(P0, P1, t)

# Compute the conformal factors
lambda_P0 = 2 / (1 - jnp.sum(P0**2))
lambda_Pt = 2 / (1 - jnp.sum(Pt**2))

# Compute the parallel transport from P0 to Pt
# First, get the squared norms
P0_norm_sq = jnp.sum(P0**2)
Pt_norm_sq = jnp.sum(Pt**2)

# Compute the inner product
inner_prod = jnp.sum(P0 * Pt)

# Compute the parallel transport scaling factor
# This accounts for the change in the metric tensor along the geodesic
scaling = lambda_P0 / lambda_Pt * (
(1 - P0_norm_sq) / (1 - Pt_norm_sq) *
(1 + 2 * inner_prod + Pt_norm_sq) /
(1 + 2 * inner_prod + P0_norm_sq)
)

# For numerical stability, clip the scaling factor
scaling = jnp.clip(scaling, 1e-6, 1e6)

# Parallel transport the initial velocity to Pt
transported_velocity = scaling * initial_velocity

# Project the transported velocity onto the tangent space at Pt
# This ensures the velocity remains tangent to the manifold
Pt_component = jnp.sum(transported_velocity * Pt) * Pt
tangent_velocity = transported_velocity - Pt_component

return tangent_velocity

def tangent_norm(self, v, w, p):
"""
Compute the norm of the difference between two tangent vectors v and w at point x
in the Poincaré ball model.
Args:
v: First tangent vector
w: Second tangent vector
x: Base point in the Poincaré ball where these vectors are tangent
Returns:
The squared norm of the difference between the tangent vectors
under the hyperbolic metric
"""
# Project base point to ensure it's in the unit ball
p = self.project_to_geometry(p)

# Ensure vectors are tangent by projecting out radial components
p_norm_sq = jnp.sum(p**2)

# Project v onto tangent space at x
v_dot_p = jnp.sum(v * p)
v_tangent = v - (v_dot_p * p)

# Project w onto tangent space at x
w_dot_p = jnp.sum(w * p)
w_tangent = w - (w_dot_p * p)

# Compute the difference between the tangent vectors
diff = v_tangent - w_tangent

# Compute the conformal factor (hyperbolic metric tensor)
# In the Poincaré ball, the metric tensor is scaled by lambda_x^2
lambda_p = 2 / (1 - p_norm_sq)

# Compute the squared norm under the hyperbolic metric
# ||v||^2 = <v,v>_x = λ_x^2 <v,v>_euclidean
return lambda_p**2 * jnp.sum(diff**2)

def exponential_map(self, p, v, delta_t=1.0):
"""
Compute the exponential map in the Poincaré ball.
Maps a tangent vector v at point p to a point in the manifold.
"""
# Project p to ensure it's in the unit ball
p = self.project_to_geometry(p)

# Compute the norm of v
v_norm = jnp.linalg.norm(v)

# If v is very small, return p
if v_norm < 1e-6:
return p

# Compute the conformal factor
lambda_p = 2 / (1 - jnp.sum(p**2))

# Scale the vector by delta_t
v = v * delta_t

# Compute the exponential map
v_norm = jnp.linalg.norm(v)
coef = jnp.tanh(v_norm / (2 * lambda_p)) / v_norm

# Return the result
result = self.mobius_addition(p, coef * v)
return self.project_to_geometry(result)

def log_map(self, p, q):
"""
Compute the logarithmic map in the Poincaré ball.
Maps a point q to a tangent vector at point p.
"""
# Project points to ensure they're in the unit ball
p = self.project_to_geometry(p)
q = self.project_to_geometry(q)

# If points are very close, return zero vector
if jnp.allclose(p, q):
return jnp.zeros_like(p)

# Compute the Möbius addition of -p and q
minus_p = -p
mobius_diff = self.mobius_addition(minus_p, q)

# Compute the norm of the difference
diff_norm = jnp.linalg.norm(mobius_diff)

# Compute the conformal factor
lambda_p = 2 / (1 - jnp.sum(p**2))

# Compute the logarithmic map
return 2 * lambda_p * jnp.arctanh(diff_norm) * mobius_diff / diff_norm
2 changes: 1 addition & 1 deletion src/wassersteinflowmatching/wasserstein/DefaultConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DefaultConfig:
minibatch_ot_lse: bool = True
noise_type: str = 'chol_normal'
scaling: str = 'None'
noise_df_scale: float = 2
noise_df_scale: float = 2.0
factor: float = 1.0
embedding_dim: int = 512
num_layers: int = 6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def create_train_state(self, model, peak_lr, end_lr, training_steps, warmup_step
# learning_rate, decay_steps, 0.97, staircase = False,
# )

tx = optax.adamw(lr_sched) #
tx = optax.adam(lr_sched) #

return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

Expand Down Expand Up @@ -264,7 +264,6 @@ def minibatch_ot(self, point_clouds, point_cloud_weights, noise, noise_weights,




@partial(jit, static_argnums=(0,))
def train_step(self, state, point_clouds_batch, weights_batch, labels_batch=None, noise_samples=None, noise_weights=None, key=random.key(0)):
"""
Expand Down Expand Up @@ -387,6 +386,7 @@ def train(
print(f'Sampling {shape_sample} points from each point cloud')
sample_points = jax.vmap(self.sample_single_batch, in_axes=(0, 0, 0, None))


tq = trange(training_steps - self.state.step, leave=True, desc="")
self.losses = []
for training_step in tq:
Expand Down Expand Up @@ -500,12 +500,6 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener
init_noise = init_noise[None, :, :]
noise = [init_noise]
else:

# noise = self.noise_func(size =[num_samples, size, self.space_dim],
# minval = self.min_val,
# maxval = self.max_val, key = subkey)


noise = self.noise_func(size = [num_samples, size, self.space_dim],
noise_config = self.noise_config,
key = subkey)
Expand Down
24 changes: 19 additions & 5 deletions src/wassersteinflowmatching/wasserstein/utils_OT.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax # type: ignore
from jax import lax # type: ignore
from jax import random # type: ignore
import ot # type: ignore
import numpy as np # type: ignore

def argmax_row_iter(M):
"""
Expand Down Expand Up @@ -134,8 +136,8 @@ def entropic_ot_distance(pc_x, pc_y, eps = 0.1, lse_mode = False):


def euclidean_distance(pc_x, pc_y):
pc_x, w_x = pc_x[0], pc_x[1]
pc_y, w_y = pc_y[0], pc_y[1]
pc_x, _ = pc_x[0], pc_x[1]
pc_y, _ = pc_y[0], pc_y[1]

dist = jnp.mean(jnp.sum((pc_x - pc_y)**2, axis = 1))
return(dist)
Expand Down Expand Up @@ -187,7 +189,6 @@ def chamfer_distance(pc_x, pc_y):
chamfer_dist = jnp.sum(pairwise_dist.min(axis = 0) * w_y) + jnp.sum(pairwise_dist.min(axis = 1) * w_x)
return chamfer_dist


def ot_mat_from_distance(distance_matrix, eps = 0.002, lse_mode = True):
ot_solve = linear.solve(
ott.geometry.geometry.Geometry(cost_matrix = distance_matrix, epsilon = eps, scale_cost = 'max_cost'),
Expand Down Expand Up @@ -263,6 +264,7 @@ def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iterati
delta = pc_y[map_ind]-pc_x
return(delta, ot_solve)


def transport_plan_sample(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200):
pc_x, w_x = pc_x[0], pc_x[1]
pc_y, w_y = pc_y[0], pc_y[1]
Expand All @@ -277,9 +279,21 @@ def transport_plan_sample(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteratio

return(ot_solve.matrix, ot_solve)

def transport_plan_euclidean(pc_x, pc_y):
def transport_plan_unreg(pc_x, pc_y):
pc_x, w_x = pc_x[0], pc_x[1]
pc_y, w_y = pc_y[0], pc_y[1]

T = ot.emd(w_x, w_y, ot.dist(pc_x, pc_y))

map_ind = np.argmax(T, axis=1)
delta = pc_y[map_ind] - pc_x

return(delta, T)

def transport_plan_euclidean(pc_x, pc_y):
pc_x, _ = pc_x[0], pc_x[1]
pc_y, _ = pc_y[0], pc_y[1]

delta = pc_y - pc_x
return(delta, 0)
return(delta, 0)

Loading

0 comments on commit a58df40

Please sign in to comment.