diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 0b561728d..f6f374686 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -33,7 +33,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 1a112731f..262f396e3 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -24,6 +24,7 @@ set( "quantized" "random" "rms_norm" + "layer_norm" "rope" "scan" "scaled_dot_product_attention" diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal new file mode 100644 index 000000000..c9fb7131d --- /dev/null +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -0,0 +1,251 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +template +[[kernel]] void layer_norm_single_row( + const device T* x, + const device T* w, + const device T* b, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + constant uint& b_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float sumx = 0; + float sumx2 = 0; + float thread_x[N_READS]; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + + x += gid * axis_size + lid * N_READS; + w += w_stride * lid * N_READS; + b += b_stride * lid * N_READS; + + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumx += thread_x[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = x[i]; + sumx2 += thread_x[i] * thread_x[i]; + sumx += thread_x[i]; + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + + // Write the outputs + out += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + thread_x[i] = (thread_x[i] - mean) * normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; + } + } + } +} + +template +[[kernel]] void layer_norm_looped( + const device T* x, + const device T* w, + const device T* b, + device T* out, + constant float& eps, + constant uint& axis_size, + constant uint& w_stride, + constant uint& b_stride, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + float sumx = 0; + float sumx2 = 0; + + constexpr int SIMD_SIZE = 32; + + threadgroup float local_sumx[SIMD_SIZE]; + threadgroup float local_sumx2[SIMD_SIZE]; + threadgroup float local_mean[1]; + threadgroup float local_normalizer[1]; + + x += gid * axis_size + lid * N_READS; + w += w_stride * lid * N_READS; + b += b_stride * lid * N_READS; + + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = x[i + r]; + sumx2 += xi * xi; + sumx += xi; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = x[i + r]; + sumx2 += xi * xi; + sumx += xi; + } + } + } + } + + sumx = simd_sum(sumx); + sumx2 = simd_sum(sumx2); + + // Initialize shared memory + if (simd_group_id == 0) { + local_sumx[simd_lane_id] = 0; + local_sumx2[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write simd accumulations into shared memory + if (simd_lane_id == 0) { + local_sumx[simd_group_id] = sumx; + local_sumx2[simd_group_id] = sumx2; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Accumulate over simd groups + if (simd_group_id == 0) { + sumx = simd_sum(local_sumx[simd_lane_id]); + sumx2 = simd_sum(local_sumx2[simd_lane_id]); + if (simd_lane_id == 0) { + float mean = sumx / axis_size; + float variance = sumx2 / axis_size - mean * mean; + + local_mean[0] = mean; + local_normalizer[0] = metal::precise::rsqrt(variance + eps); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = local_mean[0]; + float normalizer = local_normalizer[0]; + + // Write the outputs + out += gid * axis_size + lid * N_READS; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float xi = (x[r + i] - mean) * normalizer; + out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float xi = (x[r + i] - mean) * normalizer; + out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; + } + } + } + } +} + + +// clang-format off +#define instantiate_layer_norm_single_row(name, itype) \ + template [[host_name("layer_norm" #name)]] [[kernel]] void \ + layer_norm_single_row( \ + const device itype* x, \ + const device itype* w, \ + const device itype* b, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + constant uint& b_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_layer_norm_looped(name, itype) \ + template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ + layer_norm_looped( \ + const device itype* x, \ + const device itype* w, \ + const device itype* b, \ + device itype* out, \ + constant float& eps, \ + constant uint& axis_size, \ + constant uint& w_stride, \ + constant uint& b_stride, \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_layer_norm(name, itype) \ + instantiate_layer_norm_single_row(name, itype) \ + instantiate_layer_norm_looped(name, itype) + +instantiate_layer_norm(float32, float) +instantiate_layer_norm(float16, half) +instantiate_layer_norm(bfloat16, bfloat16_t) + // clang-format on + diff --git a/mlx/backend/metal/rms_norm.cpp b/mlx/backend/metal/normalization.cpp similarity index 51% rename from mlx/backend/metal/rms_norm.cpp rename to mlx/backend/metal/normalization.cpp index a3a783f31..dd6c641a2 100644 --- a/mlx/backend/metal/rms_norm.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -95,4 +95,91 @@ void RMSNorm::eval_gpu( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } +void LayerNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous + std::vector copies; + auto check_input = [&copies, &s](const array& x) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + copies.push_back(x_copy); + return x_copy; + } + }; + const array& x = check_input(inputs[0]); + const array& w = inputs[1]; + const array& b = inputs[2]; + + if (x.is_donatable()) { + out.move_shared_buffer(x); + } else { + out.set_data( + allocator::malloc_or_wait(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + + auto axis_size = static_cast(x.shape().back()); + int n_rows = x.data_size() / axis_size; + + const int simd_size = 32; + const int n_reads = RMS_N_READS; + const int looped_limit = RMS_LOOPED_LIMIT; + std::string op_name = "layer_norm"; + if (axis_size > looped_limit) { + op_name += "_looped"; + } + op_name += type_to_name(out); + auto compute_encoder = d.get_command_encoder(s.index); + { + auto kernel = d.get_kernel(op_name); + + MTL::Size grid_dims, group_dims; + if (axis_size <= looped_limit) { + size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; + size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + size_t threadgroup_size = simd_size * simds_needed; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } else { + size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; + compute_encoder->setComputePipelineState(kernel); + set_array_buffer( + compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, b, 2); + set_array_buffer(compute_encoder, out, 3); + compute_encoder->setBytes(&eps_, sizeof(float), 4); + compute_encoder->setBytes(&axis_size, sizeof(int), 5); + compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); + compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + } // namespace mlx::core::fast diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 6bbde8a50..248fef031 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -102,6 +102,7 @@ NO_GPU(Transpose) NO_GPU(Inverse) namespace fast { +NO_GPU_MULTI(LayerNorm) NO_GPU_MULTI(RMSNorm) NO_GPU_MULTI(RoPE) NO_GPU(ScaledDotProductAttention) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index cd1f4ecce..076213875 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -87,7 +87,7 @@ array rms_norm( if (s.device == Device::gpu) { return array( x.shape(), - x.dtype(), + out_type, std::make_unique(s, fallback, eps), {astype(x, out_type, s), astype(weight, out_type, s)}); } @@ -99,6 +99,88 @@ bool RMSNorm::is_equivalent(const Primitive& other) const { return eps_ == a_other.eps_; } +array layer_norm( + const array& x, + const std::optional& weight, + const std::optional& bias, + float eps, + StreamOrDevice s_ /* = {} */) { + if (x.ndim() == 0) { + std::ostringstream msg; + msg << "[layer_norm] Input must have at least 1 dimension but got input with " + "0 dimensions."; + throw std::invalid_argument(msg.str()); + } + if (weight.has_value() && (*weight).ndim() != 1) { + std::ostringstream msg; + msg << "[layer_norm] weight must have 1 dimension but has " + << (*weight).ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (bias.has_value() && (*bias).ndim() != 1) { + std::ostringstream msg; + msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + auto out_type = (weight.has_value()) + ? ((bias.has_value()) ? result_type({x, *weight, *bias}) + : result_type({x, *weight})) + : x.dtype(); + if (!is_floating_point(out_type) || is_complex(out_type)) { + std::ostringstream msg; + msg << "[layer_norm] Received unsupported type " << out_type << "."; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + bool has_weight = weight.has_value(); + bool has_bias = bias.has_value(); + auto fallback = [has_weight, has_bias, eps, out_type, s]( + const std::vector& inputs) { + auto x = astype(inputs[0], float32, s); + + // Should I not be smart here and leave the double mean to simplify()? + auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s); + auto mu2 = square(mu, s); + auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s); + auto v = subtract(x2, mu2, s); + + x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s)); + x = astype(x, out_type, s); + + // If the LN is affine then transform x according to the weight and bias + if (has_weight) { + x = multiply(x, inputs[1], s); + } + if (has_bias) { + x = add(x, inputs[2], s); + } + + return std::vector{x}; + }; + + auto passed_weight = + astype((weight.has_value()) ? *weight : array(1, out_type), out_type); + auto passed_bias = + astype((bias.has_value()) ? *bias : array(0, out_type), out_type); + + if (s.device == Device::gpu) { + return array( + x.shape(), + out_type, + std::make_unique(s, fallback, eps), + {astype(x, out_type, s), passed_weight, passed_bias}); + } + return fallback({x, passed_weight, passed_bias})[0]; +} + +bool LayerNorm::is_equivalent(const Primitive& other) const { + const LayerNorm& a_other = static_cast(other); + return eps_ == a_other.eps_; +} + array rope( const array& x, int dims, diff --git a/mlx/fast.h b/mlx/fast.h index 7e08533ca..4d73de581 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -14,6 +14,13 @@ array rms_norm( float eps, StreamOrDevice s = {}); +array layer_norm( + const array& x, + const std::optional& weight, + const std::optional& bias, + float eps, + StreamOrDevice s = {}); + array rope( const array& x, int dims, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index ea2b56d05..3a1836889 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -56,6 +56,29 @@ class RMSNorm : public Custom { float eps_; }; +class LayerNorm : public Custom { + public: + LayerNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, fallback), eps_(eps){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + }; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(LayerNorm) + bool is_equivalent(const Primitive& other) const override; + + private: + std::function(std::vector)> fallback_; + float eps_; +}; + class RoPE : public Custom { public: RoPE( diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 60d033b0d..7b53c2c5e 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -85,13 +85,19 @@ class LayerNorm(Module): eps (float): A small additive constant for numerical stability affine (bool): If True learn an affine transform to apply after the normalization + bias (bool): If True include a translation to the affine + transformation. If set to False the transformation is not really affine + just scaling. """ - def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): + def __init__( + self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True + ): super().__init__() if affine: - self.bias = mx.zeros((dims,)) self.weight = mx.ones((dims,)) + if bias: + self.bias = mx.zeros((dims,)) self.eps = eps self.dims = dims @@ -99,10 +105,9 @@ class LayerNorm(Module): return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" def __call__(self, x): - means = mx.mean(x, axis=-1, keepdims=True) - var = mx.var(x, axis=-1, keepdims=True) - x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x + self.bias) if "weight" in self else x + weight = self.weight if "weight" in self else None + bias = self.bias if "bias" in self else None + return mx.fast.layer_norm(x, weight, bias, self.eps) class RMSNorm(Module): diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 5bc274ca1..cbdbcde47 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) { array: The output array. )pbdoc"); + m.def( + "layer_norm", + [](const array& x, + const std::optional& weight, + const std::optional& bias, + float eps, + const StreamOrDevice& s /* = {} */) { + return fast::layer_norm(x, weight, bias, eps, s); + }, + "x"_a, + "weight"_a.none(), + "bias"_a.none(), + "eps"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Layer normalization. + + The normalization is with respect to the last axis of the input ``x``. + + Args: + x (array): Input array. + weight (array, optional): A multiplicative weight to scale the result by. + The ``weight`` should be one-dimensional with the same size + as the last axis of ``x``. If set to ``None`` then no scaling happens. + bias (array, optional): An additive offset to be added to the result. + The ``bias`` should be one-dimensional with the same size + as the last axis of ``x``. If set to ``None`` then no translation happens. + eps (float): A small additive constant for numerical stability. + + Returns: + array: The output array. + )pbdoc"); + m.def( "rope", [](const array& a, diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index d3285e993..77b5c721f 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase): rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6) + def test_layer_norm(self): + def layer_norm(x, weight, bias, eps): + ot = x.dtype + x = x.astype(mx.float32) + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + x = (x - mean) * mx.rsqrt(var + eps) + x = x.astype(ot) + if weight is not None: + x = x * weight + if bias is not None: + x = x + bias + return x + + # Per dtype absolute tolerance + tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2} + + dtypes = [mx.float32, mx.float16, mx.bfloat16] + epss = [1e-3, 1e-5] + dimss = [31, 32, 33] + defaults = (mx.float32, 1e-5, 32) + + for dtype in dtypes: + _, eps, dims = defaults + x = mx.random.uniform( + shape=( + 2, + dims, + ) + ).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + bias = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = layer_norm(x, weight, bias, eps) + rx_fast = mx.fast.layer_norm(x, weight, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, weight, None, eps) + rx_fast = mx.fast.layer_norm(x, weight, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, bias, eps) + rx_fast = mx.fast.layer_norm(x, None, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, None, eps) + rx_fast = mx.fast.layer_norm(x, None, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + for eps in epss: + dtype, _, dims = defaults + x = mx.random.uniform(shape=(2, dims)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + bias = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = layer_norm(x, weight, bias, eps) + rx_fast = mx.fast.layer_norm(x, weight, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, weight, None, eps) + rx_fast = mx.fast.layer_norm(x, weight, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, bias, eps) + rx_fast = mx.fast.layer_norm(x, None, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, None, eps) + rx_fast = mx.fast.layer_norm(x, None, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + for dims in dimss: + dtype, eps, _ = defaults + x = mx.random.uniform(shape=(2, dims)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + bias = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = layer_norm(x, weight, bias, eps) + rx_fast = mx.fast.layer_norm(x, weight, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, weight, None, eps) + rx_fast = mx.fast.layer_norm(x, weight, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, bias, eps) + rx_fast = mx.fast.layer_norm(x, None, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, None, eps) + rx_fast = mx.fast.layer_norm(x, None, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + + # Test > 4096 + dims, dtype, eps = 4099, mx.float32, 1e-5 + x = mx.random.uniform(shape=(dims,)).astype(dtype) + weight = mx.random.uniform(shape=(dims,)).astype(dtype) + bias = mx.random.uniform(shape=(dims,)).astype(dtype) + rx = layer_norm(x, weight, bias, eps) + rx_fast = mx.fast.layer_norm(x, weight, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, weight, None, eps) + rx_fast = mx.fast.layer_norm(x, weight, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, bias, eps) + rx_fast = mx.fast.layer_norm(x, None, bias, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + rx = layer_norm(x, None, None, eps) + rx_fast = mx.fast.layer_norm(x, None, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + def test_fast_transforms(self): x = mx.random.uniform(shape=(2, 2, 8))