mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
RoPE with frequencies as optional input (#1337)
* start rope with freq input * rope with frequencies * nits * fix bug * fix bug + test * cleanup * optional base
This commit is contained in:
parent
9d26441224
commit
bb1b76d9dc
@ -4,23 +4,20 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward>
|
template <typename T, bool traditional, bool forward>
|
||||||
[[kernel]] void rope_single(
|
void rope_single_impl(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in,
|
||||||
device T* out [[buffer(1)]],
|
device T* out,
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
constant const float& base,
|
const float inv_freq,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t& stride,
|
constant const size_t& stride,
|
||||||
uint2 pos [[thread_position_in_grid]],
|
uint2 pos,
|
||||||
uint2 grid [[threads_per_grid]]) {
|
uint2 grid) {
|
||||||
// Figure out L and d.
|
|
||||||
float L = scale * static_cast<float>(offset);
|
float L = scale * static_cast<float>(offset);
|
||||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
|
||||||
|
|
||||||
// Compute costheta, sintheta
|
// Compute costheta, sintheta
|
||||||
float theta = L * metal::exp2(-d * base);
|
float theta = L * inv_freq;
|
||||||
float costheta = metal::fast::cos(theta);
|
float costheta = metal::fast::cos(theta);
|
||||||
float sintheta = metal::fast::sin(theta);
|
float sintheta = metal::fast::sin(theta);
|
||||||
|
|
||||||
@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
|
|||||||
out[out_index_2] = static_cast<T>(rx2);
|
out[out_index_2] = static_cast<T>(rx2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward, int N = 4>
|
template <typename T, bool traditional, bool forward>
|
||||||
[[kernel]] void rope(
|
[[kernel]] void rope_single(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
constant const int& offset,
|
constant const int& offset,
|
||||||
constant const float& base,
|
constant const float& scale,
|
||||||
|
constant const size_t& stride,
|
||||||
|
constant const float& base [[buffer(10)]],
|
||||||
|
uint2 pos [[thread_position_in_grid]],
|
||||||
|
uint2 grid [[threads_per_grid]]) {
|
||||||
|
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||||
|
float inv_freq = metal::exp2(-d * base);
|
||||||
|
rope_single_impl<T, traditional, forward>(
|
||||||
|
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward>
|
||||||
|
[[kernel]] void rope_single_freqs(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
constant const int& offset,
|
||||||
|
constant const float& scale,
|
||||||
|
constant const size_t& stride,
|
||||||
|
const device float* freqs [[buffer(10)]],
|
||||||
|
constant const size_t& freq_stride [[buffer(11)]],
|
||||||
|
uint2 pos [[thread_position_in_grid]],
|
||||||
|
uint2 grid [[threads_per_grid]]) {
|
||||||
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
|
rope_single_impl<T, traditional, forward>(
|
||||||
|
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward, int N = 4>
|
||||||
|
void rope_impl(
|
||||||
|
const device T* in,
|
||||||
|
device T* out,
|
||||||
|
constant const int& offset,
|
||||||
|
const float inv_freq,
|
||||||
constant const float& scale,
|
constant const float& scale,
|
||||||
constant const size_t strides[3],
|
constant const size_t strides[3],
|
||||||
constant const size_t out_strides[3],
|
constant const size_t out_strides[3],
|
||||||
constant const size_t& n_batch,
|
constant const size_t& n_batch,
|
||||||
uint3 pos [[thread_position_in_grid]],
|
uint3 pos,
|
||||||
uint3 grid [[threads_per_grid]]) {
|
uint3 grid) {
|
||||||
// Figure out L and d.
|
|
||||||
float L = scale * static_cast<float>(pos.y + offset);
|
float L = scale * static_cast<float>(pos.y + offset);
|
||||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
|
||||||
|
|
||||||
// Compute costheta, sintheta
|
// Compute costheta, sintheta
|
||||||
float theta = L * metal::exp2(-d * base);
|
float theta = L * inv_freq;
|
||||||
float costheta = metal::fast::cos(theta);
|
float costheta = metal::fast::cos(theta);
|
||||||
float sintheta = metal::fast::sin(theta);
|
float sintheta = metal::fast::sin(theta);
|
||||||
|
|
||||||
@ -116,17 +143,85 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward, int N = 4>
|
||||||
|
[[kernel]] void rope(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
constant const int& offset,
|
||||||
|
constant const float& scale,
|
||||||
|
constant const size_t strides[3],
|
||||||
|
constant const size_t out_strides[3],
|
||||||
|
constant const size_t& n_batch,
|
||||||
|
constant const float& base [[buffer(10)]],
|
||||||
|
uint3 pos [[thread_position_in_grid]],
|
||||||
|
uint3 grid [[threads_per_grid]]) {
|
||||||
|
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||||
|
float inv_freq = metal::exp2(-d * base);
|
||||||
|
rope_impl<T, traditional, forward, N>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset,
|
||||||
|
inv_freq,
|
||||||
|
scale,
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
n_batch,
|
||||||
|
pos,
|
||||||
|
grid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool traditional, bool forward, int N = 4>
|
||||||
|
[[kernel]] void rope_freqs(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
constant const int& offset,
|
||||||
|
constant const float& scale,
|
||||||
|
constant const size_t strides[3],
|
||||||
|
constant const size_t out_strides[3],
|
||||||
|
constant const size_t& n_batch,
|
||||||
|
const device float* freqs [[buffer(10)]],
|
||||||
|
constant const size_t& freq_stride [[buffer(11)]],
|
||||||
|
uint3 pos [[thread_position_in_grid]],
|
||||||
|
uint3 grid [[threads_per_grid]]) {
|
||||||
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
|
rope_impl<T, traditional, forward, N>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset,
|
||||||
|
inv_freq,
|
||||||
|
scale,
|
||||||
|
strides,
|
||||||
|
out_strides,
|
||||||
|
n_batch,
|
||||||
|
pos,
|
||||||
|
grid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||||
rope<type, traditional, forward>( \
|
rope<type, traditional, forward>( \
|
||||||
const device type* in [[buffer(0)]], \
|
const device type* in [[buffer(0)]], \
|
||||||
device type* out [[buffer(1)]], \
|
device type* out [[buffer(1)]], \
|
||||||
constant const int& offset, \
|
constant const int& offset, \
|
||||||
constant const float& base, \
|
|
||||||
constant const float& scale, \
|
constant const float& scale, \
|
||||||
constant const size_t strides[3], \
|
constant const size_t strides[3], \
|
||||||
constant const size_t out_strides[3], \
|
constant const size_t out_strides[3], \
|
||||||
constant const size_t& n_batch, \
|
constant const size_t& n_batch, \
|
||||||
|
constant const float& base [[buffer(10)]], \
|
||||||
|
uint3 pos [[thread_position_in_grid]], \
|
||||||
|
uint3 grid [[threads_per_grid]]); \
|
||||||
|
template [[host_name("rope_freqs_" #name)]] \
|
||||||
|
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||||
|
const device type* in [[buffer(0)]], \
|
||||||
|
device type* out [[buffer(1)]], \
|
||||||
|
constant const int& offset, \
|
||||||
|
constant const float& scale, \
|
||||||
|
constant const size_t strides[3], \
|
||||||
|
constant const size_t out_strides[3], \
|
||||||
|
constant const size_t& n_batch, \
|
||||||
|
const device float* freqs [[buffer(10)]], \
|
||||||
|
constant const size_t& freq_stride [[buffer(11)]], \
|
||||||
uint3 pos [[thread_position_in_grid]], \
|
uint3 pos [[thread_position_in_grid]], \
|
||||||
uint3 grid [[threads_per_grid]]);
|
uint3 grid [[threads_per_grid]]);
|
||||||
|
|
||||||
@ -136,9 +231,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
const device type* in [[buffer(0)]], \
|
const device type* in [[buffer(0)]], \
|
||||||
device type* out [[buffer(1)]], \
|
device type* out [[buffer(1)]], \
|
||||||
constant const int& offset, \
|
constant const int& offset, \
|
||||||
constant const float& base, \
|
|
||||||
constant const float& scale, \
|
constant const float& scale, \
|
||||||
constant const size_t& stride, \
|
constant const size_t& stride, \
|
||||||
|
constant const float& base [[buffer(10)]], \
|
||||||
|
uint2 pos [[thread_position_in_grid]], \
|
||||||
|
uint2 grid [[threads_per_grid]]); \
|
||||||
|
template [[host_name("rope_single_freqs_" #name)]] \
|
||||||
|
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||||
|
const device type* in [[buffer(0)]], \
|
||||||
|
device type* out [[buffer(1)]], \
|
||||||
|
constant const int& offset, \
|
||||||
|
constant const float& scale, \
|
||||||
|
constant const size_t& stride, \
|
||||||
|
const device float* freqs [[buffer(10)]], \
|
||||||
|
constant const size_t& freq_stride [[buffer(11)]], \
|
||||||
uint2 pos [[thread_position_in_grid]], \
|
uint2 pos [[thread_position_in_grid]], \
|
||||||
uint2 grid [[threads_per_grid]]);
|
uint2 grid [[threads_per_grid]]);
|
||||||
|
|
||||||
@ -146,7 +252,6 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
instantiate_rope_s(name, type, traditional, forward) \
|
instantiate_rope_s(name, type, traditional, forward) \
|
||||||
instantiate_rope_g(name, type, traditional, forward)
|
instantiate_rope_g(name, type, traditional, forward)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_rope(traditional_float16, half, true, true)
|
instantiate_rope(traditional_float16, half, true, true)
|
||||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||||
instantiate_rope(traditional_float32, float, true, true)
|
instantiate_rope(traditional_float32, float, true, true)
|
||||||
|
@ -67,8 +67,10 @@ void RoPE::eval_gpu(
|
|||||||
// Special case for inference (single time step and contiguous)
|
// Special case for inference (single time step and contiguous)
|
||||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||||
|
|
||||||
|
bool with_freqs = inputs.size() == 2;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
|
kname << "rope_" << (single ? "single_" : "")
|
||||||
|
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -78,27 +80,36 @@ void RoPE::eval_gpu(
|
|||||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||||
compute_encoder->setBytes(&base, sizeof(float), 3);
|
compute_encoder->setBytes(&scale_, sizeof(float), 3);
|
||||||
compute_encoder->setBytes(&scale_, sizeof(float), 4);
|
|
||||||
|
|
||||||
size_t n_batch = in.size() / mat_size;
|
size_t n_batch = in.size() / mat_size;
|
||||||
|
MTL::Size group_dims;
|
||||||
|
MTL::Size grid_dims;
|
||||||
if (single) {
|
if (single) {
|
||||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
|
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
|
||||||
uint32_t dim0 = dims_ / 2;
|
uint32_t dim0 = dims_ / 2;
|
||||||
auto group_dims = get_block_dims(dim0, n_batch, 1);
|
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||||
auto grid_dims = MTL::Size(dim0, n_batch, 1);
|
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
} else {
|
} else {
|
||||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
|
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
|
||||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
|
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
|
||||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
|
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
|
||||||
uint32_t dim0 = dims_ / 2;
|
uint32_t dim0 = dims_ / 2;
|
||||||
uint32_t dim1 = in.shape(-2);
|
uint32_t dim1 = in.shape(-2);
|
||||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (with_freqs) {
|
||||||
|
auto& freqs = inputs[1];
|
||||||
|
compute_encoder.set_input_array(freqs, 10);
|
||||||
|
auto freq_stride = freqs.strides()[0];
|
||||||
|
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
|
||||||
|
} else {
|
||||||
|
compute_encoder->setBytes(&base, sizeof(float), 10);
|
||||||
|
}
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
73
mlx/fast.cpp
73
mlx/fast.cpp
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array rope(
|
array rope(
|
||||||
const array& x,
|
std::vector<array> inputs,
|
||||||
int dims,
|
int dims,
|
||||||
bool traditional,
|
bool traditional,
|
||||||
float base,
|
float base,
|
||||||
@ -331,15 +330,23 @@ array rope(
|
|||||||
int offset,
|
int offset,
|
||||||
bool forward,
|
bool forward,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
|
auto& x = inputs[0];
|
||||||
if (x.ndim() < 3) {
|
if (x.ndim() < 3) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||||
<< x.ndim() << " dimensions.";
|
<< x.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
if (inputs.size() == 2 &&
|
||||||
|
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rope] freqs must be one dimensional with size " << dims
|
||||||
|
<< " but got shape " << inputs[1].shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||||
const std::vector<array>& inputs) {
|
std::vector<array> inputs) {
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||||
@ -348,10 +355,20 @@ array rope(
|
|||||||
// Compute sines and cosines
|
// Compute sines and cosines
|
||||||
auto half_dims = dims / 2;
|
auto half_dims = dims / 2;
|
||||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
|
||||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||||
|
return exp(
|
||||||
|
multiply(
|
||||||
|
arange(0, -half_dims, -1, t, s),
|
||||||
|
array(std::log(base) / half_dims, t),
|
||||||
|
s),
|
||||||
|
s);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto inv_freqs =
|
||||||
|
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
|
||||||
auto theta =
|
auto theta =
|
||||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||||
auto coss = cos(theta, s);
|
auto coss = cos(theta, s);
|
||||||
auto sins = sin(theta, s);
|
auto sins = sin(theta, s);
|
||||||
|
|
||||||
@ -409,20 +426,39 @@ array rope(
|
|||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<RoPE>(
|
std::make_shared<RoPE>(
|
||||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||||
{x});
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
return fallback({x})[0];
|
return fallback(std::move(inputs))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
array rope(
|
array rope(
|
||||||
const array& x,
|
const array& x,
|
||||||
int dims,
|
int dims,
|
||||||
bool traditional,
|
bool traditional,
|
||||||
float base,
|
std::optional<float> base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
int offset,
|
||||||
|
const std::optional<array>& freqs /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
std::vector<array> inputs = {x};
|
||||||
|
if (freqs) {
|
||||||
|
inputs.push_back(astype(*freqs, float32, s));
|
||||||
|
if (base) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[rope] Only one of base or freqs can have a value.");
|
||||||
|
}
|
||||||
|
} else if (!base) {
|
||||||
|
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
|
||||||
|
}
|
||||||
|
return rope(
|
||||||
|
std::move(inputs),
|
||||||
|
dims,
|
||||||
|
traditional,
|
||||||
|
base.has_value() ? *base : 1.0,
|
||||||
|
scale,
|
||||||
|
offset,
|
||||||
|
true,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> RoPE::vjp(
|
std::vector<array> RoPE::vjp(
|
||||||
@ -438,16 +474,27 @@ std::vector<array> RoPE::vjp(
|
|||||||
offset = offset_,
|
offset = offset_,
|
||||||
forward = forward_,
|
forward = forward_,
|
||||||
s](std::vector<array> inputs) {
|
s](std::vector<array> inputs) {
|
||||||
return std::vector<array>{
|
return std::vector<array>{rope(
|
||||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
std::move(inputs),
|
||||||
|
dims,
|
||||||
|
traditional,
|
||||||
|
base,
|
||||||
|
scale,
|
||||||
|
offset,
|
||||||
|
!forward,
|
||||||
|
s)};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto inputs = cotangents;
|
||||||
|
if (primals.size() == 2) {
|
||||||
|
inputs.push_back(primals[1]);
|
||||||
|
}
|
||||||
return {array(
|
return {array(
|
||||||
cotangents[0].shape(),
|
cotangents[0].shape(),
|
||||||
cotangents[0].dtype(),
|
cotangents[0].dtype(),
|
||||||
std::make_shared<RoPE>(
|
std::make_shared<RoPE>(
|
||||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||||
cotangents)};
|
std::move(inputs))};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||||
|
@ -25,9 +25,10 @@ array rope(
|
|||||||
const array& x,
|
const array& x,
|
||||||
int dims,
|
int dims,
|
||||||
bool traditional,
|
bool traditional,
|
||||||
float base,
|
std::optional<float> base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
int offset,
|
||||||
|
const std::optional<array>& freqs = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||||
|
@ -79,12 +79,13 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"dims"_a,
|
"dims"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"traditional"_a,
|
"traditional"_a,
|
||||||
"base"_a,
|
"base"_a.none(),
|
||||||
"scale"_a,
|
"scale"_a,
|
||||||
"offset"_a,
|
"offset"_a,
|
||||||
|
"freqs"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Apply rotary positional encoding to the input.
|
Apply rotary positional encoding to the input.
|
||||||
|
|
||||||
@ -94,11 +95,13 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
is larger than dims then the rest is left unchanged.
|
is larger than dims then the rest is left unchanged.
|
||||||
traditional (bool): If set to ``True`` choose the traditional
|
traditional (bool): If set to ``True`` choose the traditional
|
||||||
implementation which rotates consecutive dimensions.
|
implementation which rotates consecutive dimensions.
|
||||||
base (float): The base used to compute angular frequency for
|
base (float, optional): The base used to compute angular frequency for
|
||||||
each dimension in the positional encodings.
|
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||||
|
``freqs`` must be ``None``.
|
||||||
scale (float): The scale used to scale the positions.
|
scale (float): The scale used to scale the positions.
|
||||||
offset (int): The position offset to start at.
|
offset (int): The position offset to start at.
|
||||||
|
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||||
|
If set, the ``base`` parameter must be ``None``. ``Default: None``.
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"memory_efficient_threshold"_a = nb::none(),
|
"memory_efficient_threshold"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||||
|
|
||||||
|
@ -7,13 +7,18 @@ import mlx.core as mx
|
|||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
def rope_orig(x, dims, traditional, base, scale, offset):
|
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
|
||||||
N = x.shape[1] + offset
|
N = x.shape[1] + offset
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
half_D = dims // 2
|
half_D = dims // 2
|
||||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||||
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
if freqs is None:
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
inv_freqs = mx.exp(
|
||||||
|
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inv_freqs = 1 / freqs
|
||||||
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
|
||||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||||
if traditional:
|
if traditional:
|
||||||
x1 = x[..., :dims:2]
|
x1 = x[..., :dims:2]
|
||||||
@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
def test_rope_with_freqs(self):
|
||||||
|
# Check throws
|
||||||
|
T = 4
|
||||||
|
dims = 8
|
||||||
|
x = mx.random.uniform(shape=(2, T, dims))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
freqs = mx.random.uniform(shape=(dims - 1,))
|
||||||
|
mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=0,
|
||||||
|
freqs=freqs,
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
freqs = mx.random.uniform(shape=(1, dims))
|
||||||
|
mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=0,
|
||||||
|
freqs=freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = mx.random.uniform(shape=(dims // 2,))
|
||||||
|
|
||||||
|
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=0,
|
||||||
|
freqs=freqs,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||||
|
|
||||||
|
# Test single vector
|
||||||
|
x = mx.random.uniform(shape=(1, 1, dims))
|
||||||
|
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=0,
|
||||||
|
freqs=freqs,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
|
||||||
|
|
||||||
|
# Test grad with freqs
|
||||||
|
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
|
||||||
|
f2 = lambda x, y: (
|
||||||
|
mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=0,
|
||||||
|
freqs=freqs,
|
||||||
|
)
|
||||||
|
* y
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 4, dims))
|
||||||
|
y = mx.random.uniform(shape=(2, 4, dims))
|
||||||
|
g1 = mx.grad(f1)(x, y)
|
||||||
|
g2 = mx.grad(f2)(x, y)
|
||||||
|
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||||
|
|
||||||
def test_rope_grad(self):
|
def test_rope_grad(self):
|
||||||
D = 32
|
D = 32
|
||||||
defaults = (D, 10000.0, 1.0, 0, False)
|
defaults = (D, 10000.0, 1.0, 0, False)
|
||||||
|
Loading…
Reference in New Issue
Block a user