Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(wip) Enumerable#sample #2

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions ext/enumerable/statistics/extension/statistics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,186 @@ enum_stdev(int argc, VALUE* argv, VALUE obj)
return stdev;
}

#if SIZEOF_SIZE_T == SIZEOF_LONG
static inline size_t
random_usize_limited(VALUE rnd, size_t max)
{
return (size_t)rb_random_ulong_limited(rnd, max);
}
#else
static inline size_t
random_usize_limited(VALUE rnd, size_t max)
{
if (max <= ULONG_MAX) {
return (size_t)rb_random_ulong_limited(rnd, (unsigned long)max);
}
else {
VALUE num = rb_random_int(rnd, SIZET2NUM(max));
return NUM2SIZET(num);
}
}
#endif

struct enum_sample_memo {
size_t k;
long n;
VALUE sample;
VALUE random;
};

static VALUE
enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
{
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
ENUM_WANT_SVALUE();

if (++memo->k <= 1) {
memo->sample = e;
}
else {
size_t j = random_usize_limited(memo->random, memo->k - 1);
if (j == 1) {
memo->sample = e;
}
}

return Qnil;
}

static VALUE
enum_sample_single(VALUE obj, VALUE random)
{
struct enum_sample_memo memo;

memo.k = 0;
memo.n = 1;
memo.sample = Qundef;
memo.random = random;

rb_block_call(obj, id_each, 0, 0, enum_sample_single_i, (VALUE)&memo);

return memo.sample;
}

static VALUE
enum_sample_multiple_without_replace_unweighted_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
{
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
ENUM_WANT_SVALUE();

if (++memo->k <= memo->n) {
rb_ary_push(memo->sample, e);
}
else {
size_t j = random_usize_limited(memo->random, memo->k - 1);
if (j <= memo->n) {
rb_ary_store(memo->sample, (long)(j - 1), e);
}
}

return Qnil;
}

static VALUE
enum_sample_multiple_unweighted(VALUE obj, long size, VALUE random, int replace_p)
{
struct enum_sample_memo memo;

assert(size > 1);

memo.k = 0;
memo.n = size;
memo.sample = rb_ary_new_capa(size);
memo.random = random;

if (replace_p) {
return Qnil;
}
else {
rb_block_call(obj, id_each, 0, 0, enum_sample_multiple_without_replace_unweighted_i, (VALUE)&memo);
}

return memo.sample;
}

/* call-seq:
* enum.sample -> obj
* enum.sample(random: rng) -> obj
* enum.sample(n) -> ary
* enum.sample(n, random: rng) -> ary
* enum.sample(n, random: rng, replace: true) -> ary
*
* Choose a random element or +n+ random elements from the enumerable.
*
* The enumerable is completely scanned just once for choosing random elements
* even if +n+ is ommitted or +n+ is +1+. This means this method cannot be
* applicable to an infinite enumerable.
*
* +replace:+ keyword specifies whether the sample is with or without
* replacement.
*
* On without-replacement sampling, the elements are chosen by using random
* in order to ensure that an element doesn't repeat itself unless the
* enumerable already contained duplicated elements.
*
* On with-replacement sampling, the elements are chosen by using random, and
* indices into the array can be duplicated even if the enumerable didn't contain
* duplicated elements.
*
* If the enumerable is empty the first two forms return +nil+, and the latter
* forms with +n+ return an empty array.
*
* The optional +rng+ argument will be used as the random number generator.
*/
static VALUE
enum_sample(int argc, VALUE *argv, VALUE obj)
{
VALUE size_v, random_v, replace_v, weights_v, opts;
long size;
int replace_p;

random_v = rb_cRandom;
replace_v = Qundef;
weights_v = Qundef;

if (argc == 0) goto single;

rb_scan_args(argc, argv, "01:", &size_v, &opts);
size = NIL_P(size_v) ? 1 : NUM2LONG(size_v);

if (size == 1 && NIL_P(opts)) {
goto single;
}

if (!NIL_P(opts)) {
static ID keywords[3];
VALUE kwargs[3];
if (!keywords[0]) {
keywords[0] = rb_intern("random");
keywords[1] = rb_intern("replace");
/* keywords[2] = rb_intern("weights"); */
}
rb_get_kwargs(opts, keywords, 0, 2, kwargs);
random_v = kwargs[0];
replace_v = kwargs[1];
/* weights_v = kwargs[2]; */
}

if (random_v == Qundef) {
random_v = rb_cRandom;
}

if (size == 1) {
single:
return enum_sample_single(obj, random_v);
}

replace_p = (replace_v == Qundef) ? 0 : RTEST(replace_v);

return enum_sample_multiple_unweighted(obj, size, random_v, replace_p);
}


