diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 614e6f79c..f541c6214 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -96,8 +96,7 @@ array norm( if (num_axes == 0 || num_axes == 1 || num_axes == 2) return sqrt( sum(abs(a, s) * abs(a, s), - num_axes ? axis - : get_shape_reducing_over_all_axes(a.shape().size()), + num_axes ? axis : get_reduce_axes({}, a.ndim()), keepdims, s), s); @@ -116,7 +115,7 @@ array norm( std::vector ax = axis; if (axis.empty()) - ax = get_shape_reducing_over_all_axes(a.ndim()); + ax = get_reduce_axes({}, a.ndim()); else ax = normalize_axes(ax, a.ndim()); @@ -140,7 +139,7 @@ array norm( std::vector ax = axis; if (axis.empty()) - ax = get_shape_reducing_over_all_axes(a.ndim()); + ax = get_reduce_axes({}, a.ndim()); else ax = normalize_axes(ax, a.ndim()); diff --git a/mlx/utils.h b/mlx/utils.h index 1158b7c42..0b0ae9e93 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,7 @@ #pragma once +#include #include "array.h" #include "device.h" #include "dtype.h" @@ -42,8 +43,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } -/** - * Returns the axes vector [0, 1, ... ndim). - */ -std::vector get_shape_reducing_over_all_axes(int ndim); + +using IntOrVec = std::variant>; +inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { + std::vector axes; + if (std::holds_alternative(v)) { + axes.resize(dims); + std::iota(axes.begin(), axes.end(), 0); + } else if (auto pv = std::get_if(&v); pv) { + axes.push_back(*pv); + } else { + axes = std::get>(v); + } + return axes; +} } // namespace mlx::core diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 42ad37633..6b3739ae6 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -7,6 +7,7 @@ #include "mlx/fft.h" #include "mlx/ops.h" +#include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; diff --git a/python/src/utils.h b/python/src/utils.h index 5ac878979..9751b2d6e 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #pragma once -#include #include #include @@ -14,24 +13,10 @@ namespace py = pybind11; using namespace mlx::core; -using IntOrVec = std::variant>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{}; -inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { - std::vector axes; - if (std::holds_alternative(v)) { - axes.resize(dims); - std::iota(axes.begin(), axes.end(), 0); - } else if (auto pv = std::get_if(&v); pv) { - axes.push_back(*pv); - } else { - axes = std::get>(v); - } - return axes; -} - inline array to_array_with_accessor(py::object obj) { if (py::hasattr(obj, "__mlx_array__")) { return obj.attr("__mlx_array__")().cast();