mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow offset to be an mx.array for mx.fast.rope (#1724)
* allow offset for rope * comment
This commit is contained in:
81
mlx/fast.cpp
81
mlx/fast.cpp
@@ -338,34 +338,50 @@ array rope(
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
auto& x = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (inputs.size() == 2 &&
|
||||
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
|
||||
if (offset.size() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(offset.dtype(), integer)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be an integer but got type " << offset.dtype()
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.dtype().size() != 4) {
|
||||
inputs[1] = astype(offset, uint32, s);
|
||||
}
|
||||
if (inputs.size() == 3 &&
|
||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] freqs must be one dimensional with size " << dims / 2
|
||||
<< " but got shape " << inputs[1].shape() << ".";
|
||||
<< " but got shape " << inputs[2].shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto& offset = inputs[1];
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
@@ -377,7 +393,7 @@ array rope(
|
||||
};
|
||||
|
||||
auto inv_freqs =
|
||||
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
|
||||
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
@@ -436,7 +452,7 @@ array rope(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
stream, fallback, dims, traditional, base, scale, forward),
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(std::move(inputs))[0];
|
||||
@@ -448,10 +464,10 @@ array rope(
|
||||
bool traditional,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const array& offset,
|
||||
const std::optional<array>& freqs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> inputs = {x};
|
||||
std::vector<array> inputs = {x, offset};
|
||||
if (freqs) {
|
||||
inputs.push_back(astype(*freqs, float32, s));
|
||||
if (base) {
|
||||
@@ -467,11 +483,23 @@ array rope(
|
||||
traditional,
|
||||
base.has_value() ? *base : 1.0,
|
||||
scale,
|
||||
offset,
|
||||
true,
|
||||
s);
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const std::optional<array>& freqs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(
|
||||
x, dims, traditional, base, scale, array(offset, int32), freqs, s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -482,29 +510,24 @@ std::vector<array> RoPE::vjp(
|
||||
traditional = traditional_,
|
||||
base = base_,
|
||||
scale = scale_,
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{rope(
|
||||
std::move(inputs),
|
||||
dims,
|
||||
traditional,
|
||||
base,
|
||||
scale,
|
||||
offset,
|
||||
!forward,
|
||||
s)};
|
||||
return std::vector<array>{
|
||||
rope(std::move(inputs), dims, traditional, base, scale, !forward, s)};
|
||||
};
|
||||
|
||||
auto inputs = cotangents;
|
||||
if (primals.size() == 2) {
|
||||
inputs.push_back(primals[1]);
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[RoPE::vjp] vjp for offset or frequencies not supported");
|
||||
}
|
||||
auto inputs = std::vector<array>{cotangents[0], primals[1]};
|
||||
if (primals.size() == 3) {
|
||||
inputs.push_back(primals[2]);
|
||||
}
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
s, fallback, dims_, traditional_, base_, scale_, !forward_),
|
||||
std::move(inputs))};
|
||||
}
|
||||
|
||||
@@ -513,7 +536,7 @@ bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
||||
forward_ == a_other.forward_);
|
||||
}
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
|
||||
Reference in New Issue
Block a user