diff --git a/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py b/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py index e0ff9fa..53c420d 100644 --- a/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py +++ b/src/wassersteinflowmatching/wasserstein/WassersteinFlowMatching.py @@ -73,9 +73,9 @@ def __init__( lse_mode = self.config.wasserstein_lse, num_iteration = self.config.num_sinkhorn_iters), (0, 0), 0) - elif(self.monge_map == 'row_iter'): - print(f"Using row_iter map with {self.num_sinkhorn_iters} iterations and {self.config.wasserstein_eps} epsilon") - self.transport_plan_jit = jax.vmap(partial(utils_OT.transport_plan_rowiter, + elif(self.monge_map == 'rounded_matching'): + print(f"Using rounded_matching map with {self.num_sinkhorn_iters} iterations and {self.config.wasserstein_eps} epsilon") + self.transport_plan_jit = jax.vmap(partial(utils_OT.transport_plan_rounded, eps = self.config.wasserstein_eps, lse_mode = self.config.wasserstein_lse, num_iteration = self.config.num_sinkhorn_iters), @@ -144,12 +144,19 @@ def __init__( if(labels is not None): - self.label_to_num = {label: i for i, label in enumerate(np.unique(labels))} - self.num_to_label = {i: label for i, label in enumerate(np.unique(labels))} - self.labels = jnp.array([self.label_to_num[label] for label in labels]) - self.label_dim = len(np.unique(labels)) - self.config.label_dim = self.label_dim self.mini_batch_ot_mode = False + if(isinstance(labels, (str, int))): + self.discrete_labels = True + self.config.discrete_labels = True + self.label_to_num = {label: i for i, label in enumerate(np.unique(labels))} + self.num_to_label = {i: label for i, label in enumerate(np.unique(labels))} + self.labels = jnp.array([self.label_to_num[label] for label in labels]) + self.label_dim = len(np.unique(labels)) + self.config.label_dim = self.label_dim + else: + self.discrete_labels = False + self.config.discrete_labels = False + self.labels = labels[None, :] if labels.ndim == 1 else labels else: self.labels = None self.label_dim = -1 @@ -207,11 +214,12 @@ def create_train_state(self, model, peak_lr, end_lr, training_steps, warmup_step subkey, key = random.split(key) if(self.labels is not None): + labels_input = self.labels[np.random.choice(self.labels.shape[0], attn_inputs.shape[0])] params = model.init(rngs={"params": subkey}, point_cloud = attn_inputs, t = jnp.ones((attn_inputs.shape[0])), masks = jnp.ones((attn_inputs.shape[0], attn_inputs.shape[1])), - labels = jnp.ones((attn_inputs.shape[0])), + labels = labels_input, deterministic = True)['params'] else: params = model.init(rngs={"params": subkey}, @@ -456,7 +464,14 @@ def get_flow(self, params, point_clouds, weights, t, labels = None): labels = labels, deterministic = True)) return(flow) - + + def transform_labels(self, labels, inverse = False): + if(self.discrete_labels): + if(inverse): + return [self.num_to_label[label] for label in labels] + return jnp.array([self.label_to_num[label] for label in labels]) + else: + return labels def generate_samples(self, size = None, num_samples = 10, timesteps = 100, generate_labels = None, init_noise = None, key = random.key(0)): """ @@ -480,19 +495,29 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener subkey, key = random.split(key) noise_weights = random.choice(subkey, self.weights, [num_samples]) else: - if(generate_labels is None): - generate_labels = random.choice(key, self.label_dim, [num_samples], replace = True) - elif(isinstance(generate_labels, (str, int))): - generate_labels = jnp.array([self.label_to_num[generate_labels]] * num_samples) + if(self.discrete_labels): + if(generate_labels is None): + generate_labels = random.choice(key, self.label_dim, [num_samples], replace = True) + elif(isinstance(generate_labels, (str, int))): + generate_labels = jnp.repeat(self.transform_labels([generate_labels]), num_samples) + else: + generate_labels = self.transform_labels(generate_labels) + + if(noise_weights is None): + noise_weights = [] + for label in generate_labels: + subkey, key = random.split(key) + noise_weights.append(random.choice(subkey, self.weights[self.labels == label])) + noise_weights = jnp.vstack(noise_weights) else: - generate_labels = jnp.array([self.label_to_num[label] for label in generate_labels]) - - if(noise_weights is None): - noise_weights = [] - for label in generate_labels: + if(generate_labels is None): + generate_labels = self.labels[np.random.choice(self.labels.shape[0], num_samples, replace = False)] + elif(generate_labels.ndim == 1): + generate_labels = np.tile(generate_labels[None, :], [num_samples, 1]) + + if(noise_weights is None): subkey, key = random.split(key) - noise_weights.append(random.choice(subkey, self.weights[self.labels == label])) - noise_weights = jnp.vstack(noise_weights) + noise_weights = random.choice(subkey, self.weights, [num_samples]) subkey, key = random.split(key) if(init_noise is not None): @@ -515,4 +540,4 @@ def generate_samples(self, size = None, num_samples = 10, timesteps = 100, gener noise.append(noise[-1] + dt * grad_fn) if(generate_labels is None): return noise, noise_weights - return noise, noise_weights, [self.num_to_label[label] for label in np.array(generate_labels)] \ No newline at end of file + return noise, noise_weights, self.transform_labels(generate_labels, inverse = False) \ No newline at end of file diff --git a/src/wassersteinflowmatching/wasserstein/_utils_Transformer.py b/src/wassersteinflowmatching/wasserstein/_utils_Transformer.py index e31e531..7b6df23 100644 --- a/src/wassersteinflowmatching/wasserstein/_utils_Transformer.py +++ b/src/wassersteinflowmatching/wasserstein/_utils_Transformer.py @@ -89,7 +89,10 @@ def __call__(self, point_cloud, t, masks = None, labels = None, deterministic = if(labels is not None): - l_emb = nn.Dense(features = concat_dim)(jax.nn.one_hot(labels, config.label_dim)) + if(config.discrete_labels): + l_emb = nn.Dense(features = concat_dim)(jax.nn.one_hot(labels, config.label_dim)) + else: + l_emb = nn.Dense(features = concat_dim)(labels) x = jnp.concatenate([x, jnp.tile(l_emb[:, None, :], [1, point_cloud.shape[1], 1])], axis = -1) diff --git a/src/wassersteinflowmatching/wasserstein/utils_OT.py b/src/wassersteinflowmatching/wasserstein/utils_OT.py index 68b6350..3fdf59d 100644 --- a/src/wassersteinflowmatching/wasserstein/utils_OT.py +++ b/src/wassersteinflowmatching/wasserstein/utils_OT.py @@ -7,7 +7,7 @@ import ot # type: ignore import numpy as np # type: ignore -def argmax_row_iter(M): +def rounded_matching(M): """ Convert a soft assignment matrix M to a hard assignment vector by iteratively finding the largest value in M and making assignments. @@ -195,7 +195,7 @@ def ot_mat_from_distance(distance_matrix, eps = 0.002, lse_mode = True): lse_mode = lse_mode, min_iterations = 200, max_iterations = 200) - map_ind = argmax_row_iter(ot_solve.matrix) + map_ind = rounded_matching(ot_solve.matrix) return(map_ind) def sample_ot_matrix(pc_x, pc_y, mat, key): @@ -248,7 +248,7 @@ def transport_plan_argmax(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteratio delta = pc_y[map_ind]-pc_x return(delta, ot_solve) -def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200): +def transport_plan_rounded(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iteration = 200): pc_x, w_x = pc_x[0], pc_x[1] pc_y, w_y = pc_y[0], pc_y[1] @@ -260,7 +260,7 @@ def transport_plan_rowiter(pc_x, pc_y, eps = 0.01, lse_mode = False, num_iterati max_iterations = num_iteration, lse_mode = lse_mode) - map_ind = argmax_row_iter(ot_solve.matrix) + map_ind = rounded_matching(ot_solve.matrix) delta = pc_y[map_ind]-pc_x return(delta, ot_solve) diff --git a/tutorials/tutorial_spatial_niche_wfm.ipynb b/tutorials/tutorial_spatial_niche_wfm.ipynb index 0241683..a70151c 100644 --- a/tutorials/tutorial_spatial_niche_wfm.ipynb +++ b/tutorials/tutorial_spatial_niche_wfm.ipynb @@ -15,11 +15,6 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" # see issue #152\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n", - "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"0.95\"\n", - "\n", "import jax\n", "import jax.random as random\n", "import numpy as np\n", @@ -53,19 +48,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Annotated data can be downloaded from https://dp-lab-data-public.s3.amazonaws.com/WassersteinFM/st_data.h5ad " + "Annotated data can be downloaded from https://dp-lab-data-public.s3.us-east-1.amazonaws.com/WassersteinFM/st_data.h5ad" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "st_data = sc.read('/data/peer/DoronHaviv/merFISHCortex/st_data_envi_sst.h5ad')\n", - "#st_data = sc.read('st_data.h5ad')" + "st_data = sc.read('st_data.h5ad')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "markdown", "metadata": {},