mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fast RMS Norm (#862)
* fast rmsnorm * no rms gpu * kernel * fix shared mem * looped rms and donation in softmax * Make the squaring in float32 to avoid underflow * Fix the default StreamOrDevice for rope and rms_norm in fast * nits --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
4650d94d98
commit
a54f06b16f
@ -44,7 +44,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@ -1,13 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@ -67,11 +67,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
|
@ -33,6 +33,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@ -23,6 +23,7 @@ set(
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
|
@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
194
mlx/backend/metal/kernels/rms_norm.metal
Normal file
194
mlx/backend/metal/kernels/rms_norm.metal
Normal file
@ -0,0 +1,194 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
instantiate_rms(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
@ -224,5 +223,6 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
|
98
mlx/backend/metal/rms_norm.cpp
Normal file
98
mlx/backend/metal/rms_norm.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
16 * 8, 0); // minimum of 16 bytes
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@ -37,11 +37,15 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
@ -75,6 +79,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
|
@ -102,6 +102,7 @@ NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
} // namespace fast
|
||||
|
53
mlx/fast.cpp
53
mlx/fast.cpp
@ -46,6 +46,59 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
array rms_norm(
|
||||
const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto out_type = result_type({x, weight});
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
x = multiply(
|
||||
x,
|
||||
rsqrt(
|
||||
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||
array(eps, float32),
|
||||
s),
|
||||
s),
|
||||
s);
|
||||
x = astype(x, out_type, s);
|
||||
return std::vector<array>{multiply(inputs[1], x, s)};
|
||||
};
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_unique<RMSNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||
}
|
||||
return fallback({x, weight})[0];
|
||||
}
|
||||
|
||||
bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
|
@ -8,6 +8,12 @@
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
array rms_norm(
|
||||
const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
@ -15,7 +21,7 @@ array rope(
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */);
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
array scaled_dot_product_attention(
|
||||
|
@ -1,3 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@ -31,6 +33,29 @@ class Custom : public Primitive {
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
};
|
||||
|
||||
class RMSNorm : public Custom {
|
||||
public:
|
||||
RMSNorm(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RMSNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
@ -49,7 +74,9 @@ class RoPE : public Custom {
|
||||
offset_(offset){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
|
@ -117,6 +117,8 @@ class RMSNorm(Module):
|
||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||
1.
|
||||
|
||||
Note the accumulation for the mean is done in 32-bit precision.
|
||||
|
||||
[1]: https://arxiv.org/abs/1910.07467
|
||||
|
||||
Args:
|
||||
@ -133,18 +135,7 @@ class RMSNorm(Module):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
||||
# to compute a numerically more stable RMS of x by multiplying with S
|
||||
# first and summing.
|
||||
#
|
||||
# This way we prefer underflow over overflow which is controlled with
|
||||
# the parameter epsilon anyway.
|
||||
S = 1 / x.shape[-1] ** 0.5
|
||||
|
||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
||||
n = mx.rsqrt(n + self.eps)
|
||||
|
||||
return self.weight * x * n
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
|
@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||
|
||||
m.def(
|
||||
"rms_norm",
|
||||
[](const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::rms_norm(x, weight, eps, s);
|
||||
},
|
||||
"x"_a,
|
||||
"weight"_a,
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Root Mean Square normalization (RMS norm).
|
||||
|
||||
The normalization is with respect to the last axis of the input ``x``.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
|
@ -115,6 +115,57 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_rms_norm(self):
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
epss = [1e-3, 1e-5]
|
||||
dimss = [31, 32, 33]
|
||||
defaults = (mx.float32, 1e-5, 32)
|
||||
|
||||
for dtype in dtypes:
|
||||
_, eps, dims = defaults
|
||||
x = mx.random.uniform(
|
||||
shape=(
|
||||
2,
|
||||
dims,
|
||||
)
|
||||
).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user