Skip to content

Commit a192ac4

Browse files
committed
Add user_cpu_context and ability to provide host_policy via it
1 parent 4eb47a5 commit a192ac4

File tree

3 files changed

+13
-0
lines changed

3 files changed

+13
-0
lines changed

cpp/oneapi/dal/compute.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "oneapi/dal/detail/compute_ops.hpp"
20+
#include "oneapi/dal/detail/user_policy.hpp"
2021
#include "oneapi/dal/detail/spmd_policy.hpp"
2122
#include "oneapi/dal/spmd/communicator.hpp"
2223

@@ -28,6 +29,11 @@ auto compute(Args&&... args) {
2829
return dal::detail::compute_dispatch(std::forward<Args>(args)...);
2930
}
3031

32+
template <typename... Args>
33+
auto compute(detail::user_cpu_context uctx, Args&&... args) {
34+
return dal::detail::compute_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
35+
}
36+
3137
#ifdef ONEDAL_DATA_PARALLEL
3238
template <typename... Args>
3339
auto compute(sycl::queue& queue, Args&&... args) {

cpp/oneapi/dal/detail/policy.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class ONEDAL_EXPORT host_policy : public base {
103103
}
104104
host_policy(const host_policy&) = default;
105105
host_policy(host_policy&&) = default;
106+
host_policy& operator= (const host_policy&) = default;
106107

107108
static host_policy get_default() {
108109
return host_policy(make_default_impl());

cpp/oneapi/dal/train.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "oneapi/dal/detail/train_ops.hpp"
20+
#include "oneapi/dal/detail/user_policy.hpp"
2021
#include "oneapi/dal/detail/spmd_policy.hpp"
2122
#include "oneapi/dal/spmd/communicator.hpp"
2223

@@ -28,6 +29,11 @@ auto train(Args&&... args) {
2829
return dal::detail::train_dispatch(std::forward<Args>(args)...);
2930
}
3031

32+
template <typename... Args>
33+
auto train(detail::user_cpu_context uctx, Args&&... args) {
34+
return dal::detail::train_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
35+
}
36+
3137
#ifdef ONEDAL_DATA_PARALLEL
3238
template <typename... Args>
3339
auto train(sycl::queue& queue, Args&&... args) {

0 commit comments

Comments
 (0)