Allow offset to be an mx.array for mx.fast.rope (#1724)

* allow offset for rope

* comment
This commit is contained in:
Awni Hannun
2024-12-19 15:51:44 -08:00
committed by GitHub
parent c3628eea49
commit 0308e9af71
8 changed files with 97 additions and 52 deletions

View File

@@ -66,7 +66,7 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
@@ -78,7 +78,7 @@ void RoPE::eval_gpu(
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(offset_, 2);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
@@ -101,7 +101,7 @@ void RoPE::eval_gpu(
}
if (with_freqs) {
auto& freqs = inputs[1];
auto& freqs = inputs[2];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder.set_bytes(freq_stride, 11);

View File

@@ -338,34 +338,50 @@ array rope(
bool traditional,
float base,
float scale,
int offset,
bool forward,
StreamOrDevice s) {
auto& x = inputs[0];
auto& offset = inputs[1];
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
if (offset.size() != 1) {
std::ostringstream msg;
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
<< ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(offset.dtype(), integer)) {
std::ostringstream msg;
msg << "[rope] offset must be an integer but got type " << offset.dtype()
<< ".";
throw std::invalid_argument(msg.str());
}
if (offset.dtype().size() != 4) {
inputs[1] = astype(offset, uint32, s);
}
if (inputs.size() == 3 &&
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims / 2
<< " but got shape " << inputs[1].shape() << ".";
<< " but got shape " << inputs[2].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [dims, traditional, base, scale, offset, forward, s](
auto fallback = [dims, traditional, base, scale, forward, s](
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
auto x = flatten(inputs[0], 0, ndim - 3, s);
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto& offset = inputs[1];
auto positions =
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
@@ -377,7 +393,7 @@ array rope(
};
auto inv_freqs =
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s);
@@ -436,7 +452,7 @@ array rope(
x.shape(),
x.dtype(),
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset, forward),
stream, fallback, dims, traditional, base, scale, forward),
std::move(inputs));
}
return fallback(std::move(inputs))[0];
@@ -448,10 +464,10 @@ array rope(
bool traditional,
std::optional<float> base,
float scale,
int offset,
const array& offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
std::vector<array> inputs = {x};
std::vector<array> inputs = {x, offset};
if (freqs) {
inputs.push_back(astype(*freqs, float32, s));
if (base) {
@@ -467,11 +483,23 @@ array rope(
traditional,
base.has_value() ? *base : 1.0,
scale,
offset,
true,
s);
}
array rope(
const array& x,
int dims,
bool traditional,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return rope(
x, dims, traditional, base, scale, array(offset, int32), freqs, s);
}
std::vector<array> RoPE::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -482,29 +510,24 @@ std::vector<array> RoPE::vjp(
traditional = traditional_,
base = base_,
scale = scale_,
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{rope(
std::move(inputs),
dims,
traditional,
base,
scale,
offset,
!forward,
s)};
return std::vector<array>{
rope(std::move(inputs), dims, traditional, base, scale, !forward, s)};
};
auto inputs = cotangents;
if (primals.size() == 2) {
inputs.push_back(primals[1]);
if (argnums.size() > 1 || argnums[0] != 0) {
throw std::invalid_argument(
"[RoPE::vjp] vjp for offset or frequencies not supported");
}
auto inputs = std::vector<array>{cotangents[0], primals[1]};
if (primals.size() == 3) {
inputs.push_back(primals[2]);
}
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
s, fallback, dims_, traditional_, base_, scale_, !forward_),
std::move(inputs))};
}
@@ -513,7 +536,7 @@ bool RoPE::is_equivalent(const Primitive& other) const {
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_ && forward_ == a_other.forward_);
forward_ == a_other.forward_);
}
/** Computes: O = softmax(Q @ K.T) @ V **/

View File

@@ -31,6 +31,16 @@ array rope(
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
bool traditional,
std::optional<float> base,
float scale,
const array& offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(
const array& queries,

View File

@@ -150,14 +150,12 @@ class RoPE : public Custom {
bool traditional,
float base,
float scale,
int offset,
bool forward)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset),
forward_(forward) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -183,7 +181,6 @@ class RoPE : public Custom {
bool traditional_;
float base_;
float scale_;
int offset_;
bool forward_;
};