Skip to content

Commit a7abdd9

Browse files
authored
rework sum, appender, and mean (#240)
* rework sum and mean * fixup
1 parent cbdc8d9 commit a7abdd9

File tree

3 files changed

+247
-156
lines changed

3 files changed

+247
-156
lines changed

source/mir/appender.d

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ struct ScopedBuffer(T, size_t bytes = 4096)
3535
private size_t _currentLength;
3636
private align(T.alignof) ubyte[_bufferLength * T.sizeof] _scopeBufferPayload = void;
3737

38-
private ref T[_bufferLength] _scopeBuffer() @trusted scope
38+
private ref inout(T[_bufferLength]) _scopeBuffer() inout @trusted scope
3939
{
40-
return *cast(T[_bufferLength]*)&_scopeBufferPayload;
40+
return *cast(inout(T[_bufferLength])*)&_scopeBufferPayload;
4141
}
4242

4343
private T[] prepare(size_t n) @trusted scope
@@ -71,15 +71,41 @@ struct ScopedBuffer(T, size_t bytes = 4096)
7171
else
7272
private alias R = T;
7373

74-
///
74+
/// Copy constructor is enabled only if `T` is mutable type without eleborate assign.
75+
static if (isMutable!T && !hasElaborateAssign!T)
76+
this(this)
77+
{
78+
import mir.internal.memory: malloc;
79+
if (_buffer.ptr)
80+
{
81+
auto buffer = (cast(T*)malloc(T.sizeof * _buffer.length))[0 .. _buffer.length];
82+
buffer[0 .. _currentLength] = _buffer[0 .. _currentLength];
83+
_buffer = buffer;
84+
}
85+
}
86+
else
7587
@disable this(this);
7688

7789
///
7890
~this()
7991
{
8092
import mir.internal.memory: free;
8193
data._mir_destroy;
82-
(() @trusted => free(cast(void*)_buffer.ptr))();
94+
(() @trusted { if (_buffer.ptr) free(cast(void*)_buffer.ptr); })();
95+
}
96+
97+
///
98+
void shrinkTo(size_t length)
99+
{
100+
assert(length <= _currentLength);
101+
data[length .. _currentLength]._mir_destroy;
102+
_currentLength = length;
103+
}
104+
105+
///
106+
size_t length() scope const @property
107+
{
108+
return _currentLength;
83109
}
84110

85111
///
@@ -160,7 +186,7 @@ struct ScopedBuffer(T, size_t bytes = 4096)
160186
}
161187

162188
///
163-
T[] data() @property @safe scope
189+
inout(T)[] data() inout @property @safe scope
164190
{
165191
return _buffer.length ? _buffer[0 .. _currentLength] : _scopeBuffer[0 .. _currentLength];
166192
}

source/mir/math/stat.d

