mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Allow offset to be an mx.array for mx.fast.rope
(#1724)
* allow offset for rope * comment
This commit is contained in:
@@ -79,7 +79,17 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
&mx::fast::rope,
|
||||
[](const mx::array& a,
|
||||
int dims,
|
||||
bool traditional,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
const ScalarOrArray& offset,
|
||||
const std::optional<mx::array>& freqs /* = std::nullopt */,
|
||||
mx::StreamOrDevice s /* = {} */) {
|
||||
return mx::fast::rope(
|
||||
a, dims, traditional, base, scale, to_array(offset), freqs, s);
|
||||
},
|
||||
"a"_a,
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
@@ -90,7 +100,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"freqs"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
@@ -104,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int): The position offset to start at.
|
||||
offset (int or array): The position offset to start at.
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||
|
||||
|
@@ -22,21 +22,25 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using Scalar = std::variant<int, double>;
|
||||
using Scalar = std::variant<bool, int, double>;
|
||||
|
||||
mx::Dtype scalar_to_dtype(Scalar scalar) {
|
||||
if (std::holds_alternative<int>(scalar)) {
|
||||
mx::Dtype scalar_to_dtype(Scalar s) {
|
||||
if (std::holds_alternative<int>(s)) {
|
||||
return mx::int32;
|
||||
} else {
|
||||
} else if (std::holds_alternative<double>(s)) {
|
||||
return mx::float32;
|
||||
} else {
|
||||
return mx::bool_;
|
||||
}
|
||||
}
|
||||
|
||||
double scalar_to_double(Scalar s) {
|
||||
if (std::holds_alternative<double>(s)) {
|
||||
return std::get<double>(s);
|
||||
if (auto pv = std::get_if<int>(&s); pv) {
|
||||
return static_cast<double>(*pv);
|
||||
} else if (auto pv = std::get_if<double>(&s); pv) {
|
||||
return *pv;
|
||||
} else {
|
||||
return static_cast<double>(std::get<int>(s));
|
||||
return static_cast<double>(std::get<bool>(s));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) {
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
"start"_a.noconvert(),
|
||||
"stop"_a.noconvert(),
|
||||
"step"_a.noconvert() = nb::none(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) {
|
||||
dtype,
|
||||
s);
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
"stop"_a.noconvert(),
|
||||
"step"_a.noconvert() = nb::none(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
|
@@ -8,6 +8,7 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||
N = x.shape[-2] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
@@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
bases = [10000.0, 1000000.0]
|
||||
scales = [1.0, 2.0]
|
||||
offsets = [0, 3]
|
||||
offsets = [0, 3, mx.array(3)]
|
||||
traditional = [True, False]
|
||||
|
||||
for traditional in [True, False]:
|
||||
|
@@ -1153,7 +1153,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.arange(float("inf"), 1, float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(TypeError):
|
||||
INT_MAX = 2147483647
|
||||
a = mx.arange(0, INT_MAX + 1, 1)
|
||||
|
||||
|
Reference in New Issue
Block a user