mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fast RMS Norm (#862)
* fast rmsnorm * no rms gpu * kernel * fix shared mem * looped rms and donation in softmax * Make the squaring in float32 to avoid underflow * Fix the default StreamOrDevice for rope and rms_norm in fast * nits --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
4650d94d98
commit
a54f06b16f
@ -44,7 +44,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/fast_primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
|
||||||
|
|
||||||
void RoPE::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
throw std::runtime_error("NYI");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
|
@ -67,11 +67,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
array in = check_input(std::move(inputs[0]));
|
array in = check_input(std::move(inputs[0]));
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
|
}
|
||||||
|
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
|
@ -33,6 +33,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
@ -23,6 +23,7 @@ set(
|
|||||||
"gemv"
|
"gemv"
|
||||||
"quantized"
|
"quantized"
|
||||||
"random"
|
"random"
|
||||||
|
"rms_norm"
|
||||||
"rope"
|
"rope"
|
||||||
"scan"
|
"scan"
|
||||||
"scaled_dot_product_attention"
|
"scaled_dot_product_attention"
|
||||||
|
@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
|||||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||||
|
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||||
|
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||||
|
194
mlx/backend/metal/kernels/rms_norm.metal
Normal file
194
mlx/backend/metal/kernels/rms_norm.metal
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_common>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename T, int N_READS = RMS_N_READS>
|
||||||
|
[[kernel]] void rms_single_row(
|
||||||
|
const device T* x,
|
||||||
|
const device T* w,
|
||||||
|
device T* out,
|
||||||
|
constant float& eps,
|
||||||
|
constant uint& axis_size,
|
||||||
|
constant uint& w_stride,
|
||||||
|
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||||
|
threadgroup float* local_sums [[threadgroup(1)]],
|
||||||
|
uint gid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
float acc = 0;
|
||||||
|
x += gid * axis_size + lid * N_READS;
|
||||||
|
w += w_stride * lid * N_READS;
|
||||||
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = x[i];
|
||||||
|
acc += xi * xi;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
|
float xi = x[i];
|
||||||
|
acc += xi * xi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
acc = simd_sum(acc);
|
||||||
|
// Initialize shared memory
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
local_sums[simd_lane_id] = 0;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write simd accumulations into shared memory
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_sums[simd_group_id] = acc;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Accumulate over simd groups
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
acc = simd_sum(local_sums[simd_lane_id]);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write the outputs
|
||||||
|
out += gid * axis_size + lid * N_READS;
|
||||||
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if ((lid * N_READS + i) < axis_size) {
|
||||||
|
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N_READS = RMS_N_READS>
|
||||||
|
[[kernel]] void rms_looped(
|
||||||
|
const device T* x,
|
||||||
|
const device T* w,
|
||||||
|
device T* out,
|
||||||
|
constant float& eps,
|
||||||
|
constant uint& axis_size,
|
||||||
|
constant uint& w_stride,
|
||||||
|
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||||
|
threadgroup float* local_sums [[threadgroup(1)]],
|
||||||
|
uint gid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint lsize [[threads_per_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
float acc = 0;
|
||||||
|
x += gid * axis_size + lid * N_READS;
|
||||||
|
w += w_stride * lid * N_READS;
|
||||||
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = x[i + r];
|
||||||
|
acc += xi * xi;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
|
float xi = x[i + r];
|
||||||
|
acc += xi * xi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
acc = simd_sum(acc);
|
||||||
|
// Initialize shared memory
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
local_sums[simd_lane_id] = 0;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write simd accumulations into shared memory
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_sums[simd_group_id] = acc;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Accumulate over simd groups
|
||||||
|
if (simd_group_id == 0) {
|
||||||
|
acc = simd_sum(local_sums[simd_lane_id]);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write the outputs
|
||||||
|
out += gid * axis_size + lid * N_READS;
|
||||||
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
out[r + i] = w[w_stride * (i + r)] *
|
||||||
|
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
if ((r + lid * N_READS + i) < axis_size) {
|
||||||
|
out[r + i] = w[w_stride * (i + r)] *
|
||||||
|
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_rms_single_row(name, itype) \
|
||||||
|
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||||
|
rms_single_row<itype>( \
|
||||||
|
const device itype* x, \
|
||||||
|
const device itype* w, \
|
||||||
|
device itype* out, \
|
||||||
|
constant float& eps, \
|
||||||
|
constant uint& axis_size, \
|
||||||
|
constant uint& w_stride, \
|
||||||
|
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||||
|
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_rms_looped(name, itype) \
|
||||||
|
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||||
|
rms_looped<itype>( \
|
||||||
|
const device itype* x, \
|
||||||
|
const device itype* w, \
|
||||||
|
device itype* out, \
|
||||||
|
constant float& eps, \
|
||||||
|
constant uint& axis_size, \
|
||||||
|
constant uint& w_stride, \
|
||||||
|
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||||
|
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_rms(name, itype) \
|
||||||
|
instantiate_rms_single_row(name, itype) \
|
||||||
|
instantiate_rms_looped(name, itype)
|
||||||
|
|
||||||
|
instantiate_rms(float32, float)
|
||||||
|
instantiate_rms(float16, half)
|
||||||
|
instantiate_rms(bfloat16, bfloat16_t)
|
||||||
|
// clang-format on
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <metal_atomic>
|
|
||||||
#include <metal_common>
|
#include <metal_common>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
@ -224,5 +223,6 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
|||||||
instantiate_softmax_single_row(name, itype) \
|
instantiate_softmax_single_row(name, itype) \
|
||||||
instantiate_softmax_looped(name, itype)
|
instantiate_softmax_looped(name, itype)
|
||||||
|
|
||||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
instantiate_softmax(float32, float)
|
||||||
|
instantiate_softmax(float16, half)
|
||||||
instantiate_softmax(bfloat16, bfloat16_t)
|
instantiate_softmax(bfloat16, bfloat16_t)
|
||||||
|
98
mlx/backend/metal/rms_norm.cpp
Normal file
98
mlx/backend/metal/rms_norm.cpp
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
void RMSNorm::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto check_input = [&copies, &s](const array& x) {
|
||||||
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||||
|
if (x.ndim() > 1) {
|
||||||
|
auto s = x.strides()[x.ndim() - 2];
|
||||||
|
no_copy &= (s == 0 || s == x.shape().back());
|
||||||
|
}
|
||||||
|
if (no_copy) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_gpu(x, x_copy, CopyType::General, s);
|
||||||
|
copies.push_back(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const array& x = check_input(inputs[0]);
|
||||||
|
const array& w = inputs[1];
|
||||||
|
|
||||||
|
if (x.is_donatable()) {
|
||||||
|
out.move_shared_buffer(x);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||||
|
int n_rows = x.data_size() / axis_size;
|
||||||
|
|
||||||
|
const int simd_size = 32;
|
||||||
|
const int n_reads = RMS_N_READS;
|
||||||
|
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||||
|
std::string op_name = "rms";
|
||||||
|
if (axis_size > looped_limit) {
|
||||||
|
op_name += "_looped";
|
||||||
|
}
|
||||||
|
op_name += type_to_name(out);
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
{
|
||||||
|
auto kernel = d.get_kernel(op_name);
|
||||||
|
|
||||||
|
MTL::Size grid_dims, group_dims;
|
||||||
|
if (axis_size <= looped_limit) {
|
||||||
|
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||||
|
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||||
|
size_t threadgroup_size = simd_size * simds_needed;
|
||||||
|
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
size_t n_threads = n_rows * threadgroup_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
|
} else {
|
||||||
|
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
size_t n_threads = n_rows * threadgroup_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t w_stride = w.strides()[0];
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(
|
||||||
|
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||||
|
set_array_buffer(compute_encoder, w, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||||
|
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(
|
||||||
|
16 * 8, 0); // minimum of 16 bytes
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fast
|
@ -37,11 +37,15 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
const array& in = check_input(inputs[0]);
|
const array& in = check_input(inputs[0]);
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
|
}
|
||||||
|
|
||||||
int axis_size = in.shape().back();
|
int axis_size = in.shape().back();
|
||||||
int n_rows = in.data_size() / axis_size;
|
int n_rows = in.data_size() / axis_size;
|
||||||
@ -75,6 +79,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(
|
||||||
|
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||||
|
@ -102,6 +102,7 @@ NO_GPU(Transpose)
|
|||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
|
NO_GPU_MULTI(RMSNorm)
|
||||||
NO_GPU_MULTI(RoPE)
|
NO_GPU_MULTI(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
53
mlx/fast.cpp
53
mlx/fast.cpp
@ -46,6 +46,59 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
|||||||
return {outputs, out_axes};
|
return {outputs, out_axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array rms_norm(
|
||||||
|
const array& x,
|
||||||
|
const array& weight,
|
||||||
|
float eps,
|
||||||
|
StreamOrDevice s_ /* = {} */) {
|
||||||
|
if (x.ndim() == 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
||||||
|
"0 dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (weight.ndim() != 1) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
auto out_type = result_type({x, weight});
|
||||||
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
|
||||||
|
auto x = astype(inputs[0], float32, s);
|
||||||
|
x = multiply(
|
||||||
|
x,
|
||||||
|
rsqrt(
|
||||||
|
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||||
|
array(eps, float32),
|
||||||
|
s),
|
||||||
|
s),
|
||||||
|
s);
|
||||||
|
x = astype(x, out_type, s);
|
||||||
|
return std::vector<array>{multiply(inputs[1], x, s)};
|
||||||
|
};
|
||||||
|
if (s.device == Device::gpu) {
|
||||||
|
return array(
|
||||||
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
std::make_unique<RMSNorm>(s, fallback, eps),
|
||||||
|
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||||
|
}
|
||||||
|
return fallback({x, weight})[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||||
|
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
|
||||||
|
return eps_ == a_other.eps_;
|
||||||
|
}
|
||||||
|
|
||||||
array rope(
|
array rope(
|
||||||
const array& x,
|
const array& x,
|
||||||
int dims,
|
int dims,
|
||||||
|
@ -8,6 +8,12 @@
|
|||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
array rms_norm(
|
||||||
|
const array& x,
|
||||||
|
const array& weight,
|
||||||
|
float eps,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
array rope(
|
array rope(
|
||||||
const array& x,
|
const array& x,
|
||||||
int dims,
|
int dims,
|
||||||
@ -15,7 +21,7 @@ array rope(
|
|||||||
float base,
|
float base,
|
||||||
float scale,
|
float scale,
|
||||||
int offset,
|
int offset,
|
||||||
StreamOrDevice s /* = {} */);
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||||
array scaled_dot_product_attention(
|
array scaled_dot_product_attention(
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
@ -31,6 +33,29 @@ class Custom : public Primitive {
|
|||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RMSNorm : public Custom {
|
||||||
|
public:
|
||||||
|
RMSNorm(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
|
float eps)
|
||||||
|
: Custom(stream, fallback), eps_(eps){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
};
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(RMSNorm)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
float eps_;
|
||||||
|
};
|
||||||
|
|
||||||
class RoPE : public Custom {
|
class RoPE : public Custom {
|
||||||
public:
|
public:
|
||||||
RoPE(
|
RoPE(
|
||||||
@ -49,7 +74,9 @@ class RoPE : public Custom {
|
|||||||
offset_(offset){};
|
offset_(offset){};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
};
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
|
@ -117,6 +117,8 @@ class RMSNorm(Module):
|
|||||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||||
1.
|
1.
|
||||||
|
|
||||||
|
Note the accumulation for the mean is done in 32-bit precision.
|
||||||
|
|
||||||
[1]: https://arxiv.org/abs/1910.07467
|
[1]: https://arxiv.org/abs/1910.07467
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -133,18 +135,7 @@ class RMSNorm(Module):
|
|||||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||||
# to compute a numerically more stable RMS of x by multiplying with S
|
|
||||||
# first and summing.
|
|
||||||
#
|
|
||||||
# This way we prefer underflow over overflow which is controlled with
|
|
||||||
# the parameter epsilon anyway.
|
|
||||||
S = 1 / x.shape[-1] ** 0.5
|
|
||||||
|
|
||||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
|
||||||
n = mx.rsqrt(n + self.eps)
|
|
||||||
|
|
||||||
return self.weight * x * n
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(Module):
|
class GroupNorm(Module):
|
||||||
|
@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
auto m =
|
auto m =
|
||||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"rms_norm",
|
||||||
|
[](const array& x,
|
||||||
|
const array& weight,
|
||||||
|
float eps,
|
||||||
|
const StreamOrDevice& s /* = {} */) {
|
||||||
|
return fast::rms_norm(x, weight, eps, s);
|
||||||
|
},
|
||||||
|
"x"_a,
|
||||||
|
"weight"_a,
|
||||||
|
"eps"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Root Mean Square normalization (RMS norm).
|
||||||
|
|
||||||
|
The normalization is with respect to the last axis of the input ``x``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
weight (array): A multiplicative weight to scale the result by.
|
||||||
|
The ``weight`` should be one-dimensional with the same size
|
||||||
|
as the last axis of ``x``.
|
||||||
|
eps (float): A small additive constant for numerical stability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rope",
|
"rope",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
def test_rms_norm(self):
|
||||||
|
def rms_norm(x, weight, eps):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
|
return weight * x.astype(weight.dtype)
|
||||||
|
|
||||||
|
# Per dtype absolute tolerance
|
||||||
|
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||||
|
|
||||||
|
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||||
|
epss = [1e-3, 1e-5]
|
||||||
|
dimss = [31, 32, 33]
|
||||||
|
defaults = (mx.float32, 1e-5, 32)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
_, eps, dims = defaults
|
||||||
|
x = mx.random.uniform(
|
||||||
|
shape=(
|
||||||
|
2,
|
||||||
|
dims,
|
||||||
|
)
|
||||||
|
).astype(dtype)
|
||||||
|
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||||
|
rx = rms_norm(x, weight, eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
for eps in epss:
|
||||||
|
dtype, _, dims = defaults
|
||||||
|
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||||
|
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||||
|
rx = rms_norm(x, weight, eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
for dims in dimss:
|
||||||
|
dtype, eps, _ = defaults
|
||||||
|
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||||
|
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||||
|
rx = rms_norm(x, weight, eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
|
# Test > 4096
|
||||||
|
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||||
|
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||||
|
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||||
|
rx = rms_norm(x, weight, eps)
|
||||||
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||||
|
|
||||||
def test_fast_transforms(self):
|
def test_fast_transforms(self):
|
||||||
x = mx.random.uniform(shape=(2, 2, 8))
|
x = mx.random.uniform(shape=(2, 2, 8))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user