mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-30 13:41:14 +08:00
218 lines
6.1 KiB
C++
218 lines
6.1 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#include <cassert>
|
|
#include <functional>
|
|
#include <limits>
|
|
|
|
#include "mlx/backend/common/reduce.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
namespace {
|
|
|
|
template <typename U>
|
|
struct Limits {
|
|
static const U max;
|
|
static const U min;
|
|
};
|
|
|
|
#define instantiate_default_limit(type) \
|
|
template <> \
|
|
struct Limits<type> { \
|
|
static constexpr type max = std::numeric_limits<type>::max(); \
|
|
static constexpr type min = std::numeric_limits<type>::min(); \
|
|
};
|
|
|
|
instantiate_default_limit(uint8_t);
|
|
instantiate_default_limit(uint16_t);
|
|
instantiate_default_limit(uint32_t);
|
|
instantiate_default_limit(uint64_t);
|
|
instantiate_default_limit(int8_t);
|
|
instantiate_default_limit(int16_t);
|
|
instantiate_default_limit(int32_t);
|
|
instantiate_default_limit(int64_t);
|
|
|
|
#define instantiate_float_limit(type) \
|
|
template <> \
|
|
struct Limits<type> { \
|
|
static const type max; \
|
|
static const type min; \
|
|
};
|
|
|
|
instantiate_float_limit(float16_t);
|
|
instantiate_float_limit(bfloat16_t);
|
|
instantiate_float_limit(float);
|
|
instantiate_float_limit(complex64_t);
|
|
|
|
template <>
|
|
struct Limits<bool> {
|
|
static constexpr bool max = true;
|
|
static constexpr bool min = false;
|
|
};
|
|
|
|
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
|
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
|
const bfloat16_t Limits<bfloat16_t>::max =
|
|
std::numeric_limits<float>::infinity();
|
|
const bfloat16_t Limits<bfloat16_t>::min =
|
|
-std::numeric_limits<float>::infinity();
|
|
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
|
const float16_t Limits<float16_t>::min =
|
|
-std::numeric_limits<float>::infinity();
|
|
const complex64_t Limits<complex64_t>::max =
|
|
std::numeric_limits<float>::infinity();
|
|
const complex64_t Limits<complex64_t>::min =
|
|
-std::numeric_limits<float>::infinity();
|
|
|
|
struct AndReduce {
|
|
template <typename T>
|
|
void operator()(bool* a, T b) {
|
|
(*a) &= (b != 0);
|
|
}
|
|
|
|
void operator()(bool* y, bool x) {
|
|
(*y) &= x;
|
|
}
|
|
};
|
|
|
|
struct OrReduce {
|
|
template <typename T>
|
|
void operator()(bool* a, T b) {
|
|
(*a) |= (b != 0);
|
|
}
|
|
|
|
void operator()(bool* y, bool x) {
|
|
(*y) |= x;
|
|
}
|
|
};
|
|
|
|
template <typename InT>
|
|
void reduce_dispatch_out(
|
|
const array& in,
|
|
array& out,
|
|
Reduce::ReduceType rtype,
|
|
const std::vector<int>& axes) {
|
|
switch (rtype) {
|
|
case Reduce::And: {
|
|
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
|
break;
|
|
}
|
|
case Reduce::Or: {
|
|
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
|
break;
|
|
}
|
|
case Reduce::Sum: {
|
|
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
|
switch (out.dtype()) {
|
|
case bool_:
|
|
reduction_op<InT, bool>(in, out, axes, false, op);
|
|
break;
|
|
case uint8:
|
|
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
|
break;
|
|
case uint16:
|
|
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
|
break;
|
|
case uint32:
|
|
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
|
break;
|
|
case uint64:
|
|
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
|
break;
|
|
case int8:
|
|
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
|
break;
|
|
case int16:
|
|
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
|
break;
|
|
case int32:
|
|
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
|
break;
|
|
case int64:
|
|
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
|
break;
|
|
case float16:
|
|
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
|
break;
|
|
case float32:
|
|
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
|
break;
|
|
case bfloat16:
|
|
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
|
break;
|
|
case complex64:
|
|
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
|
break;
|
|
}
|
|
} break;
|
|
case Reduce::Prod: {
|
|
auto op = [](auto y, auto x) { (*y) *= x; };
|
|
reduction_op<InT, InT>(in, out, axes, 1, op);
|
|
break;
|
|
}
|
|
case Reduce::Max: {
|
|
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
|
auto init = Limits<InT>::min;
|
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
|
break;
|
|
}
|
|
case Reduce::Min: {
|
|
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
|
auto init = Limits<InT>::max;
|
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
|
assert(inputs.size() == 1);
|
|
auto& in = inputs[0];
|
|
switch (in.dtype()) {
|
|
case bool_:
|
|
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case uint8:
|
|
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case uint16:
|
|
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case uint32:
|
|
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case uint64:
|
|
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case int8:
|
|
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case int16:
|
|
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case int32:
|
|
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case int64:
|
|
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case float16:
|
|
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case float32:
|
|
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case bfloat16:
|
|
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
case complex64:
|
|
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
|
break;
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|