Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,9 @@ class ShapeReducer(TensorReducerBase):
def __init__(self, inplace: bool = False):
super().__init__(inplace=inplace)

def _reduce_out_of_place(self, x: list[TensorType]) -> list[tuple[int, ...]]:
return [x[0].shape]
def _reduce_out_of_place(self, x: list[TensorType]) -> list[TensorType]:
# Return as tensor for consistency, because in-place reducer returns a tensor
return [fns.tensor(x[0].shape, backend=x[0].backend, dtype=TensorDataType.int32, device=x[0].device)]

def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
return None
Expand Down Expand Up @@ -561,25 +562,24 @@ def __hash__(self) -> int:


class NoopAggregator(AggregatorBase):
def __init__(self, num_samples: Optional[int]):
super().__init__(None, num_samples=num_samples)
def __init__(self, num_samples: Optional[int], return_first: bool = False):
"""
Creates an aggregator that only accumulates data without any additional processing.
:param num_samples: The number of samples to collect. If None, all samples are collected.
:param return_first: If True, the first collected sample is returned on aggregate call.
If False, all collected samples are returned as a list.
"""
if return_first and num_samples is not None and num_samples != 1:
msg = "NoopAggregator with return_first=True should not have num_samples > 1"
raise nncf.InternalError(msg)
super().__init__(None, num_samples=1 if return_first else num_samples)
self._return_first = return_first

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container.append(x)

def _aggregate_impl(self):
return self._container


class ShapeAggregator(AggregatorBase):
def __init__(self):
super().__init__(None, num_samples=1)

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container = x

def _aggregate_impl(self):
return self._container.shape
return self._container[0] if self._return_first else self._container


class OnlineAggregatorBase(AggregatorBase, ABC):
Expand Down
19 changes: 11 additions & 8 deletions src/nncf/experimental/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import nncf
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from nncf.tensor import functions as fns


Expand Down Expand Up @@ -108,31 +109,34 @@ def __eq__(self, other: TensorStatistic):
return False


@dataclass
@dataclass(init=False)
class MeanTensorStatistic(TensorStatistic):
MEAN_STAT: ClassVar[str] = "mean_values"
SHAPE_STAT: ClassVar[str] = "shape"

mean_values: Tensor
shape: tuple[int, ...]

def __init__(self, mean_values: Tensor, shape: Tensor) -> None:
self.mean_values = mean_values
self.shape = tuple(shape.tolist())

def __eq__(self, other: TensorStatistic):
if isinstance(other, MeanTensorStatistic):
return self.shape == other.shape and fns.allclose(self.mean_values, other.mean_values)
return False

def _get_serialized_data(self) -> dict[str, Tensor]:
backend = self.mean_values.backend
dtype = self.mean_values.dtype
device = self.mean_values.device
return {
self.MEAN_STAT: self.mean_values,
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=dtype, device=device),
self.SHAPE_STAT: fns.tensor(self.shape, backend=backend, dtype=TensorDataType.int32, device=device),
}

def load_data(self, loaded_data: dict[str, Tensor]) -> None:
self.mean_values = loaded_data[self.MEAN_STAT]
self.shape_values = tuple(loaded_data[self.SHAPE_STAT].tolist())
self.shape = tuple(loaded_data[self.SHAPE_STAT].tolist())


@dataclass
Expand Down Expand Up @@ -270,14 +274,13 @@ def __eq__(self, other: Any) -> bool:

def _get_serialized_data(self) -> dict[str, Tensor]:
backend = self.mean_values[0].backend
dtype = self.mean_values[0].dtype
device = self.mean_values[0].device
return {
self.MEAN_STAT: fns.stack(self.mean_values),
self.SHAPE_STAT: fns.tensor(
[[dim.data for dim in shape] for shape in self.shape_values],
self.shape_values,
backend=backend,
dtype=dtype,
dtype=TensorDataType.int32,
device=device,
),
}
Expand All @@ -292,5 +295,5 @@ def from_config(cls, config: dict[str, Any]) -> TensorStatistic:
if cls.MEAN_STAT in config and config[cls.MEAN_STAT] is not None:
mean_values = [fns.squeeze(it) for it in config[cls.MEAN_STAT]]
if cls.SHAPE_STAT in config and config[cls.SHAPE_STAT] is not None:
shape_values = [tuple(it) for it in config[cls.SHAPE_STAT]]
shape_values = [tuple(it.tolist()) for it in config[cls.SHAPE_STAT]]
return cls(mean_values=mean_values, shape_values=shape_values)
8 changes: 4 additions & 4 deletions src/nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
Expand All @@ -40,19 +40,19 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer(inplace)
else:
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
raw_reducer = RawReducer()
shape_reducer = ShapeReducer(inplace=inplace)

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}

aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
7 changes: 3 additions & 4 deletions src/nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
Expand Down Expand Up @@ -120,18 +119,18 @@ def get_mean_statistic_collector(
reducer = OVBatchMeanReducer(inplace)
else:
reducer = OVMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
raw_reducer = RawReducer()
shape_reducer = OVShapeReducer(inplace=inplace)

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
10 changes: 10 additions & 0 deletions src/nncf/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,13 @@ def as_numpy_tensor(a: Tensor) -> Tensor:
:param a: Tensor to change backend for.
:return: Tensor in numpy backend.
"""


@tensor_dispatcher
def tolist(a: Tensor) -> Any:
"""
Returns the tensor as a nested list.
For scalars, a standard Python number is returned, just like with item().

:return: The tensor as a nested list.
"""
5 changes: 5 additions & 0 deletions src/nncf/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,8 @@ def tensor(
validate_device(device)
np_dtype = convert_to_numpy_dtype(dtype)
return np.array(data, dtype=np_dtype)


@numeric.tolist.register
def _(a: T_NUMPY) -> Any:
return a.tolist()
5 changes: 5 additions & 0 deletions src/nncf/tensor/functions/tf_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,8 @@ def tensor(
@numeric.as_numpy_tensor.register
def _(a: tf.Tensor) -> npt.NDArray[Any]:
return a.numpy()


@numeric.tolist.register
def _(a: tf.Tensor) -> Any:
return a.numpy().tolist()
5 changes: 5 additions & 0 deletions src/nncf/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,8 @@ def tensor(
@numeric.as_numpy_tensor.register
def _(a: torch.Tensor) -> NDArray[Any]:
return a.cpu().detach().numpy()


@numeric.tolist.register
def _(a: torch.Tensor) -> Any:
return a.tolist()
3 changes: 3 additions & 0 deletions src/nncf/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def clone(self) -> Tensor:
def as_numpy_tensor(self) -> Tensor:
return cast(Tensor, _call_function("as_numpy_tensor", self))

def tolist(self) -> Any:
return _call_function("tolist", self)

def as_openvino_tensor(self) -> Tensor:
x = self
if x.backend == TensorBackend.numpy:
Expand Down
8 changes: 4 additions & 4 deletions src/nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import MedianMADTensorStatistic
Expand Down Expand Up @@ -306,18 +306,18 @@ def get_mean_statistic_collector(
reducer = BatchMeanReducer()
else:
reducer = MeanPerChReducer(channel_axis=channel_axis)
raw_reducer = RawReducer()
shape_reducer = ShapeReducer()

kwargs = {
"num_samples": num_samples,
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
aggregate_noop = NoopAggregator(num_samples=1, return_first=True)

collector = TensorCollector(MeanTensorStatistic)
collector.register_statistic_branch(MeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, raw_reducer, aggregate_shape)
collector.register_statistic_branch(MeanTensorStatistic.SHAPE_STAT, shape_reducer, aggregate_noop)
return collector


Expand Down
21 changes: 14 additions & 7 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from abc import abstractmethod
from collections import deque
from dataclasses import dataclass
from functools import partial
from itertools import product
Expand Down Expand Up @@ -37,7 +38,6 @@
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.tensor import Tensor
Expand Down Expand Up @@ -296,19 +296,26 @@ def test_noop_aggregator(self):

assert aggregator._collected_samples == 3
aggregated = aggregator.aggregate()
assert isinstance(aggregated, deque)
assert len(aggregated) == 3
for val in aggregated:
assert fns.allclose(val, self.get_nncf_tensor(input_))

def test_shape_aggregator(self):
aggregator = ShapeAggregator()
@pytest.mark.parametrize("num_samples,should_fail", [(1, False), (None, False), (3, True)])
def test_noop_aggregator_return_first(self, num_samples, should_fail):
if should_fail:
with pytest.raises(nncf.InternalError):
NoopAggregator(num_samples, return_first=True)
return
aggregator = NoopAggregator(num_samples, return_first=True)

ref_shape = (1, 3, 5, 7, 9)
input_ = np.empty(ref_shape)
for _ in range(3):
aggregator.register_reduced_input(self.get_nncf_tensor(input_))
input_ = np.arange(np.prod(ref_shape)).reshape(ref_shape)
aggregator.register_reduced_input(self.get_nncf_tensor(input_))

assert aggregator._collected_samples == 1
assert ref_shape == aggregator.aggregate()
aggregated = aggregator.aggregate()
assert fns.allclose(aggregated, self.get_nncf_tensor(input_))

@pytest.mark.parametrize(
"offline_aggregators_test_desc",
Expand Down
4 changes: 2 additions & 2 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ def test_mean_max_stat_building(self):
tensor_collector.register_statistic_branch(
MeanTensorStatistic.SHAPE_STAT, DummyTensorReducer("B"), DummyTensorAggregator()
)
tensor_collector.register_input_for_all_reducers(Tensor(np.array(1)))
tensor_collector.register_input_for_all_reducers(Tensor(np.array([1])))
statistic = tensor_collector.get_statistics()
assert isinstance(statistic, MeanTensorStatistic)
assert statistic.mean_values == statistic.shape == Tensor(np.array(1))
assert statistic.mean_values == statistic.shape == Tensor(np.array([1]))

def test_median_mad_stat_building(self):
class DummyMADPercentileAggregator(DummyTensorAggregator):
Expand Down
7 changes: 7 additions & 0 deletions tests/cross_fw/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,3 +2390,10 @@ def test_as_numpy_tensor(self):
assert tensor1.shape == tensor2.shape
assert tensor2.device == TensorDeviceType.CPU
assert fns.allclose(tensor1, tensor2)

def test_tolist(self):
inp_list = [[1.0, 2.0], [3.0, 4.0]]
tensor = Tensor(self.to_tensor(inp_list))
assert tensor.tolist() == inp_list
assert tensor[0].tolist() == inp_list[0]
assert tensor[0][0].tolist() == inp_list[0][0]