Skip to content

Commit a8f20e3

Browse files
committed
Added philox4x32x10 and mcg31m1 engines
1 parent 3468340 commit a8f20e3

File tree

8 files changed

+140
-5
lines changed

8 files changed

+140
-5
lines changed

dpnp/backend/extensions/rng/device/dispatch/table_builder.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class Dispatch3DTableBuilder
8080
void populate(funcPtrT table[][_no_of_types][_no_of_methods]) const
8181
{
8282
const auto map_by_engine = {table_per_type_and_method<mkl_rng_dev::mrg32k3a<8>>(),
83+
table_per_type_and_method<mkl_rng_dev::philox4x32x10<8>>(),
84+
table_per_type_and_method<mkl_rng_dev::mcg31m1<8>>(),
8385
table_per_type_and_method<mkl_rng_dev::mcg59<8>>()};
8486
assert(map_by_engine.size() == _no_of_engines);
8587

dpnp/backend/extensions/rng/device/engine/base_engine.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class EngineType {
3434
public:
3535
enum Type : std::uint8_t {
3636
MRG32k3a = 0,
37+
PHILOX4x32x10,
38+
MCG31M1,
3739
MCG59,
3840
Base, // must be the last always
3941
};

dpnp/backend/extensions/rng/device/engine/builder/philox4x32x10.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ class Builder<mkl_rng_dev::philox4x32x10<VecSize>> : public BaseBuilder<mkl_rng_
3939
public:
4040
using EngineType = mkl_rng_dev::philox4x32x10<VecSize>;
4141

42-
Builder(EngineBase *engine) : BaseBuilder<EngineType, std::uint32_t, std::uint64_t>(engine) {}
42+
Builder(EngineBase *engine) : BaseBuilder<EngineType, std::uint64_t, std::uint64_t>(engine) {}
4343
};
4444
} // dpnp::backend::ext::rng::device::engine::builder
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include "base_engine.hpp"
29+
30+
31+
namespace dpnp::backend::ext::rng::device::engine
32+
{
33+
class MCG31M1 : public EngineBase {
34+
private:
35+
sycl::queue q_;
36+
std::vector<std::uint64_t> seed_vec{};
37+
std::vector<std::uint64_t> offset_vec{};
38+
39+
public:
40+
MCG31M1(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {
41+
seed_vec.push_back(seed);
42+
offset_vec.push_back(offset);
43+
}
44+
45+
sycl::queue &get_queue() override {
46+
return q_;
47+
}
48+
49+
virtual EngineType get_type() const noexcept override {
50+
return EngineType::MCG31M1;
51+
}
52+
53+
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54+
return seed_vec;
55+
}
56+
57+
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58+
return offset_vec;
59+
}
60+
};
61+
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/engine/mcg59_engine.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ namespace dpnp::backend::ext::rng::device::engine
3333
class MCG59 : public EngineBase {
3434
private:
3535
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec;
37-
std::vector<std::uint64_t> offset_vec;
36+
std::vector<std::uint64_t> seed_vec{};
37+
std::vector<std::uint64_t> offset_vec{};
3838

3939
public:
4040
MCG59(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {

dpnp/backend/extensions/rng/device/engine/mrg32k3a_engine.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ namespace dpnp::backend::ext::rng::device::engine
3333
class MRG32k3a : public EngineBase {
3434
private:
3535
sycl::queue q_;
36-
std::vector<std::uint64_t> seed_vec;
37-
std::vector<std::uint64_t> offset_vec;
36+
std::vector<std::uint64_t> seed_vec{};
37+
std::vector<std::uint64_t> offset_vec{};
3838

3939
public:
4040
MRG32k3a(sycl::queue &q, std::uint32_t seed, std::uint64_t offset = 0) : q_(q) {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include "base_engine.hpp"
29+
30+
31+
namespace dpnp::backend::ext::rng::device::engine
32+
{
33+
class PHILOX4x32x10 : public EngineBase {
34+
private:
35+
sycl::queue q_;
36+
std::vector<std::uint64_t> seed_vec{};
37+
std::vector<std::uint64_t> offset_vec{};
38+
39+
public:
40+
PHILOX4x32x10(sycl::queue &q, std::uint64_t seed, std::uint64_t offset = 0) : q_(q) {
41+
seed_vec.push_back(seed);
42+
offset_vec.push_back(offset);
43+
}
44+
45+
sycl::queue &get_queue() override {
46+
return q_;
47+
}
48+
49+
virtual EngineType get_type() const noexcept override {
50+
return EngineType::PHILOX4x32x10;
51+
}
52+
53+
virtual std::vector<std::uint64_t> get_seeds() const noexcept override {
54+
return seed_vec;
55+
}
56+
57+
virtual std::vector<std::uint64_t> get_offsets() const noexcept override {
58+
return offset_vec;
59+
}
60+
};
61+
} // dpnp::backend::ext::rng::device::engine

dpnp/backend/extensions/rng/device/rng_py.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@
3737
#include "gaussian.hpp"
3838

3939
#include "engine/mrg32k3a_engine.hpp"
40+
#include "engine/philox4x32x10_engine.hpp"
41+
#include "engine/mcg31m1_engine.hpp"
4042
#include "engine/mcg59_engine.hpp"
4143

44+
4245
namespace mkl_rng = oneapi::mkl::rng;
4346
namespace rng_dev_ext = dpnp::backend::ext::rng::device;
4447
namespace rng_dev_engine = dpnp::backend::ext::rng::device::engine;
@@ -84,6 +87,12 @@ PYBIND11_MODULE(_rng_dev_impl, m)
8487
py::class_<rng_dev_engine::MRG32k3a, rng_dev_engine::EngineBase>(m, "MRG32k3a")
8588
.def(py::init<sycl::queue &, std::uint32_t, std::uint64_t>());
8689

90+
py::class_<rng_dev_engine::PHILOX4x32x10, rng_dev_engine::EngineBase>(m, "PHILOX4x32x10")
91+
.def(py::init<sycl::queue &, std::uint64_t, std::uint64_t>());
92+
93+
py::class_<rng_dev_engine::MCG31M1, rng_dev_engine::EngineBase>(m, "MCG31M1")
94+
.def(py::init<sycl::queue &, std::uint32_t, std::uint64_t>());
95+
8796
py::class_<rng_dev_engine::MCG59, rng_dev_engine::EngineBase>(m, "MCG59")
8897
.def(py::init<sycl::queue &, std::uint32_t, std::uint64_t>());
8998

0 commit comments

Comments
 (0)