Allow offset to be an mx.array for mx.fast.rope (#1724)

* allow offset for rope

* comment
This commit is contained in:
Awni Hannun
2024-12-19 15:51:44 -08:00
committed by GitHub
parent c3628eea49
commit 0308e9af71
8 changed files with 97 additions and 52 deletions

View File

@@ -79,7 +79,17 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rope",
&mx::fast::rope,
[](const mx::array& a,
int dims,
bool traditional,
std::optional<float> base,
float scale,
const ScalarOrArray& offset,
const std::optional<mx::array>& freqs /* = std::nullopt */,
mx::StreamOrDevice s /* = {} */) {
return mx::fast::rope(
a, dims, traditional, base, scale, to_array(offset), freqs, s);
},
"a"_a,
"dims"_a,
nb::kw_only(),
@@ -90,7 +100,7 @@ void init_fast(nb::module_& parent_module) {
"freqs"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
@@ -104,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
offset (int or array): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. Default: ``None``.

View File

@@ -22,21 +22,25 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using Scalar = std::variant<int, double>;
using Scalar = std::variant<bool, int, double>;
mx::Dtype scalar_to_dtype(Scalar scalar) {
if (std::holds_alternative<int>(scalar)) {
mx::Dtype scalar_to_dtype(Scalar s) {
if (std::holds_alternative<int>(s)) {
return mx::int32;
} else {
} else if (std::holds_alternative<double>(s)) {
return mx::float32;
} else {
return mx::bool_;
}
}
double scalar_to_double(Scalar s) {
if (std::holds_alternative<double>(s)) {
return std::get<double>(s);
if (auto pv = std::get_if<int>(&s); pv) {
return static_cast<double>(*pv);
} else if (auto pv = std::get_if<double>(&s); pv) {
return *pv;
} else {
return static_cast<double>(std::get<int>(s));
return static_cast<double>(std::get<bool>(s));
}
}
@@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) {
dtype,
s);
},
"start"_a,
"stop"_a,
"step"_a = nb::none(),
"start"_a.noconvert(),
"stop"_a.noconvert(),
"step"_a.noconvert() = nb::none(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
@@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) {
dtype,
s);
},
"stop"_a,
"step"_a = nb::none(),
"stop"_a.noconvert(),
"step"_a.noconvert() = nb::none(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),

View File

@@ -8,6 +8,7 @@ import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
offset = offset.item() if isinstance(offset, mx.array) else offset
N = x.shape[-2] + offset
dtype = x.dtype
half_D = dims // 2
@@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase):
dtypes = [mx.float32, mx.float16, mx.bfloat16]
bases = [10000.0, 1000000.0]
scales = [1.0, 2.0]
offsets = [0, 3]
offsets = [0, 3, mx.array(3)]
traditional = [True, False]
for traditional in [True, False]:

View File

@@ -1153,7 +1153,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.arange(float("inf"), 1, float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, 5)
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
INT_MAX = 2147483647
a = mx.arange(0, INT_MAX + 1, 1)