From 8c43d820d99af8cc7b273a08fa801970b3494e42 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 18:33:23 +0100 Subject: [PATCH] completed the implementation of the norm --- mlx/linalg.cpp | 140 +++++++++++++++++++++++++++++++++++++++++++++---- mlx/linalg.h | 12 +++++ mlx/utils.cpp | 13 +++++ mlx/utils.h | 5 ++ 4 files changed, 159 insertions(+), 11 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index b49713afa..1847896d2 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include "mlx/array.h" @@ -14,10 +14,77 @@ namespace mlx::core::linalg { -inline std::vector get_shape_reducing_over_all_dims(int num_axes) { - std::vector shape(num_axes); - std::iota(shape.begin(), shape.end(), 0); - return shape; +inline array vector_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == 0.0) + return sum(a != 0, axis, keepdims, s); + else if (ord == 1.0) + return sum(abs(a, s), axis, keepdims, s); + else if (ord == 2.0) + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + else + return power( + sum(power(abs(a, s), array(ord), s), axis, keepdims, s), + array(1.0 / ord)); +} + +inline array vector_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "inf") + return max(abs(a, s), axis, keepdims, s); + else if (ord == "-inf") + return min(abs(a, s), axis, keepdims, s); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord; + throw std::invalid_argument(error_stream.str()); +} + +inline array matrix_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto row_axis = axis[0]; + auto col_axis = axis[1]; + if (!keepdims && col_axis > row_axis) + col_axis -= 1; + if (ord == -1.0) + return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); + if (ord == 1.0) + return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); + if (ord == 2.0 || ord == -2.0) + throw std::logic_error("Singular value norms are not implemented."); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(error_stream.str()); +} + +inline array matrix_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "f" || ord == "fro") + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + else if (ord == "inf") + return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s); + else if (ord == "-inf") + return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); + if (ord == "nuc") + throw std::logic_error("Nuclear norm is not implemented."); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(error_stream.str()); } array norm( @@ -28,14 +95,65 @@ array norm( auto num_axes = axis.size(); if (num_axes == 0 || num_axes == 1 || num_axes == 2) - return sqrt(sum( - abs(a, s) * abs(a, s), - num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()), - keepdims, - s)); + return sqrt( + sum(abs(a, s) * abs(a, s), + num_axes ? axis + : get_shape_reducing_over_all_axes(a.shape().size()), + keepdims, + s), + s); std::stringstream error_stream; - error_stream << "Invalid axis values" << axis; + error_stream << "Invalid axis values " << axis; throw std::invalid_argument(error_stream.str()); } + +array norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + std::vector ax = axis; + + if (axis.empty()) + ax = get_shape_reducing_over_all_axes(a.ndim()); + else + ax = normalize_axes(ax, a.ndim()); + + auto num_axes = ax.size(); + if (num_axes == 1) + return vector_norm(a, ord, ax, keepdims, s); + else if (num_axes == 2) + return matrix_norm(a, ord, ax, keepdims, s); + + std::stringstream error_stream; + error_stream << "Invalid axis values " << ax; + throw std::invalid_argument(error_stream.str()); +} + +array norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + std::vector ax = axis; + + if (axis.empty()) + ax = get_shape_reducing_over_all_axes(a.ndim()); + else + ax = normalize_axes(ax, a.ndim()); + + auto num_axes = ax.size(); + if (num_axes == 1) + return vector_norm(a, ord, ax, keepdims, s); + else if (num_axes == 2) + return matrix_norm(a, ord, ax, keepdims, s); + + std::stringstream error_stream; + error_stream << "Invalid axis values " << ax; + throw std::invalid_argument(error_stream.str()); +} + } // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h index fa9658bbb..690df343c 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -11,6 +11,18 @@ #include "string.h" namespace mlx::core::linalg { +array norm( + const array& a, + const double ord, + const std::vector& axis = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array norm( + const array& a, + const std::string& ord, + const std::vector& axis = {}, + bool keepdims = false, + StreamOrDevice s = {}); array norm( const array& a, const std::vector& axis = {}, diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 1fbc67c8e..ddcb41ba8 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include #include @@ -73,6 +74,12 @@ int normalize_axis(int axis, int ndim) { } return axis; } +std::vector normalize_axes(const std::vector& axes, int ndim) { + std::vector canonical; + for (int ax : axes) + canonical.push_back(normalize_axis(ax, ndim)); + return canonical; +} std::ostream& operator<<(std::ostream& os, const Device& d) { os << "Device("; @@ -279,4 +286,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } +std::vector get_shape_reducing_over_all_axes(int ndim) { + std::vector shape(ndim); + std::iota(shape.begin(), shape.end(), 0); + return shape; +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 823b4c872..1158b7c42 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -24,6 +24,7 @@ bool is_same_shape(const std::vector& arrays); * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html */ int normalize_axis(int axis, int ndim); +std::vector normalize_axes(const std::vector& axes, int ndim); std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); @@ -41,4 +42,8 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } +/** + * Returns the axes vector [0, 1, ... ndim). + */ +std::vector get_shape_reducing_over_all_axes(int ndim); } // namespace mlx::core