mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal * cuda rope (#2576)
This commit is contained in:
60
mlx/fast.cpp
60
mlx/fast.cpp
@@ -366,10 +366,16 @@ array rope(
|
||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1) {
|
||||
if (offset.ndim() > 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
<< ".";
|
||||
msg << "[rope] offset must have at most one dimension but has shape "
|
||||
<< offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1 && offset.size() != x.shape(0)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
|
||||
<< " elements but has shape " << offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(offset.dtype(), integer)) {
|
||||
@@ -379,7 +385,7 @@ array rope(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.dtype().size() != 4) {
|
||||
inputs[1] = astype(offset, uint32, s);
|
||||
inputs[1] = astype(offset, int32, s);
|
||||
}
|
||||
if (inputs.size() == 3 &&
|
||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||
@@ -391,15 +397,26 @@ array rope(
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||
auto x = inputs[0];
|
||||
auto shape = x.shape();
|
||||
if (x.ndim() == 3) {
|
||||
x = expand_dims(x, 1, s);
|
||||
} else if (x.ndim() > 4) {
|
||||
x = flatten(x, 1, 1 + (x.ndim() - 4), s);
|
||||
}
|
||||
|
||||
auto B = x.shape(0);
|
||||
auto N = x.shape(1);
|
||||
auto T = x.shape(2);
|
||||
auto t = x.dtype();
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto& offset = inputs[1];
|
||||
auto offset = inputs[1];
|
||||
if (offset.size() > 1) {
|
||||
offset = expand_dims(offset, {-1, -2}, s);
|
||||
}
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
@@ -412,8 +429,7 @@ array rope(
|
||||
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@@ -436,32 +452,30 @@ array rope(
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
o = expand_dims(o, -1, s);
|
||||
}
|
||||
auto out = concatenate(outs, 3, s);
|
||||
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
out =
|
||||
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
|
||||
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, -1, s), shape, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
|
||||
Reference in New Issue
Block a user