Skip to content

Commit b08978c

Browse files
authored
Enable rebinding on AllocatorHandle and fix Blocked copy constructor issue (#89)
This PR introduces the ability to rebind `float` and `Float16` on `AllocatorHandle`. Additionally, it addresses an issue where the `Blocked` copy constructor failed to copy the allocator from another instance.
1 parent 46a8a59 commit b08978c

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

include/svs/core/allocator.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,15 @@ class AllocatorInterface {
537537

538538
// covariant return type
539539
virtual AllocatorInterface* clone() const = 0;
540+
virtual AllocatorInterface* rebind_float() const = 0;
541+
virtual AllocatorInterface* rebind_float16() const = 0;
540542
};
541543

542544
template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterface {
543545
public:
546+
using rebind_allocator_float = lib::rebind_allocator_t<float, Impl>;
547+
using rebind_allocator_float16 = lib::rebind_allocator_t<Float16, Impl>;
548+
544549
// pass by value due to clone()
545550
explicit AllocatorImpl(Impl impl)
546551
: AllocatorInterface{}
@@ -554,6 +559,14 @@ template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterfac
554559

555560
AllocatorImpl<Impl>* clone() const override { return new AllocatorImpl(impl_); }
556561

562+
AllocatorImpl<rebind_allocator_float>* rebind_float() const override {
563+
return new AllocatorImpl<rebind_allocator_float>(rebind_allocator_float{impl_});
564+
}
565+
566+
AllocatorImpl<rebind_allocator_float16>* rebind_float16() const override {
567+
return new AllocatorImpl<rebind_allocator_float16>(rebind_allocator_float16{impl_});
568+
}
569+
557570
private:
558571
Impl impl_;
559572
};
@@ -580,6 +593,33 @@ template <typename T> class AllocatorHandle {
580593
AllocatorHandle& operator=(AllocatorHandle&&) = default;
581594
~AllocatorHandle() = default;
582595

596+
// Enable rebinding of allocators.
597+
template <typename U> friend class AllocatorHandle;
598+
599+
template <typename U>
600+
AllocatorHandle(const AllocatorHandle<U>& other)
601+
requires std::is_same_v<T, float> && (!std::is_same_v<U, T>)
602+
: impl_{other.impl_->rebind_float()} {}
603+
template <typename U>
604+
AllocatorHandle(const AllocatorHandle<U>& other)
605+
requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
606+
: impl_{other.impl_->rebind_float16()} {}
607+
608+
template <typename U>
609+
AllocatorHandle& operator=(const AllocatorHandle<U>& other)
610+
requires std::is_same_v<T, float> && (!std::is_same_v<U, T>)
611+
{
612+
impl_.reset(other.impl_->rebind_float());
613+
return *this;
614+
}
615+
template <typename U>
616+
AllocatorHandle& operator=(const AllocatorHandle<U>& other)
617+
requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
618+
{
619+
impl_.reset(other.impl_->rebind_float16());
620+
return *this;
621+
}
622+
583623
T* allocate(size_t n) {
584624
if (impl_.get() == nullptr) {
585625
throw ANNEXCEPTION("Empty allocator handle!");

include/svs/core/data/simple.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,8 @@ template <typename Alloc> class Blocked {
565565
template <typename U> friend class Blocked;
566566
template <typename U>
567567
Blocked(const Blocked<U>& other)
568-
: parameters_{other.parameters_} {}
568+
: parameters_{other.parameters_}
569+
, allocator_{other.allocator_} {}
569570

570571
private:
571572
BlockingParameters parameters_{};

tests/svs/core/allocator.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,5 +196,19 @@ CATCH_TEST_CASE("Testing Allocator", "[allocators]") {
196196

197197
CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr), svs::Float16*>);
198198
}
199+
CATCH_SECTION("Rebind") {
200+
auto alloc = svs::make_allocator_handle(svs::lib::Allocator<int>());
201+
svs::lib::rebind_allocator_t<svs::Float16, decltype(alloc)> rebound_alloc{
202+
alloc};
203+
auto* ptr = rebound_alloc.allocate(num_elements);
204+
rebound_alloc.deallocate(ptr, num_elements);
205+
CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr), svs::Float16*>);
206+
207+
svs::lib::rebind_allocator_t<float, decltype(alloc)> rebound_alloc2{
208+
rebound_alloc};
209+
auto* ptr2 = rebound_alloc2.allocate(num_elements);
210+
rebound_alloc2.deallocate(ptr2, num_elements);
211+
CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr2), float*>);
212+
}
199213
}
200214
}

0 commit comments

Comments
 (0)