From 43ffdab1721c66f7f2f93b9ef11c4255bcdd9c19 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 31 Jul 2024 16:18:25 -0700 Subject: [PATCH] fix rope and random (#1301) * fix rope and random * comment --- mlx/backend/metal/kernels/random.metal | 26 +++-- mlx/backend/metal/kernels/rope.metal | 146 +++++++++++++++++++------ mlx/backend/metal/rope.cpp | 39 +++++-- 3 files changed, 152 insertions(+), 59 deletions(-) diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index b0397dd66..58073d649 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -43,20 +43,22 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { auto half_size = grid_dim.y - odd; out += index.x * bytes_per_key; bool drop_last = odd && (index.y == half_size); - auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); - auto bits = threefry2x32_hash(key, count); + auto bits = threefry2x32_hash( + key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); + size_t idx = size_t(index.y) << 2; for (int i = 0; i < 4; ++i) { - out[4 * count.x + i] = bits.bytes[0][i]; + out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { - out[4 * count.y + i] = bits.bytes[1][i]; + out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { - out[4 * count.y + i] = bits.bytes[1][i]; + out[idx + i] = bits.bytes[1][i]; } } } @@ -77,22 +79,24 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); auto key = uint2(keys[k1_elem], keys[k2_elem]); auto half_size = grid_dim.y - odd; - out += index.x * bytes_per_key; + out += size_t(index.x) * bytes_per_key; bool drop_last = odd && (index.y == half_size); - auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); - auto bits = threefry2x32_hash(key, count); + auto bits = threefry2x32_hash( + key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); + size_t idx = size_t(index.y) << 2; for (int i = 0; i < 4; ++i) { - out[4 * count.x + i] = bits.bytes[0][i]; + out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { - out[4 * count.y + i] = bits.bytes[1][i]; + out[idx + i] = bits.bytes[1][i]; } } else { for (int i = 0; i < 4; ++i) { - out[4 * count.y + i] = bits.bytes[1][i]; + out[idx + i] = bits.bytes[1][i]; } } } diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index d53ec3845..ef218bf36 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -6,36 +6,17 @@ #include "mlx/backend/metal/kernels/utils.h" template -[[kernel]] void rope( +[[kernel]] void rope_single( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], - constant const size_t strides[3], - constant const size_t out_strides[3], constant const int& offset, constant const float& base, constant const float& scale, - uint3 pos [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Compute the input and output indices - uint in_index_1, in_index_2; - uint out_index_1, out_index_2; - if (traditional) { - out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - pos.z * out_strides[0]; - out_index_2 = out_index_1 + 1; - in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; - in_index_2 = in_index_1 + strides[2]; - } else { - out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - pos.z * out_strides[0]; - out_index_2 = out_index_1 + grid.x * out_strides[2]; - in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0]; - in_index_2 = in_index_1 + grid.x * strides[2]; - } - + constant const size_t& stride, + uint2 pos [[thread_position_in_grid]], + uint2 grid [[threads_per_grid]]) { // Figure out L and d. - float L = scale * static_cast(pos.y + offset); + float L = scale * static_cast(offset); float d = static_cast(pos.x) / static_cast(grid.x); // Compute costheta, sintheta @@ -43,6 +24,21 @@ template float costheta = metal::fast::cos(theta); float sintheta = metal::fast::sin(theta); + // Compute the input and output indices + uint in_index_1, in_index_2; + uint out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x + pos.y * stride; + out_index_2 = out_index_1 + 1; + in_index_1 = 2 * pos.x + pos.y * stride; + in_index_2 = in_index_1 + 1; + } else { + out_index_1 = pos.x + pos.y * stride; + out_index_2 = out_index_1 + grid.x; + in_index_1 = pos.x + pos.y * stride; + in_index_2 = in_index_1 + grid.x; + } + // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); @@ -59,19 +55,97 @@ template out[out_index_2] = static_cast(rx2); } -#define instantiate_rope(name, type, traditional, forward) \ - template [[host_name("rope_" #name)]] [[kernel]] void \ - rope( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const int& offset, \ - constant const float& base, \ - constant const float& scale, \ - uint3 pos [[thread_position_in_grid]], \ +template +[[kernel]] void rope( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const int& offset, + constant const float& base, + constant const float& scale, + constant const size_t strides[3], + constant const size_t out_strides[3], + constant const size_t& n_batch, + uint3 pos [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Figure out L and d. + float L = scale * static_cast(pos.y + offset); + float d = static_cast(pos.x) / static_cast(grid.x); + + // Compute costheta, sintheta + float theta = L * metal::exp2(-d * base); + float costheta = metal::fast::cos(theta); + float sintheta = metal::fast::sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + grid.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + grid.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +#define instantiate_rope_g(name, type, traditional, forward) \ + template [[host_name("rope_" #name)]] [[kernel]] void \ + rope( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + constant const size_t strides[3], \ + constant const size_t out_strides[3], \ + constant const size_t& n_batch, \ + uint3 pos [[thread_position_in_grid]], \ uint3 grid [[threads_per_grid]]); +#define instantiate_rope_s(name, type, traditional, forward) \ + template [[host_name("rope_single_" #name)]] [[kernel]] void \ + rope_single( \ + const device type* in [[buffer(0)]], \ + device type* out [[buffer(1)]], \ + constant const int& offset, \ + constant const float& base, \ + constant const float& scale, \ + constant const size_t& stride, \ + uint2 pos [[thread_position_in_grid]], \ + uint2 grid [[threads_per_grid]]); + +#define instantiate_rope(name, type, traditional, forward) \ + instantiate_rope_s(name, type, traditional, forward) \ + instantiate_rope_g(name, type, traditional, forward) + // clang-format off instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) @@ -84,4 +158,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false) instantiate_rope(vjp_traditional_float32, float, true, false) instantiate_rope(vjp_float16, half, false, false) instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) -instantiate_rope(vjp_float32, float, false, false) // clang-format on \ No newline at end of file +instantiate_rope(vjp_float32, float, false, false) // clang-format on diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index c19ad52a4..4d51fd75c 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -5,6 +5,8 @@ namespace mlx::core::fast { +constexpr int n_per_thread = 4; + void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -62,8 +64,11 @@ void RoPE::eval_gpu( out_strides[1] = out.strides()[ndim - 2]; out_strides[2] = out.strides()[ndim - 1]; + // Special case for inference (single time step and contiguous) + bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + std::ostringstream kname; - kname << "rope_" << (forward_ ? "" : "vjp_") + kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_") << (traditional_ ? "traditional_" : "") << type_to_name(in); auto kernel = d.get_kernel(kname.str()); auto& compute_encoder = d.get_command_encoder(s.index); @@ -72,18 +77,28 @@ void RoPE::eval_gpu( compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); - compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3); - compute_encoder->setBytes(&offset_, sizeof(int), 4); - compute_encoder->setBytes(&base, sizeof(float), 5); - compute_encoder->setBytes(&scale_, sizeof(float), 6); + compute_encoder->setBytes(&offset_, sizeof(int), 2); + compute_encoder->setBytes(&base, sizeof(float), 3); + compute_encoder->setBytes(&scale_, sizeof(float), 4); - int dim0 = dims_ / 2; - int dim1 = in.shape(-2); - int dim2 = in.size() / mat_size; - auto group_dims = get_block_dims(dim0, dim1, dim2); - auto grid_dims = MTL::Size(dim0, dim1, dim2); - compute_encoder.dispatchThreads(grid_dims, group_dims); + size_t n_batch = in.size() / mat_size; + if (single) { + compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5); + uint32_t dim0 = dims_ / 2; + auto group_dims = get_block_dims(dim0, n_batch, 1); + auto grid_dims = MTL::Size(dim0, n_batch, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } else { + compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5); + compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6); + compute_encoder->setBytes(&n_batch, sizeof(size_t), 7); + uint32_t dim0 = dims_ / 2; + uint32_t dim1 = in.shape(-2); + uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; + auto group_dims = get_block_dims(dim0, dim1, dim2); + auto grid_dims = MTL::Size(dim0, dim1, dim2); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core::fast