Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun 2025-03-31 07:36:55 -07:00 committed by GitHub
parent ec2854b13a
commit de5f38fd48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 590 additions and 255 deletions

View File

@ -58,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]); auto in = set_output(inputs[0]);
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out, stream()); softmax<float, float>(in, out, stream());
break; break;
@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64: case float64:
softmax<double, double>(in, out, stream()); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
@ -95,6 +96,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@ -20,6 +20,7 @@ const char* copy();
const char* fft(); const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* hadamard(); const char* hadamard();
const char* logsumexp();
const char* quantized(); const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();

View File

@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@ -1,8 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::arange() kernel_source += metal::arange();
<< fmt::format( kernel_source += get_template_definition(
arange_kernels, kernel_name, "arange", get_type_string(out.dtype()));
kernel_name, return kernel_source;
get_type_string(out.dtype()));
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] { auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::softmax() auto in_type = get_type_string(out.dtype());
<< fmt::format( auto acc_type = get_type_string(precise ? float32 : out.dtype());
softmax_kernels, kernel_source += metal::softmax();
lib_name, kernel_source += get_template_definition(
get_type_string(out.dtype()), "block_" + lib_name, "softmax_single_row", in_type, acc_type);
get_type_string(precise ? float32 : out.dtype())); kernel_source += get_template_definition(
return kernel_source.str(); "looped_" + lib_name, "softmax_looped", in_type, acc_type);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&] {
auto t_str = get_type_string(out.dtype());
std::string kernel_source;
kernel_source = metal::utils();
kernel_source += metal::logsumexp();
kernel_source +=
get_template_definition("block_" + lib_name, "logsumexp", t_str);
kernel_source += get_template_definition(
"looped_" + lib_name, "logsumexp_looped", t_str);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }

View File

@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise, bool precise,
const array& out); const array& out);
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h) build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h) build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h) build_kernel(unary unary.h unary_ops.h)

View File

@ -5,11 +5,7 @@
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \ instantiate_kernel("arange" #tname, arange, type)
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t) instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t) instantiate_arange(uint16, uint16_t)

View File

@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
vjp_layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
vjp_layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gb, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \ #define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \ instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
instantiate_layer_norm_looped(name, itype) instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
instantiate_layer_norm(float32, float) instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half) instantiate_layer_norm(float16, half)

View File

@ -0,0 +1,142 @@
// Copyright © 2025 Apple Inc.
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0;
}
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(ld[i] - maxval);
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}
template <typename T, typename AccT = float, int N_READS = 4>
[[kernel]] void logsumexp_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= fast::exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += fast::exp(vals[i] - maxval);
}
}
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= fast::exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= fast::exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
}
}
}

View File

@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"
#define instantiate_logsumexp(name, itype) \
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
instantiate_logsumexp(float32, float)
instantiate_logsumexp(float16, half)
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on

View File

@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
} }
// clang-format off // clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
vjp_rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
\
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
vjp_rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* g, \
device itype* gx, \
device itype* gw, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \ #define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \ instantiate_kernel("rms" #name, rms_single_row, itype) \
instantiate_rms_looped(name, itype) instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
instantiate_rms(float32, float) instantiate_rms(float32, float)
instantiate_rms(float16, half) instantiate_rms(float16, half)

View File

@ -40,7 +40,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
local_max[simd_lane_id] = Limits<AccT>::min; local_max[simd_lane_id] = Limits<AccT>::min;
local_normalizer[simd_lane_id] = 0; local_normalizer[simd_lane_id] = 0;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max // Get the max
AccT maxval = Limits<AccT>::finite_min; AccT maxval = Limits<AccT>::finite_min;

View File

