From 6ee1112f30b259090f8163930d0ff5d7fe040fee Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 22 Mar 2024 17:28:26 -0700 Subject: [PATCH] Fix copy donation and add partial rope (#881) --- benchmarks/python/rope_bench.py | 8 ++++---- mlx/backend/metal/copy.cpp | 7 +++++++ mlx/backend/metal/kernels/rope.metal | 8 +++++--- mlx/backend/metal/rope.cpp | 30 ++++++++++++++++++---------- mlx/fast.cpp | 2 +- mlx/primitives.cpp | 7 +++++-- 6 files changed, 42 insertions(+), 20 deletions(-) diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py index 62f01648e..35479c0b1 100644 --- a/benchmarks/python/rope_bench.py +++ b/benchmarks/python/rope_bench.py @@ -6,21 +6,21 @@ from time_utils import time_fn def time_rope(): - rope = nn.RoPE(4096) + rope = nn.RoPE(64) # vec - x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) mx.eval(x) def rope_vec(x): for _ in range(32): - x = rope(x) + x = rope(x, offset=100) return x time_fn(rope_vec, x) # matrix - x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) mx.eval(x) def rope_mat(x): diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 0885f5691..cb5fe289a 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -12,8 +12,15 @@ namespace mlx::core { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { if (ctype == CopyType::Vector) { + // If the input is donateable, we are doing a vector copy and the types + // have the same size, then the input buffer can hold the output. if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.move_shared_buffer(in); + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + if (in.dtype() == out.dtype()) { + return; + } } else { out.set_data( allocator::malloc_or_wait(in.data_size() * out.itemsize()), diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index 484697b6d..52290e42c 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -10,6 +10,7 @@ template const device T *in [[buffer(0)]], device T * out [[buffer(1)]], constant const size_t strides[3], + constant const size_t out_strides[3], constant const int& offset, constant const float& base, constant const float& scale, @@ -19,13 +20,13 @@ template uint in_index_1, in_index_2; uint out_index_1, out_index_2; if (traditional) { - out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z)); + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; out_index_2 = out_index_1 + 1; in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { - out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z)); - out_index_2 = out_index_1 + grid.x; + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0]; + out_index_2 = out_index_1 + grid.x * out_strides[2]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; in_index_2 = in_index_1 + grid.x * strides[2]; } @@ -54,6 +55,7 @@ template const device type* in [[buffer(0)]], \ device type* out [[buffer(1)]], \ constant const size_t strides[3], \ + constant const size_t out_strides[3], \ constant const int& offset, \ constant const float& base, \ constant const float& scale, \ diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index ba04f7241..54d056643 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -16,18 +16,24 @@ void RoPE::eval_gpu( if (in.ndim() < 3) { throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); } - if (dims_ != in.shape(-1)) { - throw std::runtime_error("[RoPE] Partial RoPE application not supported"); - } auto& s = out.primitive().stream(); auto& d = metal::device(s.device); size_t strides[3]; + size_t out_strides[3]; bool donated = false; int ndim = in.ndim(); - size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1]; - if (in.flags().row_contiguous) { + size_t mat_size = in.shape(-2) * in.shape(-1); + if (dims_ < in.shape(-1)) { + donated = true; + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } else if (in.flags().row_contiguous) { if (in.is_donatable()) { donated = true; out.move_shared_buffer(in); @@ -52,6 +58,9 @@ void RoPE::eval_gpu( strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; std::ostringstream kname; kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); @@ -63,12 +72,13 @@ void RoPE::eval_gpu( set_array_buffer(compute_encoder, donated ? out : in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); - compute_encoder->setBytes(&offset_, sizeof(int), 3); - compute_encoder->setBytes(&base, sizeof(float), 4); - compute_encoder->setBytes(&scale_, sizeof(float), 5); + compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3); + compute_encoder->setBytes(&offset_, sizeof(int), 4); + compute_encoder->setBytes(&base, sizeof(float), 5); + compute_encoder->setBytes(&scale_, sizeof(float), 6); - int dim0 = in.shape()[ndim - 1] / 2; - int dim1 = in.shape()[ndim - 2]; + int dim0 = dims_ / 2; + int dim1 = in.shape(-2); int dim2 = in.size() / mat_size; auto group_dims = get_block_dims(dim0, dim1, dim2); auto grid_dims = MTL::Size(dim0, dim1, dim2); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 076213875..568e2604f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -244,7 +244,7 @@ array rope( } }; auto stream = to_stream(s); - if (stream.device == Device::gpu && x.shape(-1) == dims) { + if (stream.device == Device::gpu) { return array( x.shape(), x.dtype(), diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f7cb29f89..4fb0fc7f4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2590,8 +2590,11 @@ std::vector Scatter::vjp( break; case Scatter::Max: case Scatter::Min: { - auto mask = where(result == values, array({1}), array({0})); - vjps.push_back(multiply(cotangents[0], mask)); + vjps.push_back(where( + equal(result, values, stream()), + cotangents[0], + array(0, cotangents[0].dtype()), + stream())); break; } default: