|
1 | 1 | import torch
|
| 2 | +import torch.nn.functional as F |
2 | 3 | import torch_geometric
|
3 | 4 |
|
4 | 5 |
|
@@ -56,3 +57,104 @@ def forward(
|
56 | 57 | The lifted data.
|
57 | 58 | """
|
58 | 59 | return self.lift_features(data)
|
| 60 | + |
| 61 | + |
| 62 | +class ElementwiseMean(torch_geometric.transforms.BaseTransform): |
| 63 | + r"""Lifts r-cell features to r+1-cells by taking the mean of the lower |
| 64 | + dimensional features. |
| 65 | +
|
| 66 | + Parameters |
| 67 | + ---------- |
| 68 | + **kwargs : optional |
| 69 | + Additional arguments for the class. |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, **kwargs): |
| 73 | + super().__init__() |
| 74 | + |
| 75 | + def lift_features( |
| 76 | + self, data: torch_geometric.data.Data | dict |
| 77 | + ) -> torch_geometric.data.Data | dict: |
| 78 | + r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix. |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + data : torch_geometric.data.Data | dict |
| 83 | + The input data to be lifted. |
| 84 | +
|
| 85 | + Returns |
| 86 | + ------- |
| 87 | + torch_geometric.data.Data | dict |
| 88 | + The lifted data.""" |
| 89 | + |
| 90 | + # Find the maximum dimension of the input data |
| 91 | + max_dim = max( |
| 92 | + [int(key.split("_")[-1]) for key in data if "x_idx" in key] |
| 93 | + ) |
| 94 | + |
| 95 | + # Create a list of all x_idx tensors |
| 96 | + x_idx_tensors = [data[f"x_idx_{i}"] for i in range(max_dim + 1)] |
| 97 | + |
| 98 | + # Find the maximum sizes |
| 99 | + max_simplices = max(tensor.size(0) for tensor in x_idx_tensors) |
| 100 | + max_nodes = max(tensor.size(1) for tensor in x_idx_tensors) |
| 101 | + |
| 102 | + # Pad tensors to have the same size |
| 103 | + padded_tensors = [ |
| 104 | + F.pad( |
| 105 | + tensor, |
| 106 | + ( |
| 107 | + 0, |
| 108 | + max_nodes - tensor.size(1), |
| 109 | + 0, |
| 110 | + max_simplices - tensor.size(0), |
| 111 | + ), |
| 112 | + ) |
| 113 | + for tensor in x_idx_tensors |
| 114 | + ] |
| 115 | + |
| 116 | + # Stack all x_idx tensors |
| 117 | + all_indices = torch.stack(padded_tensors) |
| 118 | + |
| 119 | + # Create a mask for valid indices |
| 120 | + mask = all_indices != 0 |
| 121 | + |
| 122 | + # Replace 0s with a valid index (e.g., 0) to avoid indexing errors |
| 123 | + all_indices = all_indices.clamp(min=0) |
| 124 | + |
| 125 | + # Get all embeddings at once |
| 126 | + all_embeddings = data["x_0"][all_indices] |
| 127 | + |
| 128 | + # Apply mask to set padded embeddings to 0 |
| 129 | + all_embeddings = all_embeddings * mask.unsqueeze(-1).float() |
| 130 | + |
| 131 | + # Compute sum and count of non-zero elements |
| 132 | + embedding_sum = all_embeddings.sum(dim=2) |
| 133 | + count = mask.sum(dim=2).clamp(min=1) # Avoid division by zero |
| 134 | + |
| 135 | + # Compute mean |
| 136 | + mean_embeddings = embedding_sum / count.unsqueeze(-1) |
| 137 | + |
| 138 | + # Assign results back to data dictionary |
| 139 | + for i in range(1, max_dim + 1): |
| 140 | + original_size = x_idx_tensors[i].size(0) |
| 141 | + data[f"x_{i}"] = mean_embeddings[i, :original_size] |
| 142 | + |
| 143 | + return data |
| 144 | + |
| 145 | + def forward( |
| 146 | + self, data: torch_geometric.data.Data | dict |
| 147 | + ) -> torch_geometric.data.Data | dict: |
| 148 | + r"""Applies the lifting to the input data. |
| 149 | +
|
| 150 | + Parameters |
| 151 | + ---------- |
| 152 | + data : torch_geometric.data.Data | dict |
| 153 | + The input data to be lifted. |
| 154 | +
|
| 155 | + Returns |
| 156 | + ------- |
| 157 | + torch_geometric.data.Data | dict |
| 158 | + The lifted data. |
| 159 | + """ |
| 160 | + return self.lift_features(data) |
0 commit comments