Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,29 @@ void CommandEncoder::end_encoding() {
next_outputs_.clear();
concurrent_outputs_.clear();
all_inputs_.clear();

check_error();
}

bool CommandEncoder::needs_commit() const {
auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer();
return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb);
}

void CommandEncoder::commit() {
void CommandEncoder::commit(std::function<void()> completion) {
buffer_->addCompletedHandler(
[this, completion = std::move(completion)](MTL::CommandBuffer* cbuf) {
if (completion) {
completion();
}
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::atomic_store(
&error_,
std::make_shared<std::string>(fmt::format(
"[METAL] Command buffer execution failed: {}.",
cbuf->error()->localizedDescription()->utf8String())));
}
});
buffer_->commit();
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
buffer_ops_ = 0;
Expand All @@ -456,22 +471,32 @@ void CommandEncoder::commit() {

void CommandEncoder::synchronize() {
auto pool = new_scoped_memory_pool();
auto cb = NS::RetainPtr(get_command_buffer());
auto cbuf = buffer_; // retained
end_encoding();
commit();
cb->waitUntilCompleted();
if (!exiting_) {
if (cb->status() == MTL::CommandBufferStatusError) {
throw std::runtime_error(
fmt::format(
"[METAL] Command buffer execution failed: {}.",
cb->error()->localizedDescription()->utf8String()));
}
cbuf->waitUntilCompleted();
check_error();
}

void CommandEncoder::check_error() {
// Do not check error during encoding, otherwise it would leave the program in
// corrupted state.
if (encoder_) {
return;
}
// When exiting with pending GPU commands, errors will happen, ignore them.
if (exiting_) {
return;
}
auto error = std::atomic_exchange(&error_, {});
if (error) {
throw std::runtime_error(*error);
}
}

MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {
if (!encoder_) {
check_error();
encoder_ = NS::RetainPtr(
buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent));
fence_ = NS::TransferPtr(device_.mtl_device()->newFence());
Expand Down
9 changes: 5 additions & 4 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ class MLX_API CommandEncoder {
void barrier();
void end_encoding();
bool needs_commit() const;
void commit();
void commit(std::function<void()> completion = nullptr);
void synchronize();
void check_error();

MTL::CommandQueue* get_command_queue() const {
return queue_.get();
}
MTL::CommandBuffer* get_command_buffer() const {
return buffer_.get();
}
Expand All @@ -113,6 +111,9 @@ class MLX_API CommandEncoder {
int buffer_ops_{0};
size_t buffer_sizes_{0};

// Error from previous commited command buffer.
std::shared_ptr<std::string> error_;

// Encoder for issuing GPU commands.
// The members are used within a single ComputeCommandEncoder and will be
// reset after calling end_encoding().
Expand Down
23 changes: 4 additions & 19 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@ void new_stream(Stream s) {
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}

void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
Expand Down Expand Up @@ -60,17 +51,12 @@ void eval(array& arr) {
if (encoder.needs_commit()) {
encoder.end_encoding();
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
encoder.commit();
encoder.commit([s, buffers = std::move(buffers)]() {
scheduler::notify_task_completion(s);
});
} else {
command_buffer->addCompletedHandler(
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {});
}
}

Expand All @@ -79,7 +65,6 @@ void finalize(Stream s) {
auto& encoder = metal::get_command_encoder(s);
auto* cb = encoder.get_command_buffer();
encoder.end_encoding();
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
encoder.commit();
}

Expand Down
10 changes: 7 additions & 3 deletions mlx/backend/metal/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ Event::Event(Stream stream) : stream_(stream) {
}

void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(event_.get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
auto* event = static_cast<MTL::SharedEvent*>(event_.get());
// When error happened in command buffer, the event would wait indefinitely
// if we don't set a timeout.
while (!event->waitUntilSignaledValue(value(), 5 * 1000)) {
for (auto& [_, encoder] : metal::get_command_encoders()) {
encoder.check_error();
}
}
}

Expand Down
9 changes: 6 additions & 3 deletions mlx/backend/metal/fence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ void Fence::wait(Stream stream, const array& x) {
scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {
auto& f = *static_cast<FenceImpl*>(fence_.get());
if (!f.use_fast) {
if (!static_cast<MTL::SharedEvent*>(f.fence)->waitUntilSignaledValue(
count, -1)) {
throw std::runtime_error("[Fence::wait] Timed out");
// Same with Event::wait
auto* event = static_cast<MTL::SharedEvent*>(f.fence);
while (!event->waitUntilSignaledValue(count, 5 * 1000)) {
for (auto& [_, encoder] : metal::get_command_encoders()) {
encoder.check_error();
}
}
return;
}
Expand Down
Loading