Skip to content

Commit c981b7e

Browse files
razarmehrpytorchmergebot
authored andcommitted
[MPS] Add MPSAllocatorInterface to access methods of MPSAllocator (pytorch#94327)
This is a prerequisite for the upcoming PR's for the MPS Modules and Memory Leak Detection features. Also added pragma once to headers. Pull Request resolved: pytorch#94327 Approved by: https://github.com/kulinseth
1 parent 51b487b commit c981b7e

File tree

8 files changed

+141
-110
lines changed

8 files changed

+141
-110
lines changed

aten/src/ATen/mps/MPSAllocator.h

+26-35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// Copyright © 2022 Apple Inc.
22

3+
#pragma once
4+
5+
#include <ATen/mps/MPSAllocatorInterface.h>
36
#include <ATen/mps/MPSStream.h>
47
#include <cstdio>
58
#include <mutex>
@@ -9,27 +12,10 @@
912

1013
// this implementation is based on CUDACachingAllocator.
1114
// It utilizes Metal Heaps to improve the performance with buffer allocation.
15+
// Do not include this header. Use MPSAllocatorInterface.h instead.
1216
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
1317
namespace at {
1418
namespace mps {
15-
16-
class IMpsAllocatorCallback {
17-
public:
18-
enum class EventType {
19-
ALLOCATED, // buffer got allocated to be used immediately
20-
RECYCLED, // buffer pulled from free list to be reused
21-
FREED, // buffer put to free list for future recycling
22-
RELEASED, // buffer memory released
23-
};
24-
virtual ~IMpsAllocatorCallback() = default;
25-
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
26-
};
27-
28-
// MPS allocator will execute every registered callback when a block of memory is freed.
29-
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
30-
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
31-
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
32-
3319
namespace HeapAllocator {
3420

3521
#define MB(x) round_page(x * 1048576UL)
@@ -263,27 +249,44 @@ class MPSHeapAllocatorImpl
263249

264250
// interface exposed to at::Allocator
265251
id<MTLBuffer> malloc(size_t size, uint32_t usage);
252+
// frees a buffer and returns it into buffer pool
266253
void free(void* ptr);
254+
// releases all the cached buffers and their associated heaps
267255
void emptyCache();
268-
// interface exposed to internal MPS operations
256+
// returns true if buffer was allocated from the shared pool
269257
bool isSharedBuffer(void* ptr);
270-
ssize_t getRequestedBufferSize(void* ptr);
258+
// get the requested unaligned size of an MTLBuffer
259+
ssize_t getUnalignedBufferSize(void* ptr);
260+
// set the shape of a base tensor from a view tensor
271261
void setBufferShape(void* ptr, const IntArrayRef& shape);
262+
// retrieve the shape of a base tensor from a view tensor
272263
IntArrayRef getBufferShape(void* ptr);
264+
// allocate a buffer from a specialized pool to import CPU scalars into GPU
273265
id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
274266
// this indicates how far (in Megabytes) the current total allocations are from the
275267
// low watermark limit which is used to detect if we're under memory pressure
276268
// This returns zero if we've reached the low watermark limit
277269
ssize_t getLowWatermarkValue();
278-
279-
bool getDebugVerbosity() const { return m_debug_verbosity; }
280-
size_t getMaxTotalAllowedSize() const { return m_max_total_allowed_size; }
270+
// (see m_low_watermark_ratio for description)
271+
void setLowWatermarkRatio(double ratio);
272+
// (see m_high_watermark_ratio for description)
273+
void setHighWatermarkRatio(double ratio);
274+
// (see m_low_watermark_limit for description)
281275
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
276+
// (see m_max_total_allowed_size for description)
277+
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
278+
// (see m_total_allocated_memory for description)
279+
size_t getTotalAllocatedMemory() const {return m_total_allocated_memory; }
280+
// (see enum DebugVerbosity for description)
281+
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
282+
// returns the device that we allocate from
282283
inline id<MTLDevice> Device() const { return m_device; }
283284

284285
private:
285286
// (see m_high_watermark_ratio for description)
286287
constexpr static double default_high_watermark_ratio = 1.7;
288+
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
289+
constexpr static double default_high_watermark_upper_bound = 2.0;
287290
// (see m_low_watermark_ratio for description)
288291
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
289292
constexpr static double default_low_watermark_ratio_unified = 1.4;
@@ -375,17 +378,5 @@ class MPSHeapAllocatorImpl
375378
};
376379

377380
} // namespace HeapAllocator
378-
379-
// interface exposed to internal MPS operations
380-
381-
// get the requested non-aligned size of an MTL buffer
382-
ssize_t get_requested_buffer_size(void* ptr);
383-
// retrieve the shape of a base tensor from a view tensor
384-
IntArrayRef get_buffer_shape(void* ptr);
385-
// set the shape of a base tensor from a view tensor
386-
void set_buffer_shape(void* ptr, const IntArrayRef& shape);
387-
// allocate a buffer from a specialized pool to import CPU scalars into GPU
388-
DataPtr allocate_scalar_buffer(void* value, size_t size);
389-
390381
} // namespace mps
391382
} // namespace at

aten/src/ATen/mps/MPSAllocator.mm

+50-58
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,35 @@
2222
static const char *verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
2323
m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT;
2424

25-
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
26-
const double high_watermark_upper_bound = 2.0;
27-
2825
static const char *high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
29-
m_high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio;
30-
TORCH_CHECK(m_high_watermark_ratio >= 0.0 && m_high_watermark_ratio <= high_watermark_upper_bound,
31-
"invalid high watermark ratio ", m_high_watermark_ratio);
26+
const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) :
27+
default_high_watermark_ratio;
28+
setHighWatermarkRatio(high_watermark_ratio);
3229

33-
m_max_total_allowed_size = (m_high_watermark_ratio == 0.0) ? std::numeric_limits<size_t>::max() :
34-
static_cast<size_t>(m_high_watermark_ratio * (double)max_device_size());
35-
// used for comparison with lower_watermark_ratio
36-
const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? high_watermark_upper_bound : m_high_watermark_ratio;
3730
const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified :
3831
default_low_watermark_ratio_discrete;
3932
static const char *low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
40-
m_low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
41-
TORCH_CHECK(m_low_watermark_ratio >= 0.0 && m_low_watermark_ratio <= high_watermark_limit,
42-
"invalid low watermark ratio ", m_low_watermark_ratio);
33+
const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
34+
setLowWatermarkRatio(low_watermark_ratio);
35+
}
36+
37+
void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio)
38+
{
39+
TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio);
40+
m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
41+
static_cast<size_t>(ratio * (double)max_device_size());
42+
m_high_watermark_ratio = ratio;
43+
}
44+
45+
void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio)
46+
{
47+
// used for comparison with lower_watermark_ratio
48+
const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
49+
TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio);
4350
// we use this to detect if there's memory pressure
44-
m_low_watermark_limit = (m_low_watermark_ratio == 0.0) ? std::numeric_limits<size_t>::max() :
45-
static_cast<size_t>(m_low_watermark_ratio * (double)max_device_size());
51+
m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits<size_t>::max() :
52+
static_cast<size_t>(ratio * (double)max_device_size());
53+
m_low_watermark_ratio = ratio;
4654
}
4755

4856
HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params)
@@ -470,7 +478,7 @@
470478
return buffer_block->buffer;
471479
}
472480

473-
ssize_t MPSHeapAllocatorImpl::getRequestedBufferSize(void* ptr)
481+
ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr)
474482
{
475483
std::lock_guard<std::mutex> lock(m_mutex);
476484

@@ -552,24 +560,24 @@
552560
}
553561

