mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Adds mx.fast.layer_norm (#870)
This commit is contained in:
parent
105d236889
commit
2225374060
@ -33,7 +33,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@ -24,6 +24,7 @@ set(
|
||||
"quantized"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
|
251
mlx/backend/metal/kernels/layer_norm.metal
Normal file
251
mlx/backend/metal/kernels/layer_norm.metal
Normal file
@ -0,0 +1,251 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float thread_x[N_READS];
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_layer_norm_single_row(name, itype) \
|
||||
instantiate_layer_norm_looped(name, itype)
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@ -95,4 +95,91 @@ void RMSNorm::eval_gpu(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, b, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@ -102,6 +102,7 @@ NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
|
84
mlx/fast.cpp
84
mlx/fast.cpp
@ -87,7 +87,7 @@ array rms_norm(
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
out_type,
|
||||
std::make_unique<RMSNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||
}
|
||||
@ -99,6 +99,88 @@ bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.has_value() && (*weight).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] weight must have 1 dimension but has "
|
||||
<< (*weight).ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bias.has_value() && (*bias).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = (weight.has_value())
|
||||
? ((bias.has_value()) ? result_type({x, *weight, *bias})
|
||||
: result_type({x, *weight}))
|
||||
: x.dtype();
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
bool has_weight = weight.has_value();
|
||||
bool has_bias = bias.has_value();
|
||||
auto fallback = [has_weight, has_bias, eps, out_type, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
|
||||
// Should I not be smart here and leave the double mean to simplify()?
|
||||
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu2 = square(mu, s);
|
||||
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto v = subtract(x2, mu2, s);
|
||||
|
||||
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = astype(x, out_type, s);
|
||||
|
||||
// If the LN is affine then transform x according to the weight and bias
|
||||
if (has_weight) {
|
||||
x = multiply(x, inputs[1], s);
|
||||
}
|
||||
if (has_bias) {
|
||||
x = add(x, inputs[2], s);
|
||||
}
|
||||
|
||||
return std::vector<array>{x};
|
||||
};
|
||||
|
||||
auto passed_weight =
|
||||
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
|
||||
auto passed_bias =
|
||||
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
std::make_unique<LayerNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), passed_weight, passed_bias});
|
||||
}
|
||||
return fallback({x, passed_weight, passed_bias})[0];
|
||||
}
|
||||
|
||||
bool LayerNorm::is_equivalent(const Primitive& other) const {
|
||||
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
|
@ -14,6 +14,13 @@ array rms_norm(
|
||||
float eps,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
|
@ -56,6 +56,29 @@ class RMSNorm : public Custom {
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNorm : public Custom {
|
||||
public:
|
||||
LayerNorm(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
|
@ -85,13 +85,19 @@ class LayerNorm(Module):
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
bias (bool): If True include a translation to the affine
|
||||
transformation. If set to False the transformation is not really affine
|
||||
just scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
if bias:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
@ -99,10 +105,9 @@ class LayerNorm(Module):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
weight = self.weight if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
|
@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"layer_norm",
|
||||
[](const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::layer_norm(x, weight, bias, eps, s);
|
||||
},
|
||||
"x"_a,
|
||||
"weight"_a.none(),
|
||||
"bias"_a.none(),
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Layer normalization.
|
||||
|
||||
The normalization is with respect to the last axis of the input ``x``.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array, optional): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no scaling happens.
|
||||
bias (array, optional): An additive offset to be added to the result.
|
||||
The ``bias`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no translation happens.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
|
@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_layer_norm(self):
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mean = x.mean(axis=-1, keepdims=True)
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = (x - mean) * mx.rsqrt(var + eps)
|
||||
x = x.astype(ot)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2}
|
||||
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
epss = [1e-3, 1e-5]
|
||||
dimss = [31, 32, 33]
|
||||
defaults = (mx.float32, 1e-5, 32)
|
||||
|
||||
for dtype in dtypes:
|
||||
_, eps, dims = defaults
|
||||
x = mx.random.uniform(
|
||||
shape=(
|
||||
2,
|
||||
dims,
|
||||
)
|
||||
).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user