No reshape rope (#838)

* no reshape rope

* no reshape rope
This commit is contained in:
Awni Hannun 2024-03-18 17:03:07 -07:00 committed by GitHub
parent eaba55c9bf
commit 16546c70d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 24 deletions

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
@ -13,39 +13,63 @@ void RoPE::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() != 3) { if (in.ndim() < 3) {
throw std::runtime_error( throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
} }
if (dims_ != in.shape(-1)) { if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported"); throw std::runtime_error("[RoPE] Partial RoPE application not supported");
} }
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
size_t strides[3];
bool donated = false;
int ndim = in.ndim();
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc_or_wait(out.nbytes()));
strides[0] = in.strides()[0];
strides[1] = in.strides()[1];
strides[2] = in.strides()[2];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
std::ostringstream kname; std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in); kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_); float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0); set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3); compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4); compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5); compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2; int dim0 = in.shape()[ndim - 1] / 2;
int dim1 = in.shape(1); int dim1 = in.shape()[ndim - 2];
int dim2 = in.shape(0); int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2); auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2); auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);

View File

@ -54,10 +54,10 @@ array rope(
float scale, float scale,
int offset, int offset,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) { if (x.ndim() < 3) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim() msg << "[rope] Input must have at least 3 dimensions but got input with "
<< " dimensions."; << x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (traditional && x.shape(-1) != dims) { if (traditional && x.shape(-1) != dims) {
@ -67,7 +67,9 @@ array rope(
auto fallback = [dims, traditional, base, scale, offset, s]( auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
auto& x = inputs[0]; auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
auto t = x.dtype(); auto t = x.dtype();
auto N = x.shape(1) + offset; auto N = x.shape(1) + offset;
// Compute sines and cosines // Compute sines and cosines
@ -89,7 +91,7 @@ array rope(
for (auto& o : outs) { for (auto& o : outs) {
o = expand_dims(o, 3, s); o = expand_dims(o, 3, s);
} }
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)}; return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
} else { } else {
auto out_s = x.shape(); auto out_s = x.shape();
out_s.back() = half_dims; out_s.back() = half_dims;
@ -103,7 +105,7 @@ array rope(
if (dims < x.shape(-1)) { if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s)); outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
} }
return std::vector<array>{concatenate(outs, 2, s)}; return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);

View File

@ -44,9 +44,7 @@ class RoPE(Module):
return f"{self.dims}, traditional={self.traditional}" return f"{self.dims}, traditional={self.traditional}"
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
shape = x.shape return mx.fast.rope(
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
x = mx.fast.rope(
x, x,
self.dims, self.dims,
traditional=self.traditional, traditional=self.traditional,
@ -54,7 +52,6 @@ class RoPE(Module):
scale=self.scale, scale=self.scale,
offset=offset, offset=offset,
) )
return mx.reshape(x, shape)
class SinusoidalPositionalEncoding(Module): class SinusoidalPositionalEncoding(Module):