mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
reused existing util for implementation of linalg.norm
This commit is contained in:
parent
145a4d143d
commit
f82ab0eec9
@ -96,8 +96,7 @@ array norm(
|
||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||
return sqrt(
|
||||
sum(abs(a, s) * abs(a, s),
|
||||
num_axes ? axis
|
||||
: get_shape_reducing_over_all_axes(a.shape().size()),
|
||||
num_axes ? axis : get_reduce_axes({}, a.ndim()),
|
||||
keepdims,
|
||||
s),
|
||||
s);
|
||||
@ -116,7 +115,7 @@ array norm(
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||
ax = get_reduce_axes({}, a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
@ -140,7 +139,7 @@ array norm(
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||
ax = get_reduce_axes({}, a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
|
19
mlx/utils.h
19
mlx/utils.h
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
@ -42,8 +43,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||
return os << static_cast<float>(v);
|
||||
}
|
||||
/**
|
||||
* Returns the axes vector [0, 1, ... ndim).
|
||||
*/
|
||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim);
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
} // namespace mlx::core
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/complex.h>
|
||||
@ -14,24 +13,10 @@ namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray = std::
|
||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||
static constexpr std::monostate none{};
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array_with_accessor(py::object obj) {
|
||||
if (py::hasattr(obj, "__mlx_array__")) {
|
||||
return obj.attr("__mlx_array__")().cast<array>();
|
||||
|
Loading…
Reference in New Issue
Block a user