Skip to content

Commit

Permalink
Enable rebinding on AllocatorHandle and fix Blocked copy constructor …
Browse files Browse the repository at this point in the history
…issue (#89)

This PR introduces the ability to rebind `float` and `Float16` on
`AllocatorHandle`. Additionally, it addresses an issue where the
`Blocked` copy constructor failed to copy the allocator from another
instance.
  • Loading branch information
dian-lun-lin authored Feb 24, 2025
1 parent 46a8a59 commit b08978c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
40 changes: 40 additions & 0 deletions include/svs/core/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,15 @@ class AllocatorInterface {

// covariant return type
virtual AllocatorInterface* clone() const = 0;
virtual AllocatorInterface* rebind_float() const = 0;
virtual AllocatorInterface* rebind_float16() const = 0;
};

template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterface {
public:
using rebind_allocator_float = lib::rebind_allocator_t<float, Impl>;
using rebind_allocator_float16 = lib::rebind_allocator_t<Float16, Impl>;

// pass by value due to clone()
explicit AllocatorImpl(Impl impl)
: AllocatorInterface{}
Expand All @@ -554,6 +559,14 @@ template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterfac

AllocatorImpl<Impl>* clone() const override { return new AllocatorImpl(impl_); }

AllocatorImpl<rebind_allocator_float>* rebind_float() const override {
return new AllocatorImpl<rebind_allocator_float>(rebind_allocator_float{impl_});
}

AllocatorImpl<rebind_allocator_float16>* rebind_float16() const override {
return new AllocatorImpl<rebind_allocator_float16>(rebind_allocator_float16{impl_});
}

private:
Impl impl_;
};
Expand All @@ -580,6 +593,33 @@ template <typename T> class AllocatorHandle {
AllocatorHandle& operator=(AllocatorHandle&&) = default;
~AllocatorHandle() = default;

// Enable rebinding of allocators.
template <typename U> friend class AllocatorHandle;

template <typename U>
AllocatorHandle(const AllocatorHandle<U>& other)
requires std::is_same_v<T, float> && (!std::is_same_v<U, T>)
: impl_{other.impl_->rebind_float()} {}
template <typename U>
AllocatorHandle(const AllocatorHandle<U>& other)
requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
: impl_{other.impl_->rebind_float16()} {}

template <typename U>
AllocatorHandle& operator=(const AllocatorHandle<U>& other)
requires std::is_same_v<T, float> && (!std::is_same_v<U, T>)
{
impl_.reset(other.impl_->rebind_float());
return *this;
}
template <typename U>
AllocatorHandle& operator=(const AllocatorHandle<U>& other)
requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
{
impl_.reset(other.impl_->rebind_float16());
return *this;
}

T* allocate(size_t n) {
if (impl_.get() == nullptr) {
throw ANNEXCEPTION("Empty allocator handle!");
Expand Down
3 changes: 2 additions & 1 deletion include/svs/core/data/simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ template <typename Alloc> class Blocked {
template <typename U> friend class Blocked;
template <typename U>
Blocked(const Blocked<U>& other)
: parameters_{other.parameters_} {}
: parameters_{other.parameters_}
, allocator_{other.allocator_} {}

private:
BlockingParameters parameters_{};
Expand Down
14 changes: 14 additions & 0 deletions tests/svs/core/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,19 @@ CATCH_TEST_CASE("Testing Allocator", "[allocators]") {

CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr), svs::Float16*>);
}
CATCH_SECTION("Rebind") {
auto alloc = svs::make_allocator_handle(svs::lib::Allocator<int>());
svs::lib::rebind_allocator_t<svs::Float16, decltype(alloc)> rebound_alloc{
alloc};
auto* ptr = rebound_alloc.allocate(num_elements);
rebound_alloc.deallocate(ptr, num_elements);
CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr), svs::Float16*>);

svs::lib::rebind_allocator_t<float, decltype(alloc)> rebound_alloc2{
rebound_alloc};
auto* ptr2 = rebound_alloc2.allocate(num_elements);
rebound_alloc2.deallocate(ptr2, num_elements);
CATCH_STATIC_REQUIRE(std::is_same_v<decltype(ptr2), float*>);
}
}
}

0 comments on commit b08978c

Please sign in to comment.