Fast RMS Norm (#862)

* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-03-21 07:20:54 -07:00 committed by GitHub
parent 4650d94d98
commit a54f06b16f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 493 additions and 41 deletions

View File

@ -44,7 +44,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@ -1,13 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@ -67,11 +67,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
switch (in.dtype()) {
case bool_:

View File

@ -33,6 +33,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@ -23,6 +23,7 @@ set(
"gemv"
"quantized"
"random"
"rms_norm"
"rope"
"scan"
"scaled_dot_product_attention"

View File

@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@ -0,0 +1,194 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_single_row(
const device T* x,
const device T* w,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i];
acc += xi * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
float xi = x[i];
acc += xi * xi;
}
}
}
acc = simd_sum(acc);
// Initialize shared memory
if (simd_group_id == 0) {
local_sums[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sums[simd_group_id] = acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
acc = simd_sum(local_sums[simd_lane_id]);
if (simd_lane_id == 0) {
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_looped(
const device T* x,
const device T* w,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
acc += xi * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
acc += xi * xi;
}
}
}
}
acc = simd_sum(acc);
// Initialize shared memory
if (simd_group_id == 0) {
local_sums[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sums[simd_group_id] = acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
acc = simd_sum(local_sums[simd_lane_id]);
if (simd_lane_id == 0) {
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[r + i] = w[w_stride * (i + r)] *
static_cast<T>(x[r + i] * local_inv_mean[0]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
out[r + i] = w[w_stride * (i + r)] *
static_cast<T>(x[r + i] * local_inv_mean[0]);
}
}
}
}
}
// clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
instantiate_rms(float32, float)
instantiate_rms(float16, half)
instantiate_rms(bfloat16, bfloat16_t)
// clang-format on

View File

@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
@ -224,5 +223,6 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& out = outputs[0];
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
if (x.is_donatable()) {
out.move_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "rms";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&eps_, sizeof(float), 3);
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core::fast

View File

@ -37,11 +37,15 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
};
const array& in = check_input(inputs[0]);
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
@ -75,6 +79,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);

View File

@ -102,6 +102,7 @@ NO_GPU(Transpose)
NO_GPU(Inverse)
namespace fast {
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
} // namespace fast

View File

@ -46,6 +46,59 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
return {outputs, out_axes};
}
array rms_norm(
const array& x,
const array& weight,
float eps,
StreamOrDevice s_ /* = {} */) {
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type({x, weight});
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
s),
s),
s);
x = astype(x, out_type, s);
return std::vector<array>{multiply(inputs[1], x, s)};
};
if (s.device == Device::gpu) {
return array(
x.shape(),
x.dtype(),
std::make_unique<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)});
}
return fallback({x, weight})[0];
}
bool RMSNorm::is_equivalent(const Primitive& other) const {
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
return eps_ == a_other.eps_;
}
array rope(
const array& x,
int dims,

View File

@ -8,6 +8,12 @@
namespace mlx::core::fast {
array rms_norm(
const array& x,
const array& weight,
float eps,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
@ -15,7 +21,7 @@ array rope(
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */);
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(

View File

@ -1,3 +1,5 @@
// Copyright © 2024 Apple Inc.
#include "mlx/primitives.h"
namespace mlx::core::fast {
@ -31,6 +33,29 @@ class Custom : public Primitive {
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
class RMSNorm : public Custom {
public:
RMSNorm(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNorm)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RoPE : public Custom {
public:
RoPE(
@ -49,7 +74,9 @@ class RoPE : public Custom {
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

View File

@ -117,6 +117,8 @@ class RMSNorm(Module):
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
Note the accumulation for the mean is done in 32-bit precision.
[1]: https://arxiv.org/abs/1910.07467
Args:
@ -133,18 +135,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
return mx.fast.rms_norm(x, self.weight, self.eps)
class GroupNorm(Module):

View File

@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
m.def(
"rms_norm",
[](const array& x,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
"x"_a,
"weight"_a,
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
[](const array& a,

View File

@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rms_norm(self):
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
dtypes = [mx.float32, mx.float16, mx.bfloat16]
epss = [1e-3, 1e-5]
dimss = [31, 32, 33]
defaults = (mx.float32, 1e-5, 32)
for dtype in dtypes:
_, eps, dims = defaults
x = mx.random.uniform(
shape=(
2,
dims,
)
).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
x = mx.random.uniform(shape=(dims,)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))