Files
mlx/mlx/distributed/reduction_ops.h
Angelos Katharopoulos 4dbffb3954 All gather
2025-12-04 14:20:51 -08:00

39 lines
735 B
C++

// Copyright © 2025 Apple Inc.
namespace mlx::core::distributed::detail {
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output += *input;
input++;
output++;
}
}
};
template <typename T>
struct MaxOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::max(*output, *input);
input++;
output++;
}
}
};
template <typename T>
struct MinOp {
void operator()(const T* input, T* output, size_t N) const {
while (N-- > 0) {
*output = std::min(*output, *input);
input++;
output++;
}
}
};
} // namespace mlx::core::distributed::detail