Lines changed: 113 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ at expense of precision, one can use $(REF_ALTTEXT $(TT Summation.fast), Summati
77
88
License: $(LINK2 http://boost.org/LICENSE_1_0.txt, Boost License 1.0).
99
10-
Authors: Ilya Yaroshenko
10+
Authors: Ilya Yaroshenko, John Michael Hall
1111
1212
Copyright: 2019 Symmetry Investments Group and Kaleidic Associates Advisory Limited.
1313
@@ -23,41 +23,116 @@ import mir.math.common: fmamath;
2323
import mir.math.sum;
2424
import mir.primitives;
2525
import std.range.primitives: isInputRange;
26-
import std.traits: isArray, isFloatingPoint;
26+
import std.traits: isArray, isFloatingPoint, isMutable, isIterable;
2727

2828
/++
29-
Computes the average of `r`, which must be a finite iterable.
30-
31-
Returns:
32-
The average of all the elements in the range r.
29+
Output range for mean.
3330
+/
34-
template mean(Summation summation = Summation.appropriate)
31+
struct MeanAccumulator(T, Summation summation)
32+
if (isMutable!T)
3533
{
3634
///
37-
@safe @fmamath sumType!Range
38-
mean(Range)(Range r)
39-
if (hasLength!Range
40-
|| summation == Summation.appropriate
41-
|| summation == Summation.fast
42-
|| summation == Summation.naive)
35+
size_t count;
36+
///
37+
Summator!(T, summation) sumAccumulator;
38+
39+
///
40+
F mean(F = T)() @property
4341
{
44-
static if (hasLength!Range)
42+
return cast(F) sumAccumulator.sum / cast(F) count;
43+
}
44+
45+
///
46+
void put(Range)(Range r)
47+
if (isIterable!Range)
48+
{
49+
static if (hasShape!Range)
4550
{
46-
auto n = r.length;
47-
return sum!summation(r.move) / cast(sumType!Range) n;
51+
count += r.elementCount;
52+
sumAccumulator.put(r);
4853
}
4954
else
5055
{
51-
auto s = cast(typeof(return)) 0;
52-
size_t length;
53-
foreach (e; r)
56+
foreach(x; r)
5457
{
55-
length++;
56-
s += e;
58+
count++;
59+
sumAccumulator.put(x);
5760
}
58-
return s / cast(sumType!Range) length;
5961
}
6062
}
63+
64+
///
65+
void put()(T x)
66+
{
67+
count++;
68+
sumAccumulator.put(x);
69+
}
70+
}
71+
72+
///
73+
version(mir_test)
74+
@safe pure nothrow unittest
75+
{
76+
import mir.ndslice.slice : sliced;
77+
78+
MeanAccumulator!(double, Summation.pairwise) x;
79+
x.put([0.0, 1, 2, 3, 4].sliced);
80+
assert(x.mean == 2);
81+
x.put(5);
82+
assert(x.mean == 2.5);
83+
}
84+
85+
version(mir_test)
86+
@safe pure nothrow unittest
87+
{
88+
import mir.ndslice.slice : sliced;
89+
90+
MeanAccumulator!(float, Summation.pairwise) x;
91+
x.put([0, 1, 2, 3, 4].sliced);
92+
assert(x.mean == 2);
93+
x.put(5);
94+
assert(x.mean == 2.5);
95+
}
96+
97+
/++
98+
Computes the average of `r`, which must be a finite iterable.
99+
100+
Returns:
101+
The average of all the elements in the range r.
102+
+/
103+
template mean(F, Summation summation = Summation.appropriate)
104+
{
105+
/++
106+
Params:
107+
r = range
108+
+/
109+
F mean(Range)(Range r)
110+
if (isIterable!Range)
111+
{
112+
MeanAccumulator!(F, ResolveSummationType!(summation, Range, sumType!Range)) mean;
113+
mean.put(r.move);
114+
return mean.mean;
115+
}
116+
}
117+
118+
/// ditto
119+
template mean(Summation summation = Summation.appropriate)
120+
{
121+
/++
122+
Params:
123+
r = range
124+
+/
125+
sumType!Range mean(Range)(Range r)
126+
if (isIterable!Range)
127+
{
128+
return .mean!(sumType!Range, summation)(r.move);
129+
}
130+
}
131+
132+
///ditto
133+
template mean(F, string summation)
134+
{
135+
mixin("alias mean = .mean!(F, Summation." ~ summation ~ ");");
61136
}
62137

63138
///ditto
@@ -67,10 +142,23 @@ template mean(string summation)
67142
}
68143

69144
///
70-
version(mir_test) @safe pure nothrow unittest
145+
version(mir_test)
146+
@safe pure nothrow unittest
71147
{
148+
import mir.ndslice.slice : sliced;
149+
72150
assert(mean([1.0, 2, 3]) == 2);
73151
assert(mean([1.0 + 3i, 2, 3]) == 2 + 1i);
152+
153+
assert(mean!float([0, 1, 2, 3, 4, 5].sliced(3, 2)) == 2.5);
154+
155+
assert(is(typeof(mean!float([1, 2, 3])) == float));
156+
}
157+
158+
version(mir_test)
159+
@safe pure nothrow unittest
160+
{
161+
assert([1.0, 2, 3, 4].mean == 2.5);
74162
}
75163

76164
/++
@@ -157,7 +245,8 @@ template simpleLinearRegression(string summation)
157245
}
158246

159247
///
160-
version(mir_test) @safe pure nothrow @nogc unittest
248+
version(mir_test)
249+
@safe pure nothrow @nogc unittest
161250
{
162251
import mir.math.common: approxEqual;
163252
static immutable x = [0, 1, 2, 3];

0 commit comments

Comments
 (0)