mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
reused existing util for implementation of linalg.norm
This commit is contained in:
parent
145a4d143d
commit
f82ab0eec9
@ -96,8 +96,7 @@ array norm(
|
|||||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||||
return sqrt(
|
return sqrt(
|
||||||
sum(abs(a, s) * abs(a, s),
|
sum(abs(a, s) * abs(a, s),
|
||||||
num_axes ? axis
|
num_axes ? axis : get_reduce_axes({}, a.ndim()),
|
||||||
: get_shape_reducing_over_all_axes(a.shape().size()),
|
|
||||||
keepdims,
|
keepdims,
|
||||||
s),
|
s),
|
||||||
s);
|
s);
|
||||||
@ -116,7 +115,7 @@ array norm(
|
|||||||
std::vector<int> ax = axis;
|
std::vector<int> ax = axis;
|
||||||
|
|
||||||
if (axis.empty())
|
if (axis.empty())
|
||||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
ax = get_reduce_axes({}, a.ndim());
|
||||||
else
|
else
|
||||||
ax = normalize_axes(ax, a.ndim());
|
ax = normalize_axes(ax, a.ndim());
|
||||||
|
|
||||||
@ -140,7 +139,7 @@ array norm(
|
|||||||
std::vector<int> ax = axis;
|
std::vector<int> ax = axis;
|
||||||
|
|
||||||
if (axis.empty())
|
if (axis.empty())
|
||||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
ax = get_reduce_axes({}, a.ndim());
|
||||||
else
|
else
|
||||||
ax = normalize_axes(ax, a.ndim());
|
ax = normalize_axes(ax, a.ndim());
|
||||||
|
|
||||||
|
19
mlx/utils.h
19
mlx/utils.h
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "dtype.h"
|
#include "dtype.h"
|
||||||
@ -42,8 +43,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
|||||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||||
return os << static_cast<float>(v);
|
return os << static_cast<float>(v);
|
||||||
}
|
}
|
||||||
/**
|
|
||||||
* Returns the axes vector [0, 1, ... ndim).
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
*/
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim);
|
std::vector<int> axes;
|
||||||
|
if (std::holds_alternative<std::monostate>(v)) {
|
||||||
|
axes.resize(dims);
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||||
|
axes.push_back(*pv);
|
||||||
|
} else {
|
||||||
|
axes = std::get<std::vector<int>>(v);
|
||||||
|
}
|
||||||
|
return axes;
|
||||||
|
}
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace py::literals;
|
using namespace py::literals;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <numeric>
|
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include <pybind11/complex.h>
|
#include <pybind11/complex.h>
|
||||||
@ -14,24 +13,10 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
|
||||||
using ScalarOrArray = std::
|
using ScalarOrArray = std::
|
||||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||||
static constexpr std::monostate none{};
|
static constexpr std::monostate none{};
|
||||||
|
|
||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|
||||||
std::vector<int> axes;
|
|
||||||
if (std::holds_alternative<std::monostate>(v)) {
|
|
||||||
axes.resize(dims);
|
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
|
||||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
||||||
axes.push_back(*pv);
|
|
||||||
} else {
|
|
||||||
axes = std::get<std::vector<int>>(v);
|
|
||||||
}
|
|
||||||
return axes;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline array to_array_with_accessor(py::object obj) {
|
inline array to_array_with_accessor(py::object obj) {
|
||||||
if (py::hasattr(obj, "__mlx_array__")) {
|
if (py::hasattr(obj, "__mlx_array__")) {
|
||||||
return obj.attr("__mlx_array__")().cast<array>();
|
return obj.attr("__mlx_array__")().cast<array>();
|
||||||
|
Loading…
Reference in New Issue
Block a user