mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow offset to be an mx.array for mx.fast.rope
(#1724)
* allow offset for rope * comment
This commit is contained in:
parent
c3628eea49
commit
0308e9af71
@ -66,7 +66,7 @@ void RoPE::eval_gpu(
|
|||||||
// Special case for inference (single time step and contiguous)
|
// Special case for inference (single time step and contiguous)
|
||||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||||
|
|
||||||
bool with_freqs = inputs.size() == 2;
|
bool with_freqs = inputs.size() == 3;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "rope_" << (single ? "single_" : "")
|
kname << "rope_" << (single ? "single_" : "")
|
||||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||||
@ -78,7 +78,7 @@ void RoPE::eval_gpu(
|
|||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
compute_encoder.set_bytes(offset_, 2);
|
compute_encoder.set_input_array(inputs[1], 2);
|
||||||
compute_encoder.set_bytes(scale_, 3);
|
compute_encoder.set_bytes(scale_, 3);
|
||||||
|
|
||||||
size_t n_batch = in.size() / mat_size;
|
size_t n_batch = in.size() / mat_size;
|
||||||
@ -101,7 +101,7 @@ void RoPE::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (with_freqs) {
|
if (with_freqs) {
|
||||||
auto& freqs = inputs[1];
|
auto& freqs = inputs[2];
|
||||||
compute_encoder.set_input_array(freqs, 10);
|
compute_encoder.set_input_array(freqs, 10);
|
||||||
auto freq_stride = freqs.strides()[0];
|
auto freq_stride = freqs.strides()[0];
|
||||||
compute_encoder.set_bytes(freq_stride, 11);
|
compute_encoder.set_bytes(freq_stride, 11);
|
||||||
|
81
mlx/fast.cpp
81
mlx/fast.cpp
@ -338,34 +338,50 @@ array rope(
|
|||||||
bool traditional,
|
bool traditional,
|
||||||
float base,
|
float base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
|
||||||
bool forward,
|
bool forward,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
|
auto& offset = inputs[1];
|
||||||
if (x.ndim() < 3) {
|
if (x.ndim() < 3) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||||
<< x.ndim() << " dimensions.";
|
<< x.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (inputs.size() == 2 &&
|
if (offset.size() != 1) {
|
||||||
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
|
std::ostringstream msg;
|
||||||
|
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||||
|
<< ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (!issubdtype(offset.dtype(), integer)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rope] offset must be an integer but got type " << offset.dtype()
|
||||||
|
<< ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (offset.dtype().size() != 4) {
|
||||||
|
inputs[1] = astype(offset, uint32, s);
|
||||||
|
}
|
||||||
|
if (inputs.size() == 3 &&
|
||||||
|
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rope] freqs must be one dimensional with size " << dims / 2
|
msg << "[rope] freqs must be one dimensional with size " << dims / 2
|
||||||
<< " but got shape " << inputs[1].shape() << ".";
|
<< " but got shape " << inputs[2].shape() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||||
std::vector<array> inputs) {
|
std::vector<array> inputs) {
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||||
auto t = x.dtype();
|
auto t = x.dtype();
|
||||||
auto N = x.shape(1) + offset;
|
|
||||||
// Compute sines and cosines
|
// Compute sines and cosines
|
||||||
auto half_dims = dims / 2;
|
auto half_dims = dims / 2;
|
||||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
auto& offset = inputs[1];
|
||||||
|
auto positions =
|
||||||
|
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||||
|
|
||||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||||
return exp(
|
return exp(
|
||||||
@ -377,7 +393,7 @@ array rope(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto inv_freqs =
|
auto inv_freqs =
|
||||||
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
|
inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
|
||||||
auto theta =
|
auto theta =
|
||||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||||
auto coss = cos(theta, s);
|
auto coss = cos(theta, s);
|
||||||
@ -436,7 +452,7 @@ array rope(
|
|||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<RoPE>(
|
std::make_shared<RoPE>(
|
||||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
stream, fallback, dims, traditional, base, scale, forward),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
return fallback(std::move(inputs))[0];
|
return fallback(std::move(inputs))[0];
|
||||||
@ -448,10 +464,10 @@ array rope(
|
|||||||
bool traditional,
|
bool traditional,
|
||||||
std::optional<float> base,
|
std::optional<float> base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
const array& offset,
|
||||||
const std::optional<array>& freqs /* = std::nullopt */,
|
const std::optional<array>& freqs /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
std::vector<array> inputs = {x};
|
std::vector<array> inputs = {x, offset};
|
||||||
if (freqs) {
|
if (freqs) {
|
||||||
inputs.push_back(astype(*freqs, float32, s));
|
inputs.push_back(astype(*freqs, float32, s));
|
||||||
if (base) {
|
if (base) {
|
||||||
@ -467,11 +483,23 @@ array rope(
|
|||||||
traditional,
|
traditional,
|
||||||
base.has_value() ? *base : 1.0,
|
base.has_value() ? *base : 1.0,
|
||||||
scale,
|
scale,
|
||||||
offset,
|
|
||||||
true,
|
true,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array rope(
|
||||||
|
const array& x,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
std::optional<float> base,
|
||||||
|
float scale,
|
||||||
|
int offset,
|
||||||
|
const std::optional<array>& freqs /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return rope(
|
||||||
|
x, dims, traditional, base, scale, array(offset, int32), freqs, s);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> RoPE::vjp(
|
std::vector<array> RoPE::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
@ -482,29 +510,24 @@ std::vector<array> RoPE::vjp(
|
|||||||
traditional = traditional_,
|
traditional = traditional_,
|
||||||
base = base_,
|
base = base_,
|
||||||
scale = scale_,
|
scale = scale_,
|
||||||
offset = offset_,
|
|
||||||
forward = forward_,
|
forward = forward_,
|
||||||
s](std::vector<array> inputs) {
|
s](std::vector<array> inputs) {
|
||||||
return std::vector<array>{rope(
|
return std::vector<array>{
|
||||||
std::move(inputs),
|
rope(std::move(inputs), dims, traditional, base, scale, !forward, s)};
|
||||||
dims,
|
|
||||||
traditional,
|
|
||||||
base,
|
|
||||||
scale,
|
|
||||||
offset,
|
|
||||||
!forward,
|
|
||||||
s)};
|
|
||||||
};
|
};
|
||||||
|
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||||
auto inputs = cotangents;
|
throw std::invalid_argument(
|
||||||
if (primals.size() == 2) {
|
"[RoPE::vjp] vjp for offset or frequencies not supported");
|
||||||
inputs.push_back(primals[1]);
|
}
|
||||||
|
auto inputs = std::vector<array>{cotangents[0], primals[1]};
|
||||||
|
if (primals.size() == 3) {
|
||||||
|
inputs.push_back(primals[2]);
|
||||||
}
|
}
|
||||||
return {array(
|
return {array(
|
||||||
cotangents[0].shape(),
|
cotangents[0].shape(),
|
||||||
cotangents[0].dtype(),
|
cotangents[0].dtype(),
|
||||||
std::make_shared<RoPE>(
|
std::make_shared<RoPE>(
|
||||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
s, fallback, dims_, traditional_, base_, scale_, !forward_),
|
||||||
std::move(inputs))};
|
std::move(inputs))};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -513,7 +536,7 @@ bool RoPE::is_equivalent(const Primitive& other) const {
|
|||||||
return (
|
return (
|
||||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
forward_ == a_other.forward_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||||
|
10
mlx/fast.h
10
mlx/fast.h
@ -31,6 +31,16 @@ array rope(
|
|||||||
const std::optional<array>& freqs = std::nullopt,
|
const std::optional<array>& freqs = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array rope(
|
||||||
|
const array& x,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
std::optional<float> base,
|
||||||
|
float scale,
|
||||||
|
const array& offset,
|
||||||
|
const std::optional<array>& freqs = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||||
array scaled_dot_product_attention(
|
array scaled_dot_product_attention(
|
||||||
const array& queries,
|
const array& queries,
|
||||||
|
@ -150,14 +150,12 @@ class RoPE : public Custom {
|
|||||||
bool traditional,
|
bool traditional,
|
||||||
float base,
|
float base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
|
||||||
bool forward)
|
bool forward)
|
||||||
: Custom(stream, fallback),
|
: Custom(stream, fallback),
|
||||||
dims_(dims),
|
dims_(dims),
|
||||||
traditional_(traditional),
|
traditional_(traditional),
|
||||||
base_(base),
|
base_(base),
|
||||||
scale_(scale),
|
scale_(scale),
|
||||||
offset_(offset),
|
|
||||||
forward_(forward) {}
|
forward_(forward) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
@ -183,7 +181,6 @@ class RoPE : public Custom {
|
|||||||
bool traditional_;
|
bool traditional_;
|
||||||
float base_;
|
float base_;
|
||||||
float scale_;
|
float scale_;
|
||||||
int offset_;
|
|
||||||
bool forward_;
|
bool forward_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -79,7 +79,17 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rope",
|
"rope",
|
||||||
&mx::fast::rope,
|
[](const mx::array& a,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
std::optional<float> base,
|
||||||
|
float scale,
|
||||||
|
const ScalarOrArray& offset,
|
||||||
|
const std::optional<mx::array>& freqs /* = std::nullopt */,
|
||||||
|
mx::StreamOrDevice s /* = {} */) {
|
||||||
|
return mx::fast::rope(
|
||||||
|
a, dims, traditional, base, scale, to_array(offset), freqs, s);
|
||||||
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"dims"_a,
|
"dims"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -90,7 +100,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"freqs"_a = nb::none(),
|
"freqs"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Apply rotary positional encoding to the input.
|
Apply rotary positional encoding to the input.
|
||||||
|
|
||||||
@ -104,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||||
``freqs`` must be ``None``.
|
``freqs`` must be ``None``.
|
||||||
scale (float): The scale used to scale the positions.
|
scale (float): The scale used to scale the positions.
|
||||||
offset (int): The position offset to start at.
|
offset (int or array): The position offset to start at.
|
||||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||||
|
|
||||||
|
@ -22,21 +22,25 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using Scalar = std::variant<int, double>;
|
using Scalar = std::variant<bool, int, double>;
|
||||||
|
|
||||||
mx::Dtype scalar_to_dtype(Scalar scalar) {
|
mx::Dtype scalar_to_dtype(Scalar s) {
|
||||||
if (std::holds_alternative<int>(scalar)) {
|
if (std::holds_alternative<int>(s)) {
|
||||||
return mx::int32;
|
return mx::int32;
|
||||||
} else {
|
} else if (std::holds_alternative<double>(s)) {
|
||||||
return mx::float32;
|
return mx::float32;
|
||||||
|
} else {
|
||||||
|
return mx::bool_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double scalar_to_double(Scalar s) {
|
double scalar_to_double(Scalar s) {
|
||||||
if (std::holds_alternative<double>(s)) {
|
if (auto pv = std::get_if<int>(&s); pv) {
|
||||||
return std::get<double>(s);
|
return static_cast<double>(*pv);
|
||||||
|
} else if (auto pv = std::get_if<double>(&s); pv) {
|
||||||
|
return *pv;
|
||||||
} else {
|
} else {
|
||||||
return static_cast<double>(std::get<int>(s));
|
return static_cast<double>(std::get<bool>(s));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) {
|
|||||||
dtype,
|
dtype,
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
"start"_a,
|
"start"_a.noconvert(),
|
||||||
"stop"_a,
|
"stop"_a.noconvert(),
|
||||||
"step"_a = nb::none(),
|
"step"_a.noconvert() = nb::none(),
|
||||||
"dtype"_a = nb::none(),
|
"dtype"_a = nb::none(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) {
|
|||||||
dtype,
|
dtype,
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
"stop"_a,
|
"stop"_a.noconvert(),
|
||||||
"step"_a = nb::none(),
|
"step"_a.noconvert() = nb::none(),
|
||||||
"dtype"_a = nb::none(),
|
"dtype"_a = nb::none(),
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
@ -8,6 +8,7 @@ import mlx_tests
|
|||||||
|
|
||||||
|
|
||||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||||
|
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||||
N = x.shape[-2] + offset
|
N = x.shape[-2] + offset
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
half_D = dims // 2
|
half_D = dims // 2
|
||||||
@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||||
bases = [10000.0, 1000000.0]
|
bases = [10000.0, 1000000.0]
|
||||||
scales = [1.0, 2.0]
|
scales = [1.0, 2.0]
|
||||||
offsets = [0, 3]
|
offsets = [0, 3, mx.array(3)]
|
||||||
traditional = [True, False]
|
traditional = [True, False]
|
||||||
|
|
||||||
for traditional in [True, False]:
|
for traditional in [True, False]:
|
||||||
|
@ -1153,7 +1153,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
a = mx.arange(float("inf"), 1, float("inf"))
|
a = mx.arange(float("inf"), 1, float("inf"))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
a = mx.arange(float("inf"), 1, 5)
|
a = mx.arange(float("inf"), 1, 5)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TypeError):
|
||||||
INT_MAX = 2147483647
|
INT_MAX = 2147483647
|
||||||
a = mx.arange(0, INT_MAX + 1, 1)
|
a = mx.arange(0, INT_MAX + 1, 1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user