Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 753ada4

Browse files
authored
Improve the update of esample and fix issues with overlap! (#28)
1 parent 5b5942d commit 753ada4

13 files changed

+315
-96
lines changed

src/DiffinDiffsBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export cb,
5151
ScaledVector,
5252
ScaledMatrix,
5353
scale,
54+
align,
5455

5556
TreatmentSharpness,
5657
SharpDesign,
@@ -74,6 +75,7 @@ export cb,
7475
NotYetTreatedParallel,
7576
notyettreated,
7677
istreated,
78+
istreated!,
7779

7880
TermSet,
7981
termset,
@@ -84,6 +86,7 @@ export cb,
8486
findcell,
8587
cellrows,
8688
settime,
89+
aligntime,
8790
PanelStructure,
8891
setpanel,
8992
findlag!,

src/ScaledArrays.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,19 @@ ScaledArray(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where {T,R,RA<:Abstra
3333
const ScaledVector{T,R} = ScaledArray{T,R,1}
3434
const ScaledMatrix{T,R} = ScaledArray{T,R,2}
3535

36-
const ScaledArrOrSub = Union{ScaledArray, SubArray{<:Any, <:Any, <:ScaledArray}}
36+
const ScaledArrOrSub{T,R,N,RA,P} = Union{ScaledArray{T,R,N,RA,P},
37+
SubArray{<:Any, <:Any, <:ScaledArray{T,R,N,RA,P}}}
3738

39+
"""
40+
scale(sa::ScaledArrOrSub)
41+
42+
Return the step size of the `pool` of `sa`.
43+
"""
3844
scale(sa::ScaledArrOrSub) = step(DataAPI.refpool(sa))
3945

46+
Base.size(sa::ScaledArray) = size(sa.refs)
47+
Base.IndexStyle(::Type{<:ScaledArray{T,R,N,RA}}) where {T,R,N,RA} = IndexStyle(RA)
48+
4049
function _validmin(min, xmin, isstart::Bool)
4150
if min === nothing
4251
min = xmin
@@ -180,9 +189,6 @@ ScaledArray(sa::ScaledArray, step=nothing; reftype::Type=eltype(refarray(sa)),
180189
start=nothing, stop=nothing, xtype::Type=eltype(sa), usepool::Bool=true) =
181190
ScaledArray(sa, reftype, xtype, start, step, stop, usepool)
182191

183-
Base.size(sa::ScaledArray) = size(sa.refs)
184-
Base.IndexStyle(::Type{<:ScaledArray{T,R,N,RA}}) where {T,R,N,RA} = IndexStyle(RA)
185-
186192
Base.similar(sa::ScaledArray{T,R}, dims::Dims=size(sa)) where {T,R} =
187193
ScaledArray(RefArray(ones(R, dims)), DataAPI.refpool(sa), Dict{T,R}())
188194

@@ -191,6 +197,28 @@ Base.similar(sa::SubArray{<:Any, <:Any, <:ScaledArray{T,R}}, dims::Dims=size(sa)
191197

192198
Base.similar(sa::ScaledArrOrSub, dims::Int...) = similar(sa, dims)
193199

200+
"""
201+
align(xs::AbstractArray, sa::ScaledArrOrSub)
202+
203+
Convert `xs` into a [`ScaledArray`](@ref) with a `pool`
204+
that has the same first element and step size as the `pool` from `sa`.
205+
"""
206+
function align(xs::AbstractArray, sa::ScaledArrOrSub)
207+
pool = DataAPI.refpool(sa)
208+
invpool = DataAPI.invrefpool(sa)
209+
step = scale(sa)
210+
xmin, xmax = extrema(xs)
211+
start = first(pool)
212+
stop = last(pool)
213+
start < stop && xmin < start && throw(ArgumentError(
214+
"the minimum of xs $xmin is smaller than the minimum of pool $start"))
215+
start > stop && xmax > start && throw(ArgumentError(
216+
"the maximum of xs $xmax is greater than the maximum of pool $start"))
217+
refs = similar(DataAPI.refarray(sa), size(xs))
218+
_scaledlabel!(refs, invpool, xs, start, step)
219+
return ScaledArray(RefArray(refs), pool, invpool)
220+
end
221+
194222
DataAPI.refarray(sa::ScaledArray) = sa.refs
195223
DataAPI.refvalue(sa::ScaledArray, n::Integer) = getindex(DataAPI.refpool(sa), n)
196224
DataAPI.refpool(sa::ScaledArray) = sa.pool

src/operations.jl

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
9797
columns = Vector{AbstractVector}(undef, ncol)
9898
for i in 1:ncol
9999
c = cols[i]
100-
if typeof(c) <: ScaledArray || typeof(c) <: SubArray{<:Any,1,<:ScaledArray}
100+
if typeof(c) <: ScaledArrOrSub
101101
columns[i] = similar(c, ncell)
102102
else
103103
columns[i] = Vector{eltype(c)}(undef, ncell)
@@ -127,13 +127,15 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
127127
end
128128

129129
"""
130-
settime(data, timename; step, reftype, rotation)
131-
settime(time::AbstractArray; step, reftype, rotation)
130+
settime(data, timename; step, start, stop, reftype, rotation)
131+
settime(time::AbstractArray; step, start, stop, reftype, rotation)
132132
133-
Return a [`ScaledArray`](@ref) that represents discretized time periods.
133+
Convert a column of time values to a [`ScaledArray`](@ref)
134+
for representing discretized time periods of uniform length.
134135
Time values can be provided either as a table containing the relevant column or as an array.
135136
The returned array ensures well-defined time intervals for operations involving relative time
136137
(such as [`lag`](@ref) and [`diff`](@ref)).
138+
See also [`aligntime`](@ref).
137139
138140
# Arguments
139141
- `data`: a Tables.jl-compatible data table.
@@ -142,15 +144,18 @@ The returned array ensures well-defined time intervals for operations involving
142144
143145
# Keywords
144146
- `step=nothing`: the length of each time interval; try step=1 if not specified.
147+
- `start=nothing`: the first element of the `pool` of the returned [`ScaledArray`](@ref).
148+
- `stop=nothing`: the last element of the `pool` of the returned [`ScaledArray`](@ref).
145149
- `reftype::Type{<:Signed}=Int32`: the element type of the reference values for the returned [`ScaledArray`](@ref).
146150
- `rotation=nothing`: rotation groups in a rotating sampling design; use [`RotatingTimeValue`](@ref)s as reference values.
147151
"""
148-
function settime(time::AbstractArray; step=nothing, reftype::Type{<:Signed}=Int32, rotation=nothing)
152+
function settime(time::AbstractArray; step=nothing, start=nothing, stop=nothing,
153+
reftype::Type{<:Signed}=Int32, rotation=nothing)
149154
T = eltype(time)
150155
T <: ValidTimeType && !(T <: RotatingTimeValue) ||
151156
throw(ArgumentError("unaccepted element type $T from time column"))
152157
step === nothing && (step = one(T))
153-
time = ScaledArray(time, step; reftype=reftype)
158+
time = ScaledArray(time, start, step, stop; reftype=reftype)
154159
if rotation !== nothing
155160
refs = rotatingtime(rotation, time.refs)
156161
rots = unique(rotation)
@@ -168,10 +173,30 @@ function settime(time::AbstractArray; step=nothing, reftype::Type{<:Signed}=Int3
168173
return time
169174
end
170175

171-
function settime(data, timename::Union{Symbol,Integer}; step=nothing,
176+
function settime(data, timename::Union{Symbol,Integer};
177+
step=nothing, start=nothing, stop=nothing,
172178
reftype::Type{<:Signed}=Int32, rotation=nothing)
173179
checktable(data)
174-
return settime(getcolumn(data, timename); step=step, reftype=reftype, rotation=rotation)
180+
return settime(getcolumn(data, timename);
181+
step=step, start=start, stop=stop, reftype=reftype, rotation=rotation)
182+
end
183+
184+
"""
185+
aligntime(data, colname::Union{Symbol,Integer}, timename::Union{Symbol,Integer})
186+
187+
Convert a column of time values indexed by `colname` from `data` table
188+
to a [`ScaledArray`](@ref) with a `pool`
189+
that has the same first element and step size as the `pool` from
190+
the [`ScaledArray`](@ref) indexed by `timename`.
191+
See also [`settime`](@ref).
192+
193+
This is useful for representing all discretized time periods with the same scale
194+
so that the underlying reference values returned by `DataAPI.refarray`
195+
can be directly comparable across the columns.
196+
"""
197+
function aligntime(data, colname::Union{Symbol,Integer}, timename::Union{Symbol,Integer})
198+
checktable(data)
199+
return align(getcolumn(data, colname), getcolumn(data, timename))
175200
end
176201

177202
"""

src/parallels.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,25 @@ assume a parallel trends assumption holds over all the relevant time periods.
8484
"""
8585
abstract type TrendParallel{C,S} <: AbstractParallel{C,S} end
8686

87+
"""
88+
istreated(pr::TrendParallel, x)
89+
90+
Test whether `x` represents the treatment time
91+
for a group of units that are not treated.
92+
See also [`istreated!`](@ref).
93+
"""
94+
function istreated end
95+
96+
"""
97+
istreated!(out::AbstractVector{Bool}, pr::TrendParallel, x::AbstractArray)
98+
99+
For each element in `x`,
100+
test whether it represents the treatment time
101+
for a group of units that are not treated and save the result in `out`.
102+
See also [`istreated`](@ref).
103+
"""
104+
function istreated! end
105+
87106
"""
88107
NeverTreatedParallel{C,S} <: TrendParallel{C,S}
89108
@@ -110,6 +129,20 @@ end
110129

111130
istreated(pr::NeverTreatedParallel, x) = !(x in pr.e)
112131

132+
function istreated!(out::AbstractVector{Bool}, pr::NeverTreatedParallel,
133+
x::AbstractArray{<:Union{ValidTimeType, Missing}})
134+
e = Set(pr.e)
135+
out .= .!(x .∈ Ref(e))
136+
end
137+
138+
function istreated!(out::AbstractVector{Bool}, pr::NeverTreatedParallel,
139+
x::ScaledArray{<:Union{ValidTimeType, Missing}})
140+
refs = refarray(x)
141+
invpool = invrefpool(x)
142+
e = Set(invpool[c] for c in pr.e if haskey(invpool, c))
143+
out .= .!(refs .∈ Ref(e))
144+
end
145+
113146
show(io::IO, pr::NeverTreatedParallel) =
114147
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}",
115148
length(pr.e)==1 ? string("(", pr.e[1], ")") : pr.e)
@@ -196,6 +229,20 @@ end
196229

197230
istreated(pr::NotYetTreatedParallel, x) = !(x in pr.e)
198231

232+
function istreated!(out::AbstractVector{Bool}, pr::NotYetTreatedParallel,
233+
x::AbstractArray{<:Union{ValidTimeType, Missing}})
234+
e = Set(pr.e)
235+
out .= .!(x .∈ Ref(e))
236+
end
237+
238+
function istreated!(out::AbstractVector{Bool}, pr::NotYetTreatedParallel,
239+
x::ScaledArray{<:Union{ValidTimeType, Missing}})
240+
refs = refarray(x)
241+
invpool = invrefpool(x)
242+
e = Set(invpool[c] for c in pr.e if haskey(invpool, c))
243+
out .= .!(refs .∈ Ref(e))
244+
end
245+
199246
show(io::IO, pr::NotYetTreatedParallel) =
200247
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}",
201248
length(pr.e)==1 ? string("(", pr.e[1], ")") : pr.e)

0 commit comments

Comments
 (0)