Files
mlx/mlx/backend/cpu/arg_reduce.cpp

127 lines
3.5 KiB
C++
Raw Normal View History

2023-11-30 11:12:53 -08:00
// Copyright © 2023 Apple Inc.
2023-11-29 10:30:41 -08:00
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
2023-11-29 10:30:41 -08:00
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename InT, typename OpT>
2025-03-11 06:30:44 -07:00
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
2023-11-29 10:30:41 -08:00
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
2023-11-29 10:30:41 -08:00
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
2025-03-11 06:30:44 -07:00
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
2023-11-29 10:30:41 -08:00
}
2025-03-11 06:30:44 -07:00
out_ptr[i] = ind_v;
}
2023-11-29 10:30:41 -08:00
}
template <typename InT>
void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
2025-03-11 06:30:44 -07:00
int axis) {
2023-11-29 10:30:41 -08:00
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x < (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
2025-03-11 06:30:44 -07:00
arg_reduce<InT>(in, out, op, axis);
2023-11-29 10:30:41 -08:00
break;
}
case ArgReduce::ArgMax: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x > (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
2025-03-11 06:30:44 -07:00
arg_reduce<InT>(in, out, op, axis);
2023-11-29 10:30:41 -08:00
break;
}
}
}
} // namespace
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
2023-11-29 10:30:41 -08:00
assert(inputs.size() == 1);
auto& in = inputs[0];
2025-03-20 16:48:43 -07:00
out.set_data(allocator::malloc(out.nbytes()));
2025-03-11 06:30:44 -07:00
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
2023-11-29 10:30:41 -08:00
}
} // namespace mlx::core