reused existing util for implementation of linalg.norm

This commit is contained in:
Gabrijel Boduljak 2023-12-22 04:58:51 +01:00 committed by Awni Hannun
parent 145a4d143d
commit f82ab0eec9
4 changed files with 19 additions and 23 deletions

View File

@ -96,8 +96,7 @@ array norm(
if (num_axes == 0 || num_axes == 1 || num_axes == 2) if (num_axes == 0 || num_axes == 1 || num_axes == 2)
return sqrt( return sqrt(
sum(abs(a, s) * abs(a, s), sum(abs(a, s) * abs(a, s),
num_axes ? axis num_axes ? axis : get_reduce_axes({}, a.ndim()),
: get_shape_reducing_over_all_axes(a.shape().size()),
keepdims, keepdims,
s), s),
s); s);
@ -116,7 +115,7 @@ array norm(
std::vector<int> ax = axis; std::vector<int> ax = axis;
if (axis.empty()) if (axis.empty())
ax = get_shape_reducing_over_all_axes(a.ndim()); ax = get_reduce_axes({}, a.ndim());
else else
ax = normalize_axes(ax, a.ndim()); ax = normalize_axes(ax, a.ndim());
@ -140,7 +139,7 @@ array norm(
std::vector<int> ax = axis; std::vector<int> ax = axis;
if (axis.empty()) if (axis.empty())
ax = get_shape_reducing_over_all_axes(a.ndim()); ax = get_reduce_axes({}, a.ndim());
else else
ax = normalize_axes(ax, a.ndim()); ax = normalize_axes(ax, a.ndim());

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <numeric>
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "dtype.h" #include "dtype.h"
@ -42,8 +43,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v); return os << static_cast<float>(v);
} }
/**
* Returns the axes vector [0, 1, ... ndim). using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
*/ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> get_shape_reducing_over_all_axes(int ndim); std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,6 +7,7 @@
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/utils.h"
namespace py = pybind11; namespace py = pybind11;
using namespace py::literals; using namespace py::literals;

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <numeric>
#include <variant> #include <variant>
#include <pybind11/complex.h> #include <pybind11/complex.h>
@ -14,24 +13,10 @@ namespace py = pybind11;
using namespace mlx::core; using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std:: using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>; variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{}; static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
inline array to_array_with_accessor(py::object obj) { inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) { if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>(); return obj.attr("__mlx_array__")().cast<array>();