diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 195d29c2e..1ca3597e5 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -66,7 +66,7 @@ void RoPE::eval_gpu( // Special case for inference (single time step and contiguous) bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); - bool with_freqs = inputs.size() == 2; + bool with_freqs = inputs.size() == 3; std::ostringstream kname; kname << "rope_" << (single ? "single_" : "") << ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_") @@ -78,7 +78,7 @@ void RoPE::eval_gpu( compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(offset_, 2); + compute_encoder.set_input_array(inputs[1], 2); compute_encoder.set_bytes(scale_, 3); size_t n_batch = in.size() / mat_size; @@ -101,7 +101,7 @@ void RoPE::eval_gpu( } if (with_freqs) { - auto& freqs = inputs[1]; + auto& freqs = inputs[2]; compute_encoder.set_input_array(freqs, 10); auto freq_stride = freqs.strides()[0]; compute_encoder.set_bytes(freq_stride, 11); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 01089d82d..0b50b4282 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -338,34 +338,50 @@ array rope( bool traditional, float base, float scale, - int offset, bool forward, StreamOrDevice s) { auto& x = inputs[0]; + auto& offset = inputs[1]; if (x.ndim() < 3) { std::ostringstream msg; msg << "[rope] Input must have at least 3 dimensions but got input with " << x.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - if (inputs.size() == 2 && - (inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) { + if (offset.size() != 1) { + std::ostringstream msg; + msg << "[rope] offset must be a scalar but has shape " << offset.shape() + << "."; + throw std::invalid_argument(msg.str()); + } + if (!issubdtype(offset.dtype(), integer)) { + std::ostringstream msg; + msg << "[rope] offset must be an integer but got type " << offset.dtype() + << "."; + throw std::invalid_argument(msg.str()); + } + if (offset.dtype().size() != 4) { + inputs[1] = astype(offset, uint32, s); + } + if (inputs.size() == 3 && + (inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) { std::ostringstream msg; msg << "[rope] freqs must be one dimensional with size " << dims / 2 - << " but got shape " << inputs[1].shape() << "."; + << " but got shape " << inputs[2].shape() << "."; throw std::invalid_argument(msg.str()); } - auto fallback = [dims, traditional, base, scale, offset, forward, s]( + auto fallback = [dims, traditional, base, scale, forward, s]( std::vector inputs) { auto& shape = inputs[0].shape(); int ndim = shape.size(); - auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s); + auto x = flatten(inputs[0], 0, ndim - 3, s); auto t = x.dtype(); - auto N = x.shape(1) + offset; // Compute sines and cosines auto half_dims = dims / 2; - auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); + auto& offset = inputs[1]; + auto positions = + multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s); auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { return exp( @@ -377,7 +393,7 @@ array rope( }; auto inv_freqs = - inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs(); + inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs(); auto theta = multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s); auto coss = cos(theta, s); @@ -436,7 +452,7 @@ array rope( x.shape(), x.dtype(), std::make_shared( - stream, fallback, dims, traditional, base, scale, offset, forward), + stream, fallback, dims, traditional, base, scale, forward), std::move(inputs)); } return fallback(std::move(inputs))[0]; @@ -448,10 +464,10 @@ array rope( bool traditional, std::optional base, float scale, - int offset, + const array& offset, const std::optional& freqs /* = std::nullopt */, StreamOrDevice s /* = {} */) { - std::vector inputs = {x}; + std::vector inputs = {x, offset}; if (freqs) { inputs.push_back(astype(*freqs, float32, s)); if (base) { @@ -467,11 +483,23 @@ array rope( traditional, base.has_value() ? *base : 1.0, scale, - offset, true, s); } +array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + int offset, + const std::optional& freqs /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + return rope( + x, dims, traditional, base, scale, array(offset, int32), freqs, s); +} + std::vector RoPE::vjp( const std::vector& primals, const std::vector& cotangents, @@ -482,29 +510,24 @@ std::vector RoPE::vjp( traditional = traditional_, base = base_, scale = scale_, - offset = offset_, forward = forward_, s](std::vector inputs) { - return std::vector{rope( - std::move(inputs), - dims, - traditional, - base, - scale, - offset, - !forward, - s)}; + return std::vector{ + rope(std::move(inputs), dims, traditional, base, scale, !forward, s)}; }; - - auto inputs = cotangents; - if (primals.size() == 2) { - inputs.push_back(primals[1]); + if (argnums.size() > 1 || argnums[0] != 0) { + throw std::invalid_argument( + "[RoPE::vjp] vjp for offset or frequencies not supported"); + } + auto inputs = std::vector{cotangents[0], primals[1]}; + if (primals.size() == 3) { + inputs.push_back(primals[2]); } return {array( cotangents[0].shape(), cotangents[0].dtype(), std::make_shared( - s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_), + s, fallback, dims_, traditional_, base_, scale_, !forward_), std::move(inputs))}; } @@ -513,7 +536,7 @@ bool RoPE::is_equivalent(const Primitive& other) const { return ( dims_ == a_other.dims_ && base_ == a_other.base_ && scale_ == a_other.scale_ && traditional_ == a_other.traditional_ && - offset_ == a_other.offset_ && forward_ == a_other.forward_); + forward_ == a_other.forward_); } /** Computes: O = softmax(Q @ K.T) @ V **/ diff --git a/mlx/fast.h b/mlx/fast.h index 0b9608eec..9e6586cf6 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -31,6 +31,16 @@ array rope( const std::optional& freqs = std::nullopt, StreamOrDevice s = {}); +array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + const array& offset, + const std::optional& freqs = std::nullopt, + StreamOrDevice s = {}); + /** Computes: O = softmax(Q @ K.T) @ V **/ array scaled_dot_product_attention( const array& queries, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 0ec316327..0ddc815d9 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -150,14 +150,12 @@ class RoPE : public Custom { bool traditional, float base, float scale, - int offset, bool forward) : Custom(stream, fallback), dims_(dims), traditional_(traditional), base_(base), scale_(scale), - offset_(offset), forward_(forward) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -183,7 +181,6 @@ class RoPE : public Custom { bool traditional_; float base_; float scale_; - int offset_; bool forward_; }; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 5f04f3e69..91a7293fb 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -79,7 +79,17 @@ void init_fast(nb::module_& parent_module) { m.def( "rope", - &mx::fast::rope, + [](const mx::array& a, + int dims, + bool traditional, + std::optional base, + float scale, + const ScalarOrArray& offset, + const std::optional& freqs /* = std::nullopt */, + mx::StreamOrDevice s /* = {} */) { + return mx::fast::rope( + a, dims, traditional, base, scale, to_array(offset), freqs, s); + }, "a"_a, "dims"_a, nb::kw_only(), @@ -90,7 +100,7 @@ void init_fast(nb::module_& parent_module) { "freqs"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Apply rotary positional encoding to the input. @@ -104,7 +114,7 @@ void init_fast(nb::module_& parent_module) { each dimension in the positional encodings. Exactly one of ``base`` and ``freqs`` must be ``None``. scale (float): The scale used to scale the positions. - offset (int): The position offset to start at. + offset (int or array): The position offset to start at. freqs (array, optional): Optional frequencies to use with RoPE. If set, the ``base`` parameter must be ``None``. Default: ``None``. diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1f5ecb5cf..99a81fb8e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -22,21 +22,25 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using Scalar = std::variant; +using Scalar = std::variant; -mx::Dtype scalar_to_dtype(Scalar scalar) { - if (std::holds_alternative(scalar)) { +mx::Dtype scalar_to_dtype(Scalar s) { + if (std::holds_alternative(s)) { return mx::int32; - } else { + } else if (std::holds_alternative(s)) { return mx::float32; + } else { + return mx::bool_; } } double scalar_to_double(Scalar s) { - if (std::holds_alternative(s)) { - return std::get(s); + if (auto pv = std::get_if(&s); pv) { + return static_cast(*pv); + } else if (auto pv = std::get_if(&s); pv) { + return *pv; } else { - return static_cast(std::get(s)); + return static_cast(std::get(s)); } } @@ -1367,9 +1371,9 @@ void init_ops(nb::module_& m) { dtype, s); }, - "start"_a, - "stop"_a, - "step"_a = nb::none(), + "start"_a.noconvert(), + "stop"_a.noconvert(), + "step"_a.noconvert() = nb::none(), "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), @@ -1413,8 +1417,8 @@ void init_ops(nb::module_& m) { dtype, s); }, - "stop"_a, - "step"_a = nb::none(), + "stop"_a.noconvert(), + "step"_a.noconvert() = nb::none(), "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 56006e59f..524b43d1b 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -8,6 +8,7 @@ import mlx_tests def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): + offset = offset.item() if isinstance(offset, mx.array) else offset N = x.shape[-2] + offset dtype = x.dtype half_D = dims // 2 @@ -76,7 +77,7 @@ class TestFast(mlx_tests.MLXTestCase): dtypes = [mx.float32, mx.float16, mx.bfloat16] bases = [10000.0, 1000000.0] scales = [1.0, 2.0] - offsets = [0, 3] + offsets = [0, 3, mx.array(3)] traditional = [True, False] for traditional in [True, False]: diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7d010135a..d76a8143e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1153,7 +1153,7 @@ class TestOps(mlx_tests.MLXTestCase): a = mx.arange(float("inf"), 1, float("inf")) with self.assertRaises(ValueError): a = mx.arange(float("inf"), 1, 5) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): INT_MAX = 2147483647 a = mx.arange(0, INT_MAX + 1, 1)