Skip to content

Commit 49a7369

Browse files
authored
Merge pull request #38 from mcabbott/index
Allow indices to appear on RHS
2 parents ca91cca + 3f679d9 commit 49a7369

File tree

7 files changed

+190
-51
lines changed

7 files changed

+190
-51
lines changed

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
[![Build Status](https://github.com/mcabbott/TensorCast.jl/workflows/CI/badge.svg)](https://github.com/mcabbott/TensorCast.jl/actions?query=workflow%3ACI)
77

88
This package lets you work with multi-dimensional arrays in index notation,
9-
by defining a few macros. The first is `@cast`, which deals both with "casting" into
10-
new shapes (including going to and from an array-of-arrays) and with broadcasting:
9+
by defining a few macros.
10+
11+
The first is `@cast`, which deals both with "casting" into new shapes (including going to and from an array-of-arrays) and with broadcasting:
1112

1213
```julia
1314
@cast A[row][col] := B[row, col] # slice a matrix B into rows, also @cast A[r] := B[r,:]
@@ -16,6 +17,8 @@ new shapes (including going to and from an array-of-arrays) and with broadcastin
1617

1718
@cast E[φ,γ] = F[φ]^2 * exp(G[γ]) # broadcast E .= F.^2 .* exp.(G') into existing E
1819

20+
@cast _[i] := isodd(i) ? log(i) : V[i] # broadcast a function of the index values
21+
1922
@cast T[x,y,n] := outer(M[:,n])[x,y] # generalised mapslices, vector -> matrix function
2023
```
2124

docs/src/basics.md

+42
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,42 @@ julia> colwise == Int.(reduce(hcat, vec.(list)))
236236
true
237237
```
238238

239+
## Index values
240+
241+
Mostly the indices appearing in `@cast` expressions are just notation, to indicate what permutation / reshape is required.
242+
But if an index appears outside of square brackets, this is understood as a value, implemented by broadcasting over a range (appropriately permuted):
243+
244+
```jldoctest
245+
julia> @cast rat[i,j] := 0 * M[i,j] + i // j
246+
3×4 Array{Rational{Int64},2}:
247+
1//1 1//2 1//3 1//4
248+
2//1 1//1 2//3 1//2
249+
3//1 3//2 1//1 3//4
250+
251+
julia> rat == @cast _[i,j] := axes(M,1)[i] // axes(M,2)[j] # more verbose @cast
252+
true
253+
254+
julia> rat == axes(M,1) .// transpose(axes(M,2)) # what it aims to generate
255+
true
256+
257+
julia> @cast _[r,c] := r + 10c (r in 1:2, c in 1:7) # no array for range inference
258+
2×7 Array{Int64,2}:
259+
11 21 31 41 51 61 71
260+
12 22 32 42 52 62 72
261+
```
262+
263+
Writing `$i` will interpolate the variable `i`, distinct from the index `i`:
264+
265+
```jldoctest
266+
julia> i, k = 10, 100;
267+
268+
julia> @cast ones(3)[i] = i + $i + k
269+
3-element Vector{Float64}:
270+
111.0
271+
112.0
272+
113.0
273+
```
274+
239275
## Reverse & shuffle
240276

241277
A minus in front of an index will reverse that direction, and a tilde will shuffle it.
@@ -248,6 +284,9 @@ julia> @cast M2[i,j] := M[i,-j]
248284
11 8 5 2
249285
12 9 6 3
250286
287+
julia> all(M2[i,j] == M[i, end+begin-j] for i in 1:3, j in 1:4)
288+
true
289+
251290
julia> using Random; Random.seed!(42);
252291
253292
julia> @cast M3[i,j] |= M[i,~j]
@@ -257,6 +296,9 @@ julia> @cast M3[i,j] |= M[i,~j]
257296
9 6 3 12
258297
```
259298

299+
Note that the minus is a slight deviation from the rule that left equals right for all indices,
300+
it should really be `M[i, end+1-j]`.
301+
260302
## Primes `'`
261303

262304
Acting on indices, `A[i']` is normalised to `A[i′]` unicode \prime (which looks identical in some fonts).

docs/src/index.md

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ Version 0.4 has significant changes:
1919
- It uses [LazyStack.jl](https://github.com/mcabbott/LazyStack.jl) to combine handles slices, simplifying earlier code. This is lazier by default, write `@cast A[i,k] := log(B[k][i]) lazy=false` (with a new keyword option) to glue into an `Array` before broadcasting.
2020
- It uses [TransmuteDims.jl](https://github.com/mcabbott/TransmuteDims.jl) to handle all permutations & many reshapes. This is lazier by default -- the earlier code sometimes copied to avoid reshaping a `PermutedDimsArray`. This isn't always faster, though, and can be disabled by `lazy=false`.
2121

22+
New features in 0.4:
23+
- Indices can appear ouside of indexing: `@cast A[i,j] = i+j` translates to `A .= axes(A,1) .+ axes(A,2)'`
24+
- The ternary operator `? :` can appear on the right, and will be broadcast correctly.
25+
2226
## Pages
2327

2428
1. Use of `@cast` for broadcasting, dealing with arrays of arrays, and generalising `mapslices`

docs/src/reduce.md

+9-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ julia> @reduce S[i] := sum(j) M[i,j] + 1000
1717
1818
julia> @pretty @reduce S[i] := sum(j) M[i,j] + 1000
1919
begin
20-
S = transmute(sum(@__dot__(M + 1000), dims = 2), (1,))
20+
ndims(M) == 2 || throw(ArgumentError("expected a 2-tensor M[i, j]"))
21+
S = dropdims(sum(@__dot__(M + 1000), dims = 2), dims=2)
2122
end
2223
```
2324

@@ -26,9 +27,6 @@ Note that:
2627
* The sum applies to the whole right side (including the 1000 here).
2728
* And the summed dimensions are always dropped (unless you explicitly say `S[i,_] := ...`).
2829

29-
Here `transmute(..., (1,))` is equivalent to `dropdims(..., dims=2)`,
30-
keeping just the first dimension by reshaping.
31-
3230
## Not just `sum`
3331

3432
You may use any reduction funciton which understands keyword `dims=...`, like `sum` does.
@@ -112,10 +110,10 @@ There is no need to name the intermediate array, here `termite[x]`, but you must
112110
```julia-repl
113111
julia> @pretty @reduce sum(x,θ) L[x,θ] * p[θ] * log(L[x,θ] / @reduce _[x] := sum(θ′) L[x,θ′] * p[θ′])
114112
begin
115-
local fish = transmute(p, (nothing, 1))
116-
termite = transmute(sum(@__dot__(L * fish), dims = 2), (1,))
117-
local wallaby = transmute(p, (nothing, 1))
118-
rat = sum(@__dot__(L * wallaby * log(L / termite)))
113+
ndims(L) == 2 || error() # etc, some checks
114+
local goshawk = transmute(p, (nothing, 1))
115+
sandpiper = dropdims(sum(@__dot__(L * goshawk), dims = 2), dims = 2) # inner sum
116+
bison = sum(@__dot__(L * goshawk * log(L / sandpiper)))
119117
end
120118
```
121119

@@ -131,8 +129,9 @@ before summing over one index:
131129
```julia-repl
132130
julia> @pretty @reduce R[i,k] := sum(j) M[i,j] * N[j,k]
133131
begin
134-
local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
135-
R = transmute(sum(@__dot__(M * fish), dims = 2), (1, 3)) # R = dropdims(sum(...), dims=2)
132+
size(M, 2) == size(N, 1) || error() # etc, some checks
133+
local fish = transmute(N, (nothing, 1, 2)) # fish = reshape(N, 1, size(N)...)
134+
R = dropdims(sum(@__dot__(M * fish), dims = 2), dims = 2)
136135
end
137136
```
138137

0 commit comments

Comments
 (0)