mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix arg_reduce
This commit is contained in:
@@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<int64_t>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
int64_t ind_v = 0;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
|
||||
Reference in New Issue
Block a user