mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 21:16:25 +08:00
Compare commits
1 Commits
main
...
batch_rope
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c076794a22 |
@@ -230,9 +230,6 @@ jobs:
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@@ -263,6 +260,7 @@ jobs:
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
|
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const size_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -102,20 +108,19 @@ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const size_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
@@ -168,11 +175,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const size_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
|
@@ -18,30 +18,26 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
size_t out_strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
if (dims_ < in.shape(-1)) {
|
||||
int n_batch = in.shape(0);
|
||||
int n_head = in.shape(1);
|
||||
int n_seq = in.shape(2);
|
||||
int n_dim = in.shape(3);
|
||||
size_t mat_size = n_seq * n_dim;
|
||||
|
||||
if (dims_ < n_dim) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
strides[1] = out.strides()[2];
|
||||
strides[2] = out.strides()[3];
|
||||
} else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
@@ -50,29 +46,29 @@ void RoPE::eval_gpu(
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
strides[1] = in.strides()[2];
|
||||
strides[2] = in.strides()[3];
|
||||
} else if (n_batch == 1) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
strides[0] = in.strides()[1];
|
||||
strides[1] = in.strides()[2];
|
||||
strides[2] = in.strides()[3];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
strides[1] = out.strides()[2];
|
||||
strides[2] = out.strides()[3];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
out_strides[1] = out.strides()[2];
|
||||
out_strides[2] = out.strides()[3];
|
||||
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool single = in.flags().row_contiguous && n_batch == 1 && n_seq == 1;
|
||||
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
@@ -86,24 +82,30 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
// Copy offset and offset stride
|
||||
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
group_dims = get_block_dims(dim0, n_head, 1);
|
||||
grid_dims = MTL::Size(dim0, n_head, 1);
|
||||
} else {
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
compute_encoder.set_bytes(n_batch, 6);
|
||||
size_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
compute_encoder.set_bytes(offset_stride, 6);
|
||||
compute_encoder.set_bytes(n_head, 7);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
uint32_t dim1 = n_seq;
|
||||
uint32_t dim2 = n_batch * ((n_head + n_per_thread - 1) / n_per_thread);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
61
mlx/fast.cpp
61
mlx/fast.cpp
@@ -355,10 +355,10 @@ array rope(
|
||||
StreamOrDevice s) {
|
||||
auto& x = inputs[0];
|
||||
auto& offset = inputs[1];
|
||||
if (x.ndim() < 3) {
|
||||
if (x.ndim() != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
msg << "[rope] Input must have 4 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(x.dtype(), floating)) {
|
||||
@@ -366,10 +366,16 @@ array rope(
|
||||
msg << "[rope] Input must be a floating type but got " << x.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1) {
|
||||
if (offset.ndim() > 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar but has shape " << offset.shape()
|
||||
<< ".";
|
||||
msg << "[rope] offset must have at most one dimension but has shape "
|
||||
<< offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.size() != 1 && offset.size() != x.shape(0)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] offset must be a scalar or vector with " << x.shape(0)
|
||||
<< " elements but has shape " << offset.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!issubdtype(offset.dtype(), integer)) {
|
||||
@@ -379,7 +385,7 @@ array rope(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (offset.dtype().size() != 4) {
|
||||
inputs[1] = astype(offset, uint32, s);
|
||||
inputs[1] = astype(offset, int32, s);
|
||||
}
|
||||
if (inputs.size() == 3 &&
|
||||
(inputs[2].ndim() != 1 || inputs[2].shape(0) != dims / 2)) {
|
||||
@@ -391,15 +397,19 @@ array rope(
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, forward, s](
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = flatten(inputs[0], 0, ndim - 3, s);
|
||||
auto x = inputs[0];
|
||||
auto B = x.shape(0);
|
||||
auto N = x.shape(1);
|
||||
auto T = x.shape(2);
|
||||
auto t = x.dtype();
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto& offset = inputs[1];
|
||||
auto offset = inputs[1];
|
||||
if (offset.size() > 1) {
|
||||
offset = expand_dims(offset, {-1, -2}, s);
|
||||
}
|
||||
auto positions =
|
||||
multiply(add(arange(x.shape(1), t, s), offset, s), array(scale, t), s);
|
||||
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
@@ -412,8 +422,7 @@ array rope(
|
||||
|
||||
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s)
|
||||
: default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@@ -436,32 +445,30 @@ array rope(
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 0, 1}, {B, N, T, dims}, {1, 1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
o = expand_dims(o, -1, s);
|
||||
}
|
||||
auto out = concatenate(outs, 3, s);
|
||||
auto out = reshape(concatenate(outs, -1, s), {B, N, T, dims}, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
out =
|
||||
concatenate({out, slice(x, {0, 0, 0, dims}, x.shape(), s)}, -1, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
return std::vector<array>{out};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
auto x1 = slice(x, {0, 0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
auto x2 = slice(x, {0, 0, 0, half_dims}, out_s, s);
|
||||
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
outs.push_back(slice(x, {0, 0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
return std::vector<array>{concatenate(outs, -1, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
|
@@ -164,8 +164,13 @@ void init_fast(nb::module_& parent_module) {
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
The input is expected to be 4D with shape ``(B, *, T, D)`` where:
|
||||
* ``B`` is the batch size.
|
||||
* ``T`` is the sequence length.
|
||||
* ``D`` is the feature dimension.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
a (array): A 4-D input array.
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
@@ -174,7 +179,9 @@ void init_fast(nb::module_& parent_module) {
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int or array): The position offset to start at.
|
||||
offset (int or array): The position offset to start at. If an
|
||||
:obj:`array` is given it can be a scalar or vector of ``B``
|
||||
offsets for each example in the batch.
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||
|
||||
|
@@ -91,7 +91,7 @@ mx::array to_array_with_accessor(nb::object obj) {
|
||||
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@@ -8,18 +8,22 @@ import mlx_tests
|
||||
|
||||
|
||||
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||
offset = offset.item() if isinstance(offset, mx.array) else offset
|
||||
N = x.shape[-2] + offset
|
||||
N = x.shape[-2]
|
||||
dtype = x.dtype
|
||||
half_D = dims // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
positions = mx.arange(N, dtype=dtype)
|
||||
if isinstance(offset, mx.array) and offset.size > 1:
|
||||
positions = offset[:, None, None] + positions
|
||||
else:
|
||||
positions = offset + positions
|
||||
positions = positions * scale
|
||||
if freqs is None:
|
||||
inv_freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
else:
|
||||
inv_freqs = (1 / freqs).astype(x.dtype)
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||
theta = mx.expand_dims(positions, -1) * inv_freqs
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., :dims:2]
|
||||
@@ -83,7 +87,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
for traditional in [True, False]:
|
||||
dims, dtype, _, scale, offset, _ = defaults
|
||||
for base in bases:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
x = mx.random.uniform(shape=(1, 2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
@@ -97,7 +101,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
dims, _, base, scale, offset, _ = defaults
|
||||
for dtype in dtypes:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
x = mx.random.uniform(shape=(1, 2, T, dims)).astype(dtype)
|
||||
ry = rope_orig(
|
||||
x.astype(mx.float32), dims, traditional, base, scale, offset
|
||||
)
|
||||
@@ -118,7 +122,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
dims, dtype, base, scale, _, _ = defaults
|
||||
for offset in offsets:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
x = mx.random.uniform(shape=(1, 2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
@@ -132,7 +136,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
dims, dtype, base, _, offset, _ = defaults
|
||||
for scale in scales:
|
||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||
x = mx.random.uniform(shape=(1, 2, T, dims)).astype(dtype)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
@@ -160,7 +164,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test raises with integer inputs
|
||||
dims, _, base, scale, offset, traditional = defaults
|
||||
x = (mx.random.uniform(shape=(2, T, dims)) * 10).astype(mx.int32)
|
||||
x = (mx.random.uniform(shape=(1, 2, T, dims)) * 10).astype(mx.int32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = mx.fast.rope(
|
||||
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||
@@ -172,7 +176,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
# Check throws
|
||||
T = 4
|
||||
dims = 8
|
||||
x = mx.random.uniform(shape=(2, T, dims))
|
||||
x = mx.random.uniform(shape=(1, 2, T, dims))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
freqs = mx.random.uniform(shape=(dims - 1,))
|
||||
@@ -214,9 +218,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertEqual(dtype, rx.dtype)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
return
|
||||
|
||||
# Test single vector
|
||||
x = mx.random.uniform(shape=(1, 1, dims))
|
||||
x = mx.random.uniform(shape=(1, 1, 1, dims))
|
||||
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
@@ -244,8 +249,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
* y
|
||||
).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4, dims))
|
||||
y = mx.random.uniform(shape=(2, 4, dims))
|
||||
x = mx.random.uniform(shape=(1, 2, 4, dims))
|
||||
y = mx.random.uniform(shape=(1, 2, 4, dims))
|
||||
g1 = mx.grad(f1)(x, y)
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
@@ -271,12 +276,61 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
* y
|
||||
).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(2, 100, D))
|
||||
y = mx.random.uniform(shape=(2, 100, D))
|
||||
x = mx.random.uniform(shape=(1, 2, 100, D))
|
||||
y = mx.random.uniform(shape=(1, 2, 100, D))
|
||||
g1 = mx.grad(f1)(x, y)
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
|
||||
def test_rope_batch(self):
|
||||
T = 4
|
||||
base = 10000.0
|
||||
scale = 1.0
|
||||
traditional = True
|
||||
batch_sizes = [3, 8, 11]
|
||||
num_heads = [1, 3, 5]
|
||||
dims = 32
|
||||
|
||||
x = mx.random.uniform(shape=(8, 4, T, dims))
|
||||
|
||||
offset = mx.array([1, 2, 3])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for n_head in num_heads:
|
||||
x = mx.random.uniform(shape=(batch_size, n_head, T, dims))
|
||||
offset = mx.arange(batch_size)
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
rx_fast = mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
x = mx.random.normal(shape=(2, 6, 8, 64)).transpose(0, 2, 1, 3)
|
||||
dims = 64
|
||||
offset = 0
|
||||
rx_fast = mx.fast.rope(
|
||||
x, dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
rx_fast_single = mx.fast.rope(
|
||||
x[0:1], dims, traditional=traditional, scale=scale, base=base, offset=offset
|
||||
)
|
||||
|
||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||
|
||||
def test_rms_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
@@ -544,7 +598,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
x = mx.random.uniform(shape=(1, 2, 2, 8))
|
||||
|
||||
defaults = (8, False, 10000.0, 1.0, 0)
|
||||
dims, traditional, base, scale, offset = defaults
|
||||
@@ -572,7 +626,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
|
||||
|
||||
# VMAP
|
||||
x = mx.random.uniform(shape=(2, 2, 2, 8))
|
||||
x = mx.random.uniform(shape=(2, 2, 2, 2, 8))
|
||||
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
|
||||
vmap_fast_out = mx.vmap(
|
||||
lambda x: mx.fast.rope(
|
||||
|
@@ -1069,7 +1069,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
shape = (1, 3, 4)
|
||||
shape = (1, 1, 3, 4)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
y = rope(x)
|
||||
self.assertEqual(y.shape, shape)
|
||||
|
Reference in New Issue
Block a user