Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 13 additions & 68 deletions connectomics/common/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

from collections import abc
import numbers
from typing import Any, List, Tuple, TypeVar, Type, Union
from typing import Any, List, Tuple, TypeVar, Union

from connectomics.common import array_mixins
import numpy as np
import numpy.typing as npt

Expand All @@ -43,12 +42,11 @@
IndexExpOrPointLookups = Union[ArbitrarySlice, PointLookups]
CanonicalSliceOrPointLookups = Union[CanonicalSlice, PointLookups]

ArrayLike = Union[npt.ArrayLike, 'ImmutableArray', 'MutableArray']
ArrayLike = npt.ArrayLike
Tuple3f = Tuple[float, float, float]
Tuple3i = Tuple[int, int, int]
Tuple4i = Tuple[int, int, int, int]
ArrayLike3d = Union[npt.ArrayLike, 'ImmutableArray', 'MutableArray', Tuple3f,
Tuple3i]
ArrayLike3d = Union[npt.ArrayLike, Tuple3f, Tuple3i]


def is_point_lookup(ind: IndexExpOrPointLookups) -> bool:
Expand Down Expand Up @@ -120,75 +118,22 @@ def process_slice_ind(slice_ind: ArbitrarySlice, limit: int) -> slice:
return slice(*slice_ind.indices(limit))


# TODO(timblakely): Make these typed by using Generic[T]
class ImmutableArray(array_mixins.ImmutableArrayMixin, np.ndarray):
"""Strongly typed, immutable NumPy NDArray."""

def __new__(cls: Type['ImmutableArray'],
input_array: 'ArrayLike',
*args,
zero_copy=False,
**kwargs) -> 'ImmutableArray':
if zero_copy:
obj = np.asanyarray(input_array, *args, **kwargs).view(cls)
else:
obj = np.array(input_array, *args, **kwargs).view(cls)
obj.flags.writeable = False
return obj

def __init__(self, *args, zero_copy=False, **kwargs):
# Needed for mixin construction.
super().__init__() # pylint: disable=no-value-for-parameter

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get('out', ())

# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(_to_np_compatible(x) for x in inputs)
if out:
kwargs['out'] = tuple(_to_np_compatible(x) for x in out)
result = getattr(ufunc, method)(*inputs, **kwargs)

if method == 'at':
# no return value
return None

return MutableArray(result)

def copy(self, *args, **kwargs) -> 'MutableArray':
return MutableArray(np.asarray(self).copy(*args, **kwargs))

def __str__(self):
return np.ndarray.__repr__(self)


class MutableArray(array_mixins.MutableArrayMixin, ImmutableArray):
"""Strongly typed mutable version of np.ndarray."""

def __new__(cls: Type['MutableArray'],
input_array: 'ArrayLike',
*args,
zero_copy=False,
**kwargs) -> 'MutableArray':
if zero_copy:
obj = np.asanyarray(input_array, *args, **kwargs).view(cls)
else:
obj = np.array(input_array, *args, **kwargs).view(cls)
obj.flags.writeable = True
return obj
# Kludge to get around the fact that Jax uses its own ArrayImpl class, which
# doesn't fulfill either the abc.Sequence or np.ndarray contracts.
def _is_indexable(obj: Any) -> bool:
return hasattr(obj, '__getitem__') and hasattr(obj, '__setitem__')


def is_arraylike(obj):
# Technically sequences, but this is intended to check for numeric sequences.
if isinstance(obj, str) or isinstance(obj, bytes):
return False
return isinstance(obj, abc.Sequence) or isinstance(
obj, np.ndarray) or isinstance(obj, ImmutableArray) or isinstance(
obj, MutableArray)
return (
_is_indexable(obj)
or isinstance(obj, abc.Sequence)
or isinstance(obj, np.ndarray)
)


def _to_np_compatible(array_like) -> np.ndarray:
if isinstance(array_like, ImmutableArray) or isinstance(
array_like, MutableArray):
return np.asarray(array_like)
return array_like
return np.asarray(array_like)
264 changes: 0 additions & 264 deletions connectomics/common/array_mixins.py

This file was deleted.

Loading