RoPE with frequencies as optional input (#1337)

* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
This commit is contained in:
Awni Hannun 2024-08-19 18:30:50 -07:00 committed by GitHub
parent 9d26441224
commit bb1b76d9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 319 additions and 69 deletions

View File

@ -4,23 +4,20 @@
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single( void rope_single_impl(
const device T* in [[buffer(0)]], const device T* in,
device T* out [[buffer(1)]], device T* out,
constant const int& offset, constant const int& offset,
constant const float& base, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const size_t& stride,
uint2 pos [[thread_position_in_grid]], uint2 pos,
uint2 grid [[threads_per_grid]]) { uint2 grid) {
// Figure out L and d.
float L = scale * static_cast<float>(offset); float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * metal::exp2(-d * base); float theta = L * inv_freq;
float costheta = metal::fast::cos(theta); float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta); float sintheta = metal::fast::sin(theta);
@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2); out[out_index_2] = static_cast<T>(rx2);
} }
template <typename T, bool traditional, bool forward, int N = 4> template <typename T, bool traditional, bool forward>
[[kernel]] void rope( [[kernel]] void rope_single(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& base, constant const float& scale,
constant const size_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const size_t strides[3],
constant const size_t out_strides[3], constant const size_t out_strides[3],
constant const size_t& n_batch, constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]], uint3 pos,
uint3 grid [[threads_per_grid]]) { uint3 grid) {
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset); float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * metal::exp2(-d * base); float theta = L * inv_freq;
float costheta = metal::fast::cos(theta); float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta); float sintheta = metal::fast::sin(theta);
@ -116,37 +143,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
} }
} }
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \ #define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \ template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \ rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \ const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \ device type* out [[buffer(1)]], \
constant const int& offset, \ constant const int& offset, \
constant const float& base, \
constant const float& scale, \ constant const float& scale, \
constant const size_t strides[3], \ constant const size_t strides[3], \
constant const size_t out_strides[3], \ constant const size_t out_strides[3], \
constant const size_t& n_batch, \ constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \ uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \ #define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \ template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \ rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \ const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \ device type* out [[buffer(1)]], \
constant const int& offset, \ constant const int& offset, \
constant const float& base, \ constant const float& scale, \
constant const float& scale, \ constant const size_t& stride, \
constant const size_t& stride, \ constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \ uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward) instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true) instantiate_rope(traditional_float32, float, true, true)

View File

@ -67,8 +67,10 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous) // Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
std::ostringstream kname; std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_") kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in); << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@ -78,27 +80,36 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2); compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3); compute_encoder->setBytes(&scale_, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
size_t n_batch = in.size() / mat_size; size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) { if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5); compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1); group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6); compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7); compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2; uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2); uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2); group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2); grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (with_freqs) {
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
} }
array rope( array rope(
const array& x, std::vector<array> inputs,
int dims, int dims,
bool traditional, bool traditional,
float base, float base,
@ -331,15 +330,23 @@ array rope(
int offset, int offset,
bool forward, bool forward,
StreamOrDevice s) { StreamOrDevice s) {
auto& x = inputs[0];
if (x.ndim() < 3) { if (x.ndim() < 3) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with " msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions."; << x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims
<< " but got shape " << inputs[1].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [dims, traditional, base, scale, offset, forward, s]( auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) { std::vector<array> inputs) {
auto& shape = inputs[0].shape(); auto& shape = inputs[0].shape();
int ndim = shape.size(); int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s); auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
@ -348,10 +355,20 @@ array rope(
// Compute sines and cosines // Compute sines and cosines
auto half_dims = dims / 2; auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s); auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s); auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
multiply(
arange(0, -half_dims, -1, t, s),
array(std::log(base) / half_dims, t),
s),
s);
};
auto inv_freqs =
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
auto theta = auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s); multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s); auto coss = cos(theta, s);
auto sins = sin(theta, s); auto sins = sin(theta, s);
@ -409,20 +426,39 @@ array rope(
x.dtype(), x.dtype(),
std::make_shared<RoPE>( std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset, forward), stream, fallback, dims, traditional, base, scale, offset, forward),
{x}); std::move(inputs));
} }
return fallback({x})[0]; return fallback(std::move(inputs))[0];
} }
array rope( array rope(
const array& x, const array& x,
int dims, int dims,
bool traditional, bool traditional,
float base, std::optional<float> base,
float scale, float scale,
int offset, int offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s); std::vector<array> inputs = {x};
if (freqs) {
inputs.push_back(astype(*freqs, float32, s));
if (base) {
throw std::invalid_argument(
"[rope] Only one of base or freqs can have a value.");
}
} else if (!base) {
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
}
return rope(
std::move(inputs),
dims,
traditional,
base.has_value() ? *base : 1.0,
scale,
offset,
true,
s);
} }
std::vector<array> RoPE::vjp( std::vector<array> RoPE::vjp(
@ -438,16 +474,27 @@ std::vector<array> RoPE::vjp(
offset = offset_, offset = offset_,
forward = forward_, forward = forward_,
s](std::vector<array> inputs) { s](std::vector<array> inputs) {
return std::vector<array>{ return std::vector<array>{rope(
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)}; std::move(inputs),
dims,
traditional,
base,
scale,
offset,
!forward,
s)};
}; };
auto inputs = cotangents;
if (primals.size() == 2) {
inputs.push_back(primals[1]);
}
return {array( return {array(
cotangents[0].shape(), cotangents[0].shape(),
cotangents[0].dtype(), cotangents[0].dtype(),
std::make_shared<RoPE>( std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_), s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)}; std::move(inputs))};
} }
bool RoPE::is_equivalent(const Primitive& other) const { bool RoPE::is_equivalent(const Primitive& other) const {

View File

@ -25,9 +25,10 @@ array rope(
const array& x, const array& x,
int dims, int dims,
bool traditional, bool traditional,
float base, std::optional<float> base,
float scale, float scale,
int offset, int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/ /** Computes: O = softmax(Q @ K.T) @ V **/

View File

@ -79,26 +79,29 @@ void init_fast(nb::module_& parent_module) {
"dims"_a, "dims"_a,
nb::kw_only(), nb::kw_only(),
"traditional"_a, "traditional"_a,
"base"_a, "base"_a.none(),
"scale"_a, "scale"_a,
"offset"_a, "offset"_a,
"freqs"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), "def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.
Args: Args:
a (array): Input array. a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions. implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions. scale (float): The scale used to scale the positions.
offset (int): The position offset to start at. offset (int): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. ``Default: None``.
Returns: Returns:
array: The output array. array: The output array.
)pbdoc"); )pbdoc");
@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(), "memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.

View File

@ -7,13 +7,18 @@ import mlx.core as mx
import mlx_tests import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset): def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset N = x.shape[1] + offset
dtype = x.dtype dtype = x.dtype
half_D = dims // 2 half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) if freqs is None:
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = 1 / freqs
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta) costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional: if traditional:
x1 = x[..., :dims:2] x1 = x[..., :dims:2]
@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rope_with_freqs(self):
# Check throws
T = 4
dims = 8
x = mx.random.uniform(shape=(2, T, dims))
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(dims - 1,))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(1, dims))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
freqs = mx.random.uniform(shape=(dims // 2,))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test grad with freqs
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 4, dims))
y = mx.random.uniform(shape=(2, 4, dims))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_grad(self): def test_rope_grad(self):
D = 32 D = 32
defaults = (D, 10000.0, 1.0, 0, False) defaults = (D, 10000.0, 1.0, 0, False)