diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 2974d010b..71662c8ae 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -914,10 +914,12 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { } // Clear copies - if (copies.size() > 0) { + if (!copies.empty()) { auto command_buffer = d.get_command_buffer(s.index); command_buffer->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 9fac2e89f..d7e9b143f 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -81,7 +81,9 @@ void CustomKernel::eval_gpu( if (!copies.empty()) { d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 202c83544..791e1fa00 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -576,7 +576,9 @@ void fft_op( if (plan.four_step) { four_step_fft(in, out, axis, inverse, real, plan, copies, s); d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); return; } @@ -741,8 +743,13 @@ void fft_op( MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); compute_encoder->dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void fft_op( diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 46d77b03e..b4ae377d5 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -196,8 +196,12 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { s); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 2cfeb2b26..dbcf1b4df 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -227,9 +227,12 @@ void steel_matmul_conv_groups( compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); - return; + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void steel_matmul( @@ -379,8 +382,12 @@ void steel_matmul( compute_encoder.dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } @@ -507,9 +514,12 @@ void steel_matmul( compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); - return; + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void Matmul::eval_gpu(const std::vector& inputs, array& out) { @@ -680,8 +690,12 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } ///////////////////////////////////////////////////////////////////////////// @@ -886,8 +900,12 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } @@ -1000,8 +1018,12 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } @@ -1136,9 +1158,12 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); - return; + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { @@ -1433,8 +1458,12 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } @@ -1545,9 +1574,12 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); - return; + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void GatherMM::eval_gpu(const std::vector& inputs, array& out) { @@ -1773,8 +1805,12 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } return; } @@ -1914,9 +1950,12 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); - return; + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 8338c0dbe..b9053fa4f 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -91,8 +91,12 @@ void RMSNorm::eval_gpu( compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); compute_encoder.dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void RMSNormVJP::eval_gpu( @@ -109,6 +113,12 @@ void RMSNormVJP::eval_gpu( if (x.flags().row_contiguous) { return x; } + // Make sure we 'll only ever allocate once. The point of that goes beyond + // the minor optimization. We need to ensure that there will be no + // reallocation such that the references won't change when we + // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. + copies.reserve(3); + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); copy_gpu(x, copies.back(), CopyType::General, s); return copies.back(); @@ -195,7 +205,9 @@ void RMSNormVJP::eval_gpu( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } void LayerNorm::eval_gpu( @@ -280,8 +292,12 @@ void LayerNorm::eval_gpu( compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); compute_encoder.dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void LayerNormVJP::eval_gpu( @@ -298,6 +314,12 @@ void LayerNormVJP::eval_gpu( if (x.flags().row_contiguous) { return x; } + // Make sure we 'll only ever allocate once. The point of that goes beyond + // the minor optimization. We need to ensure that there will be no + // reallocation such that the references won't change when we + // push_back(...). So tl;dr 3 possible copies x, g and gw_temp. + copies.reserve(3); + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); copy_gpu(x, copies.back(), CopyType::General, s); return copies.back(); @@ -404,7 +426,9 @@ void LayerNormVJP::eval_gpu( } d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index bd5c19fe3..de70c7562 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -209,8 +209,12 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { } } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { @@ -599,8 +603,12 @@ void fast::AffineQuantize::eval_gpu( : MTL::Size(nthreads, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c62b84206..5881effa4 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -662,7 +662,9 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { if (!copies.empty()) { d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } } diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 00596c687..a113353a7 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -107,10 +107,12 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (copies.size() > 0) { + if (!copies.empty()) { auto command_buffer = d.get_command_buffer(s.index); command_buffer->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } } diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 706343ff7..7af5d6f53 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -88,8 +88,12 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder.dispatchThreads(grid_dims, group_dims); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index eecf1cef5..0de69f9c5 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -254,7 +254,9 @@ void multi_block_sort( // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( - [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); } void gpu_merge_sort(