@@ -145,15 +145,54 @@ function _broadcast_mul_ldiv(::Tuple{ScalarLayout,ApplyLayout{typeof(*)}}, A, B)
145
145
a * (A \ b)
146
146
end
147
147
148
- _broadcast_mul_ldiv (:: Tuple{ScalarLayout,AbstractBasisLayout} , A, B) =
149
- _broadcast_mul_ldiv ((ScalarLayout (),UnknownLayout ()), A, B)
148
+ _broadcast_mul_ldiv (:: Tuple{ScalarLayout,AbstractBasisLayout} , A, B) = _broadcast_mul_ldiv ((ScalarLayout (),UnknownLayout ()), A, B)
150
149
_broadcast_mul_ldiv (_, A, B) = copy (Ldiv {typeof(MemoryLayout(A)),UnknownLayout} (A,B))
151
150
152
151
copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
153
152
copy (L:: Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
154
153
155
154
156
155
156
+ # multiplication operators, reexpand in basis A
157
+ @inline function _broadcast_mul_adj (:: Tuple{Any,AbstractBasisLayout} , Ac, B)
158
+ a,b = arguments (B)
159
+ @assert a isa AbstractQuasiVector # Only works for vec .* mat
160
+ A = Ac'
161
+ ab = (A * (A \ a)) .* b # broadcasted should be overloaded
162
+ MemoryLayout (ab) isa BroadcastLayout && return Ac* transform_ldiv (A, ab)
163
+ Ac* ab
164
+ end
165
+
166
+ @inline function _broadcast_mul_adj (:: Tuple{Any,ApplyLayout{typeof(*)}} , Ac, B)
167
+ a,b = arguments (B)
168
+ @assert a isa AbstractQuasiVector # Only works for vec .* mat
169
+ args = arguments (* , b)
170
+ * (Ac* (a .* first (args)), tail (args)... )
171
+ end
172
+
173
+
174
+ function _broadcast_mul_adj (:: Tuple{ScalarLayout,Any} , Ac, B)
175
+ a,b = arguments (B)
176
+ a * (Ac* b)
177
+ end
178
+
179
+ function _broadcast_mul_adj (:: Tuple{ScalarLayout,ApplyLayout{typeof(*)}} , Ac, B)
180
+ a,b = arguments (B)
181
+ a * (Ac* b)
182
+ end
183
+
184
+ _broadcast_mul_adj (:: Tuple{ScalarLayout,AbstractBasisLayout} , A, B) = _broadcast_mul_adj ((ScalarLayout (),UnknownLayout ()), A, B)
185
+ _broadcast_mul_adj (_, A, B) = copy (Mul {typeof(MemoryLayout(A)),UnknownLayout} (A,B))
186
+
187
+ _broadcast_mul_adj_simplifiable (_, :: AbstractBasisLayout ) = Val (true )
188
+ _broadcast_mul_adj_simplifiable (_, :: ApplyLayout{typeof(*)} ) = Val (true )
189
+ _broadcast_mul_adj_simplifiable (:: ScalarLayout , _) = Val (true )
190
+ _broadcast_mul_adj_simplifiable (:: ScalarLayout , :: ApplyLayout{typeof(*)} ) = Val (true )
191
+ _broadcast_mul_adj_simplifiable (:: ScalarLayout , :: AbstractBasisLayout ) = Val (true )
192
+ _broadcast_mul_adj_simplifiable (_, _) = Val (false )
193
+
194
+ simplifiable (L:: Mul{<:AdjointBasisLayout,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_adj_simplifiable (map (MemoryLayout,arguments (L. B))... )
195
+ copy (L:: Mul{<:AdjointBasisLayout,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_adj (map (MemoryLayout,arguments (L. B)), L. A, L. B)
157
196
158
197
159
198
"""
@@ -651,6 +690,7 @@ diff_layout(::ExpansionLayout, A, dims...) = diff_layout(ApplyLayout{typeof(*)}(
651
690
# ###
652
691
653
692
simplifiable (:: Mul{<:AdjointBasisLayout, <:AbstractBasisLayout} ) = Val (true )
693
+ @inline simplifiable (L:: Mul{<:AdjointBasisLayout,ApplyLayout{typeof(*)}} ) = simplifiable (* , L. A, first (arguments (* , L. B)))
654
694
function copy (M:: Mul{<:AdjointBasisLayout, <:AbstractBasisLayout} )
655
695
A = (M. A)'
656
696
A == M. B && return grammatrix (A)
0 commit comments