mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
parent
ec2854b13a
commit
de5f38fd48
@ -58,6 +58,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
140
mlx/backend/cpu/logsumexp.cpp
Normal file
140
mlx/backend/cpu/logsumexp.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void logsumexp(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* current_in_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
}
|
||||
// Normalize
|
||||
*out_ptr = std::isinf(maximum)
|
||||
? static_cast<T>(maximum)
|
||||
: static_cast<T>(std::log(normalizer) + maximum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
logsumexp<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
logsumexp<float16_t, float>(in, out, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
||||
break;
|
||||
case float64:
|
||||
logsumexp<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] only supports floating point types");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -119,12 +119,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float64:
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[softmax] Only defined for floating point types.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
|
||||
make_jit_source(logsumexp)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@ -95,6 +96,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
|
@ -1,9 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@ -20,6 +20,7 @@ const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
|
@ -1,23 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
@ -1,8 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@ -21,13 +19,11 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::arange()
|
||||
<< fmt::format(
|
||||
arange_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::arange();
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, "arange", get_type_string(out.dtype()));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@ -259,14 +255,34 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
return kernel_source.str();
|
||||
std::string kernel_source = metal::utils();
|
||||
auto in_type = get_type_string(out.dtype());
|
||||
auto acc_type = get_type_string(precise ? float32 : out.dtype());
|
||||
kernel_source += metal::softmax();
|
||||
kernel_source += get_template_definition(
|
||||
"block_" + lib_name, "softmax_single_row", in_type, acc_type);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "softmax_looped", in_type, acc_type);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
auto t_str = get_type_string(out.dtype());
|
||||
std::string kernel_source;
|
||||
kernel_source = metal::utils();
|
||||
kernel_source += metal::logsumexp();
|
||||
kernel_source +=
|
||||
get_template_definition("block_" + lib_name, "logsumexp", t_str);
|
||||
kernel_source += get_template_definition(
|
||||
"looped_" + lib_name, "logsumexp_looped", t_str);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
@ -59,6 +59,11 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
|
@ -5,11 +5,7 @@
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
instantiate_kernel("arange" #tname, arange, type)
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
|
@ -493,71 +493,11 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_layer_norm_single_row(name, itype) \
|
||||
instantiate_layer_norm_looped(name, itype)
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \
|
||||
instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \
|
||||
instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \
|
||||
instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype)
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
|
142
mlx/backend/metal/kernels/logsumexp.h
Normal file
142
mlx/backend/metal/kernels/logsumexp.h
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
template <typename T, typename AccT = float, int N_READS = 4>
|
||||
[[kernel]] void logsumexp(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * size_t(axis_size) + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
|
||||
// Get the max
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
maxval = simd_max(maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[0] = maxval;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(ld[i] - maxval);
|
||||
}
|
||||
normalizer = simd_sum(normalizer);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT = float, int N_READS = 4>
|
||||
[[kernel]] void logsumexp_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * size_t(axis_size);
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
||||
}
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
normalizer += fast::exp(vals[i] - maxval);
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
maxval = simd_max(maxval);
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
normalizer = simd_sum(normalizer);
|
||||
|
||||
prevmax = maxval;
|
||||
if (simd_lane_id == 0) {
|
||||
local_max[simd_group_id] = maxval;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
maxval = simd_max(local_max[simd_lane_id]);
|
||||
normalizer *= fast::exp(prevmax - maxval);
|
||||
if (simd_lane_id == 0) {
|
||||
local_normalizer[simd_group_id] = normalizer;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
|
||||
if (simd_group_id == 0) {
|
||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
||||
}
|
||||
}
|
||||
}
|
18
mlx/backend/metal/kernels/logsumexp.metal
Normal file
18
mlx/backend/metal/kernels/logsumexp.metal
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/logsumexp.h"
|
||||
|
||||
#define instantiate_logsumexp(name, itype) \
|
||||
instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
|
||||
instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
|
||||
|
||||
instantiate_logsumexp(float32, float)
|
||||
instantiate_logsumexp(float16, half)
|
||||
instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on
|
@ -380,69 +380,11 @@ template <typename T, int N_READS = RMS_N_READS>
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_kernel("rms" #name, rms_single_row, itype) \
|
||||
instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \
|
||||
instantiate_kernel("rms_looped" #name, rms_looped, itype) \
|
||||
instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype)
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
|
@ -40,7 +40,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
local_max[simd_lane_id] = Limits<AccT>::min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
|
@ -9,47 +9,13 @@ using namespace metal;
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/softmax.h"
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
|
||||
instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
|
||||
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
|
||||
instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
|
||||
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
|
96
mlx/backend/metal/logsumexp.cpp
Normal file
96
mlx/backend/metal/logsumexp.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int LOGSUMEXP_LOOPED_LIMIT = 4096;
|
||||
|
||||
void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] Does not support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto ensure_contiguous = [&s, &d](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = 4;
|
||||
const int looped_limit = LOGSUMEXP_LOOPED_LIMIT;
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
kernel_name += "logsumexp_";
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_logsumexp_kernel(d, kernel_name, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -72,6 +72,13 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_logsumexp_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@ -23,12 +23,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
|
@ -82,6 +82,7 @@ NO_CPU(LogicalNot)
|
||||
NO_CPU(LogicalAnd)
|
||||
NO_CPU(LogicalOr)
|
||||
NO_CPU(LogAddExp)
|
||||
NO_CPU(LogSumExp)
|
||||
NO_CPU_MULTI(LUF)
|
||||
NO_CPU(Matmul)
|
||||
NO_CPU(Maximum)
|
||||
|
@ -82,6 +82,7 @@ NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(LogSumExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
|
@ -278,6 +278,7 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(LogicalAnd),
|
||||
SERIALIZE_PRIMITIVE(LogicalOr),
|
||||
SERIALIZE_PRIMITIVE(LogAddExp),
|
||||
SERIALIZE_PRIMITIVE(LogSumExp),
|
||||
SERIALIZE_PRIMITIVE(Matmul),
|
||||
SERIALIZE_PRIMITIVE(Maximum),
|
||||
SERIALIZE_PRIMITIVE(Minimum),
|
||||
|
33
mlx/ops.cpp
33
mlx/ops.cpp
@ -2359,6 +2359,29 @@ array logsumexp(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument("[logsumexp] Received empty array.");
|
||||
}
|
||||
if (a.ndim() == 0 && !axes.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = 1;
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<LogSumExp>(to_stream(s)),
|
||||
{astype(a, dtype, s)});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, -1, s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
auto maxval = stop_gradient(max(a, axes, true, s), s);
|
||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||
@ -3347,8 +3370,14 @@ array softmax(
|
||||
if (a.size() == 0) {
|
||||
return a;
|
||||
}
|
||||
if (a.ndim() == 0 && !axes.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
@ -3357,7 +3386,7 @@ array softmax(
|
||||
{astype(a, dtype, s)});
|
||||
} else {
|
||||
auto in = a;
|
||||
if (precise) {
|
||||
if (precise && !is_complex) {
|
||||
in = astype(a, float32, s);
|
||||
}
|
||||
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
|
||||
|
@ -2509,6 +2509,49 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
|
||||
return {{logaddexp(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> LogSumExp::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto in = inputs[0];
|
||||
if (ax == (in.ndim() - 1)) {
|
||||
in = swapaxes(in, -1, -2, stream());
|
||||
ax = in.ndim() - 2;
|
||||
}
|
||||
return {{logsumexp(in, -1, true, stream())}, {ax}};
|
||||
}
|
||||
|
||||
std::vector<array> LogSumExp::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
return {multiply(
|
||||
cotangents[0],
|
||||
softmax(primals[0], std::vector<int>{-1}, true, stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> LogSumExp::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
return {multiply(
|
||||
tangents[0],
|
||||
softmax(primals[0], std::vector<int>{-1}, true, stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<Shape> LogSumExp::output_shapes(const std::vector<array>& inputs) {
|
||||
auto s = inputs[0].shape();
|
||||
s.back() = 1;
|
||||
return {s};
|
||||
}
|
||||
|
||||
std::vector<array> Matmul::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -1350,6 +1350,20 @@ class LogAddExp : public UnaryPrimitive {
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class LogSumExp : public UnaryPrimitive {
|
||||
public:
|
||||
explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogSumExp)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
};
|
||||
|
||||
class Matmul : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
@ -690,15 +690,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(b_npy, b_mlx))
|
||||
|
||||
def test_logsumexp(self):
|
||||
def logsumexp(x, axes=None):
|
||||
maxs = mx.max(x, axis=axes, keepdims=True)
|
||||
return mx.log(mx.sum(mx.exp(x - maxs), axis=axes, keepdims=True)) + maxs
|
||||
|
||||
x = mx.array(
|
||||
[
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0],
|
||||
]
|
||||
)
|
||||
xnp = np.array(x.tolist(), dtype=np.float32)
|
||||
expected = np.log(np.sum(np.exp(xnp)))
|
||||
self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item()))
|
||||
self.assertTrue(math.isclose(mx.logsumexp(x).item(), logsumexp(x).item()))
|
||||
|
||||
x = mx.random.uniform(shape=(1025,))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Transposed
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
x = x.swapaxes(0, 1)
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Broadcast
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
# Large
|
||||
x = mx.random.uniform(shape=(1025,))
|
||||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
|
||||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
|
||||
|
||||
def test_mean(self):
|
||||
x = mx.array(
|
||||
@ -1643,6 +1662,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x = mx.full((n,), vals=-float("inf"))
|
||||
self.assertTrue(mx.all(mx.isnan(mx.softmax(x))))
|
||||
|
||||
# Transposed inputs
|
||||
a = mx.random.uniform(shape=(32, 32, 32))
|
||||
b = mx.softmax(a, axis=-1)
|
||||
c = mx.softmax(a.swapaxes(0, 1), axis=-1).swapaxes(0, 1)
|
||||
self.assertEqual((b - c).abs().max().item(), 0.0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.softmax(mx.array(1.0), axis=-1)
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Loading…
Reference in New Issue
Block a user