diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index ef218bf36..d6f44591e 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -4,23 +4,20 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" - template -[[kernel]] void rope_single( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], +void rope_single_impl( + const device T* in, + device T* out, constant const int& offset, - constant const float& base, + const float inv_freq, constant const float& scale, constant const size_t& stride, - uint2 pos [[thread_position_in_grid]], - uint2 grid [[threads_per_grid]]) { - // Figure out L and d. + uint2 pos, + uint2 grid) { float L = scale * static_cast(offset); - float d = static_cast(pos.x) / static_cast(grid.x); // Compute costheta, sintheta - float theta = L * metal::exp2(-d * base); + float theta = L * inv_freq; float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); @@ -55,24 +52,54 @@ template out[out_index_2] = static_cast(rx2); } -template -[[kernel]] void rope( +template +[[kernel]] void rope_single( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], constant const int& offset, - constant const float& base, + constant const float& scale, + constant const size_t& stride, + constant const float& base [[buffer(10)]], + uint2 pos [[thread_position_in_grid]], + uint2 grid [[threads_per_grid]]) { + float d = static_cast(pos.x) / static_cast(grid.x); + float inv_freq = metal::exp2(-d * base); + rope_single_impl( + in, out, offset, inv_freq, scale, stride, pos, grid); +} + +template +[[kernel]] void rope_single_freqs( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t& stride, + const device float* freqs [[buffer(10)]], + constant const size_t& freq_stride [[buffer(11)]], + uint2 pos [[thread_position_in_grid]], + uint2 grid [[threads_per_grid]]) { + float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); + rope_single_impl( + in, out, offset, inv_freq, scale, stride, pos, grid); +} + +template +void rope_impl( + const device T* in, + device T* out, + constant const int& offset, + const float inv_freq, constant const float& scale, constant const size_t strides[3], constant const size_t out_strides[3], constant const size_t& n_batch, - uint3 pos [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Figure out L and d. + uint3 pos, + uint3 grid) { float L = scale * static_cast(pos.y + offset); - float d = static_cast(pos.x) / static_cast(grid.x); // Compute costheta, sintheta - float theta = L * metal::exp2(-d * base); + float theta = L * inv_freq; float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); @@ -116,37 +143,115 @@ template } } +template +[[kernel]] void rope( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + constant const float& base [[buffer(10)]], + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + float d = static_cast(pos.x) / static_cast(grid.x); + float inv_freq = metal::exp2(-d * base); + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + grid); +} + +template +[[kernel]] void rope_freqs( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + const device float* freqs [[buffer(10)]], + constant const size_t& freq_stride [[buffer(11)]], + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + grid); +} + +// clang-format off #define instantiate_rope_g(name, type, traditional, forward) \ template [[host_name("rope_" #name)]] [[kernel]] void \ rope( \ const device type* in [[buffer(0)]], \ device type* out [[buffer(1)]], \ constant const int& offset, \ - constant const float& base, \ constant const float& scale, \ constant const size_t strides[3], \ constant const size_t out_strides[3], \ constant const size_t& n_batch, \ + constant const float& base [[buffer(10)]], \ + uint3 pos [[thread_position_in_grid]], \ + uint3 grid [[threads_per_grid]]); \ + template [[host_name("rope_freqs_" #name)]] \ + [[kernel]] void rope_freqs( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t strides[3], \ + constant const size_t out_strides[3], \ + constant const size_t& n_batch, \ + const device float* freqs [[buffer(10)]], \ + constant const size_t& freq_stride [[buffer(11)]], \ uint3 pos [[thread_position_in_grid]], \ uint3 grid [[threads_per_grid]]); -#define instantiate_rope_s(name, type, traditional, forward) \ - template [[host_name("rope_single_" #name)]] [[kernel]] void \ - rope_single( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& base, \ - constant const float& scale, \ - constant const size_t& stride, \ - uint2 pos [[thread_position_in_grid]], \ +#define instantiate_rope_s(name, type, traditional, forward) \ + template [[host_name("rope_single_" #name)]] [[kernel]] void \ + rope_single( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t& stride, \ + constant const float& base [[buffer(10)]], \ + uint2 pos [[thread_position_in_grid]], \ + uint2 grid [[threads_per_grid]]); \ + template [[host_name("rope_single_freqs_" #name)]] \ + [[kernel]] void rope_single_freqs( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& scale, \ + constant const size_t& stride, \ + const device float* freqs [[buffer(10)]], \ + constant const size_t& freq_stride [[buffer(11)]], \ + uint2 pos [[thread_position_in_grid]], \ uint2 grid [[threads_per_grid]]); #define instantiate_rope(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \ - instantiate_rope_g(name, type, traditional, forward) + instantiate_rope_g(name, type, traditional, forward) -// clang-format off instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(traditional_float32, float, true, true) diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 4d51fd75c..174f63bbc 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -67,8 +67,10 @@ void RoPE::eval_gpu( // Special case for inference (single time step and contiguous) bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool with_freqs = inputs.size() == 2; std::ostringstream kname; - kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_") + kname << "rope_" << (single ? "single_" : "") + << ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_") << (traditional_ ? "traditional_" : "") << type_to_name(in); auto kernel = d.get_kernel(kname.str()); auto& compute_encoder = d.get_command_encoder(s.index); @@ -78,27 +80,36 @@ void RoPE::eval_gpu( compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&offset_, sizeof(int), 2); - compute_encoder->setBytes(&base, sizeof(float), 3); - compute_encoder->setBytes(&scale_, sizeof(float), 4); + compute_encoder->setBytes(&scale_, sizeof(float), 3); size_t n_batch = in.size() / mat_size; + MTL::Size group_dims; + MTL::Size grid_dims; if (single) { - compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5); + compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4); uint32_t dim0 = dims_ / 2; - auto group_dims = get_block_dims(dim0, n_batch, 1); - auto grid_dims = MTL::Size(dim0, n_batch, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + group_dims = get_block_dims(dim0, n_batch, 1); + grid_dims = MTL::Size(dim0, n_batch, 1); } else { - compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5); - compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6); - compute_encoder->setBytes(&n_batch, sizeof(size_t), 7); + compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4); + compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5); + compute_encoder->setBytes(&n_batch, sizeof(size_t), 6); uint32_t dim0 = dims_ / 2; uint32_t dim1 = in.shape(-2); uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; - auto group_dims = get_block_dims(dim0, dim1, dim2); - auto grid_dims = MTL::Size(dim0, dim1, dim2); - compute_encoder.dispatchThreads(grid_dims, group_dims); + group_dims = get_block_dims(dim0, dim1, dim2); + grid_dims = MTL::Size(dim0, dim1, dim2); } + + if (with_freqs) { + auto& freqs = inputs[1]; + compute_encoder.set_input_array(freqs, 10); + auto freq_stride = freqs.strides()[0]; + compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11); + } else { + compute_encoder->setBytes(&base, sizeof(float), 10); + } + compute_encoder.dispatchThreads(grid_dims, group_dims); } } // namespace mlx::core::fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index e067c9b4b..dfe68827f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include @@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const { } array rope( - const array& x, + std::vector inputs, int dims, bool traditional, float base, @@ -331,15 +330,23 @@ array rope( int offset, bool forward, StreamOrDevice s) { + auto& x = inputs[0]; if (x.ndim() < 3) { std::ostringstream msg; msg << "[rope] Input must have at least 3 dimensions but got input with " << x.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } + if (inputs.size() == 2 && + (inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) { + std::ostringstream msg; + msg << "[rope] freqs must be one dimensional with size " << dims + << " but got shape " << inputs[1].shape() << "."; + throw std::invalid_argument(msg.str()); + } auto fallback = [dims, traditional, base, scale, offset, forward, s]( - const std::vector& inputs) { + std::vector inputs) { auto& shape = inputs[0].shape(); int ndim = shape.size(); auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s); @@ -348,10 +355,20 @@ array rope( // Compute sines and cosines auto half_dims = dims / 2; auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); - auto freqs = negative(arange(0, half_dims, t, s), s); - freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); + + auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() { + return exp( + multiply( + arange(0, -half_dims, -1, t, s), + array(std::log(base) / half_dims, t), + s), + s); + }; + + auto inv_freqs = + inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs(); auto theta = - multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); + multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s); auto coss = cos(theta, s); auto sins = sin(theta, s); @@ -409,20 +426,39 @@ array rope( x.dtype(), std::make_shared( stream, fallback, dims, traditional, base, scale, offset, forward), - {x}); + std::move(inputs)); } - return fallback({x})[0]; + return fallback(std::move(inputs))[0]; } array rope( const array& x, int dims, bool traditional, - float base, + std::optional base, float scale, int offset, + const std::optional& freqs /* = std::nullopt */, StreamOrDevice s /* = {} */) { - return rope(x, dims, traditional, base, scale, offset, true, s); + std::vector inputs = {x}; + if (freqs) { + inputs.push_back(astype(*freqs, float32, s)); + if (base) { + throw std::invalid_argument( + "[rope] Only one of base or freqs can have a value."); + } + } else if (!base) { + throw std::invalid_argument("[rope] Neither base nor freqs has a value."); + } + return rope( + std::move(inputs), + dims, + traditional, + base.has_value() ? *base : 1.0, + scale, + offset, + true, + s); } std::vector RoPE::vjp( @@ -438,16 +474,27 @@ std::vector RoPE::vjp( offset = offset_, forward = forward_, s](std::vector inputs) { - return std::vector{ - rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)}; + return std::vector{rope( + std::move(inputs), + dims, + traditional, + base, + scale, + offset, + !forward, + s)}; }; + auto inputs = cotangents; + if (primals.size() == 2) { + inputs.push_back(primals[1]); + } return {array( cotangents[0].shape(), cotangents[0].dtype(), std::make_shared( s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_), - cotangents)}; + std::move(inputs))}; } bool RoPE::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast.h b/mlx/fast.h index 48e95a768..0274bf6dd 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -25,9 +25,10 @@ array rope( const array& x, int dims, bool traditional, - float base, + std::optional base, float scale, int offset, + const std::optional& freqs = std::nullopt, StreamOrDevice s = {}); /** Computes: O = softmax(Q @ K.T) @ V **/ diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 349618c23..4c25c7ac7 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -79,26 +79,29 @@ void init_fast(nb::module_& parent_module) { "dims"_a, nb::kw_only(), "traditional"_a, - "base"_a, + "base"_a.none(), "scale"_a, "offset"_a, + "freqs"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), + "def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Apply rotary positional encoding to the input. Args: a (array): Input array. dims (int): The feature dimensions to be rotated. If the input feature - is larger than dims then the rest is left unchanged. + is larger than dims then the rest is left unchanged. traditional (bool): If set to ``True`` choose the traditional - implementation which rotates consecutive dimensions. - base (float): The base used to compute angular frequency for - each dimension in the positional encodings. + implementation which rotates consecutive dimensions. + base (float, optional): The base used to compute angular frequency for + each dimension in the positional encodings. Exactly one of ``base`` and + ``freqs`` must be ``None``. scale (float): The scale used to scale the positions. offset (int): The position offset to start at. - + freqs (array, optional): Optional frequencies to use with RoPE. + If set, the ``base`` parameter must be ``None``. ``Default: None``. Returns: array: The output array. )pbdoc"); @@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) { "memory_efficient_threshold"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 6ee5b7e3d..c68f7e423 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -7,13 +7,18 @@ import mlx.core as mx import mlx_tests -def rope_orig(x, dims, traditional, base, scale, offset): +def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): N = x.shape[1] + offset dtype = x.dtype half_D = dims // 2 positions = mx.arange(offset, N, dtype=dtype) * scale - freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) - theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + if freqs is None: + inv_freqs = mx.exp( + -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) + ) + else: + inv_freqs = 1 / freqs + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) costheta, sintheta = mx.cos(theta), mx.sin(theta) if traditional: x1 = x[..., :dims:2] @@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + def test_rope_with_freqs(self): + # Check throws + T = 4 + dims = 8 + x = mx.random.uniform(shape=(2, T, dims)) + + with self.assertRaises(ValueError): + freqs = mx.random.uniform(shape=(dims - 1,)) + mx.fast.rope( + x, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + with self.assertRaises(ValueError): + freqs = mx.random.uniform(shape=(1, dims)) + mx.fast.rope( + x, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + + freqs = mx.random.uniform(shape=(dims // 2,)) + + rx = rope_orig(x, dims, False, None, 1.0, 0, freqs) + rx_fast = mx.fast.rope( + x, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + + # Test single vector + x = mx.random.uniform(shape=(1, 1, dims)) + rx = rope_orig(x, dims, False, None, 1.0, 0, freqs) + rx_fast = mx.fast.rope( + x, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + + # Test grad with freqs + f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum() + f2 = lambda x, y: ( + mx.fast.rope( + x, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + * y + ).sum() + + x = mx.random.uniform(shape=(2, 4, dims)) + y = mx.random.uniform(shape=(2, 4, dims)) + g1 = mx.grad(f1)(x, y) + g2 = mx.grad(f2)(x, y) + self.assertLess(mx.abs(g1 - g2).max(), 1e-5) + def test_rope_grad(self): D = 32 defaults = (D, 10000.0, 1.0, 0, False)