554562
// MPS allocator struct to be registered with Pytorch
555-
struct TORCH_API MPSAllocator final : public at::Allocator {
563+
struct TORCH_API MPSAllocator final : public IMPSAllocator {
556564
public:
557565
explicit MPSAllocator(uint32_t Usage) :
558566
m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage)
559567
{
560568
if (_getAllocImpl().getDebugVerbosity()) {
561569
if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
562-
const size_t max_total_allowed_size = _getAllocImpl().getMaxTotalAllowedSize();
563-
const size_t low_watermark_limit = _getAllocImpl().getLowWatermarkLimit();
570+
const size_t high_watermark_limit = _getAllocImpl().getHighWatermarkLimit();
571+
const size_t low_watermark_limit = _getAllocImpl().getLowWatermarkLimit();
564572
std::cerr << "Initializing "
565573
<< ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
566574
<< " heap allocator on "
567575
<< (m_has_unified_memory ? "unified" : "discrete")
568576
<< " device memory of size "
569577
<< _getAllocImpl().Device().recommendedMaxWorkingSetSize / 1048576UL << " MB"
570578
<< " (max allowed: "
571-
<< (max_total_allowed_size == std::numeric_limits<size_t>::max() ? "unlimited" :
572-
(to_string(max_total_allowed_size / 1048576UL) + " MB"))
579+
<< (high_watermark_limit == std::numeric_limits<size_t>::max() ? "unlimited" :
580+
(to_string(high_watermark_limit / 1048576UL) + " MB"))
573581
<< ", low watermark: "
574582
<< (low_watermark_limit == std::numeric_limits<size_t>::max() ? "unlimited" :
575583
(to_string(low_watermark_limit / 1048576UL) + " MB")) << ")\n";
@@ -580,20 +588,28 @@ explicit MPSAllocator(uint32_t Usage) :
580588
~MPSAllocator() override {
581589
_getAllocImpl().emptyCache();
582590
}
591+
DeleterFnPtr raw_deleter() const override { return &Delete; }
583592

584593
DataPtr allocate(const size_t nbytes) const override {
585594
__block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr;
586595
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
587596
}
588-
589-
DataPtr allocate_scalar_buffer(void *value, size_t size) const {
597+
DataPtr allocScalarBufferWithValue(void *value, size_t size) const override {
590598
id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size);
591599
return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
592600
}
593-
594-
DeleterFnPtr raw_deleter() const override { return &Delete; }
595-
bool is_shared(void* ptr) const { return _getAllocImpl().isSharedBuffer(ptr); }
596-
bool is_shared_storage_supported() const { return m_has_unified_memory; }
601+
bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); }
602+
bool isSharedStorageSupported() const override { return m_has_unified_memory; }
603+
void emptyCache() const override { _getAllocImpl().emptyCache(); }
604+
ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); }
605+
IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); }
606+
void setBufferShape(void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); }
607+
size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); }
608+
ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); }
609+
size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); }
610+
size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); }
611+
void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); }
612+
void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); }
597613

598614
private:
599615
bool m_has_unified_memory;
@@ -618,41 +634,17 @@ static void Delete(void* ptr) {
618634
}
619635
} // anonymous namespace
620636

