22
22
static const char *verbosity_str = getenv (" PYTORCH_DEBUG_MPS_ALLOCATOR" );
23
23
m_debug_verbosity = verbosity_str ? strtol (verbosity_str, nullptr , 0 ) : DebugVerbosity::SILENT;
24
24
25
- // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
26
- const double high_watermark_upper_bound = 2.0 ;
27
-
28
25
static const char *high_watermark_ratio_str = getenv (" PYTORCH_MPS_HIGH_WATERMARK_RATIO" );
29
- m_high_watermark_ratio = high_watermark_ratio_str ? strtod (high_watermark_ratio_str, nullptr ) : default_high_watermark_ratio;
30
- TORCH_CHECK (m_high_watermark_ratio >= 0.0 && m_high_watermark_ratio <= high_watermark_upper_bound,
31
- " invalid high watermark ratio " , m_high_watermark_ratio );
26
+ const double high_watermark_ratio = high_watermark_ratio_str ? strtod (high_watermark_ratio_str, nullptr ) :
27
+ default_high_watermark_ratio;
28
+ setHighWatermarkRatio (high_watermark_ratio );
32
29
33
- m_max_total_allowed_size = (m_high_watermark_ratio == 0.0 ) ? std::numeric_limits<size_t >::max () :
34
- static_cast <size_t >(m_high_watermark_ratio * (double )max_device_size ());
35
- // used for comparison with lower_watermark_ratio
36
- const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? high_watermark_upper_bound : m_high_watermark_ratio;
37
30
const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified :
38
31
default_low_watermark_ratio_discrete;
39
32
static const char *low_watermark_ratio_str = getenv (" PYTORCH_MPS_LOW_WATERMARK_RATIO" );
40
- m_low_watermark_ratio = low_watermark_ratio_str ? strtod (low_watermark_ratio_str, nullptr ) : default_low_watermark_ratio;
41
- TORCH_CHECK (m_low_watermark_ratio >= 0.0 && m_low_watermark_ratio <= high_watermark_limit,
42
- " invalid low watermark ratio " , m_low_watermark_ratio);
33
+ const double low_watermark_ratio = low_watermark_ratio_str ? strtod (low_watermark_ratio_str, nullptr ) : default_low_watermark_ratio;
34
+ setLowWatermarkRatio (low_watermark_ratio);
35
+ }
36
+
37
+ void MPSHeapAllocatorImpl::setHighWatermarkRatio (double ratio)
38
+ {
39
+ TORCH_CHECK (ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, " invalid high watermark ratio " , ratio);
40
+ m_max_total_allowed_size = (ratio == 0.0 ) ? std::numeric_limits<size_t >::max () :
41
+ static_cast <size_t >(ratio * (double )max_device_size ());
42
+ m_high_watermark_ratio = ratio;
43
+ }
44
+
45
+ void MPSHeapAllocatorImpl::setLowWatermarkRatio (double ratio)
46
+ {
47
+ // used for comparison with lower_watermark_ratio
48
+ const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
49
+ TORCH_CHECK (ratio >= 0.0 && ratio <= high_watermark_limit, " invalid low watermark ratio " , ratio);
43
50
// we use this to detect if there's memory pressure
44
- m_low_watermark_limit = (m_low_watermark_ratio == 0.0 ) ? std::numeric_limits<size_t >::max () :
45
- static_cast <size_t >(m_low_watermark_ratio * (double )max_device_size ());
51
+ m_low_watermark_limit = (ratio == 0.0 ) ? std::numeric_limits<size_t >::max () :
52
+ static_cast <size_t >(ratio * (double )max_device_size ());
53
+ m_low_watermark_ratio = ratio;
46
54
}
47
55
48
56
HeapBlock* MPSHeapAllocatorImpl::get_free_heap (AllocParams& params)
470
478
return buffer_block->buffer ;
471
479
}
472
480
473
- ssize_t MPSHeapAllocatorImpl::getRequestedBufferSize (void * ptr)
481
+ ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize (void * ptr)
474
482
{
475
483
std::lock_guard<std::mutex> lock (m_mutex);
476
484
552
560
}
553
561
554
562
// MPS allocator struct to be registered with Pytorch
555
- struct TORCH_API MPSAllocator final : public at::Allocator {
563
+ struct TORCH_API MPSAllocator final : public IMPSAllocator {
556
564
public:
557
565
explicit MPSAllocator (uint32_t Usage) :
558
566
m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage)
559
567
{
560
568
if (_getAllocImpl ().getDebugVerbosity ()) {
561
569
if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
562
- const size_t max_total_allowed_size = _getAllocImpl ().getMaxTotalAllowedSize ();
563
- const size_t low_watermark_limit = _getAllocImpl ().getLowWatermarkLimit ();
570
+ const size_t high_watermark_limit = _getAllocImpl ().getHighWatermarkLimit ();
571
+ const size_t low_watermark_limit = _getAllocImpl ().getLowWatermarkLimit ();
564
572
std::cerr << " Initializing "
565
573
<< ((m_usage & HeapAllocator::UsageFlags::SHARED) ? " shared" : " private" )
566
574
<< " heap allocator on "
567
575
<< (m_has_unified_memory ? " unified" : " discrete" )
568
576
<< " device memory of size "
569
577
<< _getAllocImpl ().Device ().recommendedMaxWorkingSetSize / 1048576UL << " MB"
570
578
<< " (max allowed: "
571
- << (max_total_allowed_size == std::numeric_limits<size_t >::max () ? " unlimited" :
572
- (to_string (max_total_allowed_size / 1048576UL ) + " MB" ))
579
+ << (high_watermark_limit == std::numeric_limits<size_t >::max () ? " unlimited" :
580
+ (to_string (high_watermark_limit / 1048576UL ) + " MB" ))
573
581
<< " , low watermark: "
574
582
<< (low_watermark_limit == std::numeric_limits<size_t >::max () ? " unlimited" :
575
583
(to_string (low_watermark_limit / 1048576UL ) + " MB" )) << " )\n " ;
@@ -580,20 +588,28 @@ explicit MPSAllocator(uint32_t Usage) :
580
588
~MPSAllocator () override {
581
589
_getAllocImpl ().emptyCache ();
582
590
}
591
+ DeleterFnPtr raw_deleter () const override { return &Delete; }
583
592
584
593
DataPtr allocate (const size_t nbytes) const override {
585
594
__block id <MTLBuffer > buf = nbytes > 0 ? _getAllocImpl ().malloc (nbytes, m_usage) : nullptr ;
586
595
return { buf, buf, &Delete, at::Device (at::DeviceType::MPS, 0 )};
587
596
}
588
-
589
- DataPtr allocate_scalar_buffer (void *value, size_t size) const {
597
+ DataPtr allocScalarBufferWithValue (void *value, size_t size) const override {
590
598
id <MTLBuffer > buf = _getAllocImpl ().allocScalarBufferWithValue (value, size);
591
599
return { buf, buf, &Delete, at::Device (at::DeviceType::MPS, 0 )};
592
600
}
593
-
594
- DeleterFnPtr raw_deleter () const override { return &Delete; }
595
- bool is_shared (void * ptr) const { return _getAllocImpl ().isSharedBuffer (ptr); }
596
- bool is_shared_storage_supported () const { return m_has_unified_memory; }
601
+ bool isSharedBuffer (void * ptr) const override { return _getAllocImpl ().isSharedBuffer (ptr); }
602
+ bool isSharedStorageSupported () const override { return m_has_unified_memory; }
603
+ void emptyCache () const override { _getAllocImpl ().emptyCache (); }
604
+ ssize_t getUnalignedBufferSize (void * ptr) const override { return _getAllocImpl ().getUnalignedBufferSize (ptr); }
605
+ IntArrayRef getBufferShape (void * ptr) const override { return _getAllocImpl ().getBufferShape (ptr); }
606
+ void setBufferShape (void * ptr, const IntArrayRef& shape) const override { _getAllocImpl ().setBufferShape (ptr, shape); }
607
+ size_t getTotalAllocatedMemory () const override { return _getAllocImpl ().getTotalAllocatedMemory (); }
608
+ ssize_t getLowWatermarkValue () const override { return _getAllocImpl ().getLowWatermarkValue (); }
609
+ size_t getLowWatermarkLimit () const override { return _getAllocImpl ().getLowWatermarkLimit (); }
610
+ size_t getHighWatermarkLimit () const override { return _getAllocImpl ().getHighWatermarkLimit (); }
611
+ void setLowWatermarkRatio (double ratio) const override { _getAllocImpl ().setLowWatermarkRatio (ratio); }
612
+ void setHighWatermarkRatio (double ratio) const override { _getAllocImpl ().setHighWatermarkRatio (ratio); }
597
613
598
614
private:
599
615
bool m_has_unified_memory;
@@ -618,41 +634,17 @@ static void Delete(void* ptr) {
618
634
}
619
635
} // anonymous namespace
620
636
621
- at::Allocator* getMPSSharedAllocator ()
622
- {
637
+ IMPSAllocator* getIMPSAllocator (bool sharedAllocator) {
638
+ if (!sharedAllocator) {
639
+ return &_getPrivateAllocator ();
640
+ }
623
641
auto & sa = _getSharedAllocator ();
624
- if (sa.is_shared_storage_supported ()) {
642
+ if (sa.isSharedStorageSupported ()) {
625
643
return &sa;
626
644
}
627
-
628
645
return nullptr ;
629
646
}
630
647
631
- at::Allocator* getMPSPrivateAllocator () {
632
- return &_getPrivateAllocator ();
633
- }
634
-
635
- // TODO: create MPSHooks interface and move these there.
636
- ssize_t get_requested_buffer_size (void * ptr) {
637
- return _getAllocImpl ().getRequestedBufferSize (ptr);
638
- }
639
-
640
- void set_buffer_shape (void * ptr, const IntArrayRef& shape) {
641
- _getAllocImpl ().setBufferShape (ptr, shape);
642
- }
643
-
644
- IntArrayRef get_buffer_shape (void * ptr) {
645
- return _getAllocImpl ().getBufferShape (ptr);
646
- }
647
-
648
- DataPtr allocate_scalar_buffer (void *value, size_t size) {
649
- return _getPrivateAllocator ().allocate_scalar_buffer (value, size);
650
- }
651
-
652
- uint32_t get_adaptive_commit_threshold () {
653
- return _getAllocImpl ().getLowWatermarkValue ();
654
- }
655
-
656
648
} // namespace mps
657
649
658
650
namespace native {
@@ -664,14 +656,14 @@ uint32_t get_adaptive_commit_threshold() {
664
656
bool is_pinned_mps (const Tensor& self, c10::optional<Device> device)
665
657
{
666
658
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (!device.has_value () || device->is_mps ());
667
- return at::mps::_getSharedAllocator ().is_shared (self.storage ().data ());
659
+ return at::mps::_getSharedAllocator ().isSharedBuffer (self.storage ().data ());
668
660
}
669
661
670
662
// torch.pin_memory() implementation
671
663
Tensor _pin_memory_mps (const Tensor& self, c10::optional<Device> device)
672
664
{
673
665
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (!device.has_value () || device->is_mps ());
674
- auto * shared_allocator = at::mps::getMPSSharedAllocator ( );
666
+ auto * shared_allocator = at::mps::getIMPSAllocator ( true );
675
667
TORCH_CHECK (shared_allocator, " unable to pin memory on a non-unified memory device" );
676
668
677
669
const size_t storage_size = detail::computeStorageNbytes (self.sizes (), self.strides (), self.dtype ().itemsize ());
0 commit comments