diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 352958a89..5f2bb73d3 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -471,6 +471,10 @@ void Device::end_encoding(int index) { CommandEncoder& Device::get_command_encoder(int index) { auto& stream = get_stream_(index); if (stream.encoder == nullptr) { + // Ensure there is an active command buffer + if (stream.buffer == nullptr) { + get_command_buffer(index); + } stream.encoder = std::make_unique(stream); stream.fence = std::make_shared(device_->newFence()); }