@ -10,46 +10,12 @@ using namespace metal;
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \ #define instantiate_softmax(name, itype) \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
softmax_single_row<itype>( \ instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \ #define instantiate_softmax_precise(name, itype) \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
softmax_single_row<itype, float>( \ instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_softmax(float32, float) instantiate_softmax(float32, float)
instantiate_softmax(float16, half) instantiate_softmax(float16, half)

View File

@ -0,0 +1,96 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[logsumexp] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
auto ensure_contiguous = [&s, &d](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = 4;
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "logsumexp_";
kernel_name += type_to_name(out);
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_logsumexp_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel( MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [&s, &out](const array& x) { auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {

View File

@ -82,6 +82,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd) NO_CPU(LogicalAnd)
NO_CPU(LogicalOr) NO_CPU(LogicalOr)
NO_CPU(LogAddExp) NO_CPU(LogAddExp)
NO_CPU(LogSumExp)
NO_CPU_MULTI(LUF) NO_CPU_MULTI(LUF)
NO_CPU(Matmul) NO_CPU(Matmul)
NO_CPU(Maximum) NO_CPU(Maximum)

View File

@ -82,6 +82,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd) NO_GPU(LogicalAnd)
NO_GPU(LogicalOr) NO_GPU(LogicalOr)
NO_GPU(LogAddExp) NO_GPU(LogAddExp)
NO_GPU(LogSumExp)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU(Matmul) NO_GPU(Matmul)
NO_GPU(Maximum) NO_GPU(Maximum)

View File

@ -278,6 +278,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(LogicalAnd), SERIALIZE_PRIMITIVE(LogicalAnd),
SERIALIZE_PRIMITIVE(LogicalOr), SERIALIZE_PRIMITIVE(LogicalOr),
SERIALIZE_PRIMITIVE(LogAddExp), SERIALIZE_PRIMITIVE(LogAddExp),
SERIALIZE_PRIMITIVE(LogSumExp),
SERIALIZE_PRIMITIVE(Matmul), SERIALIZE_PRIMITIVE(Matmul),
SERIALIZE_PRIMITIVE(Maximum), SERIALIZE_PRIMITIVE(Maximum),
SERIALIZE_PRIMITIVE(Minimum), SERIALIZE_PRIMITIVE(Minimum),

View File

@ -2359,6 +2359,29 @@ array logsumexp(
const std::vector<int>& axes, const std::vector<int>& axes,
bool keepdims /* = false */, bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
if (a.size() == 0) {
throw std::invalid_argument("[logsumexp] Received empty array.");
}
if (a.ndim() == 0 && !axes.empty()) {
throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape();
out_shape.back() = 1;
auto out = array(
std::move(out_shape),
dtype,
std::make_shared<LogSumExp>(to_stream(s)),
{astype(a, dtype, s)});
if (!keepdims) {
out = squeeze(out, -1, s);
}
return out;
}
auto maxval = stop_gradient(max(a, axes, true, s), s); auto maxval = stop_gradient(max(a, axes, true, s), s);
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
out = add(out, reshape(maxval, out.shape(), s), s); out = add(out, reshape(maxval, out.shape(), s), s);
@ -3347,8 +3370,14 @@ array softmax(
if (a.size() == 0) { if (a.size() == 0) {
return a; return a;
} }
if (a.ndim() == 0 && !axes.empty()) {
throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions.");
}
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return array( return array(
a.shape(), a.shape(),
@ -3357,7 +3386,7 @@ array softmax(
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} else { } else {
auto in = a; auto in = a;
if (precise) { if (precise && !is_complex) {
in = astype(a, float32, s); in = astype(a, float32, s);
} }
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s); auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);

View File

@ -2509,6 +2509,49 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
return {{logaddexp(a, b, stream())}, {to_ax}}; return {{logaddexp(a, b, stream())}, {to_ax}};
} }
std::pair<std::vector<array>, std::vector<int>> LogSumExp::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto in = inputs[0];
if (ax == (in.ndim() - 1)) {
in = swapaxes(in, -1, -2, stream());
ax = in.ndim() - 2;
}
return {{logsumexp(in, -1, true, stream())}, {ax}};
}
std::vector<array> LogSumExp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
return {multiply(
cotangents[0],
softmax(primals[0], std::vector<int>{-1}, true, stream()),
stream())};
}
std::vector<array> LogSumExp::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(tangents.size() == 1);
return {multiply(
tangents[0],
softmax(primals[0], std::vector<int>{-1}, true, stream()),
stream())};
}
std::vector<Shape> LogSumExp::output_shapes(const std::vector<array>& inputs) {
auto s = inputs[0].shape();
s.back() = 1;
return {s};
}
std::vector<array> Matmul::vjp( std::vector<array> Matmul::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,

View File

@ -1350,6 +1350,20 @@ class LogAddExp : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE() DEFINE_INPUT_OUTPUT_SHAPE()
}; };
class LogSumExp : public UnaryPrimitive {
public:
explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(LogSumExp)
DEFINE_DEFAULT_IS_EQUIVALENT()
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
};
class Matmul : public UnaryPrimitive { class Matmul : public UnaryPrimitive {
public: public:
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {} explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}

View File

@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(b_npy, b_mlx)) self.assertTrue(np.array_equal(b_npy, b_mlx))
def test_logsumexp(self): def test_logsumexp(self):
def logsumexp(x, axes=None):
maxs = mx.max(x, axis=axes, keepdims=True)
return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs
x = mx.array( x = mx.array(
[ [
[1.0, 2.0], [1.0, 2.0],
[3.0, 4.0], [3.0, 4.0],
] ]
) )
xnp = np.array(x.tolist(), dtype=np.float32) self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item()))
expected = np.log(np.sum(np.exp(xnp)))
self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item())) x = mx.random.uniform(shape=(1025,))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Transposed
x = mx.random.uniform(shape=(2, 2, 8))
x = x.swapaxes(0, 1)
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Broadcast
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
# Large
x = mx.random.uniform(shape=(1025,))
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
def test_mean(self): def test_mean(self):
x = mx.array( x = mx.array(
@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase):
x = mx.full((n,), vals=-float("inf")) x = mx.full((n,), vals=-float("inf"))
self.assertTrue(mx.all(mx.isnan(mx.softmax(x)))) self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
# Transposed inputs
a = mx.random.uniform(shape=(32, 32, 32))
b = mx.softmax(a, axis=-1)
c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1)
self.assertEqual((b - c).abs().max().item(), 0.0)
with self.assertRaises(ValueError):
mx.softmax(mx.array(1.0), axis=-1)
def test_concatenate(self): def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32) a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)