Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 120963f

Browse files
authored
Fix performance issues with specset (#11)
* Determine equality of StatsStep arguments by === * Add copyargs and avoid unnecessary copying in proceed
1 parent dc9ea77 commit 120963f

File tree

8 files changed

+120
-94
lines changed

8 files changed

+120
-94
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1010
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
12-
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
1312
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1413
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
1514
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
@@ -21,7 +20,6 @@ DataFrames = "0.22"
2120
MacroTools = "0.5"
2221
Missings = "0.4"
2322
Reexport = "0.2, 1"
24-
SplitApplyCombine = "1.1"
2523
StatsBase = "0.33"
2624
StatsModels = "0.6.18"
2725
Tables = "1.2"

src/DiffinDiffsBase.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Combinatorics: combinations
55
using MacroTools: @capture, isexpr, postwalk
66
using Missings: disallowmissing
77
using Reexport
8-
using SplitApplyCombine: groupfind, groupview
98
using StatsBase: Weights, uweights
109
@reexport using StatsModels
1110
using Tables: istable, getcolumn, columntable, columnnames

src/StatsProcedures.jl

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,14 @@ See also [`proceed`](@ref).
9494
"""
9595
combinedargs(::StatsStep, ::Any) = ()
9696

97+
copyargs(::StatsStep) = ()
98+
9799
function (step::StatsStep{A,F})(@nospecialize(ntargs::NamedTuple);
98100
verbose::Bool=false) where {A,F}
99101
haskey(ntargs, :verbose) && (verbose = ntargs.verbose)
100102
verbose && printstyled("Running ", step, "\n", color=:green)
101103
ret = F.instance(groupargs(step, ntargs)..., combinedargs(step, (ntargs,))...)
102-
if ret isa Tuple{<:NamedTuple, Bool}
103-
return merge(ntargs, ret[1])
104-
else
105-
error("unexpected $(typeof(ret)) returned from $(F.name.mt.name) associated with StatsStep $A")
106-
end
104+
return merge(ntargs, ret)
107105
end
108106

109107
(step::StatsStep)(; verbose::Bool=false) = step(NamedTuple(), verbose=verbose)
@@ -197,6 +195,7 @@ _sharedby(s::SharedStatsStep) = s.ids
197195
_f(s::SharedStatsStep) = _f(s.step)
198196
groupargs(s::SharedStatsStep, @nospecialize(ntargs::NamedTuple)) = groupargs(s.step, ntargs)
199197
combinedargs(s::SharedStatsStep, v::AbstractArray) = combinedargs(s.step, v)
198+
copyargs(s::SharedStatsStep) = copyargs(s.step)
200199

201200
==(x::SharedStatsStep, y::SharedStatsStep) =
202201
x.step == y.step && x.ids == y.ids
@@ -320,6 +319,10 @@ function pool(ps::AbstractStatsProcedure...)
320319
return PooledStatsProcedure(ps, shared)
321320
end
322321

322+
# A shortcut for the simple case
323+
pool(p::AbstractStatsProcedure) =
324+
PooledStatsProcedure((p,), [SharedStatsStep(s, 1) for s in p])
325+
323326
length(p::PooledStatsProcedure) = length(p.steps)
324327
eltype(::Type{PooledStatsProcedure}) = SharedStatsStep
325328
firstindex(p::PooledStatsProcedure) = firstindex(p.steps)
@@ -439,6 +442,11 @@ function show(io::IO, ::MIME"text/plain", sp::StatsSpec{T}) where T
439442
_show_args(io, sp)
440443
end
441444

445+
function _count!(objcount::IdDict, obj)
446+
count = get(objcount, obj, 0)
447+
objcount[obj] = count + 1
448+
end
449+
442450
"""
443451
proceed(sps::AbstractVector{<:StatsSpec}; kwargs...)
444452
@@ -464,50 +472,65 @@ function proceed(sps::AbstractVector{<:StatsSpec};
464472
verbose::Bool=false, keep=nothing, keepall::Bool=false)
465473
nsps = length(sps)
466474
nsps == 0 && throw(ArgumentError("expect a nonempty vector"))
475+
476+
gids = IdDict{AbstractStatsProcedure, Vector{Int}}()
477+
objcount = IdDict{Any, Int}()
467478
traces = Vector{NamedTuple}(undef, nsps)
468-
for i in 1:nsps
469-
traces[i] = sps[i].args
479+
@inbounds for i in 1:nsps
480+
push!(get!(Vector{Int}, gids, _procedure(sps[i])()), i)
481+
args = sps[i].args
482+
foreach(x->_count!(objcount, x), args)
483+
traces[i] = args
470484
end
471-
gids = groupfind(r->_procedure(r)(), sps)
485+
472486
steps = pool((p for p in keys(gids))...)
487+
tasks = IdDict{Tuple, Vector{Int}}()
473488
ntask_total = 0
474-
for step in steps
489+
@inbounds for step in steps
475490
ntask = 0
476-
verbose && printstyled("Running ", step, "...")
477-
taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...)
478-
tasks = groupview(r->groupargs(step, r), view(traces, taskids))
479-
for (ins, subtb) in pairs(tasks)
480-
ret = _f(step)(ins..., combinedargs(step, subtb)...)
481-
if ret isa Tuple{<:NamedTuple, Bool}
482-
ret, share = ret
483-
else
484-
fname = typeof(_f(step)).name.mt.name
485-
stepname = typeof(step.step).parameters[1]
486-
error("unexpected type $(typeof(ret)) of object returned from $fname associated with StatsStep $stepname")
491+
verbose && print("Running ", step, "...")
492+
# Group arguments by ===
493+
for i in _sharedby(step)
494+
taskids = gids[steps.procs[i]]
495+
for j in taskids
496+
push!(get!(Vector{Int}, tasks, groupargs(step, traces[j])), j)
487497
end
488-
ntask += 1
489-
ntask_total += 1
490-
if share
491-
for i in eachindex(subtb)
492-
subtb[i] = merge(subtb[i], ret)
493-
end
494-
else
495-
for i in eachindex(subtb)
496-
subtb[i] = merge(subtb[i], deepcopy(ret))
498+
end
499+
500+
for (gargs, ids) in tasks
501+
# Handle potential in-place operations on mutable objects
502+
nids = length(ids)
503+
icopy = copyargs(step)
504+
if length(icopy) > 0
505+
gargs = Any[gargs...]
506+
for i in copyargs(step)
507+
a = gargs[i]
508+
objcount[a] > nids && (gargs[i] = deepcopy(a))
497509
end
498510
end
511+
512+
ret = _f(step)(gargs..., combinedargs(step, view(traces, ids))...)
513+
for id in ids
514+
foreach(x->_count!(objcount, x), ret)
515+
traces[id] = merge(traces[id], ret)
516+
end
499517
end
518+
ntask = length(tasks)
519+
ntask_total += ntask
520+
empty!(tasks)
500521
nprocs = length(_sharedby(step))
501-
verbose && printstyled("Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
522+
verbose && print("Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
502523
nprocs, nprocs > 1 ? " procedures\n" : " procedure\n")
503524
end
525+
504526
nprocs = length(steps.procs)
505527
verbose && printstyled("All steps finished (", ntask_total,
506528
ntask_total > 1 ? " tasks" : " task", " for ", nprocs,
507529
nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green)
508-
for i in 1:nsps
530+
@inbounds for i in 1:nsps
509531
traces[i] = result(_procedure(sps[i]), traces[i])
510532
end
533+
511534
if keepall
512535
return traces
513536
elseif keep === nothing
@@ -517,11 +540,11 @@ function proceed(sps::AbstractVector{<:StatsSpec};
517540
if keep isa Symbol
518541
keep = (keep,)
519542
else
520-
eltype(keep) == Symbol ||
521-
throw(ArgumentError("expect Symbol or collections of Symbols for the value of option `keep`"))
543+
eltype(keep) == Symbol || throw(ArgumentError(
544+
"expect Symbol or collections of Symbols for the value of option `keep`"))
522545
end
523546
in(:result, keep) || (keep = (keep..., :result))
524-
for i in 1:nsps
547+
@inbounds for i in 1:nsps
525548
names = ((n for n in keep if haskey(traces[i], n))...,)
526549
traces[i] = NamedTuple{names}(traces[i])
527550
end

src/procedures.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function checkdata(data, subset::Union{AbstractVector, Nothing},
2525
end
2626

2727
sum(esample) == 0 && error("no nonmissing data")
28-
return (esample=esample,), false
28+
return (esample=esample,)
2929
end
3030

3131
"""
@@ -97,7 +97,7 @@ function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel,
9797

9898
overlap!(esample, tr_rows, tr, pr, treatname, data)
9999
sum(esample) == 0 && error("no nonmissing data")
100-
return (esample=esample, tr_rows=tr_rows), false
100+
return (esample=esample, tr_rows=tr_rows)
101101
end
102102

103103
"""
@@ -109,6 +109,7 @@ const CheckVars = StatsStep{:CheckVars, typeof(checkvars!)}
109109

110110
required(::CheckVars) = (:data, :tr, :pr, :yterm, :treatname, :esample)
111111
default(::CheckVars) = (treatintterms=(), xterms=())
112+
copyargs(::CheckVars) = (6,)
112113

113114
"""
114115
makeweights(args...)
@@ -119,19 +120,18 @@ See also [`MakeWeights`](@ref).
119120
function makeweights(data, esample::BitVector, weightname::Symbol)
120121
weights = Weights(convert(Vector{Float64}, view(getcolumn(data, weightname), esample)))
121122
all(isfinite, weights) || error("data column $weightname contain not-a-number values")
122-
(weights=weights,), true
123+
return (weights=weights,)
123124
end
124125

125126
function makeweights(data, esample::BitVector, weightname::Nothing)
126127
weights = uweights(sum(esample))
127-
(weights=weights,), true
128+
return (weights=weights,)
128129
end
129130

130131
"""
131132
MakeWeights <: StatsStep
132133
133134
Call [`DiffinDiffsBase.makeweights`](@ref) to create a generic `Weights` vector.
134-
The returned object named `weights` may be shared across multiple specifications.
135135
"""
136136
const MakeWeights = StatsStep{:MakeWeights, typeof(makeweights)}
137137

test/StatsProcedures.jl

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,31 @@
11
using DiffinDiffsBase: _f, _get, groupargs, _get_default,
22
_sharedby, _show_args, _args_kwargs, _parse!, pool, proceed
3-
import DiffinDiffsBase: required, default, transformed, combinedargs
3+
import DiffinDiffsBase: required, default, transformed, combinedargs, copyargs
44

5-
testvoidstep(a::String) = NamedTuple(), false
5+
testvoidstep(a::String) = NamedTuple()
66
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep)}
77
required(::TestVoidStep) = (:a,)
88

9-
testregstep(a::String, b::String) = (c=a*b,), false
9+
testregstep(a::String, b::String) = (c=a*b,)
1010
const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep)}
1111
default(::TestRegStep) = (a="a", b="b")
1212

13-
testlaststep(a::String, c::String) = (result=a*c,), false
13+
testlaststep(a::String, c::String) = (result=a*c,)
1414
const TestLastStep = StatsStep{:TestLastStep, typeof(testlaststep)}
1515
default(::TestLastStep) = (a="a",)
1616
transformed(::TestLastStep, ntargs::NamedTuple) = (ntargs.c,)
1717

18-
testarraystep(a::String) = (result=[a],), false
19-
const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep)}
20-
default(::TestArrayStep) = (a="a",)
21-
22-
testcombinestep(a::String, bs::String...) = (c=collect(bs),), true
18+
testcombinestep(a::String, bs::String...) = (c=collect(bs),)
2319
const TestCombineStep = StatsStep{:TestCombineStep, typeof(testcombinestep)}
2420
default(::TestCombineStep) = (a="a",)
2521
combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs]
2622

27-
testinvalidstep(a::String, b::String) = b, false
28-
const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep)}
29-
default(::TestInvalidStep) = (a="a", b="b")
23+
testarraystep(a::String, c::Array) = (result=c,)
24+
const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep)}
25+
required(::TestArrayStep) = (:a, :c)
26+
copyargs(::TestArrayStep) = (2,)
3027

31-
const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testinvalidstep)}
28+
const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testregstep)}
3229

3330
@testset "StatsStep" begin
3431
@testset "_get" begin
@@ -64,14 +61,14 @@ const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testinvalidstep)}
6461
@test TestLastStep()((a="a", b="a", c="ab")) ==
6562
(a="a", b="a", c="ab", result="aab")
6663

67-
@test TestArrayStep()((a="a",)) == (a="a", result=["a"])
6864
@test TestCombineStep()((a="a", b="b")) == (a="a", b="b", c=["b"])
65+
c = ["c"]
66+
ret = TestArrayStep()((a="a", c=c,))
67+
@test ret.result === c
6968

7069
@test sprint(show, TestVoidStep()) == "TestVoidStep"
7170
@test sprint(show, MIME("text/plain"), TestVoidStep()) ==
7271
"TestVoidStep (StatsStep that calls testvoidstep)"
73-
74-
@test_throws ErrorException TestInvalidStep()()
7572
end
7673
end
7774

@@ -86,8 +83,8 @@ const NP = TestProcedure{:NullProcedure,Tuple{}}
8683
const np = NP()
8784
const CP = TestProcedure{:CombineProcedure,Tuple{TestCombineStep,TestArrayStep}}
8885
const cp = CP()
89-
const EP = TestProcedure{:InvalidProcedure,Tuple{TestInvalidStep}}
90-
const ep = EP()
86+
const AP = TestProcedure{:ArrayProcedure,Tuple{TestArrayStep}}
87+
const ap = AP()
9188

9289
@testset "AbstractStatsProcedure" begin
9390
@test length(rp) == 3
@@ -266,8 +263,10 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
266263
s5 = StatsSpec("s5", IP, (a="a", b="b"))
267264
s6 = StatsSpec("s6", CP, (a="a", b="b1"))
268265
s7 = StatsSpec("s7", CP, (a="a", b="b2"))
269-
s8 = StatsSpec("s8", NP, NamedTuple())
270-
s9 = StatsSpec("s9", EP, (a="a", b="b"))
266+
c = ["c"]
267+
s8 = StatsSpec("s8", AP, (a="a", c=c))
268+
s9 = StatsSpec("s9", AP, (a="a1", c=c))
269+
s10 = StatsSpec("s10", NP, NamedTuple())
271270

272271
@test proceed([s1]) == ["aab"]
273272
@test proceed([s1,s2], verbose=true) == ["aab", "aab"]
@@ -290,19 +289,23 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
290289
@test proceed([s1,s4], keepall=true) ==
291290
[(a="a", b="b", c="ab", result="aab"), (a="a", b="b", c="ab")]
292291

293-
@test proceed([s6], keepall=true) == [(a="a", b="b1", c=["b1"], result=["a"])]
292+
@test proceed([s6], keepall=true) == [(a="a", b="b1", c=["b1"], result=["b1"])]
294293
ret = proceed([s6,s7], keepall=true)
295-
@test ret == [(a="a", b="b1", c=["b1", "b2"], result=["a"]),
296-
(a="a", b="b2", c=["b1", "b2"], result=["a"])]
294+
@test ret == [(a="a", b="b1", c=["b1", "b2"], result=["b1", "b2"]),
295+
(a="a", b="b2", c=["b1", "b2"], result=["b1", "b2"])]
297296
@test ret[1].c === ret[2].c
298-
@test ret[1].result !== ret[2].result
297+
@test ret[1].result === ret[2].result
298+
299+
ret = proceed([s8], keepall=true)
300+
@test ret[1].c === ret[1].result
301+
ret = proceed([s8,s9], keepall=true)
302+
@test ret[1].c !== ret[1].result
299303

300-
@test proceed([s8]) == [nothing]
301-
@test proceed([s8], keepall=true) == NamedTuple[NamedTuple()]
302-
@test proceed([s8], keep=:result) == NamedTuple[NamedTuple()]
304+
@test proceed([s10]) == [nothing]
305+
@test proceed([s10], keepall=true) == NamedTuple[NamedTuple()]
306+
@test proceed([s10], keep=:result) == NamedTuple[NamedTuple()]
303307

304308
@test_throws ArgumentError proceed(StatsSpec[])
305-
@test_throws ErrorException proceed([s9])
306309
end
307310

308311
@testset "_parse!" begin

0 commit comments

Comments
 (0)