No reshape rope (#838)

* no reshape rope

* no reshape rope
This commit is contained in:
Awni Hannun 2024-03-18 17:03:07 -07:00 committed by GitHub
parent eaba55c9bf
commit 16546c70d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 24 deletions

View File

@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
@ -13,39 +13,63 @@ void RoPE::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
bool donated = false;
int ndim = in.ndim();
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc_or_wait(out.nbytes()));
strides[0] = in.strides()[0];
strides[1] = in.strides()[1];
strides[2] = in.strides()[2];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
int dim0 = in.shape()[ndim - 1] / 2;
int dim1 = in.shape()[ndim - 2];
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);

View File

@ -54,10 +54,10 @@ array rope(
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
@ -67,7 +67,9 @@ array rope(
auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
@ -89,7 +91,7 @@ array rope(
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
@ -103,7 +105,7 @@ array rope(
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
}
};
auto stream = to_stream(s);

View File

@ -44,9 +44,7 @@ class RoPE(Module):
return f"{self.dims}, traditional={self.traditional}"
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
x = mx.fast.rope(
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
@ -54,7 +52,6 @@ class RoPE(Module):
scale=self.scale,
offset=offset,
)
return mx.reshape(x, shape)
class SinusoidalPositionalEncoding(Module):