@@ -34,24 +34,14 @@ Base.length(a::Tensorizer) = mapreduce(sum,*,a.blocks)
3434
3535
3636function start (a:: TrivialTensorizer{d} ) where {d}
37- if d== 2
38- return invoke (start, Tuple{Tensorizer2D}, a)
39- else
40- # ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
41- block = SVector {d} (Ones {Int} (d))
42- return (block, (0 , nothing , nothing )), (0 ,length (a))
43- end
37+ # ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
38+ block = SVector {d} (Ones {Int} (d))
39+ return (block, (0 , nothing , nothing )), (0 ,length (a))
4440end
4541
4642function next (a:: TrivialTensorizer{d} , iterator_tuple) where {d}
47-
48- if d== 2
49- return invoke (next, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
50- end
51-
5243 (block, (j, iterator, iter_state)), (i,tot) = iterator_tuple
5344
54-
5545 @inline function check_block_finished (j, iterator, block)
5646 if iterator === nothing
5747 return true
@@ -82,19 +72,22 @@ function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
8272end
8373
8474
85- function done (a:: TrivialTensorizer{d} , iterator_tuple) where {d}
86- if d== 2
87- return invoke (done, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
88- end
89- (_, (i,tot)) = iterator_tuple
75+ function done (a:: TrivialTensorizer , iterator_tuple)
76+ i, tot = last (iterator_tuple)
9077 return i ≥ tot
9178end
9279
9380
9481# (blockrow,blockcol), (subrow,subcol), (rowshift,colshift), (numblockrows,numblockcols), (itemssofar, length)
95- start (a:: Tensorizer2D{AA, BB} ) where {AA,BB} = (1 ,1 ), (1 ,1 ), (0 ,0 ), (a. blocks[1 ][1 ],a. blocks[2 ][1 ]), (0 ,length (a))
82+ start (a:: Tensorizer2D ) = _start (a:: Tensorizer2D )
83+ start (a:: TrivialTensorizer{2} ) = _start (a:: Tensorizer2D )
84+
85+ _start (a) = (1 ,1 ), (1 ,1 ), (0 ,0 ), (a. blocks[1 ][1 ],a. blocks[2 ][1 ]), (0 ,length (a))
86+
87+ next (a:: Tensorizer2D , state) = _next (a, state)
88+ next (a:: TrivialTensorizer{2} , state) = _next (a, state)
9689
97- function next (a :: Tensorizer2D{AA, BB} , ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) where {AA,BB}
90+ function _next (a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
9891 ret = k+ rsh,j+ csh
9992 if k== n && j== m # end of block
10093 if J == 1 || K == length (a. blocks[1 ]) # end of new block
@@ -118,8 +111,10 @@ function next(a::Tensorizer2D{AA, BB}, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))
118111 ret, ((K,J), (k,j), (rsh,csh), (n,m), (i+ 1 ,tot))
119112end
120113
114+ done (a:: Tensorizer2D , state) = _done (a, state)
115+ done (a:: TrivialTensorizer{2} , state) = _done (a, state)
121116
122- done (a :: Tensorizer2D , ((K,J), (k,j), (rsh,csh), (n,m) , (i,tot))) = i ≥ tot
117+ _done (a , (_, _, _, _ , (i,tot))) = i ≥ tot
123118
124119iterate (a:: Tensorizer ) = next (a, start (a))
125120function iterate (a:: Tensorizer , st)
580575function totensor (it:: Tensorizer ,M:: AbstractVector )
581576 n= length (M)
582577 B= block (it,n)
583- ds = dimensions (it)
584578
585579 # ret=zeros(eltype(M),[sum(it.blocks[i][1:min(B.n[1],length(it.blocks[i]))]) for i=1:length(it.blocks)]...)
586580
0 commit comments