@@ -537,10 +537,15 @@ class AllocatorInterface {
537
537
538
538
// covariant return type
539
539
virtual AllocatorInterface* clone () const = 0;
540
+ virtual AllocatorInterface* rebind_float () const = 0;
541
+ virtual AllocatorInterface* rebind_float16 () const = 0;
540
542
};
541
543
542
544
template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterface {
543
545
public:
546
+ using rebind_allocator_float = lib::rebind_allocator_t <float , Impl>;
547
+ using rebind_allocator_float16 = lib::rebind_allocator_t <Float16, Impl>;
548
+
544
549
// pass by value due to clone()
545
550
explicit AllocatorImpl (Impl impl)
546
551
: AllocatorInterface{}
@@ -554,6 +559,14 @@ template <detail::Allocator Impl> class AllocatorImpl : public AllocatorInterfac
554
559
555
560
AllocatorImpl<Impl>* clone () const override { return new AllocatorImpl (impl_); }
556
561
562
+ AllocatorImpl<rebind_allocator_float>* rebind_float () const override {
563
+ return new AllocatorImpl<rebind_allocator_float>(rebind_allocator_float{impl_});
564
+ }
565
+
566
+ AllocatorImpl<rebind_allocator_float16>* rebind_float16 () const override {
567
+ return new AllocatorImpl<rebind_allocator_float16>(rebind_allocator_float16{impl_});
568
+ }
569
+
557
570
private:
558
571
Impl impl_;
559
572
};
@@ -580,6 +593,33 @@ template <typename T> class AllocatorHandle {
580
593
AllocatorHandle& operator =(AllocatorHandle&&) = default ;
581
594
~AllocatorHandle () = default ;
582
595
596
+ // Enable rebinding of allocators.
597
+ template <typename U> friend class AllocatorHandle ;
598
+
599
+ template <typename U>
600
+ AllocatorHandle (const AllocatorHandle<U>& other)
601
+ requires std::is_same_v<T, float > && (!std::is_same_v<U, T>)
602
+ : impl_{other.impl_ ->rebind_float ()} {}
603
+ template <typename U>
604
+ AllocatorHandle (const AllocatorHandle<U>& other)
605
+ requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
606
+ : impl_{other.impl_ ->rebind_float16 ()} {}
607
+
608
+ template <typename U>
609
+ AllocatorHandle& operator =(const AllocatorHandle<U>& other)
610
+ requires std::is_same_v<T, float > && (!std::is_same_v<U, T>)
611
+ {
612
+ impl_.reset (other.impl_ ->rebind_float ());
613
+ return *this ;
614
+ }
615
+ template <typename U>
616
+ AllocatorHandle& operator =(const AllocatorHandle<U>& other)
617
+ requires std::is_same_v<T, Float16> && (!std::is_same_v<U, T>)
618
+ {
619
+ impl_.reset (other.impl_ ->rebind_float16 ());
620
+ return *this ;
621
+ }
622
+
583
623
T* allocate (size_t n) {
584
624
if (impl_.get () == nullptr ) {
585
625
throw ANNEXCEPTION (" Empty allocator handle!" );
0 commit comments