mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
@@ -22,23 +22,18 @@ void rope_single_impl(
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + 1;
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
out_index_1 = pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + grid.x;
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + grid.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
@@ -48,8 +43,8 @@ void rope_single_impl(
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
|
@@ -86,7 +86,7 @@ void RoPE::eval_gpu(
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
|
Reference in New Issue
Block a user