mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -104,48 +104,14 @@ void reduce_dispatch_out(
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||
break;
|
||||
case uint8:
|
||||
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint16:
|
||||
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint32:
|
||||
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint64:
|
||||
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int8:
|
||||
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int16:
|
||||
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int32:
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int64:
|
||||
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case float16:
|
||||
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case float32:
|
||||
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case complex64:
|
||||
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||
break;
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
} break;
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
@@ -168,6 +134,29 @@ void reduce_dispatch_out(
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
Reference in New Issue
Block a user