From 16546c70d8983e9860f7456c0dfc63176fc95568 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Mar 2024 17:03:07 -0700 Subject: [PATCH] No reshape rope (#838) * no reshape rope * no reshape rope --- mlx/backend/metal/rope.cpp | 52 +++++++++++++++------ mlx/fast.cpp | 14 +++--- python/mlx/nn/layers/positional_encoding.py | 5 +- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index fdea57985..ba04f7241 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. - +#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" @@ -13,39 +13,63 @@ void RoPE::eval_gpu( auto& in = inputs[0]; auto& out = outputs[0]; - if (in.ndim() != 3) { - throw std::runtime_error( - "[RoPE] Only 3 dimensions are supported (batch x sequence x dims)"); + if (in.ndim() < 3) { + throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); } if (dims_ != in.shape(-1)) { throw std::runtime_error("[RoPE] Partial RoPE application not supported"); } - if (in.flags().row_contiguous && in.is_donatable()) { - out.move_shared_buffer(in); - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } auto& s = out.primitive().stream(); auto& d = metal::device(s.device); + + size_t strides[3]; + bool donated = false; + int ndim = in.ndim(); + size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1]; + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.move_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(allocator::malloc_or_wait(out.nbytes())); + strides[0] = in.strides()[0]; + strides[1] = in.strides()[1]; + strides[2] = in.strides()[2]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + std::ostringstream kname; kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); auto kernel = d.get_kernel(kname.str()); auto compute_encoder = d.get_command_encoder(s.index); - bool donated = in.data_shared_ptr() == nullptr; float base = std::log2(base_); compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, donated ? out : in, 0); set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); + compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); compute_encoder->setBytes(&offset_, sizeof(int), 3); compute_encoder->setBytes(&base, sizeof(float), 4); compute_encoder->setBytes(&scale_, sizeof(float), 5); - int dim0 = in.shape(2) / 2; - int dim1 = in.shape(1); - int dim2 = in.shape(0); + int dim0 = in.shape()[ndim - 1] / 2; + int dim1 = in.shape()[ndim - 2]; + int dim2 = in.size() / mat_size; auto group_dims = get_block_dims(dim0, dim1, dim2); auto grid_dims = MTL::Size(dim0, dim1, dim2); compute_encoder->dispatchThreads(grid_dims, group_dims); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 76c8bdf1a..cfdae139a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -54,10 +54,10 @@ array rope( float scale, int offset, StreamOrDevice s /* = {} */) { - if (x.ndim() != 3) { + if (x.ndim() < 3) { std::ostringstream msg; - msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() - << " dimensions."; + msg << "[rope] Input must have at least 3 dimensions but got input with " + << x.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (traditional && x.shape(-1) != dims) { @@ -67,7 +67,9 @@ array rope( auto fallback = [dims, traditional, base, scale, offset, s]( const std::vector& inputs) { - auto& x = inputs[0]; + auto& shape = inputs[0].shape(); + int ndim = shape.size(); + auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s); auto t = x.dtype(); auto N = x.shape(1) + offset; // Compute sines and cosines @@ -89,7 +91,7 @@ array rope( for (auto& o : outs) { o = expand_dims(o, 3, s); } - return std::vector{reshape(concatenate(outs, 3, s), x.shape(), s)}; + return std::vector{reshape(concatenate(outs, 3, s), shape, s)}; } else { auto out_s = x.shape(); out_s.back() = half_dims; @@ -103,7 +105,7 @@ array rope( if (dims < x.shape(-1)) { outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); } - return std::vector{concatenate(outs, 2, s)}; + return std::vector{reshape(concatenate(outs, 2, s), shape, s)}; } }; auto stream = to_stream(s); diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index f0bb92863..470366e19 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -44,9 +44,7 @@ class RoPE(Module): return f"{self.dims}, traditional={self.traditional}" def __call__(self, x, offset: int = 0): - shape = x.shape - x = mx.reshape(x, (-1, shape[-2], shape[-1])) - x = mx.fast.rope( + return mx.fast.rope( x, self.dims, traditional=self.traditional, @@ -54,7 +52,6 @@ class RoPE(Module): scale=self.scale, offset=offset, ) - return mx.reshape(x, shape) class SinusoidalPositionalEncoding(Module):