RoPE with frequencies as optional input (#1337)

* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
This commit is contained in:
Awni Hannun
2024-08-19 18:30:50 -07:00
committed by GitHub
parent 9d26441224
commit bb1b76d9dc
6 changed files with 319 additions and 69 deletions

View File

@@ -4,23 +4,20 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
void rope_single_impl(
const device T* in,
device T* out,
constant const int& offset,
constant const float& base,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
// Figure out L and d.
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& base,
constant const float& scale,
constant const size_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
@@ -116,37 +143,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
in,
out,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
uint2 pos [[thread_position_in_grid]], \
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)

View File

@@ -67,8 +67,10 @@ void RoPE::eval_gpu(
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool with_freqs = inputs.size() == 2;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -78,27 +80,36 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}
if (with_freqs) {
auto& freqs = inputs[1];
compute_encoder.set_input_array(freqs, 10);
auto freq_stride = freqs.strides()[0];
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
} else {
compute_encoder->setBytes(&base, sizeof(float), 10);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast