mlx/mlx/backend/cpu/softmax.cpp
2025-03-20 16:48:43 -07:00

189 lines
4.9 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void softmax(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
});
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out, stream());
break;
case float16:
if (precise_) {
softmax<float16_t, float>(in, out, stream());
} else {
softmax<float16_t, float16_t>(in, out, stream());
}
break;
case bfloat16:
if (precise_) {
softmax<bfloat16_t, float>(in, out, stream());
} else {
softmax<bfloat16_t, bfloat16_t>(in, out, stream());
}
break;
case float64:
softmax<double, double>(in, out, stream());
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
break;
}
}
} // namespace mlx::core