Skip to content

Commit 5784970

Browse files
committed
fix: code
1 parent 6bc531b commit 5784970

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
66
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
77
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
8+
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910

1011
[sources]

docs/src/tutorials/sharding.md

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
## Basics
1919

20-
Sharding is one mechanism supported within Reactant that tries to make it easy to program for multiple devices (including [multiple nodes](@ref distributed)).
20+
Sharding is one mechanism supported within Reactant that tries to make it easy to program
21+
for multiple devices (including [multiple nodes](@ref distributed)).
2122

2223
```@example sharding_tutorial
2324
using Reactant
@@ -26,32 +27,52 @@ using Reactant
2627
Reactant.devices()
2728
```
2829

29-
Sharding provides Reactant users a [PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space) programming model. Let's understand what this means through example.
30+
Sharding provides Reactant users a
31+
[PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space)
32+
programming model. Let's understand what this means through example.
3033

31-
Suppose we have a function that takes a large input array and computes sin for all elements of the array.
34+
Suppose we have a function that takes a large input array and computes sin for all elements
35+
of the array.
3236

3337
```@example sharding_tutorial
3438
function big_sin(data)
35-
data .= sin(data)
39+
data .= sin.(data)
3640
return nothing
3741
end
3842
3943
N = 1600
40-
x = Reactant.to_array(reshape(collect(1:N), 40, 40))
44+
x = Reactant.to_rarray(reshape(collect(Float32, 1:N), 40, 40))
4145
4246
compiled_big_sin = @compile big_sin(x)
4347
4448
compiled_big_sin(x)
4549
```
4650

47-
This successfully allocates the array `x` on one device, and executes it on the same device. However, suppose we want to execute this computation on multiple devices. Perhaps this is because the size of our inputs (`N`) is too large to fit on a single device. Or alternatively the function we execute is computationally expensive and we want to leverage the computing power of multiple devices.
48-
49-
Unlike more explicit communication libraries like MPI, the sharding model used by Reactant aims to let you execute a program on multiple devices without significant modifications to the single-device program. In particular, you do not need to write explicit communication calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a very large single-node and Reactant will automatically determine how to subdivide the data, computation, and required communication.
50-
51-
When using sharding, the one thing you need to change about your code is how arrays are allocated. In particular, you need to specify how the array is partitioned amongst available devices. For example, suppose you are on a machine with 4 GPUs. In the example above, we computed `sin` for all elements of a 40x40 grid. One partitioning we could select is to have it partitioned along the first axis, such that each GPU has a slice of 10x40 elements. We could accomplish this as follows. No change is required to the original function. However, the compiled function is specific to the sharding so we need to compile a new version for our sharded array.
51+
This successfully allocates the array `x` on one device, and executes it on the same device.
52+
However, suppose we want to execute this computation on multiple devices. Perhaps this is
53+
because the size of our inputs (`N`) is too large to fit on a single device. Or
54+
alternatively the function we execute is computationally expensive and we want to leverage
55+
the computing power of multiple devices.
56+
57+
Unlike more explicit communication libraries like MPI, the sharding model used by Reactant
58+
aims to let you execute a program on multiple devices without significant modifications to
59+
the single-device program. In particular, you do not need to write explicit communication
60+
calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a
61+
very large single-node and Reactant will automatically determine how to subdivide the data,
62+
computation, and required communication.
63+
64+
When using sharding, the one thing you need to change about your code is how arrays are
65+
allocated. In particular, you need to specify how the array is partitioned amongst available
66+
devices. For example, suppose you are on a machine with 4 GPUs. In the example above, we
67+
computed `sin` for all elements of a 40x40 grid. One partitioning we could select is to have
68+
it partitioned along the first axis, such that each GPU has a slice of 10x40 elements. We
69+
could accomplish this as follows. No change is required to the original function. However,
70+
the compiled function is specific to the sharding so we need to compile a new version for
71+
our sharded array.
5272

5373
```@example sharding_tutorial
5474
N = 1600
75+
5576
x_sharded_first = Reactant.to_array(
5677
reshape(collect(1:N), 40, 40),
5778
sharding=Sharding.NamedSharding(
@@ -65,7 +86,10 @@ compiled_big_sin_sharded_first = @compile big_sin(x_sharded_first)
6586
compiled_big_sin_sharded_first(x_sharded_first)
6687
```
6788

68-
Alternatively, we can parition the data in a different form. In particular, we could subdivide the data on both axes. As a result each GPU would have a slice of 20x20 elements. Again no change is required to the original function, but we would change the allocation as follows:
89+
Alternatively, we can parition the data in a different form. In particular, we could
90+
subdivide the data on both axes. As a result each GPU would have a slice of 20x20 elements.
91+
Again no change is required to the original function, but we would change the allocation as
92+
follows:
6993

