mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
RoPE with frequencies as optional input (#1337)
* start rope with freq input * rope with frequencies * nits * fix bug * fix bug + test * cleanup * optional base
This commit is contained in:
parent
9d26441224
commit
bb1b76d9dc
@ -4,23 +4,20 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
void rope_single_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
@ -116,17 +143,85 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
@ -136,9 +231,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
|
||||
@ -146,7 +252,6 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
|
||||
// clang-format off
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
|
@ -67,8 +67,10 @@ void RoPE::eval_gpu(
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
bool with_freqs = inputs.size() == 2;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -78,27 +80,36 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
auto group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
auto grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
} else {
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
||||
if (with_freqs) {
|
||||
auto& freqs = inputs[1];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
|
||||
} else {
|
||||
compute_encoder->setBytes(&base, sizeof(float), 10);
|
||||
}
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
73
mlx/fast.cpp
73
mlx/fast.cpp
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
std::vector<array> inputs,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
@ -331,15 +330,23 @@ array rope(
|
||||
int offset,
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
auto& x = inputs[0];
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (inputs.size() == 2 &&
|
||||
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] freqs must be one dimensional with size " << dims
|
||||
<< " but got shape " << inputs[1].shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||
@ -348,10 +355,20 @@ array rope(
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
multiply(
|
||||
arange(0, -half_dims, -1, t, s),
|
||||
array(std::log(base) / half_dims, t),
|
||||
s),
|
||||
s);
|
||||
};
|
||||
|
||||
auto inv_freqs =
|
||||
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@ -409,20 +426,39 @@ array rope(
|
||||
x.dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
{x});
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback({x})[0];
|
||||
return fallback(std::move(inputs))[0];
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const std::optional<array>& freqs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
||||
std::vector<array> inputs = {x};
|
||||
if (freqs) {
|
||||
inputs.push_back(astype(*freqs, float32, s));
|
||||
if (base) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Only one of base or freqs can have a value.");
|
||||
}
|
||||
} else if (!base) {
|
||||
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
|
||||
}
|
||||
return rope(
|
||||
std::move(inputs),
|
||||
dims,
|
||||
traditional,
|
||||
base.has_value() ? *base : 1.0,
|
||||
scale,
|
||||
offset,
|
||||
true,
|
||||
s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
@ -438,16 +474,27 @@ std::vector<array> RoPE::vjp(
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
||||
return std::vector<array>{rope(
|
||||
std::move(inputs),
|
||||
dims,
|
||||
traditional,
|
||||
base,
|
||||
scale,
|
||||
offset,
|
||||
!forward,
|
||||
s)};
|
||||
};
|
||||
|
||||
auto inputs = cotangents;
|
||||
if (primals.size() == 2) {
|
||||
inputs.push_back(primals[1]);
|
||||
}
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
cotangents)};
|
||||
std::move(inputs))};
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
|
@ -25,9 +25,10 @@ array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const std::optional<array>& freqs = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
|
@ -79,12 +79,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
"traditional"_a,
|
||||
"base"_a,
|
||||
"base"_a.none(),
|
||||
"scale"_a,
|
||||
"offset"_a,
|
||||
"freqs"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
@ -94,11 +95,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
implementation which rotates consecutive dimensions.
|
||||
base (float): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int): The position offset to start at.
|
||||
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. ``Default: None``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
|
@ -7,13 +7,18 @@ import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset):
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
N = x.shape[1] + offset
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
if freqs is None:
|
||||
inv_freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
else:
|
||||
inv_freqs = 1 / freqs
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., :dims:2]
|
||||
@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_rope_with_freqs(self):
|
||||
# Check throws
|
||||
T = 4
|
||||
dims = 8
|
||||
x = mx.random.uniform(shape=(2, T, dims))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
freqs = mx.random.uniform(shape=(dims - 1,))
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=0,
|
||||
freqs=freqs,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
freqs = mx.random.uniform(shape=(1, dims))
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=0,
|
||||
freqs=freqs,
|
||||
)
|
||||
|
||||
freqs = mx.random.uniform(shape=(dims // 2,))
|
||||
|
||||
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=0,
|
||||
freqs=freqs,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
# Test single vector
|
||||
x = mx.random.uniform(shape=(1, 1, dims))
|
||||
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=0,
|
||||
freqs=freqs,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
# Test grad with freqs
|
||||
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
|
||||
f2 = lambda x, y: (
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=0,
|
||||
freqs=freqs,
|
||||
)
|
||||
* y
|
||||
).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4, dims))
|
||||
y = mx.random.uniform(shape=(2, 4, dims))
|
||||
g1 = mx.grad(f1)(x, y)
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
|
||||
def test_rope_grad(self):
|
||||
D = 32
|
||||
defaults = (D, 10000.0, 1.0, 0, False)
|
||||
|
Loading…
Reference in New Issue
Block a user