mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Adds mx.fast.layer_norm (#870)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							105d236889
						
					
				
				
					commit
					2225374060
				
			@@ -33,7 +33,7 @@ target_sources(
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ set(
 | 
			
		||||
  "quantized"
 | 
			
		||||
  "random"
 | 
			
		||||
  "rms_norm"
 | 
			
		||||
  "layer_norm"
 | 
			
		||||
  "rope"
 | 
			
		||||
  "scan"
 | 
			
		||||
  "scaled_dot_product_attention"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										251
									
								
								mlx/backend/metal/kernels/layer_norm.metal
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								mlx/backend/metal/kernels/layer_norm.metal
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,251 @@
 | 
			
		||||
// Copyright © 2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <metal_common>
 | 
			
		||||
#include <metal_simdgroup>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/kernels/bf16.h"
 | 
			
		||||
#include "mlx/backend/metal/kernels/defines.h"
 | 
			
		||||
#include "mlx/backend/metal/kernels/utils.h"
 | 
			
		||||
 | 
			
		||||
using namespace metal;
 | 
			
		||||
 | 
			
		||||
template <typename T, int N_READS = RMS_N_READS>
 | 
			
		||||
[[kernel]] void layer_norm_single_row(
 | 
			
		||||
    const device T* x,
 | 
			
		||||
    const device T* w,
 | 
			
		||||
    const device T* b,
 | 
			
		||||
    device T* out,
 | 
			
		||||
    constant float& eps,
 | 
			
		||||
    constant uint& axis_size,
 | 
			
		||||
    constant uint& w_stride,
 | 
			
		||||
    constant uint& b_stride,
 | 
			
		||||
    uint gid [[threadgroup_position_in_grid]],
 | 
			
		||||
    uint lid [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint simd_lane_id [[thread_index_in_simdgroup]],
 | 
			
		||||
    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
  float sumx = 0;
 | 
			
		||||
  float sumx2 = 0;
 | 
			
		||||
  float thread_x[N_READS];
 | 
			
		||||
 | 
			
		||||
  constexpr int SIMD_SIZE = 32;
 | 
			
		||||
 | 
			
		||||
  threadgroup float local_sumx[SIMD_SIZE];
 | 
			
		||||
  threadgroup float local_sumx2[SIMD_SIZE];
 | 
			
		||||
  threadgroup float local_mean[1];
 | 
			
		||||
  threadgroup float local_normalizer[1];
 | 
			
		||||
 | 
			
		||||
  x += gid * axis_size + lid * N_READS;
 | 
			
		||||
  w += w_stride * lid * N_READS;
 | 
			
		||||
  b += b_stride * lid * N_READS;
 | 
			
		||||
 | 
			
		||||
  if (lid * N_READS + N_READS <= axis_size) {
 | 
			
		||||
    for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
      thread_x[i] = x[i];
 | 
			
		||||
      sumx2 += thread_x[i] * thread_x[i];
 | 
			
		||||
      sumx += thread_x[i];
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
      if ((lid * N_READS + i) < axis_size) {
 | 
			
		||||
        thread_x[i] = x[i];
 | 
			
		||||
        sumx2 += thread_x[i] * thread_x[i];
 | 
			
		||||
        sumx += thread_x[i];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  sumx = simd_sum(sumx);
 | 
			
		||||
  sumx2 = simd_sum(sumx2);
 | 
			
		||||
 | 
			
		||||
  //  Initialize shared memory
 | 
			
		||||
  if (simd_group_id == 0) {
 | 
			
		||||
    local_sumx[simd_lane_id] = 0;
 | 
			
		||||
    local_sumx2[simd_lane_id] = 0;
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  // Write simd accumulations into shared memory
 | 
			
		||||
  if (simd_lane_id == 0) {
 | 
			
		||||
    local_sumx[simd_group_id] = sumx;
 | 
			
		||||
    local_sumx2[simd_group_id] = sumx2;
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  // Accumulate over simd groups
 | 
			
		||||
  if (simd_group_id == 0) {
 | 
			
		||||
    sumx = simd_sum(local_sumx[simd_lane_id]);
 | 
			
		||||
    sumx2 = simd_sum(local_sumx2[simd_lane_id]);
 | 
			
		||||
    if (simd_lane_id == 0) {
 | 
			
		||||
      float mean = sumx / axis_size;
 | 
			
		||||
      float variance = sumx2 / axis_size - mean * mean;
 | 
			
		||||
 | 
			
		||||
      local_mean[0] = mean;
 | 
			
		||||
      local_normalizer[0] = metal::precise::rsqrt(variance + eps);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  float mean = local_mean[0];
 | 
			
		||||
  float normalizer = local_normalizer[0];
 | 
			
		||||
 | 
			
		||||
  // Write the outputs
 | 
			
		||||
  out += gid * axis_size + lid * N_READS;
 | 
			
		||||
  if (lid * N_READS + N_READS <= axis_size) {
 | 
			
		||||
    for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
      thread_x[i] = (thread_x[i] - mean) * normalizer;
 | 
			
		||||
      out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
      if ((lid * N_READS + i) < axis_size) {
 | 
			
		||||
        thread_x[i] = (thread_x[i] - mean) * normalizer;
 | 
			
		||||
        out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int N_READS = RMS_N_READS>
 | 
			
		||||
[[kernel]] void layer_norm_looped(
 | 
			
		||||
    const device T* x,
 | 
			
		||||
    const device T* w,
 | 
			
		||||
    const device T* b,
 | 
			
		||||
    device T* out,
 | 
			
		||||
    constant float& eps,
 | 
			
		||||
    constant uint& axis_size,
 | 
			
		||||
    constant uint& w_stride,
 | 
			
		||||
    constant uint& b_stride,
 | 
			
		||||
    uint gid [[threadgroup_position_in_grid]],
 | 
			
		||||
    uint lid [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint lsize [[threads_per_threadgroup]],
 | 
			
		||||
    uint simd_lane_id [[thread_index_in_simdgroup]],
 | 
			
		||||
    uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
  float sumx = 0;
 | 
			
		||||
  float sumx2 = 0;
 | 
			
		||||
 | 
			
		||||
  constexpr int SIMD_SIZE = 32;
 | 
			
		||||
 | 
			
		||||
  threadgroup float local_sumx[SIMD_SIZE];
 | 
			
		||||
  threadgroup float local_sumx2[SIMD_SIZE];
 | 
			
		||||
  threadgroup float local_mean[1];
 | 
			
		||||
  threadgroup float local_normalizer[1];
 | 
			
		||||
 | 
			
		||||
  x += gid * axis_size + lid * N_READS;
 | 
			
		||||
  w += w_stride * lid * N_READS;
 | 
			
		||||
  b += b_stride * lid * N_READS;
 | 
			
		||||
 | 
			
		||||
  for (uint r = 0; r < axis_size; r += lsize * N_READS) {
 | 
			
		||||
    if (r + lid * N_READS + N_READS <= axis_size) {
 | 
			
		||||
      for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
        float xi = x[i + r];
 | 
			
		||||
        sumx2 += xi * xi;
 | 
			
		||||
        sumx += xi;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
        if ((r + lid * N_READS + i) < axis_size) {
 | 
			
		||||
          float xi = x[i + r];
 | 
			
		||||
          sumx2 += xi * xi;
 | 
			
		||||
          sumx += xi;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  sumx = simd_sum(sumx);
 | 
			
		||||
  sumx2 = simd_sum(sumx2);
 | 
			
		||||
 | 
			
		||||
  //  Initialize shared memory
 | 
			
		||||
  if (simd_group_id == 0) {
 | 
			
		||||
    local_sumx[simd_lane_id] = 0;
 | 
			
		||||
    local_sumx2[simd_lane_id] = 0;
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  // Write simd accumulations into shared memory
 | 
			
		||||
  if (simd_lane_id == 0) {
 | 
			
		||||
    local_sumx[simd_group_id] = sumx;
 | 
			
		||||
    local_sumx2[simd_group_id] = sumx2;
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  // Accumulate over simd groups
 | 
			
		||||
  if (simd_group_id == 0) {
 | 
			
		||||
    sumx = simd_sum(local_sumx[simd_lane_id]);
 | 
			
		||||
    sumx2 = simd_sum(local_sumx2[simd_lane_id]);
 | 
			
		||||
    if (simd_lane_id == 0) {
 | 
			
		||||
      float mean = sumx / axis_size;
 | 
			
		||||
      float variance = sumx2 / axis_size - mean * mean;
 | 
			
		||||
 | 
			
		||||
      local_mean[0] = mean;
 | 
			
		||||
      local_normalizer[0] = metal::precise::rsqrt(variance + eps);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
  float mean = local_mean[0];
 | 
			
		||||
  float normalizer = local_normalizer[0];
 | 
			
		||||
 | 
			
		||||
  // Write the outputs
 | 
			
		||||
  out += gid * axis_size + lid * N_READS;
 | 
			
		||||
  for (uint r = 0; r < axis_size; r += lsize * N_READS) {
 | 
			
		||||
    if (r + lid * N_READS + N_READS <= axis_size) {
 | 
			
		||||
      for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
        float xi = (x[r + i] - mean) * normalizer;
 | 
			
		||||
        out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (int i = 0; i < N_READS; i++) {
 | 
			
		||||
        if ((r + lid * N_READS + i) < axis_size) {
 | 
			
		||||
          float xi = (x[r + i] - mean) * normalizer;
 | 
			
		||||
          out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// clang-format off
 | 
			
		||||
#define instantiate_layer_norm_single_row(name, itype)        \
 | 
			
		||||
  template [[host_name("layer_norm" #name)]] [[kernel]] void  \
 | 
			
		||||
  layer_norm_single_row<itype>(                               \
 | 
			
		||||
      const device itype* x,                                  \
 | 
			
		||||
      const device itype* w,                                  \
 | 
			
		||||
      const device itype* b,                                  \
 | 
			
		||||
      device itype* out,                                      \
 | 
			
		||||
      constant float& eps,                                    \
 | 
			
		||||
      constant uint& axis_size,                               \
 | 
			
		||||
      constant uint& w_stride,                                \
 | 
			
		||||
      constant uint& b_stride,                                \
 | 
			
		||||
      uint gid [[thread_position_in_grid]],                   \
 | 
			
		||||
      uint lid [[thread_position_in_threadgroup]],            \
 | 
			
		||||
      uint simd_lane_id [[thread_index_in_simdgroup]],        \
 | 
			
		||||
      uint simd_group_id [[simdgroup_index_in_threadgroup]]);
 | 
			
		||||
 | 
			
		||||
#define instantiate_layer_norm_looped(name, itype)                   \
 | 
			
		||||
  template [[host_name("layer_norm_looped" #name)]] [[kernel]] void  \
 | 
			
		||||
  layer_norm_looped<itype>(                                                 \
 | 
			
		||||
      const device itype* x,                                         \
 | 
			
		||||
      const device itype* w,                                         \
 | 
			
		||||
      const device itype* b,                                         \
 | 
			
		||||
      device itype* out,                                             \
 | 
			
		||||
      constant float& eps,                                           \
 | 
			
		||||
      constant uint& axis_size,                                      \
 | 
			
		||||
      constant uint& w_stride,                                       \
 | 
			
		||||
      constant uint& b_stride,                                       \
 | 
			
		||||
      uint gid [[thread_position_in_grid]],                          \
 | 
			
		||||
      uint lid [[thread_position_in_threadgroup]],                   \
 | 
			
		||||
      uint lsize [[threads_per_threadgroup]],                        \
 | 
			
		||||
      uint simd_lane_id [[thread_index_in_simdgroup]],               \
 | 
			
		||||
      uint simd_group_id [[simdgroup_index_in_threadgroup]]);
 | 
			
		||||
 | 
			
		||||
#define instantiate_layer_norm(name, itype)      \
 | 
			
		||||
  instantiate_layer_norm_single_row(name, itype) \
 | 
			
		||||
  instantiate_layer_norm_looped(name, itype)
 | 
			
		||||
 | 
			
		||||
instantiate_layer_norm(float32, float)
 | 
			
		||||
instantiate_layer_norm(float16, half)
 | 
			
		||||
instantiate_layer_norm(bfloat16, bfloat16_t)
 | 
			
		||||
    // clang-format on
 | 
			
		||||
 | 
			
		||||
@@ -95,4 +95,91 @@ void RMSNorm::eval_gpu(
 | 
			
		||||
      [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void LayerNorm::eval_gpu(
 | 
			
		||||
    const std::vector<array>& inputs,
 | 
			
		||||
    std::vector<array>& outputs) {
 | 
			
		||||
  auto& s = stream();
 | 
			
		||||
  auto& d = metal::device(s.device);
 | 
			
		||||
  auto& out = outputs[0];
 | 
			
		||||
 | 
			
		||||
  // Make sure that the last dimension is contiguous
 | 
			
		||||
  std::vector<array> copies;
 | 
			
		||||
  auto check_input = [&copies, &s](const array& x) {
 | 
			
		||||
    bool no_copy = x.strides()[x.ndim() - 1] == 1;
 | 
			
		||||
    if (x.ndim() > 1) {
 | 
			
		||||
      auto s = x.strides()[x.ndim() - 2];
 | 
			
		||||
      no_copy &= (s == 0 || s == x.shape().back());
 | 
			
		||||
    }
 | 
			
		||||
    if (no_copy) {
 | 
			
		||||
      return x;
 | 
			
		||||
    } else {
 | 
			
		||||
      array x_copy(x.shape(), x.dtype(), nullptr, {});
 | 
			
		||||
      copy_gpu(x, x_copy, CopyType::General, s);
 | 
			
		||||
      copies.push_back(x_copy);
 | 
			
		||||
      return x_copy;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  const array& x = check_input(inputs[0]);
 | 
			
		||||
  const array& w = inputs[1];
 | 
			
		||||
  const array& b = inputs[2];
 | 
			
		||||
 | 
			
		||||
  if (x.is_donatable()) {
 | 
			
		||||
    out.move_shared_buffer(x);
 | 
			
		||||
  } else {
 | 
			
		||||
    out.set_data(
 | 
			
		||||
        allocator::malloc_or_wait(x.data_size() * x.itemsize()),
 | 
			
		||||
        x.data_size(),
 | 
			
		||||
        x.strides(),
 | 
			
		||||
        x.flags());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto axis_size = static_cast<uint32_t>(x.shape().back());
 | 
			
		||||
  int n_rows = x.data_size() / axis_size;
 | 
			
		||||
 | 
			
		||||
  const int simd_size = 32;
 | 
			
		||||
  const int n_reads = RMS_N_READS;
 | 
			
		||||
  const int looped_limit = RMS_LOOPED_LIMIT;
 | 
			
		||||
  std::string op_name = "layer_norm";
 | 
			
		||||
  if (axis_size > looped_limit) {
 | 
			
		||||
    op_name += "_looped";
 | 
			
		||||
  }
 | 
			
		||||
  op_name += type_to_name(out);
 | 
			
		||||
  auto compute_encoder = d.get_command_encoder(s.index);
 | 
			
		||||
  {
 | 
			
		||||
    auto kernel = d.get_kernel(op_name);
 | 
			
		||||
 | 
			
		||||
    MTL::Size grid_dims, group_dims;
 | 
			
		||||
    if (axis_size <= looped_limit) {
 | 
			
		||||
      size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
 | 
			
		||||
      size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
 | 
			
		||||
      size_t threadgroup_size = simd_size * simds_needed;
 | 
			
		||||
      assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
 | 
			
		||||
      size_t n_threads = n_rows * threadgroup_size;
 | 
			
		||||
      grid_dims = MTL::Size(n_threads, 1, 1);
 | 
			
		||||
      group_dims = MTL::Size(threadgroup_size, 1, 1);
 | 
			
		||||
    } else {
 | 
			
		||||
      size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
 | 
			
		||||
      size_t n_threads = n_rows * threadgroup_size;
 | 
			
		||||
      grid_dims = MTL::Size(n_threads, 1, 1);
 | 
			
		||||
      group_dims = MTL::Size(threadgroup_size, 1, 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
 | 
			
		||||
    uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
 | 
			
		||||
    compute_encoder->setComputePipelineState(kernel);
 | 
			
		||||
    set_array_buffer(
 | 
			
		||||
        compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
 | 
			
		||||
    set_array_buffer(compute_encoder, w, 1);
 | 
			
		||||
    set_array_buffer(compute_encoder, b, 2);
 | 
			
		||||
    set_array_buffer(compute_encoder, out, 3);
 | 
			
		||||
    compute_encoder->setBytes(&eps_, sizeof(float), 4);
 | 
			
		||||
    compute_encoder->setBytes(&axis_size, sizeof(int), 5);
 | 
			
		||||
    compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
 | 
			
		||||
    compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
 | 
			
		||||
    compute_encoder->dispatchThreads(grid_dims, group_dims);
 | 
			
		||||
  }
 | 
			
		||||
  d.get_command_buffer(s.index)->addCompletedHandler(
 | 
			
		||||
      [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::fast
 | 
			
		||||
@@ -102,6 +102,7 @@ NO_GPU(Transpose)
 | 
			
		||||
NO_GPU(Inverse)
 | 
			
		||||
 | 
			
		||||
namespace fast {
 | 
			
		||||
NO_GPU_MULTI(LayerNorm)
 | 
			
		||||
NO_GPU_MULTI(RMSNorm)
 | 
			
		||||
NO_GPU_MULTI(RoPE)
 | 
			
		||||
NO_GPU(ScaledDotProductAttention)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										84
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										84
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							@@ -87,7 +87,7 @@ array rms_norm(
 | 
			
		||||
  if (s.device == Device::gpu) {
 | 
			
		||||
    return array(
 | 
			
		||||
        x.shape(),
 | 
			
		||||
        x.dtype(),
 | 
			
		||||
        out_type,
 | 
			
		||||
        std::make_unique<RMSNorm>(s, fallback, eps),
 | 
			
		||||
        {astype(x, out_type, s), astype(weight, out_type, s)});
 | 
			
		||||
  }
 | 
			
		||||
@@ -99,6 +99,88 @@ bool RMSNorm::is_equivalent(const Primitive& other) const {
 | 
			
		||||
  return eps_ == a_other.eps_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array layer_norm(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    const std::optional<array>& weight,
 | 
			
		||||
    const std::optional<array>& bias,
 | 
			
		||||
    float eps,
 | 
			
		||||
    StreamOrDevice s_ /* = {} */) {
 | 
			
		||||
  if (x.ndim() == 0) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[layer_norm] Input must have at least 1 dimension but got input with "
 | 
			
		||||
           "0 dimensions.";
 | 
			
		||||
    throw std::invalid_argument(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
  if (weight.has_value() && (*weight).ndim() != 1) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[layer_norm] weight must have 1 dimension but has "
 | 
			
		||||
        << (*weight).ndim() << " dimensions.";
 | 
			
		||||
    throw std::invalid_argument(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
  if (bias.has_value() && (*bias).ndim() != 1) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
 | 
			
		||||
        << " dimensions.";
 | 
			
		||||
    throw std::invalid_argument(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto out_type = (weight.has_value())
 | 
			
		||||
      ? ((bias.has_value()) ? result_type({x, *weight, *bias})
 | 
			
		||||
                            : result_type({x, *weight}))
 | 
			
		||||
      : x.dtype();
 | 
			
		||||
  if (!is_floating_point(out_type) || is_complex(out_type)) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[layer_norm] Received unsupported type " << out_type << ".";
 | 
			
		||||
    throw std::invalid_argument(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto s = to_stream(s_);
 | 
			
		||||
  bool has_weight = weight.has_value();
 | 
			
		||||
  bool has_bias = bias.has_value();
 | 
			
		||||
  auto fallback = [has_weight, has_bias, eps, out_type, s](
 | 
			
		||||
                      const std::vector<array>& inputs) {
 | 
			
		||||
    auto x = astype(inputs[0], float32, s);
 | 
			
		||||
 | 
			
		||||
    // Should I not be smart here and leave the double mean to simplify()?
 | 
			
		||||
    auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
 | 
			
		||||
    auto mu2 = square(mu, s);
 | 
			
		||||
    auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
 | 
			
		||||
    auto v = subtract(x2, mu2, s);
 | 
			
		||||
 | 
			
		||||
    x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
 | 
			
		||||
    x = astype(x, out_type, s);
 | 
			
		||||
 | 
			
		||||
    // If the LN is affine then transform x according to the weight and bias
 | 
			
		||||
    if (has_weight) {
 | 
			
		||||
      x = multiply(x, inputs[1], s);
 | 
			
		||||
    }
 | 
			
		||||
    if (has_bias) {
 | 
			
		||||
      x = add(x, inputs[2], s);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return std::vector<array>{x};
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  auto passed_weight =
 | 
			
		||||
      astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
 | 
			
		||||
  auto passed_bias =
 | 
			
		||||
      astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
 | 
			
		||||
 | 
			
		||||
  if (s.device == Device::gpu) {
 | 
			
		||||
    return array(
 | 
			
		||||
        x.shape(),
 | 
			
		||||
        out_type,
 | 
			
		||||
        std::make_unique<LayerNorm>(s, fallback, eps),
 | 
			
		||||
        {astype(x, out_type, s), passed_weight, passed_bias});
 | 
			
		||||
  }
 | 
			
		||||
  return fallback({x, passed_weight, passed_bias})[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool LayerNorm::is_equivalent(const Primitive& other) const {
 | 
			
		||||
  const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
 | 
			
		||||
  return eps_ == a_other.eps_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array rope(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    int dims,
 | 
			
		||||
 
 | 
			
		||||
@@ -14,6 +14,13 @@ array rms_norm(
 | 
			
		||||
    float eps,
 | 
			
		||||
    StreamOrDevice s = {});
 | 
			
		||||
 | 
			
		||||
array layer_norm(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    const std::optional<array>& weight,
 | 
			
		||||
    const std::optional<array>& bias,
 | 
			
		||||
    float eps,
 | 
			
		||||
    StreamOrDevice s = {});
 | 
			
		||||
 | 
			
		||||
array rope(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    int dims,
 | 
			
		||||
 
 | 
			
		||||
@@ -56,6 +56,29 @@ class RMSNorm : public Custom {
 | 
			
		||||
  float eps_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class LayerNorm : public Custom {
 | 
			
		||||
 public:
 | 
			
		||||
  LayerNorm(
 | 
			
		||||
      Stream stream,
 | 
			
		||||
      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
			
		||||
      float eps)
 | 
			
		||||
      : Custom(stream, fallback), eps_(eps){};
 | 
			
		||||
 | 
			
		||||
  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
			
		||||
      override {
 | 
			
		||||
    throw std::runtime_error("NYI");
 | 
			
		||||
  };
 | 
			
		||||
  void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
			
		||||
      override;
 | 
			
		||||
 | 
			
		||||
  DEFINE_PRINT(LayerNorm)
 | 
			
		||||
  bool is_equivalent(const Primitive& other) const override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::function<std::vector<array>(std::vector<array>)> fallback_;
 | 
			
		||||
  float eps_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class RoPE : public Custom {
 | 
			
		||||
 public:
 | 
			
		||||
  RoPE(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user