mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-02 23:31:16 +08:00

* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
128 lines
3.2 KiB
C++
128 lines
3.2 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include <cassert>
|
|
#include <cmath>
|
|
|
|
#include "mlx/backend/common/copy.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
namespace {
|
|
|
|
template <typename T, typename AccT>
|
|
void softmax(const array& in, array& out) {
|
|
const T* in_ptr = in.data<T>();
|
|
T* out_ptr = out.data<T>();
|
|
int N = in.shape().back();
|
|
int M = in.data_size() / N;
|
|
const T* current_in_ptr;
|
|
T* current_out_ptr;
|
|
|
|
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
|
// Find the maximum
|
|
current_in_ptr = in_ptr;
|
|
AccT maximum = *current_in_ptr;
|
|
for (int j = 0; j < N; j++, current_in_ptr++) {
|
|
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
|
: maximum;
|
|
}
|
|
|
|
// Compute the normalizer and the exponentials
|
|
AccT normalizer = 0;
|
|
current_out_ptr = out_ptr;
|
|
current_in_ptr = in_ptr;
|
|
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
|
AccT expv = std::exp(*current_in_ptr - maximum);
|
|
normalizer += expv;
|
|
if constexpr (std::is_same<T, AccT>::value) {
|
|
*current_out_ptr = expv;
|
|
}
|
|
}
|
|
normalizer = 1 / normalizer;
|
|
|
|
// Normalize
|
|
current_in_ptr = in_ptr;
|
|
current_out_ptr = out_ptr;
|
|
for (int j = 0; j < N; j++, current_out_ptr++) {
|
|
if constexpr (std::is_same<T, AccT>::value) {
|
|
*current_out_ptr *= normalizer;
|
|
} else {
|
|
auto v = std::exp(*current_in_ptr - maximum);
|
|
*current_out_ptr = static_cast<T>(v * normalizer);
|
|
current_in_ptr++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
|
assert(inputs.size() == 1);
|
|
|
|
// Make sure that the last dimension is contiguous
|
|
auto check_input = [](array x) {
|
|
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
|
if (x.ndim() > 1) {
|
|
auto s = x.strides()[x.ndim() - 2];
|
|
no_copy &= (s == 0 || s == x.shape().back());
|
|
}
|
|
if (no_copy) {
|
|
return x;
|
|
} else {
|
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
|
copy(x, x_copy, CopyType::General);
|
|
return x_copy;
|
|
}
|
|
};
|
|
array in = check_input(std::move(inputs[0]));
|
|
if (in.is_donatable()) {
|
|
out.copy_shared_buffer(in);
|
|
} else {
|
|
out.set_data(
|
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
|
in.data_size(),
|
|
in.strides(),
|
|
in.flags());
|
|
}
|
|
|
|
switch (in.dtype()) {
|
|
case bool_:
|
|
case uint8:
|
|
case uint16:
|
|
case uint32:
|
|
case uint64:
|
|
case int8:
|
|
case int16:
|
|
case int32:
|
|
case int64:
|
|
throw std::invalid_argument(
|
|
"Softmax is defined only for floating point types");
|
|
break;
|
|
case float32:
|
|
softmax<float, float>(in, out);
|
|
break;
|
|
case float16:
|
|
if (precise_) {
|
|
softmax<float16_t, float>(in, out);
|
|
} else {
|
|
softmax<float16_t, float16_t>(in, out);
|
|
}
|
|
break;
|
|
case bfloat16:
|
|
if (precise_) {
|
|
softmax<bfloat16_t, float>(in, out);
|
|
} else {
|
|
softmax<bfloat16_t, bfloat16_t>(in, out);
|
|
}
|
|
break;
|
|
case complex64:
|
|
throw std::invalid_argument(
|
|
"[Softmax] Not yet implemented for complex64");
|
|
break;
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|