RoPE with frequencies as optional input (#1337)

* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
This commit is contained in:
Awni Hannun 2024-08-19 18:30:50 -07:00 committed by GitHub
parent 9d26441224
commit bb1b76d9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 319 additions and 69 deletions

View File

@ -4,23 +4,20 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
void rope_single_impl(
const device T* in,
device T* out,
constant const int& offset,
constant const float& base,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
// Figure out L and d.
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& base,
constant const float& scale,
constant const size_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@ -116,17 +143,85 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
@ -136,9 +231,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
@ -146,7 +252,6 @@ template <typename T, bool traditional, bool forward, int N = 4>
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)

View File

@ -67,8 +67,10 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
@ -78,27 +80,36 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}
if (with_freqs) {
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
@ -323,7 +322,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
}
array rope(
const array& x,
std::vector<array> inputs,
int dims,
bool traditional,
float base,
@ -331,15 +330,23 @@ array rope(
int offset,
bool forward,
StreamOrDevice s) {
auto& x = inputs[0];
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (inputs.size() == 2 &&
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
std::ostringstream msg;
msg << "[rope] freqs must be one dimensional with size " << dims
<< " but got shape " << inputs[1].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [dims, traditional, base, scale, offset, forward, s](
const std::vector<array>& inputs) {
std::vector<array> inputs) {
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
@ -348,10 +355,20 @@ array rope(
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
return exp(
multiply(
arange(0, -half_dims, -1, t, s),
array(std::log(base) / half_dims, t),
s),
s);
};
auto inv_freqs =
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
@ -409,20 +426,39 @@ array rope(
x.dtype(),
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset, forward),
{x});
std::move(inputs));
}
return fallback({x})[0];
return fallback(std::move(inputs))[0];
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return rope(x, dims, traditional, base, scale, offset, true, s);
std::vector<array> inputs = {x};
if (freqs) {
inputs.push_back(astype(*freqs, float32, s));
if (base) {
throw std::invalid_argument(
"[rope] Only one of base or freqs can have a value.");
}
} else if (!base) {
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
}
return rope(
std::move(inputs),
dims,
traditional,
base.has_value() ? *base : 1.0,
scale,
offset,
true,
s);
}
std::vector<array> RoPE::vjp(
@ -438,16 +474,27 @@ std::vector<array> RoPE::vjp(
offset = offset_,
forward = forward_,
s](std::vector<array> inputs) {
return std::vector<array>{
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
return std::vector<array>{rope(
std::move(inputs),
dims,
traditional,
base,
scale,
offset,
!forward,
s)};
};
auto inputs = cotangents;
if (primals.size() == 2) {
inputs.push_back(primals[1]);
}
return {array(
cotangents[0].shape(),
cotangents[0].dtype(),
std::make_shared<RoPE>(
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
cotangents)};
std::move(inputs))};
}
bool RoPE::is_equivalent(const Primitive& other) const {

View File

@ -25,9 +25,10 @@ array rope(
const array& x,
int dims,
bool traditional,
float base,
std::optional<float> base,
float scale,
int offset,
const std::optional<array>& freqs = std::nullopt,
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/

View File

@ -79,12 +79,13 @@ void init_fast(nb::module_& parent_module) {
"dims"_a,
nb::kw_only(),
"traditional"_a,
"base"_a,
"base"_a.none(),
"scale"_a,
"offset"_a,
"freqs"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
@ -94,11 +95,13 @@ void init_fast(nb::module_& parent_module) {
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Exactly one of ``base`` and
``freqs`` must be ``None``.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
freqs (array, optional): Optional frequencies to use with RoPE.
If set, the ``base`` parameter must be ``None``. ``Default: None``.
Returns:
array: The output array.
)pbdoc");
@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.

View File

@ -7,13 +7,18 @@ import mlx.core as mx
import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset):
def rope_orig(x, dims, traditional, base, scale, offset, freqs=None):
N = x.shape[1] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
if freqs is None:
inv_freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
else:
inv_freqs = 1 / freqs
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., :dims:2]
@ -138,6 +143,84 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rope_with_freqs(self):
# Check throws
T = 4
dims = 8
x = mx.random.uniform(shape=(2, T, dims))
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(dims - 1,))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
with self.assertRaises(ValueError):
freqs = mx.random.uniform(shape=(1, dims))
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
freqs = mx.random.uniform(shape=(dims // 2,))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test single vector
x = mx.random.uniform(shape=(1, 1, dims))
rx = rope_orig(x, dims, False, None, 1.0, 0, freqs)
rx_fast = mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
# Test grad with freqs
f1 = lambda x, y: (rope_orig(x, dims, False, None, 1.0, 0, freqs) * y).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=False,
base=None,
scale=1.0,
offset=0,
freqs=freqs,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 4, dims))
y = mx.random.uniform(shape=(2, 4, dims))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rope_grad(self):
D = 32
defaults = (D, 10000.0, 1.0, 0, False)