diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index e36e0567a..96159dfa8 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -58,6 +58,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp diff --git a/mlx/backend/cpu/logsumexp.cpp b/mlx/backend/cpu/logsumexp.cpp new file mode 100644 index 000000000..56f0dab9f --- /dev/null +++ b/mlx/backend/cpu/logsumexp.cpp @@ -0,0 +1,140 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" +#include "mlx/primitives.h" +#include "mlx/types/limits.h" + +namespace mlx::core { + +namespace { + +using namespace mlx::core::simd; + +template +void logsumexp(const array& in, array& out, Stream stream) { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(in); + encoder.set_output_array(out); + + const T* in_ptr = in.data(); + T* out_ptr = out.data(); + + int M = in.shape().back(); + int L = in.data_size() / M; + + encoder.dispatch([in_ptr, out_ptr, M, L]() mutable { + constexpr int N = std::min(max_size, max_size); + + const T* current_in_ptr; + + for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) { + // Find the maximum + current_in_ptr = in_ptr; + Simd vmaximum(-numeric_limits::infinity()); + size_t s = M; + while (s >= N) { + Simd vals = load(current_in_ptr); + vmaximum = maximum(vals, vmaximum); + current_in_ptr += N; + s -= N; + } + + AccT maximum = max(vmaximum); + while (s-- > 0) { + maximum = std::max(maximum, static_cast(*current_in_ptr)); + current_in_ptr++; + } + + // Compute the normalizer and the exponentials + Simd vnormalizer(0.0); + current_in_ptr = in_ptr; + s = M; + while (s >= N) { + Simd vexp = load(current_in_ptr); + vexp = exp(vexp - maximum); + vnormalizer = vnormalizer + vexp; + current_in_ptr += N; + s -= N; + } + AccT normalizer = sum(vnormalizer); + while (s-- > 0) { + AccT _exp = std::exp(*current_in_ptr - maximum); + normalizer += _exp; + current_in_ptr++; + } + // Normalize + *out_ptr = std::isinf(maximum) + ? static_cast(maximum) + : static_cast(std::log(normalizer) + maximum); + } + }); +} + +} // namespace + +void LogSumExp::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // Make sure that the last dimension is contiguous + auto s = stream(); + auto& encoder = cpu::get_command_encoder(s); + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy(x, x_copy, CopyType::General, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + switch (in.dtype()) { + case float32: + logsumexp(in, out, stream()); + break; + case float16: + logsumexp(in, out, stream()); + break; + case bfloat16: + logsumexp(in, out, stream()); + break; + case float64: + logsumexp(in, out, stream()); + break; + default: + throw std::runtime_error( + "[logsumexp] only supports floating point types"); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 78e4a3e68..41d14f556 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto set_output = [s = stream(), &out](const array& x) { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); - } - if (no_copy) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { @@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { auto in = set_output(inputs[0]); switch (in.dtype()) { - case bool_: - case uint8: - case uint16: - case uint32: - case uint64: - case int8: - case int16: - case int32: - case int64: - throw std::runtime_error( - "Softmax is defined only for floating point types"); - break; case float32: softmax(in, out, stream()); break; @@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { case float64: softmax(in, out, stream()); break; - case complex64: - throw std::invalid_argument( - "[Softmax] Not yet implemented for complex64"); + default: + throw std::runtime_error( + "[softmax] Only defined for floating point types."); break; } } diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index e49201277..7985396c4 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -47,6 +47,7 @@ if(MLX_METAL_JIT) make_jit_source(binary) make_jit_source(binary_two) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) + make_jit_source(logsumexp) make_jit_source(ternary) make_jit_source(softmax) make_jit_source(scan) @@ -95,6 +96,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp diff --git a/mlx/backend/metal/jit/arange.h b/mlx/backend/metal/jit/arange.h deleted file mode 100644 index 0c224dca4..000000000 --- a/mlx/backend/metal/jit/arange.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view arange_kernels = R"( -template [[host_name("{0}")]] [[kernel]] void arange<{1}>( - constant const {1}& start, - constant const {1}& step, - device {1}* out, - uint index [[thread_position_in_grid]]); -)"; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index b14aa567b..921ce50ce 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -20,6 +20,7 @@ const char* copy(); const char* fft(); const char* gather_axis(); const char* hadamard(); +const char* logsumexp(); const char* quantized(); const char* ternary(); const char* scan(); diff --git a/mlx/backend/metal/jit/softmax.h b/mlx/backend/metal/jit/softmax.h deleted file mode 100644 index a9672a050..000000000 --- a/mlx/backend/metal/jit/softmax.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view softmax_kernels = R"( -template [[host_name("block_{0}")]] [[kernel]] void -softmax_single_row<{1}, {2}>( - const device {1}* in, - device {1}* out, - constant int& axis_size, - uint gid [[thread_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -template [[host_name("looped_{0}")]] [[kernel]] void -softmax_looped<{1}, {2}>( - const device {1}* in, - device {1}* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index a9cc267e1..204bb14e7 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,8 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/includes.h" -#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel( const std::string& kernel_name, const array& out) { auto lib = d.get_library(kernel_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::arange() - << fmt::format( - arange_kernels, - kernel_name, - get_type_string(out.dtype())); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + kernel_source += metal::arange(); + kernel_source += get_template_definition( + kernel_name, "arange", get_type_string(out.dtype())); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel( const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&] { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::softmax() - << fmt::format( - softmax_kernels, - lib_name, - get_type_string(out.dtype()), - get_type_string(precise ? float32 : out.dtype())); - return kernel_source.str(); + std::string kernel_source = metal::utils(); + auto in_type = get_type_string(out.dtype()); + auto acc_type = get_type_string(precise ? float32 : out.dtype()); + kernel_source += metal::softmax(); + kernel_source += get_template_definition( + "block_" + lib_name, "softmax_single_row", in_type, acc_type); + kernel_source += get_template_definition( + "looped_" + lib_name, "softmax_looped", in_type, acc_type); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&] { + auto t_str = get_type_string(out.dtype()); + std::string kernel_source; + kernel_source = metal::utils(); + kernel_source += metal::logsumexp(); + kernel_source += + get_template_definition("block_" + lib_name, "logsumexp", t_str); + kernel_source += get_template_definition( + "looped_" + lib_name, "logsumexp_looped", t_str); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 63d17f959..1638a4496 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel( bool precise, const array& out); +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 7498c137a..309a840f8 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT) build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) + build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index c2e325697..fb56c1c5f 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -5,11 +5,7 @@ #include "mlx/backend/metal/kernels/arange.h" #define instantiate_arange(tname, type) \ - template [[host_name("arange" #tname)]] [[kernel]] void arange( \ - constant const type& start, \ - constant const type& step, \ - device type* out, \ - uint index [[thread_position_in_grid]]); + instantiate_kernel("arange" #tname, arange, type) instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 4674a4228..51570e48d 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -493,71 +493,11 @@ template } // clang-format off -#define instantiate_layer_norm_single_row(name, itype) \ - template [[host_name("layer_norm" #name)]] [[kernel]] void \ - layer_norm_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* b, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - constant uint& b_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \ - vjp_layer_norm_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_layer_norm_looped(name, itype) \ - template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \ - layer_norm_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* b, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - constant uint& b_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \ - vjp_layer_norm_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gb, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_layer_norm(name, itype) \ - instantiate_layer_norm_single_row(name, itype) \ - instantiate_layer_norm_looped(name, itype) +#define instantiate_layer_norm(name, itype) \ + instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \ + instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \ + instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \ + instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype) instantiate_layer_norm(float32, float) instantiate_layer_norm(float16, half) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h new file mode 100644 index 000000000..374bbcd41 --- /dev/null +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -0,0 +1,142 @@ +// Copyright © 2025 Apple Inc. + +template +[[kernel]] void logsumexp( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(ld[i] - maxval); + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} + +template +[[kernel]] void logsumexp_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) + : Limits::finite_min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= fast::exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(vals[i] - maxval); + } + } + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= fast::exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= fast::exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} diff --git a/mlx/backend/metal/kernels/logsumexp.metal b/mlx/backend/metal/kernels/logsumexp.metal new file mode 100644 index 000000000..eb76436cf --- /dev/null +++ b/mlx/backend/metal/kernels/logsumexp.metal @@ -0,0 +1,18 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/logsumexp.h" + +#define instantiate_logsumexp(name, itype) \ + instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \ + instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \ + +instantiate_logsumexp(float32, float) +instantiate_logsumexp(float16, half) +instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index f4c1536de..62f2457b7 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -380,69 +380,11 @@ template } // clang-format off -#define instantiate_rms_single_row(name, itype) \ - template [[host_name("rms" #name)]] [[kernel]] void \ - rms_single_row( \ - const device itype* x, \ - const device itype* w, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - \ - template [[host_name("vjp_rms" #name)]] [[kernel]] void \ - vjp_rms_single_row( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_rms_looped(name, itype) \ - template [[host_name("rms_looped" #name)]] [[kernel]] void \ - rms_looped( \ - const device itype* x, \ - const device itype* w, \ - device itype* out, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - \ - template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \ - vjp_rms_looped( \ - const device itype* x, \ - const device itype* w, \ - const device itype* g, \ - device itype* gx, \ - device itype* gw, \ - constant float& eps, \ - constant uint& axis_size, \ - constant uint& w_stride, \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_rms(name, itype) \ - instantiate_rms_single_row(name, itype) \ - instantiate_rms_looped(name, itype) +#define instantiate_rms(name, itype) \ + instantiate_kernel("rms" #name, rms_single_row, itype) \ + instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \ + instantiate_kernel("rms_looped" #name, rms_looped, itype) \ + instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype) instantiate_rms(float32, float) instantiate_rms(float16, half) diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..43e593d0e 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -40,7 +40,6 @@ template local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } - threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min; diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 1b64d59a1..79d5d3fca 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -9,47 +9,13 @@ using namespace metal; #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/softmax.h" -#define instantiate_softmax(name, itype) \ - template [[host_name("block_softmax_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("looped_softmax_" #name)]] [[kernel]] void \ - softmax_looped( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[threadgroup_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_softmax(name, itype) \ + instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \ + instantiate_kernel("looped_softmax_" #name, softmax_looped, itype) -#define instantiate_softmax_precise(name, itype) \ - template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ - softmax_single_row( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[thread_position_in_grid]], \ - uint _lid [[thread_position_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \ - softmax_looped( \ - const device itype* in, \ - device itype* out, \ - constant int& axis_size, \ - uint gid [[threadgroup_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_softmax_precise(name, itype) \ + instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \ + instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float) instantiate_softmax(float32, float) instantiate_softmax(float16, half) diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp new file mode 100644 index 000000000..4901190e1 --- /dev/null +++ b/mlx/backend/metal/logsumexp.cpp @@ -0,0 +1,96 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096; + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + if (!issubdtype(out.dtype(), floating)) { + throw std::runtime_error( + "[logsumexp] Does not support non-floating point types."); + } + auto& s = stream(); + auto& d = metal::device(s.device); + + // Make sure that the last dimension is contiguous + auto ensure_contiguous = [&s, &d](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + const int simd_size = 32; + const int n_reads = 4; + const int looped_limit = LOGSUMEXP_LOOPED_LIMIT; + + std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; + kernel_name += "logsumexp_"; + kernel_name += type_to_name(out); + + auto kernel = get_logsumexp_kernel(d, kernel_name, out); + auto& compute_encoder = d.get_command_encoder(s.index); + { + MTL::Size grid_dims, group_dims; + if (axis_size <= looped_limit) { + size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; + size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; + size_t threadgroup_size = simd_size * simds_needed; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } else { + size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t n_threads = n_rows * threadgroup_size; + grid_dims = MTL::Size(n_threads, 1, 1); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(axis_size, 2); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index ff561374d..2d6077ed1 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_logsumexp_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_scan_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index b089188b8..224721a50 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto set_output = [&s, &out](const array& x) { - bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; - if (no_copy && x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); - } - if (no_copy) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.is_donatable()) { out.copy_shared_buffer(x); } else { diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 2f1ae566f..84372b096 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -82,6 +82,7 @@ NO_CPU(LogicalNot) NO_CPU(LogicalAnd) NO_CPU(LogicalOr) NO_CPU(LogAddExp) +NO_CPU(LogSumExp) NO_CPU_MULTI(LUF) NO_CPU(Matmul) NO_CPU(Maximum) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 6e37a1d2b..6826c97f6 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -82,6 +82,7 @@ NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) +NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Matmul) NO_GPU(Maximum) diff --git a/mlx/export.cpp b/mlx/export.cpp index 4eb3ff99a..8051f786c 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -278,6 +278,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(LogicalAnd), SERIALIZE_PRIMITIVE(LogicalOr), SERIALIZE_PRIMITIVE(LogAddExp), + SERIALIZE_PRIMITIVE(LogSumExp), SERIALIZE_PRIMITIVE(Matmul), SERIALIZE_PRIMITIVE(Maximum), SERIALIZE_PRIMITIVE(Minimum), diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fe1852e0a..2f3997e7b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2359,6 +2359,29 @@ array logsumexp( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + throw std::invalid_argument("[logsumexp] Received empty array."); + } + if (a.ndim() == 0 && !axes.empty()) { + throw std::invalid_argument( + "[logsumexp] Received non-empty axes for array with 0 dimensions."); + } + bool is_complex = issubdtype(a.dtype(), complexfloating); + if (!is_complex && axes.size() == 1 && + (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + auto dtype = at_least_float(a.dtype()); + auto out_shape = a.shape(); + out_shape.back() = 1; + auto out = array( + std::move(out_shape), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); + if (!keepdims) { + out = squeeze(out, -1, s); + } + return out; + } auto maxval = stop_gradient(max(a, axes, true, s), s); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); out = add(out, reshape(maxval, out.shape(), s), s); @@ -3347,8 +3370,14 @@ array softmax( if (a.size() == 0) { return a; } + if (a.ndim() == 0 && !axes.empty()) { + throw std::invalid_argument( + "[softmax] Received non-empty axes for array with 0 dimensions."); + } - if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + bool is_complex = issubdtype(a.dtype(), complexfloating); + if (!is_complex && axes.size() == 1 && + (a.ndim() == axes[0] + 1 || axes[0] == -1)) { auto dtype = at_least_float(a.dtype()); return array( a.shape(), @@ -3357,7 +3386,7 @@ array softmax( {astype(a, dtype, s)}); } else { auto in = a; - if (precise) { + if (precise && !is_complex) { in = astype(a, float32, s); } auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5e5ec82e..6f9e45313 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2509,6 +2509,49 @@ std::pair, std::vector> LogAddExp::vmap( return {{logaddexp(a, b, stream())}, {to_ax}}; } +std::pair, std::vector> LogSumExp::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto in = inputs[0]; + if (ax == (in.ndim() - 1)) { + in = swapaxes(in, -1, -2, stream()); + ax = in.ndim() - 2; + } + return {{logsumexp(in, -1, true, stream())}, {ax}}; +} + +std::vector LogSumExp::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + assert(primals.size() == 1); + assert(cotangents.size() == 1); + return {multiply( + cotangents[0], + softmax(primals[0], std::vector{-1}, true, stream()), + stream())}; +} + +std::vector LogSumExp::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(tangents.size() == 1); + return {multiply( + tangents[0], + softmax(primals[0], std::vector{-1}, true, stream()), + stream())}; +} + +std::vector LogSumExp::output_shapes(const std::vector& inputs) { + auto s = inputs[0].shape(); + s.back() = 1; + return {s}; +} + std::vector Matmul::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index bb0ca8080..c7b2de878 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1350,6 +1350,20 @@ class LogAddExp : public UnaryPrimitive { DEFINE_INPUT_OUTPUT_SHAPE() }; +class LogSumExp : public UnaryPrimitive { + public: + explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(LogSumExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + class Matmul : public UnaryPrimitive { public: explicit Matmul(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d7c79d9db..a71d2c253 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(b_npy, b_mlx)) def test_logsumexp(self): + def logsumexp(x, axes=None): + maxs = mx.max(x, axis=axes, keepdims=True) + return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs + x = mx.array( [ [1.0, 2.0], [3.0, 4.0], ] ) - xnp = np.array(x.tolist(), dtype=np.float32) - expected = np.log(np.sum(np.exp(xnp))) - self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item())) + self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item())) + + x = mx.random.uniform(shape=(1025,)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Transposed + x = mx.random.uniform(shape=(2, 2, 8)) + x = x.swapaxes(0, 1) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Broadcast + x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + + # Large + x = mx.random.uniform(shape=(1025,)) + x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) def test_mean(self): x = mx.array( @@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase): x = mx.full((n,), vals=-float("inf")) self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) + # Transposed inputs + a = mx.random.uniform(shape=(32, 32, 32)) + b = mx.softmax(a, axis=-1) + c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1) + self.assertEqual((b - c).abs().max().item(), 0.0) + + with self.assertRaises(ValueError): + mx.softmax(mx.array(1.0), axis=-1) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)