@@ -537,10 +537,15 @@ class AllocatorInterface {
537537
538538 // covariant return type
539539 virtual AllocatorInterface* clone () const = 0;
540+ virtual AllocatorInterface* rebind_float () const = 0;
541+ virtual AllocatorInterface* rebind_float16 () const = 0;
540542};
541543
542544template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterface {
543545 public:
546+ using rebind_allocator_float = lib::rebind_allocator_t <float , Impl>;
547+ using rebind_allocator_float16 = lib::rebind_allocator_t <Float16, Impl>;
548+
544549 // pass by value due to clone()
545550 explicit AllocatorImpl (Impl impl)
546551 : AllocatorInterface{}
@@ -554,6 +559,14 @@ template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterfac
554559
555560 AllocatorImpl<Impl>* clone () const override { return new AllocatorImpl (impl_); }
556561
562+ AllocatorImpl<rebind_allocator_float>* rebind_float () const override {
563+ return new AllocatorImpl<rebind_allocator_float>(rebind_allocator_float{impl_});
564+ }
565+
566+ AllocatorImpl<rebind_allocator_float16>* rebind_float16 () const override {
567+ return new AllocatorImpl<rebind_allocator_float16>(rebind_allocator_float16{impl_});
568+ }
569+
557570 private:
558571 Impl impl_;
559572};
@@ -580,6 +593,33 @@ template <typename T> class AllocatorHandle {
580593 AllocatorHandle& operator =(AllocatorHandle&&) = default ;
581594 ~AllocatorHandle () = default ;
582595
596+ // Enable rebinding of allocators.
597+ template <typename U> friend class AllocatorHandle ;
598+
599+ template <typename U>
600+ AllocatorHandle (const AllocatorHandle<U>& other)
601+ requires std::is_same_v<T, float > && (!std::is_same_v<U, T>)
602+ : impl_{other.impl_ ->rebind_float ()} {}
603+ template <typename U>
604+ AllocatorHandle (const AllocatorHandle<U>& other)
605+ requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
606+ : impl_{other.impl_ ->rebind_float16 ()} {}
607+
608+ template <typename U>
609+ AllocatorHandle& operator =(const AllocatorHandle<U>& other)
610+ requires std::is_same_v<T, float > && (!std::is_same_v<U, T>)
611+ {
612+ impl_.reset (other.impl_ ->rebind_float ());
613+ return *this ;
614+ }
615+ template <typename U>
616+ AllocatorHandle& operator =(const AllocatorHandle<U>& other)
617+ requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
618+ {
619+ impl_.reset (other.impl_ ->rebind_float16 ());
620+ return *this ;
621+ }
622+
583623 T* allocate (size_t n) {
584624 if (impl_.get () == nullptr ) {
585625 throw ANNEXCEPTION (" Empty allocator handle!" );
0 commit comments