diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index f35588144..df5301244 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -19,7 +19,7 @@ void RMSNorm::eval_gpu( // Make sure that the last dimension is contiguous std::vector copies; - auto check_input = [&copies, &s](const array& x) { + auto check_input = [&copies, &s](const array& x) -> const array& { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -28,10 +28,9 @@ void RMSNorm::eval_gpu( if (no_copy) { return x; } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - copies.push_back(x_copy); - return x_copy; + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); } }; const array& x = check_input(inputs[0]); @@ -106,15 +105,13 @@ void RMSNormVJP::eval_gpu( // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. std::vector copies; - auto check_input = [&copies, &s](const array& x) { + auto check_input = [&copies, &s](const array& x) -> const array& { if (x.flags().row_contiguous) { return x; } - - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - copies.push_back(x_copy); - return x_copy; + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -149,8 +146,11 @@ void RMSNormVJP::eval_gpu( gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); } copies.push_back(gw_temp); - array zero(0, gw.dtype()); - copy_gpu(zero, gw, CopyType::Scalar, s); + { + array zero(0, gw.dtype()); + copy_gpu(zero, gw, CopyType::Scalar, s); + copies.push_back(std::move(zero)); + } const int simd_size = 32; const int n_reads = RMS_N_READS; @@ -212,7 +212,7 @@ void LayerNorm::eval_gpu( // Make sure that the last dimension is contiguous std::vector copies; - auto check_input = [&copies, &s](const array& x) { + auto check_input = [&copies, &s](const array& x) -> const array& { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -221,10 +221,9 @@ void LayerNorm::eval_gpu( if (no_copy) { return x; } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - copies.push_back(x_copy); - return x_copy; + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); } }; const array& x = check_input(inputs[0]); @@ -300,15 +299,13 @@ void LayerNormVJP::eval_gpu( // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. std::vector copies; - auto check_input = [&copies, &s](const array& x) { + auto check_input = [&copies, &s](const array& x) -> const array& { if (x.flags().row_contiguous) { return x; } - - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - copies.push_back(x_copy); - return x_copy; + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); }; const array& x = check_input(inputs[0]); const array& w = inputs[1]; @@ -345,9 +342,12 @@ void LayerNormVJP::eval_gpu( gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); } copies.push_back(gw_temp); - array zero(0, gw.dtype()); - copy_gpu(zero, gw, CopyType::Scalar, s); - copy_gpu(zero, gb, CopyType::Scalar, s); + { + array zero(0, gw.dtype()); + copy_gpu(zero, gw, CopyType::Scalar, s); + copy_gpu(zero, gb, CopyType::Scalar, s); + copies.push_back(std::move(zero)); + } // Finish with the gradient for b in case we had a b auto compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index a934a9d2d..12d89a665 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -21,7 +21,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous std::vector copies; - auto check_input = [&copies, &s](const array& x) { + auto check_input = [&copies, &s](const array& x) -> const array& { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; @@ -30,10 +30,9 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { if (no_copy) { return x; } else { - array x_copy(x.shape(), x.dtype(), nullptr, {}); - copy_gpu(x, x_copy, CopyType::General, s); - copies.push_back(x_copy); - return x_copy; + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); } }; const array& in = check_input(inputs[0]); @@ -81,7 +80,6 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); set_array_buffer( compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); - set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);