mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal * cuda rope (#2576)
This commit is contained in:
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
const int* offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -123,20 +129,19 @@ __device__ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -167,7 +172,8 @@ __global__ void rope(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
@@ -182,12 +188,13 @@ __global__ void rope(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
|
||||
if (single && !with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
|
||||
} else if (with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user