mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
parent
eaba55c9bf
commit
16546c70d8
@ -1,5 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
@ -13,39 +13,63 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
strides[0] = in.strides()[0];
|
||||
strides[1] = in.strides()[1];
|
||||
strides[2] = in.strides()[2];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
int dim0 = in.shape()[ndim - 1] / 2;
|
||||
int dim1 = in.shape()[ndim - 2];
|
||||
int dim2 = in.size() / mat_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
14
mlx/fast.cpp
14
mlx/fast.cpp
@ -54,10 +54,10 @@ array rope(
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() != 3) {
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
@ -67,7 +67,9 @@ array rope(
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
@ -89,7 +91,7 @@ array rope(
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
@ -103,7 +105,7 @@ array rope(
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
|
@ -44,9 +44,7 @@ class RoPE(Module):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
x = mx.fast.rope(
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
@ -54,7 +52,6 @@ class RoPE(Module):
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
return mx.reshape(x, shape)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
|
Loading…
Reference in New Issue
Block a user