Skip to content

Commit be99339

Browse files
authored
Faster indexing with ReinterpretArrays (#37277)
There's at least one branch that can't actually be taken but which LLVM doesn't elide. The implementations here specialize for cases where both the old and new types are aligned to one another.
1 parent ddd08cd commit be99339

File tree

1 file changed

+68
-39
lines changed

1 file changed

+68
-39
lines changed

base/reinterpretarray.jl

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,32 @@ end
151151
GC.@preserve t s begin
152152
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
153153
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
154-
i = 1
155-
nbytes_copied = 0
156-
# This is a bit complicated to deal with partial elements
157-
# at both the start and the end. LLVM will fold as appropriate,
158-
# once it knows the data layout
159-
while nbytes_copied < sizeof(T)
160-
s[] = a.parent[ind_start + i, tailinds...]
161-
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
162-
_memcpy!(tptr + nbytes_copied, sptr + sidx, nb)
163-
nbytes_copied += nb
164-
sidx = 0
165-
i += 1
154+
# Optimizations that avoid branches
155+
if sizeof(T) % sizeof(S) == 0
156+
# T is bigger than S and contains an integer number of them
157+
n = sizeof(T) ÷ sizeof(S)
158+
for i = 1:n
159+
s[] = a.parent[ind_start + i, tailinds...]
160+
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
161+
end
162+
elseif sizeof(S) % sizeof(T) == 0
163+
# S is bigger than T and contains an integer number of them
164+
s[] = a.parent[ind_start + 1, tailinds...]
165+
_memcpy!(tptr, sptr + sidx, sizeof(T))
166+
else
167+
i = 1
168+
nbytes_copied = 0
169+
# This is a bit complicated to deal with partial elements
170+
# at both the start and the end. LLVM will fold as appropriate,
171+
# once it knows the data layout
172+
while nbytes_copied < sizeof(T)
173+
s[] = a.parent[ind_start + i, tailinds...]
174+
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
175+
_memcpy!(tptr + nbytes_copied, sptr + sidx, nb)
176+
nbytes_copied += nb
177+
sidx = 0
178+
i += 1
179+
end
166180
end
167181
end
168182
return t[]
@@ -200,33 +214,48 @@ end
200214
GC.@preserve t s begin
201215
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
202216
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
203-
nbytes_copied = 0
204-
i = 1
205-
# Deal with any partial elements at the start. We'll have to copy in the
206-
# element from the original array and overwrite the relevant parts
207-
if sidx != 0
208-
s[] = a.parent[ind_start + i, tailinds...]
209-
nb = min(sizeof(S) - sidx, sizeof(T))
210-
_memcpy!(sptr + sidx, tptr, nb)
211-
nbytes_copied += nb
212-
a.parent[ind_start + i, tailinds...] = s[]
213-
i += 1
214-
sidx = 0
215-
end
216-
# Deal with the main body of elements
217-
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
218-
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
219-
_memcpy!(sptr, tptr + nbytes_copied, nb)
220-
nbytes_copied += nb
221-
a.parent[ind_start + i, tailinds...] = s[]
222-
i += 1
223-
end
224-
# Deal with trailing partial elements
225-
if nbytes_copied < sizeof(T)
226-
s[] = a.parent[ind_start + i, tailinds...]
227-
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
228-
_memcpy!(sptr, tptr + nbytes_copied, nb)
229-
a.parent[ind_start + i, tailinds...] = s[]
217+
# Optimizations that avoid branches
218+
if sizeof(T) % sizeof(S) == 0
219+
# T is bigger than S and contains an integer number of them
220+
n = sizeof(T) ÷ sizeof(S)
221+
for i = 0:n-1
222+
_memcpy!(sptr, tptr + i*sizeof(S), sizeof(S))
223+
a.parent[ind_start + i + 1, tailinds...] = s[]
224+
end
225+
elseif sizeof(S) % sizeof(T) == 0
226+
# S is bigger than T and contains an integer number of them
227+
s[] = a.parent[ind_start + 1, tailinds...]
228+
_memcpy!(sptr + sidx, tptr, sizeof(T))
229+
a.parent[ind_start + 1, tailinds...] = s[]
230+
else
231+
nbytes_copied = 0
232+
i = 1
233+
# Deal with any partial elements at the start. We'll have to copy in the
234+
# element from the original array and overwrite the relevant parts
235+
if sidx != 0
236+
s[] = a.parent[ind_start + i, tailinds...]
237+
nb = min((sizeof(S) - sidx) % UInt, sizeof(T) % UInt)
238+
_memcpy!(sptr + sidx, tptr, nb)
239+
nbytes_copied += nb
240+
a.parent[ind_start + i, tailinds...] = s[]
241+
i += 1
242+
sidx = 0
243+
end
244+
# Deal with the main body of elements
245+
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
246+
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
247+
_memcpy!(sptr, tptr + nbytes_copied, nb)
248+
nbytes_copied += nb
249+
a.parent[ind_start + i, tailinds...] = s[]
250+
i += 1
251+
end
252+
# Deal with trailing partial elements
253+
if nbytes_copied < sizeof(T)
254+
s[] = a.parent[ind_start + i, tailinds...]
255+
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
256+
_memcpy!(sptr, tptr + nbytes_copied, nb)
257+
a.parent[ind_start + i, tailinds...] = s[]
258+
end
230259
end
231260
end
232261
end

0 commit comments

Comments
 (0)