mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
no copy batch rope (#2595)
This commit is contained in:
@@ -3,7 +3,12 @@
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
template <typename T, bool traditional, bool forward>
|
||||
|
||||
constant bool forward [[function_constant(1)]];
|
||||
constant bool traditional [[function_constant(2)]];
|
||||
constant bool hs_transpose [[function_constant(3)]];
|
||||
|
||||
template <typename T>
|
||||
void rope_single_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
@@ -46,7 +51,7 @@ void rope_single_impl(
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
template <typename T>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -58,11 +63,10 @@ template <typename T, bool traditional, bool forward>
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
template <typename T>
|
||||
[[kernel]] void rope_single_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -74,11 +78,10 @@ template <typename T, bool traditional, bool forward>
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
template <typename T, typename IdxT, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
@@ -102,23 +105,29 @@ void rope_impl(
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
size_t in_index_1, in_index_2;
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
IdxT in_index_1;
|
||||
if (hs_transpose) {
|
||||
IdxT batch_stride = grid.y * IdxT(strides[1]);
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
|
||||
}
|
||||
IdxT in_index_2;
|
||||
IdxT out_index_1 =
|
||||
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
|
||||
IdxT out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 += 2 * pos.x * IdxT(strides[2]);
|
||||
in_index_2 = in_index_1 + IdxT(strides[2]);
|
||||
} else {
|
||||
out_index_1 += pos.x * IdxT(out_strides[2]);
|
||||
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
|
||||
in_index_1 += pos.x * IdxT(strides[2]);
|
||||
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
|
||||
}
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
@@ -135,14 +144,14 @@ void rope_impl(
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
in_index_1 += strides[0];
|
||||
in_index_2 += strides[0];
|
||||
out_index_1 += out_strides[0];
|
||||
out_index_2 += out_strides[0];
|
||||
in_index_1 += IdxT(strides[0]);
|
||||
in_index_2 += IdxT(strides[0]);
|
||||
out_index_1 += IdxT(out_strides[0]);
|
||||
out_index_2 += IdxT(out_strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
template <typename T, typename IdxT, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -157,7 +166,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
rope_impl<T, IdxT, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
@@ -171,7 +180,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
template <typename T, typename IdxT, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
@@ -186,7 +195,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
rope_impl<T, IdxT, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
@@ -201,27 +210,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
||||
#define instantiate_rope_g(name, type) \
|
||||
instantiate_kernel("rope_" #name, rope, type, int32_t) \
|
||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
|
||||
instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
|
||||
instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
||||
#define instantiate_rope_s(name, type) \
|
||||
instantiate_kernel("rope_single_" #name, rope_single, type) \
|
||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
#define instantiate_rope(name, type) \
|
||||
instantiate_rope_s(name, type) \
|
||||
instantiate_rope_g(name, type)
|
||||
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
||||
instantiate_rope(float16, half)
|
||||
instantiate_rope(bfloat16, bfloat16_t)
|
||||
instantiate_rope(float32, float) // clang-format on
|
||||
|
@@ -29,6 +29,7 @@ void RoPE::eval_gpu(
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;
|
||||
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
@@ -40,6 +41,8 @@ void RoPE::eval_gpu(
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
bool head_seq_transpose = false;
|
||||
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
@@ -64,6 +67,17 @@ void RoPE::eval_gpu(
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (
|
||||
ndim == 4 &&
|
||||
// batch dim is regularly strided
|
||||
in.strides()[0] == T * N * D &&
|
||||
// sequence and head dimensions are transposed
|
||||
in.strides()[1] == D && in.strides()[2] == N * D) {
|
||||
head_seq_transpose = true;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
strides[0] = in.strides()[1];
|
||||
strides[1] = in.strides()[2];
|
||||
strides[2] = in.strides()[3];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
@@ -77,15 +91,33 @@ void RoPE::eval_gpu(
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Special case for inference (single batch, single time step, and contiguous)
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
// Special case for inference (single time step, contiguous, one offset)
|
||||
auto& offset = inputs[1];
|
||||
bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;
|
||||
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
std::string kname;
|
||||
concatenate(
|
||||
kname,
|
||||
"rope_",
|
||||
single ? "single_" : "",
|
||||
(with_freqs) ? "freqs_" : "",
|
||||
large ? "large_" : "",
|
||||
type_to_name(in));
|
||||
std::string hash_name;
|
||||
concatenate(
|
||||
hash_name,
|
||||
kname,
|
||||
"_",
|
||||
forward_ ? "" : "vjp_",
|
||||
traditional_ ? "traditional_" : "",
|
||||
head_seq_transpose ? "transpose" : "");
|
||||
metal::MTLFCList func_consts = {
|
||||
{&forward_, MTL::DataType::DataTypeBool, 1},
|
||||
{&traditional_, MTL::DataType::DataTypeBool, 2},
|
||||
{&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};
|
||||
|
||||
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
float base = std::log2(base_);
|
||||
@@ -93,7 +125,7 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_input_array(offset, 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
MTL::Size group_dims;
|
||||
@@ -107,8 +139,8 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_bytes(strides, 3, 4);
|
||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
if (offset.ndim() > 0) {
|
||||
offset_stride = offset.strides()[0];
|
||||
}
|
||||
compute_encoder.set_bytes(offset_stride, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
|
Reference in New Issue
Block a user