1010#include < ATen/native/xpu/sycl/Philox4x32.h>
1111#include < ATen/native/xpu/sycl/TensorApplyUtils.h>
1212#include < ATen/ops/empty.h>
13+ #include < ATen/xpu/PhiloxXpuState.h>
1314#include < comm/DeviceProperties.h>
1415#include < comm/Runtime.h>
1516
@@ -23,50 +24,6 @@ using namespace at::xpu;
2324
2425const uint32_t rand4_engine_calls = 4 ;
2526
26- struct PhiloxState {
27- PhiloxState () = default ;
28- // Called if graph capture is not underway
29- PhiloxState (uint64_t seed, uint64_t offset) {
30- seed_ = seed;
31- offset_.val = offset;
32- }
33- // Called if graph capture is underway
34- PhiloxState (
35- uint64_t seed,
36- int64_t * offset_extragraph,
37- uint32_t offset_intragraph) {
38- seed_ = seed;
39- offset_.ptr = offset_extragraph;
40- offset_intragraph_ = offset_intragraph;
41- captured_ = true ;
42- }
43-
44- union Payload {
45- uint64_t val;
46- int64_t * ptr;
47- };
48-
49- uint64_t seed_ = 0 ;
50- Payload offset_;
51- uint32_t offset_intragraph_ = 0 ;
52- bool captured_ = false ;
53- };
54-
55- inline std::tuple<uint64_t , uint64_t > philox_unpack (PhiloxState arg) {
56- if (arg.captured_ ) {
57- // static_cast avoids "warning: invalid narrowing conversion from "long" to
58- // "unsigned long".
59- // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire
60- // kernel. For most threads' reads it will hit in cache, so it shouldn't
61- // hurt performance.
62- return std::make_tuple (
63- arg.seed_ ,
64- static_cast <uint64_t >(*(arg.offset_ .ptr ) + arg.offset_intragraph_ ));
65- } else {
66- return std::make_tuple (arg.seed_ , arg.offset_ .val );
67- }
68- }
69-
7027template <uint32_t UNROLL = rand4_engine_calls>
7128inline std::tuple<uint64_t , uint32_t , uint32_t > calc_execution_policy (
7229 int64_t total_elements) {
@@ -96,7 +53,7 @@ struct DistributionElementwiseKernelFunctor {
9653 int num_groups = item.get_group_range (0 );
9754 int idx = item.get_global_linear_id ();
9855
99- auto seeds = philox_unpack (philox_args_);
56+ auto seeds = at::xpu::philox::unpack (philox_args_);
10057 randStatePhilox4_32_10_t state;
10158 rand_init (std::get<0 >(seeds), idx, std::get<1 >(seeds), &state);
10259
@@ -125,23 +82,21 @@ struct DistributionElementwiseKernelFunctor {
12582 }
12683 DistributionElementwiseKernelFunctor (
12784 int64_t numel,
128- std::pair< uint64_t , uint64_t > rng_engine_inputs,
85+ PhiloxXpuState rng_engine_inputs,
12986 dist_t dist_func,
13087 transform_t transform_func,
13188 char * out_data,
13289 offset_calc_t offset_calc)
13390 : numel_(numel),
134- philox_args_ (PhiloxState(
135- std::get<0 >(rng_engine_inputs),
136- std::get<1>(rng_engine_inputs))),
91+ philox_args_ (rng_engine_inputs),
13792 dist_func_(dist_func),
13893 transform_func_(transform_func),
13994 out_data_(out_data),
14095 offset_calc_(offset_calc) {}
14196
14297 private:
14398 int64_t numel_;
144- PhiloxState philox_args_;
99+ PhiloxXpuState philox_args_;
145100 dist_t dist_func_;
146101 transform_t transform_func_;
147102 char * out_data_;
@@ -171,11 +126,11 @@ void distribution_nullary_kernel(
171126 auto num_groups = std::get<1 >(execution_policy);
172127 auto group_size = std::get<2 >(execution_policy);
173128
174- std::pair< uint64_t , uint64_t > rng_engine_inputs;
129+ PhiloxXpuState rng_engine_inputs;
175130 {
176131 // See Note [Acquire lock when using random generators]
177132 std::lock_guard<std::mutex> lock (gen->mutex_ );
178- rng_engine_inputs = gen->philox_engine_inputs (counter_offset);
133+ rng_engine_inputs = gen->philox_xpu_state (counter_offset);
179134 }
180135
181136 if (!iter.can_use_32bit_indexing ()) {
@@ -234,7 +189,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
234189 int global_size = item.get_global_range (0 );
235190 int global_idx = item.get_group (0 ) * group_size + item.get_local_id (0 );
236191
237- auto seeds = philox_unpack (philox_args_);
192+ auto seeds = at::xpu::philox::unpack (philox_args_);
238193 randStatePhilox4_32_10_t state;
239194 rand_init (std::get<0 >(seeds), global_idx, std::get<1 >(seeds), &state);
240195
@@ -247,7 +202,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
247202 DistributionUnaryElementwiseKernelFunctor (
248203 int numel,
249204 const func_t f,
250- PhiloxState philox_args,
205+ PhiloxXpuState philox_args,
251206 scalar1_t * output_data,
252207 const scalar2_t * input_data,
253208 inp_offset_calc_t input_offset_calculator,
@@ -263,7 +218,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
263218 private:
264219 int numel_;
265220 const func_t f_;
266- PhiloxState philox_args_;
221+ PhiloxXpuState philox_args_;
267222 scalar1_t * output_data_;
268223 const scalar2_t * input_data_;
269224 inp_offset_calc_t inp_calc_;
@@ -273,7 +228,7 @@ struct DistributionUnaryElementwiseKernelFunctor {
273228template <typename scalar1_t , typename scalar2_t , typename func_t >
274229void distribution_unary_kernel (
275230 TensorIterator& iter,
276- PhiloxState philox_args,
231+ PhiloxXpuState philox_args,
277232 func_t f) {
278233 if (!iter.can_use_32bit_indexing ()) {
279234 for (auto & sub_iter : iter.with_32bit_indexing ()) {
@@ -340,7 +295,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
340295 int global_size = item.get_global_range (0 );
341296 int global_idx = item.get_group (0 ) * group_size + item.get_local_id (0 );
342297
343- auto seeds = philox_unpack (philox_args_);
298+ auto seeds = at::xpu::philox::unpack (philox_args_);
344299
345300 randStatePhilox4_32_10_t state;
346301 rand_init (std::get<0 >(seeds), global_idx, std::get<1 >(seeds), &state);
@@ -356,7 +311,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
356311 DistributionBinaryElementwiseKernelFunctor (
357312 int numel,
358313 func_t f,
359- PhiloxState philox_args,
314+ PhiloxXpuState philox_args,
360315 output_t * output_data,
361316 const input_t_1* input_data_1,
362317 const input_t_2* input_data_2,
@@ -374,7 +329,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
374329 private:
375330 int64_t numel_;
376331 func_t f_;
377- PhiloxState philox_args_;
332+ PhiloxXpuState philox_args_;
378333 output_t * out_data_;
379334 const input_t_1* inp_data_1_;
380335 const input_t_2* inp_data_2_;
@@ -385,7 +340,7 @@ struct DistributionBinaryElementwiseKernelFunctor {
385340template <typename func_t >
386341void distribution_binary_kernel (
387342 TensorIteratorBase& iter,
388- PhiloxState philox_args,
343+ PhiloxXpuState philox_args,
389344 const func_t & f) {
390345 static_assert (
391346 std::is_same<
@@ -762,7 +717,7 @@ struct BernoulliTensorApplyFunctor {
762717 const prob_t & p2,
763718 const prob_t & p3,
764719 const prob_t & p4) const {
765- auto seeds = philox_unpack (philox_args_);
720+ auto seeds = at::xpu::philox::unpack (philox_args_);
766721 randStatePhilox4_32_10_t state;
767722 rand_init (
768723 std::get<0 >(seeds),
@@ -792,20 +747,18 @@ struct BernoulliTensorApplyFunctor {
792747 }
793748 }
794749 }
795- BernoulliTensorApplyFunctor (std::pair<uint64_t , uint64_t > rng_engine_inputs)
796- : philox_args_(
797- std::get<0 >(rng_engine_inputs),
798- std::get<1 >(rng_engine_inputs)) {}
750+ BernoulliTensorApplyFunctor (PhiloxXpuState rng_engine_inputs)
751+ : philox_args_(rng_engine_inputs) {}
799752
800753 private:
801- PhiloxState philox_args_;
754+ PhiloxXpuState philox_args_;
802755};
803756
804757template <typename scalar_t , typename prob_t >
805758void bernoulli_tensor_kernel (
806759 TensorBase& ret,
807760 TensorBase& p,
808- std::pair< uint64_t , uint64_t > rng_engine_inputs) {
761+ PhiloxXpuState rng_engine_inputs) {
809762 auto functor =
810763 BernoulliTensorApplyFunctor<scalar_t , prob_t >(rng_engine_inputs);
811764 // The template argument `4` below indicates that we want to operate on four
@@ -820,11 +773,11 @@ void bernoulli_tensor_kernel(
820773
821774template <typename RNG>
822775void bernoulli_kernel (const TensorBase& self, const TensorBase& p_, RNG gen) {
823- std::pair< uint64_t , uint64_t > rng_engine_inputs;
776+ PhiloxXpuState rng_engine_inputs;
824777 {
825778 // See Note [Acquire lock when using random generators]
826779 std::lock_guard<std::mutex> lock (gen->mutex_ );
827- rng_engine_inputs = gen->philox_engine_inputs (10 );
780+ rng_engine_inputs = gen->philox_xpu_state (10 );
828781 }
829782 TORCH_CHECK (
830783 at::isFloatingType (p_.scalar_type ()),
0 commit comments