mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix arg_reduce
This commit is contained in:
@@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
|||||||
Strides strides = remove_index(in.strides(), axis);
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = remove_index(in.shape(), axis);
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<int64_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
for (int64_t i = 0; i < out.size(); ++i) {
|
for (int64_t i = 0; i < out.size(); ++i) {
|
||||||
auto loc = elem_to_loc(i, shape, strides);
|
auto loc = elem_to_loc(i, shape, strides);
|
||||||
auto local_in_ptr = in_ptr + loc;
|
auto local_in_ptr = in_ptr + loc;
|
||||||
int64_t ind_v = 0;
|
uint32_t ind_v = 0;
|
||||||
InT v = (*local_in_ptr);
|
InT v = (*local_in_ptr);
|
||||||
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||||
op(j, (*local_in_ptr), &ind_v, &v);
|
op(j, (*local_in_ptr), &ind_v, &v);
|
||||||
|
|||||||
Reference in New Issue
Block a user