mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal * cuda rope (#2576)
This commit is contained in:
@@ -10,7 +10,7 @@ void rope_single_impl(
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
@@ -52,7 +52,7 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
@@ -68,9 +68,9 @@ template <typename T, bool traditional, bool forward>
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const int64_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -82,15 +82,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -102,20 +108,19 @@ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -141,11 +146,12 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
@@ -159,7 +165,8 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
@@ -168,13 +175,14 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
const device int* offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const int64_t strides[3],
|
||||
constant const int64_t out_strides[3],
|
||||
constant const int64_t& offset_stride,
|
||||
constant const int& n_head,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
constant const int64_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
@@ -186,61 +194,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
|
@@ -18,23 +18,29 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
size_t out_strides[3];
|
||||
int64_t strides[3];
|
||||
int64_t out_strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
if (dims_ < in.shape(-1)) {
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -71,8 +77,8 @@ void RoPE::eval_gpu(
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
// Special case for inference (single batch, single time step, and contiguous)
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
@@ -86,24 +92,29 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder.set_bytes(out_strides, 1, 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
group_dims = get_block_dims(dim0, N, 1);
|
||||
grid_dims = MTL::Size(dim0, N, 1);
|
||||
} else {
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
compute_encoder.set_bytes(n_batch, 6);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
compute_encoder.set_bytes(offset_stride, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
uint32_t dim1 = T;
|
||||
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
Reference in New Issue
Block a user