mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
some style and API consistency updates to linalg norm
This commit is contained in:
parent
4bae4a8239
commit
f7cea9563d
@ -57,6 +57,7 @@ are the CPU and GPU.
|
|||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
python/fft
|
python/fft
|
||||||
|
python/linalg
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
.. _linalg:
|
.. _linalg:
|
||||||
|
|
||||||
Linear Algebra
|
Linear Algebra
|
||||||
=====
|
==============
|
||||||
|
|
||||||
.. currentmodule:: mlx.core.linalg
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
|
190
mlx/linalg.cpp
190
mlx/linalg.cpp
@ -1,47 +1,42 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <sstream>
|
#include <numeric>
|
||||||
#include <string>
|
#include <ostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
Dtype at_least_float(const Dtype& d) {
|
||||||
|
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||||
|
}
|
||||||
|
|
||||||
inline array vector_norm(
|
inline array vector_norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const double ord,
|
const double ord,
|
||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (ord == 0.0)
|
auto dtype = at_least_float(a.dtype());
|
||||||
return sum(a != 0, axis, keepdims, s);
|
if (ord == 0.0) {
|
||||||
else if (ord == 1.0)
|
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
|
||||||
return sum(abs(a, s), axis, keepdims, s);
|
} else if (ord == 1.0) {
|
||||||
else if (ord == 2.0)
|
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
} else if (ord == 2.0) {
|
||||||
else
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else {
|
||||||
return power(
|
return power(
|
||||||
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
|
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
|
||||||
array(1.0 / ord));
|
array(1.0 / ord, dtype),
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array vector_norm(
|
|
||||||
const array& a,
|
|
||||||
const std::string& ord,
|
|
||||||
const std::vector<int>& axis,
|
|
||||||
bool keepdims,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
if (ord == "inf")
|
|
||||||
return max(abs(a, s), axis, keepdims, s);
|
|
||||||
else if (ord == "-inf")
|
|
||||||
return min(abs(a, s), axis, keepdims, s);
|
|
||||||
std::ostringstream error_stream;
|
|
||||||
error_stream << "Invalid ord value " << ord;
|
|
||||||
throw std::invalid_argument(error_stream.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array matrix_norm(
|
inline array matrix_norm(
|
||||||
@ -50,19 +45,30 @@ inline array matrix_norm(
|
|||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto row_axis = axis[0];
|
auto row_axis = axis[0];
|
||||||
auto col_axis = axis[1];
|
auto col_axis = axis[1];
|
||||||
if (!keepdims && col_axis > row_axis)
|
if (!keepdims && col_axis > row_axis && col_axis > 0) {
|
||||||
col_axis -= 1;
|
col_axis -= 1;
|
||||||
if (ord == -1.0)
|
}
|
||||||
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
if (ord == -1.0) {
|
||||||
if (ord == 1.0)
|
return astype(
|
||||||
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
if (ord == 2.0 || ord == -2.0)
|
dtype,
|
||||||
throw std::logic_error("Singular value norms are not implemented.");
|
s);
|
||||||
std::ostringstream error_stream;
|
} else if (ord == 1.0) {
|
||||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
return astype(
|
||||||
throw std::invalid_argument(error_stream.str());
|
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == 2.0 || ord == -2.0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[linalg::norm] Singular value norms are not implemented.");
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array matrix_norm(
|
inline array matrix_norm(
|
||||||
@ -71,85 +77,77 @@ inline array matrix_norm(
|
|||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (ord == "f" || ord == "fro")
|
if (ord == "f" || ord == "fro") {
|
||||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||||
else if (ord == "inf")
|
} else if (ord == "nuc") {
|
||||||
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
|
throw std::runtime_error(
|
||||||
else if (ord == "-inf")
|
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||||
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
} else {
|
||||||
if (ord == "nuc")
|
std::ostringstream msg;
|
||||||
throw std::logic_error("Nuclear norm is not implemented.");
|
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
|
||||||
std::ostringstream error_stream;
|
throw std::invalid_argument(msg.str());
|
||||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
}
|
||||||
throw std::invalid_argument(error_stream.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axis,
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
bool keepdims,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto num_axes = axis.size();
|
if (!axis) {
|
||||||
|
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
if (axis.value().size() > 2) {
|
||||||
return sqrt(
|
throw std::invalid_argument(
|
||||||
sum(abs(a, s) * abs(a, s),
|
"[linalg::norm] Received too many axes for norm");
|
||||||
num_axes ? axis : get_reduce_axes({}, a.ndim()),
|
}
|
||||||
keepdims,
|
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
||||||
s),
|
|
||||||
s);
|
|
||||||
|
|
||||||
std::ostringstream error_stream;
|
|
||||||
error_stream << "Invalid axis values " << axis;
|
|
||||||
throw std::invalid_argument(error_stream.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const double ord,
|
const double ord,
|
||||||
const std::vector<int>& axis,
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
bool keepdims,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s /* = {} */) {
|
||||||
std::vector<int> ax = axis;
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
if (axis.empty())
|
ax.resize(a.ndim());
|
||||||
ax = get_reduce_axes({}, a.ndim());
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
else
|
} else {
|
||||||
ax = normalize_axes(ax, a.ndim());
|
ax = axis.value();
|
||||||
|
}
|
||||||
auto num_axes = ax.size();
|
if (ax.size() == 1) {
|
||||||
if (num_axes == 1)
|
|
||||||
return vector_norm(a, ord, ax, keepdims, s);
|
return vector_norm(a, ord, ax, keepdims, s);
|
||||||
else if (num_axes == 2)
|
} else if (ax.size() == 2) {
|
||||||
return matrix_norm(a, ord, ax, keepdims, s);
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
} else {
|
||||||
std::ostringstream error_stream;
|
throw std::invalid_argument(
|
||||||
error_stream << "Invalid axis values " << ax;
|
"[linalg::norm] Received too many axes for norm");
|
||||||
throw std::invalid_argument(error_stream.str());
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
const std::string& ord,
|
||||||
const std::vector<int>& axis,
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
bool keepdims,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s /* = {} */) {
|
||||||
std::vector<int> ax = axis;
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
if (axis.empty())
|
ax.resize(a.ndim());
|
||||||
ax = get_reduce_axes({}, a.ndim());
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
else
|
} else {
|
||||||
ax = normalize_axes(ax, a.ndim());
|
ax = axis.value();
|
||||||
|
}
|
||||||
auto num_axes = ax.size();
|
if (ax.size() != 2) {
|
||||||
if (num_axes == 1)
|
std::ostringstream msg;
|
||||||
return vector_norm(a, ord, ax, keepdims, s);
|
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
|
||||||
else if (num_axes == 2)
|
<< " but received " << ax.size() << " axis/axes.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
return matrix_norm(a, ord, ax, keepdims, s);
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
|
||||||
std::ostringstream error_stream;
|
|
||||||
error_stream << "Invalid axis values " << ax;
|
|
||||||
throw std::invalid_argument(error_stream.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
40
mlx/linalg.h
40
mlx/linalg.h
@ -2,27 +2,61 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
#include "stream.h"
|
#include "stream.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Compute vector or matrix norms.
|
||||||
|
*
|
||||||
|
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||||
|
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||||
|
* - If axis is provided, but ord is not, then the 2-norm is computed along the
|
||||||
|
* given axes. At most 2 axes can be specified.
|
||||||
|
* - If both axis and ord are provided, then the corresponding matrix of vector
|
||||||
|
* norm is computed. At most 2 axes can be specified.
|
||||||
|
*/
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const double ord,
|
const double ord,
|
||||||
const std::vector<int>& axis = {},
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::string& ord,
|
const std::string& ord,
|
||||||
const std::vector<int>& axis = {},
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axis = {},
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
inline array
|
||||||
|
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
||||||
|
return norm(a, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) {
|
|||||||
}
|
}
|
||||||
return axis;
|
return axis;
|
||||||
}
|
}
|
||||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
|
|
||||||
std::vector<int> canonical;
|
|
||||||
for (int ax : axes)
|
|
||||||
canonical.push_back(normalize_axis(ax, ndim));
|
|
||||||
return canonical;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||||
os << "Device(";
|
os << "Device(";
|
||||||
|
16
mlx/utils.h
16
mlx/utils.h
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "dtype.h"
|
#include "dtype.h"
|
||||||
@ -25,7 +24,6 @@ bool is_same_shape(const std::vector<array>& arrays);
|
|||||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||||
*/
|
*/
|
||||||
int normalize_axis(int axis, int ndim);
|
int normalize_axis(int axis, int ndim);
|
||||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||||
@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
|||||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||||
return os << static_cast<float>(v);
|
return os << static_cast<float>(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
|
||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|
||||||
std::vector<int> axes;
|
|
||||||
if (std::holds_alternative<std::monostate>(v)) {
|
|
||||||
axes.resize(dims);
|
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
|
||||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
||||||
axes.push_back(*pv);
|
|
||||||
} else {
|
|
||||||
axes = std::get<std::vector<int>>(v);
|
|
||||||
}
|
|
||||||
return axes;
|
|
||||||
}
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
|
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace py::literals;
|
using namespace py::literals;
|
||||||
|
@ -26,106 +26,78 @@ using namespace mlx::core;
|
|||||||
using namespace mlx::core::linalg;
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
void init_linalg(py::module_& parent_module) {
|
void init_linalg(py::module_& parent_module) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
auto m =
|
auto m =
|
||||||
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"norm",
|
"norm",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
const std::variant<std::monostate, int, double, std::string>& ord,
|
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||||
const std::variant<std::monostate, int, std::vector<int>>& axis,
|
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||||
const bool keepdims,
|
const bool keepdims,
|
||||||
const StreamOrDevice stream) {
|
const StreamOrDevice stream) {
|
||||||
return std::visit(
|
std::optional<std::vector<int>> axis = std::nullopt;
|
||||||
overloaded{
|
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||||
[&](const double p) {
|
axis = std::vector<int>{*pv};
|
||||||
if (std::isinf((float)p) || std::isinf(p)) {
|
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
|
||||||
if (p > 0) {
|
axis = *pv;
|
||||||
return norm(
|
|
||||||
a,
|
|
||||||
"inf",
|
|
||||||
std::holds_alternative<std::monostate>(axis)
|
|
||||||
? std::vector<int>()
|
|
||||||
: get_reduce_axes(axis, a.ndim()),
|
|
||||||
keepdims,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
return norm(
|
|
||||||
a,
|
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||||
"-inf",
|
return norm(a, axis, keepdims, stream);
|
||||||
get_reduce_axes(axis, a.ndim()),
|
} else {
|
||||||
keepdims,
|
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||||
stream);
|
return norm(a, *pv, axis, keepdims, stream);
|
||||||
|
}
|
||||||
|
double ord;
|
||||||
|
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||||
|
ord = *pv;
|
||||||
|
} else {
|
||||||
|
ord = std::get<double>(ord_);
|
||||||
|
}
|
||||||
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
}
|
}
|
||||||
return norm(
|
|
||||||
a,
|
|
||||||
p,
|
|
||||||
std::holds_alternative<std::monostate>(axis)
|
|
||||||
? std::vector<int>()
|
|
||||||
: get_reduce_axes(axis, a.ndim()),
|
|
||||||
keepdims,
|
|
||||||
stream);
|
|
||||||
},
|
|
||||||
[&](const std::string& p) {
|
|
||||||
return norm(
|
|
||||||
a,
|
|
||||||
p,
|
|
||||||
std::holds_alternative<std::monostate>(axis)
|
|
||||||
? std::vector<int>()
|
|
||||||
: get_reduce_axes(axis, a.ndim()),
|
|
||||||
keepdims,
|
|
||||||
stream);
|
|
||||||
},
|
|
||||||
[&](const std::monostate _) {
|
|
||||||
return norm(
|
|
||||||
a,
|
|
||||||
std::holds_alternative<std::monostate>(axis)
|
|
||||||
? std::vector<int>()
|
|
||||||
: get_reduce_axes(axis, a.ndim()),
|
|
||||||
keepdims,
|
|
||||||
stream);
|
|
||||||
}},
|
|
||||||
ord);
|
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
"ord"_a = none,
|
"ord"_a = none,
|
||||||
"axis"_a = none,
|
"axis"_a = none,
|
||||||
"keepdims"_a = false,
|
"keepdims"_a = false,
|
||||||
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Matrix or vector norm.
|
Matrix or vector norm.
|
||||||
|
|
||||||
This function is able to return matrix or vector norms,
|
This function computes vector or matrix norms depending on the value of
|
||||||
depending on the value of the ``ord`` parameter.
|
the ``ord`` and ``axis`` parameters.
|
||||||
|
|
||||||
Parameters
|
Args:
|
||||||
----------
|
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||||
a : array_like
|
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||||
Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord`
|
2-norm of ``a.flatten`` will be returned.
|
||||||
is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned.
|
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||||
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
|
If ``None``, the 2-norm will be computed along the given ``axis``.
|
||||||
Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None.
|
Default: ``None``.
|
||||||
axis : {None, int, 2-tuple of ints}, optional.
|
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||||
If `axis` is an integer, it specifies the axis of `a` along which to
|
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||||
compute the vector norms. If `axis` is a 2-tuple, it specifies the
|
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||||
axes that hold 2-D matrices, and the matrix norms of these matrices
|
norms of these matrices are computed. If `axis` is ``None`` then
|
||||||
are computed. If `axis` is None then either a vector norm (when `a`
|
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
|
||||||
is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default
|
2-D) is returned. Default: ``None``.
|
||||||
is None.
|
keepdims (bool, optional): If ``True``, the axes which are normed over are
|
||||||
keepdims : bool, optional
|
left in the result as dimensions with size one. Default ``False``.
|
||||||
If this is set to True, the axes which are normed over are left in the
|
|
||||||
result as dimensions with size one. With this option the result will
|
|
||||||
broadcast correctly against the original `a`.
|
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
array: The output containing the norm(s).
|
||||||
n : array
|
|
||||||
Norm of the matrix or vector(s).
|
|
||||||
|
|
||||||
Notes
|
Notes:
|
||||||
-----
|
|
||||||
For values of ``ord < 1``, the result is, strictly speaking, not a
|
For values of ``ord < 1``, the result is, strictly speaking, not a
|
||||||
mathematical 'norm', but it may still be useful for various numerical
|
mathematical norm, but it may still be useful for various numerical
|
||||||
purposes.
|
purposes.
|
||||||
|
|
||||||
The following norms can be calculated:
|
The following norms can be calculated:
|
||||||
@ -145,26 +117,25 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
other -- sum(abs(x)**ord)**(1./ord)
|
other -- sum(abs(x)**ord)**(1./ord)
|
||||||
===== ============================ ==========================
|
===== ============================ ==========================
|
||||||
|
|
||||||
|
.. warning::
|
||||||
Nuclear norm and norms based on singular values are not yet implemented.
|
Nuclear norm and norms based on singular values are not yet implemented.
|
||||||
|
|
||||||
The Frobenius norm is given by [1]_:
|
The Frobenius norm is given by [1]_:
|
||||||
|
|
||||||
:math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||||
|
|
||||||
The nuclear norm is the sum of the singular values.
|
The nuclear norm is the sum of the singular values.
|
||||||
|
|
||||||
Both the Frobenius and nuclear norm orders are only defined for
|
Both the Frobenius and nuclear norm orders are only defined for
|
||||||
matrices and raise a ValueError when ``a.ndim != 2``.
|
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
|
||||||
|
|
||||||
References
|
References:
|
||||||
----------
|
|
||||||
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
||||||
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
||||||
|
|
||||||
Examples
|
Examples:
|
||||||
--------
|
|
||||||
>>> import mlx.core as mx
|
>>> import mlx.core as mx
|
||||||
>>> from mlx.core import linalg as LA
|
>>> from mlx.core import linalg as la
|
||||||
>>> a = mx.arange(9) - 4
|
>>> a = mx.arange(9) - 4
|
||||||
>>> a
|
>>> a
|
||||||
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
||||||
@ -173,46 +144,46 @@ void init_linalg(py::module_& parent_module) {
|
|||||||
array([[-4, -3, -2],
|
array([[-4, -3, -2],
|
||||||
[-1, 0, 1],
|
[-1, 0, 1],
|
||||||
[ 2, 3, 4]], dtype=int32)
|
[ 2, 3, 4]], dtype=int32)
|
||||||
>>> LA.norm(a)
|
>>> la.norm(a)
|
||||||
array(7.74597, dtype=float32)
|
array(7.74597, dtype=float32)
|
||||||
>>> LA.norm(b)
|
>>> la.norm(b)
|
||||||
array(7.74597, dtype=float32)
|
array(7.74597, dtype=float32)
|
||||||
>>> LA.norm(b, 'fro')
|
>>> la.norm(b, 'fro')
|
||||||
array(7.74597, dtype=float32)
|
array(7.74597, dtype=float32)
|
||||||
>>> LA.norm(a, float("inf"))
|
>>> la.norm(a, float("inf"))
|
||||||
array(4, dtype=int32)
|
array(4, dtype=float32)
|
||||||
>>> LA.norm(b, float("inf"))
|
>>> la.norm(b, float("inf"))
|
||||||
array(9, dtype=int32)
|
array(9, dtype=float32)
|
||||||
>>> LA.norm(a, -float("inf"))
|
>>> la.norm(a, -float("inf"))
|
||||||
array(0, dtype=int32)
|
|
||||||
>>> LA.norm(b, -float("inf"))
|
|
||||||
array(2, dtype=int32)
|
|
||||||
>>> LA.norm(a, 1)
|
|
||||||
array(20, dtype=int32)
|
|
||||||
>>> LA.norm(b, 1)
|
|
||||||
array(7, dtype=int32)
|
|
||||||
>>> LA.norm(a, -1)
|
|
||||||
array(0, dtype=float32)
|
array(0, dtype=float32)
|
||||||
>>> LA.norm(b, -1)
|
>>> la.norm(b, -float("inf"))
|
||||||
array(6, dtype=int32)
|
array(2, dtype=float32)
|
||||||
>>> LA.norm(a, 2)
|
>>> la.norm(a, 1)
|
||||||
|
array(20, dtype=float32)
|
||||||
|
>>> la.norm(b, 1)
|
||||||
|
array(7, dtype=float32)
|
||||||
|
>>> la.norm(a, -1)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> la.norm(b, -1)
|
||||||
|
array(6, dtype=float32)
|
||||||
|
>>> la.norm(a, 2)
|
||||||
array(7.74597, dtype=float32)
|
array(7.74597, dtype=float32)
|
||||||
>>> LA.norm(a, 3)
|
>>> la.norm(a, 3)
|
||||||
array(5.84804, dtype=float32)
|
array(5.84804, dtype=float32)
|
||||||
>>> LA.norm(a, -3)
|
>>> la.norm(a, -3)
|
||||||
array(0, dtype=float32)
|
array(0, dtype=float32)
|
||||||
>>> c = mx.array([[ 1, 2, 3],
|
>>> c = mx.array([[ 1, 2, 3],
|
||||||
... [-1, 1, 4]])
|
... [-1, 1, 4]])
|
||||||
>>> LA.norm(c, axis=0)
|
>>> la.norm(c, axis=0)
|
||||||
array([1.41421, 2.23607, 5], dtype=float32)
|
array([1.41421, 2.23607, 5], dtype=float32)
|
||||||
>>> LA.norm(c, axis=1)
|
>>> la.norm(c, axis=1)
|
||||||
array([3.74166, 4.24264], dtype=float32)
|
array([3.74166, 4.24264], dtype=float32)
|
||||||
>>> LA.norm(c, ord=1, axis=1)
|
>>> la.norm(c, ord=1, axis=1)
|
||||||
array([6, 6], dtype=int32)
|
array([6, 6], dtype=float32)
|
||||||
>>> m = mx.arange(8).reshape(2,2,2)
|
>>> m = mx.arange(8).reshape(2,2,2)
|
||||||
>>> LA.norm(m, axis=(1,2))
|
>>> la.norm(m, axis=(1,2))
|
||||||
array([3.74166, 11.225], dtype=float32)
|
array([3.74166, 11.225], dtype=float32)
|
||||||
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
||||||
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <numeric>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include <pybind11/complex.h>
|
#include <pybind11/complex.h>
|
||||||
@ -13,10 +14,24 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
using ScalarOrArray = std::
|
using ScalarOrArray = std::
|
||||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||||
static constexpr std::monostate none{};
|
static constexpr std::monostate none{};
|
||||||
|
|
||||||
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||||
|
std::vector<int> axes;
|
||||||
|
if (std::holds_alternative<std::monostate>(v)) {
|
||||||
|
axes.resize(dims);
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||||
|
axes.push_back(*pv);
|
||||||
|
} else {
|
||||||
|
axes = std::get<std::vector<int>>(v);
|
||||||
|
}
|
||||||
|
return axes;
|
||||||
|
}
|
||||||
|
|
||||||
inline array to_array_with_accessor(py::object obj) {
|
inline array to_array_with_accessor(py::object obj) {
|
||||||
if (py::hasattr(obj, "__mlx_array__")) {
|
if (py::hasattr(obj, "__mlx_array__")) {
|
||||||
return obj.attr("__mlx_array__")().cast<array>();
|
return obj.attr("__mlx_array__")().cast<array>();
|
||||||
|
@ -3,170 +3,155 @@
|
|||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
using namespace mlx::core::linalg;
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||||
array arr_one_d({1, 2, 3});
|
// Zero dimensions
|
||||||
array arr_two_d = reshape(arange(9), {3, 3});
|
array x(2.0);
|
||||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
CHECK_EQ(norm(x).item<float>(), 2.0f);
|
||||||
|
CHECK_THROWS(norm(x, 0));
|
||||||
|
|
||||||
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
|
x = array({1, 2, 3});
|
||||||
CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9)))
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
.item<bool>());
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, true).ndim(), 1);
|
||||||
|
CHECK_THROWS(norm(x, 1));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
expected =
|
||||||
|
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
|
||||||
|
|
||||||
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
norm(arr_two_d, {}, false),
|
norm(x, 0, false),
|
||||||
array(sqrt(
|
|
||||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_two_d, {0}, false),
|
|
||||||
array(
|
array(
|
||||||
{sqrt(0 + 3 * 3 + 6 * 6),
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
sqrt(1 + 4 * 4 + 7 * 7),
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(allclose(
|
||||||
norm(arr_two_d, {1}, false),
|
norm(x, 1, false),
|
||||||
array(
|
array(
|
||||||
{sqrt(0 + 1 + 2 * 2),
|
{std::sqrt(0 + 1 + 2 * 2),
|
||||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_two_d, {0, 1}, false),
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
array(sqrt(
|
CHECK(allclose(
|
||||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
norm(x, 2, false),
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, {2}, false),
|
|
||||||
array(
|
array(
|
||||||
{
|
{
|
||||||
sqrt(0 + 1 + 2 * 2),
|
std::sqrt(0 + 1 + 2 * 2),
|
||||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||||
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||||
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||||
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||||
},
|
},
|
||||||
{2, 3}))
|
{2, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, {1}, false),
|
norm(x, std::vector<int>{1, 2}, false),
|
||||||
array(
|
array(
|
||||||
{
|
{std::sqrt(
|
||||||
sqrt(0 + 3 * 3 + 6 * 6),
|
|
||||||
sqrt(1 + 4 * 4 + 7 * 7),
|
|
||||||
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
|
||||||
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
|
||||||
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
|
||||||
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
|
||||||
},
|
|
||||||
{2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, {0}, false),
|
|
||||||
array(
|
|
||||||
{
|
|
||||||
sqrt(0 + 9 * 9),
|
|
||||||
sqrt(1 + 10 * 10),
|
|
||||||
sqrt(2 * 2 + 11 * 11),
|
|
||||||
sqrt(3 * 3 + 12 * 12),
|
|
||||||
sqrt(4 * 4 + 13 * 13),
|
|
||||||
sqrt(5 * 5 + 14 * 14),
|
|
||||||
sqrt(6 * 6 + 15 * 15),
|
|
||||||
sqrt(7 * 7 + 16 * 16),
|
|
||||||
sqrt(8 * 8 + 17 * 17),
|
|
||||||
},
|
|
||||||
{3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, {1, 2}, false),
|
|
||||||
array(
|
|
||||||
{sqrt(
|
|
||||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||||
8 * 8),
|
8 * 8),
|
||||||
sqrt(
|
std::sqrt(
|
||||||
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||||
15 * 15 + 16 * 16 + 17 * 17)},
|
15 * 15 + 16 * 16 + 17 * 17)},
|
||||||
{2}))
|
{2}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||||
array arr_one_d({1, 2, 3});
|
CHECK_THROWS(norm(array(0), 2.0));
|
||||||
array arr_two_d = reshape(arange(9), {3, 3});
|
|
||||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
|
||||||
|
|
||||||
CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item<bool>());
|
array x({1, 2, 3});
|
||||||
CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item<bool>());
|
|
||||||
CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item<bool>());
|
|
||||||
|
|
||||||
CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9)))
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
.item<bool>());
|
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
|
||||||
CHECK(array_equal(
|
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
|
||||||
norm(arr_two_d, 2.0, {0}, false),
|
CHECK_THROWS(norm(x, 2.0, 1));
|
||||||
|
|
||||||
|
expected = 1 + 2 + 3;
|
||||||
|
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 1;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, 0, false),
|
||||||
array(
|
array(
|
||||||
{sqrt(0 + 3 * 3 + 6 * 6),
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
sqrt(1 + 4 * 4 + 7 * 7),
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(allclose(
|
||||||
norm(arr_two_d, 2.0, {1}, false),
|
norm(x, 2.0, 1, false),
|
||||||
array(
|
array(
|
||||||
{sqrt(0 + 1 + 2 * 2),
|
{sqrt(0 + 1 + 2 * 2),
|
||||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, 2.0, {2}, false),
|
|
||||||
array(
|
|
||||||
{
|
|
||||||
sqrt(0 + 1 + 2 * 2),
|
|
||||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
|
||||||
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
|
||||||
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
|
||||||
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
|
||||||
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
|
||||||
},
|
|
||||||
{2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, 2.0, {1}, false),
|
|
||||||
array(
|
|
||||||
{
|
|
||||||
sqrt(0 + 3 * 3 + 6 * 6),
|
|
||||||
sqrt(1 + 4 * 4 + 7 * 7),
|
|
||||||
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
|
||||||
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
|
||||||
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
|
||||||
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
|
||||||
},
|
|
||||||
{2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
norm(arr_three_d, 2.0, {0}, false),
|
|
||||||
array(
|
|
||||||
{
|
|
||||||
sqrt(0 + 9 * 9),
|
|
||||||
sqrt(1 + 10 * 10),
|
|
||||||
sqrt(2 * 2 + 11 * 11),
|
|
||||||
sqrt(3 * 3 + 12 * 12),
|
|
||||||
sqrt(4 * 4 + 13 * 13),
|
|
||||||
sqrt(5 * 5 + 14 * 14),
|
|
||||||
sqrt(6 * 6 + 15 * 15),
|
|
||||||
sqrt(7 * 7 + 16 * 16),
|
|
||||||
sqrt(8 * 8 + 17 * 17),
|
|
||||||
},
|
|
||||||
{3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(21.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(3.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, 3.0, {0}),
|
norm(x, 3.0, 0),
|
||||||
array(
|
array(
|
||||||
{9.,
|
{9.,
|
||||||
10.00333222,
|
10.00333222,
|
||||||
@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
17.57113899},
|
17.57113899},
|
||||||
{3, 3}))
|
{3, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(
|
|
||||||
allclose(
|
|
||||||
norm(arr_three_d, 3.0, {1}),
|
|
||||||
array(
|
|
||||||
{6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893},
|
|
||||||
{2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, 3.0, {2}),
|
norm(x, 3.0, 2),
|
||||||
array(
|
array(
|
||||||
{2.08008382,
|
{2.08008382,
|
||||||
6.,
|
6.,
|
||||||
@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
23.13593104},
|
23.13593104},
|
||||||
{2, 3}))
|
{2, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
|
||||||
norm(arr_three_d, 0.0, {0}),
|
|
||||||
array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(
|
CHECK(
|
||||||
allclose(
|
allclose(
|
||||||
norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(
|
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
allclose(
|
.item<bool>());
|
||||||
norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, 1.0, {0}),
|
norm(x, 1.0, 0),
|
||||||
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
|
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
||||||
norm(arr_three_d, 1.0, {1}),
|
|
||||||
array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
||||||
norm(arr_three_d, 1.0, {2}),
|
|
||||||
array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
|
||||||
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item<bool>());
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
|
||||||
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item<bool>());
|
|
||||||
|
|
||||||
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1}))
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1}))
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1}))
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
|
||||||
CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0))
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
//
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||||
CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(
|
|
||||||
allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(
|
|
||||||
allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item<bool>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||||
array arr_one_d({1, 2, 3});
|
array x({1, 2, 3});
|
||||||
array arr_two_d = reshape(arange(9), {3, 3});
|
CHECK_THROWS(norm(x, "fro"));
|
||||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
|
||||||
|
|
||||||
CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item<bool>());
|
x = reshape(arange(9), {3, 3});
|
||||||
CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item<bool>());
|
CHECK_THROWS(norm(x, "bad ord"));
|
||||||
|
|
||||||
CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857}))
|
CHECK_EQ(
|
||||||
.item<bool>());
|
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
|
||||||
CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857}))
|
doctest::Approx(14.2828568570857));
|
||||||
.item<bool>());
|
CHECK_EQ(
|
||||||
CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item<bool>());
|
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||||
CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item<bool>());
|
doctest::Approx(14.2828568570857));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, "fro", {0, 1}),
|
norm(x, "fro", std::vector<int>{0, 1}),
|
||||||
array({22.24859546, 24.31049156, 26.43860813}))
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907}))
|
norm(x, "fro", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, "f", {0, 1}),
|
norm(x, "f", std::vector<int>{0, 1}),
|
||||||
array({22.24859546, 24.31049156, 26.43860813}))
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(arr_three_d, "f", {1, 0}),
|
norm(x, "f", std::vector<int>{1, 0}),
|
||||||
array({22.24859546, 24.31049156, 26.43860813}))
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(
|
CHECK(allclose(
|
||||||
allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907}))
|
norm(x, "f", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(
|
CHECK(allclose(
|
||||||
allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907}))
|
norm(x, "f", std::vector<int>{2, 1}),
|
||||||
.item<bool>());
|
array({14.28285686, 39.7617907}))
|
||||||
CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user