Skip to content

Commit

Permalink
Merge pull request #9 from DoronHav/remove_old_code
Browse files Browse the repository at this point in the history
fix aws link
  • Loading branch information
DoronHav authored Dec 2, 2024
2 parents a58df40 + f7f6706 commit 9211f04
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 36 deletions.
69 changes: 47 additions & 22 deletions src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(
lse_mode = self.config.wasserstein_lse,
num_iteration = self.config.num_sinkhorn_iters),
(0, 0), 0)
elif(self.monge_map == 'row_iter'):
print(f"Using row_iter map with {self.num_sinkhorn_iters} iterations and {self.config.wasserstein_eps} epsilon")
self.transport_plan_jit = jax.vmap(partial(utils_OT.transport_plan_rowiter,
elif(self.monge_map == 'rounded_matching'):
print(f"Using rounded_matching map with {self.num_sinkhorn_iters} iterations and {self.config.wasserstein_eps} epsilon")
self.transport_plan_jit = jax.vmap(partial(utils_OT.transport_plan_rounded,
eps = self.config.wasserstein_eps,
lse_mode = self.config.wasserstein_lse,
num_iteration = self.config.num_sinkhorn_iters),
Expand Down Expand Up @@ -144,12 +144,19 @@ def __init__(


if(labels is not None):
self.label_to_num = {label: i for i, label in enumerate(np.unique(labels))}
self.num_to_label = {i: label for i, label in enumerate(np.unique(labels))}
self.labels = jnp.array([self.label_to_num[label] for label in labels])
self.label_dim = len(np.unique(labels))
self.config.label_dim = self.label_dim
self.mini_batch_ot_mode = False
if(isinstance(labels, (str, int))):
self.discrete_labels = True
self.config.discrete_labels = True
self.label_to_num = {label: i for i, label in enumerate(np.unique(labels))}
self.num_to_label = {i: label for i, label in enumerate(np.unique(labels))}
self.labels = jnp.array([self.label_to_num[label] for label in labels])
self.label_dim = len(np.unique(labels))
self.config.label_dim = self.label_dim
else:
self.discrete_labels = False
self.config.discrete_labels = False
self.labels = labels[None, :] if labels.ndim == 1 else labels
else:
self.labels = None
self.label_dim = -1
Expand Down Expand Up @@ -207,11 +214,12 @@ def create_train_state(self, model, peak_lr, end_lr, training_steps, warmup_step
subkey, key = random.split(key)

if(self.labels is not None):
labels_input = self.labels[np.random.choice(self.labels.shape[0], attn_inputs.shape[0])]
params = model.init(rngs={"params": subkey},
point_cloud = attn_inputs,
t = jnp.ones((attn_inputs.shape[0])),
masks = jnp.ones((attn_inputs.shape[0], attn_inputs.shape[1])),
labels = jnp.ones((attn_inputs.shape[0])),
labels = labels_input,
deterministic = True)['params']
else:
params = model.init(rngs={"params": subkey},
Expand Down Expand Up @@ -456,7 +464,14 @@ def get_flow(self, params, point_clouds, weights, t, labels = None):
labels = labels,
deterministic = True))
return(flow)


def transform_labels(self, labels, inverse = False):
if(self.discrete_labels):
if(inverse):
return [self.num_to_label[label] for label in labels]
return jnp.array([self.label_to_num[label] for label in labels])
else:
return labels

def generate_samples(self, size = None, num_samples = 10, timesteps = 100, generate_labels = None, init_noise = None, key = random.key(0)):
"""
Expand All @@ -480,19 +495,29 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener
subkey, key = random.split(key)
noise_weights = random.choice(subkey, self.weights, [num_samples])
else:
if(generate_labels is None):
generate_labels = random.choice(key, self.label_dim, [num_samples], replace = True)
elif(isinstance(generate_labels, (str, int))):
generate_labels = jnp.array([self.label_to_num[generate_labels]] * num_samples)
if(self.discrete_labels):
if(generate_labels is None):
generate_labels = random.choice(key, self.label_dim, [num_samples], replace = True)
elif(isinstance(generate_labels, (str, int))):
generate_labels = jnp.repeat(self.transform_labels([generate_labels]), num_samples)
else:
generate_labels = self.transform_labels(generate_labels)

if(noise_weights is None):
noise_weights = []
for label in generate_labels:
subkey, key = random.split(key)
noise_weights.append(random.choice(subkey, self.weights[self.labels == label]))
noise_weights = jnp.vstack(noise_weights)
else:
generate_labels = jnp.array([self.label_to_num[label] for label in generate_labels])

if(noise_weights is None):
noise_weights = []
for label in generate_labels:
if(generate_labels is None):
generate_labels = self.labels[np.random.choice(self.labels.shape[0], num_samples, replace = False)]
elif(generate_labels.ndim == 1):
generate_labels = np.tile(generate_labels[None, :], [num_samples, 1])

if(noise_weights is None):
subkey, key = random.split(key)
noise_weights.append(random.choice(subkey, self.weights[self.labels == label]))
noise_weights = jnp.vstack(noise_weights)
noise_weights = random.choice(subkey, self.weights, [num_samples])
subkey, key = random.split(key)

if(init_noise is not None):
Expand All @@ -515,4 +540,4 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener
noise.append(noise[-1] + dt * grad_fn)
if(generate_labels is None):
return noise, noise_weights
return noise, noise_weights, [self.num_to_label[label] for label in np.array(generate_labels)]
return noise, noise_weights, self.transform_labels(generate_labels, inverse = False)
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __call__(self, point_cloud, t, masks = None, labels = None, deterministic =


if(labels is not None):
l_emb = nn.Dense(features = concat_dim)(jax.nn.one_hot(labels, config.label_dim))
if(config.discrete_labels):
l_emb = nn.Dense(features = concat_dim)(jax.nn.one_hot(labels, config.label_dim))
else:
l_emb = nn.Dense(features = concat_dim)(labels)
x = jnp.concatenate([x,
jnp.tile(l_emb[:, None, :], [1, point_cloud.shape[1], 1])], axis = -1)

Expand Down
8 changes: 4 additions & 4 deletions src/wassersteinflowmatching/wasserstein/utils_OT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ot # type: ignore
import numpy as np # type: ignore

def argmax_row_iter(M):
def rounded_matching(M):
"""
Convert a soft assignment matrix M to a hard assignment vector
by iteratively finding the largest value in M and making assignments.
Expand Down Expand Up @@ -195,7 +195,7 @@ def ot_mat_from_distance(distance_matrix, eps = 0.002, lse_mode = True):
lse_mode = lse_mode,
min_iterations = 200,
max_iterations = 200)
map_ind = argmax_row_iter(ot_solve.matrix)
map_ind = rounded_matching(ot_solve.matrix)
return(map_ind)

def sample_ot_matrix(pc_x, pc_y, mat, key):
Expand Down Expand Up @@ -248,7 +248,7 @@ def transport_plan_argmax(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteratio
delta = pc_y[map_ind]-pc_x
return(delta, ot_solve)

def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200):
def transport_plan_rounded(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200):
pc_x, w_x = pc_x[0], pc_x[1]
pc_y, w_y = pc_y[0], pc_y[1]

Expand All @@ -260,7 +260,7 @@ def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iterati
max_iterations = num_iteration,
lse_mode = lse_mode)

map_ind = argmax_row_iter(ot_solve.matrix)
map_ind = rounded_matching(ot_solve.matrix)
delta = pc_y[map_ind]-pc_x
return(delta, ot_solve)

Expand Down
17 changes: 8 additions & 9 deletions tutorials/tutorial_spatial_niche_wfm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" # see issue #152\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"0.95\"\n",
"\n",
"import jax\n",
"import jax.random as random\n",
"import numpy as np\n",
Expand Down Expand Up @@ -53,19 +48,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Annotated data can be downloaded from https://dp-lab-data-public.s3.amazonaws.com/WassersteinFM/st_data.h5ad "
"Annotated data can be downloaded from https://dp-lab-data-public.s3.us-east-1.amazonaws.com/WassersteinFM/st_data.h5ad"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"st_data = sc.read('/data/peer/DoronHaviv/merFISHCortex/st_data_envi_sst.h5ad')\n",
"#st_data = sc.read('st_data.h5ad')"
"st_data = sc.read('st_data.h5ad')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 9211f04

Please sign in to comment.