Skip to content

Commit 8342cdd

Browse files
authored
Change philox state API usage to support XPU graph (#2085)
Switch philox_engine_inputs usage to philox_xpu_state per XPU graph request. This PR requires stock pytorch XPUGeneratorImpl [PR](pytorch/pytorch#163332) merged. Signed-off-by: Ma, Jing1 <[email protected]>
1 parent 369a0c9 commit 8342cdd

File tree

6 files changed

+65
-127
lines changed

6 files changed

+65
-127
lines changed

src/ATen/native/xpu/sycl/DistributionTemplates.h

Lines changed: 22 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/native/xpu/sycl/Philox4x32.h>
1111
#include <ATen/native/xpu/sycl/TensorApplyUtils.h>
1212
#include <ATen/ops/empty.h>
13+
#include <ATen/xpu/PhiloxXpuState.h>
1314
#include <comm/DeviceProperties.h>
1415
#include <comm/Runtime.h>
1516

@@ -23,50 +24,6 @@ using namespace at::xpu;
2324

2425
const uint32_t rand4_engine_calls = 4;
2526

26-
struct PhiloxState {
27-
PhiloxState() = default;
28-
// Called if graph capture is not underway
29-
PhiloxState(uint64_t seed, uint64_t offset) {
30-
seed_ = seed;
31-
offset_.val = offset;
32-
}
33-
// Called if graph capture is underway
34-
PhiloxState(
35-
uint64_t seed,
36-
int64_t* offset_extragraph,
37-
uint32_t offset_intragraph) {
38-
seed_ = seed;
39-
offset_.ptr = offset_extragraph;
40-
offset_intragraph_ = offset_intragraph;
41-
captured_ = true;
42-
}
43-
44-
union Payload {
45-
uint64_t val;
46-
int64_t* ptr;
47-
};
48-
49-
uint64_t seed_ = 0;
50-
Payload offset_;
51-
uint32_t offset_intragraph_ = 0;
52-
bool captured_ = false;
53-
};
54-
55-
inline std::tuple<uint64_t, uint64_t> philox_unpack(PhiloxState arg) {
56-
if (arg.captured_) {
57-
// static_cast avoids "warning: invalid narrowing conversion from "long" to
58-
// "unsigned long".
59-
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire
60-
// kernel. For most threads' reads it will hit in cache, so it shouldn't
61-
// hurt performance.
62-
return std::make_tuple(
63-
arg.seed_,
64-
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
65-
} else {
66-
return std::make_tuple(arg.seed_, arg.offset_.val);
67-
}
68-
}
69-
7027
template <uint32_t UNROLL = rand4_engine_calls>
7128
inline std::tuple<uint64_t, uint32_t, uint32_t> calc_execution_policy(
7229
int64_t total_elements) {
@@ -96,7 +53,7 @@ struct DistributionElementwiseKernelFunctor {
9653
int num_groups = item.get_group_range(0);
9754
int idx = item.get_global_linear_id();
9855

99-
auto seeds = philox_unpack(philox_args_);
56+
auto seeds = at::xpu::philox::unpack(philox_args_);
10057
randStatePhilox4_32_10_t state;
10158
rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state);
10259

@@ -125,23 +82,21 @@ struct DistributionElementwiseKernelFunctor {
12582
}
12683
DistributionElementwiseKernelFunctor(
12784
int64_t numel,
128-
std::pair<uint64_t, uint64_t> rng_engine_inputs,
85+
PhiloxXpuState rng_engine_inputs,
12986
dist_t dist_func,
13087
transform_t transform_func,
13188
char* out_data,
13289
offset_calc_t offset_calc)
13390
: numel_(numel),
134-
philox_args_(PhiloxState(
135-
std::get<0>(rng_engine_inputs),
136-
std::get<1>(rng_engine_inputs))),
91+
philox_args_(rng_engine_inputs),
13792
dist_func_(dist_func),
13893
transform_func_(transform_func),
13994
out_data_(out_data),
14095
offset_calc_(offset_calc) {}
14196

14297
private:
14398
int64_t numel_;
144-
PhiloxState philox_args_;
99+
PhiloxXpuState philox_args_;
145100
dist_t dist_func_;
146101
transform_t transform_func_;
147102
char* out_data_;
@@ -171,11 +126,11 @@ void distribution_nullary_kernel(
171126
auto num_groups = std::get<1>(execution_policy);
172127
auto group_size = std::get<2>(execution_policy);
173128

174-
std::pair<uint64_t, uint64_t> rng_engine_inputs;
129+
PhiloxXpuState rng_engine_inputs;
175130
{
176131
// See Note [Acquire lock when using random generators]
177132
std::lock_guard<std::mutex> lock(gen->mutex_);
178-
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
133+
rng_engine_inputs = gen->philox_xpu_state(counter_offset);
179134
}
180135

181136
if (!iter.can_use_32bit_indexing()) {
@@ -234,7 +189,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
234189
int global_size = item.get_global_range(0);
235190
int global_idx = item.get_group(0) * group_size + item.get_local_id(0);
236191

237-
auto seeds = philox_unpack(philox_args_);
192+
auto seeds = at::xpu::philox::unpack(philox_args_);
238193
randStatePhilox4_32_10_t state;
239194
rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state);
240195

@@ -247,7 +202,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
247202
DistributionUnaryElementwiseKernelFunctor(
248203
int numel,
249204
const func_t f,
250-
PhiloxState philox_args,
205+
PhiloxXpuState philox_args,
251206
scalar1_t* output_data,
252207
const scalar2_t* input_data,
253208
inp_offset_calc_t input_offset_calculator,
@@ -263,7 +218,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
263218
private:
264219
int numel_;
265220
const func_t f_;
266-
PhiloxState philox_args_;
221+
PhiloxXpuState philox_args_;
267222
scalar1_t* output_data_;
268223
const scalar2_t* input_data_;
269224
inp_offset_calc_t inp_calc_;
@@ -273,7 +228,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
273228
template <typename scalar1_t, typename scalar2_t, typename func_t>
274229
void distribution_unary_kernel(
275230
TensorIterator& iter,
276-
PhiloxState philox_args,
231+
PhiloxXpuState philox_args,
277232
func_t f) {
278233
if (!iter.can_use_32bit_indexing()) {
279234
for (auto& sub_iter : iter.with_32bit_indexing()) {
@@ -340,7 +295,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
340295
int global_size = item.get_global_range(0);
341296
int global_idx = item.get_group(0) * group_size + item.get_local_id(0);
342297

343-
auto seeds = philox_unpack(philox_args_);
298+
auto seeds = at::xpu::philox::unpack(philox_args_);
344299

345300
randStatePhilox4_32_10_t state;
346301
rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state);
@@ -356,7 +311,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
356311
DistributionBinaryElementwiseKernelFunctor(
357312
int numel,
358313
func_t f,
359-
PhiloxState philox_args,
314+
PhiloxXpuState philox_args,
360315
output_t* output_data,
361316
const input_t_1* input_data_1,
362317
const input_t_2* input_data_2,
@@ -374,7 +329,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
374329
private:
375330
int64_t numel_;
376331
func_t f_;
377-
PhiloxState philox_args_;
332+
PhiloxXpuState philox_args_;
378333
output_t* out_data_;
379334
const input_t_1* inp_data_1_;
380335
const input_t_2* inp_data_2_;
@@ -385,7 +340,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
385340
template <typename func_t>
386341
void distribution_binary_kernel(
387342
TensorIteratorBase& iter,
388-
PhiloxState philox_args,
343+
PhiloxXpuState philox_args,
389344
const func_t& f) {
390345
static_assert(
391346
std::is_same<
@@ -762,7 +717,7 @@ struct BernoulliTensorApplyFunctor {
762717
const prob_t& p2,
763718
const prob_t& p3,
764719
const prob_t& p4) const {
765-
auto seeds = philox_unpack(philox_args_);
720+
auto seeds = at::xpu::philox::unpack(philox_args_);
766721
randStatePhilox4_32_10_t state;
767722
rand_init(
768723
std::get<0>(seeds),
@@ -792,20 +747,18 @@ struct BernoulliTensorApplyFunctor {
792747
}
793748
}
794749
}
795-
BernoulliTensorApplyFunctor(std::pair<uint64_t, uint64_t> rng_engine_inputs)
796-
: philox_args_(
797-
std::get<0>(rng_engine_inputs),
798-
std::get<1>(rng_engine_inputs)) {}
750+
BernoulliTensorApplyFunctor(PhiloxXpuState rng_engine_inputs)
751+
: philox_args_(rng_engine_inputs) {}
799752

800753
private:
801-
PhiloxState philox_args_;
754+
PhiloxXpuState philox_args_;
802755
};
803756

804757
template <typename scalar_t, typename prob_t>
805758
void bernoulli_tensor_kernel(
806759
TensorBase& ret,
807760
TensorBase& p,
808-
std::pair<uint64_t, uint64_t> rng_engine_inputs) {
761+
PhiloxXpuState rng_engine_inputs) {
809762
auto functor =
810763
BernoulliTensorApplyFunctor<scalar_t, prob_t>(rng_engine_inputs);
811764
// The template argument `4` below indicates that we want to operate on four
@@ -820,11 +773,11 @@ void bernoulli_tensor_kernel(
820773

821774
template <typename RNG>
822775
void bernoulli_kernel(const TensorBase& self, const TensorBase& p_, RNG gen) {
823-
std::pair<uint64_t, uint64_t> rng_engine_inputs;
776+
PhiloxXpuState rng_engine_inputs;
824777
{
825778
// See Note [Acquire lock when using random generators]
826779
std::lock_guard<std::mutex> lock(gen->mutex_);
827-
rng_engine_inputs = gen->philox_engine_inputs(10);
780+
rng_engine_inputs = gen->philox_xpu_state(10);
828781
}
829782
TORCH_CHECK(
830783
at::isFloatingType(p_.scalar_type()),

src/ATen/native/xpu/sycl/Distributions.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct PoissonTensorApplyFunctor {
1717
SYCL_KERNEL_ASSERT(
1818
lambda >= 0 &&
1919
"invalid Poisson rate, expected rate to be non-negative");
20-
auto seeds = philox_unpack(philox_args_);
20+
auto seeds = at::xpu::philox::unpack(philox_args_);
2121
randStatePhilox4_32_10_t state;
2222
rand_init(
2323
std::get<0>(seeds),
@@ -26,20 +26,18 @@ struct PoissonTensorApplyFunctor {
2626
&state);
2727
ret_val = static_cast<scalar_t>(rand_poisson(&state, lambda));
2828
}
29-
PoissonTensorApplyFunctor(std::pair<uint64_t, uint64_t> rng_engine_inputs)
30-
: philox_args_(
31-
std::get<0>(rng_engine_inputs),
32-
std::get<1>(rng_engine_inputs)) {}
29+
PoissonTensorApplyFunctor(PhiloxXpuState rng_engine_inputs)
30+
: philox_args_(rng_engine_inputs) {}
3331

3432
private:
35-
PhiloxState philox_args_;
33+
PhiloxXpuState philox_args_;
3634
};
3735

3836
template <typename scalar_t>
3937
void poisson_kernel(
4038
const at::TensorBase& ret,
4139
const at::TensorBase& lambda,
42-
std::pair<uint64_t, uint64_t> rng_engine_inputs) {
40+
PhiloxXpuState rng_engine_inputs) {
4341
auto functor = PoissonTensorApplyFunctor<scalar_t>(rng_engine_inputs);
4442
at::native::xpu::tensor_apply2<
4543
scalar_t,
@@ -55,11 +53,11 @@ void launch_poisson_kernel(
5553
const TensorBase& ret,
5654
const TensorBase& lambda,
5755
at::XPUGeneratorImpl* gen) {
58-
std::pair<uint64_t, uint64_t> rng_engine_inputs;
56+
PhiloxXpuState rng_engine_inputs;
5957
{
6058
// See Note [Acquire lock when using random generators]
6159
std::lock_guard<std::mutex> lock(gen->mutex_);
62-
rng_engine_inputs = gen->philox_engine_inputs(20);
60+
rng_engine_inputs = gen->philox_xpu_state(20);
6361
}
6462
AT_DISPATCH_FLOATING_TYPES_AND2(
6563
at::ScalarType::Half,
@@ -101,21 +99,19 @@ struct BinomialFunctor {
10199
};
102100

103101
template <typename scalar_t>
104-
void binomial_kernel(TensorIteratorBase& iter, PhiloxState philox_args) {
102+
void binomial_kernel(TensorIteratorBase& iter, PhiloxXpuState philox_args) {
105103
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
106104
BinomialFunctor<scalar_t, accscalar_t> f;
107105
at::native::xpu::distribution_binary_kernel(iter, philox_args, f);
108106
}
109107

110108
void launch_binomial_kernel(TensorIteratorBase& iter, XPUGeneratorImpl* gen) {
111-
std::pair<uint64_t, uint64_t> engine_inputs;
109+
PhiloxXpuState rng_engine_inputs;
112110
{
113111
// See Note [Acquire lock when using random generators]
114112
std::lock_guard<std::mutex> lock(gen->mutex_);
115-
engine_inputs = gen->philox_engine_inputs(42);
113+
rng_engine_inputs = gen->philox_xpu_state(42);
116114
}
117-
PhiloxState rng_engine_inputs(
118-
std::get<0>(engine_inputs), std::get<1>(engine_inputs));
119115
AT_DISPATCH_FLOATING_TYPES_AND2(
120116
at::ScalarType::Half,
121117
at::ScalarType::BFloat16,
@@ -130,7 +126,7 @@ struct GammaTensorApplyFunctor {
130126
sycl::nd_item<1> item,
131127
scalar_t& ret_val,
132128
const scalar_t& alpha) const {
133-
auto seeds = philox_unpack(philox_args_);
129+
auto seeds = at::xpu::philox::unpack(philox_args_);
134130
randStatePhilox4_32_10_t state;
135131
rand_init(
136132
std::get<0>(seeds),
@@ -155,18 +151,18 @@ struct GammaTensorApplyFunctor {
155151
ret_val = (min_value > sample) ? min_value : sample;
156152
}
157153

158-
GammaTensorApplyFunctor(PhiloxState philox_args)
154+
GammaTensorApplyFunctor(PhiloxXpuState philox_args)
159155
: philox_args_(philox_args) {}
160156

161157
private:
162-
PhiloxState philox_args_;
158+
PhiloxXpuState philox_args_;
163159
};
164160

165161
template <typename scalar_t>
166162
void gamma_kernel(
167163
const at::TensorBase& ret,
168164
const at::TensorBase& alpha,
169-
PhiloxState philox_args) {
165+
PhiloxXpuState philox_args) {
170166
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
171167
GammaTensorApplyFunctor<scalar_t, accscalar_t> functor(philox_args);
172168
at::native::xpu::tensor_apply2<
@@ -183,18 +179,17 @@ void launch_gamma_kernel(
183179
Tensor& ret,
184180
const Tensor& alpha,
185181
XPUGeneratorImpl* gen) {
186-
std::pair<uint64_t, uint64_t> engine_inputs;
182+
PhiloxXpuState rng_engine_inputs;
187183
{
188184
// See Note [Acquire lock when using random generators]
189185
std::lock_guard<std::mutex> lock(gen->mutex_);
190186
// Using a seed value of 10 for the Philox random engine initialization.
191187
// This seed was chosen to ensure consistent random number generation
192188
// behavior for this specific kernel. Modify with caution as it affects
193189
// reproducibility of results.
194-
engine_inputs = gen->philox_engine_inputs(10);
190+
rng_engine_inputs = gen->philox_xpu_state(10);
195191
}
196-
PhiloxState rng_engine_inputs(
197-
std::get<0>(engine_inputs), std::get<1>(engine_inputs));
192+
198193
AT_DISPATCH_FLOATING_TYPES_AND2(
199194
at::ScalarType::Half,
200195
at::ScalarType::BFloat16,

0 commit comments

Comments
 (0)