621-
at::Allocator* getMPSSharedAllocator()
622-
{
637+
IMPSAllocator* getIMPSAllocator(bool sharedAllocator) {
638+
if (!sharedAllocator) {
639+
return &_getPrivateAllocator();
640+
}
623641
auto& sa = _getSharedAllocator();
624-
if (sa.is_shared_storage_supported()) {
642+
if (sa.isSharedStorageSupported()) {
625643
return &sa;
626644
}
627-
628645
return nullptr;
629646
}
630647

631-
at::Allocator* getMPSPrivateAllocator() {
632-
return &_getPrivateAllocator();
633-
}
634-
635-
// TODO: create MPSHooks interface and move these there.
636-
ssize_t get_requested_buffer_size(void* ptr) {
637-
return _getAllocImpl().getRequestedBufferSize(ptr);
638-
}
639-
640-
void set_buffer_shape(void* ptr, const IntArrayRef& shape) {
641-
_getAllocImpl().setBufferShape(ptr, shape);
642-
}
643-
644-
IntArrayRef get_buffer_shape(void* ptr) {
645-
return _getAllocImpl().getBufferShape(ptr);
646-
}
647-
648-
DataPtr allocate_scalar_buffer(void *value, size_t size) {
649-
return _getPrivateAllocator().allocate_scalar_buffer(value, size);
650-
}
651-
652-
uint32_t get_adaptive_commit_threshold() {
653-
return _getAllocImpl().getLowWatermarkValue();
654-
}
655-
656648
} // namespace mps
657649

658650
namespace native {
@@ -664,14 +656,14 @@ uint32_t get_adaptive_commit_threshold() {
664656
bool is_pinned_mps(const Tensor& self, c10::optional<Device> device)
665657
{
666658
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
667-
return at::mps::_getSharedAllocator().is_shared(self.storage().data());
659+
return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data());
668660
}
669661

670662
// torch.pin_memory() implementation
671663
Tensor _pin_memory_mps(const Tensor& self, c10::optional<Device> device)
672664
{
673665
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps());
674-
auto* shared_allocator = at::mps::getMPSSharedAllocator();
666+
auto* shared_allocator = at::mps::getIMPSAllocator(true);
675667
TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device");
676668

677669
const size_t storage_size = detail::computeStorageNbytes(self.sizes(), self.strides(), self.dtype().itemsize());
+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright © 2023 Apple Inc.
2+
3+
#pragma once
4+
5+
#include <c10/core/Allocator.h>
6+
#include <c10/util/Registry.h>
7+
#include <ATen/core/ATen_fwd.h>
8+
9+
namespace at {
10+
namespace mps {
11+
12+
// this is a public interface to access MPSAllocator.
13+
// Do not declare methods that would depend on MPS or Metal frameworks.
14+
class IMPSAllocator : public c10::Allocator {
15+
public:
16+
// see the comments in MPSAllocator.h for the description of these methods.
17+
virtual void emptyCache() const = 0;
18+
virtual ssize_t getUnalignedBufferSize(void* ptr) const = 0;
19+
virtual IntArrayRef getBufferShape(void* ptr) const = 0;
20+
virtual void setBufferShape(void* ptr, const IntArrayRef& shape) const = 0;
21+
virtual bool isSharedBuffer(void* ptr) const = 0;
22+
virtual bool isSharedStorageSupported() const = 0;
23+
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
24+
virtual void setLowWatermarkRatio(double ratio) const = 0;
25+
virtual void setHighWatermarkRatio(double ratio) const = 0;
26+
virtual ssize_t getLowWatermarkValue() const = 0;
27+
virtual size_t getLowWatermarkLimit() const = 0;
28+
virtual size_t getHighWatermarkLimit() const = 0;
29+
virtual size_t getTotalAllocatedMemory() const = 0;
30+
};
31+
32+
class IMpsAllocatorCallback {
33+
public:
34+
enum class EventType {
35+
ALLOCATED, // buffer got allocated to be used immediately
36+
RECYCLED, // buffer pulled from free list to be reused
37+
FREED, // buffer put to free list for future recycling
38+
RELEASED, // buffer memory released
39+
};
40+
virtual ~IMpsAllocatorCallback() = default;
41+
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
42+
};
43+
44+
// MPS allocator will execute every registered callback when a block of memory is freed.
45+
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
46+
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
47+
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
48+
49+
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
50+
51+
} // namespace mps
52+
} // namespace at

aten/src/ATen/mps/MPSDevice.mm

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/util/CallOnce.h>
44

55
#include <ATen/mps/MPSDevice.h>
6+
#include <ATen/mps/MPSAllocatorInterface.h>
67
#include <ATen/mps/IndexKernels.h>
78

89
namespace at {
@@ -94,10 +95,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
9495
return _macos13plus;
9596
}
9697

97-
at::Allocator* getMPSSharedAllocator();
98-
at::Allocator* getMPSPrivateAllocator();
9998
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
100-
return useSharedAllocator ? getMPSSharedAllocator() : getMPSPrivateAllocator();
99+
return getIMPSAllocator(useSharedAllocator);
101100
}
102101

103102
bool is_available() {

aten/src/ATen/mps/MPSStream.mm

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
// Copyright © 2022 Apple Inc.
22

33
#include <ATen/mps/MPSStream.h>
4+
#include <ATen/mps/MPSAllocatorInterface.h>
45

56
namespace at {
67
namespace mps {
78

89
#define USE_COMMIT_AND_CONTINUE 1
910

10-
// the frequency that we commit the command buffer calculated based on low watermark ratio in MPSAllocator
11-
uint32_t get_adaptive_commit_threshold();
12-
1311
//-----------------------------------------------------------------
1412
// MPSStream
1513
//-----------------------------------------------------------------
@@ -52,7 +50,7 @@
5250
break;
5351
case SyncType::COMMIT_ADAPTIVE:
5452
// the adaptive commit only commits if we hit the low watermark memory threshold
55-
if (get_adaptive_commit_threshold() <= 1) {
53+
if (getIMPSAllocator()->getLowWatermarkValue() <= 1) {
5654
#if USE_COMMIT_AND_CONTINUE
5755
commitAndContinue();
5856
#else

0 commit comments

Comments
 (0)