RoPE with frequencies as optional input (#1337)

* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
This commit is contained in:
Awni Hannun 2024-08-19 18:30:50 -07:00 committed by GitHub
parent 9d26441224
commit bb1b76d9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 319 additions and 69 deletions

View File

@ -4,23 +4,20 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
void rope_single_impl(
const device T* in,
device T* out,
constant const int& offset,
constant const float& base,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
// Figure out L and d.
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& base,
constant const float& scale,
constant const size_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@ -116,37 +143,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
uint2 pos [[thread_position_in_grid]], \
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)

View File

@ -67,8 +67,10 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
@ -78,27 +80,36 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}
if (with_freqs) {
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
}
array rope(
const array& x,
std::vector<array> inputs,
int dims,
bool traditional,
float base,
@ -331,15 +330,23 @@ array rope(
int offset,
bool forward,
StreamOrDevice s) {
auto& x = inputs[0];
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims
<< " but got shape " << inputs[1].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) {
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
@ -348,10 +355,20 @@ array rope(
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
multiply(
arange(0, -half_dims, -1, t, s),
array(std::log(base) / half_dims, t),
s),
s);
};
auto inv_freqs =
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
@ -409,20 +426,39 @@ array rope(
x.dtype(),
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset, forward),
{x});
std::move(inputs));
}
return fallback({x})[0];
return fallback(std::move(inputs))[0];
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s);
std::vector<array> inputs = {x};
if (freqs) {
inputs.push_back(astype(*freqs, float32, s));
if (base) {
throw std::invalid_argument(
"[rope] Only one of base or freqs can have a value.");
}
} else if (!base) {
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
}
return rope(
std::move(inputs),
dims,
traditional,
base.has_value() ? *base : 1.0,
scale,
offset,
true,
s);
}
std::vector<array> RoPE::vjp(
@ -438,16 +474,27 @@ std::vector<array> RoPE::vjp(
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
return std::vector<array>{rope(
std::move(inputs),
dims,
traditional,
base,
scale,
offset,
!forward,
s)};
};
auto inputs = cotangents;
if (primals.size() == 2) {
inputs.push_back(primals[1]);
}
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)};
std::move(inputs))};
}
bool RoPE::is_equivalent(const Primitive& other) const {

View File

@ -25,9 +25,10 @@ array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/

View File

@ -79,26 +79,29 @@ void init_fast(nb::module_& parent_module) {
"dims"_a,
nb::kw_only(),
"traditional"_a,
"base"_a,
"base"_a.none(),
"scale"_a,
"offset"_a,
"freqs"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
Args:
a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
implementation which rotates consecutive dimensions.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. ``Default: None``.
Returns:
array: The output array.
)pbdoc");
@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.

View File

@ -7,13 +7,18 @@ import mlx.core as mx
import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset):
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = 1 / freqs
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rope_with_freqs(self):
# Check throws
T = 4
dims = 8
x = mx.random.uniform(shape=(2, T, dims))
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(dims - 1,))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(1, dims))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
freqs = mx.random.uniform(shape=(dims // 2,))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test grad with freqs
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 4, dims))
y = mx.random.uniform(shape=(2, 4, dims))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_grad(self):
D = 32
defaults = (D, 10000.0, 1.0, 0, False)