mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
RoPE with frequencies as optional input (#1337)
* start rope with freq input * rope with frequencies * nits * fix bug * fix bug + test * cleanup * optional base
This commit is contained in:
@@ -4,23 +4,20 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
void rope_single_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
@@ -55,24 +52,54 @@ template <typename T, bool traditional, bool forward>
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
@@ -116,37 +143,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
|
||||
// clang-format off
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
|
@@ -67,8 +67,10 @@ void RoPE::eval_gpu(
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
bool with_freqs = inputs.size() == 2;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -78,27 +80,36 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
auto group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
auto grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
} else {
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
||||
if (with_freqs) {
|
||||
auto& freqs = inputs[1];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
|
||||
} else {
|
||||
compute_encoder->setBytes(&base, sizeof(float), 10);
|
||||
}
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
Reference in New Issue
Block a user