mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -32,6 +32,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
|
@@ -23,6 +23,7 @@ set(
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"rope"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
|
68
mlx/backend/metal/kernels/rope.metal
Normal file
68
mlx/backend/metal/kernels/rope.metal
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
55
mlx/backend/metal/rope.cpp
Normal file
55
mlx/backend/metal/rope.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
Reference in New Issue
Block a user