/* call-seq:
* ary.mean_stdev(population: false)
*
Expand Down Expand Up @@ -1479,6 +1659,7 @@ Init_extension(void)
rb_define_method(rb_mEnumerable, "variance", enum_variance, -1);
rb_define_method(rb_mEnumerable, "mean_stdev", enum_mean_stdev, -1);
rb_define_method(rb_mEnumerable, "stdev", enum_stdev, -1);
rb_define_method(rb_mEnumerable, "sample", enum_sample, -1);

#ifndef HAVE_ARRAY_SUM
rb_define_method(rb_cArray, "sum", ary_sum, -1);
Expand Down
201 changes: 201 additions & 0 deletions spec/enum/sample_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
require 'spec_helper'
require 'enumerable/statistics'

RSpec.describe Enumerable, '#sample' do
let(:random) { Random.new }
let(:n) { 20 }

let(:replace) { nil }
let(:weights) { nil }
let(:opts) { {} }

before do
opts[:replace] = replace if replace
opts[:weights] = weights if weights
end

context 'when the receiver has 1 item' do
let(:enum) { 1.upto(1) }

shared_examples_for '1-item enumerable' do
context 'without replacement' do
specify { expect(opts).not_to include(:replace) }

specify do
expect(enum.sample(**opts)).to eq(1)

expect(enum.sample(10, **opts)).to eq([1])
expect(enum.sample(20, **opts)).to eq([1])
end
end

context 'with replacement' do
let(:replace) { true }

specify { expect(opts).to include(replace: true) }

specify do
expect(enum.sample(10, **opts)).to eq(Array.new(10, 1))
expect(enum.sample(20, **opts)).to eq(Array.new(20, 1))
end
end
end

context 'without weights' do
specify { expect(opts).not_to include(:weights) }

include_examples '1-item enumerable'
end

# TODO: weights
xcontext 'with weights' do
let(:weights) do
{ 1 => 1.0 }
end

specify { expect(opts).to include(weights: weights) }

include_examples '1-item enumerable'
end
end

context 'when the receiver has 2 item' do
let(:enum) { 1.upto(2) }

shared_examples_for 'sample from 2-item enumerable without replacement' do
specify { expect(opts).not_to include(:replace) }

specify do
expect(Array.new(100) { enum.sample(**opts) }).to all(eq(1).or eq(2))

expect(enum.sample(10, **opts)).to contain_exactly(1, 2)
expect(enum.sample(20, **opts)).to contain_exactly(1, 2)
end
end

context 'without weights' do
context 'without replacement' do
it_behaves_like 'sample from 2-item enumerable without replacement'
end

context 'with replacement' do
let(:replace) { true }

specify { expect(opts).to include(replace: true) }

specify do
expect(enum.sample(10, **opts)).to have_attributes(length: 10).and all(eq(1).or eq(2))
expect(enum.sample(20, **opts)).to have_attributes(length: 20).and all(eq(1).or eq(2))
end
end
end

# TODO: weights
xcontext 'with weights' do
specify { expect(opts).to include(weights: weights) }

context 'without replacement' do
it_behaves_like 'sample from 2-item enumerable without replacement'
end

context 'with replacement' do
let(:replace) { true }

specify { expect(opts).to include(replace: true) }
end
end
end

context 'without weight' do
let(:enum) { 1.upto(100000) }

specify { expect(opts).not_to include(:weights) }

context 'without replacement' do
specify { expect(opts).not_to include(:replace) }

context 'without size' do
context 'without rng' do
specify do
result = enum.sample
expect(result).to be_an(Integer)
other_results = Array.new(100) { enum.sample }
expect(other_results).not_to be_all {|i| i == result }
end
end

context 'with rng' do
specify do
save_random = random.dup
result = enum.sample(random: random)
expect(result).to be_an(Integer)
other_results = Array.new(100) { enum.sample(random: save_random.dup) }
expect(other_results).to be_all {|i| i == result }
end
end
end

context 'with size (== 1)' do
context 'without rng' do
specify do
result = enum.sample(1)
expect(result).to be_an(Integer)
other_results = Array.new(100) { enum.sample(1) }
expect(other_results).not_to be_all {|i| i == result }
end
end

context 'with rng' do
specify do
save_random = random.dup
result = enum.sample(1, random: random)
expect(result).to be_an(Integer)
other_results = Array.new(100) { enum.sample(1, random: save_random.dup) }
expect(other_results).to be_all {|i| i == result }
end
end
end

context 'with size (> 1)' do
context 'without rng' do
subject(:result) { enum.sample(n) }

specify do
result = enum.sample(n)
expect(result).to be_an(Array)
expect(result.length).to eq(n)
expect(result.uniq.length).to eq(n)
other_results = Array.new(100) { enum.sample(n) }
expect(other_results).not_to be_all {|i| i == result }
end
end

context 'with rng' do
subject(:result) { enum.sample(n, random: random) }

specify do
save_random = random.dup
result = enum.sample(n, random: random)
expect(result).to be_an(Array)
expect(result.length).to eq(n)
expect(result.uniq.length).to eq(n)
other_results = Array.new(100) { enum.sample(n, random: save_random.dup) }
expect(other_results).to be_all {|i| i == result }
end
end
end
end

context 'with replacement' do
let(:replace) { true }

specify { expect(opts).to include(replace: true) }

pending
end
end

context 'with weight' do
pending
end
end