Skip to content

Commit

Permalink
feat: 2d cyclic edges and simplices
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Apr 1, 2024
1 parent c5c51f3 commit 842c8f7
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions kamui/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import numpy as np
from typing import Tuple, Optional, Iterable
from typing import Tuple, Optional, Iterable, Union

__all__ = ["get_2d_edges_and_simplices", "get_3d_edges_and_simplices"]


def get_2d_edges_and_simplices(
shape: Tuple[int, int]
shape: Tuple[int, int], cyclical_axis: Union[int, Tuple[int, int]] = ()
) -> Tuple[np.ndarray, Iterable[Iterable[int]]]:
nodes = np.arange(np.prod(shape)).reshape(shape)
if type(cyclical_axis) is int:
cyclical_axis = (cyclical_axis,)

# if the axis length <= 2, then the axis is already cyclical
cyclical_axis = tuple(filter(lambda ax: shape[ax] > 2, cyclical_axis))

edges = np.concatenate(
(
np.stack([nodes[:, :-1].ravel(), nodes[:, 1:].ravel()], axis=1),
np.stack([nodes[:-1, :].ravel(), nodes[1:, :].ravel()], axis=1),
)
+ tuple(
np.stack(
[
np.take(nodes, [0], axis=ax).ravel(),
np.take(nodes, [-1], axis=ax).ravel(),
],
axis=1,
)
for ax in cyclical_axis
),
axis=0,
)
Expand All @@ -24,6 +40,30 @@ def get_2d_edges_and_simplices(
),
axis=1,
).tolist()
if len(cyclical_axis) > 0:
pairs = [
(
np.squeeze(np.take(nodes, [0], axis=ax), axis=ax),
np.squeeze(np.take(nodes, [-1], axis=ax), axis=ax),
)
for ax in cyclical_axis
]
simplices += np.concatenate(
tuple(
np.stack(
(
x[:-1],
y[:-1],
y[1:],
x[1:],
),
axis=1,
)
for x, y in pairs
),
axis=0,
).tolist()

return edges, simplices


Expand Down

0 comments on commit 842c8f7

Please sign in to comment.