mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
no copy batch rope (#2595)
This commit is contained in:
@@ -3,7 +3,12 @@
|
|||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
template <typename T, bool traditional, bool forward>
|
|
||||||
|
constant bool forward [[function_constant(1)]];
|
||||||
|
constant bool traditional [[function_constant(2)]];
|
||||||
|
constant bool hs_transpose [[function_constant(3)]];
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void rope_single_impl(
|
void rope_single_impl(
|
||||||
const device T* in,
|
const device T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
@@ -46,7 +51,7 @@ void rope_single_impl(
|
|||||||
out[index_2] = static_cast<T>(rx2);
|
out[index_2] = static_cast<T>(rx2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward>
|
template <typename T>
|
||||||
[[kernel]] void rope_single(
|
[[kernel]] void rope_single(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
@@ -58,11 +63,10 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint2 grid [[threads_per_grid]]) {
|
uint2 grid [[threads_per_grid]]) {
|
||||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||||
float inv_freq = metal::exp2(-d * base);
|
float inv_freq = metal::exp2(-d * base);
|
||||||
rope_single_impl<T, traditional, forward>(
|
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward>
|
template <typename T>
|
||||||
[[kernel]] void rope_single_freqs(
|
[[kernel]] void rope_single_freqs(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
@@ -74,11 +78,10 @@ template <typename T, bool traditional, bool forward>
|
|||||||
uint2 pos [[thread_position_in_grid]],
|
uint2 pos [[thread_position_in_grid]],
|
||||||
uint2 grid [[threads_per_grid]]) {
|
uint2 grid [[threads_per_grid]]) {
|
||||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
rope_single_impl<T, traditional, forward>(
|
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward, int N = 4>
|
template <typename T, typename IdxT, int N = 4>
|
||||||
void rope_impl(
|
void rope_impl(
|
||||||
const device T* in,
|
const device T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
@@ -102,23 +105,29 @@ void rope_impl(
|
|||||||
float theta = L * inv_freq;
|
float theta = L * inv_freq;
|
||||||
float costheta = metal::fast::cos(theta);
|
float costheta = metal::fast::cos(theta);
|
||||||
float sintheta = metal::fast::sin(theta);
|
float sintheta = metal::fast::sin(theta);
|
||||||
|
|
||||||
// Compute the input and output indices
|
// Compute the input and output indices
|
||||||
size_t in_index_1, in_index_2;
|
IdxT in_index_1;
|
||||||
size_t out_index_1, out_index_2;
|
if (hs_transpose) {
|
||||||
if (traditional) {
|
IdxT batch_stride = grid.y * IdxT(strides[1]);
|
||||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
|
||||||
mat_idx * out_strides[0];
|
|
||||||
out_index_2 = out_index_1 + 1;
|
|
||||||
in_index_1 =
|
in_index_1 =
|
||||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
|
||||||
in_index_2 = in_index_1 + strides[2];
|
|
||||||
} else {
|
} else {
|
||||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
|
||||||
mat_idx * out_strides[0];
|
}
|
||||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
IdxT in_index_2;
|
||||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
IdxT out_index_1 =
|
||||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
|
||||||
|
IdxT out_index_2;
|
||||||
|
if (traditional) {
|
||||||
|
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
|
||||||
|
out_index_2 = out_index_1 + 1;
|
||||||
|
in_index_1 += 2 * pos.x * IdxT(strides[2]);
|
||||||
|
in_index_2 = in_index_1 + IdxT(strides[2]);
|
||||||
|
} else {
|
||||||
|
out_index_1 += pos.x * IdxT(out_strides[2]);
|
||||||
|
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
|
||||||
|
in_index_1 += pos.x * IdxT(strides[2]);
|
||||||
|
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
|
||||||
}
|
}
|
||||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||||
// Read and write the output
|
// Read and write the output
|
||||||
@@ -135,14 +144,14 @@ void rope_impl(
|
|||||||
}
|
}
|
||||||
out[out_index_1] = static_cast<T>(rx1);
|
out[out_index_1] = static_cast<T>(rx1);
|
||||||
out[out_index_2] = static_cast<T>(rx2);
|
out[out_index_2] = static_cast<T>(rx2);
|
||||||
in_index_1 += strides[0];
|
in_index_1 += IdxT(strides[0]);
|
||||||
in_index_2 += strides[0];
|
in_index_2 += IdxT(strides[0]);
|
||||||
out_index_1 += out_strides[0];
|
out_index_1 += IdxT(out_strides[0]);
|
||||||
out_index_2 += out_strides[0];
|
out_index_2 += IdxT(out_strides[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward, int N = 4>
|
template <typename T, typename IdxT, int N = 4>
|
||||||
[[kernel]] void rope(
|
[[kernel]] void rope(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
@@ -157,7 +166,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
uint3 grid [[threads_per_grid]]) {
|
uint3 grid [[threads_per_grid]]) {
|
||||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||||
float inv_freq = metal::exp2(-d * base);
|
float inv_freq = metal::exp2(-d * base);
|
||||||
rope_impl<T, traditional, forward, N>(
|
rope_impl<T, IdxT, N>(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
offset,
|
offset,
|
||||||
@@ -171,7 +180,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
grid);
|
grid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool traditional, bool forward, int N = 4>
|
template <typename T, typename IdxT, int N = 4>
|
||||||
[[kernel]] void rope_freqs(
|
[[kernel]] void rope_freqs(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device T* out [[buffer(1)]],
|
device T* out [[buffer(1)]],
|
||||||
@@ -186,7 +195,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
uint3 pos [[thread_position_in_grid]],
|
uint3 pos [[thread_position_in_grid]],
|
||||||
uint3 grid [[threads_per_grid]]) {
|
uint3 grid [[threads_per_grid]]) {
|
||||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||||
rope_impl<T, traditional, forward, N>(
|
rope_impl<T, IdxT, N>(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
offset,
|
offset,
|
||||||
@@ -201,27 +210,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
#define instantiate_rope_g(name, type) \
|
||||||
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
|
instantiate_kernel("rope_" #name, rope, type, int32_t) \
|
||||||
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
|
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
|
||||||
|
instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
|
||||||
|
instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
|
||||||
|
|
||||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
#define instantiate_rope_s(name, type) \
|
||||||
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
|
instantiate_kernel("rope_single_" #name, rope_single, type) \
|
||||||
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
|
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
|
||||||
|
|
||||||
#define instantiate_rope(name, type, traditional, forward) \
|
#define instantiate_rope(name, type) \
|
||||||
instantiate_rope_s(name, type, traditional, forward) \
|
instantiate_rope_s(name, type) \
|
||||||
instantiate_rope_g(name, type, traditional, forward)
|
instantiate_rope_g(name, type)
|
||||||
|
|
||||||
instantiate_rope(traditional_float16, half, true, true)
|
instantiate_rope(float16, half)
|
||||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
instantiate_rope(bfloat16, bfloat16_t)
|
||||||
instantiate_rope(traditional_float32, float, true, true)
|
instantiate_rope(float32, float) // clang-format on
|
||||||
instantiate_rope(float16, half, false, true)
|
|
||||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
|
||||||
instantiate_rope(float32, float, false, true)
|
|
||||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
|
||||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
|
||||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
|
||||||
instantiate_rope(vjp_float16, half, false, false)
|
|
||||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
|
||||||
instantiate_rope(vjp_float32, float, false, false) // clang-format on
|
|
||||||
|
@@ -29,6 +29,7 @@ void RoPE::eval_gpu(
|
|||||||
int T = in.shape(-2);
|
int T = in.shape(-2);
|
||||||
int D = in.shape(-1);
|
int D = in.shape(-1);
|
||||||
size_t mat_size = T * D;
|
size_t mat_size = T * D;
|
||||||
|
bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;
|
||||||
|
|
||||||
int dispatch_ndim = ndim;
|
int dispatch_ndim = ndim;
|
||||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||||
@@ -40,6 +41,8 @@ void RoPE::eval_gpu(
|
|||||||
N *= in.shape(i);
|
N *= in.shape(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool head_seq_transpose = false;
|
||||||
|
|
||||||
if (dims_ < D) {
|
if (dims_ < D) {
|
||||||
donated = true;
|
donated = true;
|
||||||
auto ctype =
|
auto ctype =
|
||||||
@@ -64,6 +67,17 @@ void RoPE::eval_gpu(
|
|||||||
strides[0] = in.strides()[ndim - 3];
|
strides[0] = in.strides()[ndim - 3];
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
|
} else if (
|
||||||
|
ndim == 4 &&
|
||||||
|
// batch dim is regularly strided
|
||||||
|
in.strides()[0] == T * N * D &&
|
||||||
|
// sequence and head dimensions are transposed
|
||||||
|
in.strides()[1] == D && in.strides()[2] == N * D) {
|
||||||
|
head_seq_transpose = true;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
strides[0] = in.strides()[1];
|
||||||
|
strides[1] = in.strides()[2];
|
||||||
|
strides[2] = in.strides()[3];
|
||||||
} else {
|
} else {
|
||||||
// Copy non-contiguous > 3D inputs into the output and treat
|
// Copy non-contiguous > 3D inputs into the output and treat
|
||||||
// input as donated
|
// input as donated
|
||||||
@@ -77,15 +91,33 @@ void RoPE::eval_gpu(
|
|||||||
out_strides[1] = out.strides()[ndim - 2];
|
out_strides[1] = out.strides()[ndim - 2];
|
||||||
out_strides[2] = out.strides()[ndim - 1];
|
out_strides[2] = out.strides()[ndim - 1];
|
||||||
|
|
||||||
// Special case for inference (single batch, single time step, and contiguous)
|
// Special case for inference (single time step, contiguous, one offset)
|
||||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
auto& offset = inputs[1];
|
||||||
|
bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;
|
||||||
|
|
||||||
bool with_freqs = inputs.size() == 3;
|
bool with_freqs = inputs.size() == 3;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "rope_" << (single ? "single_" : "")
|
concatenate(
|
||||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
kname,
|
||||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
"rope_",
|
||||||
auto kernel = d.get_kernel(kname.str());
|
single ? "single_" : "",
|
||||||
|
(with_freqs) ? "freqs_" : "",
|
||||||
|
large ? "large_" : "",
|
||||||
|
type_to_name(in));
|
||||||
|
std::string hash_name;
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
kname,
|
||||||
|
"_",
|
||||||
|
forward_ ? "" : "vjp_",
|
||||||
|
traditional_ ? "traditional_" : "",
|
||||||
|
head_seq_transpose ? "transpose" : "");
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&forward_, MTL::DataType::DataTypeBool, 1},
|
||||||
|
{&traditional_, MTL::DataType::DataTypeBool, 2},
|
||||||
|
{&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|
||||||
float base = std::log2(base_);
|
float base = std::log2(base_);
|
||||||
@@ -93,7 +125,7 @@ void RoPE::eval_gpu(
|
|||||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(inputs[1], 2);
|
compute_encoder.set_input_array(offset, 2);
|
||||||
compute_encoder.set_bytes(scale_, 3);
|
compute_encoder.set_bytes(scale_, 3);
|
||||||
|
|
||||||
MTL::Size group_dims;
|
MTL::Size group_dims;
|
||||||
@@ -107,8 +139,8 @@ void RoPE::eval_gpu(
|
|||||||
compute_encoder.set_bytes(strides, 3, 4);
|
compute_encoder.set_bytes(strides, 3, 4);
|
||||||
compute_encoder.set_bytes(out_strides, 3, 5);
|
compute_encoder.set_bytes(out_strides, 3, 5);
|
||||||
int64_t offset_stride = 0;
|
int64_t offset_stride = 0;
|
||||||
if (inputs[1].ndim() > 0) {
|
if (offset.ndim() > 0) {
|
||||||
offset_stride = inputs[1].strides()[0];
|
offset_stride = offset.strides()[0];
|
||||||
}
|
}
|
||||||
compute_encoder.set_bytes(offset_stride, 6);
|
compute_encoder.set_bytes(offset_stride, 6);
|
||||||
compute_encoder.set_bytes(N, 7);
|
compute_encoder.set_bytes(N, 7);
|
||||||
|
Reference in New Issue
Block a user