mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
1a48713d32
commit
ccf1645995
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_rope():
|
||||||
|
rope = nn.RoPE(4096)
|
||||||
|
|
||||||
|
# vec
|
||||||
|
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_vec(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_vec, x)
|
||||||
|
|
||||||
|
# matrix
|
||||||
|
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_mat(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_mat, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rope()
|
@ -3,9 +3,10 @@ target_sources(
|
|||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
|
@ -11,6 +11,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
14
mlx/backend/common/rope.cpp
Normal file
14
mlx/backend/common/rope.cpp
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
void RoPE::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
@ -32,6 +32,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
@ -23,6 +23,7 @@ set(
|
|||||||
"quantized"
|
"quantized"
|
||||||
"random"
|
"random"
|
||||||
"reduce"
|
"reduce"
|
||||||
|
"rope"
|
||||||
"scan"
|
"scan"
|
||||||
"softmax"
|
"softmax"
|
||||||
"sort"
|
"sort"
|
||||||
|
68
mlx/backend/metal/kernels/rope.metal
Normal file
68
mlx/backend/metal/kernels/rope.metal
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
template <typename T, bool traditional>
|
||||||
|
[[kernel]] void rope(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device T * out [[buffer(1)]],
|
||||||
|
constant const size_t strides[3],
|
||||||
|
constant const int& offset,
|
||||||
|
constant const float& base,
|
||||||
|
constant const float& scale,
|
||||||
|
uint3 pos [[thread_position_in_grid]],
|
||||||
|
uint3 grid [[threads_per_grid]]) {
|
||||||
|
// Compute the input and output indices
|
||||||
|
uint in_index_1, in_index_2;
|
||||||
|
uint out_index_1, out_index_2;
|
||||||
|
if (traditional) {
|
||||||
|
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||||
|
out_index_2 = out_index_1 + 1;
|
||||||
|
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
|
in_index_2 = in_index_1 + strides[2];
|
||||||
|
} else {
|
||||||
|
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||||
|
out_index_2 = out_index_1 + grid.x;
|
||||||
|
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||||
|
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out L and d.
|
||||||
|
float L = scale * static_cast<float>(pos.y + offset);
|
||||||
|
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||||
|
|
||||||
|
// Compute costheta, sintheta
|
||||||
|
float theta = L * metal::exp2(-d * base);
|
||||||
|
float costheta = metal::fast::cos(theta);
|
||||||
|
float sintheta = metal::fast::sin(theta);
|
||||||
|
|
||||||
|
// Read and write the output
|
||||||
|
float x1 = static_cast<float>(in[in_index_1]);
|
||||||
|
float x2 = static_cast<float>(in[in_index_2]);
|
||||||
|
float rx1 = x1 * costheta - x2 * sintheta;
|
||||||
|
float rx2 = x1 * sintheta + x2 * costheta;
|
||||||
|
out[out_index_1] = static_cast<T>(rx1);
|
||||||
|
out[out_index_2] = static_cast<T>(rx2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_rope(name, type, traditional) \
|
||||||
|
template [[host_name("rope_" #name)]] \
|
||||||
|
[[kernel]] void rope<type, traditional>( \
|
||||||
|
const device type* in [[buffer(0)]], \
|
||||||
|
device type* out [[buffer(1)]], \
|
||||||
|
constant const size_t strides[3], \
|
||||||
|
constant const int& offset, \
|
||||||
|
constant const float& base, \
|
||||||
|
constant const float& scale, \
|
||||||
|
uint3 pos [[thread_position_in_grid]], \
|
||||||
|
uint3 grid [[threads_per_grid]]);
|
||||||
|
|
||||||
|
instantiate_rope(traditional_float16, half, true)
|
||||||
|
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||||
|
instantiate_rope(traditional_float32, float, true)
|
||||||
|
instantiate_rope(float16, half, false)
|
||||||
|
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||||
|
instantiate_rope(float32, float, false)
|
55
mlx/backend/metal/rope.cpp
Normal file
55
mlx/backend/metal/rope.cpp
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
void RoPE::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
if (in.ndim() != 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||||
|
}
|
||||||
|
if (dims_ != in.shape(-1)) {
|
||||||
|
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||||
|
}
|
||||||
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|
||||||
|
bool donated = in.data_shared_ptr() == nullptr;
|
||||||
|
float base = std::log2(base_);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||||
|
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||||
|
|
||||||
|
int dim0 = in.shape(2) / 2;
|
||||||
|
int dim1 = in.shape(1);
|
||||||
|
int dim2 = in.shape(0);
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||||
|
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
@ -95,4 +96,8 @@ NO_GPU(Tan)
|
|||||||
NO_GPU(Tanh)
|
NO_GPU(Tanh)
|
||||||
NO_GPU(Transpose)
|
NO_GPU(Transpose)
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
NO_GPU_MULTI(RoPE)
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
128
mlx/fast.cpp
Normal file
128
mlx/fast.cpp
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
std::vector<array> Custom::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||||
|
std::vector<array> vjp_outs;
|
||||||
|
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||||
|
if (i < argnums.size() && i == argnums[j]) {
|
||||||
|
vjp_outs.push_back(vjps[i]);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjp_outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Custom::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||||
|
std::vector<array> jvp_outs;
|
||||||
|
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||||
|
if (i < argnums.size() && i == argnums[j]) {
|
||||||
|
jvp_outs.push_back(jvps[i]);
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jvp_outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
|
||||||
|
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||||
|
return {outputs, out_axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
array rope(
|
||||||
|
const array& x,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
float base,
|
||||||
|
float scale,
|
||||||
|
int offset,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (x.ndim() != 3) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (traditional && x.shape(-1) != dims) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[rope] Does not support partial traditional application.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto t = x.dtype();
|
||||||
|
auto N = x.shape(1) + offset;
|
||||||
|
// Compute sines and cosines
|
||||||
|
auto half_dims = dims / 2;
|
||||||
|
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||||
|
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||||
|
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||||
|
auto theta =
|
||||||
|
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||||
|
auto coss = cos(theta, s);
|
||||||
|
auto sins = sin(theta, s);
|
||||||
|
|
||||||
|
if (traditional) {
|
||||||
|
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||||
|
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||||
|
std::vector<array> outs;
|
||||||
|
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||||
|
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||||
|
for (auto& o : outs) {
|
||||||
|
o = expand_dims(o, 3, s);
|
||||||
|
}
|
||||||
|
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||||
|
} else {
|
||||||
|
auto out_s = x.shape();
|
||||||
|
out_s.back() = half_dims;
|
||||||
|
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||||
|
out_s.back() = dims;
|
||||||
|
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||||
|
|
||||||
|
std::vector<array> outs;
|
||||||
|
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||||
|
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||||
|
if (dims < x.shape(-1)) {
|
||||||
|
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||||
|
}
|
||||||
|
return std::vector<array>{concatenate(outs, 2, s)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// TODO change to condition for using custom prim
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||||
|
return array(
|
||||||
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_unique<RoPE>(
|
||||||
|
stream, fallback, dims, traditional, base, scale, offset),
|
||||||
|
{x});
|
||||||
|
}
|
||||||
|
return fallback({x})[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||||
|
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||||
|
return (
|
||||||
|
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||||
|
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||||
|
offset_ == a_other.offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
82
mlx/fast.h
Normal file
82
mlx/fast.h
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
// Custom primitive accepts a fallback function which it uses for
|
||||||
|
// transformations. Transformations are virtual so that derived classes may to
|
||||||
|
// override the default behavior
|
||||||
|
class Custom : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Custom(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||||
|
: Primitive(stream), fallback_(fallback){};
|
||||||
|
|
||||||
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
virtual std::vector<array> jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
|
virtual std::vector<array> vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
};
|
||||||
|
|
||||||
|
array rope(
|
||||||
|
const array& x,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
float base,
|
||||||
|
float scale,
|
||||||
|
int offset,
|
||||||
|
StreamOrDevice s /* = {} */);
|
||||||
|
|
||||||
|
class RoPE : public Custom {
|
||||||
|
public:
|
||||||
|
RoPE(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
float base,
|
||||||
|
float scale,
|
||||||
|
int offset)
|
||||||
|
: Custom(stream, fallback),
|
||||||
|
dims_(dims),
|
||||||
|
traditional_(traditional),
|
||||||
|
base_(base),
|
||||||
|
scale_(scale),
|
||||||
|
offset_(offset){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(RoPE)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
int dims_;
|
||||||
|
bool traditional_;
|
||||||
|
float base_;
|
||||||
|
float scale_;
|
||||||
|
int offset_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/fast.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -20,20 +20,13 @@ class RoPE(Module):
|
|||||||
Args:
|
Args:
|
||||||
dims (int): The feature dimensions to be rotated. If the input feature
|
dims (int): The feature dimensions to be rotated. If the input feature
|
||||||
is larger than dims then the rest is left unchanged.
|
is larger than dims then the rest is left unchanged.
|
||||||
traditional (bool, optional): If set to True choose the traditional
|
traditional (bool, optional): If set to ``True`` choose the traditional
|
||||||
implementation which is slightly less efficient. Default: ``False``.
|
implementation which is slightly less efficient. Default: ``False``.
|
||||||
base (float, optional): The base used to compute angular frequency for
|
base (float, optional): The base used to compute angular frequency for
|
||||||
each dimension in the positional encodings. Default: ``10000``.
|
each dimension in the positional encodings. Default: ``10000``.
|
||||||
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
||||||
|
|
||||||
Attributes:
|
|
||||||
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
|
|
||||||
_cos_sin_theta_value (tuple): Cached cosine and sine values.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_cos_sin_theta_key = None
|
|
||||||
_cos_sin_theta_value = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
@ -50,69 +43,18 @@ class RoPE(Module):
|
|||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.dims}, traditional={self.traditional}"
|
return f"{self.dims}, traditional={self.traditional}"
|
||||||
|
|
||||||
def _compute_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., : self.dims // 2]
|
|
||||||
x2 = x[..., self.dims // 2 : self.dims]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
|
||||||
else:
|
|
||||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"RoPE doesn't implement partial traditional application"
|
|
||||||
)
|
|
||||||
|
|
||||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
def __call__(self, x, offset: int = 0):
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
N = x.shape[1] + offset
|
x = mx.fast.rope(
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
x,
|
||||||
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=self.base,
|
||||||
|
scale=self.scale,
|
||||||
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
return mx.reshape(x, shape)
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return mx.reshape(rx, shape)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_cos_sin_theta(
|
|
||||||
cls,
|
|
||||||
N: int,
|
|
||||||
D: int,
|
|
||||||
offset: int = 0,
|
|
||||||
base: float = 10000,
|
|
||||||
scale: float = 1.0,
|
|
||||||
dtype=mx.float32,
|
|
||||||
):
|
|
||||||
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
|
|
||||||
half_D = D // 2
|
|
||||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
|
||||||
freqs = mx.exp(
|
|
||||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
|
||||||
)
|
|
||||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
|
||||||
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
|
|
||||||
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
|
|
||||||
return cls._cos_sin_theta_value
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionalEncoding(Module):
|
class SinusoidalPositionalEncoding(Module):
|
||||||
|
@ -3,6 +3,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
|
59
python/src/fast.cpp
Normal file
59
python/src/fast.cpp
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void init_extensions(py::module_& parent_module) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
|
auto m =
|
||||||
|
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"rope",
|
||||||
|
[](const array& a,
|
||||||
|
int dims,
|
||||||
|
bool traditional,
|
||||||
|
float base,
|
||||||
|
float scale,
|
||||||
|
int offset,
|
||||||
|
const StreamOrDevice& s /* = {} */) {
|
||||||
|
return fast::rope(a, dims, traditional, base, scale, offset, s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"dims"_a,
|
||||||
|
py::kw_only(),
|
||||||
|
"traditional"_a,
|
||||||
|
"base"_a,
|
||||||
|
"scale"_a,
|
||||||
|
"offset"_a,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Apply rotary positional encoding to the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
dims (int): The feature dimensions to be rotated. If the input feature
|
||||||
|
is larger than dims then the rest is left unchanged.
|
||||||
|
traditional (bool): If set to ``True`` choose the traditional
|
||||||
|
implementation which rotates consecutive dimensions.
|
||||||
|
base (float): The base used to compute angular frequency for
|
||||||
|
each dimension in the positional encodings.
|
||||||
|
scale (float): The scale used to scale the positions.
|
||||||
|
offset (int): The position offset to start at.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -17,6 +17,7 @@ void init_random(py::module_&);
|
|||||||
void init_fft(py::module_&);
|
void init_fft(py::module_&);
|
||||||
void init_linalg(py::module_&);
|
void init_linalg(py::module_&);
|
||||||
void init_constants(py::module_&);
|
void init_constants(py::module_&);
|
||||||
|
void init_extensions(py::module_&);
|
||||||
|
|
||||||
PYBIND11_MODULE(core, m) {
|
PYBIND11_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) {
|
|||||||
init_fft(m);
|
init_fft(m);
|
||||||
init_linalg(m);
|
init_linalg(m);
|
||||||
init_constants(m);
|
init_constants(m);
|
||||||
|
init_extensions(m);
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ void init_random(py::module_& parent_module) {
|
|||||||
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
|
||||||
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
|
||||||
shape (list(int), optional): Shape of the output. Default is ``()``.
|
shape (list(int), optional): Shape of the output. Default is ``()``.
|
||||||
key (array, optional): A PRNG key. Default: None.
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
dtype (Dtype, optional): Type of the output. Default is ``float32``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
158
python/tests/test_fast.py
Normal file
158
python/tests/test_fast.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
|
||||||
|
def rope_orig(x, dims, traditional, base, scale, offset):
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
dtype = x.dtype
|
||||||
|
half_D = dims // 2
|
||||||
|
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||||
|
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
|
||||||
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||||
|
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||||
|
if traditional:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||||
|
return mx.reshape(rx, x.shape)
|
||||||
|
else:
|
||||||
|
x1 = x[..., : dims // 2]
|
||||||
|
x2 = x[..., dims // 2 : dims]
|
||||||
|
rx1 = x1 * costheta - x2 * sintheta
|
||||||
|
rx2 = x1 * sintheta + x2 * costheta
|
||||||
|
if dims < x.shape[-1]:
|
||||||
|
rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)
|
||||||
|
else:
|
||||||
|
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||||
|
return rx
|
||||||
|
|
||||||
|
|
||||||
|
class TestFast(mlx_tests.MLXTestCase):
|
||||||
|
def test_rope(self):
|
||||||
|
T = 4
|
||||||
|
|
||||||
|
# Defaults: dims, dtype, base, scale, offset, traditional
|
||||||
|
defaults = (8, mx.float32, 10000.0, 1.0, 0, False)
|
||||||
|
|
||||||
|
# Per dtype absolute tolerance
|
||||||
|
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||||
|
|
||||||
|
# Test cases:
|
||||||
|
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||||
|
bases = [10000.0, 1000000.0]
|
||||||
|
scales = [1.0, 2.0]
|
||||||
|
offsets = [0, 3]
|
||||||
|
traditional = [True, False]
|
||||||
|
|
||||||
|
for traditional in [True, False]:
|
||||||
|
dims, dtype, _, scale, offset, _ = defaults
|
||||||
|
for base in bases:
|
||||||
|
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
dims, _, base, scale, offset, _ = defaults
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||||
|
ry = rope_orig(
|
||||||
|
x.astype(mx.float32), dims, traditional, base, scale, offset
|
||||||
|
)
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
if dtype != mx.float32:
|
||||||
|
self.assertLessEqual(
|
||||||
|
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
dims, dtype, base, scale, _, _ = defaults
|
||||||
|
for offset in offsets:
|
||||||
|
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
dims, dtype, base, _, offset, _ = defaults
|
||||||
|
for scale in scales:
|
||||||
|
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||||
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
|
rx_fast = mx.fast.rope(
|
||||||
|
x,
|
||||||
|
dims,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scale=scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
def test_fast_transforms(self):
|
||||||
|
x = mx.random.uniform(shape=(2, 2, 8))
|
||||||
|
|
||||||
|
defaults = (8, False, 10000.0, 1.0, 0)
|
||||||
|
dims, traditional, base, scale, offset = defaults
|
||||||
|
|
||||||
|
# VJP
|
||||||
|
_, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||||
|
_, vjp_fast_out = mx.vjp(
|
||||||
|
lambda x: mx.fast.rope(
|
||||||
|
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||||
|
),
|
||||||
|
(x,),
|
||||||
|
(mx.ones_like(x),),
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))
|
||||||
|
|
||||||
|
# JVP
|
||||||
|
_, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
|
||||||
|
_, jvp_fast_out = mx.jvp(
|
||||||
|
lambda x: mx.fast.rope(
|
||||||
|
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||||
|
),
|
||||||
|
(x,),
|
||||||
|
(mx.ones_like(x),),
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
|
||||||
|
|
||||||
|
# VMAP
|
||||||
|
x = mx.random.uniform(shape=(2, 2, 2, 8))
|
||||||
|
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
|
||||||
|
vmap_fast_out = mx.vmap(
|
||||||
|
lambda x: mx.fast.rope(
|
||||||
|
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
|
||||||
|
)
|
||||||
|
)(x)
|
||||||
|
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user