7094
```@example sharding_tutorial
7195
N = 1600
@@ -82,37 +106,66 @@ compiled_big_sin_sharded_both = @compile big_sin(x_sharded_both)
82106
compiled_big_sin_sharded_both(x_sharded_both)
83107
```
84108

85-
Sharding in reactant requires you to specify how the data is sharded across devices on a mesh. We start by specifying the mesh [`Sharding.Mesh`](@ref) which is a collection of the devices reshaped into an N-D grid. Additionally, we can specify names for each axis of the mesh, that are then referenced when specifying how the data is sharded.
109+
Sharding in reactant requires you to specify how the data is sharded across devices on a
110+
mesh. We start by specifying the mesh [`Sharding.Mesh`](@ref) which is a collection of the
111+
devices reshaped into an N-D grid. Additionally, we can specify names for each axis of the
112+
mesh, that are then referenced when specifying how the data is sharded.
86113

87-
1. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y))`: Creates a 2D grid of 4 devices arranged in a 2x2 grid. The first axis is named `:x` and the second axis is named `:y`.
88-
2. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y))`: Creates a 2D grid of 4 devices arranged in a 4x1 grid. The first axis is named `:x` and the second axis is named `:y`.
114+
1. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 2, 2), (:x, :y))`: Creates a 2D grid of 4
115+
devices arranged in a 2x2 grid. The first axis is named `:x` and the second axis is named
116+
`:y`.
117+
2. `Sharding.Mesh(reshape(Reactant.devices()[1:4], 4, 1), (:x, :y))`: Creates a 2D grid of 4
118+
devices arranged in a 4x1 grid. The first axis is named `:x` and the second axis is
119+
named `:y`.
89120

90-
Given the mesh, we will specify how the data is sharded across the devices.
121+
Given the mesh, we will specify how the data is sharded across the devices.
91122

92123
<!-- TODO describe how arrays are the "global data arrays, even though data is itself only stored on relevant device and computation is performed only devices with the required data (effectively showing under the hood how execution occurs) -->
93124

94125
<!-- TODO make a simple conway's game of life, or heat equation using sharding simulation example to show how a ``typical MPI'' simulation can be written using sharding. -->
95126

96127
## Simple 1-Dimensional Heat Equation
97128

98-
So far we chose a function which was perfectly parallelizable (e.g. each elemnt of the array only accesses its own data). Let's consider a more realistic example where an updated element requires data from its neighbors. In the distributed case, this requires communicating the data along the boundaries.
129+
So far we chose a function which was perfectly parallelizable (e.g. each elemnt of the array
130+
only accesses its own data). Let's consider a more realistic example where an updated
131+
element requires data from its neighbors. In the distributed case, this requires
132+
communicating the data along the boundaries.
99133

100-
In particular, let's implement a one-dimensional [heat equation](https://en.wikipedia.org/wiki/Heat_equation) simulation. In this code you initialize the temperature of all points of the simulation and over time the code will simulate how the heat is transfered across space. In particular points of high temperature will transfer energy to points of low energy.
134+
In particular, let's implement a one-dimensional
135+
[heat equation](https://en.wikipedia.org/wiki/Heat_equation) simulation. In this code you
136+
initialize the temperature of all points of the simulation and over time the code will
137+
simulate how the heat is transfered across space. In particular points of high temperature
138+
will transfer energy to points of low energy.
101139

102140
As an example, here is a visualization of a 2-dimensional heat equation:
103141

104142
![Heat Equation Animation](https://upload.wikimedia.org/wikipedia/commons/a/a9/Heat_eqn.gif)
105143

106-
TODO we should animate the above -- and even more ideally have one we generate ourselves.
144+
<!-- TODO we should animate the above -- and even more ideally have one we generate ourselves. -->
107145

108-
To keep things simple, let's implement a 1-dimensional heat equation here. We start off with an array for the temperature at each point, and will compute the next version of the temperatures according to the equation `x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]`.
146+
To keep things simple, let's implement a 1-dimensional heat equation here. We start off with
147+
an array for the temperature at each point, and will compute the next version of the
148+
temperatures according to the equation
149+
`x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]`.
109150

110-
Let's consider how this can be implemented with explicit MPI communication. Each node will contain a subset of the total data. For example, if we simulate with 100 points, and have 4 devices, each device will contain 25 data points. We're going to allocate some extra room at each end of the buffer to store the ``halo'', or the data at the boundary. Each time step that we take will first copy in the data from its neighbors into the halo via an explicit MPI send and recv call. We'll then compute the updated data for our slice of the data.
151+
Let's consider how this can be implemented with explicit MPI communication. Each node will
152+
contain a subset of the total data. For example, if we simulate with 100 points, and have 4
153+
devices, each device will contain 25 data points. We're going to allocate some extra room at
154+
each end of the buffer to store the ``halo'', or the data at the boundary. Each time step
155+
that we take will first copy in the data from its neighbors into the halo via an explicit
156+
MPI send and recv call. We'll then compute the updated data for our slice of the data.
111157

112-
With sharding, things are a bit more simple. We can write the code as if we only had one device. No explicit send or recv's are necessary
113-
as they will be added automatically by Reactant when it deduces they are needed. In fact, Reactant will attempt to optimize the placement of the communicatinos to minimize total runtime. While Reactant tries to do a good job (which could be faster than an initial implementation -- especially for complex codebases), an expert may be able to find a better placement of the communication.
158+
With sharding, things are a bit more simple. We can write the code as if we only had one
159+
device. No explicit send or recv's are necessary as they will be added automatically by
160+
Reactant when it deduces they are needed. In fact, Reactant will attempt to optimize the
161+
placement of the communicatinos to minimize total runtime. While Reactant tries to do a
162+
good job (which could be faster than an initial implementation -- especially for complex
163+
codebases), an expert may be able to find a better placement of the communication.
114164

115-
The only difference for the sharded code again occurs during allocation. Here we explicitly specify that we want to subdivide the initial grid of 100 amongst all devices. Analagously if we had 4 devices to work with, each device would have 25 elements in its local storage. From the user's standpoint, however, all arrays give access to the entire dataset.
165+
The only difference for the sharded code again occurs during allocation. Here we explicitly
166+
specify that we want to subdivide the initial grid of 100 amongst all devices. Analagously
167+
if we had 4 devices to work with, each device would have 25 elements in its local storage.
168+
From the user's standpoint, however, all arrays give access to the entire dataset.
116169

117170
::: code-group
118171

@@ -182,7 +235,7 @@ data = Reactant.to_rarray(
182235
)
183236

184237
function simulate(data, time_steps)
185-
@traced for i in 1:time_steps
238+
@trace for i in 1:time_steps
186239
one_dim_heat_equation_time_step_sharded!(data)
187240
end
188241
end
@@ -192,6 +245,13 @@ end
192245

193246
:::
194247

248+
MPI to send the data. between computers When using GPUs on different devices, one needs to copy the data through the network via NCCL instead of the `cuda.
249+
250+
All devices from all nodes are available for use by Reactant. Given the topology of the devices, Reactant will automatically determine the right type of communication primitive to use to send data between the relevant nodes. For example, between GPUs on the same host Reactant may use the faster `cudaMemcpy` whereas for GPUs on different nodes Reactant will use NCCL.
251+
252+
The fact that you doesn't need to specify how the communication is occuring enables code written with Reactant to be run on a different topology (e.g. moving fro
253+
One nice feature about how Reactant's handling of multiple devices is that you don't need to specify how the data is transfered. For example, when using multiple GPUs on the same host it might be efficient to copy data using a `cudaMemcpy` to transfer between devices directly. When using CPUs on multiple different nodes, one can use
254+
195255
## Devices
196256

197257
You can query the available devices that Reactant can access as follows using
@@ -211,13 +271,6 @@ Reactant.addressable_devices()
211271

212272
You can inspect the type of the device, as well as its properties.
213273

214-
MPI to send the data. between computers When using GPUs on different devices, one needs to copy the data through the network via NCCL instead of the `cuda.
215-
216-
All devices from all nodes are available for use by Reactant. Given the topology of the devices, Reactant will automatically determine the right type of communication primitive to use to send data between the relevant nodes. For example, between GPUs on the same host Reactant may use the faster `cudaMemcpy` whereas for GPUs on different nodes Reactant will use NCCL.
217-
218-
The fact that you doesn't need to specify how the communication is occuring enables code written with Reactant to be run on a different topology (e.g. moving fro
219-
One nice feature about how Reactant's handling of multiple devices is that you don't need to specify how the data is transfered. For example, when using multiple GPUs on the same host it might be efficient to copy data using a `cudaMemcpy` to transfer between devices directly. When using CPUs on multiple different nodes, one can use
220-
221274
## Generating Distributed Data by Concatenating Local-Worker Data
222275

223276
## Handling Replicated Tensors

0 commit comments

Comments
 (0)