no copy batch rope (#2595)

This commit is contained in:
Awni Hannun
2025-09-15 14:23:48 -07:00
committed by GitHub
parent 8afb6d62f2
commit caecbe876a
2 changed files with 95 additions and 61 deletions

View File

@@ -3,7 +3,12 @@
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
constant bool forward [[function_constant(1)]];
constant bool traditional [[function_constant(2)]];
constant bool hs_transpose [[function_constant(3)]];
template <typename T>
void rope_single_impl(
const device T* in,
device T* out,
@@ -46,7 +51,7 @@ void rope_single_impl(
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -58,11 +63,10 @@ template <typename T, bool traditional, bool forward>
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -74,11 +78,10 @@ template <typename T, bool traditional, bool forward>
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
void rope_impl(
const device T* in,
device T* out,
@@ -102,23 +105,29 @@ void rope_impl(
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
IdxT in_index_1;
if (hs_transpose) {
IdxT batch_stride = grid.y * IdxT(strides[1]);
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
}
IdxT in_index_2;
IdxT out_index_1 =
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
IdxT out_index_2;
if (traditional) {
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + 1;
in_index_1 += 2 * pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + IdxT(strides[2]);
} else {
out_index_1 += pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
in_index_1 += pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
@@ -135,14 +144,14 @@ void rope_impl(
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
in_index_1 += IdxT(strides[0]);
in_index_2 += IdxT(strides[0]);
out_index_1 += IdxT(out_strides[0]);
out_index_2 += IdxT(out_strides[0]);
}
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -157,7 +166,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -171,7 +180,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -186,7 +195,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -201,27 +210,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
#define instantiate_rope_g(name, type) \
instantiate_kernel("rope_" #name, rope, type, int32_t) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
#define instantiate_rope_s(name, type, traditional, forward) \
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
#define instantiate_rope_s(name, type) \
instantiate_kernel("rope_single_" #name, rope_single, type) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
#define instantiate_rope(name, type) \
instantiate_rope_s(name, type) \
instantiate_rope_g(name, type)
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on
instantiate_rope(float16, half)
instantiate_rope(bfloat16, bfloat16_t)
instantiate_rope(float32, float) // clang-format on

View File

@@ -29,6 +29,7 @@ void RoPE::eval_gpu(
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
@@ -40,6 +41,8 @@ void RoPE::eval_gpu(
N *= in.shape(i);
}
bool head_seq_transpose = false;
if (dims_ < D) {
donated = true;
auto ctype =
@@ -64,6 +67,17 @@ void RoPE::eval_gpu(
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (
ndim == 4 &&
// batch dim is regularly strided
in.strides()[0] == T * N * D &&
// sequence and head dimensions are transposed
in.strides()[1] == D && in.strides()[2] == N * D) {
head_seq_transpose = true;
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[1];
strides[1] = in.strides()[2];
strides[2] = in.strides()[3];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
@@ -77,15 +91,33 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && B == 1 && T == 1;
// Special case for inference (single time step, contiguous, one offset)
auto& offset = inputs[1];
bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
std::string kname;
concatenate(
kname,
"rope_",
single ? "single_" : "",
(with_freqs) ? "freqs_" : "",
large ? "large_" : "",
type_to_name(in));
std::string hash_name;
concatenate(
hash_name,
kname,
"_",
forward_ ? "" : "vjp_",
traditional_ ? "traditional_" : "",
head_seq_transpose ? "transpose" : "");
metal::MTLFCList func_consts = {
{&forward_, MTL::DataType::DataTypeBool, 1},
{&traditional_, MTL::DataType::DataTypeBool, 2},
{&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};
auto kernel = d.get_kernel(kname, hash_name, func_consts);
auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_);
@@ -93,7 +125,7 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_input_array(offset, 2);
compute_encoder.set_bytes(scale_, 3);
MTL::Size group_dims;
@@ -107,8 +139,8 @@ void RoPE::eval_gpu(
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
if (offset.ndim() > 0) {
offset_stride = offset.strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);