mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -3,9 +3,10 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
|
@@ -11,6 +11,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
|
14
mlx/backend/common/rope.cpp
Normal file
14
mlx/backend/common/rope.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -32,6 +32,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
|
@@ -23,6 +23,7 @@ set(
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"rope"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
|
68
mlx/backend/metal/kernels/rope.metal
Normal file
68
mlx/backend/metal/kernels/rope.metal
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
55
mlx/backend/metal/rope.cpp
Normal file
55
mlx/backend/metal/rope.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
@@ -95,4 +96,8 @@ NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(RoPE)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
128
mlx/fast.cpp
Normal file
128
mlx/fast.cpp
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
std::vector<array> Custom::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return vjp_outs;
|
||||
}
|
||||
|
||||
std::vector<array> Custom::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
|
||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() != 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
}
|
||||
};
|
||||
// TODO change to condition for using custom prim
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_unique<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
82
mlx/fast.h
Normal file
82
mlx/fast.h
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
// Custom primitive accepts a fallback function which it uses for
|
||||
// transformations. Transformations are virtual so that derived classes may to
|
||||
// override the default behavior
|
||||
class Custom : public Primitive {
|
||||
public:
|
||||
explicit Custom(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||
: Primitive(stream), fallback_(fallback){};
|
||||
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
virtual std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
virtual std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
};
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */);
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset)
|
||||
: Custom(stream, fallback),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int dims_;
|
||||
bool traditional_;
|
||||
float base_;
|
||||
float scale_;
|
||||
int offset_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
Reference in New Issue
Block a user